protocol.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. package trojan
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "net"
  7. "os"
  8. "sync"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/rw"
  16. )
  17. const (
  18. KeyLength = 56
  19. CommandTCP = 1
  20. CommandUDP = 3
  21. CommandMux = 0x7f
  22. )
  23. var CRLF = []byte{'\r', '\n'}
  24. var _ N.EarlyConn = (*ClientConn)(nil)
  25. type ClientConn struct {
  26. N.ExtendedConn
  27. key [KeyLength]byte
  28. destination M.Socksaddr
  29. headerWritten bool
  30. }
  31. func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
  32. return &ClientConn{
  33. ExtendedConn: bufio.NewExtendedConn(conn),
  34. key: key,
  35. destination: destination,
  36. }
  37. }
  38. func (c *ClientConn) NeedHandshake() bool {
  39. return !c.headerWritten
  40. }
  41. func (c *ClientConn) Write(p []byte) (n int, err error) {
  42. if c.headerWritten {
  43. return c.ExtendedConn.Write(p)
  44. }
  45. err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
  46. if err != nil {
  47. return
  48. }
  49. n = len(p)
  50. c.headerWritten = true
  51. return
  52. }
  53. func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
  54. if c.headerWritten {
  55. return c.ExtendedConn.WriteBuffer(buffer)
  56. }
  57. err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
  58. if err != nil {
  59. return err
  60. }
  61. c.headerWritten = true
  62. return nil
  63. }
  64. func (c *ClientConn) FrontHeadroom() int {
  65. if !c.headerWritten {
  66. return KeyLength + 5 + M.MaxSocksaddrLength
  67. }
  68. return 0
  69. }
  70. func (c *ClientConn) Upstream() any {
  71. return c.ExtendedConn
  72. }
  73. type ClientPacketConn struct {
  74. net.Conn
  75. access sync.Mutex
  76. key [KeyLength]byte
  77. headerWritten bool
  78. }
  79. func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
  80. return &ClientPacketConn{
  81. Conn: conn,
  82. key: key,
  83. }
  84. }
  85. func (c *ClientPacketConn) NeedHandshake() bool {
  86. return !c.headerWritten
  87. }
  88. func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  89. return ReadPacket(c.Conn, buffer)
  90. }
  91. func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  92. if !c.headerWritten {
  93. c.access.Lock()
  94. if c.headerWritten {
  95. c.access.Unlock()
  96. } else {
  97. err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
  98. c.headerWritten = true
  99. c.access.Unlock()
  100. return err
  101. }
  102. }
  103. return WritePacket(c.Conn, buffer, destination)
  104. }
  105. func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  106. buffer := buf.With(p)
  107. destination, err := c.ReadPacket(buffer)
  108. if err != nil {
  109. return
  110. }
  111. n = buffer.Len()
  112. if destination.IsFqdn() {
  113. addr = destination
  114. } else {
  115. addr = destination.UDPAddr()
  116. }
  117. return
  118. }
  119. func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  120. return bufio.WritePacket(c, p, addr)
  121. }
  122. func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
  123. n, _, err = c.ReadFrom(p)
  124. return
  125. }
  126. func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
  127. return 0, os.ErrInvalid
  128. }
  129. func (c *ClientPacketConn) FrontHeadroom() int {
  130. if !c.headerWritten {
  131. return KeyLength + 2*M.MaxSocksaddrLength + 9
  132. }
  133. return M.MaxSocksaddrLength + 4
  134. }
  135. func (c *ClientPacketConn) Upstream() any {
  136. return c.Conn
  137. }
  138. func Key(password string) [KeyLength]byte {
  139. var key [KeyLength]byte
  140. hash := sha256.New224()
  141. common.Must1(hash.Write([]byte(password)))
  142. hex.Encode(key[:], hash.Sum(nil))
  143. return key
  144. }
  145. func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
  146. _, err := conn.Write(key[:])
  147. if err != nil {
  148. return err
  149. }
  150. _, err = conn.Write(CRLF)
  151. if err != nil {
  152. return err
  153. }
  154. _, err = conn.Write([]byte{command})
  155. if err != nil {
  156. return err
  157. }
  158. err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
  159. if err != nil {
  160. return err
  161. }
  162. _, err = conn.Write(CRLF)
  163. if err != nil {
  164. return err
  165. }
  166. if len(payload) > 0 {
  167. _, err = conn.Write(payload)
  168. if err != nil {
  169. return err
  170. }
  171. }
  172. return nil
  173. }
  174. func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
  175. headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
  176. header := buf.NewSize(headerLen + len(payload))
  177. defer header.Release()
  178. common.Must1(header.Write(key[:]))
  179. common.Must1(header.Write(CRLF))
  180. common.Must(header.WriteByte(CommandTCP))
  181. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  182. if err != nil {
  183. return err
  184. }
  185. common.Must1(header.Write(CRLF))
  186. common.Must1(header.Write(payload))
  187. _, err = conn.Write(header.Bytes())
  188. if err != nil {
  189. return E.Cause(err, "write request")
  190. }
  191. return nil
  192. }
  193. func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  194. header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
  195. common.Must1(header.Write(key[:]))
  196. common.Must1(header.Write(CRLF))
  197. common.Must(header.WriteByte(CommandTCP))
  198. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  199. if err != nil {
  200. return err
  201. }
  202. common.Must1(header.Write(CRLF))
  203. _, err = conn.Write(payload.Bytes())
  204. if err != nil {
  205. return E.Cause(err, "write request")
  206. }
  207. return nil
  208. }
  209. func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  210. headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
  211. payloadLen := payload.Len()
  212. var header *buf.Buffer
  213. var writeHeader bool
  214. if payload.Start() >= headerLen {
  215. header = buf.With(payload.ExtendHeader(headerLen))
  216. } else {
  217. header = buf.NewSize(headerLen)
  218. defer header.Release()
  219. writeHeader = true
  220. }
  221. common.Must1(header.Write(key[:]))
  222. common.Must1(header.Write(CRLF))
  223. common.Must(header.WriteByte(CommandUDP))
  224. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  225. if err != nil {
  226. return err
  227. }
  228. common.Must1(header.Write(CRLF))
  229. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  230. common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
  231. common.Must1(header.Write(CRLF))
  232. if writeHeader {
  233. _, err := conn.Write(header.Bytes())
  234. if err != nil {
  235. return E.Cause(err, "write request")
  236. }
  237. }
  238. _, err = conn.Write(payload.Bytes())
  239. if err != nil {
  240. return E.Cause(err, "write payload")
  241. }
  242. return nil
  243. }
  244. func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
  245. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  246. if err != nil {
  247. return M.Socksaddr{}, E.Cause(err, "read destination")
  248. }
  249. var length uint16
  250. err = binary.Read(conn, binary.BigEndian, &length)
  251. if err != nil {
  252. return M.Socksaddr{}, E.Cause(err, "read chunk length")
  253. }
  254. err = rw.SkipN(conn, 2)
  255. if err != nil {
  256. return M.Socksaddr{}, E.Cause(err, "skip crlf")
  257. }
  258. _, err = buffer.ReadFullFrom(conn, int(length))
  259. return destination.Unwrap(), err
  260. }
  261. func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
  262. defer buffer.Release()
  263. bufferLen := buffer.Len()
  264. header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
  265. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  266. if err != nil {
  267. return err
  268. }
  269. common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
  270. common.Must1(header.Write(CRLF))
  271. _, err = conn.Write(buffer.Bytes())
  272. if err != nil {
  273. return E.Cause(err, "write packet")
  274. }
  275. return nil
  276. }