client_bind.go 5.1 KB


  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/sing/service"
  14. "github.com/sagernet/sing/service/pause"
  15. "github.com/sagernet/wireguard-go/conn"
  16. )
  17. var _ conn.Bind = (*ClientBind)(nil)
  18. type ClientBind struct {
  19. ctx context.Context
  20. pauseManager pause.Manager
  21. bindCtx context.Context
  22. bindDone context.CancelFunc
  23. errorHandler E.Handler
  24. dialer N.Dialer
  25. reservedForEndpoint map[netip.AddrPort][3]uint8
  26. connAccess sync.Mutex
  27. conn *wireConn
  28. done chan struct{}
  29. isConnect bool
  30. connectAddr netip.AddrPort
  31. reserved [3]uint8
  32. }
  33. func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
  34. return &ClientBind{
  35. ctx: ctx,
  36. pauseManager: service.FromContext[pause.Manager](ctx),
  37. errorHandler: errorHandler,
  38. dialer: dialer,
  39. reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
  40. done: make(chan struct{}),
  41. isConnect: isConnect,
  42. connectAddr: connectAddr,
  43. reserved: reserved,
  44. }
  45. }
  46. func (c *ClientBind) connect() (*wireConn, error) {
  47. serverConn := c.conn
  48. if serverConn != nil {
  49. select {
  50. case <-serverConn.done:
  51. serverConn = nil
  52. default:
  53. return serverConn, nil
  54. }
  55. }
  56. c.connAccess.Lock()
  57. defer c.connAccess.Unlock()
  58. select {
  59. case <-c.done:
  60. return nil, net.ErrClosed
  61. default:
  62. }
  63. serverConn = c.conn
  64. if serverConn != nil {
  65. select {
  66. case <-serverConn.done:
  67. serverConn = nil
  68. default:
  69. return serverConn, nil
  70. }
  71. }
  72. if c.isConnect {
  73. udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
  74. if err != nil {
  75. return nil, err
  76. }
  77. c.conn = &wireConn{
  78. PacketConn: bufio.NewUnbindPacketConn(udpConn),
  79. done: make(chan struct{}),
  80. }
  81. } else {
  82. udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  83. if err != nil {
  84. return nil, err
  85. }
  86. c.conn = &wireConn{
  87. PacketConn: bufio.NewPacketConn(udpConn),
  88. done: make(chan struct{}),
  89. }
  90. }
  91. return c.conn, nil
  92. }
  93. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  94. select {
  95. case <-c.done:
  96. c.done = make(chan struct{})
  97. default:
  98. }
  99. c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
  100. return []conn.ReceiveFunc{c.receive}, 0, nil
  101. }
  102. func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
  103. udpConn, err := c.connect()
  104. if err != nil {
  105. select {
  106. case <-c.done:
  107. return
  108. default:
  109. }
  110. c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
  111. err = nil
  112. c.pauseManager.WaitActive()
  113. time.Sleep(time.Second)
  114. return
  115. }
  116. n, addr, err := udpConn.ReadFrom(packets[0])
  117. if err != nil {
  118. udpConn.Close()
  119. select {
  120. case <-c.done:
  121. default:
  122. c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
  123. err = nil
  124. }
  125. return
  126. }
  127. sizes[0] = n
  128. if n > 3 {
  129. b := packets[0]
  130. common.ClearArray(b[1:4])
  131. }
  132. eps[0] = Endpoint(M.AddrPortFromNet(addr))
  133. count = 1
  134. return
  135. }
  136. func (c *ClientBind) Close() error {
  137. select {
  138. case <-c.done:
  139. default:
  140. close(c.done)
  141. }
  142. if c.bindDone != nil {
  143. c.bindDone()
  144. }
  145. c.connAccess.Lock()
  146. defer c.connAccess.Unlock()
  147. common.Close(common.PtrOrNil(c.conn))
  148. return nil
  149. }
  150. func (c *ClientBind) SetMark(mark uint32) error {
  151. return nil
  152. }
  153. func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
  154. udpConn, err := c.connect()
  155. if err != nil {
  156. c.pauseManager.WaitActive()
  157. time.Sleep(time.Second)
  158. return err
  159. }
  160. destination := netip.AddrPort(ep.(Endpoint))
  161. for _, b := range bufs {
  162. if len(b) > 3 {
  163. reserved, loaded := c.reservedForEndpoint[destination]
  164. if !loaded {
  165. reserved = c.reserved
  166. }
  167. copy(b[1:4], reserved[:])
  168. }
  169. _, err = udpConn.WriteToUDPAddrPort(b, destination)
  170. if err != nil {
  171. udpConn.Close()
  172. return err
  173. }
  174. }
  175. return nil
  176. }
  177. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  178. ap, err := netip.ParseAddrPort(s)
  179. if err != nil {
  180. return nil, err
  181. }
  182. return Endpoint(ap), nil
  183. }
  184. func (c *ClientBind) BatchSize() int {
  185. return 1
  186. }
  187. func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
  188. c.reservedForEndpoint[destination] = reserved
  189. }
  190. type wireConn struct {
  191. net.PacketConn
  192. conn net.Conn
  193. access sync.Mutex
  194. done chan struct{}
  195. }
  196. func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
  197. if w.conn != nil {
  198. return w.conn.Write(b)
  199. }
  200. return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
  201. }
  202. func (w *wireConn) Close() error {
  203. w.access.Lock()
  204. defer w.access.Unlock()
  205. select {
  206. case <-w.done:
  207. return net.ErrClosed
  208. default:
  209. }
  210. w.PacketConn.Close()
  211. close(w.done)
  212. return nil
  213. }