service.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package mux
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "net"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/log"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/rw"
  16. )
  17. func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
  18. request, err := ReadRequest(conn)
  19. if err != nil {
  20. return err
  21. }
  22. session, err := request.Protocol.newServer(conn)
  23. if err != nil {
  24. return err
  25. }
  26. var stream net.Conn
  27. for {
  28. stream, err = session.Accept()
  29. if err != nil {
  30. return err
  31. }
  32. go newConnection(ctx, router, errorHandler, logger, stream, metadata)
  33. }
  34. }
  35. func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
  36. stream = &wrapStream{stream}
  37. request, err := ReadStreamRequest(stream)
  38. if err != nil {
  39. logger.ErrorContext(ctx, err)
  40. return
  41. }
  42. metadata.Destination = request.Destination
  43. if request.Network == N.NetworkTCP {
  44. logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
  45. hErr := router.RouteConnection(ctx, &ServerConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
  46. stream.Close()
  47. if hErr != nil {
  48. errorHandler.NewError(ctx, hErr)
  49. }
  50. } else {
  51. var packetConn N.PacketConn
  52. if !request.PacketAddr {
  53. logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
  54. packetConn = &ServerPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
  55. } else {
  56. logger.InfoContext(ctx, "inbound multiplex packet connection")
  57. packetConn = &ServerPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
  58. }
  59. hErr := router.RoutePacketConnection(ctx, packetConn, metadata)
  60. stream.Close()
  61. if hErr != nil {
  62. errorHandler.NewError(ctx, hErr)
  63. }
  64. }
  65. }
  66. var _ N.HandshakeConn = (*ServerConn)(nil)
  67. type ServerConn struct {
  68. N.ExtendedConn
  69. responseWrite bool
  70. }
  71. func (c *ServerConn) HandshakeFailure(err error) error {
  72. errMessage := err.Error()
  73. _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
  74. defer common.KeepAlive(_buffer)
  75. buffer := common.Dup(_buffer)
  76. defer buffer.Release()
  77. common.Must(
  78. buffer.WriteByte(statusError),
  79. rw.WriteVString(_buffer, errMessage),
  80. )
  81. return c.ExtendedConn.WriteBuffer(buffer)
  82. }
  83. func (c *ServerConn) Write(b []byte) (n int, err error) {
  84. if c.responseWrite {
  85. return c.ExtendedConn.Write(b)
  86. }
  87. _buffer := buf.StackNewSize(1 + len(b))
  88. defer common.KeepAlive(_buffer)
  89. buffer := common.Dup(_buffer)
  90. defer buffer.Release()
  91. common.Must(
  92. buffer.WriteByte(statusSuccess),
  93. common.Error(buffer.Write(b)),
  94. )
  95. _, err = c.ExtendedConn.Write(buffer.Bytes())
  96. if err != nil {
  97. return
  98. }
  99. c.responseWrite = true
  100. return len(b), nil
  101. }
  102. func (c *ServerConn) WriteBuffer(buffer *buf.Buffer) error {
  103. if c.responseWrite {
  104. return c.ExtendedConn.WriteBuffer(buffer)
  105. }
  106. buffer.ExtendHeader(1)[0] = statusSuccess
  107. c.responseWrite = true
  108. return c.ExtendedConn.WriteBuffer(buffer)
  109. }
  110. func (c *ServerConn) FrontHeadroom() int {
  111. if !c.responseWrite {
  112. return 1
  113. }
  114. return 0
  115. }
  116. func (c *ServerConn) Upstream() any {
  117. return c.ExtendedConn
  118. }
  119. var (
  120. _ N.HandshakeConn = (*ServerPacketConn)(nil)
  121. _ N.PacketConn = (*ServerPacketConn)(nil)
  122. )
  123. type ServerPacketConn struct {
  124. N.ExtendedConn
  125. destination M.Socksaddr
  126. responseWrite bool
  127. }
  128. func (c *ServerPacketConn) HandshakeFailure(err error) error {
  129. errMessage := err.Error()
  130. _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
  131. defer common.KeepAlive(_buffer)
  132. buffer := common.Dup(_buffer)
  133. defer buffer.Release()
  134. common.Must(
  135. buffer.WriteByte(statusError),
  136. rw.WriteVString(_buffer, errMessage),
  137. )
  138. return c.ExtendedConn.WriteBuffer(buffer)
  139. }
  140. func (c *ServerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  141. var length uint16
  142. err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
  143. if err != nil {
  144. return
  145. }
  146. if buffer.FreeLen() < int(length) {
  147. return destination, io.ErrShortBuffer
  148. }
  149. _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
  150. if err != nil {
  151. return
  152. }
  153. destination = c.destination
  154. return
  155. }
  156. func (c *ServerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  157. pLen := buffer.Len()
  158. common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
  159. if !c.responseWrite {
  160. buffer.ExtendHeader(1)[0] = statusSuccess
  161. c.responseWrite = true
  162. }
  163. return c.ExtendedConn.WriteBuffer(buffer)
  164. }
  165. func (c *ServerPacketConn) Upstream() any {
  166. return c.ExtendedConn
  167. }
  168. func (c *ServerPacketConn) FrontHeadroom() int {
  169. if !c.responseWrite {
  170. return 3
  171. }
  172. return 2
  173. }
  174. var (
  175. _ N.HandshakeConn = (*ServerPacketAddrConn)(nil)
  176. _ N.PacketConn = (*ServerPacketAddrConn)(nil)
  177. )
  178. type ServerPacketAddrConn struct {
  179. N.ExtendedConn
  180. responseWrite bool
  181. }
  182. func (c *ServerPacketAddrConn) HandshakeFailure(err error) error {
  183. errMessage := err.Error()
  184. _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
  185. defer common.KeepAlive(_buffer)
  186. buffer := common.Dup(_buffer)
  187. defer buffer.Release()
  188. common.Must(
  189. buffer.WriteByte(statusError),
  190. rw.WriteVString(_buffer, errMessage),
  191. )
  192. return c.ExtendedConn.WriteBuffer(buffer)
  193. }
  194. func (c *ServerPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  195. destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
  196. if err != nil {
  197. return
  198. }
  199. var length uint16
  200. err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
  201. if err != nil {
  202. return
  203. }
  204. if buffer.FreeLen() < int(length) {
  205. return destination, io.ErrShortBuffer
  206. }
  207. _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
  208. if err != nil {
  209. return
  210. }
  211. return
  212. }
  213. func (c *ServerPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  214. pLen := buffer.Len()
  215. common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
  216. common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
  217. if !c.responseWrite {
  218. buffer.ExtendHeader(1)[0] = statusSuccess
  219. c.responseWrite = true
  220. }
  221. return c.ExtendedConn.WriteBuffer(buffer)
  222. }
  223. func (c *ServerPacketAddrConn) Upstream() any {
  224. return c.ExtendedConn
  225. }
  226. func (c *ServerPacketAddrConn) FrontHeadroom() int {
  227. if !c.responseWrite {
  228. return 3 + M.MaxSocksaddrLength
  229. }
  230. return 2 + M.MaxSocksaddrLength
  231. }