فهرست منبع

control/controlhttp: move Dial options into options struct (#5661)

This turns 'dialParams' into something more like net.Dialer, where
configuration fields are public on the struct.

Split out of #5648

Change-Id: I0c56fd151dc5489c3c94fb40d18fd639e06473bc
Signed-off-by: Andrew Dunham <[email protected]>
Andrew Dunham 3 سال پیش
والد
کامیت
9b71008ef2

+ 9 - 1
cmd/tailscale/cli/debug.go

@@ -489,7 +489,15 @@ func runTS2021(ctx context.Context, args []string) error {
 		return c, err
 		return c, err
 	}
 	}
 
 
-	conn, err := controlhttp.Dial(ctx, ts2021Args.host, "80", "443", machinePrivate, keys.PublicKey, uint16(ts2021Args.version), dialFunc)
+	conn, err := (&controlhttp.Dialer{
+		Hostname:        ts2021Args.host,
+		HTTPPort:        "80",
+		HTTPSPort:       "443",
+		MachineKey:      machinePrivate,
+		ControlKey:      keys.PublicKey,
+		ProtocolVersion: uint16(ts2021Args.version),
+		Dialer:          dialFunc,
+	}).Dial(ctx)
 	log.Printf("controlhttp.Dial = %p, %v", conn, err)
 	log.Printf("controlhttp.Dial = %p, %v", conn, err)
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 9 - 1
control/controlclient/noise.go

@@ -165,7 +165,15 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
 		// thousand version numbers before getting to this point.
 		// thousand version numbers before getting to this point.
 		panic("capability version is too high to fit in the wire protocol")
 		panic("capability version is too high to fit in the wire protocol")
 	}
 	}
-	conn, err := controlhttp.Dial(ctx, nc.host, nc.httpPort, nc.httpsPort, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial)
+	conn, err := (&controlhttp.Dialer{
+		Hostname:        nc.host,
+		HTTPPort:        nc.httpPort,
+		HTTPSPort:       nc.httpsPort,
+		MachineKey:      nc.privKey,
+		ControlKey:      nc.serverPubKey,
+		ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
+		Dialer:          nc.dialer.SystemDial,
+	}).Dial(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 40 - 40
control/controlhttp/client.go

@@ -40,57 +40,49 @@ import (
 	"tailscale.com/net/netutil"
 	"tailscale.com/net/netutil"
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tshttpproxy"
 	"tailscale.com/net/tshttpproxy"
-	"tailscale.com/types/key"
 )
 )
 
 
-// Dial connects to the HTTP server at host:httpPort, requests to switch to the
-// Tailscale control protocol, and returns an established control
+var stdDialer net.Dialer
+
+// Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to
+// switch to the Tailscale control protocol, and returns an established control
 // protocol connection.
 // protocol connection.
 //
 //
-// If Dial fails to connect using addr, it also tries to tunnel over
-// TLS to host:httpsPort as a compatibility fallback.
+// If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the
+// Dialer's Host:HTTPSPort as a compatibility fallback.
 //
 //
 // The provided ctx is only used for the initial connection, until
 // The provided ctx is only used for the initial connection, until
 // Dial returns. It does not affect the connection once established.
 // Dial returns. It does not affect the connection once established.
-func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
-	a := &dialParams{
-		host:       host,
-		httpPort:   httpPort,
-		httpsPort:  httpsPort,
-		machineKey: machineKey,
-		controlKey: controlKey,
-		version:    protocolVersion,
-		proxyFunc:  tshttpproxy.ProxyFromEnvironment,
-		dialer:     dialer,
+func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
+	if a.Hostname == "" {
+		return nil, errors.New("required Dialer.Hostname empty")
 	}
 	}
 	return a.dial(ctx)
 	return a.dial(ctx)
 }
 }
 
 
-type dialParams struct {
-	host       string
-	httpPort   string
-	httpsPort  string
-	machineKey key.MachinePrivate
-	controlKey key.MachinePublic
-	version    uint16
-	proxyFunc  func(*http.Request) (*url.URL, error) // or nil
-	dialer     dnscache.DialContextFunc
+func (a *Dialer) logf(format string, args ...any) {
+	if a.Logf != nil {
+		a.Logf(format, args...)
+	}
+}
 
 
-	// For tests only
-	insecureTLS       bool
-	testFallbackDelay time.Duration
+func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) {
+	if a.proxyFunc != nil {
+		return a.proxyFunc
+	}
+	return tshttpproxy.ProxyFromEnvironment
 }
 }
 
 
