bind.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "net/netip"
  6. "strconv"
  7. "sync"
  8. "golang.zx2c4.com/wireguard/conn"
  9. "github.com/xtls/xray-core/common/net"
  10. "github.com/xtls/xray-core/features/dns"
  11. "github.com/xtls/xray-core/transport/internet"
  12. )
  13. type netReadInfo struct {
  14. // status
  15. waiter sync.WaitGroup
  16. // param
  17. buff []byte
  18. // result
  19. bytes int
  20. endpoint conn.Endpoint
  21. err error
  22. }
  23. // reduce duplicated code
  24. type netBind struct {
  25. dns dns.Client
  26. dnsOption dns.IPOption
  27. workers int
  28. readQueue chan *netReadInfo
  29. }
  30. // SetMark implements conn.Bind
  31. func (bind *netBind) SetMark(mark uint32) error {
  32. return nil
  33. }
  34. // ParseEndpoint implements conn.Bind
  35. func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  36. ipStr, port, err := net.SplitHostPort(s)
  37. if err != nil {
  38. return nil, err
  39. }
  40. portNum, err := strconv.Atoi(port)
  41. if err != nil {
  42. return nil, err
  43. }
  44. addr := net.ParseAddress(ipStr)
  45. if addr.Family() == net.AddressFamilyDomain {
  46. ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
  47. if err != nil {
  48. return nil, err
  49. } else if len(ips) == 0 {
  50. return nil, dns.ErrEmptyResponse
  51. }
  52. addr = net.IPAddress(ips[0])
  53. }
  54. dst := net.Destination{
  55. Address: addr,
  56. Port: net.Port(portNum),
  57. Network: net.Network_UDP,
  58. }
  59. return &netEndpoint{
  60. dst: dst,
  61. }, nil
  62. }
  63. // BatchSize implements conn.Bind
  64. func (bind *netBind) BatchSize() int {
  65. return 1
  66. }
  67. // Open implements conn.Bind
  68. func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  69. bind.readQueue = make(chan *netReadInfo)
  70. fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
  71. defer func() {
  72. if r := recover(); r != nil {
  73. n = 0
  74. err = errors.New("channel closed")
  75. }
  76. }()
  77. r := &netReadInfo{
  78. buff: bufs[0],
  79. }
  80. r.waiter.Add(1)
  81. bind.readQueue <- r
  82. r.waiter.Wait() // wait read goroutine done, or we will miss the result
  83. sizes[0], eps[0] = r.bytes, r.endpoint
  84. return 1, r.err
  85. }
  86. workers := bind.workers
  87. if workers <= 0 {
  88. workers = 1
  89. }
  90. arr := make([]conn.ReceiveFunc, workers)
  91. for i := 0; i < workers; i++ {
  92. arr[i] = fun
  93. }
  94. return arr, uint16(uport), nil
  95. }
  96. // Close implements conn.Bind
  97. func (bind *netBind) Close() error {
  98. if bind.readQueue != nil {
  99. close(bind.readQueue)
  100. }
  101. return nil
  102. }
  103. type netBindClient struct {
  104. netBind
  105. ctx context.Context
  106. dialer internet.Dialer
  107. reserved []byte
  108. }
  109. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  110. c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
  111. if err != nil {
  112. return err
  113. }
  114. endpoint.conn = c
  115. go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
  116. for {
  117. v, ok := <-readQueue
  118. if !ok {
  119. return
  120. }
  121. i, err := c.Read(v.buff)
  122. if i > 3 {
  123. v.buff[1] = 0
  124. v.buff[2] = 0
  125. v.buff[3] = 0
  126. }
  127. v.bytes = i
  128. v.endpoint = endpoint
  129. v.err = err
  130. v.waiter.Done()
  131. if err != nil {
  132. endpoint.conn = nil
  133. return
  134. }
  135. }
  136. }(bind.readQueue, endpoint)
  137. return nil
  138. }
  139. func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
  140. var err error
  141. nend, ok := endpoint.(*netEndpoint)
  142. if !ok {
  143. return conn.ErrWrongEndpointType
  144. }
  145. if nend.conn == nil {
  146. err = bind.connectTo(nend)
  147. if err != nil {
  148. return err
  149. }
  150. }
  151. for _, buff := range buff {
  152. if len(buff) > 3 && len(bind.reserved) == 3 {
  153. copy(buff[1:], bind.reserved)
  154. }
  155. if _, err = nend.conn.Write(buff); err != nil {
  156. return err
  157. }
  158. }
  159. return nil
  160. }
  161. type netBindServer struct {
  162. netBind
  163. }
  164. func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
  165. var err error
  166. nend, ok := endpoint.(*netEndpoint)
  167. if !ok {
  168. return conn.ErrWrongEndpointType
  169. }
  170. if nend.conn == nil {
  171. return errors.New("connection not open yet")
  172. }
  173. for _, buff := range buff {
  174. if _, err = nend.conn.Write(buff); err != nil {
  175. return err
  176. }
  177. }
  178. return err
  179. }
  180. type netEndpoint struct {
  181. dst net.Destination
  182. conn net.Conn
  183. }
  184. func (netEndpoint) ClearSrc() {}
  185. func (e netEndpoint) DstIP() netip.Addr {
  186. return netip.Addr{}
  187. }
  188. func (e netEndpoint) SrcIP() netip.Addr {
  189. return netip.Addr{}
  190. }
  191. func (e netEndpoint) DstToBytes() []byte {
  192. var dat []byte
  193. if e.dst.Address.Family().IsIPv4() {
  194. dat = e.dst.Address.IP().To4()[:]
  195. } else {
  196. dat = e.dst.Address.IP().To16()[:]
  197. }
  198. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  199. return dat
  200. }
  201. func (e netEndpoint) DstToString() string {
  202. return e.dst.NetAddr()
  203. }
  204. func (e netEndpoint) SrcToString() string {
  205. return ""
  206. }
  207. func toNetIpAddr(addr net.Address) netip.Addr {
  208. if addr.Family().IsIPv4() {
  209. ip := addr.IP()
  210. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  211. } else {
  212. ip := addr.IP()
  213. arr := [16]byte{}
  214. for i := 0; i < 16; i++ {
  215. arr[i] = ip[i]
  216. }
  217. return netip.AddrFrom16(arr)
  218. }
  219. }