|  | @@ -19,6 +19,11 @@ import (
 | 
	
		
			
				|  |  |  // messages are wrong when using ECDH.
 | 
	
		
			
				|  |  |  const debugHandshake = false
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +// chanSize sets the amount of buffering SSH connections. This is
 | 
	
		
			
				|  |  | +// primarily for testing: setting chanSize=0 uncovers deadlocks more
 | 
	
		
			
				|  |  | +// quickly.
 | 
	
		
			
				|  |  | +const chanSize = 16
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  // keyingTransport is a packet based transport that supports key
 | 
	
		
			
				|  |  |  // changes. It need not be thread-safe. It should pass through
 | 
	
		
			
				|  |  |  // msgNewKeys in both directions.
 | 
	
	
		
			
				|  | @@ -53,34 +58,58 @@ type handshakeTransport struct {
 | 
	
		
			
				|  |  |  	incoming  chan []byte
 | 
	
		
			
				|  |  |  	readError error
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	mu             sync.Mutex
 | 
	
		
			
				|  |  | +	writeError     error
 | 
	
		
			
				|  |  | +	sentInitPacket []byte
 | 
	
		
			
				|  |  | +	sentInitMsg    *kexInitMsg
 | 
	
		
			
				|  |  | +	pendingPackets [][]byte // Used when a key exchange is in progress.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// If the read loop wants to schedule a kex, it pings this
 | 
	
		
			
				|  |  | +	// channel, and the write loop will send out a kex
 | 
	
		
			
				|  |  | +	// message.
 | 
	
		
			
				|  |  | +	requestKex chan struct{}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// If the other side requests or confirms a kex, its kexInit
 | 
	
		
			
				|  |  | +	// packet is sent here for the write loop to find it.
 | 
	
		
			
				|  |  | +	startKex chan *pendingKex
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	// data for host key checking
 | 
	
		
			
				|  |  |  	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
 | 
	
		
			
				|  |  |  	dialAddress     string
 | 
	
		
			
				|  |  |  	remoteAddr      net.Addr
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	readSinceKex uint64
 | 
	
		
			
				|  |  | +	// Algorithms agreed in the last key exchange.
 | 
	
		
			
				|  |  | +	algorithms *algorithms
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	readPacketsLeft uint32
 | 
	
		
			
				|  |  | +	readBytesLeft   int64
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	// Protects the writing side of the connection
 | 
	
		
			
				|  |  | -	mu              sync.Mutex
 | 
	
		
			
				|  |  | -	cond            *sync.Cond
 | 
	
		
			
				|  |  | -	sentInitPacket  []byte
 | 
	
		
			
				|  |  | -	sentInitMsg     *kexInitMsg
 | 
	
		
			
				|  |  | -	writtenSinceKex uint64
 | 
	
		
			
				|  |  | -	writeError      error
 | 
	
		
			
				|  |  | +	writePacketsLeft uint32
 | 
	
		
			
				|  |  | +	writeBytesLeft   int64
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	// The session ID or nil if first kex did not complete yet.
 | 
	
		
			
				|  |  |  	sessionID []byte
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +type pendingKex struct {
 | 
	
		
			
				|  |  | +	otherInit []byte
 | 
	
		
			
				|  |  | +	done      chan error
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
 | 
	
		
			
				|  |  |  	t := &handshakeTransport{
 | 
	
		
			
				|  |  |  		conn:          conn,
 | 
	
		
			
				|  |  |  		serverVersion: serverVersion,
 | 
	
		
			
				|  |  |  		clientVersion: clientVersion,
 | 
	
		
			
				|  |  | -		incoming:      make(chan []byte, 16),
 | 
	
		
			
				|  |  | -		config:        config,
 | 
	
		
			
				|  |  | +		incoming:      make(chan []byte, chanSize),
 | 
	
		
			
				|  |  | +		requestKex:    make(chan struct{}, 1),
 | 
	
		
			
				|  |  | +		startKex:      make(chan *pendingKex, 1),
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		config: config,
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	t.cond = sync.NewCond(&t.mu)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// We always start with a mandatory key exchange.
 | 
	
		
			
				|  |  | +	t.requestKex <- struct{}{}
 | 
	
		
			
				|  |  |  	return t
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -95,6 +124,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
 | 
	
		
			
				|  |  |  		t.hostKeyAlgorithms = supportedHostKeyAlgos
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	go t.readLoop()
 | 
	
		
			
				|  |  | +	go t.kexLoop()
 | 
	
		
			
				|  |  |  	return t
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -102,6 +132,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
 | 
	
		
			
				|  |  |  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 | 
	
		
			
				|  |  |  	t.hostKeys = config.hostKeys
 | 
	
		
			
				|  |  |  	go t.readLoop()
 | 
	
		
			
				|  |  | +	go t.kexLoop()
 | 
	
		
			
				|  |  |  	return t
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -109,6 +140,20 @@ func (t *handshakeTransport) getSessionID() []byte {
 | 
	
		
			
				|  |  |  	return t.sessionID
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +// waitSession waits for the session to be established. This should be
 | 
	
		
			
				|  |  | +// the first thing to call after instantiating handshakeTransport.
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) waitSession() error {
 | 
	
		
			
				|  |  | +	p, err := t.readPacket()
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		return err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	if p[0] != msgNewKeys {
 | 
	
		
			
				|  |  | +		return fmt.Errorf("ssh: first packet should be msgNewKeys")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (t *handshakeTransport) id() string {
 | 
	
		
			
				|  |  |  	if len(t.hostKeys) > 0 {
 | 
	
		
			
				|  |  |  		return "server"
 | 
	
	
		
			
				|  | @@ -116,6 +161,20 @@ func (t *handshakeTransport) id() string {
 | 
	
		
			
				|  |  |  	return "client"
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) printPacket(p []byte, write bool) {
 | 
	
		
			
				|  |  | +	action := "got"
 | 
	
		
			
				|  |  | +	if write {
 | 
	
		
			
				|  |  | +		action = "sent"
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
 | 
	
		
			
				|  |  | +		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		msg, err := decode(p)
 | 
	
		
			
				|  |  | +		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (t *handshakeTransport) readPacket() ([]byte, error) {
 | 
	
		
			
				|  |  |  	p, ok := <-t.incoming
 | 
	
		
			
				|  |  |  	if !ok {
 | 
	
	
		
			
				|  | @@ -125,8 +184,10 @@ func (t *handshakeTransport) readPacket() ([]byte, error) {
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (t *handshakeTransport) readLoop() {
 | 
	
		
			
				|  |  | +	first := true
 | 
	
		
			
				|  |  |  	for {
 | 
	
		
			
				|  |  | -		p, err := t.readOnePacket()
 | 
	
		
			
				|  |  | +		p, err := t.readOnePacket(first)
 | 
	
		
			
				|  |  | +		first = false
 | 
	
		
			
				|  |  |  		if err != nil {
 | 
	
		
			
				|  |  |  			t.readError = err
 | 
	
		
			
				|  |  |  			close(t.incoming)
 | 
	
	
		
			
				|  | @@ -138,67 +199,204 @@ func (t *handshakeTransport) readLoop() {
 | 
	
		
			
				|  |  |  		t.incoming <- p
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	// If we can't read, declare the writing part dead too.
 | 
	
		
			
				|  |  | +	// Stop writers too.
 | 
	
		
			
				|  |  | +	t.recordWriteError(t.readError)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// Unblock the writer should it wait for this.
 | 
	
		
			
				|  |  | +	close(t.startKex)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// Don't close t.requestKex; it's also written to from writePacket.
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) pushPacket(p []byte) error {
 | 
	
		
			
				|  |  | +	if debugHandshake {
 | 
	
		
			
				|  |  | +		t.printPacket(p, true)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return t.conn.writePacket(p)
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) getWriteError() error {
 | 
	
		
			
				|  |  |  	t.mu.Lock()
 | 
	
		
			
				|  |  |  	defer t.mu.Unlock()
 | 
	
		
			
				|  |  | -	if t.writeError == nil {
 | 
	
		
			
				|  |  | -		t.writeError = t.readError
 | 
	
		
			
				|  |  | +	return t.writeError
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) recordWriteError(err error) {
 | 
	
		
			
				|  |  | +	t.mu.Lock()
 | 
	
		
			
				|  |  | +	defer t.mu.Unlock()
 | 
	
		
			
				|  |  | +	if t.writeError == nil && err != nil {
 | 
	
		
			
				|  |  | +		t.writeError = err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	t.cond.Broadcast()
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) readOnePacket() ([]byte, error) {
 | 
	
		
			
				|  |  | -	if t.readSinceKex > t.config.RekeyThreshold {
 | 
	
		
			
				|  |  | -		if err := t.requestKeyChange(); err != nil {
 | 
	
		
			
				|  |  | -			return nil, err
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) requestKeyExchange() {
 | 
	
		
			
				|  |  | +	select {
 | 
	
		
			
				|  |  | +	case t.requestKex <- struct{}{}:
 | 
	
		
			
				|  |  | +	default:
 | 
	
		
			
				|  |  | +		// something already requested a kex, so do nothing.
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) kexLoop() {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +write:
 | 
	
		
			
				|  |  | +	for t.getWriteError() == nil {
 | 
	
		
			
				|  |  | +		var request *pendingKex
 | 
	
		
			
				|  |  | +		var sent bool
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		for request == nil || !sent {
 | 
	
		
			
				|  |  | +			var ok bool
 | 
	
		
			
				|  |  | +			select {
 | 
	
		
			
				|  |  | +			case request, ok = <-t.startKex:
 | 
	
		
			
				|  |  | +				if !ok {
 | 
	
		
			
				|  |  | +					break write
 | 
	
		
			
				|  |  | +				}
 | 
	
		
			
				|  |  | +			case <-t.requestKex:
 | 
	
		
			
				|  |  | +				break
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +			if !sent {
 | 
	
		
			
				|  |  | +				if err := t.sendKexInit(); err != nil {
 | 
	
		
			
				|  |  | +					t.recordWriteError(err)
 | 
	
		
			
				|  |  | +					break
 | 
	
		
			
				|  |  | +				}
 | 
	
		
			
				|  |  | +				sent = true
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		if err := t.getWriteError(); err != nil {
 | 
	
		
			
				|  |  | +			if request != nil {
 | 
	
		
			
				|  |  | +				request.done <- err
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			break
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// We're not servicing t.requestKex, but that is OK:
 | 
	
		
			
				|  |  | +		// we never block on sending to t.requestKex.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// We're not servicing t.startKex, but the remote end
 | 
	
		
			
				|  |  | +		// has just sent us a kexInitMsg, so it can't send
 | 
	
		
			
				|  |  | +		// another key change request, until we close the done
 | 
	
		
			
				|  |  | +		// channel on the pendingKex request.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		err := t.enterKeyExchange(request.otherInit)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		t.mu.Lock()
 | 
	
		
			
				|  |  | +		t.writeError = err
 | 
	
		
			
				|  |  | +		t.sentInitPacket = nil
 | 
	
		
			
				|  |  | +		t.sentInitMsg = nil
 | 
	
		
			
				|  |  | +		t.writePacketsLeft = packetRekeyThreshold
 | 
	
		
			
				|  |  | +		if t.config.RekeyThreshold > 0 {
 | 
	
		
			
				|  |  | +			t.writeBytesLeft = int64(t.config.RekeyThreshold)
 | 
	
		
			
				|  |  | +		} else if t.algorithms != nil {
 | 
	
		
			
				|  |  | +			t.writeBytesLeft = t.algorithms.w.rekeyBytes()
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// we have completed the key exchange. Since the
 | 
	
		
			
				|  |  | +		// reader is still blocked, it is safe to clear out
 | 
	
		
			
				|  |  | +		// the requestKex channel. This avoids the situation
 | 
	
		
			
				|  |  | +		// where: 1) we consumed our own request for the
 | 
	
		
			
				|  |  | +		// initial kex, and 2) the kex from the remote side
 | 
	
		
			
				|  |  | +		// caused another send on the requestKex channel,
 | 
	
		
			
				|  |  | +	clear:
 | 
	
		
			
				|  |  | +		for {
 | 
	
		
			
				|  |  | +			select {
 | 
	
		
			
				|  |  | +			case <-t.requestKex:
 | 
	
		
			
				|  |  | +				//
 | 
	
		
			
				|  |  | +			default:
 | 
	
		
			
				|  |  | +				break clear
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		request.done <- t.writeError
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// kex finished. Push packets that we received while
 | 
	
		
			
				|  |  | +		// the kex was in progress. Don't look at t.startKex
 | 
	
		
			
				|  |  | +		// and don't increment writtenSinceKex: if we trigger
 | 
	
		
			
				|  |  | +		// another kex while we are still busy with the last
 | 
	
		
			
				|  |  | +		// one, things will become very confusing.
 | 
	
		
			
				|  |  | +		for _, p := range t.pendingPackets {
 | 
	
		
			
				|  |  | +			t.writeError = t.pushPacket(p)
 | 
	
		
			
				|  |  | +			if t.writeError != nil {
 | 
	
		
			
				|  |  | +				break
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | +		t.pendingPackets = t.pendingPackets[:0]
 | 
	
		
			
				|  |  | +		t.mu.Unlock()
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	// drain startKex channel. We don't service t.requestKex
 | 
	
		
			
				|  |  | +	// because nobody does blocking sends there.
 | 
	
		
			
				|  |  | +	go func() {
 | 
	
		
			
				|  |  | +		for init := range t.startKex {
 | 
	
		
			
				|  |  | +			init.done <- t.writeError
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	// Unblock reader.
 | 
	
		
			
				|  |  | +	t.conn.Close()
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +// The protocol uses uint32 for packet counters, so we can't let them
 | 
	
		
			
				|  |  | +// reach 1<<32.  We will actually read and write more packets than
 | 
	
		
			
				|  |  | +// this, though: the other side may send more packets, and after we
 | 
	
		
			
				|  |  | +// hit this limit on writing we will send a few more packets for the
 | 
	
		
			
				|  |  | +// key exchange itself.
 | 
	
		
			
				|  |  | +const packetRekeyThreshold = (1 << 31)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
 | 
	
		
			
				|  |  |  	p, err := t.conn.readPacket()
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		return nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	t.readSinceKex += uint64(len(p))
 | 
	
		
			
				|  |  | +	if t.readPacketsLeft > 0 {
 | 
	
		
			
				|  |  | +		t.readPacketsLeft--
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		t.requestKeyExchange()
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if t.readBytesLeft > 0 {
 | 
	
		
			
				|  |  | +		t.readBytesLeft -= int64(len(p))
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		t.requestKeyExchange()
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	if debugHandshake {
 | 
	
		
			
				|  |  | -		if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
 | 
	
		
			
				|  |  | -			log.Printf("%s got data (packet %d bytes)", t.id(), len(p))
 | 
	
		
			
				|  |  | -		} else {
 | 
	
		
			
				|  |  | -			msg, err := decode(p)
 | 
	
		
			
				|  |  | -			log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | +		t.printPacket(p, false)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if first && p[0] != msgKexInit {
 | 
	
		
			
				|  |  | +		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	if p[0] != msgKexInit {
 | 
	
		
			
				|  |  |  		return p, nil
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	t.mu.Lock()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  	firstKex := t.sessionID == nil
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	err = t.enterKeyExchangeLocked(p)
 | 
	
		
			
				|  |  | -	if err != nil {
 | 
	
		
			
				|  |  | -		// drop connection
 | 
	
		
			
				|  |  | -		t.conn.Close()
 | 
	
		
			
				|  |  | -		t.writeError = err
 | 
	
		
			
				|  |  | +	kex := pendingKex{
 | 
	
		
			
				|  |  | +		done:      make(chan error, 1),
 | 
	
		
			
				|  |  | +		otherInit: p,
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +	t.startKex <- &kex
 | 
	
		
			
				|  |  | +	err = <-kex.done
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	if debugHandshake {
 | 
	
		
			
				|  |  |  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	// Unblock writers.
 | 
	
		
			
				|  |  | -	t.sentInitMsg = nil
 | 
	
		
			
				|  |  | -	t.sentInitPacket = nil
 | 
	
		
			
				|  |  | -	t.cond.Broadcast()
 | 
	
		
			
				|  |  | -	t.writtenSinceKex = 0
 | 
	
		
			
				|  |  | -	t.mu.Unlock()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		return nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	t.readSinceKex = 0
 | 
	
		
			
				|  |  | +	t.readPacketsLeft = packetRekeyThreshold
 | 
	
		
			
				|  |  | +	if t.config.RekeyThreshold > 0 {
 | 
	
		
			
				|  |  | +		t.readBytesLeft = int64(t.config.RekeyThreshold)
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		t.readBytesLeft = t.algorithms.r.rekeyBytes()
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	// By default, a key exchange is hidden from higher layers by
 | 
	
		
			
				|  |  |  	// translating it into msgIgnore.
 | 
	
	
		
			
				|  | @@ -213,61 +411,16 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
 | 
	
		
			
				|  |  |  	return successPacket, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -// keyChangeCategory describes whether a key exchange is the first on a
 | 
	
		
			
				|  |  | -// connection, or a subsequent one.
 | 
	
		
			
				|  |  | -type keyChangeCategory bool
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -const (
 | 
	
		
			
				|  |  | -	firstKeyExchange      keyChangeCategory = true
 | 
	
		
			
				|  |  | -	subsequentKeyExchange keyChangeCategory = false
 | 
	
		
			
				|  |  | -)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -// sendKexInit sends a key change message, and returns the message
 | 
	
		
			
				|  |  | -// that was sent. After initiating the key change, all writes will be
 | 
	
		
			
				|  |  | -// blocked until the change is done, and a failed key change will
 | 
	
		
			
				|  |  | -// close the underlying transport. This function is safe for
 | 
	
		
			
				|  |  | -// concurrent use by multiple goroutines.
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
 | 
	
		
			
				|  |  | -	var err error
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +// sendKexInit sends a key change message.
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) sendKexInit() error {
 | 
	
		
			
				|  |  |  	t.mu.Lock()
 | 
	
		
			
				|  |  | -	// If this is the initial key change, but we already have a sessionID,
 | 
	
		
			
				|  |  | -	// then do nothing because the key exchange has already completed
 | 
	
		
			
				|  |  | -	// asynchronously.
 | 
	
		
			
				|  |  | -	if !isFirst || t.sessionID == nil {
 | 
	
		
			
				|  |  | -		_, _, err = t.sendKexInitLocked(isFirst)
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -	t.mu.Unlock()
 | 
	
		
			
				|  |  | -	if err != nil {
 | 
	
		
			
				|  |  | -		return err
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -	if isFirst {
 | 
	
		
			
				|  |  | -		if packet, err := t.readPacket(); err != nil {
 | 
	
		
			
				|  |  | -			return err
 | 
	
		
			
				|  |  | -		} else if packet[0] != msgNewKeys {
 | 
	
		
			
				|  |  | -			return unexpectedMessageError(msgNewKeys, packet[0])
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -	return nil
 | 
	
		
			
				|  |  | -}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) requestInitialKeyChange() error {
 | 
	
		
			
				|  |  | -	return t.sendKexInit(firstKeyExchange)
 | 
	
		
			
				|  |  | -}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) requestKeyChange() error {
 | 
	
		
			
				|  |  | -	return t.sendKexInit(subsequentKeyExchange)
 | 
	
		
			
				|  |  | -}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -// sendKexInitLocked sends a key change message. t.mu must be locked
 | 
	
		
			
				|  |  | -// while this happens.
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
 | 
	
		
			
				|  |  | -	// kexInits may be sent either in response to the other side,
 | 
	
		
			
				|  |  | -	// or because our side wants to initiate a key change, so we
 | 
	
		
			
				|  |  | -	// may have already sent a kexInit. In that case, don't send a
 | 
	
		
			
				|  |  | -	// second kexInit.
 | 
	
		
			
				|  |  | +	defer t.mu.Unlock()
 | 
	
		
			
				|  |  |  	if t.sentInitMsg != nil {
 | 
	
		
			
				|  |  | -		return t.sentInitMsg, t.sentInitPacket, nil
 | 
	
		
			
				|  |  | +		// kexInits may be sent either in response to the other side,
 | 
	
		
			
				|  |  | +		// or because our side wants to initiate a key change, so we
 | 
	
		
			
				|  |  | +		// may have already sent a kexInit. In that case, don't send a
 | 
	
		
			
				|  |  | +		// second kexInit.
 | 
	
		
			
				|  |  | +		return nil
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	msg := &kexInitMsg{
 | 
	
	
		
			
				|  | @@ -295,53 +448,65 @@ func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexI
 | 
	
		
			
				|  |  |  	packetCopy := make([]byte, len(packet))
 | 
	
		
			
				|  |  |  	copy(packetCopy, packet)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	if err := t.conn.writePacket(packetCopy); err != nil {
 | 
	
		
			
				|  |  | -		return nil, nil, err
 | 
	
		
			
				|  |  | +	if err := t.pushPacket(packetCopy); err != nil {
 | 
	
		
			
				|  |  | +		return err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	t.sentInitMsg = msg
 | 
	
		
			
				|  |  |  	t.sentInitPacket = packet
 | 
	
		
			
				|  |  | -	return msg, packet, nil
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (t *handshakeTransport) writePacket(p []byte) error {
 | 
	
		
			
				|  |  | +	switch p[0] {
 | 
	
		
			
				|  |  | +	case msgKexInit:
 | 
	
		
			
				|  |  | +		return errors.New("ssh: only handshakeTransport can send kexInit")
 | 
	
		
			
				|  |  | +	case msgNewKeys:
 | 
	
		
			
				|  |  | +		return errors.New("ssh: only handshakeTransport can send newKeys")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	t.mu.Lock()
 | 
	
		
			
				|  |  |  	defer t.mu.Unlock()
 | 
	
		
			
				|  |  | +	if t.writeError != nil {
 | 
	
		
			
				|  |  | +		return t.writeError
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	if t.writtenSinceKex > t.config.RekeyThreshold {
 | 
	
		
			
				|  |  | -		t.sendKexInitLocked(subsequentKeyExchange)
 | 
	
		
			
				|  |  | +	if t.sentInitMsg != nil {
 | 
	
		
			
				|  |  | +		// Copy the packet so the writer can reuse the buffer.
 | 
	
		
			
				|  |  | +		cp := make([]byte, len(p))
 | 
	
		
			
				|  |  | +		copy(cp, p)
 | 
	
		
			
				|  |  | +		t.pendingPackets = append(t.pendingPackets, cp)
 | 
	
		
			
				|  |  | +		return nil
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	for t.sentInitMsg != nil && t.writeError == nil {
 | 
	
		
			
				|  |  | -		t.cond.Wait()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if t.writeBytesLeft > 0 {
 | 
	
		
			
				|  |  | +		t.writeBytesLeft -= int64(len(p))
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		t.requestKeyExchange()
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	if t.writeError != nil {
 | 
	
		
			
				|  |  | -		return t.writeError
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if t.writePacketsLeft > 0 {
 | 
	
		
			
				|  |  | +		t.writePacketsLeft--
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		t.requestKeyExchange()
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	t.writtenSinceKex += uint64(len(p))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	switch p[0] {
 | 
	
		
			
				|  |  | -	case msgKexInit:
 | 
	
		
			
				|  |  | -		return errors.New("ssh: only handshakeTransport can send kexInit")
 | 
	
		
			
				|  |  | -	case msgNewKeys:
 | 
	
		
			
				|  |  | -		return errors.New("ssh: only handshakeTransport can send newKeys")
 | 
	
		
			
				|  |  | -	default:
 | 
	
		
			
				|  |  | -		return t.conn.writePacket(p)
 | 
	
		
			
				|  |  | +	if err := t.pushPacket(p); err != nil {
 | 
	
		
			
				|  |  | +		t.writeError = err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (t *handshakeTransport) Close() error {
 | 
	
		
			
				|  |  |  	return t.conn.Close()
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -// enterKeyExchange runs the key exchange. t.mu must be held while running this.
 | 
	
		
			
				|  |  | -func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
 | 
	
		
			
				|  |  | +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 | 
	
		
			
				|  |  |  	if debugHandshake {
 | 
	
		
			
				|  |  |  		log.Printf("%s entered key exchange", t.id())
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
 | 
	
		
			
				|  |  | -	if err != nil {
 | 
	
		
			
				|  |  | -		return err
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	otherInit := &kexInitMsg{}
 | 
	
		
			
				|  |  |  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
 | 
	
	
		
			
				|  | @@ -352,20 +517,20 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
 | 
	
		
			
				|  |  |  		clientVersion: t.clientVersion,
 | 
	
		
			
				|  |  |  		serverVersion: t.serverVersion,
 | 
	
		
			
				|  |  |  		clientKexInit: otherInitPacket,
 | 
	
		
			
				|  |  | -		serverKexInit: myInitPacket,
 | 
	
		
			
				|  |  | +		serverKexInit: t.sentInitPacket,
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	clientInit := otherInit
 | 
	
		
			
				|  |  | -	serverInit := myInit
 | 
	
		
			
				|  |  | +	serverInit := t.sentInitMsg
 | 
	
		
			
				|  |  |  	if len(t.hostKeys) == 0 {
 | 
	
		
			
				|  |  | -		clientInit = myInit
 | 
	
		
			
				|  |  | -		serverInit = otherInit
 | 
	
		
			
				|  |  | +		clientInit, serverInit = serverInit, clientInit
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		magics.clientKexInit = myInitPacket
 | 
	
		
			
				|  |  | +		magics.clientKexInit = t.sentInitPacket
 | 
	
		
			
				|  |  |  		magics.serverKexInit = otherInitPacket
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	algs, err := findAgreedAlgorithms(clientInit, serverInit)
 | 
	
		
			
				|  |  | +	var err error
 | 
	
		
			
				|  |  | +	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		return err
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -388,16 +553,16 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	kex, ok := kexAlgoMap[algs.kex]
 | 
	
		
			
				|  |  | +	kex, ok := kexAlgoMap[t.algorithms.kex]
 | 
	
		
			
				|  |  |  	if !ok {
 | 
	
		
			
				|  |  | -		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
 | 
	
		
			
				|  |  | +		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	var result *kexResult
 | 
	
		
			
				|  |  |  	if len(t.hostKeys) > 0 {
 | 
	
		
			
				|  |  | -		result, err = t.server(kex, algs, &magics)
 | 
	
		
			
				|  |  | +		result, err = t.server(kex, t.algorithms, &magics)
 | 
	
		
			
				|  |  |  	} else {
 | 
	
		
			
				|  |  | -		result, err = t.client(kex, algs, &magics)
 | 
	
		
			
				|  |  | +		result, err = t.client(kex, t.algorithms, &magics)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
	
		
			
				|  | @@ -409,7 +574,7 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	result.SessionID = t.sessionID
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	t.conn.prepareKeyChange(algs, result)
 | 
	
		
			
				|  |  | +	t.conn.prepareKeyChange(t.algorithms, result)
 | 
	
		
			
				|  |  |  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 | 
	
		
			
				|  |  |  		return err
 | 
	
		
			
				|  |  |  	}
 |