ssh: prevent double kex at connection start, 2nd try

The previous attempt would fail in the following scenario:

* select picks "first" kex from requestKex

* read loop receives a remote kex, posts on requestKex (which is now
  empty) [*] for sending out a response, and sends pendingKex on startKex.

* select picks pendingKex from startKex, and proceeds to run the key
  exchange.

* the posting on requestKex in [*] now triggers a second key exchange.

Fixes #18861. 

Change-Id: I443e82f1d04c7f17d1485fdb87072b9feec26aa8
Reviewed-on: https://go-review.googlesource.com/36055
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/handshake.go b/ssh/handshake.go
index e3f82c4..8de6506 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -66,8 +66,8 @@
 
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
-	// message. The boolean is whether this is the first request or not.
-	requestKex chan bool
+	// message.
+	requestKex chan struct{}
 
 	// If the other side requests or confirms a kex, its kexInit
 	// packet is sent here for the write loop to find it.
@@ -102,14 +102,14 @@
 		serverVersion: serverVersion,
 		clientVersion: clientVersion,
 		incoming:      make(chan []byte, chanSize),
-		requestKex:    make(chan bool, 1),
+		requestKex:    make(chan struct{}, 1),
 		startKex:      make(chan *pendingKex, 1),
 
 		config: config,
 	}
 
 	// We always start with a mandatory key exchange.
-	t.requestKex <- true
+	t.requestKex <- struct{}{}
 	return t
 }
 
@@ -166,6 +166,7 @@
 	if write {
 		action = "sent"
 	}
+
 	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
 		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
 	} else {
@@ -230,14 +231,13 @@
 
 func (t *handshakeTransport) requestKeyExchange() {
 	select {
-	case t.requestKex <- false:
+	case t.requestKex <- struct{}{}:
 	default:
 		// something already requested a kex, so do nothing.
 	}
 }
 
 func (t *handshakeTransport) kexLoop() {
-	firstSent := false
 
 write:
 	for t.getWriteError() == nil {
@@ -251,18 +251,8 @@
 				if !ok {
 					break write
 				}
-			case requestFirst := <-t.requestKex:
-				// For the first key exchange, both
-				// sides will initiate a key exchange,
-				// and both channels will fire. To
-				// avoid doing two key exchanges in a
-				// row, ignore our own request for an
-				// initial kex if we have already sent
-				// it out.
-				if firstSent && requestFirst {
-
-					continue
-				}
+			case <-t.requestKex:
+				break
 			}
 
 			if !sent {
@@ -270,7 +260,6 @@
 					t.recordWriteError(err)
 					break
 				}
-				firstSent = true
 				sent = true
 			}
 		}
@@ -287,7 +276,8 @@
 
 		// We're not servicing t.startKex, but the remote end
 		// has just sent us a kexInitMsg, so it can't send
-		// another key change request.
+		// another key change request, until we close the done
+		// channel on the pendingKex request.
 
 		err := t.enterKeyExchange(request.otherInit)
 
@@ -301,6 +291,23 @@
 		} else if t.algorithms != nil {
 			t.writeBytesLeft = t.algorithms.w.rekeyBytes()
 		}
+
+		// we have completed the key exchange. Since the
+		// reader is still blocked, it is safe to clear out
+		// the requestKex channel. This avoids the situation
+		// where: 1) we consumed our own request for the
+		// initial kex, and 2) the kex from the remote side
+		// caused another send on the requestKex channel,
+	clear:
+		for {
+			select {
+			case <-t.requestKex:
+				//
+			default:
+				break clear
+			}
+		}
+
 		request.done <- t.writeError
 
 		// kex finished. Push packets that we received while