Ver código fonte

control/controlbase: don't enforce a max protocol version at handshake time.

Doing so makes development unpleasant, because we have to first break the
client by bumping to a version the control server rejects, then upgrade
the control server to make it accept the new version.

This strict rejection at handshake time is only necessary if we want to
blocklist some vulnerable protocol versions in the future. So, switch
to a default-permissive stance: until we have such a version that we
have to eagerly block early, we'll accept whatever version the client
presents, and leave it to the user of controlbase.Conn to make decisions
based on that version.

Noise still enforces that the client and server *agree* on what protocol
version is being used, and the control server still has the option to
finish the handshake and then hang up with an in-noise error, rather
than abort at the handshake level.

Updates #3488

Signed-off-by: David Anderson <[email protected]>
David Anderson 3 anos atrás
pai
commit
f570372b4d

+ 2 - 2
control/controlbase/conn_test.go

@@ -206,7 +206,7 @@ func TestConnStd(t *testing.T) {
 		serverErr := make(chan error, 1)
 		serverErr := make(chan error, 1)
 		go func() {
 		go func() {
 			var err error
 			var err error
-			c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
+			c2, err = Server(context.Background(), s2, controlKey, nil)
 			serverErr <- err
 			serverErr <- err
 		}()
 		}()
 		c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
 		c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
@@ -398,7 +398,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
 	)
 	)
 	go func() {
 	go func() {
 		var err error
 		var err error
-		server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil)
+		server, err = Server(context.Background(), serverConn, controlKey, nil)
 		serverErr <- err
 		serverErr <- err
 	}()
 	}()
 
 

+ 6 - 16
control/controlbase/handshake.go

@@ -193,19 +193,13 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta
 // Server initiates a control server handshake, returning the resulting
 // Server initiates a control server handshake, returning the resulting
 // control connection.
 // control connection.
 //
 //
-// maxSupportedVersion is the highest handshake version the server is
-// willing to handshake with. The server will handshake with any
-// version from 0 to maxSupportedVersion inclusive, the caller should
-// inspect conn.Version() to determine what version of the handshake
-// was executed.
-//
 // optionalInit can be the client's initial handshake message as
 // optionalInit can be the client's initial handshake message as
 // returned by ClientDeferred, or nil in which case the initial
 // returned by ClientDeferred, or nil in which case the initial
 // message is read from conn.
 // message is read from conn.
 //
 //
 // The context deadline, if any, covers the entire handshaking
 // The context deadline, if any, covers the entire handshaking
 // process.
 // process.
