Browse Source

feature/tpm: protect all TPM handle operations with a mutex (#17708)

In particular on Windows, the `transport.TPMCloser` we get is not safe
for concurrent use. This is especially noticeable because
`tpm.attestationKey.Clone` uses the same open handle as the original
key. So wrap the operations on ak.tpm with a mutex and make a deep copy
with a new connection in Clone.

Updates #15830
Updates #17662
Updates #17644

Signed-off-by: Andrew Lytvynov <[email protected]>
Andrew Lytvynov 4 months ago
parent
commit
f522b9dbb7
2 changed files with 100 additions and 6 deletions
  1. 34 6
      feature/tpm/attestation.go
  2. 66 0
      feature/tpm/attestation_test.go

+ 34 - 6
feature/tpm/attestation.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"io"
 	"log"
+	"sync"
 
 	"github.com/google/go-tpm/tpm2"
 	"github.com/google/go-tpm/tpm2/transport"
@@ -19,7 +20,8 @@ import (
 )
 
 type attestationKey struct {
-	tpm transport.TPMCloser
+	tpmMu sync.Mutex
+	tpm   transport.TPMCloser
 	// private and public parts of the TPM key as returned from tpm2.Create.
 	// These are used for serialization.
 	tpmPrivate tpm2.TPM2BPrivate
@@ -144,7 +146,7 @@ type attestationKeySerialized struct {
 
 // MarshalJSON implements json.Marshaler.
 func (ak *attestationKey) MarshalJSON() ([]byte, error) {
-	if ak == nil || ak.IsZero() {
+	if ak == nil || len(ak.tpmPublic.Bytes()) == 0 || len(ak.tpmPrivate.Buffer) == 0 {
 		return []byte("null"), nil
 	}
 	return json.Marshal(attestationKeySerialized{
@@ -163,6 +165,13 @@ func (ak *attestationKey) UnmarshalJSON(data []byte) (retErr error) {
 	ak.tpmPrivate = tpm2.TPM2BPrivate{Buffer: aks.TPMPrivate}
 	ak.tpmPublic = tpm2.BytesAs2B[tpm2.TPMTPublic, *tpm2.TPMTPublic](aks.TPMPublic)
 
+	ak.tpmMu.Lock()
+	defer ak.tpmMu.Unlock()
+	if ak.tpm != nil {
+		ak.tpm.Close()
+		ak.tpm = nil
+	}
+
 	tpm, err := open()
 	if err != nil {
 		return key.ErrUnsupported
@@ -182,6 +191,9 @@ func (ak *attestationKey) Public() crypto.PublicKey {
 }
 
 func (ak *attestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
+	ak.tpmMu.Lock()
+	defer ak.tpmMu.Unlock()
+
 	if !ak.loaded() {
 		return nil, errors.New("tpm2 attestation key is not loaded during Sign")
 	}
@@ -247,6 +259,9 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
 }
 
 func (ak *attestationKey) Close() error {
+	ak.tpmMu.Lock()
+	defer ak.tpmMu.Unlock()
+
 	var errs []error
 	if ak.handle != nil && ak.tpm != nil {
 		_, err := tpm2.FlushContext{FlushHandle: ak.handle.Handle}.Execute(ak.tpm)
@@ -262,18 +277,31 @@ func (ak *attestationKey) Clone() key.HardwareAttestationKey {
 	if ak == nil {
 		return nil
 	}
-	return &attestationKey{
-		tpm:        ak.tpm,
+
+	tpm, err := open()
+	if err != nil {
+		log.Printf("[unexpected] failed to open a TPM connection in feature/tpm.attestationKey.Clone: %v", err)
+		return nil
+	}
+	akc := &attestationKey{
+		tpm:        tpm,
 		tpmPrivate: ak.tpmPrivate,
 		tpmPublic:  ak.tpmPublic,
-		handle:     ak.handle,
-		pub:        ak.pub,
 	}
+	if err := akc.load(); err != nil {
+		log.Printf("[unexpected] failed to load TPM key in feature/tpm.attestationKey.Clone: %v", err)
+		tpm.Close()
+		return nil
+	}
+	return akc
 }
 
 func (ak *attestationKey) IsZero() bool {
 	if ak == nil {
 		return true
 	}
+
+	ak.tpmMu.Lock()
+	defer ak.tpmMu.Unlock()
 	return !ak.loaded()
 }

+ 66 - 0
feature/tpm/attestation_test.go

@@ -10,6 +10,8 @@ import (
 	"crypto/rand"
 	"crypto/sha256"
 	"encoding/json"
+	"runtime"
+	"sync"
 	"testing"
 )
 
@@ -62,6 +64,37 @@ func TestAttestationKeySign(t *testing.T) {
 	}
 }
 
+func TestAttestationKeySignConcurrent(t *testing.T) {
+	skipWithoutTPM(t)
+	ak, err := newAttestationKey()
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Cleanup(func() {
+		if err := ak.Close(); err != nil {
+			t.Errorf("ak.Close: %v", err)
+		}
+	})
+
+	data := []byte("secrets")
+	digest := sha256.Sum256(data)
+
+	wg := sync.WaitGroup{}
+	for range runtime.GOMAXPROCS(-1) {
+		wg.Go(func() {
+			// Check signature/validation round trip.
+			sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) {
+				t.Errorf("ecdsa.VerifyASN1 failed")
+			}
+		})
+	}
+	wg.Wait()
+}
+
 func TestAttestationKeyUnmarshal(t *testing.T) {
 	skipWithoutTPM(t)
 	ak, err := newAttestationKey()
@@ -96,3 +129,36 @@ func TestAttestationKeyUnmarshal(t *testing.T) {
 		t.Error("unmarshalled public key is not the same as the original public key")
 	}
 }
+
+func TestAttestationKeyClone(t *testing.T) {
+	skipWithoutTPM(t)
+	ak, err := newAttestationKey()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ak2 := ak.Clone()
+	if ak2 == nil {
+		t.Fatal("Clone failed")
+	}
+	t.Cleanup(func() {
+		if err := ak2.Close(); err != nil {
+			t.Errorf("ak2.Close: %v", err)
+		}
+	})
+	// Close the original key, ak2 should remain open and usable.
+	if err := ak.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	data := []byte("secrets")
+	digest := sha256.Sum256(data)
+	// Check signature/validation round trip using cloned key.
+	sig, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !ecdsa.VerifyASN1(ak2.Public().(*ecdsa.PublicKey), digest[:], sig) {
+		t.Errorf("ecdsa.VerifyASN1 failed")
+	}
+}