Synchronize WriteStatus with WriteHeader on server. (#2074)
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 8b93e22..ab35618 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -682,28 +682,7 @@
})
}
-// WriteHeader sends the header metedata md back to the client.
-func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
- if s.headerOk || s.getState() == streamDone {
- return ErrIllegalHeaderWrite
- }
- s.headerOk = true
- if md.Len() > 0 {
- if s.header.Len() > 0 {
- s.header = metadata.Join(s.header, md)
- } else {
- s.header = md
- }
- }
- md = s.header
- // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
- // first and create a slice of that exact size.
- headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
- headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
- headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
- if s.sendCompress != "" {
- headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
- }
+func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField {
for k, vv := range md {
if isReservedHeader(k) {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@@ -713,6 +692,37 @@
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
+ return headerFields
+}
+
+// WriteHeader sends the header metedata md back to the client.
+func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
+ if s.updateHeaderSent() || s.getState() == streamDone {
+ return ErrIllegalHeaderWrite
+ }
+ s.hdrMu.Lock()
+ if md.Len() > 0 {
+ if s.header.Len() > 0 {
+ s.header = metadata.Join(s.header, md)
+ } else {
+ s.header = md
+ }
+ }
+ t.writeHeaderLocked(s)
+ s.hdrMu.Unlock()
+ return nil
+}
+
+func (t *http2Server) writeHeaderLocked(s *Stream) {
+ // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
+ // first and create a slice of that exact size.
+ headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
+ if s.sendCompress != "" {
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
+ }
+ headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
t.controlBuf.put(&headerFrame{
streamID: s.id,
hf: headerFields,
@@ -728,7 +738,6 @@
outHeader := &stats.OutHeader{}
t.stats.HandleRPC(s.Context(), outHeader)
}
- return nil
}
// WriteStatus sends stream status to the client and terminates the stream.
@@ -736,21 +745,20 @@
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
- if !s.headerOk && s.header.Len() > 0 {
- if err := t.WriteHeader(s, nil); err != nil {
- return err
- }
- } else {
- if s.getState() == streamDone {
- return nil
- }
+ if s.getState() == streamDone {
+ return nil
}
+ s.hdrMu.Lock()
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
- if !s.headerOk {
- headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
- headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
+ if !s.updateHeaderSent() { // No headers have been sent.
+ if len(s.header) > 0 { // Send a separate header frame.
+ t.writeHeaderLocked(s)
+ } else { // Send a trailer only response.
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
+ }
}
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
@@ -766,16 +774,8 @@
}
// Attach the trailer metadata.
- for k, vv := range s.trailer {
- // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
- if isReservedHeader(k) {
- continue
- }
- for _, v := range vv {
- headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
- }
- }
- trailer := &headerFrame{
+ headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer)
+ trailingHeader := &headerFrame{
streamID: s.id,
hf: headerFields,
endStream: true,
@@ -783,7 +783,8 @@
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
}
- t.closeStream(s, false, 0, trailer, true)
+ s.hdrMu.Unlock()
+ t.closeStream(s, false, 0, trailingHeader, true)
if t.stats != nil {
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
@@ -793,7 +794,7 @@
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
- if !s.headerOk { // Headers haven't been written yet.
+ if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
return streamErrorf(codes.Internal, "transport: %v", err)
diff --git a/transport/transport.go b/transport/transport.go
index 2f643a3..f51f878 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -185,13 +185,20 @@
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
- header metadata.MD // the received header metadata.
- trailer metadata.MD // the key-value map of trailer metadata.
- headerOk bool // becomes true from the first header is about to send
- state streamState
+ // hdrMu protects header and trailer metadata on the server-side.
+ hdrMu sync.Mutex
+ header metadata.MD // the received header metadata.
+ trailer metadata.MD // the key-value map of trailer metadata.
- status *status.Status // the status error received from the server
+ // On the server-side, headerSent is atomically set to 1 when the headers are sent out.
+ headerSent uint32
+
+ state streamState
+
+ // On client-side it is the status error received from the server.
+ // On server-side it is unused.
+ status *status.Status
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
@@ -201,6 +208,17 @@
contentSubtype string
}
+// isHeaderSent is only valid on the server-side.
+func (s *Stream) isHeaderSent() bool {
+ return atomic.LoadUint32(&s.headerSent) == 1
+}
+
+// updateHeaderSent updates headerSent and returns true
+// if it was alreay set. It is valid only on server-side.
+func (s *Stream) updateHeaderSent() bool {
+ return atomic.SwapUint32(&s.headerSent, 1) == 1
+}
+
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
@@ -313,10 +331,12 @@
if md.Len() == 0 {
return nil
}
- if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
+ if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
+ s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
+ s.hdrMu.Unlock()
return nil
}
@@ -335,7 +355,12 @@
if md.Len() == 0 {
return nil
}
+ if s.getState() == streamDone {
+ return ErrIllegalHeaderWrite
+ }
+ s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
+ s.hdrMu.Unlock()
return nil
}