netstack.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // netstack doesn't build on 32-bit machines (https://github.com/google/gvisor/issues/5241)
  5. // +build amd64 arm64 ppc64le riscv64 s390x
  6. // Package netstack wires up gVisor's netstack into Tailscale.
  7. package netstack
  8. import (
  9. "context"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "gvisor.dev/gvisor/pkg/tcpip"
  19. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  20. "gvisor.dev/gvisor/pkg/tcpip/buffer"
  21. "gvisor.dev/gvisor/pkg/tcpip/header"
  22. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  23. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  24. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  25. "gvisor.dev/gvisor/pkg/tcpip/stack"
  26. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  27. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  28. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  29. "gvisor.dev/gvisor/pkg/waiter"
  30. "inet.af/netaddr"
  31. "tailscale.com/net/packet"
  32. "tailscale.com/types/logger"
  33. "tailscale.com/types/netmap"
  34. "tailscale.com/util/dnsname"
  35. "tailscale.com/wgengine"
  36. "tailscale.com/wgengine/filter"
  37. "tailscale.com/wgengine/magicsock"
  38. "tailscale.com/wgengine/tstun"
  39. )
  40. // Impl contains the state for the netstack implementation,
  41. // and implements wgengine.FakeImpl to act as a userspace network
  42. // stack when Tailscale is running in fake mode.
  43. type Impl struct {
  44. ipstack *stack.Stack
  45. linkEP *channel.Endpoint
  46. tundev *tstun.TUN
  47. e wgengine.Engine
  48. mc *magicsock.Conn
  49. logf logger.Logf
  50. mu sync.Mutex
  51. dns DNSMap
  52. }
  53. const nicID = 1
  54. // Create creates and populates a new Impl.
  55. func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) {
  56. if mc == nil {
  57. return nil, errors.New("nil magicsock.Conn")
  58. }
  59. if tundev == nil {
  60. return nil, errors.New("nil tundev")
  61. }
  62. if logf == nil {
  63. return nil, errors.New("nil logger")
  64. }
  65. if e == nil {
  66. return nil, errors.New("nil Engine")
  67. }
  68. ipstack := stack.New(stack.Options{
  69. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  70. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  71. })
  72. const mtu = 1500
  73. linkEP := channel.New(512, mtu, "")
  74. if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
  75. return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
  76. }
  77. // Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side
  78. // are handled by the one fake NIC we use.
  79. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
  80. ipv6Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 16)), tcpip.AddressMask(strings.Repeat("\x00", 16)))
  81. ipstack.SetRouteTable([]tcpip.Route{
  82. {
  83. Destination: ipv4Subnet,
  84. NIC: nicID,
  85. },
  86. {
  87. Destination: ipv6Subnet,
  88. NIC: nicID,
  89. },
  90. })
  91. ns := &Impl{
  92. logf: logf,
  93. ipstack: ipstack,
  94. linkEP: linkEP,
  95. tundev: tundev,
  96. e: e,
  97. mc: mc,
  98. }
  99. return ns, nil
  100. }
  101. // Start sets up all the handlers so netstack can start working. Implements
  102. // wgengine.FakeImpl.
  103. func (ns *Impl) Start() error {
  104. ns.e.AddNetworkMapCallback(ns.updateIPs)
  105. // size = 0 means use default buffer size
  106. const tcpReceiveBufferSize = 0
  107. const maxInFlightConnectionAttempts = 16
  108. tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
  109. udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
  110. ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
  111. ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
  112. go ns.injectOutbound()
  113. ns.tundev.PostFilterIn = ns.injectInbound
  114. return nil
  115. }
  116. // DNSMap maps MagicDNS names (both base + FQDN) to their first IP.
  117. // It should not be mutated once created.
  118. type DNSMap map[string]netaddr.IP
  119. func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap {
  120. ret := make(DNSMap)
  121. suffix := nm.MagicDNSSuffix()
  122. if nm.Name != "" && len(nm.Addresses) > 0 {
  123. ip := nm.Addresses[0].IP
  124. ret[strings.TrimRight(nm.Name, ".")] = ip
  125. if dnsname.HasSuffix(nm.Name, suffix) {
  126. ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip
  127. }
  128. }
  129. for _, p := range nm.Peers {
  130. if p.Name != "" && len(p.Addresses) > 0 {
  131. ip := p.Addresses[0].IP
  132. ret[strings.TrimRight(p.Name, ".")] = ip
  133. if dnsname.HasSuffix(p.Name, suffix) {
  134. ret[dnsname.TrimSuffix(p.Name, suffix)] = ip
  135. }
  136. }
  137. }
  138. return ret
  139. }
  140. func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
  141. ns.mu.Lock()
  142. defer ns.mu.Unlock()
  143. ns.dns = DNSMapFromNetworkMap(nm)
  144. }
  145. func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
  146. ns.updateDNS(nm)
  147. oldIPs := make(map[tcpip.Address]bool)
  148. for _, ip := range ns.ipstack.AllAddresses()[nicID] {
  149. oldIPs[ip.AddressWithPrefix.Address] = true
  150. }
  151. newIPs := make(map[tcpip.Address]bool)
  152. for _, ip := range nm.Addresses {
  153. newIPs[tcpip.Address(ip.IP.IPAddr().IP)] = true
  154. }
  155. ipsToBeAdded := make(map[tcpip.Address]bool)
  156. for ip := range newIPs {
  157. if !oldIPs[ip] {
  158. ipsToBeAdded[ip] = true
  159. }
  160. }
  161. ipsToBeRemoved := make(map[tcpip.Address]bool)
  162. for ip := range oldIPs {
  163. if !newIPs[ip] {
  164. ipsToBeRemoved[ip] = true
  165. }
  166. }
  167. for ip := range ipsToBeRemoved {
  168. err := ns.ipstack.RemoveAddress(nicID, ip)
  169. if err != nil {
  170. ns.logf("netstack: could not deregister IP %s: %v", ip, err)
  171. } else {
  172. ns.logf("[v2] netstack: deregistered IP %s", ip)
  173. }
  174. }
  175. for ip := range ipsToBeAdded {
  176. var err *tcpip.Error
  177. if ip.To4() == "" {
  178. err = ns.ipstack.AddAddress(nicID, ipv6.ProtocolNumber, ip)
  179. } else {
  180. err = ns.ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip)
  181. }
  182. if err != nil {
  183. ns.logf("netstack: could not register IP %s: %v", ip, err)
  184. } else {
  185. ns.logf("[v2] netstack: registered IP %s", ip)
  186. }
  187. }
  188. }
  189. // Resolve resolves addr into an IP:port using first the MagicDNS contents
  190. // of m, else using the system resolver.
  191. func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
  192. ipp, pippErr := netaddr.ParseIPPort(addr)
  193. if pippErr == nil {
  194. return ipp, nil
  195. }
  196. host, port, err := net.SplitHostPort(addr)
  197. if err != nil {
  198. // addr is malformed.
  199. return netaddr.IPPort{}, err
  200. }
  201. if net.ParseIP(host) != nil {
  202. // The host part of addr was an IP, so the netaddr.ParseIPPort above should've
  203. // passed. Must've been a bad port number. Return the original error.
  204. return netaddr.IPPort{}, pippErr
  205. }
  206. port16, err := strconv.ParseUint(port, 10, 16)
  207. if err != nil {
  208. return netaddr.IPPort{}, fmt.Errorf("invalid port in address %q", addr)
  209. }
  210. // Host is not an IP, so assume it's a DNS name.
  211. // Try MagicDNS first, else otherwise a real DNS lookup.
  212. ip := m[host]
  213. if !ip.IsZero() {
  214. return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
  215. }
  216. // No Magic DNS name so try real DNS.
  217. var r net.Resolver
  218. ips, err := r.LookupIP(ctx, "ip", host)
  219. if err != nil {
  220. return netaddr.IPPort{}, err
  221. }
  222. if len(ips) == 0 {
  223. return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host)
  224. }
  225. ip, _ = netaddr.FromStdIP(ips[0])
  226. return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
  227. }
  228. func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
  229. ns.mu.Lock()
  230. dnsMap := ns.dns
  231. ns.mu.Unlock()
  232. remoteIPPort, err := dnsMap.Resolve(ctx, addr)
  233. if err != nil {
  234. return nil, err
  235. }
  236. remoteAddress := tcpip.FullAddress{
  237. NIC: nicID,
  238. Addr: tcpip.Address(remoteIPPort.IP.IPAddr().IP),
  239. Port: remoteIPPort.Port,
  240. }
  241. var ipType tcpip.NetworkProtocolNumber
  242. if remoteIPPort.IP.Is4() {
  243. ipType = ipv4.ProtocolNumber
  244. } else {
  245. ipType = ipv6.ProtocolNumber
  246. }
  247. return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType)
  248. }
  249. func (ns *Impl) injectOutbound() {
  250. for {
  251. packetInfo, ok := ns.linkEP.ReadContext(context.Background())
  252. if !ok {
  253. ns.logf("[v2] ReadContext-for-write = ok=false")
  254. continue
  255. }
  256. pkt := packetInfo.Pkt
  257. hdrNetwork := pkt.NetworkHeader()
  258. hdrTransport := pkt.TransportHeader()
  259. full := make([]byte, 0, pkt.Size())
  260. full = append(full, hdrNetwork.View()...)
  261. full = append(full, hdrTransport.View()...)
  262. full = append(full, pkt.Data.ToView()...)
  263. ns.logf("[v2] packet Write out: % x", full)
  264. if err := ns.tundev.InjectOutbound(full); err != nil {
  265. log.Printf("netstack inject outbound: %v", err)
  266. return
  267. }
  268. }
  269. }
  270. func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.TUN) filter.Response {
  271. var pn tcpip.NetworkProtocolNumber
  272. switch p.IPVersion {
  273. case 4:
  274. pn = header.IPv4ProtocolNumber
  275. case 6:
  276. pn = header.IPv6ProtocolNumber
  277. }
  278. ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer())
  279. vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView()
  280. packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
  281. Data: vv,
  282. })
  283. ns.linkEP.InjectInbound(pn, packetBuf)
  284. return filter.Accept
  285. }
  286. func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
  287. ns.logf("[v2] ForwarderRequest: %v", r)
  288. var wq waiter.Queue
  289. ep, err := r.CreateEndpoint(&wq)
  290. if err != nil {
  291. r.Complete(true)
  292. return
  293. }
  294. localAddr, err := ep.GetLocalAddress()
  295. ns.logf("[v2] forwarding port %v to 100.101.102.103:80", localAddr.Port)
  296. if err != nil {
  297. r.Complete(true)
  298. return
  299. }
  300. r.Complete(false)
  301. c := gonet.NewTCPConn(&wq, ep)
  302. go ns.forwardTCP(c, &wq, "100.101.102.103:80")
  303. }
  304. func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address string) {
  305. defer client.Close()
  306. ns.logf("[v2] netstack: forwarding to address %s", address)
  307. ctx, cancel := context.WithCancel(context.Background())
  308. defer cancel()
  309. waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  310. wq.EventRegister(&waitEntry, waiter.EventHUp)
  311. defer wq.EventUnregister(&waitEntry)
  312. done := make(chan bool)
  313. // netstack doesn't close the notification channel automatically if there was no
  314. // hup signal, so we close done after we're done to not leak the goroutine below.
  315. defer close(done)
  316. go func() {
  317. select {
  318. case <-notifyCh:
  319. case <-done:
  320. }
  321. cancel()
  322. }()
  323. server, err := ns.DialContextTCP(ctx, address)
  324. if err != nil {
  325. ns.logf("netstack: could not connect to server %s: %s", address, err)
  326. return
  327. }
  328. defer server.Close()
  329. connClosed := make(chan bool, 2)
  330. go func() {
  331. io.Copy(server, client)
  332. connClosed <- true
  333. }()
  334. go func() {
  335. io.Copy(client, server)
  336. connClosed <- true
  337. }()
  338. <-connClosed
  339. ns.logf("[v2] netstack: forwarder connection to %s closed", address)
  340. }
  341. func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
  342. ns.logf("[v2] UDP ForwarderRequest: %v", r)
  343. var wq waiter.Queue
  344. ep, err := r.CreateEndpoint(&wq)
  345. if err != nil {
  346. ns.logf("Could not create endpoint, exiting")
  347. return
  348. }
  349. c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
  350. go echoUDP(c)
  351. }
  352. func echoUDP(c *gonet.UDPConn) {
  353. buf := make([]byte, 1500)
  354. for {
  355. n, err := c.Read(buf)
  356. if err != nil {
  357. break
  358. }
  359. c.Write(buf[:n])
  360. }
  361. c.Close()
  362. }