ssh: allow server auth callbacks to send additional banners

Add a new BannerError error type that auth callbacks can return to send
banner to the client. While the BannerCallback can send the initial
banner message, auth callbacks might want to communicate more
information to the client to help them diagnose failures.

Updates golang/go#64962

Change-Id: I97a26480ff4064b95a0a26042b0a5e19737cfb62
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/558695
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
diff --git a/ssh/server.go b/ssh/server.go
index e2ae4f8..3ca9e89 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -462,6 +462,24 @@
 // It is returned in ServerAuthError.Errors from NewServerConn.
 var ErrNoAuth = errors.New("ssh: no auth passed yet")
 
+// BannerError is an error that can be returned by authentication handlers in
+// ServerConfig to send a banner message to the client.
+type BannerError struct {
+	Err     error
+	Message string
+}
+
+func (b *BannerError) Unwrap() error {
+	return b.Err
+}
+
+func (b *BannerError) Error() string {
+	if b.Err == nil {
+		return b.Message
+	}
+	return b.Err.Error()
+}
+
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
 	sessionID := s.transport.getSessionID()
 	var cache pubKeyCache
@@ -734,6 +752,18 @@
 			config.AuthLogCallback(s, userAuthReq.Method, authErr)
 		}
 
+		var bannerErr *BannerError
+		if errors.As(authErr, &bannerErr) {
+			if bannerErr.Message != "" {
+				bannerMsg := &userAuthBannerMsg{
+					Message: bannerErr.Message,
+				}
+				if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+					return nil, err
+				}
+			}
+		}
+
 		if authErr == nil {
 			break userAuthLoop
 		}
diff --git a/ssh/server_test.go b/ssh/server_test.go
index 5b47b9e..9057a9b 100644
--- a/ssh/server_test.go
+++ b/ssh/server_test.go
@@ -6,8 +6,10 @@
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"net"
+	"slices"
 	"strings"
 	"sync/atomic"
 	"testing"
@@ -225,6 +227,78 @@
 	}
 }
 
+func TestBannerError(t *testing.T) {
+	serverConfig := &ServerConfig{
+		BannerCallback: func(ConnMetadata) string {
+			return "banner from BannerCallback"
+		},
+		NoClientAuth: true,
+		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+			err := &BannerError{
+				Err:     errors.New("error from NoClientAuthCallback"),
+				Message: "banner from NoClientAuthCallback",
+			}
+			return nil, fmt.Errorf("wrapped: %w", err)
+		},
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     errors.New("error from PublicKeyCallback"),
+				Message: "banner from PublicKeyCallback",
+			}
+		},
+		KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     nil, // make sure that a nil inner error is allowed
+				Message: "banner from KeyboardInteractiveCallback",
+			}
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	var banners []string
+	clientConfig := &ClientConfig{
+		User: "test",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
+				return []string{"letmein"}, nil
+			}),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(msg string) error {
+			banners = append(banners, msg)
+			return nil
+		},
+	}
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	defer c.Close()
+
+	wantBanners := []string{
+		"banner from BannerCallback",
+		"banner from NoClientAuthCallback",
+		"banner from PublicKeyCallback",
+		"banner from KeyboardInteractiveCallback",
+	}
+	if !slices.Equal(banners, wantBanners) {
+		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+	}
+}
+
 type markerConn struct {
 	closed uint32
 	used   uint32