netstack.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package netstack wires up gVisor's netstack into Tailscale.
  5. package netstack
  6. import (
  7. "context"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "time"
  17. "inet.af/netaddr"
  18. "inet.af/netstack/tcpip"
  19. "inet.af/netstack/tcpip/adapters/gonet"
  20. "inet.af/netstack/tcpip/buffer"
  21. "inet.af/netstack/tcpip/header"
  22. "inet.af/netstack/tcpip/link/channel"
  23. "inet.af/netstack/tcpip/network/ipv4"
  24. "inet.af/netstack/tcpip/network/ipv6"
  25. "inet.af/netstack/tcpip/stack"
  26. "inet.af/netstack/tcpip/transport/icmp"
  27. "inet.af/netstack/tcpip/transport/tcp"
  28. "inet.af/netstack/tcpip/transport/udp"
  29. "inet.af/netstack/waiter"
  30. "tailscale.com/net/packet"
  31. "tailscale.com/net/tstun"
  32. "tailscale.com/types/logger"
  33. "tailscale.com/types/netmap"
  34. "tailscale.com/util/dnsname"
  35. "tailscale.com/wgengine"
  36. "tailscale.com/wgengine/filter"
  37. "tailscale.com/wgengine/magicsock"
  38. )
  39. const debugNetstack = false
  40. // Impl contains the state for the netstack implementation,
  41. // and implements wgengine.FakeImpl to act as a userspace network
  42. // stack when Tailscale is running in fake mode.
  43. type Impl struct {
  44. ipstack *stack.Stack
  45. linkEP *channel.Endpoint
  46. tundev *tstun.Wrapper
  47. e wgengine.Engine
  48. mc *magicsock.Conn
  49. logf logger.Logf
  50. mu sync.Mutex
  51. dns DNSMap
  52. }
  53. const nicID = 1
  54. const mtu = 1500
  55. // Create creates and populates a new Impl.
  56. func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) {
  57. if mc == nil {
  58. return nil, errors.New("nil magicsock.Conn")
  59. }
  60. if tundev == nil {
  61. return nil, errors.New("nil tundev")
  62. }
  63. if logf == nil {
  64. return nil, errors.New("nil logger")
  65. }
  66. if e == nil {
  67. return nil, errors.New("nil Engine")
  68. }
  69. ipstack := stack.New(stack.Options{
  70. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  71. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  72. })
  73. linkEP := channel.New(512, mtu, "")
  74. if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
  75. return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
  76. }
  77. // Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side
  78. // are handled by the one fake NIC we use.
  79. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
  80. ipv6Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 16)), tcpip.AddressMask(strings.Repeat("\x00", 16)))
  81. ipstack.SetRouteTable([]tcpip.Route{
  82. {
  83. Destination: ipv4Subnet,
  84. NIC: nicID,
  85. },
  86. {
  87. Destination: ipv6Subnet,
  88. NIC: nicID,
  89. },
  90. })
  91. ns := &Impl{
  92. logf: logf,
  93. ipstack: ipstack,
  94. linkEP: linkEP,
  95. tundev: tundev,
  96. e: e,
  97. mc: mc,
  98. }
  99. return ns, nil
  100. }
  101. // Start sets up all the handlers so netstack can start working. Implements
  102. // wgengine.FakeImpl.
  103. func (ns *Impl) Start() error {
  104. ns.e.AddNetworkMapCallback(ns.updateIPs)
  105. // size = 0 means use default buffer size
  106. const tcpReceiveBufferSize = 0
  107. const maxInFlightConnectionAttempts = 16
  108. tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
  109. udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
  110. ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
  111. ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
  112. go ns.injectOutbound()
  113. ns.tundev.PostFilterIn = ns.injectInbound
  114. return nil
  115. }
  116. // DNSMap maps MagicDNS names (both base + FQDN) to their first IP.
  117. // It should not be mutated once created.
  118. type DNSMap map[string]netaddr.IP
  119. func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap {
  120. ret := make(DNSMap)
  121. suffix := nm.MagicDNSSuffix()
  122. if nm.Name != "" && len(nm.Addresses) > 0 {
  123. ip := nm.Addresses[0].IP
  124. ret[strings.TrimRight(nm.Name, ".")] = ip
  125. if dnsname.HasSuffix(nm.Name, suffix) {
  126. ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip
  127. }
  128. }
  129. for _, p := range nm.Peers {
  130. if p.Name != "" && len(p.Addresses) > 0 {
  131. ip := p.Addresses[0].IP
  132. ret[strings.TrimRight(p.Name, ".")] = ip
  133. if dnsname.HasSuffix(p.Name, suffix) {
  134. ret[dnsname.TrimSuffix(p.Name, suffix)] = ip
  135. }
  136. }
  137. }
  138. return ret
  139. }
  140. func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
  141. ns.mu.Lock()
  142. defer ns.mu.Unlock()
  143. ns.dns = DNSMapFromNetworkMap(nm)
  144. }
  145. func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
  146. ns.updateDNS(nm)
  147. oldIPs := make(map[tcpip.Address]bool)
  148. for _, ip := range ns.ipstack.AllAddresses()[nicID] {
  149. oldIPs[ip.AddressWithPrefix.Address] = true
  150. }
  151. newIPs := make(map[tcpip.Address]bool)
  152. for _, ip := range nm.Addresses {
  153. newIPs[tcpip.Address(ip.IP.IPAddr().IP)] = true
  154. }
  155. ipsToBeAdded := make(map[tcpip.Address]bool)
  156. for ip := range newIPs {
  157. if !oldIPs[ip] {
  158. ipsToBeAdded[ip] = true
  159. }
  160. }
  161. ipsToBeRemoved := make(map[tcpip.Address]bool)
  162. for ip := range oldIPs {
  163. if !newIPs[ip] {
  164. ipsToBeRemoved[ip] = true
  165. }
  166. }
  167. for ip := range ipsToBeRemoved {
  168. err := ns.ipstack.RemoveAddress(nicID, ip)
  169. if err != nil {
  170. ns.logf("netstack: could not deregister IP %s: %v", ip, err)
  171. } else {
  172. ns.logf("[v2] netstack: deregistered IP %s", ip)
  173. }
  174. }
  175. for ip := range ipsToBeAdded {
  176. var err tcpip.Error
  177. if ip.To4() == "" {
  178. err = ns.ipstack.AddAddress(nicID, ipv6.ProtocolNumber, ip)
  179. } else {
  180. err = ns.ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip)
  181. }
  182. if err != nil {
  183. ns.logf("netstack: could not register IP %s: %v", ip, err)
  184. } else {
  185. ns.logf("[v2] netstack: registered IP %s", ip)
  186. }
  187. }
  188. }
  189. // Resolve resolves addr into an IP:port using first the MagicDNS contents
  190. // of m, else using the system resolver.
  191. func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
  192. ipp, pippErr := netaddr.ParseIPPort(addr)
  193. if pippErr == nil {
  194. return ipp, nil
  195. }
  196. host, port, err := net.SplitHostPort(addr)
  197. if err != nil {
  198. // addr is malformed.
  199. return netaddr.IPPort{}, err
  200. }
  201. if net.ParseIP(host) != nil {
  202. // The host part of addr was an IP, so the netaddr.ParseIPPort above should've
  203. // passed. Must've been a bad port number. Return the original error.
  204. return netaddr.IPPort{}, pippErr
  205. }
  206. port16, err := strconv.ParseUint(port, 10, 16)
  207. if err != nil {
  208. return netaddr.IPPort{}, fmt.Errorf("invalid port in address %q", addr)
  209. }
  210. // Host is not an IP, so assume it's a DNS name.
  211. // Try MagicDNS first, else otherwise a real DNS lookup.
  212. ip := m[host]
  213. if !ip.IsZero() {
  214. return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
  215. }
  216. // No MagicDNS name so try real DNS.
  217. var r net.Resolver
  218. ips, err := r.LookupIP(ctx, "ip", host)
  219. if err != nil {
  220. return netaddr.IPPort{}, err
  221. }
  222. if len(ips) == 0 {
  223. return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host)
  224. }
  225. ip, _ = netaddr.FromStdIP(ips[0])
  226. return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
  227. }
  228. func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
  229. ns.mu.Lock()
  230. dnsMap := ns.dns
  231. ns.mu.Unlock()
  232. remoteIPPort, err := dnsMap.Resolve(ctx, addr)
  233. if err != nil {
  234. return nil, err
  235. }
  236. remoteAddress := tcpip.FullAddress{
  237. NIC: nicID,
  238. Addr: tcpip.Address(remoteIPPort.IP.IPAddr().IP),
  239. Port: remoteIPPort.Port,
  240. }
  241. var ipType tcpip.NetworkProtocolNumber
  242. if remoteIPPort.IP.Is4() {
  243. ipType = ipv4.ProtocolNumber
  244. } else {
  245. ipType = ipv6.ProtocolNumber
  246. }
  247. return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType)
  248. }
  249. func (ns *Impl) injectOutbound() {
  250. for {
  251. packetInfo, ok := ns.linkEP.ReadContext(context.Background())
  252. if !ok {
  253. ns.logf("[v2] ReadContext-for-write = ok=false")
  254. continue
  255. }
  256. pkt := packetInfo.Pkt
  257. hdrNetwork := pkt.NetworkHeader()
  258. hdrTransport := pkt.TransportHeader()
  259. full := make([]byte, 0, pkt.Size())
  260. full = append(full, hdrNetwork.View()...)
  261. full = append(full, hdrTransport.View()...)
  262. full = append(full, pkt.Data().AsRange().AsView()...)
  263. if debugNetstack {
  264. ns.logf("[v2] packet Write out: % x", full)
  265. }
  266. if err := ns.tundev.InjectOutbound(full); err != nil {
  267. log.Printf("netstack inject outbound: %v", err)
  268. return
  269. }
  270. }
  271. }
  272. func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
  273. var pn tcpip.NetworkProtocolNumber
  274. switch p.IPVersion {
  275. case 4:
  276. pn = header.IPv4ProtocolNumber
  277. case 6:
  278. pn = header.IPv6ProtocolNumber
  279. }
  280. if debugNetstack {
  281. ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer())
  282. }
  283. vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView()
  284. packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
  285. Data: vv,
  286. })
  287. ns.linkEP.InjectInbound(pn, packetBuf)
  288. return filter.Accept
  289. }
  290. func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
  291. if debugNetstack {
  292. // Kinda ugly:
  293. // ForwarderRequest: &{{{{0 0}}} 0xc0001c30b0 0xc0004c3d40 {1240 6 true 826109390 0 true}
  294. ns.logf("[v2] ForwarderRequest: %v", r)
  295. }
  296. var wq waiter.Queue
  297. ep, err := r.CreateEndpoint(&wq)
  298. if err != nil {
  299. r.Complete(true)
  300. return
  301. }
  302. localAddr, err := ep.GetLocalAddress()
  303. if err != nil {
  304. r.Complete(true)
  305. return
  306. }
  307. r.Complete(false)
  308. c := gonet.NewTCPConn(&wq, ep)
  309. go ns.forwardTCP(c, &wq, localAddr.Port)
  310. }
  311. func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, port uint16) {
  312. defer client.Close()
  313. ns.logf("[v2] netstack: forwarding incoming connection on port %v", port)
  314. ctx, cancel := context.WithCancel(context.Background())
  315. defer cancel()
  316. waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  317. wq.EventRegister(&waitEntry, waiter.EventHUp)
  318. defer wq.EventUnregister(&waitEntry)
  319. done := make(chan bool)
  320. // netstack doesn't close the notification channel automatically if there was no
  321. // hup signal, so we close done after we're done to not leak the goroutine below.
  322. defer close(done)
  323. go func() {
  324. select {
  325. case <-notifyCh:
  326. case <-done:
  327. }
  328. cancel()
  329. }()
  330. var stdDialer net.Dialer
  331. server, err := stdDialer.DialContext(ctx, "tcp", net.JoinHostPort("localhost", strconv.Itoa(int(port))))
  332. if err != nil {
  333. ns.logf("netstack: could not connect to local server on port %v: %v", port, err)
  334. return
  335. }
  336. defer server.Close()
  337. backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
  338. backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
  339. clientRemoteIP, _ := netaddr.FromStdIP(client.RemoteAddr().(*net.TCPAddr).IP)
  340. ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
  341. defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
  342. connClosed := make(chan error, 2)
  343. go func() {
  344. _, err := io.Copy(server, client)
  345. connClosed <- err
  346. }()
  347. go func() {
  348. _, err := io.Copy(client, server)
  349. connClosed <- err
  350. }()
  351. err = <-connClosed
  352. if err != nil {
  353. ns.logf("proxy connection closed with error: %v", err)
  354. }
  355. ns.logf("[v2] netstack: forwarder connection on port %v closed", port)
  356. }
  357. func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
  358. ns.logf("[v2] UDP ForwarderRequest: %v", r)
  359. var wq waiter.Queue
  360. ep, err := r.CreateEndpoint(&wq)
  361. if err != nil {
  362. ns.logf("Could not create endpoint, exiting")
  363. return
  364. }
  365. localAddr, err := ep.GetLocalAddress()
  366. if err != nil {
  367. return
  368. }
  369. remoteAddr, err := ep.GetRemoteAddress()
  370. if err != nil {
  371. return
  372. }
  373. c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
  374. go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
  375. }
  376. func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
  377. port := clientLocalAddr.Port
  378. ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
  379. backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)}
  380. backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
  381. backendConn, err := net.ListenUDP("udp4", backendListenAddr)
  382. if err != nil {
  383. ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
  384. backendListenAddr.Port = 0
  385. backendConn, err = net.ListenUDP("udp4", backendListenAddr)
  386. if err != nil {
  387. ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
  388. return
  389. }
  390. }
  391. backendLocalAddr := backendConn.LocalAddr().(*net.UDPAddr)
  392. backendLocalIPPort, ok := netaddr.FromStdAddr(backendListenAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
  393. if !ok {
  394. ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
  395. }
  396. clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String()))
  397. ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
  398. ctx, cancel := context.WithCancel(context.Background())
  399. timer := time.AfterFunc(2*time.Minute, func() {
  400. ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
  401. ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
  402. cancel()
  403. client.Close()
  404. backendConn.Close()
  405. })
  406. extend := func() {
  407. timer.Reset(2 * time.Minute)
  408. }
  409. startPacketCopy(ctx, cancel, client, &net.UDPAddr{
  410. IP: net.ParseIP(clientRemoteAddr.Addr.String()),
  411. Port: int(clientRemoteAddr.Port),
  412. }, backendConn, ns.logf, extend)
  413. startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
  414. }
  415. func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
  416. go func() {
  417. defer cancel() // tear down the other direction's copy
  418. pkt := make([]byte, mtu)
  419. for {
  420. select {
  421. case <-ctx.Done():
  422. return
  423. default:
  424. n, srcAddr, err := src.ReadFrom(pkt)
  425. if err != nil {
  426. if ctx.Err() == nil {
  427. logf("read packet from %s failed: %v", srcAddr, err)
  428. }
  429. return
  430. }
  431. _, err = dst.WriteTo(pkt[:n], dstAddr)
  432. if err != nil {
  433. if ctx.Err() == nil {
  434. logf("write packet to %s failed: %v", dstAddr, err)
  435. }
  436. return
  437. }
  438. if debugNetstack {
  439. logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr)
  440. }
  441. extend()
  442. }
  443. }
  444. }()
  445. }