client_bind.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. "github.com/sagernet/sing/common/logger"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. "github.com/sagernet/sing/service"
  15. "github.com/sagernet/sing/service/pause"
  16. "github.com/sagernet/wireguard-go/conn"
  17. )
  18. var _ conn.Bind = (*ClientBind)(nil)
  19. type ClientBind struct {
  20. ctx context.Context
  21. logger logger.Logger
  22. pauseManager pause.Manager
  23. bindCtx context.Context
  24. bindDone context.CancelFunc
  25. dialer N.Dialer
  26. reservedForEndpoint map[netip.AddrPort][3]uint8
  27. connAccess sync.Mutex
  28. conn *wireConn
  29. done chan struct{}
  30. isConnect bool
  31. connectAddr netip.AddrPort
  32. reserved [3]uint8
  33. }
  34. func NewClientBind(ctx context.Context, logger logger.Logger, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
  35. return &ClientBind{
  36. ctx: ctx,
  37. logger: logger,
  38. pauseManager: service.FromContext[pause.Manager](ctx),
  39. dialer: dialer,
  40. reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
  41. done: make(chan struct{}),
  42. isConnect: isConnect,
  43. connectAddr: connectAddr,
  44. reserved: reserved,
  45. }
  46. }
  47. func (c *ClientBind) connect() (*wireConn, error) {
  48. serverConn := c.conn
  49. if serverConn != nil {
  50. select {
  51. case <-serverConn.done:
  52. serverConn = nil
  53. default:
  54. return serverConn, nil
  55. }
  56. }
  57. c.connAccess.Lock()
  58. defer c.connAccess.Unlock()
  59. select {
  60. case <-c.done:
  61. return nil, net.ErrClosed
  62. default:
  63. }
  64. serverConn = c.conn
  65. if serverConn != nil {
  66. select {
  67. case <-serverConn.done:
  68. serverConn = nil
  69. default:
  70. return serverConn, nil
  71. }
  72. }
  73. if c.isConnect {
  74. udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
  75. if err != nil {
  76. return nil, err
  77. }
  78. c.conn = &wireConn{
  79. PacketConn: bufio.NewUnbindPacketConn(udpConn),
  80. done: make(chan struct{}),
  81. }
  82. } else {
  83. udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  84. if err != nil {
  85. return nil, err
  86. }
  87. c.conn = &wireConn{
  88. PacketConn: bufio.NewPacketConn(udpConn),
  89. done: make(chan struct{}),
  90. }
  91. }
  92. return c.conn, nil
  93. }
  94. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  95. select {
  96. case <-c.done:
  97. c.done = make(chan struct{})
  98. default:
  99. }
  100. c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
  101. return []conn.ReceiveFunc{c.receive}, 0, nil
  102. }
  103. func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
  104. udpConn, err := c.connect()
  105. if err != nil {
  106. select {
  107. case <-c.done:
  108. return
  109. default:
  110. }
  111. c.logger.Error(E.Cause(err, "connect to server"))
  112. err = nil
  113. c.pauseManager.WaitActive()
  114. time.Sleep(time.Second)
  115. return
  116. }
  117. n, addr, err := udpConn.ReadFrom(packets[0])
  118. if err != nil {
  119. udpConn.Close()
  120. select {
  121. case <-c.done:
  122. default:
  123. c.logger.Error(E.Cause(err, "read packet"))
  124. err = nil
  125. }
  126. return
  127. }
  128. sizes[0] = n
  129. if n > 3 {
  130. b := packets[0]
  131. common.ClearArray(b[1:4])
  132. }
  133. eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
  134. count = 1
  135. return
  136. }
  137. func (c *ClientBind) Close() error {
  138. select {
  139. case <-c.done:
  140. default:
  141. close(c.done)
  142. }
  143. if c.bindDone != nil {
  144. c.bindDone()
  145. }
  146. c.connAccess.Lock()
  147. defer c.connAccess.Unlock()
  148. common.Close(common.PtrOrNil(c.conn))
  149. return nil
  150. }
  151. func (c *ClientBind) SetMark(mark uint32) error {
  152. return nil
  153. }
  154. func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
  155. udpConn, err := c.connect()
  156. if err != nil {
  157. c.pauseManager.WaitActive()
  158. time.Sleep(time.Second)
  159. return err
  160. }
  161. destination := netip.AddrPort(ep.(remoteEndpoint))
  162. for _, b := range bufs {
  163. if len(b) > 3 {
  164. reserved, loaded := c.reservedForEndpoint[destination]
  165. if !loaded {
  166. reserved = c.reserved
  167. }
  168. copy(b[1:4], reserved[:])
  169. }
  170. _, err = udpConn.WriteToUDPAddrPort(b, destination)
  171. if err != nil {
  172. udpConn.Close()
  173. return err
  174. }
  175. }
  176. return nil
  177. }
  178. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  179. ap, err := netip.ParseAddrPort(s)
  180. if err != nil {
  181. return nil, err
  182. }
  183. return remoteEndpoint(ap), nil
  184. }
  185. func (c *ClientBind) BatchSize() int {
  186. return 1
  187. }
  188. func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
  189. c.reservedForEndpoint[destination] = reserved
  190. }
  191. type wireConn struct {
  192. net.PacketConn
  193. conn net.Conn
  194. access sync.Mutex
  195. done chan struct{}
  196. }
  197. func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
  198. if w.conn != nil {
  199. return w.conn.Write(b)
  200. }
  201. return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
  202. }
  203. func (w *wireConn) Close() error {
  204. w.access.Lock()
  205. defer w.access.Unlock()
  206. select {
  207. case <-w.done:
  208. return net.ErrClosed
  209. default:
  210. }
  211. w.PacketConn.Close()
  212. close(w.done)
  213. return nil
  214. }
  215. var _ conn.Endpoint = (*remoteEndpoint)(nil)
  216. type remoteEndpoint netip.AddrPort
  217. func (e remoteEndpoint) ClearSrc() {
  218. }
  219. func (e remoteEndpoint) SrcToString() string {
  220. return ""
  221. }
  222. func (e remoteEndpoint) DstToString() string {
  223. return (netip.AddrPort)(e).String()
  224. }
  225. func (e remoteEndpoint) DstToBytes() []byte {
  226. b, _ := (netip.AddrPort)(e).MarshalBinary()
  227. return b
  228. }
  229. func (e remoteEndpoint) DstIP() netip.Addr {
  230. return (netip.AddrPort)(e).Addr()
  231. }
  232. func (e remoteEndpoint) SrcIP() netip.Addr {
  233. return netip.Addr{}
  234. }