protocol.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package trojan
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "github.com/xtls/xray-core/common/buf"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/common/net"
  8. "github.com/xtls/xray-core/common/protocol"
  9. )
  10. var (
  11. crlf = []byte{'\r', '\n'}
  12. addrParser = protocol.NewAddressParser(
  13. protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
  14. protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
  15. protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
  16. )
  17. )
  18. const (
  19. maxLength = 8192
  20. commandTCP byte = 1
  21. commandUDP byte = 3
  22. )
  23. // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
  24. type ConnWriter struct {
  25. io.Writer
  26. Target net.Destination
  27. Account *MemoryAccount
  28. headerSent bool
  29. }
  30. // Write implements io.Writer
  31. func (c *ConnWriter) Write(p []byte) (n int, err error) {
  32. if !c.headerSent {
  33. if err := c.writeHeader(); err != nil {
  34. return 0, errors.New("failed to write request header").Base(err)
  35. }
  36. }
  37. return c.Writer.Write(p)
  38. }
  39. // WriteMultiBuffer implements buf.Writer
  40. func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  41. defer buf.ReleaseMulti(mb)
  42. for _, b := range mb {
  43. if !b.IsEmpty() {
  44. if _, err := c.Write(b.Bytes()); err != nil {
  45. return err
  46. }
  47. }
  48. }
  49. return nil
  50. }
  51. func (c *ConnWriter) writeHeader() error {
  52. buffer := buf.StackNew()
  53. defer buffer.Release()
  54. command := commandTCP
  55. if c.Target.Network == net.Network_UDP {
  56. command = commandUDP
  57. }
  58. if _, err := buffer.Write(c.Account.Key); err != nil {
  59. return err
  60. }
  61. if _, err := buffer.Write(crlf); err != nil {
  62. return err
  63. }
  64. if err := buffer.WriteByte(command); err != nil {
  65. return err
  66. }
  67. if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
  68. return err
  69. }
  70. if _, err := buffer.Write(crlf); err != nil {
  71. return err
  72. }
  73. _, err := c.Writer.Write(buffer.Bytes())
  74. if err == nil {
  75. c.headerSent = true
  76. }
  77. return err
  78. }
  79. // PacketWriter UDP Connection Writer Wrapper for trojan protocol
  80. type PacketWriter struct {
  81. io.Writer
  82. Target net.Destination
  83. }
  84. // WriteMultiBuffer implements buf.Writer
  85. func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  86. for {
  87. mb2, b := buf.SplitFirst(mb)
  88. mb = mb2
  89. if b == nil {
  90. break
  91. }
  92. target := &w.Target
  93. if b.UDP != nil {
  94. target = b.UDP
  95. }
  96. if _, err := w.writePacket(b.Bytes(), *target); err != nil {
  97. b.Release()
  98. buf.ReleaseMulti(mb)
  99. return err
  100. }
  101. b.Release()
  102. }
  103. return nil
  104. }
  105. func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
  106. buffer := buf.StackNew()
  107. defer buffer.Release()
  108. length := len(payload)
  109. lengthBuf := [2]byte{}
  110. binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
  111. if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
  112. return 0, err
  113. }
  114. if _, err := buffer.Write(lengthBuf[:]); err != nil {
  115. return 0, err
  116. }
  117. if _, err := buffer.Write(crlf); err != nil {
  118. return 0, err
  119. }
  120. if _, err := buffer.Write(payload); err != nil {
  121. return 0, err
  122. }
  123. _, err := w.Write(buffer.Bytes())
  124. if err != nil {
  125. return 0, err
  126. }
  127. return length, nil
  128. }
  129. // ConnReader is TCP Connection Reader Wrapper for trojan protocol
  130. type ConnReader struct {
  131. io.Reader
  132. Target net.Destination
  133. Flow string
  134. headerParsed bool
  135. }
  136. // ParseHeader parses the trojan protocol header
  137. func (c *ConnReader) ParseHeader() error {
  138. var crlf [2]byte
  139. var command [1]byte
  140. var hash [56]byte
  141. if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
  142. return errors.New("failed to read user hash").Base(err)
  143. }
  144. if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
  145. return errors.New("failed to read crlf").Base(err)
  146. }
  147. if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
  148. return errors.New("failed to read command").Base(err)
  149. }
  150. network := net.Network_TCP
  151. if command[0] == commandUDP {
  152. network = net.Network_UDP
  153. }
  154. addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
  155. if err != nil {
  156. return errors.New("failed to read address and port").Base(err)
  157. }
  158. c.Target = net.Destination{Network: network, Address: addr, Port: port}
  159. if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
  160. return errors.New("failed to read crlf").Base(err)
  161. }
  162. c.headerParsed = true
  163. return nil
  164. }
  165. // Read implements io.Reader
  166. func (c *ConnReader) Read(p []byte) (int, error) {
  167. if !c.headerParsed {
  168. if err := c.ParseHeader(); err != nil {
  169. return 0, err
  170. }
  171. }
  172. return c.Reader.Read(p)
  173. }
  174. // ReadMultiBuffer implements buf.Reader
  175. func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  176. b := buf.New()
  177. _, err := b.ReadFrom(c)
  178. return buf.MultiBuffer{b}, err
  179. }
  180. // PacketReader is UDP Connection Reader Wrapper for trojan protocol
  181. type PacketReader struct {
  182. io.Reader
  183. }
  184. // ReadMultiBuffer implements buf.Reader
  185. func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  186. addr, port, err := addrParser.ReadAddressPort(nil, r)
  187. if err != nil {
  188. return nil, errors.New("failed to read address and port").Base(err)
  189. }
  190. var lengthBuf [2]byte
  191. if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
  192. return nil, errors.New("failed to read payload length").Base(err)
  193. }
  194. remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
  195. if remain > maxLength {
  196. return nil, errors.New("oversize payload")
  197. }
  198. var crlf [2]byte
  199. if _, err := io.ReadFull(r, crlf[:]); err != nil {
  200. return nil, errors.New("failed to read crlf").Base(err)
  201. }
  202. dest := net.UDPDestination(addr, port)
  203. var mb buf.MultiBuffer
  204. for remain > 0 {
  205. length := buf.Size
  206. if remain < length {
  207. length = remain
  208. }
  209. b := buf.New()
  210. b.UDP = &dest
  211. mb = append(mb, b)
  212. n, err := b.ReadFullFrom(r, int32(length))
  213. if err != nil {
  214. buf.ReleaseMulti(mb)
  215. return nil, errors.New("failed to read payload").Base(err)
  216. }
  217. remain -= int(n)
  218. }
  219. return mb, nil
  220. }