bind.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "strconv"
  9. "sync"
  10. "golang.zx2c4.com/wireguard/conn"
  11. xnet "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/features/dns"
  13. "github.com/xtls/xray-core/transport/internet"
  14. )
  15. type netReadInfo struct {
  16. // status
  17. waiter sync.WaitGroup
  18. // param
  19. buff []byte
  20. // result
  21. bytes int
  22. endpoint conn.Endpoint
  23. err error
  24. }
  25. // reduce duplicated code
  26. type netBind struct {
  27. dns dns.Client
  28. dnsOption dns.IPOption
  29. workers int
  30. readQueue chan *netReadInfo
  31. }
  32. // SetMark implements conn.Bind
  33. func (bind *netBind) SetMark(mark uint32) error {
  34. return nil
  35. }
  36. // ParseEndpoint implements conn.Bind
  37. func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  38. ipStr, port, err := net.SplitHostPort(s)
  39. if err != nil {
  40. return nil, err
  41. }
  42. portNum, err := strconv.Atoi(port)
  43. if err != nil {
  44. return nil, err
  45. }
  46. addr := xnet.ParseAddress(ipStr)
  47. if addr.Family() == xnet.AddressFamilyDomain {
  48. ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
  49. if err != nil {
  50. return nil, err
  51. } else if len(ips) == 0 {
  52. return nil, dns.ErrEmptyResponse
  53. }
  54. addr = xnet.IPAddress(ips[0])
  55. }
  56. dst := xnet.Destination{
  57. Address: addr,
  58. Port: xnet.Port(portNum),
  59. Network: xnet.Network_UDP,
  60. }
  61. return &netEndpoint{
  62. dst: dst,
  63. }, nil
  64. }
  65. // BatchSize implements conn.Bind
  66. func (bind *netBind) BatchSize() int {
  67. return 1
  68. }
  69. // Open implements conn.Bind
  70. func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  71. bind.readQueue = make(chan *netReadInfo)
  72. fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
  73. defer func() {
  74. if r := recover(); r != nil {
  75. n = 0
  76. err = errors.New("channel closed")
  77. }
  78. }()
  79. r := &netReadInfo{
  80. buff: bufs[0],
  81. }
  82. r.waiter.Add(1)
  83. bind.readQueue <- r
  84. r.waiter.Wait() // wait read goroutine done, or we will miss the result
  85. sizes[0], eps[0] = r.bytes, r.endpoint
  86. return 1, r.err
  87. }
  88. workers := bind.workers
  89. if workers <= 0 {
  90. workers = 1
  91. }
  92. arr := make([]conn.ReceiveFunc, workers)
  93. for i := 0; i < workers; i++ {
  94. arr[i] = fun
  95. }
  96. return arr, uint16(uport), nil
  97. }
  98. // Close implements conn.Bind
  99. func (bind *netBind) Close() error {
  100. if bind.readQueue != nil {
  101. close(bind.readQueue)
  102. }
  103. return nil
  104. }
  105. type netBindClient struct {
  106. netBind
  107. dialer internet.Dialer
  108. reserved []byte
  109. }
  110. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  111. c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
  112. if err != nil {
  113. return err
  114. }
  115. endpoint.conn = c
  116. go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
  117. for {
  118. v, ok := <-readQueue
  119. if !ok {
  120. return
  121. }
  122. i, err := c.Read(v.buff)
  123. if i > 3 {
  124. v.buff[1] = 0
  125. v.buff[2] = 0
  126. v.buff[3] = 0
  127. }
  128. v.bytes = i
  129. v.endpoint = endpoint
  130. v.err = err
  131. v.waiter.Done()
  132. if err != nil && errors.Is(err, io.EOF) {
  133. endpoint.conn = nil
  134. return
  135. }
  136. }
  137. }(bind.readQueue, endpoint)
  138. return nil
  139. }
  140. func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
  141. var err error
  142. nend, ok := endpoint.(*netEndpoint)
  143. if !ok {
  144. return conn.ErrWrongEndpointType
  145. }
  146. if nend.conn == nil {
  147. err = bind.connectTo(nend)
  148. if err != nil {
  149. return err
  150. }
  151. }
  152. for _, buff := range buff {
  153. if len(buff) > 3 && len(bind.reserved) == 3 {
  154. copy(buff[1:], bind.reserved)
  155. }
  156. if _, err = nend.conn.Write(buff); err != nil {
  157. return err
  158. }
  159. }
  160. return nil
  161. }
  162. type netBindServer struct {
  163. netBind
  164. }
  165. func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
  166. var err error
  167. nend, ok := endpoint.(*netEndpoint)
  168. if !ok {
  169. return conn.ErrWrongEndpointType
  170. }
  171. if nend.conn == nil {
  172. return newError("connection not open yet")
  173. }
  174. for _, buff := range buff {
  175. if _, err = nend.conn.Write(buff); err != nil {
  176. return err
  177. }
  178. }
  179. return err
  180. }
  181. type netEndpoint struct {
  182. dst xnet.Destination
  183. conn net.Conn
  184. }
  185. func (netEndpoint) ClearSrc() {}
  186. func (e netEndpoint) DstIP() netip.Addr {
  187. return netip.Addr{}
  188. }
  189. func (e netEndpoint) SrcIP() netip.Addr {
  190. return netip.Addr{}
  191. }
  192. func (e netEndpoint) DstToBytes() []byte {
  193. var dat []byte
  194. if e.dst.Address.Family().IsIPv4() {
  195. dat = e.dst.Address.IP().To4()[:]
  196. } else {
  197. dat = e.dst.Address.IP().To16()[:]
  198. }
  199. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  200. return dat
  201. }
  202. func (e netEndpoint) DstToString() string {
  203. return e.dst.NetAddr()
  204. }
  205. func (e netEndpoint) SrcToString() string {
  206. return ""
  207. }
  208. func toNetIpAddr(addr xnet.Address) netip.Addr {
  209. if addr.Family().IsIPv4() {
  210. ip := addr.IP()
  211. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  212. } else {
  213. ip := addr.IP()
  214. arr := [16]byte{}
  215. for i := 0; i < 16; i++ {
  216. arr[i] = ip[i]
  217. }
  218. return netip.AddrFrom16(arr)
  219. }
  220. }