|
@@ -0,0 +1,2026 @@
|
|
|
+package dtls
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "context"
|
|
|
+ "crypto/rand"
|
|
|
+ "crypto/tls"
|
|
|
+ "crypto/x509"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/pion/dtls/v2/internal/ciphersuite"
|
|
|
+ "github.com/pion/dtls/v2/internal/net/dpipe"
|
|
|
+ "github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
|
|
+ "github.com/pion/dtls/v2/pkg/crypto/hash"
|
|
|
+ "github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
|
|
+ "github.com/pion/dtls/v2/pkg/crypto/signature"
|
|
|
+ "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol/alert"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol/extension"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol/handshake"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
|
|
+ "github.com/pion/transport/test"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity")
|
|
|
+ errPSKRejected = errors.New("PSK Rejected")
|
|
|
+ errNotExpectedChain = errors.New("not expected chain")
|
|
|
+ errExpecedChain = errors.New("expected chain")
|
|
|
+ errWrongCert = errors.New("wrong cert")
|
|
|
+)
|
|
|
+
|
|
|
+func TestStressDuplex(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ // Run the test
|
|
|
+ stressDuplex(t)
|
|
|
+}
|
|
|
+
|
|
|
+func stressDuplex(t *testing.T) {
|
|
|
+ ca, cb, err := pipeMemory()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer func() {
|
|
|
+ err = ca.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ err = cb.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ opt := test.Options{
|
|
|
+ MsgSize: 2048,
|
|
|
+ MsgCount: 100,
|
|
|
+ }
|
|
|
+
|
|
|
+ err = test.StressDuplex(ca, cb, opt)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestRoutineLeakOnClose(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(5 * time.Second)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ ca, cb, err := pipeMemory()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err := ca.Write(make([]byte, 100)); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if err := cb.Close(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if err := ca.Close(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ // Packet is sent, but not read.
|
|
|
+ // inboundLoop routine should not be leaked.
|
|
|
+}
|
|
|
+
|
|
|
+func TestReadWriteDeadline(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(5 * time.Second)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ ca, cb, err := pipeMemory()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := ca.SetDeadline(time.Unix(0, 1)); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ _, werr := ca.Write(make([]byte, 100))
|
|
|
+ if e, ok := werr.(net.Error); ok {
|
|
|
+ if !e.Timeout() {
|
|
|
+ t.Error("Deadline exceeded Write must return Timeout error")
|
|
|
+ }
|
|
|
+ if !e.Temporary() {
|
|
|
+ t.Error("Deadline exceeded Write must return Temporary error")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ t.Error("Write must return net.Error error")
|
|
|
+ }
|
|
|
+ _, rerr := ca.Read(make([]byte, 100))
|
|
|
+ if e, ok := rerr.(net.Error); ok {
|
|
|
+ if !e.Timeout() {
|
|
|
+ t.Error("Deadline exceeded Read must return Timeout error")
|
|
|
+ }
|
|
|
+ if !e.Temporary() {
|
|
|
+ t.Error("Deadline exceeded Read must return Temporary error")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ t.Error("Read must return net.Error error")
|
|
|
+ }
|
|
|
+ if err := ca.SetDeadline(time.Time{}); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := ca.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ if err := cb.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) {
|
|
|
+ t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err)
|
|
|
+ }
|
|
|
+ if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) {
|
|
|
+ t.Errorf("Read must return %v after close, got %v", io.EOF, err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSequenceNumberOverflow(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(5 * time.Second)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ t.Run("ApplicationData", func(t *testing.T) {
|
|
|
+ ca, cb, err := pipeMemory()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber)
|
|
|
+ if _, werr := ca.Write(make([]byte, 100)); werr != nil {
|
|
|
+ t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr)
|
|
|
+ }
|
|
|
+ if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) {
|
|
|
+ t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := ca.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ if err := cb.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ t.Run("Handshake", func(t *testing.T) {
|
|
|
+ ca, cb, err := pipeMemory()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1)
|
|
|
+
|
|
|
+ // Try to send handshake packet.
|
|
|
+ if werr := ca.writePackets(ctx, []*packet{
|
|
|
+ {
|
|
|
+ record: &recordlayer.RecordLayer{
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: make([]byte, 64),
|
|
|
+ CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()),
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }); !errors.Is(werr, errSequenceNumberOverflow) {
|
|
|
+ t.Errorf("Connection must fail on handshake packet reaches maximum sequence number")
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := ca.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ if err := cb.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func pipeMemory() (*Conn, *Conn, error) {
|
|
|
+ // In memory pipe
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ return pipeConn(ca, cb)
|
|
|
+}
|
|
|
+
|
|
|
+func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) {
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+
|
|
|
+ c := make(chan result)
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // Setup client
|
|
|
+ go func() {
|
|
|
+ client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Setup server
|
|
|
+ server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Receive client
|
|
|
+ res := <-c
|
|
|
+ if res.err != nil {
|
|
|
+ return nil, nil, res.err
|
|
|
+ }
|
|
|
+
|
|
|
+ return res.c, server, nil
|
|
|
+}
|
|
|
+
|
|
|
+func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
|
|
|
+ if generateCertificate {
|
|
|
+ clientCert, err := selfsign.GenerateSelfSigned()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ cfg.Certificates = []tls.Certificate{clientCert}
|
|
|
+ }
|
|
|
+ cfg.InsecureSkipVerify = true
|
|
|
+ return ClientWithContext(ctx, c, cfg)
|
|
|
+}
|
|
|
+
|
|
|
+func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
|
|
|
+ if generateCertificate {
|
|
|
+ serverCert, err := selfsign.GenerateSelfSigned()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ cfg.Certificates = []tls.Certificate{serverCert}
|
|
|
+ }
|
|
|
+ return ServerWithContext(ctx, c, cfg)
|
|
|
+}
|
|
|
+
|
|
|
+func TestHandshakeWithAlert(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cases := map[string]struct {
|
|
|
+ configServer, configClient *Config
|
|
|
+ errServer, errClient error
|
|
|
+ }{
|
|
|
+ "CipherSuiteNoIntersection": {
|
|
|
+ configServer: &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ },
|
|
|
+ configClient: &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ },
|
|
|
+ errServer: errCipherSuiteNoIntersection,
|
|
|
+ errClient: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ },
|
|
|
+ "SignatureSchemesNoIntersection": {
|
|
|
+ configServer: &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256},
|
|
|
+ },
|
|
|
+ configClient: &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
|
|
|
+ },
|
|
|
+ errServer: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ errClient: errNoAvailableSignatureSchemes,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for name, testCase := range cases {
|
|
|
+ testCase := testCase
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ clientErr := make(chan error, 1)
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ go func() {
|
|
|
+ _, err := testClient(ctx, ca, testCase.configClient, true)
|
|
|
+ clientErr <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ _, errServer := testServer(ctx, cb, testCase.configServer, true)
|
|
|
+ if !errors.Is(errServer, testCase.errServer) {
|
|
|
+ t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
|
|
|
+ }
|
|
|
+
|
|
|
+ errClient := <-clientErr
|
|
|
+ if !errors.Is(errClient, testCase.errClient) {
|
|
|
+ t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestExportKeyingMaterial(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ var rand [28]byte
|
|
|
+ exportLabel := "EXTRACTOR-dtls_srtp"
|
|
|
+
|
|
|
+ expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b}
|
|
|
+ expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77}
|
|
|
+
|
|
|
+ c := &Conn{
|
|
|
+ state: State{
|
|
|
+ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand},
|
|
|
+ remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand},
|
|
|
+ localSequenceNumber: []uint64{0, 0},
|
|
|
+ cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
|
|
|
+ },
|
|
|
+ }
|
|
|
+ c.setLocalEpoch(0)
|
|
|
+ c.setRemoteEpoch(0)
|
|
|
+
|
|
|
+ state := c.ConnectionState()
|
|
|
+ _, err := state.ExportKeyingMaterial(exportLabel, nil, 0)
|
|
|
+ if !errors.Is(err, errHandshakeInProgress) {
|
|
|
+ t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ c.setLocalEpoch(1)
|
|
|
+ state = c.ConnectionState()
|
|
|
+ _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
|
|
|
+ if !errors.Is(err, errContextUnsupported) {
|
|
|
+ t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ for k := range invalidKeyingLabels() {
|
|
|
+ state = c.ConnectionState()
|
|
|
+ _, err = state.ExportKeyingMaterial(k, nil, 0)
|
|
|
+ if !errors.Is(err, errReservedExportKeyingMaterial) {
|
|
|
+ t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ state = c.ConnectionState()
|
|
|
+ keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
|
|
|
+ } else if !bytes.Equal(keyingMaterial, expectedServerKey) {
|
|
|
+ t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial)
|
|
|
+ }
|
|
|
+
|
|
|
+ c.state.isClient = true
|
|
|
+ state = c.ConnectionState()
|
|
|
+ keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
|
|
|
+ } else if !bytes.Equal(keyingMaterial, expectedClientKey) {
|
|
|
+ t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestPSK(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ for _, test := range []struct {
|
|
|
+ Name string
|
|
|
+ ServerIdentity []byte
|
|
|
+ CipherSuites []CipherSuiteID
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ Name: "Server identity specified",
|
|
|
+ ServerIdentity: []byte("Test Identity"),
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Server identity nil",
|
|
|
+ ServerIdentity: nil,
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "TLS_PSK_WITH_AES_128_CBC_SHA256",
|
|
|
+ ServerIdentity: nil,
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256},
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ test := test
|
|
|
+ t.Run(test.Name, func(t *testing.T) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ clientIdentity := []byte("Client Identity")
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ clientRes := make(chan result, 1)
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ go func() {
|
|
|
+ conf := &Config{
|
|
|
+ PSK: func(hint []byte) ([]byte, error) {
|
|
|
+ if !bytes.Equal(test.ServerIdentity, hint) { // nolint
|
|
|
+ return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) // nolint
|
|
|
+ }
|
|
|
+
|
|
|
+ return []byte{0xAB, 0xC1, 0x23}, nil
|
|
|
+ },
|
|
|
+ PSKIdentityHint: clientIdentity,
|
|
|
+ CipherSuites: test.CipherSuites,
|
|
|
+ }
|
|
|
+
|
|
|
+ c, err := testClient(ctx, ca, conf, false)
|
|
|
+ clientRes <- result{c, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ config := &Config{
|
|
|
+ PSK: func(hint []byte) ([]byte, error) {
|
|
|
+ if !bytes.Equal(clientIdentity, hint) {
|
|
|
+ return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint)
|
|
|
+ }
|
|
|
+ return []byte{0xAB, 0xC1, 0x23}, nil
|
|
|
+ },
|
|
|
+ PSKIdentityHint: test.ServerIdentity,
|
|
|
+ CipherSuites: test.CipherSuites,
|
|
|
+ }
|
|
|
+
|
|
|
+ server, err := testServer(ctx, cb, config, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("TestPSK: Server failed(%v)", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ actualPSKIdentityHint := server.ConnectionState().IdentityHint
|
|
|
+ if !bytes.Equal(actualPSKIdentityHint, clientIdentity) {
|
|
|
+ t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer func() {
|
|
|
+ _ = server.Close()
|
|
|
+ }()
|
|
|
+
|
|
|
+ res := <-clientRes
|
|
|
+ if res.err != nil {
|
|
|
+ t.Fatal(res.err)
|
|
|
+ }
|
|
|
+ _ = res.c.Close()
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestPSKHintFail(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ serverAlertError := &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
|
|
|
+ pskRejected := errPSKRejected
|
|
|
+
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ clientErr := make(chan error, 1)
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ go func() {
|
|
|
+ conf := &Config{
|
|
|
+ PSK: func(hint []byte) ([]byte, error) {
|
|
|
+ return nil, pskRejected
|
|
|
+ },
|
|
|
+ PSKIdentityHint: []byte{},
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err := testClient(ctx, ca, conf, false)
|
|
|
+ clientErr <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ config := &Config{
|
|
|
+ PSK: func(hint []byte) ([]byte, error) {
|
|
|
+ return nil, pskRejected
|
|
|
+ },
|
|
|
+ PSKIdentityHint: []byte{},
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) {
|
|
|
+ t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := <-clientErr; !errors.Is(err, pskRejected) {
|
|
|
+ t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestClientTimeout(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ clientErr := make(chan error, 1)
|
|
|
+
|
|
|
+ ca, _ := dpipe.Pipe()
|
|
|
+ go func() {
|
|
|
+ conf := &Config{}
|
|
|
+
|
|
|
+ c, err := testClient(ctx, ca, conf, true)
|
|
|
+ if err == nil {
|
|
|
+ _ = c.Close()
|
|
|
+ }
|
|
|
+ clientErr <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ // no server!
|
|
|
+ err := <-clientErr
|
|
|
+ if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
|
|
+ t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSRTPConfiguration(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ for _, test := range []struct {
|
|
|
+ Name string
|
|
|
+ ClientSRTP []SRTPProtectionProfile
|
|
|
+ ServerSRTP []SRTPProtectionProfile
|
|
|
+ ExpectedProfile SRTPProtectionProfile
|
|
|
+ WantClientError error
|
|
|
+ WantServerError error
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ Name: "No SRTP in use",
|
|
|
+ ClientSRTP: nil,
|
|
|
+ ServerSRTP: nil,
|
|
|
+ ExpectedProfile: 0,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "SRTP both ends",
|
|
|
+ ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
|
|
+ ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
|
|
+ ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "SRTP client only",
|
|
|
+ ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
|
|
+ ServerSRTP: nil,
|
|
|
+ ExpectedProfile: 0,
|
|
|
+ WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ WantServerError: errServerNoMatchingSRTPProfile,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "SRTP server only",
|
|
|
+ ClientSRTP: nil,
|
|
|
+ ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
|
|
+ ExpectedProfile: 0,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Multiple Suites",
|
|
|
+ ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
|
|
|
+ ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
|
|
|
+ ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Multiple Suites, Client Chooses",
|
|
|
+ ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
|
|
|
+ ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80},
|
|
|
+ ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
|
|
|
+ if !errors.Is(err, test.WantServerError) {
|
|
|
+ t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
|
|
+ }
|
|
|
+ if err == nil {
|
|
|
+ defer func() {
|
|
|
+ _ = server.Close()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
+ res := <-c
|
|
|
+ if res.err == nil {
|
|
|
+ defer func() {
|
|
|
+ _ = res.c.Close()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ if !errors.Is(res.err, test.WantClientError) {
|
|
|
+ t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
|
|
+ }
|
|
|
+ if res.c == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile()
|
|
|
+ if actualClientSRTP != test.ExpectedProfile {
|
|
|
+ t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP)
|
|
|
+ }
|
|
|
+
|
|
|
+ actualServerSRTP, _ := server.SelectedSRTPProtectionProfile()
|
|
|
+ if actualServerSRTP != test.ExpectedProfile {
|
|
|
+ t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestClientCertificate(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ srvCert, err := selfsign.GenerateSelfSigned()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ srvCAPool := x509.NewCertPool()
|
|
|
+ srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0])
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ srvCAPool.AddCert(srvCertificate)
|
|
|
+
|
|
|
+ cert, err := selfsign.GenerateSelfSigned()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ certificate, err := x509.ParseCertificate(cert.Certificate[0])
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ caPool := x509.NewCertPool()
|
|
|
+ caPool.AddCert(certificate)
|
|
|
+
|
|
|
+ t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
|
|
|
+ tests := map[string]struct {
|
|
|
+ clientCfg *Config
|
|
|
+ serverCfg *Config
|
|
|
+ wantErr bool
|
|
|
+ }{
|
|
|
+ "NoClientCert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: NoClientCert,
|
|
|
+ ClientCAs: caPool,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "NoClientCert_cert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequireAnyClientCert,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "RequestClientCert_cert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequestClientCert,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "RequestClientCert_no_cert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequestClientCert,
|
|
|
+ ClientCAs: caPool,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "RequireAnyClientCert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequireAnyClientCert,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "RequireAnyClientCert_error": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequireAnyClientCert,
|
|
|
+ },
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ "VerifyClientCertIfGiven_no_cert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: VerifyClientCertIfGiven,
|
|
|
+ ClientCAs: caPool,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "VerifyClientCertIfGiven_cert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: VerifyClientCertIfGiven,
|
|
|
+ ClientCAs: caPool,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "VerifyClientCertIfGiven_error": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: VerifyClientCertIfGiven,
|
|
|
+ },
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ "RequireAndVerifyClientCert": {
|
|
|
+ clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{
|
|
|
+ Certificates: []tls.Certificate{srvCert},
|
|
|
+ ClientAuth: RequireAndVerifyClientCert,
|
|
|
+ ClientCAs: caPool,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for name, tt := range tests {
|
|
|
+ tt := tt
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client, err := Client(ca, tt.clientCfg)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ server, err := Server(cb, tt.serverCfg)
|
|
|
+ res := <-c
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ _ = server.Close()
|
|
|
+ }
|
|
|
+ if res.err == nil {
|
|
|
+ _ = res.c.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ if tt.wantErr {
|
|
|
+ if err != nil {
|
|
|
+ // Error expected, test succeeded
|
|
|
+ return
|
|
|
+ }
|
|
|
+ t.Error("Error expected")
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("Server failed(%v)", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if res.err != nil {
|
|
|
+ t.Errorf("Client failed(%v)", res.err)
|
|
|
+ }
|
|
|
+
|
|
|
+ actualClientCert := server.ConnectionState().PeerCertificates
|
|
|
+ if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
|
|
|
+ if actualClientCert == nil {
|
|
|
+ t.Errorf("Client did not provide a certificate")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(actualClientCert) != len(tt.clientCfg.Certificates[0].Certificate) || !bytes.Equal(tt.clientCfg.Certificates[0].Certificate[0], actualClientCert[0]) {
|
|
|
+ t.Errorf("Client certificate was not communicated correctly")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if tt.serverCfg.ClientAuth == NoClientCert {
|
|
|
+ if actualClientCert != nil {
|
|
|
+ t.Errorf("Client certificate wasn't expected")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ actualServerCert := res.c.ConnectionState().PeerCertificates
|
|
|
+ if actualServerCert == nil {
|
|
|
+ t.Errorf("Server did not provide a certificate")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(actualServerCert) != len(tt.serverCfg.Certificates[0].Certificate) || !bytes.Equal(tt.serverCfg.Certificates[0].Certificate[0], actualServerCert[0]) {
|
|
|
+ t.Errorf("Server certificate was not communicated correctly")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestExtendedMasterSecret(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ tests := map[string]struct {
|
|
|
+ clientCfg *Config
|
|
|
+ serverCfg *Config
|
|
|
+ expectedClientErr error
|
|
|
+ expectedServerErr error
|
|
|
+ }{
|
|
|
+ "Request_Request_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Request_Require_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Request_Disable_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Require_Request_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Require_Require_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Require_Disable_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: errClientRequiredButNoServerEMS,
|
|
|
+ expectedServerErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ },
|
|
|
+ "Disable_Request_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequestExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ "Disable_Require_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ expectedServerErr: errServerRequiredButNoClientEMS,
|
|
|
+ },
|
|
|
+ "Disable_Disable_ExtendedMasterSecret": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ serverCfg: &Config{
|
|
|
+ ExtendedMasterSecret: DisableExtendedMasterSecret,
|
|
|
+ },
|
|
|
+ expectedClientErr: nil,
|
|
|
+ expectedServerErr: nil,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for name, tt := range tests {
|
|
|
+ tt := tt
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client, err := testClient(ctx, ca, tt.clientCfg, true)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ server, err := testServer(ctx, cb, tt.serverCfg, true)
|
|
|
+ res := <-c
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ _ = server.Close()
|
|
|
+ }
|
|
|
+ if res.err == nil {
|
|
|
+ _ = res.c.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ if !errors.Is(res.err, tt.expectedClientErr) {
|
|
|
+ t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !errors.Is(err, tt.expectedServerErr) {
|
|
|
+ t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestServerCertificate(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ cert, err := selfsign.GenerateSelfSigned()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ certificate, err := x509.ParseCertificate(cert.Certificate[0])
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ caPool := x509.NewCertPool()
|
|
|
+ caPool.AddCert(certificate)
|
|
|
+
|
|
|
+ t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
|
|
|
+ tests := map[string]struct {
|
|
|
+ clientCfg *Config
|
|
|
+ serverCfg *Config
|
|
|
+ wantErr bool
|
|
|
+ }{
|
|
|
+ "no_ca": {
|
|
|
+ clientCfg: &Config{},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ "good_ca": {
|
|
|
+ clientCfg: &Config{RootCAs: caPool},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ },
|
|
|
+ "no_ca_skip_verify": {
|
|
|
+ clientCfg: &Config{InsecureSkipVerify: true},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ },
|
|
|
+ "good_ca_skip_verify_custom_verify_peer": {
|
|
|
+ clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
|
|
|
+ if len(chain) != 0 {
|
|
|
+ return errNotExpectedChain
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }},
|
|
|
+ },
|
|
|
+ "good_ca_verify_custom_verify_peer": {
|
|
|
+ clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
|
|
|
+ serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
|
|
|
+ if len(chain) == 0 {
|
|
|
+ return errExpecedChain
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }},
|
|
|
+ },
|
|
|
+ "good_ca_custom_verify_peer": {
|
|
|
+ clientCfg: &Config{
|
|
|
+ RootCAs: caPool,
|
|
|
+ VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
|
|
|
+ return errWrongCert
|
|
|
+ },
|
|
|
+ },
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ "server_name": {
|
|
|
+ clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ },
|
|
|
+ "server_name_error": {
|
|
|
+ clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"},
|
|
|
+ serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for name, tt := range tests {
|
|
|
+ tt := tt
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ srvCh := make(chan result)
|
|
|
+ go func() {
|
|
|
+ s, err := Server(cb, tt.serverCfg)
|
|
|
+ srvCh <- result{s, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ cli, err := Client(ca, tt.clientCfg)
|
|
|
+ if err == nil {
|
|
|
+ _ = cli.Close()
|
|
|
+ }
|
|
|
+ if !tt.wantErr && err != nil {
|
|
|
+ t.Errorf("Client failed(%v)", err)
|
|
|
+ }
|
|
|
+ if tt.wantErr && err == nil {
|
|
|
+ t.Fatal("Error expected")
|
|
|
+ }
|
|
|
+
|
|
|
+ srv := <-srvCh
|
|
|
+ if srv.err == nil {
|
|
|
+ _ = srv.c.Close()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestCipherSuiteConfiguration(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ for _, test := range []struct {
|
|
|
+ Name string
|
|
|
+ ClientCipherSuites []CipherSuiteID
|
|
|
+ ServerCipherSuites []CipherSuiteID
|
|
|
+ WantClientError error
|
|
|
+ WantServerError error
|
|
|
+ WantSelectedCipherSuite CipherSuiteID
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ Name: "No CipherSuites specified",
|
|
|
+ ClientCipherSuites: nil,
|
|
|
+ ServerCipherSuites: nil,
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Invalid CipherSuite",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{0x00},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{0x00},
|
|
|
+ WantClientError: &invalidCipherSuite{0x00},
|
|
|
+ WantServerError: &invalidCipherSuite{0x00},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Valid CipherSuites specified",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "CipherSuites mismatch",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
|
|
+ WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
|
|
+ WantServerError: errCipherSuiteNoIntersection,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Valid CipherSuites CCM specified",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Valid CipherSuites CCM-8 specified",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Server supports subset of client suites",
|
|
|
+ ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
|
|
+ ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
|
|
+ WantClientError: nil,
|
|
|
+ WantServerError: nil,
|
|
|
+ WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ test := test
|
|
|
+ t.Run(test.Name, func(t *testing.T) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true)
|
|
|
+ if err == nil {
|
|
|
+ defer func() {
|
|
|
+ _ = server.Close()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ if !errors.Is(err, test.WantServerError) {
|
|
|
+ t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ res := <-c
|
|
|
+ if res.err == nil {
|
|
|
+ _ = server.Close()
|
|
|
+ }
|
|
|
+ if !errors.Is(res.err, test.WantClientError) {
|
|
|
+ t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
|
|
+ }
|
|
|
+ if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite {
|
|
|
+ t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID())
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestCertificateAndPSKServer(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ for _, test := range []struct {
|
|
|
+ Name string
|
|
|
+ ClientPSK bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ Name: "Client uses PKI",
|
|
|
+ ClientPSK: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Client uses PSK",
|
|
|
+ ClientPSK: true,
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ test := test
|
|
|
+ t.Run(test.Name, func(t *testing.T) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}
|
|
|
+ if test.ClientPSK {
|
|
|
+ config.PSK = func([]byte) ([]byte, error) {
|
|
|
+ return []byte{0x00, 0x01, 0x02}, nil
|
|
|
+ }
|
|
|
+ config.PSKIdentityHint = []byte{0x00}
|
|
|
+ config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256}
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := testClient(ctx, ca, config, false)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ config := &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256},
|
|
|
+ PSK: func([]byte) ([]byte, error) {
|
|
|
+ return []byte{0x00, 0x01, 0x02}, nil
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ server, err := testServer(ctx, cb, config, true)
|
|
|
+ if err == nil {
|
|
|
+ defer func() {
|
|
|
+ _ = server.Close()
|
|
|
+ }()
|
|
|
+ } else {
|
|
|
+ t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ res := <-c
|
|
|
+ if res.err == nil {
|
|
|
+ _ = server.Close()
|
|
|
+ } else {
|
|
|
+ t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestPSKConfiguration(t *testing.T) {
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ for _, test := range []struct {
|
|
|
+ Name string
|
|
|
+ ClientHasCertificate bool
|
|
|
+ ServerHasCertificate bool
|
|
|
+ ClientPSK PSKCallback
|
|
|
+ ServerPSK PSKCallback
|
|
|
+ ClientPSKIdentity []byte
|
|
|
+ ServerPSKIdentity []byte
|
|
|
+ WantClientError error
|
|
|
+ WantServerError error
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ Name: "PSK and no certificate specified",
|
|
|
+ ClientHasCertificate: false,
|
|
|
+ ServerHasCertificate: false,
|
|
|
+ ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ClientPSKIdentity: []byte{0x00},
|
|
|
+ ServerPSKIdentity: []byte{0x00},
|
|
|
+ WantClientError: errNoAvailablePSKCipherSuite,
|
|
|
+ WantServerError: errNoAvailablePSKCipherSuite,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "PSK and certificate specified",
|
|
|
+ ClientHasCertificate: true,
|
|
|
+ ServerHasCertificate: true,
|
|
|
+ ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ClientPSKIdentity: []byte{0x00},
|
|
|
+ ServerPSKIdentity: []byte{0x00},
|
|
|
+ WantClientError: errNoAvailablePSKCipherSuite,
|
|
|
+ WantServerError: errNoAvailablePSKCipherSuite,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "PSK and no identity specified",
|
|
|
+ ClientHasCertificate: false,
|
|
|
+ ServerHasCertificate: false,
|
|
|
+ ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
|
|
|
+ ClientPSKIdentity: nil,
|
|
|
+ ServerPSKIdentity: nil,
|
|
|
+ WantClientError: errPSKAndIdentityMustBeSetForClient,
|
|
|
+ WantServerError: errNoAvailablePSKCipherSuite,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "No PSK and identity specified",
|
|
|
+ ClientHasCertificate: false,
|
|
|
+ ServerHasCertificate: false,
|
|
|
+ ClientPSK: nil,
|
|
|
+ ServerPSK: nil,
|
|
|
+ ClientPSKIdentity: []byte{0x00},
|
|
|
+ ServerPSKIdentity: []byte{0x00},
|
|
|
+ WantClientError: errIdentityNoPSK,
|
|
|
+ WantServerError: errIdentityNoPSK,
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ type result struct {
|
|
|
+ c *Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ c := make(chan result)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate)
|
|
|
+ c <- result{client, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ _, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate)
|
|
|
+ if err != nil || test.WantServerError != nil {
|
|
|
+ if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
|
|
|
+ t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ res := <-c
|
|
|
+ if res.err != nil || test.WantClientError != nil {
|
|
|
+ if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
|
|
|
+ t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestServerTimeout(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ cookie := make([]byte, 20)
|
|
|
+ _, err := rand.Read(cookie)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var rand [28]byte
|
|
|
+ random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
|
|
|
+
|
|
|
+ cipherSuites := []CipherSuite{
|
|
|
+ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
|
|
|
+ &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
|
|
|
+ }
|
|
|
+
|
|
|
+ extensions := []extension.Extension{
|
|
|
+ &extension.SupportedSignatureAlgorithms{
|
|
|
+ SignatureHashAlgorithms: []signaturehash.Algorithm{
|
|
|
+ {Hash: hash.SHA256, Signature: signature.ECDSA},
|
|
|
+ {Hash: hash.SHA384, Signature: signature.ECDSA},
|
|
|
+ {Hash: hash.SHA512, Signature: signature.ECDSA},
|
|
|
+ {Hash: hash.SHA256, Signature: signature.RSA},
|
|
|
+ {Hash: hash.SHA384, Signature: signature.RSA},
|
|
|
+ {Hash: hash.SHA512, Signature: signature.RSA},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ &extension.SupportedEllipticCurves{
|
|
|
+ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
|
|
|
+ },
|
|
|
+ &extension.SupportedPointFormats{
|
|
|
+ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ record := &recordlayer.RecordLayer{
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ SequenceNumber: 0,
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ // sequenceNumber and messageSequence line up, may need to be re-evaluated
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 0,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: cookie,
|
|
|
+ Random: random,
|
|
|
+ CipherSuiteIDs: cipherSuiteIDs(cipherSuites),
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ Extensions: extensions,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ packet, err := record.Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ defer func() {
|
|
|
+ err := ca.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Client reader
|
|
|
+ caReadChan := make(chan []byte, 1000)
|
|
|
+ go func() {
|
|
|
+ for {
|
|
|
+ data := make([]byte, 8192)
|
|
|
+ n, err := ca.Read(data)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ caReadChan <- data[:n]
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Start sending ClientHello packets until server responds with first packet
|
|
|
+ go func() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-time.After(10 * time.Millisecond):
|
|
|
+ _, err := ca.Write(packet)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ case <-caReadChan:
|
|
|
+ // Once we receive the first reply from the server, stop
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ config := &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ FlightInterval: 100 * time.Millisecond,
|
|
|
+ }
|
|
|
+
|
|
|
+ _, serverErr := testServer(ctx, cb, config, true)
|
|
|
+ if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
|
|
|
+ t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait a little longer to ensure no additional messages have been sent by the server
|
|
|
+ time.Sleep(300 * time.Millisecond)
|
|
|
+ select {
|
|
|
+ case msg := <-caReadChan:
|
|
|
+ t.Fatalf("Expected no additional messages from server, got: %+v", msg)
|
|
|
+ default:
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestProtocolVersionValidation(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ cookie := make([]byte, 20)
|
|
|
+ if _, err := rand.Read(cookie); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var rand [28]byte
|
|
|
+ random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
|
|
|
+
|
|
|
+ localKeypair, err := elliptic.GenerateKeypair(elliptic.X25519)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ config := &Config{
|
|
|
+ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
|
+ FlightInterval: 100 * time.Millisecond,
|
|
|
+ }
|
|
|
+
|
|
|
+ t.Run("Server", func(t *testing.T) {
|
|
|
+ serverCases := map[string]struct {
|
|
|
+ records []*recordlayer.RecordLayer
|
|
|
+ }{
|
|
|
+ "ClientHelloVersion": {
|
|
|
+ records: []*recordlayer.RecordLayer{
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
|
|
|
+ Cookie: cookie,
|
|
|
+ Random: random,
|
|
|
+ CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "SecondsClientHelloVersion": {
|
|
|
+ records: []*recordlayer.RecordLayer{
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: cookie,
|
|
|
+ Random: random,
|
|
|
+ CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: 1,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 1,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
|
|
|
+ Cookie: cookie,
|
|
|
+ Random: random,
|
|
|
+ CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for name, c := range serverCases {
|
|
|
+ c := c
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ defer func() {
|
|
|
+ err := ca.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+ defer wg.Wait()
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
|
|
|
+ t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ time.Sleep(50 * time.Millisecond)
|
|
|
+
|
|
|
+ resp := make([]byte, 1024)
|
|
|
+ for _, record := range c.records {
|
|
|
+ packet, err := record.Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if _, werr := ca.Write(packet); werr != nil {
|
|
|
+ t.Fatal(werr)
|
|
|
+ }
|
|
|
+ n, rerr := ca.Read(resp[:cap(resp)])
|
|
|
+ if rerr != nil {
|
|
|
+ t.Fatal(rerr)
|
|
|
+ }
|
|
|
+ resp = resp[:n]
|
|
|
+ }
|
|
|
+
|
|
|
+ h := &recordlayer.Header{}
|
|
|
+ if err := h.Unmarshal(resp); err != nil {
|
|
|
+ t.Fatal("Failed to unmarshal response")
|
|
|
+ }
|
|
|
+ if h.ContentType != protocol.ContentTypeAlert {
|
|
|
+ t.Errorf("Peer must return alert to unsupported protocol version")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("Client", func(t *testing.T) {
|
|
|
+ clientCases := map[string]struct {
|
|
|
+ records []*recordlayer.RecordLayer
|
|
|
+ }{
|
|
|
+ "ServerHelloVersion": {
|
|
|
+ records: []*recordlayer.RecordLayer{
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Message: &handshake.MessageHelloVerifyRequest{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: cookie,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: 1,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 1,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageServerHello{
|
|
|
+ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
|
|
|
+ Random: random,
|
|
|
+ CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(),
|
|
|
+ CompressionMethod: defaultCompressionMethods()[0],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: 2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 2,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageCertificate{},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: 3,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 3,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageServerKeyExchange{
|
|
|
+ EllipticCurveType: elliptic.CurveTypeNamedCurve,
|
|
|
+ NamedCurve: elliptic.X25519,
|
|
|
+ PublicKey: localKeypair.PublicKey,
|
|
|
+ HashAlgorithm: hash.SHA256,
|
|
|
+ SignatureAlgorithm: signature.ECDSA,
|
|
|
+ Signature: make([]byte, 64),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: 4,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: 4,
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageServerHelloDone{},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for name, c := range clientCases {
|
|
|
+ c := c
|
|
|
+ t.Run(name, func(t *testing.T) {
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ defer func() {
|
|
|
+ err := ca.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+ defer wg.Wait()
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
|
|
|
+ t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ time.Sleep(50 * time.Millisecond)
|
|
|
+
|
|
|
+ for _, record := range c.records {
|
|
|
+ if _, err := ca.Read(make([]byte, 1024)); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ packet, err := record.Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if _, err := ca.Write(packet); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ resp := make([]byte, 1024)
|
|
|
+ n, err := ca.Read(resp)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ resp = resp[:n]
|
|
|
+
|
|
|
+ h := &recordlayer.Header{}
|
|
|
+ if err := h.Unmarshal(resp); err != nil {
|
|
|
+ t.Fatal("Failed to unmarshal response")
|
|
|
+ }
|
|
|
+ if h.ContentType != protocol.ContentTypeAlert {
|
|
|
+ t.Errorf("Peer must return alert to unsupported protocol version")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestMultipleHelloVerifyRequest(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(time.Second * 20)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ cookies := [][]byte{
|
|
|
+ // first clientHello contains an empty cookie
|
|
|
+ {},
|
|
|
+ }
|
|
|
+ var packets [][]byte
|
|
|
+ for i := 0; i < 2; i++ {
|
|
|
+ cookie := make([]byte, 20)
|
|
|
+ if _, err := rand.Read(cookie); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ cookies = append(cookies, cookie)
|
|
|
+
|
|
|
+ record := &recordlayer.RecordLayer{
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ SequenceNumber: uint64(i),
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: uint16(i),
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageHelloVerifyRequest{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: cookie,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ packet, err := record.Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ packets = append(packets, packet)
|
|
|
+ }
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ defer func() {
|
|
|
+ err := ca.Close()
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+ defer wg.Wait()
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ _, _ = testClient(ctx, ca, &Config{}, false)
|
|
|
+ }()
|
|
|
+
|
|
|
+ for i, cookie := range cookies {
|
|
|
+ // read client hello
|
|
|
+ resp := make([]byte, 1024)
|
|
|
+ n, err := cb.Read(resp)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ record := &recordlayer.RecordLayer{}
|
|
|
+ if err := record.Unmarshal(resp[:n]); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ clientHello := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
|
|
|
+ if !bytes.Equal(clientHello.Cookie, cookie) {
|
|
|
+ t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie)
|
|
|
+ }
|
|
|
+ if len(packets) <= i {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ // write hello verify request
|
|
|
+ if _, err := cb.Write(packets[i]); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ cancel()
|
|
|
+}
|
|
|
+
|
|
|
+// Assert that a DTLS Server always responds with RenegotiationInfo if
|
|
|
+// a ClientHello contained that extension or not
|
|
|
+func TestRenegotationInfo(t *testing.T) {
|
|
|
+ // Limit runtime in case of deadlocks
|
|
|
+ lim := test.TimeOut(10 * time.Second)
|
|
|
+ defer lim.Stop()
|
|
|
+
|
|
|
+ // Check for leaking routines
|
|
|
+ report := test.CheckRoutines(t)
|
|
|
+ defer report()
|
|
|
+
|
|
|
+ resp := make([]byte, 1024)
|
|
|
+
|
|
|
+ for _, testCase := range []struct {
|
|
|
+ Name string
|
|
|
+ SendRenegotiationInfo bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ "Include RenegotiationInfo",
|
|
|
+ true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "No RenegotiationInfo",
|
|
|
+ false,
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ test := testCase
|
|
|
+ t.Run(test.Name, func(t *testing.T) {
|
|
|
+ sendClientHello := func(cookie []byte, ca net.Conn, sequenceNumber uint64) {
|
|
|
+ extensions := []extension.Extension{}
|
|
|
+ if test.SendRenegotiationInfo {
|
|
|
+ extensions = append(extensions, &extension.RenegotiationInfo{
|
|
|
+ RenegotiatedConnection: 0,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ packet, err := (&recordlayer.RecordLayer{
|
|
|
+ Header: recordlayer.Header{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ SequenceNumber: sequenceNumber,
|
|
|
+ },
|
|
|
+ Content: &handshake.Handshake{
|
|
|
+ Header: handshake.Header{
|
|
|
+ MessageSequence: uint16(sequenceNumber),
|
|
|
+ },
|
|
|
+ Message: &handshake.MessageClientHello{
|
|
|
+ Version: protocol.Version1_2,
|
|
|
+ Cookie: cookie,
|
|
|
+ CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()),
|
|
|
+ CompressionMethods: defaultCompressionMethods(),
|
|
|
+ Extensions: extensions,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }).Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err = ca.Write(packet); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ca, cb := dpipe.Pipe()
|
|
|
+ defer func() {
|
|
|
+ if err := ca.Close(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ time.Sleep(50 * time.Millisecond)
|
|
|
+
|
|
|
+ sendClientHello([]byte{}, ca, 0)
|
|
|
+ n, err := ca.Read(resp)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ r := &recordlayer.RecordLayer{}
|
|
|
+ if err = r.Unmarshal(resp[:n]); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
|
|
|
+
|
|
|
+ sendClientHello(helloVerifyRequest.Cookie, ca, 1)
|
|
|
+ if n, err = ca.Read(resp); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ messages, err := recordlayer.UnpackDatagram(resp[:n])
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := r.Unmarshal(messages[0]); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
|
|
+ gotNegotationInfo := false
|
|
|
+ for _, v := range serverHello.Extensions {
|
|
|
+ if _, ok := v.(*extension.RenegotiationInfo); ok {
|
|
|
+ gotNegotationInfo = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !gotNegotationInfo {
|
|
|
+ t.Fatalf("Received ServerHello without RenegotiationInfo")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|