Kaynağa Gözat

tsnet: add Dial method to allow dialing out to the tailnet.

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 4 yıl önce
ebeveyn
işleme
44937b59e7
1 değiştirilmiş dosya ile 23 ekleme ve 9 silme
  1. 23 9
      tsnet/tsnet.go

+ 23 - 9
tsnet/tsnet.go

@@ -63,6 +63,15 @@ type Server struct {
 
 	mu        sync.Mutex
 	listeners map[listenKey]*listener
+	dialer    *tsdial.Dialer
+}
+
+// Dial connects to the address on the tailnet.
+func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) {
+	if err := s.init(); err != nil {
+		return nil, err
+	}
+	return s.dialer.UserDial(ctx, network, address)
 }
 
 func (s *Server) doInit() {
@@ -71,6 +80,11 @@ func (s *Server) doInit() {
 	}
 }
 
+func (s *Server) init() error {
+	s.initOnce.Do(s.doInit)
+	return s.initErr
+}
+
 func (s *Server) start() error {
 	if v, _ := strconv.ParseBool(os.Getenv("TAILSCALE_USE_WIP_CODE")); !v {
 		return errors.New("code disabled without environment variable TAILSCALE_USE_WIP_CODE set true")
@@ -117,11 +131,11 @@ func (s *Server) start() error {
 		return err
 	}
 
-	dialer := new(tsdial.Dialer) // mutated below (before used)
+	s.dialer = new(tsdial.Dialer) // mutated below (before used)
 	eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
 		ListenPort:  0,
 		LinkMonitor: linkMon,
-		Dialer:      dialer,
+		Dialer:      s.dialer,
 	})
 	if err != nil {
 		return err
@@ -132,7 +146,7 @@ func (s *Server) start() error {
 		return fmt.Errorf("%T is not a wgengine.InternalsGetter", eng)
 	}
 
-	ns, err := netstack.Create(logf, tunDev, eng, magicConn, dialer)
+	ns, err := netstack.Create(logf, tunDev, eng, magicConn, s.dialer)
 	if err != nil {
 		return fmt.Errorf("netstack.Create: %w", err)
 	}
@@ -141,11 +155,11 @@ func (s *Server) start() error {
 	if err := ns.Start(); err != nil {
 		return fmt.Errorf("failed to start netstack: %w", err)
 	}
-	dialer.UseNetstackForIP = func(ip netaddr.IP) bool {
+	s.dialer.UseNetstackForIP = func(ip netaddr.IP) bool {
 		_, ok := eng.PeerForIP(ip)
 		return ok
 	}
-	dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) {
+	s.dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) {
 		return ns.DialContextTCP(ctx, dst)
 	}
 
@@ -156,7 +170,7 @@ func (s *Server) start() error {
 	}
 	logid := "tslib-TODO"
 
-	lb, err := ipnlocal.NewLocalBackend(logf, logid, store, dialer, eng)
+	lb, err := ipnlocal.NewLocalBackend(logf, logid, store, s.dialer, eng)
 	if err != nil {
 		return fmt.Errorf("NewLocalBackend: %v", err)
 	}
@@ -217,15 +231,15 @@ func (s *Server) forwardTCP(c net.Conn, port uint16) {
 	}
 }
 
+// Listen announces only on the Tailscale network.
 func (s *Server) Listen(network, addr string) (net.Listener, error) {
 	host, port, err := net.SplitHostPort(addr)
 	if err != nil {
 		return nil, fmt.Errorf("tsnet: %w", err)
 	}
 
-	s.initOnce.Do(s.doInit)
-	if s.initErr != nil {
-		return nil, s.initErr
+	if err := s.init(); err != nil {
+		return nil, err
 	}
 
 	key := listenKey{network, host, port}