bind.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "net/netip"
  6. "strconv"
  7. "sync"
  8. "golang.zx2c4.com/wireguard/conn"
  9. "github.com/xtls/xray-core/common/net"
  10. "github.com/xtls/xray-core/features/dns"
  11. "github.com/xtls/xray-core/transport/internet"
  12. )
  13. type netReadInfo struct {
  14. // status
  15. waiter sync.WaitGroup
  16. // param
  17. buff []byte
  18. // result
  19. bytes int
  20. endpoint conn.Endpoint
  21. err error
  22. }
  23. // reduce duplicated code
  24. type netBind struct {
  25. dns dns.Client
  26. dnsOption dns.IPOption
  27. workers int
  28. readQueue chan *netReadInfo
  29. }
  30. // SetMark implements conn.Bind
  31. func (bind *netBind) SetMark(mark uint32) error {
  32. return nil
  33. }
  34. // ParseEndpoint implements conn.Bind
  35. func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  36. ipStr, port, err := net.SplitHostPort(s)
  37. if err != nil {
  38. return nil, err
  39. }
  40. portNum, err := strconv.Atoi(port)
  41. if err != nil {
  42. return nil, err
  43. }
  44. addr := net.ParseAddress(ipStr)
  45. if addr.Family() == net.AddressFamilyDomain {
  46. ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
  47. if err != nil {
  48. return nil, err
  49. } else if len(ips) == 0 {
  50. return nil, dns.ErrEmptyResponse
  51. }
  52. addr = net.IPAddress(ips[0])
  53. }
  54. dst := net.Destination{
  55. Address: addr,
  56. Port: net.Port(portNum),
  57. Network: net.Network_UDP,
  58. }
  59. return &netEndpoint{
  60. dst: dst,
  61. }, nil
  62. }
  63. // BatchSize implements conn.Bind
  64. func (bind *netBind) BatchSize() int {
  65. return 1
  66. }
  67. // Open implements conn.Bind
  68. func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  69. bind.readQueue = make(chan *netReadInfo)
  70. fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
  71. defer func() {
  72. if r := recover(); r != nil {
  73. n = 0
  74. err = errors.New("channel closed")
  75. }
  76. }()
  77. r, ok := <-bind.readQueue
  78. if !ok {
  79. return 0, errors.New("channel closed")
  80. }
  81. copy(bufs[0], r.buff[:r.bytes])
  82. sizes[0], eps[0] = r.bytes, r.endpoint
  83. r.waiter.Done()
  84. return 1, r.err
  85. }
  86. workers := bind.workers
  87. if workers <= 0 {
  88. workers = 1
  89. }
  90. arr := make([]conn.ReceiveFunc, workers)
  91. for i := 0; i < workers; i++ {
  92. arr[i] = fun
  93. }
  94. return arr, uint16(uport), nil
  95. }
  96. // Close implements conn.Bind
  97. func (bind *netBind) Close() error {
  98. if bind.readQueue != nil {
  99. close(bind.readQueue)
  100. }
  101. return nil
  102. }
  103. type netBindClient struct {
  104. netBind
  105. ctx context.Context
  106. dialer internet.Dialer
  107. reserved []byte
  108. }
  109. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  110. c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
  111. if err != nil {
  112. return err
  113. }
  114. endpoint.conn = c
  115. go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
  116. defer func() {
  117. _ = recover() // handle send on closed channel
  118. }()
  119. for {
  120. buff := make([]byte, 1700)
  121. i, err := c.Read(buff)
  122. if i > 3 {
  123. buff[1] = 0
  124. buff[2] = 0
  125. buff[3] = 0
  126. }
  127. r := &netReadInfo{
  128. buff: buff,
  129. bytes: i,
  130. endpoint: endpoint,
  131. err: err,
  132. }
  133. r.waiter.Add(1)
  134. readQueue <- r
  135. r.waiter.Wait()
  136. if err != nil {
  137. endpoint.conn = nil
  138. return
  139. }
  140. }
  141. }(bind.readQueue, endpoint)
  142. return nil
  143. }
  144. func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
  145. var err error
  146. nend, ok := endpoint.(*netEndpoint)
  147. if !ok {
  148. return conn.ErrWrongEndpointType
  149. }
  150. if nend.conn == nil {
  151. err = bind.connectTo(nend)
  152. if err != nil {
  153. return err
  154. }
  155. }
  156. for _, buff := range buff {
  157. if len(buff) > 3 && len(bind.reserved) == 3 {
  158. copy(buff[1:], bind.reserved)
  159. }
  160. if _, err = nend.conn.Write(buff); err != nil {
  161. return err
  162. }
  163. }
  164. return nil
  165. }
  166. type netBindServer struct {
  167. netBind
  168. }
  169. func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
  170. var err error
  171. nend, ok := endpoint.(*netEndpoint)
  172. if !ok {
  173. return conn.ErrWrongEndpointType
  174. }
  175. if nend.conn == nil {
  176. return errors.New("connection not open yet")
  177. }
  178. for _, buff := range buff {
  179. if _, err = nend.conn.Write(buff); err != nil {
  180. return err
  181. }
  182. }
  183. return err
  184. }
  185. type netEndpoint struct {
  186. dst net.Destination
  187. conn net.Conn
  188. }
  189. func (netEndpoint) ClearSrc() {}
  190. func (e netEndpoint) DstIP() netip.Addr {
  191. return netip.Addr{}
  192. }
  193. func (e netEndpoint) SrcIP() netip.Addr {
  194. return netip.Addr{}
  195. }
  196. func (e netEndpoint) DstToBytes() []byte {
  197. var dat []byte
  198. if e.dst.Address.Family().IsIPv4() {
  199. dat = e.dst.Address.IP().To4()[:]
  200. } else {
  201. dat = e.dst.Address.IP().To16()[:]
  202. }
  203. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  204. return dat
  205. }
  206. func (e netEndpoint) DstToString() string {
  207. return e.dst.NetAddr()
  208. }
  209. func (e netEndpoint) SrcToString() string {
  210. return ""
  211. }
  212. func toNetIpAddr(addr net.Address) netip.Addr {
  213. if addr.Family().IsIPv4() {
  214. ip := addr.IP()
  215. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  216. } else {
  217. ip := addr.IP()
  218. arr := [16]byte{}
  219. for i := 0; i < 16; i++ {
  220. arr[i] = ip[i]
  221. }
  222. return netip.AddrFrom16(arr)
  223. }
  224. }