client_bind.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/wireguard-go/conn"
  14. )
  15. var _ conn.Bind = (*ClientBind)(nil)
  16. type ClientBind struct {
  17. ctx context.Context
  18. errorHandler E.Handler
  19. dialer N.Dialer
  20. reservedForEndpoint map[netip.AddrPort][3]uint8
  21. connAccess sync.Mutex
  22. conn *wireConn
  23. done chan struct{}
  24. isConnect bool
  25. connectAddr netip.AddrPort
  26. reserved [3]uint8
  27. }
  28. func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
  29. return &ClientBind{
  30. ctx: ctx,
  31. errorHandler: errorHandler,
  32. dialer: dialer,
  33. reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
  34. isConnect: isConnect,
  35. connectAddr: connectAddr,
  36. reserved: reserved,
  37. }
  38. }
  39. func (c *ClientBind) connect() (*wireConn, error) {
  40. serverConn := c.conn
  41. if serverConn != nil {
  42. select {
  43. case <-serverConn.done:
  44. serverConn = nil
  45. default:
  46. return serverConn, nil
  47. }
  48. }
  49. c.connAccess.Lock()
  50. defer c.connAccess.Unlock()
  51. serverConn = c.conn
  52. if serverConn != nil {
  53. select {
  54. case <-serverConn.done:
  55. serverConn = nil
  56. default:
  57. return serverConn, nil
  58. }
  59. }
  60. if c.isConnect {
  61. udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
  62. if err != nil {
  63. return nil, err
  64. }
  65. c.conn = &wireConn{
  66. PacketConn: bufio.NewUnbindPacketConn(udpConn),
  67. done: make(chan struct{}),
  68. }
  69. } else {
  70. udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  71. if err != nil {
  72. return nil, err
  73. }
  74. c.conn = &wireConn{
  75. PacketConn: bufio.NewPacketConn(udpConn),
  76. done: make(chan struct{}),
  77. }
  78. }
  79. return c.conn, nil
  80. }
  81. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  82. select {
  83. case <-c.done:
  84. err = net.ErrClosed
  85. return
  86. default:
  87. }
  88. return []conn.ReceiveFunc{c.receive}, 0, nil
  89. }
  90. func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
  91. udpConn, err := c.connect()
  92. if err != nil {
  93. select {
  94. case <-c.done:
  95. return
  96. default:
  97. }
  98. c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
  99. err = nil
  100. time.Sleep(time.Second)
  101. return
  102. }
  103. n, addr, err := udpConn.ReadFrom(packets[0])
  104. if err != nil {
  105. udpConn.Close()
  106. select {
  107. case <-c.done:
  108. default:
  109. c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
  110. err = nil
  111. }
  112. return
  113. }
  114. sizes[0] = n
  115. if n > 3 {
  116. b := packets[0]
  117. common.ClearArray(b[1:4])
  118. }
  119. eps[0] = Endpoint(M.AddrPortFromNet(addr))
  120. count = 1
  121. return
  122. }
  123. func (c *ClientBind) Reset() {
  124. common.Close(common.PtrOrNil(c.conn))
  125. }
  126. func (c *ClientBind) Close() error {
  127. common.Close(common.PtrOrNil(c.conn))
  128. if c.done == nil {
  129. c.done = make(chan struct{})
  130. return nil
  131. }
  132. select {
  133. case <-c.done:
  134. default:
  135. close(c.done)
  136. }
  137. return nil
  138. }
  139. func (c *ClientBind) SetMark(mark uint32) error {
  140. return nil
  141. }
  142. func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
  143. udpConn, err := c.connect()
  144. if err != nil {
  145. return err
  146. }
  147. destination := netip.AddrPort(ep.(Endpoint))
  148. for _, b := range bufs {
  149. if len(b) > 3 {
  150. reserved, loaded := c.reservedForEndpoint[destination]
  151. if !loaded {
  152. reserved = c.reserved
  153. }
  154. copy(b[1:4], reserved[:])
  155. }
  156. _, err = udpConn.WriteTo(b, M.SocksaddrFromNetIP(destination))
  157. if err != nil {
  158. udpConn.Close()
  159. return err
  160. }
  161. }
  162. return nil
  163. }
  164. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  165. ap, err := netip.ParseAddrPort(s)
  166. if err != nil {
  167. return nil, err
  168. }
  169. return Endpoint(ap), nil
  170. }
  171. func (c *ClientBind) BatchSize() int {
  172. return 1
  173. }
  174. func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
  175. c.reservedForEndpoint[destination] = reserved
  176. }
  177. type wireConn struct {
  178. net.PacketConn
  179. access sync.Mutex
  180. done chan struct{}
  181. }
  182. func (w *wireConn) Close() error {
  183. w.access.Lock()
  184. defer w.access.Unlock()
  185. select {
  186. case <-w.done:
  187. return net.ErrClosed
  188. default:
  189. }
  190. w.PacketConn.Close()
  191. close(w.done)
  192. return nil
  193. }