client_bind.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "sync"
  6. "github.com/sagernet/sing/common"
  7. M "github.com/sagernet/sing/common/metadata"
  8. N "github.com/sagernet/sing/common/network"
  9. "github.com/sagernet/wireguard-go/conn"
  10. )
  11. var _ conn.Bind = (*ClientBind)(nil)
  12. type ClientBind struct {
  13. ctx context.Context
  14. dialer N.Dialer
  15. peerAddr M.Socksaddr
  16. reserved [3]uint8
  17. connAccess sync.Mutex
  18. conn *wireConn
  19. done chan struct{}
  20. }
  21. func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
  22. return &ClientBind{
  23. ctx: ctx,
  24. dialer: dialer,
  25. peerAddr: peerAddr,
  26. reserved: reserved,
  27. }
  28. }
  29. func (c *ClientBind) connect() (*wireConn, error) {
  30. serverConn := c.conn
  31. if serverConn != nil {
  32. select {
  33. case <-serverConn.done:
  34. serverConn = nil
  35. default:
  36. return serverConn, nil
  37. }
  38. }
  39. c.connAccess.Lock()
  40. defer c.connAccess.Unlock()
  41. serverConn = c.conn
  42. if serverConn != nil {
  43. select {
  44. case <-serverConn.done:
  45. serverConn = nil
  46. default:
  47. return serverConn, nil
  48. }
  49. }
  50. udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
  51. if err != nil {
  52. return nil, &wireError{err}
  53. }
  54. c.conn = &wireConn{
  55. Conn: udpConn,
  56. done: make(chan struct{}),
  57. }
  58. return c.conn, nil
  59. }
  60. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  61. select {
  62. case <-c.done:
  63. err = net.ErrClosed
  64. return
  65. default:
  66. }
  67. return []conn.ReceiveFunc{c.receive}, 0, nil
  68. }
  69. func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
  70. udpConn, err := c.connect()
  71. if err != nil {
  72. err = &wireError{err}
  73. return
  74. }
  75. n, err = udpConn.Read(b)
  76. if err != nil {
  77. udpConn.Close()
  78. select {
  79. case <-c.done:
  80. default:
  81. err = &wireError{err}
  82. }
  83. return
  84. }
  85. if n > 3 {
  86. b[1] = 0
  87. b[2] = 0
  88. b[3] = 0
  89. }
  90. ep = Endpoint(c.peerAddr)
  91. return
  92. }
  93. func (c *ClientBind) Reset() {
  94. c.connAccess.Lock()
  95. defer c.connAccess.Unlock()
  96. common.Close(common.PtrOrNil(c.conn))
  97. }
  98. func (c *ClientBind) Close() error {
  99. c.connAccess.Lock()
  100. defer c.connAccess.Unlock()
  101. common.Close(common.PtrOrNil(c.conn))
  102. if c.done == nil {
  103. c.done = make(chan struct{})
  104. return nil
  105. }
  106. select {
  107. case <-c.done:
  108. return net.ErrClosed
  109. default:
  110. close(c.done)
  111. }
  112. return nil
  113. }
  114. func (c *ClientBind) SetMark(mark uint32) error {
  115. return nil
  116. }
  117. func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
  118. udpConn, err := c.connect()
  119. if err != nil {
  120. return err
  121. }
  122. if len(b) > 3 {
  123. b[1] = c.reserved[0]
  124. b[2] = c.reserved[1]
  125. b[3] = c.reserved[2]
  126. }
  127. _, err = udpConn.Write(b)
  128. if err != nil {
  129. udpConn.Close()
  130. }
  131. return err
  132. }
  133. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  134. return Endpoint(c.peerAddr), nil
  135. }
  136. func (c *ClientBind) Endpoint() conn.Endpoint {
  137. return Endpoint(c.peerAddr)
  138. }
  139. type wireConn struct {
  140. net.Conn
  141. access sync.Mutex
  142. done chan struct{}
  143. }
  144. func (w *wireConn) Close() error {
  145. w.access.Lock()
  146. defer w.access.Unlock()
  147. select {
  148. case <-w.done:
  149. return net.ErrClosed
  150. default:
  151. }
  152. w.Conn.Close()
  153. close(w.done)
  154. return nil
  155. }