protocol.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package trojan
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "net"
  7. "os"
  8. "sync"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/rw"
  16. )
  17. const (
  18. KeyLength = 56
  19. CommandTCP = 1
  20. CommandUDP = 3
  21. CommandMux = 0x7f
  22. )
  23. var CRLF = []byte{'\r', '\n'}
  24. var _ N.EarlyWriter = (*ClientConn)(nil)
  25. type ClientConn struct {
  26. N.ExtendedConn
  27. key [KeyLength]byte
  28. destination M.Socksaddr
  29. headerWritten bool
  30. }
  31. func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
  32. return &ClientConn{
  33. ExtendedConn: bufio.NewExtendedConn(conn),
  34. key: key,
  35. destination: destination,
  36. }
  37. }
  38. func (c *ClientConn) NeedHandshakeForWrite() bool {
  39. return !c.headerWritten
  40. }
  41. func (c *ClientConn) Write(p []byte) (n int, err error) {
  42. if c.headerWritten {
  43. return c.ExtendedConn.Write(p)
  44. }
  45. err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
  46. if err != nil {
  47. return
  48. }
  49. n = len(p)
  50. c.headerWritten = true
  51. return
  52. }
  53. func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
  54. if c.headerWritten {
  55. return c.ExtendedConn.WriteBuffer(buffer)
  56. }
  57. err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
  58. if err != nil {
  59. return err
  60. }
  61. c.headerWritten = true
  62. return nil
  63. }
  64. func (c *ClientConn) FrontHeadroom() int {
  65. if !c.headerWritten {
  66. return KeyLength + 5 + M.MaxSocksaddrLength
  67. }
  68. return 0
  69. }
  70. func (c *ClientConn) Upstream() any {
  71. return c.ExtendedConn
  72. }
  73. func (c *ClientConn) ReaderReplaceable() bool {
  74. return c.headerWritten
  75. }
  76. func (c *ClientConn) WriterReplaceable() bool {
  77. return c.headerWritten
  78. }
  79. type ClientPacketConn struct {
  80. net.Conn
  81. access sync.Mutex
  82. key [KeyLength]byte
  83. headerWritten bool
  84. readWaitOptions N.ReadWaitOptions
  85. }
  86. func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
  87. return &ClientPacketConn{
  88. Conn: conn,
  89. key: key,
  90. }
  91. }
  92. func (c *ClientPacketConn) NeedHandshake() bool {
  93. return !c.headerWritten
  94. }
  95. func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  96. return ReadPacket(c.Conn, buffer)
  97. }
  98. func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  99. if !c.headerWritten {
  100. c.access.Lock()
  101. if c.headerWritten {
  102. c.access.Unlock()
  103. } else {
  104. err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
  105. c.headerWritten = true
  106. c.access.Unlock()
  107. return err
  108. }
  109. }
  110. return WritePacket(c.Conn, buffer, destination)
  111. }
  112. func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  113. buffer := buf.With(p)
  114. destination, err := c.ReadPacket(buffer)
  115. if err != nil {
  116. return
  117. }
  118. n = buffer.Len()
  119. if destination.IsFqdn() {
  120. addr = destination
  121. } else {
  122. addr = destination.UDPAddr()
  123. }
  124. return
  125. }
  126. func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  127. return bufio.WritePacket(c, p, addr)
  128. }
  129. func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
  130. n, _, err = c.ReadFrom(p)
  131. return
  132. }
  133. func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
  134. return 0, os.ErrInvalid
  135. }
  136. func (c *ClientPacketConn) FrontHeadroom() int {
  137. if !c.headerWritten {
  138. return KeyLength + 2*M.MaxSocksaddrLength + 9
  139. }
  140. return M.MaxSocksaddrLength + 4
  141. }
  142. func (c *ClientPacketConn) Upstream() any {
  143. return c.Conn
  144. }
  145. func Key(password string) [KeyLength]byte {
  146. var key [KeyLength]byte
  147. hash := sha256.New224()
  148. common.Must1(hash.Write([]byte(password)))
  149. hex.Encode(key[:], hash.Sum(nil))
  150. return key
  151. }
  152. func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
  153. _, err := conn.Write(key[:])
  154. if err != nil {
  155. return err
  156. }
  157. _, err = conn.Write(CRLF)
  158. if err != nil {
  159. return err
  160. }
  161. _, err = conn.Write([]byte{command})
  162. if err != nil {
  163. return err
  164. }
  165. err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
  166. if err != nil {
  167. return err
  168. }
  169. _, err = conn.Write(CRLF)
  170. if err != nil {
  171. return err
  172. }
  173. if len(payload) > 0 {
  174. _, err = conn.Write(payload)
  175. if err != nil {
  176. return err
  177. }
  178. }
  179. return nil
  180. }
  181. func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
  182. headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
  183. header := buf.NewSize(headerLen + len(payload))
  184. defer header.Release()
  185. common.Must1(header.Write(key[:]))
  186. common.Must1(header.Write(CRLF))
  187. common.Must(header.WriteByte(CommandTCP))
  188. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  189. if err != nil {
  190. return err
  191. }
  192. common.Must1(header.Write(CRLF))
  193. common.Must1(header.Write(payload))
  194. _, err = conn.Write(header.Bytes())
  195. if err != nil {
  196. return E.Cause(err, "write request")
  197. }
  198. return nil
  199. }
  200. func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  201. header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
  202. common.Must1(header.Write(key[:]))
  203. common.Must1(header.Write(CRLF))
  204. common.Must(header.WriteByte(CommandTCP))
  205. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  206. if err != nil {
  207. return err
  208. }
  209. common.Must1(header.Write(CRLF))
  210. _, err = conn.Write(payload.Bytes())
  211. if err != nil {
  212. return E.Cause(err, "write request")
  213. }
  214. return nil
  215. }
  216. func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  217. headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
  218. payloadLen := payload.Len()
  219. var header *buf.Buffer
  220. var writeHeader bool
  221. if payload.Start() >= headerLen {
  222. header = buf.With(payload.ExtendHeader(headerLen))
  223. } else {
  224. header = buf.NewSize(headerLen)
  225. defer header.Release()
  226. writeHeader = true
  227. }
  228. common.Must1(header.Write(key[:]))
  229. common.Must1(header.Write(CRLF))
  230. common.Must(header.WriteByte(CommandUDP))
  231. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  232. if err != nil {
  233. return err
  234. }
  235. common.Must1(header.Write(CRLF))
  236. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  237. common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
  238. common.Must1(header.Write(CRLF))
  239. if writeHeader {
  240. _, err := conn.Write(header.Bytes())
  241. if err != nil {
  242. return E.Cause(err, "write request")
  243. }
  244. }
  245. _, err = conn.Write(payload.Bytes())
  246. if err != nil {
  247. return E.Cause(err, "write payload")
  248. }
  249. return nil
  250. }
  251. func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
  252. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  253. if err != nil {
  254. return M.Socksaddr{}, E.Cause(err, "read destination")
  255. }
  256. var length uint16
  257. err = binary.Read(conn, binary.BigEndian, &length)
  258. if err != nil {
  259. return M.Socksaddr{}, E.Cause(err, "read chunk length")
  260. }
  261. err = rw.SkipN(conn, 2)
  262. if err != nil {
  263. return M.Socksaddr{}, E.Cause(err, "skip crlf")
  264. }
  265. _, err = buffer.ReadFullFrom(conn, int(length))
  266. return destination, err
  267. }
  268. func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
  269. defer buffer.Release()
  270. bufferLen := buffer.Len()
  271. header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
  272. err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
  273. if err != nil {
  274. return err
  275. }
  276. common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
  277. common.Must1(header.Write(CRLF))
  278. _, err = conn.Write(buffer.Bytes())
  279. if err != nil {
  280. return E.Cause(err, "write packet")
  281. }
  282. return nil
  283. }