netstack.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build !ts_omit_netstack
  4. package main
  5. import (
  6. "context"
  7. "expvar"
  8. "net"
  9. "net/netip"
  10. "tailscale.com/tsd"
  11. "tailscale.com/types/logger"
  12. "tailscale.com/wgengine/netstack"
  13. )
  14. func init() {
  15. hookNewNetstack.Set(newNetstack)
  16. }
  17. func newNetstack(logf logger.Logf, sys *tsd.System, onlyNetstack bool) (tsd.NetstackImpl, error) {
  18. ns, err := netstack.Create(logf,
  19. sys.Tun.Get(),
  20. sys.Engine.Get(),
  21. sys.MagicSock.Get(),
  22. sys.Dialer.Get(),
  23. sys.DNSManager.Get(),
  24. sys.ProxyMapper(),
  25. )
  26. if err != nil {
  27. return nil, err
  28. }
  29. // Only register debug info if we have a debug mux
  30. if debugMux != nil {
  31. expvar.Publish("netstack", ns.ExpVar())
  32. }
  33. sys.Set(ns)
  34. ns.ProcessLocalIPs = onlyNetstack
  35. ns.ProcessSubnets = onlyNetstack || handleSubnetsInNetstack()
  36. dialer := sys.Dialer.Get() // must be set by caller already
  37. if onlyNetstack {
  38. e := sys.Engine.Get()
  39. dialer.UseNetstackForIP = func(ip netip.Addr) bool {
  40. _, ok := e.PeerForIP(ip)
  41. return ok
  42. }
  43. dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
  44. // Note: don't just return ns.DialContextTCP or we'll return
  45. // *gonet.TCPConn(nil) instead of a nil interface which trips up
  46. // callers.
  47. tcpConn, err := ns.DialContextTCP(ctx, dst)
  48. if err != nil {
  49. return nil, err
  50. }
  51. return tcpConn, nil
  52. }
  53. dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
  54. // Note: don't just return ns.DialContextUDP or we'll return
  55. // *gonet.UDPConn(nil) instead of a nil interface which trips up
  56. // callers.
  57. udpConn, err := ns.DialContextUDP(ctx, dst)
  58. if err != nil {
  59. return nil, err
  60. }
  61. return udpConn, nil
  62. }
  63. }
  64. return ns, nil
  65. }