tun.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /* SPDX-License-Identifier: MIT
  2. *
  3. * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
  4. */
  5. package gvisortun
  6. import (
  7. "context"
  8. "fmt"
  9. "net/netip"
  10. "os"
  11. "syscall"
  12. "golang.zx2c4.com/wireguard/tun"
  13. "gvisor.dev/gvisor/pkg/buffer"
  14. "gvisor.dev/gvisor/pkg/tcpip"
  15. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  16. "gvisor.dev/gvisor/pkg/tcpip/header"
  17. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  18. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  19. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  20. "gvisor.dev/gvisor/pkg/tcpip/stack"
  21. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  22. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  23. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  24. )
  25. type netTun struct {
  26. ep *channel.Endpoint
  27. stack *stack.Stack
  28. events chan tun.Event
  29. incomingPacket chan *buffer.View
  30. mtu int
  31. hasV4, hasV6 bool
  32. }
  33. type Net netTun
  34. func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
  35. opts := stack.Options{
  36. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  37. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
  38. HandleLocal: !promiscuousMode,
  39. }
  40. dev := &netTun{
  41. ep: channel.New(1024, uint32(mtu), ""),
  42. stack: stack.New(opts),
  43. events: make(chan tun.Event, 1),
  44. incomingPacket: make(chan *buffer.View),
  45. mtu: mtu,
  46. }
  47. dev.ep.AddNotify(dev)
  48. tcpipErr := dev.stack.CreateNIC(1, dev.ep)
  49. if tcpipErr != nil {
  50. return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
  51. }
  52. for _, ip := range localAddresses {
  53. var protoNumber tcpip.NetworkProtocolNumber
  54. if ip.Is4() {
  55. protoNumber = ipv4.ProtocolNumber
  56. } else if ip.Is6() {
  57. protoNumber = ipv6.ProtocolNumber
  58. }
  59. protoAddr := tcpip.ProtocolAddress{
  60. Protocol: protoNumber,
  61. AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
  62. }
  63. tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
  64. if tcpipErr != nil {
  65. return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
  66. }
  67. if ip.Is4() {
  68. dev.hasV4 = true
  69. } else if ip.Is6() {
  70. dev.hasV6 = true
  71. }
  72. }
  73. if dev.hasV4 {
  74. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
  75. }
  76. if dev.hasV6 {
  77. dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
  78. }
  79. if promiscuousMode {
  80. // enable promiscuous mode to handle all packets processed by netstack
  81. dev.stack.SetPromiscuousMode(1, true)
  82. dev.stack.SetSpoofing(1, true)
  83. }
  84. opt := tcpip.CongestionControlOption("cubic")
  85. if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  86. return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
  87. }
  88. dev.events <- tun.EventUp
  89. return dev, (*Net)(dev), dev.stack, nil
  90. }
  91. // BatchSize implements tun.Device
  92. func (tun *netTun) BatchSize() int {
  93. return 1
  94. }
  95. // Name implements tun.Device
  96. func (tun *netTun) Name() (string, error) {
  97. return "go", nil
  98. }
  99. // File implements tun.Device
  100. func (tun *netTun) File() *os.File {
  101. return nil
  102. }
  103. // Events implements tun.Device
  104. func (tun *netTun) Events() <-chan tun.Event {
  105. return tun.events
  106. }
  107. // Read implements tun.Device
  108. func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
  109. view, ok := <-tun.incomingPacket
  110. if !ok {
  111. return 0, os.ErrClosed
  112. }
  113. n, err := view.Read(buf[0][offset:])
  114. if err != nil {
  115. return 0, err
  116. }
  117. sizes[0] = n
  118. return 1, nil
  119. }
  120. // Write implements tun.Device
  121. func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
  122. for _, buf := range buf {
  123. packet := buf[offset:]
  124. if len(packet) == 0 {
  125. continue
  126. }
  127. pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
  128. switch packet[0] >> 4 {
  129. case 4:
  130. tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
  131. case 6:
  132. tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
  133. default:
  134. return 0, syscall.EAFNOSUPPORT
  135. }
  136. }
  137. return len(buf), nil
  138. }
  139. // WriteNotify implements channel.Notification
  140. func (tun *netTun) WriteNotify() {
  141. pkt := tun.ep.Read()
  142. if pkt == nil {
  143. return
  144. }
  145. view := pkt.ToView()
  146. pkt.DecRef()
  147. tun.incomingPacket <- view
  148. }
  149. // Flush implements tun.Device
  150. func (tun *netTun) Flush() error {
  151. return nil
  152. }
  153. // Close implements tun.Device
  154. func (tun *netTun) Close() error {
  155. tun.stack.RemoveNIC(1)
  156. if tun.events != nil {
  157. close(tun.events)
  158. }
  159. tun.ep.Close()
  160. if tun.incomingPacket != nil {
  161. close(tun.incomingPacket)
  162. }
  163. return nil
  164. }
  165. // MTU implements tun.Device
  166. func (tun *netTun) MTU() (int, error) {
  167. return tun.mtu, nil
  168. }
  169. func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
  170. var protoNumber tcpip.NetworkProtocolNumber
  171. if endpoint.Addr().Is4() {
  172. protoNumber = ipv4.ProtocolNumber
  173. } else {
  174. protoNumber = ipv6.ProtocolNumber
  175. }
  176. return tcpip.FullAddress{
  177. NIC: 1,
  178. Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
  179. Port: endpoint.Port(),
  180. }, protoNumber
  181. }
  182. func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
  183. fa, pn := convertToFullAddr(addr)
  184. return gonet.DialContextTCP(ctx, net.stack, fa, pn)
  185. }
  186. func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
  187. var lfa, rfa *tcpip.FullAddress
  188. var pn tcpip.NetworkProtocolNumber
  189. if laddr.IsValid() || laddr.Port() > 0 {
  190. var addr tcpip.FullAddress
  191. addr, pn = convertToFullAddr(laddr)
  192. lfa = &addr
  193. }
  194. if raddr.IsValid() || raddr.Port() > 0 {
  195. var addr tcpip.FullAddress
  196. addr, pn = convertToFullAddr(raddr)
  197. rfa = &addr
  198. }
  199. return gonet.DialUDP(net.stack, lfa, rfa, pn)
  200. }