ssh: invert algorithm choices on the server

At the protocol level, SSH lets client and server specify different
algorithms for the read and write half of the connection. This has
never worked correctly, as Client-to-Server was always interpreted as
the "write" side, even if we were the server.

This has never been a problem because, apparently, there are no
clients that insist on different algorithm choices running against Go
SSH servers.

Since the SSH package does not expose a mechanism to specify
algorithms for read/write separately, there is end-to-end for this
change, so add a unittest instead.

Change-Id: Ie3aa781630a3bb7a3b0e3754cb67b3ce12581544
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/172538
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/common.go b/ssh/common.go
index 04f3620..d97415d 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -109,6 +109,7 @@
 	return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
 }
 
+// directionAlgorithms records algorithm choices in one direction (either read or write)
 type directionAlgorithms struct {
 	Cipher      string
 	MAC         string
@@ -137,7 +138,7 @@
 	r       directionAlgorithms
 }
 
-func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
+func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
 	result := &algorithms{}
 
 	result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
@@ -150,32 +151,37 @@
 		return
 	}
 
-	result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+	stoc, ctos := &result.w, &result.r
+	if isClient {
+		ctos, stoc = stoc, ctos
+	}
+
+	ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+	stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+	ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+	stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+	ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+	stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
 	if err != nil {
 		return
 	}
diff --git a/ssh/common_test.go b/ssh/common_test.go
new file mode 100644
index 0000000..96744dc
--- /dev/null
+++ b/ssh/common_test.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestFindAgreedAlgorithms(t *testing.T) {
+	initKex := func(k *kexInitMsg) {
+		if k.KexAlgos == nil {
+			k.KexAlgos = []string{"kex1"}
+		}
+		if k.ServerHostKeyAlgos == nil {
+			k.ServerHostKeyAlgos = []string{"hostkey1"}
+		}
+		if k.CiphersClientServer == nil {
+			k.CiphersClientServer = []string{"cipher1"}
+
+		}
+		if k.CiphersServerClient == nil {
+			k.CiphersServerClient = []string{"cipher1"}
+
+		}
+		if k.MACsClientServer == nil {
+			k.MACsClientServer = []string{"mac1"}
+
+		}
+		if k.MACsServerClient == nil {
+			k.MACsServerClient = []string{"mac1"}
+
+		}
+		if k.CompressionClientServer == nil {
+			k.CompressionClientServer = []string{"compression1"}
+
+		}
+		if k.CompressionServerClient == nil {
+			k.CompressionServerClient = []string{"compression1"}
+
+		}
+		if k.LanguagesClientServer == nil {
+			k.LanguagesClientServer = []string{"language1"}
+
+		}
+		if k.LanguagesServerClient == nil {
+			k.LanguagesServerClient = []string{"language1"}
+
+		}
+	}
+
+	initDirAlgs := func(a *directionAlgorithms) {
+		if a.Cipher == "" {
+			a.Cipher = "cipher1"
+		}
+		if a.MAC == "" {
+			a.MAC = "mac1"
+		}
+		if a.Compression == "" {
+			a.Compression = "compression1"
+		}
+	}
+
+	initAlgs := func(a *algorithms) {
+		if a.kex == "" {
+			a.kex = "kex1"
+		}
+		if a.hostKey == "" {
+			a.hostKey = "hostkey1"
+		}
+		initDirAlgs(&a.r)
+		initDirAlgs(&a.w)
+	}
+
+	type testcase struct {
+		name                   string
+		clientIn, serverIn     kexInitMsg
+		wantClient, wantServer algorithms
+		wantErr                bool
+	}
+
+	cases := []testcase{
+		testcase{
+			name: "standard",
+		},
+
+		testcase{
+			name: "no common hostkey",
+			serverIn: kexInitMsg{
+				ServerHostKeyAlgos: []string{"hostkey2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common kex",
+			serverIn: kexInitMsg{
+				KexAlgos: []string{"kex2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "client decides cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher1", "cipher2"},
+				CiphersServerClient: []string{"cipher2", "cipher3"},
+			},
+			clientIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2", "cipher1"},
+				CiphersServerClient: []string{"cipher3", "cipher2"},
+			},
+			wantClient: algorithms{
+				r: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				w: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+			wantServer: algorithms{
+				w: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				r: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+		},
+
+		// TODO(hanwen): fix and add tests for AEAD ignoring
+		// the MACs field
+	}
+
+	for i := range cases {
+		initKex(&cases[i].clientIn)
+		initKex(&cases[i].serverIn)
+		initAlgs(&cases[i].wantClient)
+		initAlgs(&cases[i].wantServer)
+	}
+
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
+			clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
+
+			serverHasErr := serverErr != nil
+			clientHasErr := clientErr != nil
+			if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
+				t.Fatalf("got client/server error (%v, %v), want hasError %v",
+					clientErr, serverErr, c.wantErr)
+
+			}
+			if c.wantErr {
+				return
+			}
+
+			if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
+				t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
+			}
+			if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
+				t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
+			}
+		})
+	}
+}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 4f7912e..2b10b05 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -543,7 +543,8 @@
 
 	clientInit := otherInit
 	serverInit := t.sentInitMsg
-	if len(t.hostKeys) == 0 {
+	isClient := len(t.hostKeys) == 0
+	if isClient {
 		clientInit, serverInit = serverInit, clientInit
 
 		magics.clientKexInit = t.sentInitPacket
@@ -551,7 +552,7 @@
 	}
 
 	var err error
-	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
+	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
 	if err != nil {
 		return err
 	}