|
|
@@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte {
|
|
|
return strconv.AppendUint(ret, uint64(version), 10)
|
|
|
}
|
|
|
|
|
|
-// Client initiates a control client handshake, returning the resulting
|
|
|
-// control connection.
|
|
|
+// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
|
|
|
+// is assumed to have already sent the client>server handshake
|
|
|
+// initiation message.
|
|
|
+type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
|
|
|
+
|
|
|
+// ClientDeferred initiates a control client handshake, returning the
|
|
|
+// initial message to send to the server and a continuation to
|
|
|
+// finalize the handshake.
|
|
|
//
|
|
|
-// The context deadline, if any, covers the entire handshaking
|
|
|
-// process. Any preexisting Conn deadline is removed.
|
|
|
-func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
|
|
|
- if deadline, ok := ctx.Deadline(); ok {
|
|
|
- if err := conn.SetDeadline(deadline); err != nil {
|
|
|
- return nil, fmt.Errorf("setting conn deadline: %w", err)
|
|
|
- }
|
|
|
- defer func() {
|
|
|
- conn.SetDeadline(time.Time{})
|
|
|
- }()
|
|
|
- }
|
|
|
-
|
|
|
+// ClientDeferred is split in this way for RTT reduction: we run this
|
|
|
+// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
|
|
|
+// completely serialized the negotiation followed by the handshake,
|
|
|
+// we'd pay an extra RTT to transmit the handshake initiation after
|
|
|
+// protocol switching. By splitting the handshake into an initial
|
|
|
+// message and a continuation, we can embed the handshake initiation
|
|
|
+// into the HTTP protocol switching request and avoid a bit of delay.
|
|
|
+func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
|
|
|
var s symmetricState
|
|
|
s.Initialize()
|
|
|
|
|
|
@@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
|
|
|
s.MixHash(machineEphemeralPub.UntypedBytes())
|
|
|
cipher, err := s.MixDH(machineEphemeral, controlKey)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("computing es: %w", err)
|
|
|
+ return nil, nil, fmt.Errorf("computing es: %w", err)
|
|
|
}
|
|
|
machineKeyPub := machineKey.Public()
|
|
|
s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
|
|
|
cipher, err = s.MixDH(machineKey, controlKey)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("computing ss: %w", err)
|
|
|
+ return nil, nil, fmt.Errorf("computing ss: %w", err)
|
|
|
}
|
|
|
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
|
|
|
|
|
|
- if _, err := conn.Write(init[:]); err != nil {
|
|
|
- return nil, fmt.Errorf("writing initiation: %w", err)
|
|
|
+ cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
|
|
|
+ return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey)
|
|
|
+ }
|
|
|
+ return init[:], cont, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Client wraps ClientDeferred and immediately invokes the returned
|
|
|
+// continuation with conn.
|
|
|
+//
|
|
|
+// This is a helper for when you don't need the fancy
|
|
|
+// continuation-style handshake, and just want to synchronously
|
|
|
+// upgrade a net.Conn to a secure transport.
|
|
|
+func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
|
|
|
+ init, cont, err := ClientDeferred(machineKey, controlKey)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if _, err := conn.Write(init); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return cont(ctx, conn)
|
|
|
+}
|
|
|
+
|
|
|
+func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
|
|
|
+ // No matter what, this function can only run once per s. Ensure
|
|
|
+ // attempted reuse causes a panic.
|
|
|
+ defer func() {
|
|
|
+ s.finished = true
|
|
|
+ }()
|
|
|
+
|
|
|
+ if deadline, ok := ctx.Deadline(); ok {
|
|
|
+ if err := conn.SetDeadline(deadline); err != nil {
|
|
|
+ return nil, fmt.Errorf("setting conn deadline: %w", err)
|
|
|
+ }
|
|
|
+ defer func() {
|
|
|
+ conn.SetDeadline(time.Time{})
|
|
|
+ }()
|
|
|
}
|
|
|
|
|
|
// Read in the payload and look for errors/protocol violations from the server.
|
|
|
@@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
|
|
|
// <- e, ee, se
|
|
|
controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
|
|
|
s.MixHash(controlEphemeralPub.UntypedBytes())
|
|
|
- if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
|
|
+ if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
|
|
return nil, fmt.Errorf("computing ee: %w", err)
|
|
|
}
|
|
|
- cipher, err = s.MixDH(machineKey, controlEphemeralPub)
|
|
|
+ cipher, err := s.MixDH(machineKey, controlEphemeralPub)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("computing se: %w", err)
|
|
|
}
|
|
|
@@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
|
|
|
// Server initiates a control server handshake, returning the resulting
|
|
|
// control connection.
|
|
|
//
|
|
|
+// optionalInit can be the client's initial handshake message as
|
|
|
+// returned by ClientDeferred, or nil in which case the initial
|
|
|
+// message is read from conn.
|
|
|
+//
|
|
|
// The context deadline, if any, covers the entire handshaking
|
|
|
// process.
|
|
|
-func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) {
|
|
|
+func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
|
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
|
if err := conn.SetDeadline(deadline); err != nil {
|
|
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
|
|
@@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
|
|
|
s.Initialize()
|
|
|
|
|
|
var init initiationMessage
|
|
|
- if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
|
|
+ if optionalInit != nil {
|
|
|
+ if len(optionalInit) != len(init) {
|
|
|
+ return nil, sendErr("wrong handshake initiation size")
|
|
|
+ }
|
|
|
+ copy(init[:], optionalInit)
|
|
|
+ } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if init.Version() != protocolVersion {
|
|
|
@@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
|
|
|
if init.Length() != len(init.Payload()) {
|
|
|
return nil, sendErr("wrong handshake initiation length")
|
|
|
}
|
|
|
- if _, err := io.ReadFull(conn, init.Payload()); err != nil {
|
|
|
- return nil, err
|
|
|
+ // if optionalInit was provided, we have the payload already.
|
|
|
+ if optionalInit == nil {
|
|
|
+ if _, err := io.ReadFull(conn, init.Payload()); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// prologue. Can only do this once we at least think the client is
|