tun.go 4.9 KB


  1. package wireguard
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/netip"
  7. "runtime"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/xtls/xray-core/common/errors"
  13. "github.com/xtls/xray-core/common/log"
  14. xnet "github.com/xtls/xray-core/common/net"
  15. "github.com/xtls/xray-core/proxy/wireguard/gvisortun"
  16. "gvisor.dev/gvisor/pkg/tcpip"
  17. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  18. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  19. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  20. "gvisor.dev/gvisor/pkg/waiter"
  21. "golang.zx2c4.com/wireguard/conn"
  22. "golang.zx2c4.com/wireguard/device"
  23. "golang.zx2c4.com/wireguard/tun"
  24. )
  25. type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
  26. type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
  27. type Tunnel interface {
  28. BuildDevice(ipc string, bind conn.Bind) error
  29. DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
  30. DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error)
  31. Close() error
  32. }
  33. type tunnel struct {
  34. tun tun.Device
  35. device *device.Device
  36. rw sync.Mutex
  37. }
  38. func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) {
  39. t.rw.Lock()
  40. defer t.rw.Unlock()
  41. if t.device != nil {
  42. return errors.New("device is already initialized")
  43. }
  44. logger := &device.Logger{
  45. Verbosef: func(format string, args ...any) {
  46. log.Record(&log.GeneralMessage{
  47. Severity: log.Severity_Debug,
  48. Content: fmt.Sprintf(format, args...),
  49. })
  50. },
  51. Errorf: func(format string, args ...any) {
  52. log.Record(&log.GeneralMessage{
  53. Severity: log.Severity_Error,
  54. Content: fmt.Sprintf(format, args...),
  55. })
  56. },
  57. }
  58. t.device = device.NewDevice(t.tun, bind, logger)
  59. if err = t.device.IpcSet(ipc); err != nil {
  60. return err
  61. }
  62. if err = t.device.Up(); err != nil {
  63. return err
  64. }
  65. return nil
  66. }
  67. func (t *tunnel) Close() (err error) {
  68. t.rw.Lock()
  69. defer t.rw.Unlock()
  70. if t.device == nil {
  71. return nil
  72. }
  73. t.device.Close()
  74. t.device = nil
  75. err = t.tun.Close()
  76. t.tun = nil
  77. return nil
  78. }
  79. func CalculateInterfaceName(name string) (tunName string) {
  80. if runtime.GOOS == "darwin" {
  81. tunName = "utun"
  82. } else if name != "" {
  83. tunName = name
  84. } else {
  85. tunName = "tun"
  86. }
  87. interfaces, err := net.Interfaces()
  88. if err != nil {
  89. return
  90. }
  91. var tunIndex int
  92. for _, netInterface := range interfaces {
  93. if strings.HasPrefix(netInterface.Name, tunName) {
  94. index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
  95. if parseErr == nil {
  96. tunIndex = int(index) + 1
  97. }
  98. }
  99. }
  100. tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
  101. return
  102. }
  103. var _ Tunnel = (*gvisorNet)(nil)
  104. type gvisorNet struct {
  105. tunnel
  106. net *gvisortun.Net
  107. }
  108. func (g *gvisorNet) Close() error {
  109. return g.tunnel.Close()
  110. }
  111. func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
  112. net.Conn, error,
  113. ) {
  114. return g.net.DialContextTCPAddrPort(ctx, addr)
  115. }
  116. func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
  117. return g.net.DialUDPAddrPort(laddr, raddr)
  118. }
  119. func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
  120. out := &gvisorNet{}
  121. tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
  122. if err != nil {
  123. return nil, err
  124. }
  125. if handler != nil {
  126. // handler is only used for promiscuous mode
  127. // capture all packets and send to handler
  128. tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
  129. go func(r *tcp.ForwarderRequest) {
  130. var (
  131. wq waiter.Queue
  132. id = r.ID()
  133. )
  134. // Perform a TCP three-way handshake.
  135. ep, err := r.CreateEndpoint(&wq)
  136. if err != nil {
  137. errors.LogError(context.Background(), err.String())
  138. r.Complete(true)
  139. return
  140. }
  141. r.Complete(false)
  142. defer ep.Close()
  143. // enable tcp keep-alive to prevent hanging connections
  144. ep.SocketOptions().SetKeepAlive(true)
  145. // local address is actually destination
  146. handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
  147. }(r)
  148. })
  149. stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
  150. udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
  151. go func(r *udp.ForwarderRequest) {
  152. var (
  153. wq waiter.Queue
  154. id = r.ID()
  155. )
  156. ep, err := r.CreateEndpoint(&wq)
  157. if err != nil {
  158. errors.LogError(context.Background(), err.String())
  159. return
  160. }
  161. defer ep.Close()
  162. // prevents hanging connections and ensure timely release
  163. ep.SocketOptions().SetLinger(tcpip.LingerOption{
  164. Enabled: true,
  165. Timeout: 15 * time.Second,
  166. })
  167. handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
  168. }(r)
  169. })
  170. stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
  171. }
  172. out.tun, out.net = tun, n
  173. return out, nil
  174. }