tun.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. /* SPDX-License-Identifier: MIT
  2. *
  3. * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
  4. */
  5. package gvisortun
  6. import (
  7. "context"
  8. "fmt"
  9. "net/netip"
  10. "os"
  11. "sync"
  12. "syscall"
  13. "golang.zx2c4.com/wireguard/tun"
  14. "gvisor.dev/gvisor/pkg/buffer"
  15. "gvisor.dev/gvisor/pkg/tcpip"
  16. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  17. "gvisor.dev/gvisor/pkg/tcpip/header"
  18. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  19. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  20. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  21. "gvisor.dev/gvisor/pkg/tcpip/stack"
  22. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  23. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  24. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  25. )
  26. type netTun struct {
  27. ep *channel.Endpoint
  28. stack *stack.Stack
  29. events chan tun.Event
  30. incomingPacket chan *buffer.View
  31. mtu int
  32. hasV4, hasV6 bool
  33. closeOnce sync.Once
  34. }
  35. type Net netTun
  36. func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
  37. opts := stack.Options{
  38. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  39. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
  40. HandleLocal: !promiscuousMode,
  41. }
  42. dev := &netTun{
  43. ep: channel.New(1024, uint32(mtu), ""),
  44. stack: stack.New(opts),
  45. events: make(chan tun.Event, 1),
  46. incomingPacket: make(chan *buffer.View),
  47. mtu: mtu,
  48. }
  49. dev.ep.AddNotify(dev)
  50. tcpipErr := dev.stack.CreateNIC(1, dev.ep)
  51. if tcpipErr != nil {
  52. return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
  53. }
  54. for _, ip := range localAddresses {
  55. var protoNumber tcpip.NetworkProtocolNumber
  56. if ip.Is4() {
  57. protoNumber = ipv4.ProtocolNumber
  58. } else if ip.Is6() {
  59. protoNumber = ipv6.ProtocolNumber
  60. }
  61. protoAddr := tcpip.ProtocolAddress{
  62. Protocol: protoNumber,
  63. AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
  64. }
  65. tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
  66. if tcpipErr != nil {
  67. return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
  68. }
  69. if ip.Is4() {
  70. dev.hasV4 = true
  71. } else if ip.Is6() {
  72. dev.hasV6 = true
  73. }
  74. }
  75. if dev.hasV4 {
  76. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
  77. }
  78. if dev.hasV6 {
  79. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
  80. }
  81. if promiscuousMode {
  82. // enable promiscuous mode to handle all packets processed by netstack
  83. dev.stack.SetPromiscuousMode(1, true)
  84. dev.stack.SetSpoofing(1, true)
  85. }
  86. opt := tcpip.CongestionControlOption("cubic")
  87. if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  88. return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
  89. }
  90. dev.events <- tun.EventUp
  91. return dev, (*Net)(dev), dev.stack, nil
  92. }
  93. // BatchSize implements tun.Device
  94. func (tun *netTun) BatchSize() int {
  95. return 1
  96. }
  97. // Name implements tun.Device
  98. func (tun *netTun) Name() (string, error) {
  99. return "go", nil
  100. }
  101. // File implements tun.Device
  102. func (tun *netTun) File() *os.File {
  103. return nil
  104. }
  105. // Events implements tun.Device
  106. func (tun *netTun) Events() <-chan tun.Event {
  107. return tun.events
  108. }
  109. // Read implements tun.Device
  110. func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
  111. view, ok := <-tun.incomingPacket
  112. if !ok {
  113. return 0, os.ErrClosed
  114. }
  115. n, err := view.Read(buf[0][offset:])
  116. if err != nil {
  117. return 0, err
  118. }
  119. sizes[0] = n
  120. return 1, nil
  121. }
  122. // Write implements tun.Device
  123. func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
  124. for _, buf := range buf {
  125. packet := buf[offset:]
  126. if len(packet) == 0 {
  127. continue
  128. }
  129. pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
  130. switch packet[0] >> 4 {
  131. case 4:
  132. tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
  133. case 6:
  134. tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
  135. default:
  136. return 0, syscall.EAFNOSUPPORT
  137. }
  138. }
  139. return len(buf), nil
  140. }
  141. // WriteNotify implements channel.Notification
  142. func (tun *netTun) WriteNotify() {
  143. pkt := tun.ep.Read()
  144. if pkt == nil {
  145. return
  146. }
  147. view := pkt.ToView()
  148. pkt.DecRef()
  149. tun.incomingPacket <- view
  150. }
  151. // Flush implements tun.Device
  152. func (tun *netTun) Flush() error {
  153. return nil
  154. }
  155. // Close implements tun.Device
  156. func (tun *netTun) Close() error {
  157. tun.closeOnce.Do(func() {
  158. tun.stack.RemoveNIC(1)
  159. close(tun.events)
  160. tun.ep.Close()
  161. close(tun.incomingPacket)
  162. })
  163. return nil
  164. }
  165. // MTU implements tun.Device
  166. func (tun *netTun) MTU() (int, error) {
  167. return tun.mtu, nil
  168. }
  169. func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
  170. var protoNumber tcpip.NetworkProtocolNumber
  171. if endpoint.Addr().Is4() {
  172. protoNumber = ipv4.ProtocolNumber
  173. } else {
  174. protoNumber = ipv6.ProtocolNumber
  175. }
  176. return tcpip.FullAddress{
  177. NIC: 1,
  178. Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
  179. Port: endpoint.Port(),
  180. }, protoNumber
  181. }
  182. func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
  183. fa, pn := convertToFullAddr(addr)
  184. return gonet.DialContextTCP(ctx, net.stack, fa, pn)
  185. }
  186. func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
  187. var lfa, rfa *tcpip.FullAddress
  188. var pn tcpip.NetworkProtocolNumber
  189. if laddr.IsValid() || laddr.Port() > 0 {
  190. var addr tcpip.FullAddress
  191. addr, pn = convertToFullAddr(laddr)
  192. lfa = &addr
  193. }
  194. if raddr.IsValid() || raddr.Port() > 0 {
  195. var addr tcpip.FullAddress
  196. addr, pn = convertToFullAddr(raddr)
  197. rfa = &addr
  198. }
  199. return gonet.DialUDP(net.stack, lfa, rfa, pn)
  200. }