1
0

client_bind.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. "github.com/sagernet/sing/common/logger"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. "github.com/sagernet/sing/service"
  15. "github.com/sagernet/sing/service/pause"
  16. "github.com/sagernet/wireguard-go/conn"
  17. )
  18. var _ conn.Bind = (*ClientBind)(nil)
  19. type ClientBind struct {
  20. ctx context.Context
  21. logger logger.Logger
  22. pauseManager pause.Manager
  23. bindCtx context.Context
  24. bindDone context.CancelFunc
  25. dialer N.Dialer
  26. reservedForEndpoint map[netip.AddrPort][3]uint8
  27. connAccess sync.Mutex
  28. conn *wireConn
  29. done chan struct{}
  30. isConnect bool
  31. connectAddr netip.AddrPort
  32. reserved [3]uint8
  33. }
  34. func NewClientBind(ctx context.Context, logger logger.Logger, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
  35. return &ClientBind{
  36. ctx: ctx,
  37. logger: logger,
  38. pauseManager: service.FromContext[pause.Manager](ctx),
  39. dialer: dialer,
  40. reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
  41. done: make(chan struct{}),
  42. isConnect: isConnect,
  43. connectAddr: connectAddr,
  44. reserved: reserved,
  45. }
  46. }
  47. func (c *ClientBind) connect() (*wireConn, error) {
  48. serverConn := c.conn
  49. if serverConn != nil {
  50. select {
  51. case <-serverConn.done:
  52. serverConn = nil
  53. default:
  54. return serverConn, nil
  55. }
  56. }
  57. c.connAccess.Lock()
  58. defer c.connAccess.Unlock()
  59. select {
  60. case <-c.done:
  61. return nil, net.ErrClosed
  62. default:
  63. }
  64. serverConn = c.conn
  65. if serverConn != nil {
  66. select {
  67. case <-serverConn.done:
  68. serverConn = nil
  69. default:
  70. return serverConn, nil
  71. }
  72. }
  73. if c.isConnect {
  74. udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
  75. if err != nil {
  76. return nil, err
  77. }
  78. c.conn = &wireConn{
  79. PacketConn: bufio.NewUnbindPacketConn(udpConn),
  80. done: make(chan struct{}),
  81. }
  82. } else {
  83. udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  84. if err != nil {
  85. return nil, err
  86. }
  87. c.conn = &wireConn{
  88. PacketConn: bufio.NewPacketConn(udpConn),
  89. done: make(chan struct{}),
  90. }
  91. }
  92. return c.conn, nil
  93. }
  94. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  95. select {
  96. case <-c.done:
  97. c.done = make(chan struct{})
  98. default:
  99. }
  100. c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
  101. return []conn.ReceiveFunc{c.receive}, 0, nil
  102. }
  103. func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
  104. udpConn, err := c.connect()
  105. if err != nil {
  106. select {
  107. case <-c.done:
  108. return
  109. default:
  110. }
  111. c.logger.Error(E.Cause(err, "connect to server"))
  112. err = nil
  113. c.pauseManager.WaitActive()
  114. time.Sleep(time.Second)
  115. return
  116. }
  117. n, addr, err := udpConn.ReadFrom(packets[0])
  118. if err != nil {
  119. udpConn.Close()
  120. select {
  121. case <-c.done:
  122. default:
  123. c.logger.Error(E.Cause(err, "read packet"))
  124. err = nil
  125. }
  126. return
  127. }
  128. sizes[0] = n
  129. if n > 3 {
  130. b := packets[0]
  131. common.ClearArray(b[1:4])
  132. }
  133. eps[0] = remoteEndpoint(M.SocksaddrFromNet(addr).Unwrap().AddrPort())
  134. count = 1
  135. return
  136. }
  137. func (c *ClientBind) Close() error {
  138. select {
  139. case <-c.done:
  140. default:
  141. close(c.done)
  142. }
  143. if c.bindDone != nil {
  144. c.bindDone()
  145. }
  146. c.connAccess.Lock()
  147. defer c.connAccess.Unlock()
  148. common.Close(common.PtrOrNil(c.conn))
  149. return nil
  150. }
  151. func (c *ClientBind) SetMark(mark uint32) error {
  152. return nil
  153. }
  154. func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint, offset int) error {
  155. udpConn, err := c.connect()
  156. if err != nil {
  157. c.pauseManager.WaitActive()
  158. time.Sleep(time.Second)
  159. return err
  160. }
  161. destination := netip.AddrPort(ep.(remoteEndpoint))
  162. for _, buf := range bufs {
  163. if offset > 0 {
  164. buf = buf[offset:]
  165. }
  166. if len(buf) > 3 {
  167. reserved, loaded := c.reservedForEndpoint[destination]
  168. if !loaded {
  169. reserved = c.reserved
  170. }
  171. copy(buf[1:4], reserved[:])
  172. }
  173. _, err = udpConn.WriteToUDPAddrPort(buf, destination)
  174. if err != nil {
  175. udpConn.Close()
  176. return err
  177. }
  178. }
  179. return nil
  180. }
  181. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  182. ap, err := netip.ParseAddrPort(s)
  183. if err != nil {
  184. return nil, err
  185. }
  186. return remoteEndpoint(ap), nil
  187. }
  188. func (c *ClientBind) BatchSize() int {
  189. return 1
  190. }
  191. func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
  192. c.reservedForEndpoint[destination] = reserved
  193. }
  194. type wireConn struct {
  195. net.PacketConn
  196. conn net.Conn
  197. access sync.Mutex
  198. done chan struct{}
  199. }
  200. func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
  201. if w.conn != nil {
  202. return w.conn.Write(b)
  203. }
  204. return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
  205. }
  206. func (w *wireConn) Close() error {
  207. w.access.Lock()
  208. defer w.access.Unlock()
  209. select {
  210. case <-w.done:
  211. return net.ErrClosed
  212. default:
  213. }
  214. w.PacketConn.Close()
  215. close(w.done)
  216. return nil
  217. }
  218. var _ conn.Endpoint = (*remoteEndpoint)(nil)
  219. type remoteEndpoint netip.AddrPort
  220. func (e remoteEndpoint) ClearSrc() {
  221. }
  222. func (e remoteEndpoint) SrcToString() string {
  223. return ""
  224. }
  225. func (e remoteEndpoint) DstToString() string {
  226. return (netip.AddrPort)(e).String()
  227. }
  228. func (e remoteEndpoint) DstToBytes() []byte {
  229. b, _ := (netip.AddrPort)(e).MarshalBinary()
  230. return b
  231. }
  232. func (e remoteEndpoint) DstIP() netip.Addr {
  233. return (netip.AddrPort)(e).Addr()
  234. }
  235. func (e remoteEndpoint) SrcIP() netip.Addr {
  236. return netip.Addr{}
  237. }