-func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) {
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
 	if deadline, ok := ctx.Deadline(); ok {
 	if deadline, ok := ctx.Deadline(); ok {
 		if err := conn.SetDeadline(deadline); err != nil {
 		if err := conn.SetDeadline(deadline); err != nil {
 			return nil, fmt.Errorf("setting conn deadline: %w", err)
 			return nil, fmt.Errorf("setting conn deadline: %w", err)
@@ -245,15 +239,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, m
 	} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
 	} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	// Currently, these versions exclusively indicate what the upper
-	// RPC protocol understands, the Noise handshake is exactly the
-	// same in all versions. If that ever changes, this check will
-	// need to become more complex to handle different kinds of
-	// handshake.
-	if init.Version() > maxSupportedVersion {
-		return nil, sendErr("unsupported handshake version")
-	}
-	// Just a rename to make it more obvious what the value is
+	// Just a rename to make it more obvious what the value is. In the
+	// current implementation we don't need to block any protocol
+	// versions at this layer, it's safe to let the handshake proceed
+	// and then let the caller make decisions based on the agreed-upon
+	// protocol version.
 	clientVersion := init.Version()
 	clientVersion := init.Version()
 	if init.Type() != msgTypeInitiation {
 	if init.Type() != msgTypeInitiation {
 		return nil, sendErr("unexpected handshake message type")
 		return nil, sendErr("unexpected handshake message type")

+ 6 - 6
control/controlbase/handshake_test.go

@@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) {
 	)
 	)
 	go func() {
 	go func() {
 		var err error
 		var err error
-		server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+		server, err = Server(context.Background(), serverConn, serverKey, nil)
 		serverErr <- err
 		serverErr <- err
 	}()
 	}()
 
 
@@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) {
 		)
 		)
 		go func() {
 		go func() {
 			var err error
 			var err error
-			server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+			server, err = Server(context.Background(), serverConn, serverKey, nil)
 			serverErr <- err
 			serverErr <- err
 		}()
 		}()
 
 
@@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
 			serverErr             = make(chan error, 1)
 			serverErr             = make(chan error, 1)
 		)
 		)
 		go func() {
 		go func() {
-			_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+			_, err := Server(context.Background(), serverConn, serverKey, nil)
 			// If the server failed, we have to close the Conn to
 			// If the server failed, we have to close the Conn to
 			// unblock the client.
 			// unblock the client.
 			if err != nil {
 			if err != nil {
@@ -200,7 +200,7 @@ func TestTampering(t *testing.T) {
 			serverErr             = make(chan error, 1)
 			serverErr             = make(chan error, 1)
 		)
 		)
 		go func() {
 		go func() {
-			_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+			_, err := Server(context.Background(), serverConn, serverKey, nil)
 			serverErr <- err
 			serverErr <- err
 		}()
 		}()
 
 
@@ -225,7 +225,7 @@ func TestTampering(t *testing.T) {
 			serverErr             = make(chan error, 1)
 			serverErr             = make(chan error, 1)
 		)
 		)
 		go func() {
 		go func() {
-			server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+			server, err := Server(context.Background(), serverConn, serverKey, nil)
 			serverErr <- err
 			serverErr <- err
 			_, err = io.WriteString(server, strings.Repeat("a", 14))
 			_, err = io.WriteString(server, strings.Repeat("a", 14))
 			serverErr <- err
 			serverErr <- err
@@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
 			serverErr             = make(chan error, 1)
 			serverErr             = make(chan error, 1)
 		)
 		)
 		go func() {
 		go func() {
-			server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
+			server, err := Server(context.Background(), serverConn, serverKey, nil)
 			serverErr <- err
 			serverErr <- err
 			var bs [100]byte
 			var bs [100]byte
 			// The server needs a timeout if the tampering is hitting the length header.
 			// The server needs a timeout if the tampering is hitting the length header.

+ 1 - 1
control/controlbase/interop_test.go

@@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
 	)
 	)
 
 
 	go func() {
 	go func() {
-		server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
+		server, err := Server(context.Background(), s2, controlKey, nil)
 		serverErr <- err
 		serverErr <- err
 		if err != nil {
 		if err != nil {
 			return
 			return

+ 1 - 1
control/controlhttp/http_test.go

@@ -107,7 +107,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
 	const testProtocolVersion = 1
 	const testProtocolVersion = 1
 	sch := make(chan serverResult, 1)
 	sch := make(chan serverResult, 1)
 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion)
+		conn, err := AcceptHTTP(context.Background(), w, r, server)
 		if err != nil {
 		if err != nil {
 			log.Print(err)
 			log.Print(err)
 		}
 		}

+ 2 - 2
control/controlhttp/server.go

@@ -21,7 +21,7 @@ import (
 //
 //
 // AcceptHTTP always writes an HTTP response to w. The caller must not
 // AcceptHTTP always writes an HTTP response to w. The caller must not
 // attempt their own response after calling AcceptHTTP.
 // attempt their own response after calling AcceptHTTP.
-func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) {
+func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
 	next := r.Header.Get("Upgrade")
 	next := r.Header.Get("Upgrade")
 	if next == "" {
 	if next == "" {
 		http.Error(w, "missing next protocol", http.StatusBadRequest)
 		http.Error(w, "missing next protocol", http.StatusBadRequest)
@@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
 	}
 	}
 	conn = netutil.NewDrainBufConn(conn, brw.Reader)
 	conn = netutil.NewDrainBufConn(conn, brw.Reader)
 
 
-	nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init)
+	nc, err := controlbase.Server(ctx, conn, private, init)
 	if err != nil {
 	if err != nil {
 		conn.Close()
 		conn.Close()
 		return nil, fmt.Errorf("noise handshake failed: %w", err)
 		return nil, fmt.Errorf("noise handshake failed: %w", err)