-// httpsFallbackDelay is how long we'll wait for a.httpPort to work before
-// starting to try a.httpsPort.
-func (a *dialParams) httpsFallbackDelay() time.Duration {
+// httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before
+// starting to try a.HTTPSPort.
+func (a *Dialer) httpsFallbackDelay() time.Duration {
 	if v := a.testFallbackDelay; v != 0 {
 	if v := a.testFallbackDelay; v != 0 {
 		return v
 		return v
 	}
 	}
 	return 500 * time.Millisecond
 	return 500 * time.Millisecond
 }
 }
 
 
-func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
+func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
 	// Create one shared context used by both port 80 and port 443 dials.
 	// Create one shared context used by both port 80 and port 443 dials.
 	// If port 80 is still in flight when 443 returns, this deferred cancel
 	// If port 80 is still in flight when 443 returns, this deferred cancel
 	// will stop the port 80 dial.
 	// will stop the port 80 dial.
@@ -102,12 +94,12 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
 	// we'll speak Noise.
 	// we'll speak Noise.
 	u80 := &url.URL{
 	u80 := &url.URL{
 		Scheme: "http",
 		Scheme: "http",
-		Host:   net.JoinHostPort(a.host, a.httpPort),
+		Host:   net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")),
 		Path:   serverUpgradePath,
 		Path:   serverUpgradePath,
 	}
 	}
 	u443 := &url.URL{
 	u443 := &url.URL{
 		Scheme: "https",
 		Scheme: "https",
-		Host:   net.JoinHostPort(a.host, a.httpsPort),
+		Host:   net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
 		Path:   serverUpgradePath,
 		Path:   serverUpgradePath,
 	}
 	}
 
 
@@ -169,8 +161,8 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
 }
 }
 
 
 // dialURL attempts to connect to the given URL.
 // dialURL attempts to connect to the given URL.
-func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
-	init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
+func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
+	init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -189,26 +181,34 @@ func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn
 // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
 // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
 //
 //
 // Only the provided ctx is used, not a.ctx.
 // Only the provided ctx is used, not a.ctx.
-func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
+func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
 	dns := &dnscache.Resolver{
 	dns := &dnscache.Resolver{
 		Forward:          dnscache.Get().Forward,
 		Forward:          dnscache.Get().Forward,
 		LookupIPFallback: dnsfallback.Lookup,
 		LookupIPFallback: dnsfallback.Lookup,
 		UseLastGood:      true,
 		UseLastGood:      true,
 	}
 	}
+
+	var dialer dnscache.DialContextFunc
+	if a.Dialer != nil {
+		dialer = a.Dialer
+	} else {
+		dialer = stdDialer.DialContext
+	}
+
 	tr := http.DefaultTransport.(*http.Transport).Clone()
 	tr := http.DefaultTransport.(*http.Transport).Clone()
 	defer tr.CloseIdleConnections()
 	defer tr.CloseIdleConnections()
-	tr.Proxy = a.proxyFunc
+	tr.Proxy = a.getProxyFunc()
 	tshttpproxy.SetTransportGetProxyConnectHeader(tr)
 	tshttpproxy.SetTransportGetProxyConnectHeader(tr)
-	tr.DialContext = dnscache.Dialer(a.dialer, dns)
+	tr.DialContext = dnscache.Dialer(dialer, dns)
 	// Disable HTTP2, since h2 can't do protocol switching.
 	// Disable HTTP2, since h2 can't do protocol switching.
 	tr.TLSClientConfig.NextProtos = []string{}
 	tr.TLSClientConfig.NextProtos = []string{}
 	tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
 	tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
-	tr.TLSClientConfig = tlsdial.Config(a.host, tr.TLSClientConfig)
+	tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig)
 	if a.insecureTLS {
 	if a.insecureTLS {
 		tr.TLSClientConfig.InsecureSkipVerify = true
 		tr.TLSClientConfig.InsecureSkipVerify = true
 		tr.TLSClientConfig.VerifyConnection = nil
 		tr.TLSClientConfig.VerifyConnection = nil
 	}
 	}
-	tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig)
+	tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig)
 	tr.DisableCompression = true
 	tr.DisableCompression = true
 
 
 	// (mis)use httptrace to extract the underlying net.Conn from the
 	// (mis)use httptrace to extract the underlying net.Conn from the

+ 9 - 6
control/controlhttp/client_js.go

@@ -7,27 +7,31 @@ package controlhttp
 import (
 import (
 	"context"
 	"context"
 	"encoding/base64"
 	"encoding/base64"
+	"errors"
 	"net"
 	"net"
 	"net/url"
 	"net/url"
 
 
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket"
 	"tailscale.com/control/controlbase"
 	"tailscale.com/control/controlbase"
-	"tailscale.com/net/dnscache"
-	"tailscale.com/types/key"
 )
 )
 
 
 // Variant of Dial that tunnels the request over WebSockets, since we cannot do
 // Variant of Dial that tunnels the request over WebSockets, since we cannot do
 // bi-directional communication over an HTTP connection when in JS.
 // bi-directional communication over an HTTP connection when in JS.
