client_bind.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/sing/service/pause"
  14. "github.com/sagernet/wireguard-go/conn"
  15. )
  16. var _ conn.Bind = (*ClientBind)(nil)
  17. type ClientBind struct {
  18. ctx context.Context
  19. errorHandler E.Handler
  20. dialer N.Dialer
  21. reservedForEndpoint map[M.Socksaddr][3]uint8
  22. connAccess sync.Mutex
  23. conn *wireConn
  24. done chan struct{}
  25. isConnect bool
  26. connectAddr M.Socksaddr
  27. reserved [3]uint8
  28. pauseManager pause.Manager
  29. }
  30. func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
  31. return &ClientBind{
  32. ctx: ctx,
  33. errorHandler: errorHandler,
  34. dialer: dialer,
  35. reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
  36. isConnect: isConnect,
  37. connectAddr: connectAddr,
  38. reserved: reserved,
  39. pauseManager: pause.ManagerFromContext(ctx),
  40. }
  41. }
  42. func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
  43. c.reservedForEndpoint[destination] = reserved
  44. }
  45. func (c *ClientBind) connect() (*wireConn, error) {
  46. serverConn := c.conn
  47. if serverConn != nil {
  48. select {
  49. case <-serverConn.done:
  50. serverConn = nil
  51. default:
  52. return serverConn, nil
  53. }
  54. }
  55. c.connAccess.Lock()
  56. defer c.connAccess.Unlock()
  57. serverConn = c.conn
  58. if serverConn != nil {
  59. select {
  60. case <-serverConn.done:
  61. serverConn = nil
  62. default:
  63. return serverConn, nil
  64. }
  65. }
  66. if c.isConnect {
  67. udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
  68. if err != nil {
  69. return nil, err
  70. }
  71. c.conn = &wireConn{
  72. PacketConn: &bufio.UnbindPacketConn{
  73. ExtendedConn: bufio.NewExtendedConn(udpConn),
  74. Addr: c.connectAddr,
  75. },
  76. done: make(chan struct{}),
  77. }
  78. } else {
  79. udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  80. if err != nil {
  81. return nil, err
  82. }
  83. c.conn = &wireConn{
  84. PacketConn: bufio.NewPacketConn(udpConn),
  85. done: make(chan struct{}),
  86. }
  87. }
  88. return c.conn, nil
  89. }
  90. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  91. select {
  92. case <-c.done:
  93. err = net.ErrClosed
  94. return
  95. default:
  96. }
  97. return []conn.ReceiveFunc{c.receive}, 0, nil
  98. }
  99. func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
  100. udpConn, err := c.connect()
  101. if err != nil {
  102. select {
  103. case <-c.done:
  104. return
  105. default:
  106. }
  107. c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
  108. err = nil
  109. time.Sleep(time.Second)
  110. c.pauseManager.WaitActive()
  111. return
  112. }
  113. n, addr, err := udpConn.ReadFrom(packets[0])
  114. if err != nil {
  115. udpConn.Close()
  116. select {
  117. case <-c.done:
  118. default:
  119. c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
  120. err = nil
  121. }
  122. return
  123. }
  124. sizes[0] = n
  125. if n > 3 {
  126. b := packets[0]
  127. b[1] = 0
  128. b[2] = 0
  129. b[3] = 0
  130. }
  131. eps[0] = Endpoint(M.SocksaddrFromNet(addr))
  132. count = 1
  133. return
  134. }
  135. func (c *ClientBind) Reset() {
  136. common.Close(common.PtrOrNil(c.conn))
  137. }
  138. func (c *ClientBind) Close() error {
  139. common.Close(common.PtrOrNil(c.conn))
  140. if c.done == nil {
  141. c.done = make(chan struct{})
  142. return nil
  143. }
  144. select {
  145. case <-c.done:
  146. return net.ErrClosed
  147. default:
  148. close(c.done)
  149. }
  150. return nil
  151. }
  152. func (c *ClientBind) SetMark(mark uint32) error {
  153. return nil
  154. }
  155. func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
  156. udpConn, err := c.connect()
  157. if err != nil {
  158. return err
  159. }
  160. destination := M.Socksaddr(ep.(Endpoint))
  161. for _, b := range bufs {
  162. if len(b) > 3 {
  163. reserved, loaded := c.reservedForEndpoint[destination]
  164. if !loaded {
  165. reserved = c.reserved
  166. }
  167. b[1] = reserved[0]
  168. b[2] = reserved[1]
  169. b[3] = reserved[2]
  170. }
  171. _, err = udpConn.WriteTo(b, destination)
  172. if err != nil {
  173. udpConn.Close()
  174. return err
  175. }
  176. }
  177. return nil
  178. }
  179. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  180. return Endpoint(M.ParseSocksaddr(s)), nil
  181. }
  182. func (c *ClientBind) BatchSize() int {
  183. return 1
  184. }
  185. type wireConn struct {
  186. net.PacketConn
  187. access sync.Mutex
  188. done chan struct{}
  189. }
  190. func (w *wireConn) Close() error {
  191. w.access.Lock()
  192. defer w.access.Unlock()
  193. select {
  194. case <-w.done:
  195. return net.ErrClosed
  196. default:
  197. }
  198. w.PacketConn.Close()
  199. close(w.done)
  200. return nil
  201. }