bind.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "strconv"
  9. "sync"
  10. "github.com/sagernet/wireguard-go/conn"
  11. xnet "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/features/dns"
  13. "github.com/xtls/xray-core/transport/internet"
  14. )
  15. type netReadInfo struct {
  16. // status
  17. waiter sync.WaitGroup
  18. // param
  19. buff []byte
  20. // result
  21. bytes int
  22. endpoint conn.Endpoint
  23. err error
  24. }
  25. type netBindClient struct {
  26. workers int
  27. dialer internet.Dialer
  28. dns dns.Client
  29. dnsOption dns.IPOption
  30. reserved []byte
  31. readQueue chan *netReadInfo
  32. }
  33. func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
  34. ipStr, port, _, err := splitAddrPort(s)
  35. if err != nil {
  36. return nil, err
  37. }
  38. var addr net.IP
  39. if IsDomainName(ipStr) {
  40. ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
  41. if err != nil {
  42. return nil, err
  43. } else if len(ips) == 0 {
  44. return nil, dns.ErrEmptyResponse
  45. }
  46. addr = ips[0]
  47. } else {
  48. addr = net.ParseIP(ipStr)
  49. }
  50. if addr == nil {
  51. return nil, errors.New("failed to parse ip: " + ipStr)
  52. }
  53. var ip xnet.Address
  54. if p4 := addr.To4(); len(p4) == net.IPv4len {
  55. ip = xnet.IPAddress(p4[:])
  56. } else {
  57. ip = xnet.IPAddress(addr[:])
  58. }
  59. dst := xnet.Destination{
  60. Address: ip,
  61. Port: xnet.Port(port),
  62. Network: xnet.Network_UDP,
  63. }
  64. return &netEndpoint{
  65. dst: dst,
  66. }, nil
  67. }
  68. func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
  69. bind.readQueue = make(chan *netReadInfo)
  70. fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
  71. defer func() {
  72. if r := recover(); r != nil {
  73. cap = 0
  74. ep = nil
  75. err = errors.New("channel closed")
  76. }
  77. }()
  78. r := &netReadInfo{
  79. buff: buff,
  80. }
  81. r.waiter.Add(1)
  82. bind.readQueue <- r
  83. r.waiter.Wait() // wait read goroutine done, or we will miss the result
  84. return r.bytes, r.endpoint, r.err
  85. }
  86. workers := bind.workers
  87. if workers <= 0 {
  88. workers = 1
  89. }
  90. arr := make([]conn.ReceiveFunc, workers)
  91. for i := 0; i < workers; i++ {
  92. arr[i] = fun
  93. }
  94. return arr, uint16(uport), nil
  95. }
  96. func (bind *netBindClient) Close() error {
  97. if bind.readQueue != nil {
  98. close(bind.readQueue)
  99. }
  100. return nil
  101. }
  102. func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
  103. c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
  104. if err != nil {
  105. return err
  106. }
  107. endpoint.conn = c
  108. go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
  109. for {
  110. v, ok := <-readQueue
  111. if !ok {
  112. return
  113. }
  114. i, err := c.Read(v.buff)
  115. if i > 3 {
  116. v.buff[1] = 0
  117. v.buff[2] = 0
  118. v.buff[3] = 0
  119. }
  120. v.bytes = i
  121. v.endpoint = endpoint
  122. v.err = err
  123. v.waiter.Done()
  124. if err != nil && errors.Is(err, io.EOF) {
  125. endpoint.conn = nil
  126. return
  127. }
  128. }
  129. }(bind.readQueue, endpoint)
  130. return nil
  131. }
  132. func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
  133. var err error
  134. nend, ok := endpoint.(*netEndpoint)
  135. if !ok {
  136. return conn.ErrWrongEndpointType
  137. }
  138. if nend.conn == nil {
  139. err = bind.connectTo(nend)
  140. if err != nil {
  141. return err
  142. }
  143. }
  144. if len(buff) > 3 && len(bind.reserved) == 3 {
  145. copy(buff[1:], bind.reserved)
  146. }
  147. _, err = nend.conn.Write(buff)
  148. return err
  149. }
  150. func (bind *netBindClient) SetMark(mark uint32) error {
  151. return nil
  152. }
  153. type netEndpoint struct {
  154. dst xnet.Destination
  155. conn net.Conn
  156. }
  157. func (netEndpoint) ClearSrc() {}
  158. func (e netEndpoint) DstIP() netip.Addr {
  159. return toNetIpAddr(e.dst.Address)
  160. }
  161. func (e netEndpoint) SrcIP() netip.Addr {
  162. return netip.Addr{}
  163. }
  164. func (e netEndpoint) DstToBytes() []byte {
  165. var dat []byte
  166. if e.dst.Address.Family().IsIPv4() {
  167. dat = e.dst.Address.IP().To4()[:]
  168. } else {
  169. dat = e.dst.Address.IP().To16()[:]
  170. }
  171. dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
  172. return dat
  173. }
  174. func (e netEndpoint) DstToString() string {
  175. return e.dst.NetAddr()
  176. }
  177. func (e netEndpoint) SrcToString() string {
  178. return ""
  179. }
  180. func toNetIpAddr(addr xnet.Address) netip.Addr {
  181. if addr.Family().IsIPv4() {
  182. ip := addr.IP()
  183. return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
  184. } else {
  185. ip := addr.IP()
  186. arr := [16]byte{}
  187. for i := 0; i < 16; i++ {
  188. arr[i] = ip[i]
  189. }
  190. return netip.AddrFrom16(arr)
  191. }
  192. }
  193. func stringsLastIndexByte(s string, b byte) int {
  194. for i := len(s) - 1; i >= 0; i-- {
  195. if s[i] == b {
  196. return i
  197. }
  198. }
  199. return -1
  200. }
  201. func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
  202. i := stringsLastIndexByte(s, ':')
  203. if i == -1 {
  204. return "", 0, false, errors.New("not an ip:port")
  205. }
  206. ip = s[:i]
  207. portStr := s[i+1:]
  208. if len(ip) == 0 {
  209. return "", 0, false, errors.New("no IP")
  210. }
  211. if len(portStr) == 0 {
  212. return "", 0, false, errors.New("no port")
  213. }
  214. port64, err := strconv.ParseUint(portStr, 10, 16)
  215. if err != nil {
  216. return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
  217. }
  218. port = uint16(port64)
  219. if ip[0] == '[' {
  220. if len(ip) < 2 || ip[len(ip)-1] != ']' {
  221. return "", 0, false, errors.New("missing ]")
  222. }
  223. ip = ip[1 : len(ip)-1]
  224. v6 = true
  225. }
  226. return ip, port, v6, nil
  227. }