-func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
-	init, cont, err := controlbase.ClientDeferred(machineKey, controlKey, protocolVersion)
+func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
+	if d.Hostname == "" {
+		return nil, errors.New("required Dialer.Hostname empty")
+	}
+
+	init, cont, err := controlbase.ClientDeferred(d.MachineKey, d.ControlKey, d.ProtocolVersion)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	wsScheme := "wss"
 	wsScheme := "wss"
+	host := d.Hostname
 	if host == "localhost" {
 	if host == "localhost" {
 		wsScheme = "ws"
 		wsScheme = "ws"
-		host = net.JoinHostPort(host, httpPort)
+		host = net.JoinHostPort(host, strDef(d.HTTPPort, "80"))
 	}
 	}
 	wsURL := &url.URL{
 	wsURL := &url.URL{
 		Scheme: wsScheme,
 		Scheme: wsScheme,
@@ -52,5 +56,4 @@ func Dial(ctx context.Context, host string, httpPort string, httpsPort string, m
 		return nil, err
 		return nil, err
 	}
 	}
 	return cbConn, nil
 	return cbConn, nil
-
 }
 }

+ 65 - 0
control/controlhttp/constants.go

@@ -4,6 +4,16 @@
 
 
 package controlhttp
 package controlhttp
 
 
+import (
+	"net/http"
+	"net/url"
+	"time"
+
+	"tailscale.com/net/dnscache"
+	"tailscale.com/types/key"
+	"tailscale.com/types/logger"
+)
+
 const (
 const (
 	// upgradeHeader is the value of the Upgrade HTTP header used to
 	// upgradeHeader is the value of the Upgrade HTTP header used to
 	// indicate the Tailscale control protocol.
 	// indicate the Tailscale control protocol.
@@ -18,3 +28,58 @@ const (
 	// to do the protocol switch is located.
 	// to do the protocol switch is located.
 	serverUpgradePath = "/ts2021"
 	serverUpgradePath = "/ts2021"
 )
 )
+
+// Dialer contains configuration on how to dial the Tailscale control server.
+type Dialer struct {
+	// Hostname is the hostname to connect to, with no port number.
+	//
+	// This field is required.
+	Hostname string
+
+	// MachineKey contains the current machine's private key.
+	//
+	// This field is required.
+	MachineKey key.MachinePrivate
+
+	// ControlKey contains the expected public key for the control server.
+	//
+	// This field is required.
+	ControlKey key.MachinePublic
+
+	// ProtocolVersion is the expected protocol version to negotiate.
+	//
+	// This field is required.
+	ProtocolVersion uint16
+
+	// HTTPPort is the port number to use when making a HTTP connection.
+	//
+	// If not specified, this defaults to port 80.
+	HTTPPort string
+
+	// HTTPSPort is the port number to use when making a HTTPS connection.
+	//
+	// If not specified, this defaults to port 443.
+	HTTPSPort string
+
+	// Dialer is the dialer used to make outbound connections.
+	//
+	// If not specified, this defaults to net.Dialer.DialContext.
+	Dialer dnscache.DialContextFunc
+
+	// Logf, if set, is a logging function to use; if unset, logs are
+	// dropped.
+	Logf logger.Logf
+
+	proxyFunc func(*http.Request) (*url.URL, error) // or nil
+
+	// For tests only
+	insecureTLS       bool
+	testFallbackDelay time.Duration
+}
+
+func strDef(v1, v2 string) string {
+	if v1 != "" {
+		return v1
+	}
+	return v2
+}

+ 9 - 8
control/controlhttp/http_test.go

@@ -170,15 +170,16 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
 		defer cancel()
 		defer cancel()
 	}
 	}
 
 
-	a := dialParams{
-		host:              "localhost",
-		httpPort:          strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
-		httpsPort:         strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
-		machineKey:        client,
-		controlKey:        server.Public(),
-		version:           testProtocolVersion,
+	a := &Dialer{
+		Hostname:          "localhost",
+		HTTPPort:          strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
+		HTTPSPort:         strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
+		MachineKey:        client,
+		ControlKey:        server.Public(),
+		ProtocolVersion:   testProtocolVersion,
+		Dialer:            new(tsdial.Dialer).SystemDial,
+		Logf:              t.Logf,
 		insecureTLS:       true,
 		insecureTLS:       true,
-		dialer:            new(tsdial.Dialer).SystemDial,
 		testFallbackDelay: 50 * time.Millisecond,
 		testFallbackDelay: 50 * time.Millisecond,
 	}
 	}