client_bind.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "sync"
  7. "github.com/sagernet/sing/common"
  8. "github.com/sagernet/sing/common/buf"
  9. "github.com/sagernet/sing/common/bufio"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. "github.com/sagernet/wireguard-go/conn"
  13. )
  14. var _ conn.Bind = (*ClientBind)(nil)
  15. type ClientBind struct {
  16. ctx context.Context
  17. dialer N.Dialer
  18. reservedForEndpoint map[M.Socksaddr][3]uint8
  19. connAccess sync.Mutex
  20. conn *wireConn
  21. done chan struct{}
  22. isConnect bool
  23. connectAddr M.Socksaddr
  24. reserved [3]uint8
  25. }
  26. func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
  27. return &ClientBind{
  28. ctx: ctx,
  29. dialer: dialer,
  30. reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
  31. isConnect: isConnect,
  32. connectAddr: connectAddr,
  33. reserved: reserved,
  34. }
  35. }
  36. func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
  37. c.reservedForEndpoint[destination] = reserved
  38. }
  39. func (c *ClientBind) connect() (*wireConn, error) {
  40. serverConn := c.conn
  41. if serverConn != nil {
  42. select {
  43. case <-serverConn.done:
  44. serverConn = nil
  45. default:
  46. return serverConn, nil
  47. }
  48. }
  49. c.connAccess.Lock()
  50. defer c.connAccess.Unlock()
  51. serverConn = c.conn
  52. if serverConn != nil {
  53. select {
  54. case <-serverConn.done:
  55. serverConn = nil
  56. default:
  57. return serverConn, nil
  58. }
  59. }
  60. if c.isConnect {
  61. udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
  62. if err != nil {
  63. return nil, &wireError{err}
  64. }
  65. c.conn = &wireConn{
  66. NetPacketConn: &bufio.UnbindPacketConn{
  67. ExtendedConn: bufio.NewExtendedConn(udpConn),
  68. Addr: c.connectAddr,
  69. },
  70. done: make(chan struct{}),
  71. }
  72. } else {
  73. udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
  74. if err != nil {
  75. return nil, &wireError{err}
  76. }
  77. c.conn = &wireConn{
  78. NetPacketConn: bufio.NewPacketConn(udpConn),
  79. done: make(chan struct{}),
  80. }
  81. }
  82. return c.conn, nil
  83. }
  84. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  85. select {
  86. case <-c.done:
  87. err = net.ErrClosed
  88. return
  89. default:
  90. }
  91. return []conn.ReceiveFunc{c.receive}, 0, nil
  92. }
  93. func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
  94. udpConn, err := c.connect()
  95. if err != nil {
  96. err = &wireError{err}
  97. return
  98. }
  99. buffer := buf.With(b)
  100. destination, err := udpConn.ReadPacket(buffer)
  101. if err != nil {
  102. udpConn.Close()
  103. select {
  104. case <-c.done:
  105. default:
  106. err = &wireError{err}
  107. }
  108. return
  109. }
  110. n = buffer.Len()
  111. if buffer.Start() > 0 {
  112. copy(b, buffer.Bytes())
  113. }
  114. if n > 3 {
  115. b[1] = 0
  116. b[2] = 0
  117. b[3] = 0
  118. }
  119. ep = Endpoint(destination)
  120. return
  121. }
  122. func (c *ClientBind) Reset() {
  123. common.Close(common.PtrOrNil(c.conn))
  124. }
  125. func (c *ClientBind) Close() error {
  126. common.Close(common.PtrOrNil(c.conn))
  127. if c.done == nil {
  128. c.done = make(chan struct{})
  129. return nil
  130. }
  131. select {
  132. case <-c.done:
  133. return net.ErrClosed
  134. default:
  135. close(c.done)
  136. }
  137. return nil
  138. }
  139. func (c *ClientBind) SetMark(mark uint32) error {
  140. return nil
  141. }
  142. func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
  143. udpConn, err := c.connect()
  144. if err != nil {
  145. return err
  146. }
  147. destination := M.Socksaddr(ep.(Endpoint))
  148. if len(b) > 3 {
  149. reserved, loaded := c.reservedForEndpoint[destination]
  150. if !loaded {
  151. reserved = c.reserved
  152. }
  153. b[1] = reserved[0]
  154. b[2] = reserved[1]
  155. b[3] = reserved[2]
  156. }
  157. err = udpConn.WritePacket(buf.As(b), destination)
  158. if err != nil {
  159. udpConn.Close()
  160. }
  161. return err
  162. }
  163. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  164. return Endpoint(M.ParseSocksaddr(s)), nil
  165. }
  166. type wireConn struct {
  167. N.NetPacketConn
  168. access sync.Mutex
  169. done chan struct{}
  170. }
  171. func (w *wireConn) Close() error {
  172. w.access.Lock()
  173. defer w.access.Unlock()
  174. select {
  175. case <-w.done:
  176. return net.ErrClosed
  177. default:
  178. }
  179. w.NetPacketConn.Close()
  180. close(w.done)
  181. return nil
  182. }