Http status to grpc status conversion (#1195)
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 80583ab..3e5ff73 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -968,18 +968,16 @@
}
s.bytesReceived = true
var state decodeState
- for _, hf := range frame.Fields {
- if err := state.processHeaderField(hf); err != nil {
- s.mu.Lock()
- if !s.headerDone {
- close(s.headerChan)
- s.headerDone = true
- }
- s.mu.Unlock()
- s.write(recvMsg{err: err})
- // Something wrong. Stops reading even when there is remaining.
- return
+ if err := state.decodeResponseHeader(frame); err != nil {
+ s.mu.Lock()
+ if !s.headerDone {
+ close(s.headerChan)
+ s.headerDone = true
}
+ s.mu.Unlock()
+ s.write(recvMsg{err: err})
+ // Something wrong. Stops reading even when there is remaining.
+ return
}
endStream := frame.StreamEnded()
diff --git a/transport/http_util.go b/transport/http_util.go
index 795d5d1..9b31717 100644
--- a/transport/http_util.go
+++ b/transport/http_util.go
@@ -40,6 +40,7 @@
"fmt"
"io"
"net"
+ "net/http"
"strconv"
"strings"
"sync/atomic"
@@ -88,6 +89,24 @@
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
}
+ httpStatusConvTab = map[int]codes.Code{
+ // 400 Bad Request - INTERNAL.
+ http.StatusBadRequest: codes.Internal,
+ // 401 Unauthorized - UNAUTHENTICATED.
+ http.StatusUnauthorized: codes.Unauthenticated,
+ // 403 Forbidden - PERMISSION_DENIED.
+ http.StatusForbidden: codes.PermissionDenied,
+ // 404 Not Found - UNIMPLEMENTED.
+ http.StatusNotFound: codes.Unimplemented,
+ // 429 Too Many Requests - UNAVAILABLE.
+ http.StatusTooManyRequests: codes.Unavailable,
+ // 502 Bad Gateway - UNAVAILABLE.
+ http.StatusBadGateway: codes.Unavailable,
+ // 503 Service Unavailable - UNAVAILABLE.
+ http.StatusServiceUnavailable: codes.Unavailable,
+ // 504 Gateway timeout - UNAVAILABLE.
+ http.StatusGatewayTimeout: codes.Unavailable,
+ }
)
// Records the states during HPACK decoding. Must be reset once the
@@ -100,8 +119,9 @@
statusGen *status.Status
// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
// intended for direct access outside of parsing.
- rawStatusCode int32
+ rawStatusCode *int
rawStatusMsg string
+ httpStatus *int
// Server side only fields.
timeoutSet bool
timeout time.Duration
@@ -159,7 +179,7 @@
func (d *decodeState) status() *status.Status {
if d.statusGen == nil {
// No status-details were provided; generate status using code/msg.
- d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg)
+ d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
}
return d.statusGen
}
@@ -193,6 +213,44 @@
return v, nil
}
+func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error {
+ for _, hf := range frame.Fields {
+ if err := d.processHeaderField(hf); err != nil {
+ return err
+ }
+ }
+
+ // If grpc status exists, no need to check further.
+ if d.rawStatusCode != nil || d.statusGen != nil {
+ return nil
+ }
+
+ // If grpc status doesn't exist and http status doesn't exist,
+ // then it's a malformed header.
+ if d.httpStatus == nil {
+ return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
+ }
+
+ if *(d.httpStatus) != http.StatusOK {
+ code, ok := httpStatusConvTab[*(d.httpStatus)]
+ if !ok {
+ code = codes.Unknown
+ }
+ return streamErrorf(code, http.StatusText(*(d.httpStatus)))
+ }
+
+ // gRPC status doesn't exist and http status is OK.
+ // Set rawStatusCode to be unknown and return nil error.
+ // So that, if the stream has ended this Unknown status
+ // will be propogated to the user.
+ // Otherwise, it will be ignored. In which case, status from
+ // a later trailer, that has StreamEnded flag set, is propogated.
+ code := int(codes.Unknown)
+ d.rawStatusCode = &code
+ return nil
+
+}
+
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name {
case "content-type":
@@ -206,7 +264,7 @@
if err != nil {
return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)
}
- d.rawStatusCode = int32(code)
+ d.rawStatusCode = &code
case "grpc-message":
d.rawStatusMsg = decodeGrpcMessage(f.Value)
case "grpc-status-details-bin":
@@ -227,6 +285,12 @@
}
case ":path":
d.method = f.Value
+ case ":status":
+ code, err := strconv.Atoi(f.Value)
+ if err != nil {
+ return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err)
+ }
+ d.httpStatus = &code
default:
if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
if d.mdata == nil {
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 7429f2e..0b534d2 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -34,11 +34,13 @@
package transport
import (
+ "bufio"
"bytes"
"fmt"
"io"
"math"
"net"
+ "net/http"
"reflect"
"strconv"
"strings"
@@ -1416,3 +1418,192 @@
break
}
}
+
+// A function of type writeHeaders writes out
+// http status with the given stream ID using the given framer.
+type writeHeaders func(*http2.Framer, uint32, int) error
+
+func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error {
+ var buf bytes.Buffer
+ henc := hpack.NewEncoder(&buf)
+ henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)})
+ if err := framer.WriteHeaders(http2.HeadersFrameParam{
+ StreamID: sid,
+ BlockFragment: buf.Bytes(),
+ EndStream: true,
+ EndHeaders: true,
+ }); err != nil {
+ return err
+ }
+ return nil
+}
+
+func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error {
+ var buf bytes.Buffer
+ henc := hpack.NewEncoder(&buf)
+ henc.WriteField(hpack.HeaderField{
+ Name: ":status",
+ Value: fmt.Sprint(http.StatusOK),
+ })
+ if err := framer.WriteHeaders(http2.HeadersFrameParam{
+ StreamID: sid,
+ BlockFragment: buf.Bytes(),
+ EndHeaders: true,
+ }); err != nil {
+ return err
+ }
+ buf.Reset()
+ henc.WriteField(hpack.HeaderField{
+ Name: ":status",
+ Value: fmt.Sprint(httpStatus),
+ })
+ if err := framer.WriteHeaders(http2.HeadersFrameParam{
+ StreamID: sid,
+ BlockFragment: buf.Bytes(),
+ EndStream: true,
+ EndHeaders: true,
+ }); err != nil {
+ return err
+ }
+ return nil
+}
+
+type httpServer struct {
+ conn net.Conn
+ httpStatus int
+ wh writeHeaders
+}
+
+func (s *httpServer) start(t *testing.T, lis net.Listener) {
+ // Launch an HTTP server to send back header with httpStatus.
+ go func() {
+ var err error
+ s.conn, err = lis.Accept()
+ if err != nil {
+ t.Errorf("Error accepting connection: %v", err)
+ return
+ }
+ defer s.conn.Close()
+ // Read preface sent by client.
+ if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil {
+ t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err)
+ return
+ }
+ reader := bufio.NewReaderSize(s.conn, http2IOBufSize)
+ writer := bufio.NewWriterSize(s.conn, http2IOBufSize)
+ framer := http2.NewFramer(writer, reader)
+ if err = framer.WriteSettingsAck(); err != nil {
+ t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
+ return
+ }
+ var sid uint32
+ // Read frames until a header is received.
+ for {
+ frame, err := framer.ReadFrame()
+ if err != nil {
+ t.Errorf("Error at server-side while reading frame. Err: %v", err)
+ return
+ }
+ if hframe, ok := frame.(*http2.HeadersFrame); ok {
+ sid = hframe.Header().StreamID
+ break
+ }
+ }
+ if err = s.wh(framer, sid, s.httpStatus); err != nil {
+ t.Errorf("Error at server-side while writing headers. Err: %v", err)
+ return
+ }
+ writer.Flush()
+ }()
+}
+
+func (s *httpServer) cleanUp() {
+ if s.conn != nil {
+ s.conn.Close()
+ }
+}
+
+func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream *Stream, cleanUp func()) {
+ var (
+ err error
+ lis net.Listener
+ server *httpServer
+ client ClientTransport
+ )
+ cleanUp = func() {
+ if lis != nil {
+ lis.Close()
+ }
+ if server != nil {
+ server.cleanUp()
+ }
+ if client != nil {
+ client.Close()
+ }
+ }
+ defer func() {
+ if err != nil {
+ cleanUp()
+ }
+ }()
+ lis, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Failed to listen. Err: %v", err)
+ }
+ server = &httpServer{
+ httpStatus: httpStatus,
+ wh: wh,
+ }
+ server.start(t, lis)
+ client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{})
+ if err != nil {
+ t.Fatalf("Error creating client. Err: %v", err)
+ }
+ stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true})
+ if err != nil {
+ t.Fatalf("Error creating stream at client-side. Err: %v", err)
+ }
+ return
+}
+
+func TestHTTPToGRPCStatusMapping(t *testing.T) {
+ for k := range httpStatusConvTab {
+ testHTTPToGRPCStatusMapping(t, k, writeOneHeader)
+ }
+}
+
+func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) {
+ stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
+ defer cleanUp()
+ want := httpStatusConvTab[httpStatus]
+ _, err := stream.Read([]byte{})
+ if err == nil {
+ t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
+ }
+ serr, ok := err.(StreamError)
+ if !ok {
+ t.Fatalf("err.(Type) = %T, want StreamError", err)
+ }
+ if want != serr.Code {
+ t.Fatalf("Want error code: %v, got: %v", want, serr.Code)
+ }
+}
+
+func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
+ stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
+ defer cleanUp()
+ _, err := stream.Read([]byte{})
+ if err != io.EOF {
+ t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
+ }
+ want := codes.Unknown
+ stream.mu.Lock()
+ defer stream.mu.Unlock()
+ if stream.status.Code() != want {
+ t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want)
+ }
+}
+
+func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
+ testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
+}