client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package mux
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "net"
  7. "sync"
  8. "github.com/sagernet/sing-box/option"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/x/list"
  16. )
  17. var _ N.Dialer = (*Client)(nil)
  18. type Client struct {
  19. access sync.Mutex
  20. connections list.List[abstractSession]
  21. ctx context.Context
  22. dialer N.Dialer
  23. protocol Protocol
  24. maxConnections int
  25. minStreams int
  26. maxStreams int
  27. }
  28. func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
  29. return &Client{
  30. ctx: ctx,
  31. dialer: dialer,
  32. protocol: protocol,
  33. maxConnections: maxConnections,
  34. minStreams: minStreams,
  35. maxStreams: maxStreams,
  36. }
  37. }
  38. func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
  39. if !options.Enabled {
  40. return dialer, nil
  41. }
  42. if options.MaxConnections == 0 && options.MaxStreams == 0 {
  43. options.MinStreams = 8
  44. }
  45. protocol, err := ParseProtocol(options.Protocol)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
  50. }
  51. func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  52. switch N.NetworkName(network) {
  53. case N.NetworkTCP:
  54. stream, err := c.openStream()
  55. if err != nil {
  56. return nil, err
  57. }
  58. return &ClientConn{Conn: stream, destination: destination}, nil
  59. case N.NetworkUDP:
  60. stream, err := c.openStream()
  61. if err != nil {
  62. return nil, err
  63. }
  64. return bufio.NewUnbindPacketConn(&ClientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
  65. default:
  66. return nil, E.Extend(N.ErrUnknownNetwork, network)
  67. }
  68. }
  69. func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  70. stream, err := c.openStream()
  71. if err != nil {
  72. return nil, err
  73. }
  74. return &ClientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
  75. }
  76. func (c *Client) openStream() (net.Conn, error) {
  77. var (
  78. session abstractSession
  79. stream net.Conn
  80. err error
  81. )
  82. for attempts := 0; attempts < 2; attempts++ {
  83. session, err = c.offer()
  84. if err != nil {
  85. continue
  86. }
  87. stream, err = session.Open()
  88. if err != nil {
  89. continue
  90. }
  91. break
  92. }
  93. if err != nil {
  94. return nil, err
  95. }
  96. return &wrapStream{stream}, nil
  97. }
  98. func (c *Client) offer() (abstractSession, error) {
  99. c.access.Lock()
  100. defer c.access.Unlock()
  101. sessions := make([]abstractSession, 0, c.maxConnections)
  102. for element := c.connections.Front(); element != nil; {
  103. if element.Value.IsClosed() {
  104. nextElement := element.Next()
  105. c.connections.Remove(element)
  106. element = nextElement
  107. continue
  108. }
  109. sessions = append(sessions, element.Value)
  110. element = element.Next()
  111. }
  112. sLen := len(sessions)
  113. if sLen == 0 {
  114. return c.offerNew()
  115. }
  116. session := common.MinBy(sessions, abstractSession.NumStreams)
  117. numStreams := session.NumStreams()
  118. if numStreams == 0 {
  119. return session, nil
  120. }
  121. if c.maxConnections > 0 {
  122. if sLen >= c.maxConnections || numStreams < c.minStreams {
  123. return session, nil
  124. }
  125. } else {
  126. if c.maxStreams > 0 && numStreams < c.maxStreams {
  127. return session, nil
  128. }
  129. }
  130. return c.offerNew()
  131. }
  132. func (c *Client) offerNew() (abstractSession, error) {
  133. conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
  134. if err != nil {
  135. return nil, err
  136. }
  137. session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
  138. if err != nil {
  139. return nil, err
  140. }
  141. c.connections.PushBack(session)
  142. return session, nil
  143. }
  144. func (c *Client) Close() error {
  145. c.access.Lock()
  146. defer c.access.Unlock()
  147. for _, session := range c.connections.Array() {
  148. session.Close()
  149. }
  150. return nil
  151. }
  152. type ClientConn struct {
  153. net.Conn
  154. destination M.Socksaddr
  155. requestWrite bool
  156. responseRead bool
  157. }
  158. func (c *ClientConn) readResponse() error {
  159. response, err := ReadStreamResponse(c.Conn)
  160. if err != nil {
  161. return err
  162. }
  163. if response.Status == statusError {
  164. return E.New("remote error: ", response.Message)
  165. }
  166. return nil
  167. }
  168. func (c *ClientConn) Read(b []byte) (n int, err error) {
  169. if !c.responseRead {
  170. err = c.readResponse()
  171. if err != nil {
  172. return
  173. }
  174. c.responseRead = true
  175. }
  176. return c.Conn.Read(b)
  177. }
  178. func (c *ClientConn) Write(b []byte) (n int, err error) {
  179. if c.requestWrite {
  180. return c.Conn.Write(b)
  181. }
  182. request := StreamRequest{
  183. Network: N.NetworkTCP,
  184. Destination: c.destination,
  185. }
  186. _buffer := buf.StackNewSize(requestLen(request) + len(b))
  187. defer common.KeepAlive(_buffer)
  188. buffer := common.Dup(_buffer)
  189. defer buffer.Release()
  190. EncodeStreamRequest(request, buffer)
  191. buffer.Write(b)
  192. _, err = c.Conn.Write(buffer.Bytes())
  193. if err != nil {
  194. return
  195. }
  196. c.requestWrite = true
  197. return len(b), nil
  198. }
  199. func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
  200. if !c.requestWrite {
  201. return bufio.ReadFrom0(c, r)
  202. }
  203. return bufio.Copy(c.Conn, r)
  204. }
  205. func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
  206. if !c.responseRead {
  207. return bufio.WriteTo0(c, w)
  208. }
  209. return bufio.Copy(w, c.Conn)
  210. }
  211. func (c *ClientConn) LocalAddr() net.Addr {
  212. return c.Conn.LocalAddr()
  213. }
  214. func (c *ClientConn) RemoteAddr() net.Addr {
  215. return c.destination.TCPAddr()
  216. }
  217. func (c *ClientConn) ReaderReplaceable() bool {
  218. return c.responseRead
  219. }
  220. func (c *ClientConn) WriterReplaceable() bool {
  221. return c.requestWrite
  222. }
  223. func (c *ClientConn) Upstream() any {
  224. return c.Conn
  225. }
  226. type ClientPacketConn struct {
  227. N.ExtendedConn
  228. destination M.Socksaddr
  229. requestWrite bool
  230. responseRead bool
  231. }
  232. func (c *ClientPacketConn) readResponse() error {
  233. response, err := ReadStreamResponse(c.ExtendedConn)
  234. if err != nil {
  235. return err
  236. }
  237. if response.Status == statusError {
  238. return E.New("remote error: ", response.Message)
  239. }
  240. return nil
  241. }
  242. func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
  243. if !c.responseRead {
  244. err = c.readResponse()
  245. if err != nil {
  246. return
  247. }
  248. c.responseRead = true
  249. }
  250. var length uint16
  251. err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
  252. if err != nil {
  253. return
  254. }
  255. if cap(b) < int(length) {
  256. return 0, io.ErrShortBuffer
  257. }
  258. return io.ReadFull(c.ExtendedConn, b[:length])
  259. }
  260. func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
  261. request := StreamRequest{
  262. Network: N.NetworkUDP,
  263. Destination: c.destination,
  264. }
  265. rLen := requestLen(request)
  266. if len(payload) > 0 {
  267. rLen += 2 + len(payload)
  268. }
  269. _buffer := buf.StackNewSize(rLen)
  270. defer common.KeepAlive(_buffer)
  271. buffer := common.Dup(_buffer)
  272. defer buffer.Release()
  273. EncodeStreamRequest(request, buffer)
  274. if len(payload) > 0 {
  275. common.Must(
  276. binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
  277. common.Error(buffer.Write(payload)),
  278. )
  279. }
  280. _, err = c.ExtendedConn.Write(buffer.Bytes())
  281. if err != nil {
  282. return
  283. }
  284. c.requestWrite = true
  285. return len(payload), nil
  286. }
  287. func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
  288. if !c.requestWrite {
  289. return c.writeRequest(b)
  290. }
  291. err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
  292. if err != nil {
  293. return
  294. }
  295. return c.ExtendedConn.Write(b)
  296. }
  297. func (c *ClientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
  298. if !c.requestWrite {
  299. defer buffer.Release()
  300. return common.Error(c.writeRequest(buffer.Bytes()))
  301. }
  302. bLen := buffer.Len()
  303. binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
  304. return c.ExtendedConn.WriteBuffer(buffer)
  305. }
  306. func (c *ClientPacketConn) Headroom() int {
  307. return 2
  308. }
  309. func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  310. return c.WriteBuffer(buffer)
  311. }
  312. func (c *ClientPacketConn) LocalAddr() net.Addr {
  313. return c.ExtendedConn.LocalAddr()
  314. }
  315. func (c *ClientPacketConn) RemoteAddr() net.Addr {
  316. return c.destination.UDPAddr()
  317. }
  318. func (c *ClientPacketConn) Upstream() any {
  319. return c.ExtendedConn
  320. }
  321. var _ N.NetPacketConn = (*ClientPacketAddrConn)(nil)
  322. type ClientPacketAddrConn struct {
  323. N.ExtendedConn
  324. destination M.Socksaddr
  325. requestWrite bool
  326. responseRead bool
  327. }
  328. func (c *ClientPacketAddrConn) readResponse() error {
  329. response, err := ReadStreamResponse(c.ExtendedConn)
  330. if err != nil {
  331. return err
  332. }
  333. if response.Status == statusError {
  334. return E.New("remote error: ", response.Message)
  335. }
  336. return nil
  337. }
  338. func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  339. if !c.responseRead {
  340. err = c.readResponse()
  341. if err != nil {
  342. return
  343. }
  344. c.responseRead = true
  345. }
  346. destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
  347. if err != nil {
  348. return
  349. }
  350. addr = destination.UDPAddr()
  351. var length uint16
  352. err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
  353. if err != nil {
  354. return
  355. }
  356. if cap(p) < int(length) {
  357. return 0, nil, io.ErrShortBuffer
  358. }
  359. n, err = io.ReadFull(c.ExtendedConn, p[:length])
  360. return
  361. }
  362. func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
  363. request := StreamRequest{
  364. Network: N.NetworkUDP,
  365. Destination: c.destination,
  366. PacketAddr: true,
  367. }
  368. rLen := requestLen(request)
  369. if len(payload) > 0 {
  370. rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
  371. }
  372. _buffer := buf.StackNewSize(rLen)
  373. defer common.KeepAlive(_buffer)
  374. buffer := common.Dup(_buffer)
  375. defer buffer.Release()
  376. EncodeStreamRequest(request, buffer)
  377. if len(payload) > 0 {
  378. common.Must(
  379. M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
  380. binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
  381. common.Error(buffer.Write(payload)),
  382. )
  383. }
  384. _, err = c.ExtendedConn.Write(buffer.Bytes())
  385. if err != nil {
  386. return
  387. }
  388. c.requestWrite = true
  389. return len(payload), nil
  390. }
  391. func (c *ClientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  392. if !c.requestWrite {
  393. return c.writeRequest(p, M.SocksaddrFromNet(addr))
  394. }
  395. err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
  396. if err != nil {
  397. return
  398. }
  399. err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
  400. if err != nil {
  401. return
  402. }
  403. return c.ExtendedConn.Write(p)
  404. }
  405. func (c *ClientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  406. if !c.responseRead {
  407. err = c.readResponse()
  408. if err != nil {
  409. return
  410. }
  411. c.responseRead = true
  412. }
  413. destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
  414. if err != nil {
  415. return
  416. }
  417. var length uint16
  418. err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
  419. if err != nil {
  420. return
  421. }
  422. if buffer.FreeLen() < int(length) {
  423. return destination, io.ErrShortBuffer
  424. }
  425. _, err = io.ReadFull(c.ExtendedConn, buffer.Extend(int(length)))
  426. return
  427. }
  428. func (c *ClientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  429. if !c.requestWrite {
  430. defer buffer.Release()
  431. return common.Error(c.writeRequest(buffer.Bytes(), destination))
  432. }
  433. bLen := buffer.Len()
  434. header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
  435. common.Must(
  436. M.SocksaddrSerializer.WriteAddrPort(header, destination),
  437. binary.Write(header, binary.BigEndian, uint16(bLen)),
  438. )
  439. return c.ExtendedConn.WriteBuffer(buffer)
  440. }
  441. func (c *ClientPacketAddrConn) LocalAddr() net.Addr {
  442. return c.ExtendedConn.LocalAddr()
  443. }
  444. func (c *ClientPacketAddrConn) Headroom() int {
  445. return 2 + M.MaxSocksaddrLength
  446. }
  447. func (c *ClientPacketAddrConn) Upstream() any {
  448. return c.ExtendedConn
  449. }