ssh/agent: add checking for empty SSH requests

Previously empty SSH requests would cause a panic.

Change-Id: I8443fee50891b3d2b3b62ac01fb0b9e96244241f
GitHub-Last-Rev: 64f00d2bf2ee722f53e68b6bd4f70c722d7694bd
GitHub-Pull-Request: golang/crypto#58
Reviewed-on: https://go-review.googlesource.com/c/140237
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
index 2f8abf1..c31853a 100644
--- a/ssh/agent/client_test.go
+++ b/ssh/agent/client_test.go
@@ -8,11 +8,13 @@
 	"bytes"
 	"crypto/rand"
 	"errors"
+	"io"
 	"net"
 	"os"
 	"os/exec"
 	"path/filepath"
 	"strconv"
+	"sync"
 	"testing"
 	"time"
 
@@ -196,6 +198,63 @@
 
 }
 
+func TestMalformedRequests(t *testing.T) {
+	keyringAgent := NewKeyring()
+	listener, err := netListener()
+	if err != nil {
+		t.Fatalf("netListener: %v", err)
+	}
+	defer listener.Close()
+
+	testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
+		var wg sync.WaitGroup
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			c, err := listener.Accept()
+			if err != nil {
+				t.Errorf("listener.Accept: %v", err)
+				return
+			}
+			defer c.Close()
+
+			err = ServeAgent(keyringAgent, c)
+			if err == nil {
+				t.Error("ServeAgent should have returned an error to malformed input")
+			} else {
+				if (err != io.EOF) != wantServerErr {
+					t.Errorf("ServeAgent returned expected error: %v", err)
+				}
+			}
+		}()
+
+		c, err := net.Dial("tcp", listener.Addr().String())
+		if err != nil {
+			t.Fatalf("net.Dial: %v", err)
+		}
+		_, err = c.Write(requestBytes)
+		if err != nil {
+			t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
+		}
+		c.Close()
+		wg.Wait()
+	}
+
+	var testCases = []struct {
+		name          string
+		requestBytes  []byte
+		wantServerErr bool
+	}{
+		{"Empty request", []byte{}, false},
+		{"Short header", []byte{0x00}, true},
+		{"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
+		{"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
+	}
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
+	}
+}
+
 func TestAgent(t *testing.T) {
 	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
@@ -215,17 +274,26 @@
 	testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
 }
 
-// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
-// therefore is buffered (net.Pipe deadlocks if both sides start with
-// a write.)
-func netPipe() (net.Conn, net.Conn, error) {
+// netListener creates a localhost network listener.
+func netListener() (net.Listener, error) {
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 		listener, err = net.Listen("tcp", "[::1]:0")
 		if err != nil {
-			return nil, nil, err
+			return nil, err
 		}
 	}
+	return listener, nil
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+	listener, err := netListener()
+	if err != nil {
+		return nil, nil, err
+	}
 	defer listener.Close()
 	c1, err := net.Dial("tcp", listener.Addr().String())
 	if err != nil {
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
index a194976..6e7a1e0 100644
--- a/ssh/agent/server.go
+++ b/ssh/agent/server.go
@@ -541,6 +541,9 @@
 			return err
 		}
 		l := binary.BigEndian.Uint32(length[:])
+		if l == 0 {
+			return fmt.Errorf("agent: request size is 0")
+		}
 		if l > maxAgentResponseBytes {
 			// We also cap requests.
 			return fmt.Errorf("agent: request too large: %d", l)