protocol.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package trojan
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "io"
  7. "net"
  8. "os"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/rw"
  16. )
  17. const (
  18. KeyLength = 56
  19. CommandTCP = 1
  20. CommandUDP = 3
  21. CommandMux = 0x7f
  22. )
  23. var CRLF = []byte{'\r', '\n'}
  24. var _ N.EarlyConn = (*ClientConn)(nil)
  25. type ClientConn struct {
  26. N.ExtendedConn
  27. key [KeyLength]byte
  28. destination M.Socksaddr
  29. headerWritten bool
  30. }
  31. func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
  32. return &ClientConn{
  33. ExtendedConn: bufio.NewExtendedConn(conn),
  34. key: key,
  35. destination: destination,
  36. }
  37. }
  38. func (c *ClientConn) NeedHandshake() bool {
  39. return !c.headerWritten
  40. }
  41. func (c *ClientConn) Write(p []byte) (n int, err error) {
  42. if c.headerWritten {
  43. return c.ExtendedConn.Write(p)
  44. }
  45. err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
  46. if err != nil {
  47. return
  48. }
  49. n = len(p)
  50. c.headerWritten = true
  51. return
  52. }
  53. func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
  54. if c.headerWritten {
  55. return c.ExtendedConn.WriteBuffer(buffer)
  56. }
  57. err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
  58. if err != nil {
  59. return err
  60. }
  61. c.headerWritten = true
  62. return nil
  63. }
  64. func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
  65. if !c.headerWritten {
  66. return bufio.ReadFrom0(c, r)
  67. }
  68. return bufio.Copy(c.ExtendedConn, r)
  69. }
  70. func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
  71. return bufio.Copy(w, c.ExtendedConn)
  72. }
  73. func (c *ClientConn) FrontHeadroom() int {
  74. if !c.headerWritten {
  75. return KeyLength + 5 + M.MaxSocksaddrLength
  76. }
  77. return 0
  78. }
  79. func (c *ClientConn) Upstream() any {
  80. return c.ExtendedConn
  81. }
  82. type ClientPacketConn struct {
  83. net.Conn
  84. key [KeyLength]byte
  85. headerWritten bool
  86. }
  87. func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
  88. return &ClientPacketConn{
  89. Conn: conn,
  90. key: key,
  91. }
  92. }
  93. func (c *ClientPacketConn) NeedHandshake() bool {
  94. return !c.headerWritten
  95. }
  96. func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  97. return ReadPacket(c.Conn, buffer)
  98. }
  99. func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  100. if !c.headerWritten {
  101. err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
  102. c.headerWritten = true
  103. return err
  104. }
  105. return WritePacket(c.Conn, buffer, destination)
  106. }
  107. func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  108. buffer := buf.With(p)
  109. destination, err := c.ReadPacket(buffer)
  110. if err != nil {
  111. return
  112. }
  113. n = buffer.Len()
  114. if destination.IsFqdn() {
  115. addr = destination
  116. } else {
  117. addr = destination.UDPAddr()
  118. }
  119. return
  120. }
  121. func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  122. return bufio.WritePacket(c, p, addr)
  123. }
  124. func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
  125. n, _, err = c.ReadFrom(p)
  126. return
  127. }
  128. func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
  129. return 0, os.ErrInvalid
  130. }
  131. func (c *ClientPacketConn) FrontHeadroom() int {
  132. if !c.headerWritten {
  133. return KeyLength + 2*M.MaxSocksaddrLength + 9
  134. }
  135. return M.MaxSocksaddrLength + 4
  136. }
  137. func (c *ClientPacketConn) Upstream() any {
  138. return c.Conn
  139. }
  140. func Key(password string) [KeyLength]byte {
  141. var key [KeyLength]byte
  142. hash := sha256.New224()
  143. common.Must1(hash.Write([]byte(password)))
  144. hex.Encode(key[:], hash.Sum(nil))
  145. return key
  146. }
  147. func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
  148. _, err := conn.Write(key[:])
  149. if err != nil {
  150. return err
  151. }
  152. _, err = conn.Write(CRLF)
  153. if err != nil {
  154. return err
  155. }
  156. _, err = conn.Write([]byte{command})
  157. if err != nil {
  158. return err
  159. }
  160. err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
  161. if err != nil {
  162. return err
  163. }
  164. _, err = conn.Write(CRLF)
  165. if err != nil {
  166. return err
  167. }
  168. if len(payload) > 0 {
  169. _, err = conn.Write(payload)
  170. if err != nil {
  171. return err
  172. }
  173. }
  174. return nil
  175. }
  176. func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
  177. headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
  178. var header *buf.Buffer
  179. defer header.Release()
  180. var writeHeader bool
  181. if len(payload) > 0 && headerLen+len(payload) < 65535 {
  182. buffer := buf.StackNewSize(headerLen + len(payload))
  183. defer common.KeepAlive(buffer)
  184. header = common.Dup(buffer)
  185. } else {
  186. buffer := buf.StackNewSize(headerLen)
  187. defer common.KeepAlive(buffer)
  188. header = common.Dup(buffer)
  189. writeHeader = true
  190. }
  191. common.Must1(header.Write(key[:]))
  192. common.Must1(header.Write(CRLF))
  193. common.Must(header.WriteByte(CommandTCP))
  194. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  195. common.Must1(header.Write(CRLF))
  196. if !writeHeader {
  197. common.Must1(header.Write(payload))
  198. }
  199. _, err := conn.Write(header.Bytes())
  200. if err != nil {
  201. return E.Cause(err, "write request")
  202. }
  203. if writeHeader {
  204. _, err = conn.Write(payload)
  205. if err != nil {
  206. return E.Cause(err, "write payload")
  207. }
  208. }
  209. return nil
  210. }
  211. func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  212. header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
  213. common.Must1(header.Write(key[:]))
  214. common.Must1(header.Write(CRLF))
  215. common.Must(header.WriteByte(CommandTCP))
  216. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  217. common.Must1(header.Write(CRLF))
  218. _, err := conn.Write(payload.Bytes())
  219. if err != nil {
  220. return E.Cause(err, "write request")
  221. }
  222. return nil
  223. }
  224. func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
  225. headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
  226. payloadLen := payload.Len()
  227. var header *buf.Buffer
  228. defer header.Release()
  229. var writeHeader bool
  230. if payload.Start() >= headerLen {
  231. header = buf.With(payload.ExtendHeader(headerLen))
  232. } else {
  233. buffer := buf.StackNewSize(headerLen)
  234. defer common.KeepAlive(buffer)
  235. header = common.Dup(buffer)
  236. writeHeader = true
  237. }
  238. common.Must1(header.Write(key[:]))
  239. common.Must1(header.Write(CRLF))
  240. common.Must(header.WriteByte(CommandUDP))
  241. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  242. common.Must1(header.Write(CRLF))
  243. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  244. common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
  245. common.Must1(header.Write(CRLF))
  246. if writeHeader {
  247. _, err := conn.Write(header.Bytes())
  248. if err != nil {
  249. return E.Cause(err, "write request")
  250. }
  251. }
  252. _, err := conn.Write(payload.Bytes())
  253. if err != nil {
  254. return E.Cause(err, "write payload")
  255. }
  256. return nil
  257. }
  258. func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
  259. destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
  260. if err != nil {
  261. return M.Socksaddr{}, E.Cause(err, "read destination")
  262. }
  263. var length uint16
  264. err = binary.Read(conn, binary.BigEndian, &length)
  265. if err != nil {
  266. return M.Socksaddr{}, E.Cause(err, "read chunk length")
  267. }
  268. err = rw.SkipN(conn, 2)
  269. if err != nil {
  270. return M.Socksaddr{}, E.Cause(err, "skip crlf")
  271. }
  272. _, err = buffer.ReadFullFrom(conn, int(length))
  273. return destination, err
  274. }
  275. func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
  276. defer buffer.Release()
  277. bufferLen := buffer.Len()
  278. header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
  279. common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
  280. common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
  281. common.Must1(header.Write(CRLF))
  282. _, err := conn.Write(buffer.Bytes())
  283. if err != nil {
  284. return E.Cause(err, "write packet")
  285. }
  286. return nil
  287. }