tun.go 4.8 KB


  1. package wireguard
  2. import (
  3. "context"
  4. "fmt"
  5. "net/netip"
  6. "runtime"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/xtls/xray-core/common/errors"
  12. "github.com/xtls/xray-core/common/log"
  13. "github.com/xtls/xray-core/common/net"
  14. "github.com/xtls/xray-core/proxy/wireguard/gvisortun"
  15. "gvisor.dev/gvisor/pkg/tcpip"
  16. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  17. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  18. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  19. "gvisor.dev/gvisor/pkg/waiter"
  20. "golang.zx2c4.com/wireguard/conn"
  21. "golang.zx2c4.com/wireguard/device"
  22. "golang.zx2c4.com/wireguard/tun"
  23. )
  24. type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
  25. type promiscuousModeHandler func(dest net.Destination, conn net.Conn)
  26. type Tunnel interface {
  27. BuildDevice(ipc string, bind conn.Bind) error
  28. DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
  29. DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error)
  30. Close() error
  31. }
  32. type tunnel struct {
  33. tun tun.Device
  34. device *device.Device
  35. rw sync.Mutex
  36. }
  37. func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) {
  38. t.rw.Lock()
  39. defer t.rw.Unlock()
  40. if t.device != nil {
  41. return errors.New("device is already initialized")
  42. }
  43. logger := &device.Logger{
  44. Verbosef: func(format string, args ...any) {
  45. log.Record(&log.GeneralMessage{
  46. Severity: log.Severity_Debug,
  47. Content: fmt.Sprintf(format, args...),
  48. })
  49. },
  50. Errorf: func(format string, args ...any) {
  51. log.Record(&log.GeneralMessage{
  52. Severity: log.Severity_Error,
  53. Content: fmt.Sprintf(format, args...),
  54. })
  55. },
  56. }
  57. t.device = device.NewDevice(t.tun, bind, logger)
  58. if err = t.device.IpcSet(ipc); err != nil {
  59. return err
  60. }
  61. if err = t.device.Up(); err != nil {
  62. return err
  63. }
  64. return nil
  65. }
  66. func (t *tunnel) Close() (err error) {
  67. t.rw.Lock()
  68. defer t.rw.Unlock()
  69. if t.device == nil {
  70. return nil
  71. }
  72. t.device.Close()
  73. t.device = nil
  74. err = t.tun.Close()
  75. t.tun = nil
  76. return nil
  77. }
  78. func CalculateInterfaceName(name string) (tunName string) {
  79. if runtime.GOOS == "darwin" {
  80. tunName = "utun"
  81. } else if name != "" {
  82. tunName = name
  83. } else {
  84. tunName = "tun"
  85. }
  86. interfaces, err := net.Interfaces()
  87. if err != nil {
  88. return
  89. }
  90. var tunIndex int
  91. for _, netInterface := range interfaces {
  92. if strings.HasPrefix(netInterface.Name, tunName) {
  93. index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
  94. if parseErr == nil {
  95. tunIndex = int(index) + 1
  96. }
  97. }
  98. }
  99. tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
  100. return
  101. }
  102. var _ Tunnel = (*gvisorNet)(nil)
  103. type gvisorNet struct {
  104. tunnel
  105. net *gvisortun.Net
  106. }
  107. func (g *gvisorNet) Close() error {
  108. return g.tunnel.Close()
  109. }
  110. func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
  111. net.Conn, error,
  112. ) {
  113. return g.net.DialContextTCPAddrPort(ctx, addr)
  114. }
  115. func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
  116. return g.net.DialUDPAddrPort(laddr, raddr)
  117. }
  118. func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
  119. out := &gvisorNet{}
  120. tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
  121. if err != nil {
  122. return nil, err
  123. }
  124. if handler != nil {
  125. // handler is only used for promiscuous mode
  126. // capture all packets and send to handler
  127. tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
  128. go func(r *tcp.ForwarderRequest) {
  129. var (
  130. wq waiter.Queue
  131. id = r.ID()
  132. )
  133. // Perform a TCP three-way handshake.
  134. ep, err := r.CreateEndpoint(&wq)
  135. if err != nil {
  136. errors.LogError(context.Background(), err.String())
  137. r.Complete(true)
  138. return
  139. }
  140. r.Complete(false)
  141. defer ep.Close()
  142. // enable tcp keep-alive to prevent hanging connections
  143. ep.SocketOptions().SetKeepAlive(true)
  144. // local address is actually destination
  145. handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
  146. }(r)
  147. })
  148. stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
  149. udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
  150. go func(r *udp.ForwarderRequest) {
  151. var (
  152. wq waiter.Queue
  153. id = r.ID()
  154. )
  155. ep, err := r.CreateEndpoint(&wq)
  156. if err != nil {
  157. errors.LogError(context.Background(), err.String())
  158. return
  159. }
  160. defer ep.Close()
  161. // prevents hanging connections and ensure timely release
  162. ep.SocketOptions().SetLinger(tcpip.LingerOption{
  163. Enabled: true,
  164. Timeout: 15 * time.Second,
  165. })
  166. handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
  167. }(r)
  168. })
  169. stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
  170. }
  171. out.tun, out.net = tun, n
  172. return out, nil
  173. }