protocol.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. package trojan
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "io"
  7. "net"
  8. "os"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/rw"
  16. )
  17. const (
  18. KeyLength = 56
  19. CommandTCP = 1
  20. CommandUDP = 3
  21. CommandMux = 0x7f
  22. )
  23. var CRLF = []byte{'\r', '\n'}
  24. type ClientConn struct {
  25. N.ExtendedConn
  26. key [KeyLength]byte
  27. destination M.Socksaddr
  28. headerWritten bool
  29. }
  30. func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
  31. return &ClientConn{
  32. ExtendedConn: bufio.NewExtendedConn(conn),
  33. key: key,
  34. destination: destination,
  35. }
  36. }
  37. func (c *ClientConn) Write(p []byte) (n int, err error) {
  38. if c.headerWritten {
  39. return c.ExtendedConn.Write(p)
  40. }
  41. err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
  42. if err != nil {
  43. return
  44. }
  45. n = len(p)
  46. c.headerWritten = true
  47. return
  48. }
  49. func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
  50. if c.headerWritten {
  51. return c.ExtendedConn.WriteBuffer(buffer)
  52. }
  53. err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
  54. if err != nil {
  55. return err
  56. }
  57. c.headerWritten = true
  58. return nil
  59. }
  60. func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
  61. if !c.headerWritten {
  62. return bufio.ReadFrom0(c, r)
  63. }
  64. return bufio.Copy(c.ExtendedConn, r)
  65. }
  66. func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
  67. return bufio.Copy(w, c.ExtendedConn)
  68. }
  69. func (c *ClientConn) FrontHeadroom() int {
  70. if !c.headerWritten {
  71. return KeyLength + 5 + M.MaxSocksaddrLength
  72. }
  73. return 0
  74. }
  75. func (c *ClientConn) Upstream() any {
  76. return c.ExtendedConn
  77. }
  78. type ClientPacketConn struct {
  79. net.Conn
  80. key [KeyLength]byte
  81. headerWritten bool
  82. }
  83. func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
  84. return &ClientPacketConn{
  85. Conn: conn,
  86. key: key,
  87. }
  88. }
  89. func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  90. return ReadPacket(c.Conn, buffer)
  91. }
  92. func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  93. if !c.headerWritten {
  94. err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
  95. c.headerWritten = true
  96. return err
  97. }
  98. return WritePacket(c.Conn, buffer, destination)
  99. }
  100. func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  101. buffer := buf.With(p)
  102. destination, err := c.ReadPacket(buffer)
  103. if err != nil {
  104. return
  105. }
  106. n = buffer.Len()
  107. addr = destination.UDPAddr()
  108. return
  109. }
  110. func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  111. return bufio.WritePacket(c, p, addr)
  112. }
  113. func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
  114. n, _, err = c.ReadFrom(p)
  115. return
  116. }
  117. func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
  118. return 0, os.ErrInvalid
  119. }
  120. func (c *ClientPacketConn) FrontHeadroom() int {
  121. if !c.headerWritten {
  122. return KeyLength + 2*M.MaxSocksaddrLength + 9
  123. }
  124. return M.MaxSocksaddrLength + 4
  125. }
  126. func (c *ClientPacketConn) Upstream() any {
  127. return c.Conn
  128. }
  129. func Key(password string) [KeyLength]byte {
  130. var key [KeyLength]byte
  131. hash := sha256.New224()
  132. common.Must1(hash.Write([]byte(password)))
  133. hex.Encode(key[:], hash.Sum(nil))
  134. return key
  135. }
  136. func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
  137. _, err := conn.Write(key[:])
  138. if err != nil {
  139. return err
  140. }
  141. _, err = conn.Write(CRLF)
  142. if err != nil {
  143. return err
  144. }
  145. _, err = conn.Write([]byte{command})
  146. if err != nil {
  147. return err
  148. }
  149. err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
  150. if err != nil {
  151. return err
  152. }
  153. _, err = conn.Write(CRLF)
  154. if err != nil {
  155. return err
  156. }
  157. if len(payload) > 0 {
  158. _, err = conn.Write(payload)
  159. if err != nil {
  160. return err
  161. }
  162. }
  163. return nil
  164. }
  165. func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
  166. headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
  167. var header *buf.Buffer
  168. defer header.Release()
  169. var writeHeader bool
  170. if len(payload) > 0 && headerLen+len(payload) < 65535 {
  171. buffer := buf.StackNewSize(headerLen + len(payload))
  172. defer common.KeepAlive(buffer)
  173. header = common.Dup(buffer)
  174. } else {
  175. buffer := buf.StackNewSize(headerLen)
  176. defer common.KeepAlive(buffer)
  177. header = common.Dup(buffer)
  178. writeHeader = true
  179. }
  180. common.Must1(header.Write(key[:]))
  181. common.Must1(header.Write(CRLF))
  182. common.Must(header.WriteByte(CommandTCP))
  183. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  184. common.Must1(header.Write(CRLF))
  185. if !writeHeader {
  186. common.Must1(header.Write(payload))
  187. }
  188. _, err := conn.Write(header.Bytes())
  189. if err != nil {
  190. return E.Cause(err, "write request")
  191. }
  192. if writeHeader {
  193. _, err = conn.Write(payload)
  194. if err != nil {
  195. return E.Cause(err, "write payload")
  196. }
  197. }
  198. return nil
  199. }
  200. func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  201. header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
  202. common.Must1(header.Write(key[:]))
  203. common.Must1(header.Write(CRLF))
  204. common.Must(header.WriteByte(CommandTCP))
  205. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  206. common.Must1(header.Write(CRLF))
  207. _, err := conn.Write(payload.Bytes())
  208. if err != nil {
  209. return E.Cause(err, "write request")
  210. }
  211. return nil
  212. }
  213. func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  214. headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
  215. payloadLen := payload.Len()
  216. var header *buf.Buffer
  217. defer header.Release()
  218. var writeHeader bool
  219. if payload.Start() >= headerLen {
  220. header = buf.With(payload.ExtendHeader(headerLen))
  221. } else {
  222. buffer := buf.StackNewSize(headerLen)
  223. defer common.KeepAlive(buffer)
  224. header = common.Dup(buffer)
  225. writeHeader = true
  226. }
  227. common.Must1(header.Write(key[:]))
  228. common.Must1(header.Write(CRLF))
  229. common.Must(header.WriteByte(CommandUDP))
  230. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  231. common.Must1(header.Write(CRLF))
  232. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  233. common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
  234. common.Must1(header.Write(CRLF))
  235. if writeHeader {
  236. _, err := conn.Write(header.Bytes())
  237. if err != nil {
  238. return E.Cause(err, "write request")
  239. }
  240. }
  241. _, err := conn.Write(payload.Bytes())
  242. if err != nil {
  243. return E.Cause(err, "write payload")
  244. }
  245. return nil
  246. }
  247. func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
  248. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  249. if err != nil {
  250. return M.Socksaddr{}, E.Cause(err, "read destination")
  251. }
  252. var length uint16
  253. err = binary.Read(conn, binary.BigEndian, &length)
  254. if err != nil {
  255. return M.Socksaddr{}, E.Cause(err, "read chunk length")
  256. }
  257. err = rw.SkipN(conn, 2)
  258. if err != nil {
  259. return M.Socksaddr{}, E.Cause(err, "skip crlf")
  260. }
  261. _, err = buffer.ReadFullFrom(conn, int(length))
  262. return destination, err
  263. }
  264. func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
  265. defer buffer.Release()
  266. bufferLen := buffer.Len()
  267. header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
  268. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  269. common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
  270. common.Must1(header.Write(CRLF))
  271. _, err := conn.Write(buffer.Bytes())
  272. if err != nil {
  273. return E.Cause(err, "write packet")
  274. }
  275. return nil
  276. }