internal/transport: handle h2 errcode on header decoding (#3872)

Handles HTTP2 error code when malformed request/response header appears.
Fixes: #3819
diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go
index e7f2321..a12d6b8 100644
--- a/internal/transport/http2_client.go
+++ b/internal/transport/http2_client.go
@@ -1206,8 +1206,8 @@
 	state := &decodeState{}
 	// Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode.
 	state.data.isGRPC = !initialHeader
-	if err := state.decodeHeader(frame); err != nil {
-		t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
+	if h2code, err := state.decodeHeader(frame); err != nil {
+		t.closeStream(s, err, true, h2code, status.Convert(err), nil, endStream)
 		return
 	}
 
diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go
index 3be22fe..0cf1cc3 100644
--- a/internal/transport/http2_server.go
+++ b/internal/transport/http2_server.go
@@ -306,12 +306,12 @@
 	state := &decodeState{
 		serverSide: true,
 	}
-	if err := state.decodeHeader(frame); err != nil {
-		if se, ok := status.FromError(err); ok {
+	if h2code, err := state.decodeHeader(frame); err != nil {
+		if _, ok := status.FromError(err); ok {
 			t.controlBuf.put(&cleanupStream{
 				streamID: streamID,
 				rst:      true,
-				rstCode:  statusCodeConvTab[se.Code()],
+				rstCode:  h2code,
 				onWrite:  func() {},
 			})
 		}
diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go
index 5e1e7a6..4d15afb 100644
--- a/internal/transport/http_util.go
+++ b/internal/transport/http_util.go
@@ -73,13 +73,6 @@
 		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
 		http2.ErrCodeHTTP11Required:     codes.Internal,
 	}
-	statusCodeConvTab = map[codes.Code]http2.ErrCode{
-		codes.Internal:          http2.ErrCodeInternal,
-		codes.Canceled:          http2.ErrCodeCancel,
-		codes.Unavailable:       http2.ErrCodeRefusedStream,
-		codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
-		codes.PermissionDenied:  http2.ErrCodeInadequateSecurity,
-	}
 	// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
 	HTTPStatusConvTab = map[int]codes.Code{
 		// 400 Bad Request - INTERNAL.
@@ -222,11 +215,11 @@
 	return v, nil
 }
 
-func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error {
+func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) (http2.ErrCode, error) {
 	// frame.Truncated is set to true when framer detects that the current header
 	// list size hits MaxHeaderListSize limit.
 	if frame.Truncated {
-		return status.Error(codes.Internal, "peer header list size exceeded limit")
+		return http2.ErrCodeFrameSize, status.Error(codes.Internal, "peer header list size exceeded limit")
 	}
 
 	for _, hf := range frame.Fields {
@@ -235,10 +228,10 @@
 
 	if d.data.isGRPC {
 		if d.data.grpcErr != nil {
-			return d.data.grpcErr
+			return http2.ErrCodeProtocol, d.data.grpcErr
 		}
 		if d.serverSide {
-			return nil
+			return http2.ErrCodeNo, nil
 		}
 		if d.data.rawStatusCode == nil && d.data.statusGen == nil {
 			// gRPC status doesn't exist.
@@ -250,12 +243,12 @@
 			code := int(codes.Unknown)
 			d.data.rawStatusCode = &code
 		}
-		return nil
+		return http2.ErrCodeNo, nil
 	}
 
 	// HTTP fallback mode
 	if d.data.httpErr != nil {
-		return d.data.httpErr
+		return http2.ErrCodeProtocol, d.data.httpErr
 	}
 
 	var (
@@ -270,7 +263,7 @@
 		}
 	}
 
-	return status.Error(code, d.constructHTTPErrMsg())
+	return http2.ErrCodeProtocol, status.Error(code, d.constructHTTPErrMsg())
 }
 
 // constructErrMsg constructs error message to be returned in HTTP fallback mode.
diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go
index a3616f7..85a083f 100644
--- a/internal/transport/http_util_test.go
+++ b/internal/transport/http_util_test.go
@@ -23,6 +23,9 @@
 	"reflect"
 	"testing"
 	"time"
+
+	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/hpack"
 )
 
 func (s) TestTimeoutDecode(t *testing.T) {
@@ -185,3 +188,65 @@
 		}
 	}
 }
+
+func (s) TestDecodeHeaderH2ErrCode(t *testing.T) {
+	for _, test := range []struct {
+		name string
+		// input
+		metaHeaderFrame *http2.MetaHeadersFrame
+		serverSide      bool
+		// output
+		wantCode http2.ErrCode
+	}{
+		{
+			name: "valid header",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{
+				{Name: "content-type", Value: "application/grpc"},
+			}},
+			wantCode: http2.ErrCodeNo,
+		},
+		{
+			name: "valid header serverSide",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{
+				{Name: "content-type", Value: "application/grpc"},
+			}},
+			serverSide: true,
+			wantCode:   http2.ErrCodeNo,
+		},
+		{
+			name: "invalid grpc status header field",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{
+				{Name: "content-type", Value: "application/grpc"},
+				{Name: "grpc-status", Value: "xxxx"},
+			}},
+			wantCode: http2.ErrCodeProtocol,
+		},
+		{
+			name: "invalid http content type",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{
+				{Name: "content-type", Value: "application/json"},
+			}},
+			wantCode: http2.ErrCodeProtocol,
+		},
+		{
+			name: "http fallback and invalid http status",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{
+				// No content type provided then fallback into handling http error.
+				{Name: ":status", Value: "xxxx"},
+			}},
+			wantCode: http2.ErrCodeProtocol,
+		},
+		{
+			name:            "http2 frame size exceeds",
+			metaHeaderFrame: &http2.MetaHeadersFrame{Fields: nil, Truncated: true},
+			wantCode:        http2.ErrCodeFrameSize,
+		},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			state := &decodeState{serverSide: test.serverSide}
+			if h2code, _ := state.decodeHeader(test.metaHeaderFrame); h2code != test.wantCode {
+				t.Fatalf("decodeState.decodeHeader(%v) = %v, want %v", test.metaHeaderFrame, h2code, test.wantCode)
+			}
+		})
+	}
+}