| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 | // Copyright 2013 The Go Authors. All rights reserved.// Use of this source code is governed by a BSD-style// license that can be found in the LICENSE file.package sshimport (	"bytes"	"crypto/rand"	"errors"	"fmt"	"net"	"runtime"	"strings"	"sync"	"testing")type testChecker struct {	calls []string}func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {	if dialAddr == "bad" {		return fmt.Errorf("dialAddr is bad")	}	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)	}	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))	return nil}// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and// therefore is buffered (net.Pipe deadlocks if both sides start with// a write.)func netPipe() (net.Conn, net.Conn, error) {	listener, err := net.Listen("tcp", "127.0.0.1:0")	if err != nil {		return nil, nil, err	}	defer listener.Close()	c1, err := net.Dial("tcp", listener.Addr().String())	if err != nil {		return nil, nil, err	}	c2, err := listener.Accept()	if err != nil {		c1.Close()		return nil, nil, err	}	return c1, c2, nil}func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {	a, b, err := netPipe()	if err != nil {		return nil, nil, err	}	trC := newTransport(a, rand.Reader, true)	trS := newTransport(b, rand.Reader, false)	clientConf.SetDefaults()	v := []byte("version")	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())	serverConf := &ServerConfig{}	serverConf.AddHostKey(testSigners["ecdsa"])	serverConf.AddHostKey(testSigners["rsa"])	serverConf.SetDefaults()	server = newServerTransport(trS, v, v, serverConf)	return client, server, nil}func TestHandshakeBasic(t *testing.T) {	if runtime.GOOS == "plan9" {		t.Skip("see golang.org/issue/7237")	}	checker := &testChecker{}	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")	if err != nil {		t.Fatalf("handshakePair: %v", err)	}	defer trC.Close()	defer trS.Close()	go func() {		// Client writes a bunch of stuff, and does a key		// change in the middle. This should not confuse the		// handshake in progress		for i := 0; i < 10; i++ {			p := []byte{msgRequestSuccess, byte(i)}			if err := trC.writePacket(p); err != nil {				t.Fatalf("sendPacket: %v", err)			}			if i == 5 {				// halfway through, we request a key change.				_, _, err := trC.sendKexInit()				if err != nil {					t.Fatalf("sendKexInit: %v", err)				}			}		}		trC.Close()	}()	// Server checks that client messages come in cleanly	i := 0	for {		p, err := trS.readPacket()		if err != nil {			break		}		if p[0] == msgNewKeys {			continue		}		want := []byte{msgRequestSuccess, byte(i)}		if bytes.Compare(p, want) != 0 {			t.Errorf("message %d: got %q, want %q", i, p, want)		}		i++	}	if i != 10 {		t.Errorf("received %d messages, want 10.", i)	}	// If all went well, we registered exactly 1 key change.	if len(checker.calls) != 1 {		t.Fatalf("got %d host key checks, want 1", len(checker.calls))	}	pub := testSigners["ecdsa"].PublicKey()	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())	if want != checker.calls[0] {		t.Errorf("got %q want %q for host key check", checker.calls[0], want)	}}func TestHandshakeError(t *testing.T) {	checker := &testChecker{}	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")	if err != nil {		t.Fatalf("handshakePair: %v", err)	}	defer trC.Close()	defer trS.Close()	// send a packet	packet := []byte{msgRequestSuccess, 42}	if err := trC.writePacket(packet); err != nil {		t.Errorf("writePacket: %v", err)	}	// Now request a key change.	_, _, err = trC.sendKexInit()	if err != nil {		t.Errorf("sendKexInit: %v", err)	}	// the key change will fail, and afterwards we can't write.	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {		t.Errorf("writePacket after botched rekey succeeded.")	}	readback, err := trS.readPacket()	if err != nil {		t.Fatalf("server closed too soon: %v", err)	}	if bytes.Compare(readback, packet) != 0 {		t.Errorf("got %q want %q", readback, packet)	}	readback, err = trS.readPacket()	if err == nil {		t.Errorf("got a message %q after failed key change", readback)	}}func TestHandshakeTwice(t *testing.T) {	checker := &testChecker{}	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")	if err != nil {		t.Fatalf("handshakePair: %v", err)	}	defer trC.Close()	defer trS.Close()	// send a packet	packet := make([]byte, 5)	packet[0] = msgRequestSuccess	if err := trC.writePacket(packet); err != nil {		t.Errorf("writePacket: %v", err)	}	// Now request a key change.	_, _, err = trC.sendKexInit()	if err != nil {		t.Errorf("sendKexInit: %v", err)	}	// Send another packet. Use a fresh one, since writePacket destroys.	packet = make([]byte, 5)	packet[0] = msgRequestSuccess	if err := trC.writePacket(packet); err != nil {		t.Errorf("writePacket: %v", err)	}	// 2nd key change.	_, _, err = trC.sendKexInit()	if err != nil {		t.Errorf("sendKexInit: %v", err)	}	packet = make([]byte, 5)	packet[0] = msgRequestSuccess	if err := trC.writePacket(packet); err != nil {		t.Errorf("writePacket: %v", err)	}	packet = make([]byte, 5)	packet[0] = msgRequestSuccess	for i := 0; i < 5; i++ {		msg, err := trS.readPacket()		if err != nil {			t.Fatalf("server closed too soon: %v", err)		}		if msg[0] == msgNewKeys {			continue		}		if bytes.Compare(msg, packet) != 0 {			t.Errorf("packet %d: got %q want %q", i, msg, packet)		}	}	if len(checker.calls) != 2 {		t.Errorf("got %d key changes, want 2", len(checker.calls))	}}func TestHandshakeAutoRekeyWrite(t *testing.T) {	checker := &testChecker{}	clientConf := &ClientConfig{HostKeyCallback: checker.Check}	clientConf.RekeyThreshold = 500	trC, trS, err := handshakePair(clientConf, "addr")	if err != nil {		t.Fatalf("handshakePair: %v", err)	}	defer trC.Close()	defer trS.Close()	for i := 0; i < 5; i++ {		packet := make([]byte, 251)		packet[0] = msgRequestSuccess		if err := trC.writePacket(packet); err != nil {			t.Errorf("writePacket: %v", err)		}	}	j := 0	for ; j < 5; j++ {		_, err := trS.readPacket()		if err != nil {			break		}	}	if j != 5 {		t.Errorf("got %d, want 5 messages", j)	}	if len(checker.calls) != 2 {		t.Errorf("got %d key changes, wanted 2", len(checker.calls))	}}type syncChecker struct {	called chan int}func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {	t.called <- 1	return nil}func TestHandshakeAutoRekeyRead(t *testing.T) {	sync := &syncChecker{make(chan int, 2)}	clientConf := &ClientConfig{		HostKeyCallback: sync.Check,	}	clientConf.RekeyThreshold = 500	trC, trS, err := handshakePair(clientConf, "addr")	if err != nil {		t.Fatalf("handshakePair: %v", err)	}	defer trC.Close()	defer trS.Close()	packet := make([]byte, 501)	packet[0] = msgRequestSuccess	if err := trS.writePacket(packet); err != nil {		t.Fatalf("writePacket: %v", err)	}	// While we read out the packet, a key change will be	// initiated.	if _, err := trC.readPacket(); err != nil {		t.Fatalf("readPacket(client): %v", err)	}	<-sync.called}// errorKeyingTransport generates errors after a given number of// read/write operations.type errorKeyingTransport struct {	packetConn	readLeft, writeLeft int}func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {	return nil}func (n *errorKeyingTransport) getSessionID() []byte {	return nil}func (n *errorKeyingTransport) writePacket(packet []byte) error {	if n.writeLeft == 0 {		n.Close()		return errors.New("barf")	}	n.writeLeft--	return n.packetConn.writePacket(packet)}func (n *errorKeyingTransport) readPacket() ([]byte, error) {	if n.readLeft == 0 {		n.Close()		return nil, errors.New("barf")	}	n.readLeft--	return n.packetConn.readPacket()}func TestHandshakeErrorHandlingRead(t *testing.T) {	for i := 0; i < 20; i++ {		testHandshakeErrorHandlingN(t, i, -1)	}}func TestHandshakeErrorHandlingWrite(t *testing.T) {	for i := 0; i < 20; i++ {		testHandshakeErrorHandlingN(t, -1, i)	}}// testHandshakeErrorHandlingN runs handshakes, injecting errors. If// handshakeTransport deadlocks, the go runtime will detect it and// panic.func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})	a, b := memPipe()	defer a.Close()	defer b.Close()	key := testSigners["ecdsa"]	serverConf := Config{RekeyThreshold: minRekeyThreshold}	serverConf.SetDefaults()	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})	serverConn.hostKeys = []Signer{key}	go serverConn.readLoop()	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}	clientConf.SetDefaults()	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}	go clientConn.readLoop()	var wg sync.WaitGroup	wg.Add(4)	for _, hs := range []packetConn{serverConn, clientConn} {		go func(c packetConn) {			for {				err := c.writePacket(msg)				if err != nil {					break				}			}			wg.Done()		}(hs)		go func(c packetConn) {			for {				_, err := c.readPacket()				if err != nil {					break				}			}			wg.Done()		}(hs)	}	wg.Wait()}
 |