protocol.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package trojan
  2. import (
  3. "context"
  4. "encoding/binary"
  5. fmt "fmt"
  6. "io"
  7. "runtime"
  8. "syscall"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/common/signal"
  15. "github.com/xtls/xray-core/features/stats"
  16. "github.com/xtls/xray-core/transport/internet"
  17. "github.com/xtls/xray-core/transport/internet/xtls"
  18. )
  19. var (
  20. crlf = []byte{'\r', '\n'}
  21. addrParser = protocol.NewAddressParser(
  22. protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
  23. protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
  24. protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
  25. )
  26. xtls_show = false
  27. )
  28. const (
  29. maxLength = 8192
  30. // XRS is constant for XTLS splice mode
  31. XRS = "xtls-rprx-splice"
  32. // XRD is constant for XTLS direct mode
  33. XRD = "xtls-rprx-direct"
  34. // XRO is constant for XTLS origin mode
  35. XRO = "xtls-rprx-origin"
  36. commandTCP byte = 1
  37. commandUDP byte = 3
  38. // for XTLS
  39. commandXRD byte = 0xf0 // XTLS direct mode
  40. commandXRO byte = 0xf1 // XTLS origin mode
  41. )
  42. // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
  43. type ConnWriter struct {
  44. io.Writer
  45. Target net.Destination
  46. Account *MemoryAccount
  47. Flow string
  48. headerSent bool
  49. }
  50. // Write implements io.Writer
  51. func (c *ConnWriter) Write(p []byte) (n int, err error) {
  52. if !c.headerSent {
  53. if err := c.writeHeader(); err != nil {
  54. return 0, newError("failed to write request header").Base(err)
  55. }
  56. }
  57. return c.Writer.Write(p)
  58. }
  59. // WriteMultiBuffer implements buf.Writer
  60. func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  61. defer buf.ReleaseMulti(mb)
  62. for _, b := range mb {
  63. if !b.IsEmpty() {
  64. if _, err := c.Write(b.Bytes()); err != nil {
  65. return err
  66. }
  67. }
  68. }
  69. return nil
  70. }
  71. func (c *ConnWriter) writeHeader() error {
  72. buffer := buf.StackNew()
  73. defer buffer.Release()
  74. command := commandTCP
  75. if c.Target.Network == net.Network_UDP {
  76. command = commandUDP
  77. } else if c.Flow == XRD {
  78. command = commandXRD
  79. } else if c.Flow == XRO {
  80. command = commandXRO
  81. }
  82. if _, err := buffer.Write(c.Account.Key); err != nil {
  83. return err
  84. }
  85. if _, err := buffer.Write(crlf); err != nil {
  86. return err
  87. }
  88. if err := buffer.WriteByte(command); err != nil {
  89. return err
  90. }
  91. if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
  92. return err
  93. }
  94. if _, err := buffer.Write(crlf); err != nil {
  95. return err
  96. }
  97. _, err := c.Writer.Write(buffer.Bytes())
  98. if err == nil {
  99. c.headerSent = true
  100. }
  101. return err
  102. }
  103. // PacketWriter UDP Connection Writer Wrapper for trojan protocol
  104. type PacketWriter struct {
  105. io.Writer
  106. Target net.Destination
  107. }
  108. // WriteMultiBuffer implements buf.Writer
  109. func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  110. for {
  111. mb2, b := buf.SplitFirst(mb)
  112. mb = mb2
  113. if b == nil {
  114. break
  115. }
  116. target := &w.Target
  117. if b.UDP != nil {
  118. target = b.UDP
  119. }
  120. if _, err := w.writePacket(b.Bytes(), *target); err != nil {
  121. buf.ReleaseMulti(mb)
  122. return err
  123. }
  124. }
  125. return nil
  126. }
  127. // WriteMultiBufferWithMetadata writes udp packet with destination specified
  128. func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
  129. for {
  130. mb2, b := buf.SplitFirst(mb)
  131. mb = mb2
  132. if b == nil {
  133. break
  134. }
  135. source := &dest
  136. if b.UDP != nil {
  137. source = b.UDP
  138. }
  139. if _, err := w.writePacket(b.Bytes(), *source); err != nil {
  140. buf.ReleaseMulti(mb)
  141. return err
  142. }
  143. }
  144. return nil
  145. }
  146. func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
  147. buffer := buf.StackNew()
  148. defer buffer.Release()
  149. length := len(payload)
  150. lengthBuf := [2]byte{}
  151. binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
  152. if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
  153. return 0, err
  154. }
  155. if _, err := buffer.Write(lengthBuf[:]); err != nil {
  156. return 0, err
  157. }
  158. if _, err := buffer.Write(crlf); err != nil {
  159. return 0, err
  160. }
  161. if _, err := buffer.Write(payload); err != nil {
  162. return 0, err
  163. }
  164. _, err := w.Write(buffer.Bytes())
  165. if err != nil {
  166. return 0, err
  167. }
  168. return length, nil
  169. }
  170. // ConnReader is TCP Connection Reader Wrapper for trojan protocol
  171. type ConnReader struct {
  172. io.Reader
  173. Target net.Destination
  174. Flow string
  175. headerParsed bool
  176. }
  177. // ParseHeader parses the trojan protocol header
  178. func (c *ConnReader) ParseHeader() error {
  179. var crlf [2]byte
  180. var command [1]byte
  181. var hash [56]byte
  182. if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
  183. return newError("failed to read user hash").Base(err)
  184. }
  185. if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
  186. return newError("failed to read crlf").Base(err)
  187. }
  188. if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
  189. return newError("failed to read command").Base(err)
  190. }
  191. network := net.Network_TCP
  192. if command[0] == commandUDP {
  193. network = net.Network_UDP
  194. } else if command[0] == commandXRD {
  195. c.Flow = XRD
  196. } else if command[0] == commandXRO {
  197. c.Flow = XRO
  198. }
  199. addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
  200. if err != nil {
  201. return newError("failed to read address and port").Base(err)
  202. }
  203. c.Target = net.Destination{Network: network, Address: addr, Port: port}
  204. if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
  205. return newError("failed to read crlf").Base(err)
  206. }
  207. c.headerParsed = true
  208. return nil
  209. }
  210. // Read implements io.Reader
  211. func (c *ConnReader) Read(p []byte) (int, error) {
  212. if !c.headerParsed {
  213. if err := c.ParseHeader(); err != nil {
  214. return 0, err
  215. }
  216. }
  217. return c.Reader.Read(p)
  218. }
  219. // ReadMultiBuffer implements buf.Reader
  220. func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  221. b := buf.New()
  222. _, err := b.ReadFrom(c)
  223. return buf.MultiBuffer{b}, err
  224. }
  225. // PacketPayload combines udp payload and destination
  226. type PacketPayload struct {
  227. Target net.Destination
  228. Buffer buf.MultiBuffer
  229. }
  230. // PacketReader is UDP Connection Reader Wrapper for trojan protocol
  231. type PacketReader struct {
  232. io.Reader
  233. }
  234. // ReadMultiBuffer implements buf.Reader
  235. func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  236. p, err := r.ReadMultiBufferWithMetadata()
  237. if p != nil {
  238. return p.Buffer, err
  239. }
  240. return nil, err
  241. }
  242. // ReadMultiBufferWithMetadata reads udp packet with destination
  243. func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
  244. addr, port, err := addrParser.ReadAddressPort(nil, r)
  245. if err != nil {
  246. return nil, newError("failed to read address and port").Base(err)
  247. }
  248. var lengthBuf [2]byte
  249. if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
  250. return nil, newError("failed to read payload length").Base(err)
  251. }
  252. remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
  253. if remain > maxLength {
  254. return nil, newError("oversize payload")
  255. }
  256. var crlf [2]byte
  257. if _, err := io.ReadFull(r, crlf[:]); err != nil {
  258. return nil, newError("failed to read crlf").Base(err)
  259. }
  260. dest := net.UDPDestination(addr, port)
  261. var mb buf.MultiBuffer
  262. for remain > 0 {
  263. length := buf.Size
  264. if remain < length {
  265. length = remain
  266. }
  267. b := buf.New()
  268. b.UDP = &dest
  269. mb = append(mb, b)
  270. n, err := b.ReadFullFrom(r, int32(length))
  271. if err != nil {
  272. buf.ReleaseMulti(mb)
  273. return nil, newError("failed to read payload").Base(err)
  274. }
  275. remain -= int(n)
  276. }
  277. return &PacketPayload{Target: dest, Buffer: mb}, nil
  278. }
  279. func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error {
  280. err := func() error {
  281. var ct stats.Counter
  282. for {
  283. if conn.DirectIn {
  284. conn.DirectIn = false
  285. if sctx != nil {
  286. if inbound := session.InboundFromContext(sctx); inbound != nil && inbound.Conn != nil {
  287. iConn := inbound.Conn
  288. statConn, ok := iConn.(*internet.StatCouterConnection)
  289. if ok {
  290. iConn = statConn.Connection
  291. }
  292. if xc, ok := iConn.(*xtls.Conn); ok {
  293. iConn = xc.Connection
  294. }
  295. if tc, ok := iConn.(*net.TCPConn); ok {
  296. if conn.SHOW {
  297. fmt.Println(conn.MARK, "Splice")
  298. }
  299. runtime.Gosched() // necessary
  300. w, err := tc.ReadFrom(conn.Connection)
  301. if counter != nil {
  302. counter.Add(w)
  303. }
  304. if statConn != nil && statConn.WriteCounter != nil {
  305. statConn.WriteCounter.Add(w)
  306. }
  307. return err
  308. } else {
  309. panic("XTLS Splice: not TCP inbound")
  310. }
  311. } else {
  312. //panic("XTLS Splice: nil inbound or nil inbound.Conn")
  313. }
  314. }
  315. reader = buf.NewReadVReader(conn.Connection, rawConn)
  316. ct = counter
  317. if conn.SHOW {
  318. fmt.Println(conn.MARK, "ReadV")
  319. }
  320. }
  321. buffer, err := reader.ReadMultiBuffer()
  322. if !buffer.IsEmpty() {
  323. if ct != nil {
  324. ct.Add(int64(buffer.Len()))
  325. }
  326. timer.Update()
  327. if werr := writer.WriteMultiBuffer(buffer); werr != nil {
  328. return werr
  329. }
  330. }
  331. if err != nil {
  332. return err
  333. }
  334. }
  335. }()
  336. if err != nil && errors.Cause(err) != io.EOF {
  337. return err
  338. }
  339. return nil
  340. }