bind.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "strconv"
  9. "sync"
  10. "golang.zx2c4.com/wireguard/conn"
  11. xnet "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/features/dns"
  13. "github.com/xtls/xray-core/transport/internet"
  14. )
  15. type netReadInfo struct {
  16. // status
  17. waiter sync.WaitGroup
  18. // param
  19. buff []byte
  20. // result
  21. bytes int
  22. endpoint conn.Endpoint
  23. err error
  24. }
  25. // reduce duplicated code
  26. type netBind struct {
  27. dns dns.Client
  28. dnsOption dns.IPOption
  29. workers int
  30. readQueue chan *netReadInfo
  31. }
  32. // SetMark implements conn.Bind
  33. func (bind *netBind) SetMark(mark uint32) error {
  34. return nil
  35. }
  36. // ParseEndpoint implements conn.Bind
  37. func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  38. ipStr, port, err := net.SplitHostPort(s)
  39. if err != nil {
  40. return nil, err
  41. }
  42. portNum, err := strconv.Atoi(port)
  43. if err != nil {
  44. return nil, err
  45. }
  46. addr := xnet.ParseAddress(ipStr)
  47. if addr.Family() == xnet.AddressFamilyDomain {
  48. ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
  49. if err != nil {
  50. return nil, err
  51. } else if len(ips) == 0 {
  52. return nil, dns.ErrEmptyResponse
  53. }
  54. addr = xnet.IPAddress(ips[0])
  55. }
  56. dst := xnet.Destination{
  57. Address: addr,
  58. Port: xnet.Port(portNum),
  59. Network: xnet.Network_UDP,
  60. }
  61. return &netEndpoint{
  62. dst: dst,
  63. }, nil
  64. }
  65. // BatchSize implements conn.Bind
  66. func (bind *netBind) BatchSize() int {
  67. return 1
  68. }
  69. // Open implements conn.Bind
  70. func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  71. bind.readQueue = make(chan *netReadInfo)
  72. fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
  73. defer func() {
  74. if r := recover(); r != nil {
  75. n = 0
  76. err = errors.New("channel closed")
  77. }
  78. }()
  79. r := &netReadInfo{
  80. buff: bufs[0],
  81. }
  82. r.waiter.Add(1)
  83. bind.readQueue <- r
  84. r.waiter.Wait() // wait read goroutine done, or we will miss the result
  85. sizes[0], eps[0] = r.bytes, r.endpoint
  86. return 1, r.err
  87. }
  88. workers := bind.workers
  89. if workers <= 0 {
  90. workers = 1
  91. }
  92. arr := make([]conn.ReceiveFunc, workers)
  93. for i := 0; i < workers; i++ {
  94. arr[i] = fun
  95. }
  96. return arr, uint16(uport), nil
  97. }
  98. // Close implements conn.Bind
  99. func (bind *netBind) Close() error {
  100. if bind.readQueue != nil {
  101. close(bind.readQueue)
  102. }
  103. return nil
  104. }
  105. type netBindClient struct {
  106. netBind
  107. ctx context.Context
  108. dialer internet.Dialer
  109. reserved []byte
  110. }
  111. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  112. c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
  113. if err != nil {
  114. return err
  115. }
  116. endpoint.conn = c
  117. go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
  118. for {
  119. v, ok := <-readQueue
  120. if !ok {
  121. return
  122. }
  123. i, err := c.Read(v.buff)
  124. if i > 3 {
  125. v.buff[1] = 0
  126. v.buff[2] = 0
  127. v.buff[3] = 0
  128. }
  129. v.bytes = i
  130. v.endpoint = endpoint
  131. v.err = err
  132. v.waiter.Done()
  133. if err != nil && errors.Is(err, io.EOF) {
  134. endpoint.conn = nil
  135. return
  136. }
  137. }
  138. }(bind.readQueue, endpoint)
  139. return nil
  140. }
  141. func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
  142. var err error
  143. nend, ok := endpoint.(*netEndpoint)
  144. if !ok {
  145. return conn.ErrWrongEndpointType
  146. }
  147. if nend.conn == nil {
  148. err = bind.connectTo(nend)
  149. if err != nil {
  150. return err
  151. }
  152. }
  153. for _, buff := range buff {
  154. if len(buff) > 3 && len(bind.reserved) == 3 {
  155. copy(buff[1:], bind.reserved)
  156. }
  157. if _, err = nend.conn.Write(buff); err != nil {
  158. return err
  159. }
  160. }
  161. return nil
  162. }
  163. type netBindServer struct {
  164. netBind
  165. }
  166. func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
  167. var err error
  168. nend, ok := endpoint.(*netEndpoint)
  169. if !ok {
  170. return conn.ErrWrongEndpointType
  171. }
  172. if nend.conn == nil {
  173. return errors.New("connection not open yet")
  174. }
  175. for _, buff := range buff {
  176. if _, err = nend.conn.Write(buff); err != nil {
  177. return err
  178. }
  179. }
  180. return err
  181. }
  182. type netEndpoint struct {
  183. dst xnet.Destination
  184. conn net.Conn
  185. }
  186. func (netEndpoint) ClearSrc() {}
  187. func (e netEndpoint) DstIP() netip.Addr {
  188. return netip.Addr{}
  189. }
  190. func (e netEndpoint) SrcIP() netip.Addr {
  191. return netip.Addr{}
  192. }
  193. func (e netEndpoint) DstToBytes() []byte {
  194. var dat []byte
  195. if e.dst.Address.Family().IsIPv4() {
  196. dat = e.dst.Address.IP().To4()[:]
  197. } else {
  198. dat = e.dst.Address.IP().To16()[:]
  199. }
  200. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  201. return dat
  202. }
  203. func (e netEndpoint) DstToString() string {
  204. return e.dst.NetAddr()
  205. }
  206. func (e netEndpoint) SrcToString() string {
  207. return ""
  208. }
  209. func toNetIpAddr(addr xnet.Address) netip.Addr {
  210. if addr.Family().IsIPv4() {
  211. ip := addr.IP()
  212. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  213. } else {
  214. ip := addr.IP()
  215. arr := [16]byte{}
  216. for i := 0; i < 16; i++ {
  217. arr[i] = ip[i]
  218. }
  219. return netip.AddrFrom16(arr)
  220. }
  221. }