Browse Source

control/controlclient: export NoiseClient

This allows reusing the NoiseClient in other repos without having to reimplement the earlyPayload logic.

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 years ago
parent
commit
a413fa4f85
3 changed files with 43 additions and 33 deletions
  1. 7 20
      control/controlclient/direct.go
  2. 35 12
      control/controlclient/noise.go
  3. 1 1
      control/controlclient/noise_test.go

+ 7 - 20
control/controlclient/direct.go

@@ -84,8 +84,8 @@ type Direct struct {
 	serverKey      key.MachinePublic // original ("legacy") nacl crypto_box-based public key
 	serverNoiseKey key.MachinePublic
 
-	sfGroup     singleflight.Group[struct{}, *noiseClient] // protects noiseClient creation.
-	noiseClient *noiseClient
+	sfGroup     singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
+	noiseClient *NoiseClient
 
 	persist       persist.Persist
 	authKey       string
@@ -262,7 +262,7 @@ func NewDirect(opts Options) (*Direct, error) {
 		}
 	}
 	if opts.NoiseTestClient != nil {
-		c.noiseClient = &noiseClient{
+		c.noiseClient = &NoiseClient{
 			Client: opts.NoiseTestClient,
 		}
 		c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client
@@ -1470,7 +1470,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<-
 }
 
 // getNoiseClient returns the noise client, creating one if one doesn't exist.
-func (c *Direct) getNoiseClient() (*noiseClient, error) {
+func (c *Direct) getNoiseClient() (*NoiseClient, error) {
 	c.mu.Lock()
 	serverNoiseKey := c.serverNoiseKey
 	nc := c.noiseClient
@@ -1485,13 +1485,13 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
 	if c.dialPlan != nil {
 		dp = c.dialPlan.Load
 	}
-	nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) {
+	nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) {
 		k, err := c.getMachinePrivKey()
 		if err != nil {
 			return nil, err
 		}
 		c.logf("creating new noise client")
-		nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
+		nc, err := NewNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
 		if err != nil {
 			return nil, err
 		}
@@ -1618,20 +1618,7 @@ func (c *Direct) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundT
 	if err != nil {
 		return nil, nil, err
 	}
-	for tries := 0; tries < 3; tries++ {
-		conn, err := nc.getConn(ctx)
-		if err != nil {
-			return nil, nil, err
-		}
-		earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx)
-		if err != nil {
-			return nil, nil, err
-		}
-		if conn.h2cc.ReserveNewRequest() {
-			return conn, earlyPayloadMaybeNil, nil
-		}
-	}
-	return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
+	return nc.GetSingleUseRoundTripper(ctx)
 }
 
 // doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to

+ 35 - 12
control/controlclient/noise.go

@@ -35,7 +35,7 @@ import (
 type noiseConn struct {
 	*controlbase.Conn
 	id   int
-	pool *noiseClient
+	pool *NoiseClient
 	h2cc *http2.ClientConn
 
 	readHeaderOnce    sync.Once     // guards init of reader field
@@ -135,9 +135,9 @@ func (c *noiseConn) Close() error {
 	return nil
 }
 
-// noiseClient provides a http.Client to connect to tailcontrol over
+// NoiseClient provides a http.Client to connect to tailcontrol over
 // the ts2021 protocol.
-type noiseClient struct {
+type NoiseClient struct {
 	// Client is an HTTP client to talk to the coordination server.
 	// It automatically makes a new Noise connection as needed.
 	// It does not support node key proofs. To do that, call
@@ -175,11 +175,11 @@ type noiseClient struct {
 	connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close
 }
 
-// newNoiseClient returns a new noiseClient for the provided server and machine key.
+// NewNoiseClient returns a new noiseClient for the provided server and machine key.
 // serverURL is of the form https://<host>:<port> (no trailing slash).
 //
 // dialPlan may be nil
-func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) {
+func NewNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*NoiseClient, error) {
 	u, err := url.Parse(serverURL)
 	if err != nil {
 		return nil, err
@@ -200,7 +200,7 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic,
 		httpPort = "80"
 		httpsPort = "443"
 	}
-	np := &noiseClient{
+	np := &NoiseClient{
 		serverPubKey: serverPubKey,
 		privKey:      privKey,
 		host:         u.Hostname(),
@@ -227,7 +227,30 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic,
 	return np, nil
 }
 
-func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) {
+// GetSingleUseRoundTripper returns a RoundTripper that can be only be used once
+// (and must be used once) to make a single HTTP request over the noise channel
+// to the coordination server.
+//
+// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise
+// payload, if any.
+func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) {
+	for tries := 0; tries < 3; tries++ {
+		conn, err := nc.getConn(ctx)
+		if err != nil {
+			return nil, nil, err
+		}
+		earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx)
+		if err != nil {
+			return nil, nil, err
+		}
+		if conn.h2cc.ReserveNewRequest() {
+			return conn, earlyPayloadMaybeNil, nil
+		}
+	}
+	return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
+}
+
+func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
 	nc.mu.Lock()
 	if last := nc.last; last != nil && last.canTakeNewRequest() {
 		nc.mu.Unlock()
@@ -242,7 +265,7 @@ func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) {
 	return conn, nil
 }
 
-func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
+func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
 	ctx := req.Context()
 	conn, err := nc.getConn(ctx)
 	if err != nil {
@@ -253,7 +276,7 @@ func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
 
 // connClosed removes the connection with the provided ID from the pool
 // of active connections.
-func (nc *noiseClient) connClosed(id int) {
+func (nc *NoiseClient) connClosed(id int) {
 	nc.mu.Lock()
 	defer nc.mu.Unlock()
 	conn := nc.connPool[id]
@@ -267,7 +290,7 @@ func (nc *noiseClient) connClosed(id int) {
 
 // Close closes all the underlying noise connections.
 // It is a no-op and returns nil if the connection is already closed.
-func (nc *noiseClient) Close() error {
+func (nc *NoiseClient) Close() error {
 	nc.mu.Lock()
 	conns := nc.connPool
 	nc.connPool = nil
@@ -284,7 +307,7 @@ func (nc *noiseClient) Close() error {
 
 // dial opens a new connection to tailcontrol, fetching the server noise key
 // if not cached.
-func (nc *noiseClient) dial() (*noiseConn, error) {
+func (nc *NoiseClient) dial() (*noiseConn, error) {
 	nc.mu.Lock()
 	connID := nc.nextID
 	nc.nextID++
@@ -369,7 +392,7 @@ func (nc *noiseClient) dial() (*noiseConn, error) {
 	return ncc, nil
 }
 
-func (nc *noiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) {
+func (nc *NoiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) {
 	jbody, err := json.Marshal(body)
 	if err != nil {
 		return nil, err

+ 1 - 1
control/controlclient/noise_test.go

@@ -75,7 +75,7 @@ func (tt noiseClientTest) run(t *testing.T) {
 	defer hs.Close()
 
 	dialer := new(tsdial.Dialer)
-	nc, err := newNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil)
+	nc, err := NewNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil)
 	if err != nil {
 		t.Fatal(err)
 	}