bind.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "net/netip"
  6. "strconv"
  7. "sync"
  8. "golang.zx2c4.com/wireguard/conn"
  9. "github.com/xtls/xray-core/common/net"
  10. "github.com/xtls/xray-core/features/dns"
  11. "github.com/xtls/xray-core/transport/internet"
  12. )
  13. type netReadInfo struct {
  14. // status
  15. waiter sync.WaitGroup
  16. // param
  17. buff []byte
  18. // result
  19. bytes int
  20. endpoint conn.Endpoint
  21. err error
  22. }
  23. // reduce duplicated code
  24. type netBind struct {
  25. dns dns.Client
  26. dnsOption dns.IPOption
  27. workers int
  28. readQueue chan *netReadInfo
  29. }
  30. // SetMark implements conn.Bind
  31. func (bind *netBind) SetMark(mark uint32) error {
  32. return nil
  33. }
  34. // ParseEndpoint implements conn.Bind
  35. func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  36. ipStr, port, err := net.SplitHostPort(s)
  37. if err != nil {
  38. return nil, err
  39. }
  40. portNum, err := strconv.Atoi(port)
  41. if err != nil {
  42. return nil, err
  43. }
  44. addr := net.ParseAddress(ipStr)
  45. if addr.Family() == net.AddressFamilyDomain {
  46. ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
  47. if err != nil {
  48. return nil, err
  49. } else if len(ips) == 0 {
  50. return nil, dns.ErrEmptyResponse
  51. }
  52. addr = net.IPAddress(ips[0])
  53. }
  54. dst := net.Destination{
  55. Address: addr,
  56. Port: net.Port(portNum),
  57. Network: net.Network_UDP,
  58. }
  59. return &netEndpoint{
  60. dst: dst,
  61. }, nil
  62. }
  63. // BatchSize implements conn.Bind
  64. func (bind *netBind) BatchSize() int {
  65. return 1
  66. }
  67. // Open implements conn.Bind
  68. func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  69. bind.readQueue = make(chan *netReadInfo)
  70. fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
  71. defer func() {
  72. if r := recover(); r != nil {
  73. n = 0
  74. err = errors.New("channel closed")
  75. }
  76. }()
  77. r := &netReadInfo{
  78. buff: bufs[0],
  79. }
  80. r.waiter.Add(1)
  81. bind.readQueue <- r
  82. r.waiter.Wait() // wait read goroutine done, or we will miss the result
  83. sizes[0], eps[0] = r.bytes, r.endpoint
  84. return 1, r.err
  85. }
  86. workers := bind.workers
  87. if workers <= 0 {
  88. workers = 1
  89. }
  90. arr := make([]conn.ReceiveFunc, workers)
  91. for i := 0; i < workers; i++ {
  92. arr[i] = fun
  93. }
  94. return arr, uint16(uport), nil
  95. }
  96. // Close implements conn.Bind
  97. func (bind *netBind) Close() error {
  98. if bind.readQueue != nil {
  99. close(bind.readQueue)
  100. }
  101. return nil
  102. }
  103. type netBindClient struct {
  104. netBind
  105. ctx context.Context
  106. dialer internet.Dialer
  107. reserved []byte
  108. // Track all peer connections for unified reading
  109. connMutex sync.RWMutex
  110. conns map[*netEndpoint]net.Conn
  111. dataChan chan *receivedData
  112. closeChan chan struct{}
  113. closeOnce sync.Once
  114. }
  115. const (
  116. // Buffer size for dataChan - allows some buffering of received packets
  117. // while dispatcher matches them with read requests
  118. dataChannelBufferSize = 100
  119. )
  120. type receivedData struct {
  121. data []byte
  122. n int
  123. endpoint *netEndpoint
  124. err error
  125. }
  126. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  127. c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
  128. if err != nil {
  129. return err
  130. }
  131. endpoint.conn = c
  132. // Initialize channels on first connection
  133. bind.connMutex.Lock()
  134. if bind.conns == nil {
  135. bind.conns = make(map[*netEndpoint]net.Conn)
  136. bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
  137. bind.closeChan = make(chan struct{})
  138. // Start unified reader dispatcher
  139. go bind.unifiedReader()
  140. }
  141. bind.conns[endpoint] = c
  142. bind.connMutex.Unlock()
  143. // Start a reader goroutine for this specific connection
  144. go func(conn net.Conn, endpoint *netEndpoint) {
  145. const maxPacketSize = 1500
  146. for {
  147. select {
  148. case <-bind.closeChan:
  149. return
  150. default:
  151. }
  152. buf := make([]byte, maxPacketSize)
  153. n, err := conn.Read(buf)
  154. // Send only the valid data portion to dispatcher
  155. dataToSend := buf
  156. if n > 0 && n < len(buf) {
  157. dataToSend = buf[:n]
  158. }
  159. // Send received data to dispatcher
  160. select {
  161. case bind.dataChan <- &receivedData{
  162. data: dataToSend,
  163. n: n,
  164. endpoint: endpoint,
  165. err: err,
  166. }:
  167. case <-bind.closeChan:
  168. return
  169. }
  170. if err != nil {
  171. bind.connMutex.Lock()
  172. delete(bind.conns, endpoint)
  173. endpoint.conn = nil
  174. bind.connMutex.Unlock()
  175. return
  176. }
  177. }
  178. }(c, endpoint)
  179. return nil
  180. }
  181. // unifiedReader dispatches received data to waiting read requests
  182. func (bind *netBindClient) unifiedReader() {
  183. for {
  184. select {
  185. case data := <-bind.dataChan:
  186. // Bounds check to prevent panic
  187. if data.n > len(data.data) {
  188. data.n = len(data.data)
  189. }
  190. // Wait for a read request with timeout to prevent blocking forever
  191. select {
  192. case v := <-bind.readQueue:
  193. // Copy data to request buffer
  194. n := copy(v.buff, data.data[:data.n])
  195. // Clear reserved bytes if needed
  196. if n > 3 {
  197. v.buff[1] = 0
  198. v.buff[2] = 0
  199. v.buff[3] = 0
  200. }
  201. v.bytes = n
  202. v.endpoint = data.endpoint
  203. v.err = data.err
  204. v.waiter.Done()
  205. case <-bind.closeChan:
  206. return
  207. }
  208. case <-bind.closeChan:
  209. return
  210. }
  211. }
  212. }
  213. // Close implements conn.Bind.Close for netBindClient
  214. func (bind *netBindClient) Close() error {
  215. // Use sync.Once to prevent double-close panic
  216. bind.closeOnce.Do(func() {
  217. bind.connMutex.Lock()
  218. if bind.closeChan != nil {
  219. close(bind.closeChan)
  220. }
  221. bind.connMutex.Unlock()
  222. })
  223. // Call parent Close
  224. return bind.netBind.Close()
  225. }
  226. func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
  227. var err error
  228. nend, ok := endpoint.(*netEndpoint)
  229. if !ok {
  230. return conn.ErrWrongEndpointType
  231. }
  232. if nend.conn == nil {
  233. err = bind.connectTo(nend)
  234. if err != nil {
  235. return err
  236. }
  237. }
  238. for _, buff := range buff {
  239. if len(buff) > 3 && len(bind.reserved) == 3 {
  240. copy(buff[1:], bind.reserved)
  241. }
  242. if _, err = nend.conn.Write(buff); err != nil {
  243. return err
  244. }
  245. }
  246. return nil
  247. }
  248. type netBindServer struct {
  249. netBind
  250. }
  251. func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
  252. var err error
  253. nend, ok := endpoint.(*netEndpoint)
  254. if !ok {
  255. return conn.ErrWrongEndpointType
  256. }
  257. if nend.conn == nil {
  258. return errors.New("connection not open yet")
  259. }
  260. for _, buff := range buff {
  261. if _, err = nend.conn.Write(buff); err != nil {
  262. return err
  263. }
  264. }
  265. return err
  266. }
  267. type netEndpoint struct {
  268. dst net.Destination
  269. conn net.Conn
  270. }
  271. func (netEndpoint) ClearSrc() {}
  272. func (e netEndpoint) DstIP() netip.Addr {
  273. return netip.Addr{}
  274. }
  275. func (e netEndpoint) SrcIP() netip.Addr {
  276. return netip.Addr{}
  277. }
  278. func (e netEndpoint) DstToBytes() []byte {
  279. var dat []byte
  280. if e.dst.Address.Family().IsIPv4() {
  281. dat = e.dst.Address.IP().To4()[:]
  282. } else {
  283. dat = e.dst.Address.IP().To16()[:]
  284. }
  285. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  286. return dat
  287. }
  288. func (e netEndpoint) DstToString() string {
  289. return e.dst.NetAddr()
  290. }
  291. func (e netEndpoint) SrcToString() string {
  292. return ""
  293. }
  294. func toNetIpAddr(addr net.Address) netip.Addr {
  295. if addr.Family().IsIPv4() {
  296. ip := addr.IP()
  297. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  298. } else {
  299. ip := addr.IP()
  300. arr := [16]byte{}
  301. for i := 0; i < 16; i++ {
  302. arr[i] = ip[i]
  303. }
  304. return netip.AddrFrom16(arr)
  305. }
  306. }