client_bind.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package wireguard
  2. import (
  3. "context"
  4. "net"
  5. "sync"
  6. "github.com/sagernet/sing/common"
  7. M "github.com/sagernet/sing/common/metadata"
  8. N "github.com/sagernet/sing/common/network"
  9. "golang.zx2c4.com/wireguard/conn"
  10. )
  11. var _ conn.Bind = (*ClientBind)(nil)
  12. type ClientBind struct {
  13. ctx context.Context
  14. dialer N.Dialer
  15. peerAddr M.Socksaddr
  16. connAccess sync.Mutex
  17. conn *wireConn
  18. }
  19. func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr) *ClientBind {
  20. return &ClientBind{
  21. ctx: ctx,
  22. dialer: dialer,
  23. peerAddr: peerAddr,
  24. }
  25. }
  26. func (c *ClientBind) connect() (*wireConn, error) {
  27. serverConn := c.conn
  28. if serverConn != nil {
  29. select {
  30. case <-serverConn.done:
  31. serverConn = nil
  32. default:
  33. return serverConn, nil
  34. }
  35. }
  36. c.connAccess.Lock()
  37. defer c.connAccess.Unlock()
  38. serverConn = c.conn
  39. if serverConn != nil {
  40. select {
  41. case <-serverConn.done:
  42. serverConn = nil
  43. default:
  44. return serverConn, nil
  45. }
  46. }
  47. udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
  48. if err != nil {
  49. return nil, &wireError{err}
  50. }
  51. c.conn = &wireConn{
  52. Conn: udpConn,
  53. done: make(chan struct{}),
  54. }
  55. return c.conn, nil
  56. }
  57. func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  58. return []conn.ReceiveFunc{c.receive}, 0, nil
  59. }
  60. func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
  61. udpConn, err := c.connect()
  62. if err != nil {
  63. err = &wireError{err}
  64. return
  65. }
  66. n, err = udpConn.Read(b)
  67. if err != nil {
  68. udpConn.Close()
  69. err = &wireError{err}
  70. }
  71. ep = Endpoint(c.peerAddr)
  72. return
  73. }
  74. func (c *ClientBind) Close() error {
  75. c.connAccess.Lock()
  76. defer c.connAccess.Unlock()
  77. common.Close(common.PtrOrNil(c.conn))
  78. return nil
  79. }
  80. func (c *ClientBind) SetMark(mark uint32) error {
  81. return nil
  82. }
  83. func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
  84. udpConn, err := c.connect()
  85. if err != nil {
  86. return err
  87. }
  88. _, err = udpConn.Write(b)
  89. if err != nil {
  90. udpConn.Close()
  91. }
  92. return err
  93. }
  94. func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  95. return Endpoint(c.peerAddr), nil
  96. }
  97. func (c *ClientBind) Endpoint() conn.Endpoint {
  98. return Endpoint(c.peerAddr)
  99. }
  100. type wireConn struct {
  101. net.Conn
  102. access sync.Mutex
  103. done chan struct{}
  104. }
  105. func (w *wireConn) Close() error {
  106. w.access.Lock()
  107. defer w.access.Unlock()
  108. select {
  109. case <-w.done:
  110. return net.ErrClosed
  111. default:
  112. }
  113. w.Conn.Close()
  114. close(w.done)
  115. return nil
  116. }