Merge branch 'master' into service_config_pr
diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go
index 199bbe1..dfe8c8f 100644
--- a/benchmark/worker/benchmark_client.go
+++ b/benchmark/worker/benchmark_client.go
@@ -37,6 +37,7 @@
"math"
"runtime"
"sync"
+ "syscall"
"time"
"golang.org/x/net/context"
@@ -85,6 +86,7 @@
lastResetTime time.Time
histogramOptions stats.HistogramOptions
lockingHistograms []lockingHistogram
+ rusageLastReset *syscall.Rusage
}
func printClientConfig(config *testpb.ClientConfig) {
@@ -226,6 +228,9 @@
return nil, err
}
+ rusage := new(syscall.Rusage)
+ syscall.Getrusage(syscall.RUSAGE_SELF, rusage)
+
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
bc := &benchmarkClient{
histogramOptions: stats.HistogramOptions{
@@ -236,9 +241,10 @@
},
lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns), rpcCountPerConn*len(conns)),
- stop: make(chan bool),
- lastResetTime: time.Now(),
- closeConns: closeConns,
+ stop: make(chan bool),
+ lastResetTime: time.Now(),
+ closeConns: closeConns,
+ rusageLastReset: rusage,
}
if err = performRPCs(config, conns, bc); err != nil {
@@ -338,8 +344,9 @@
// getStats returns the stats for benchmark client.
// It resets lastResetTime and all histograms if argument reset is true.
func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
- var timeElapsed float64
+ var wallTimeElapsed, uTimeElapsed, sTimeElapsed float64
mergedHistogram := stats.NewHistogram(bc.histogramOptions)
+ latestRusage := new(syscall.Rusage)
if reset {
// Merging histogram may take some time.
@@ -353,14 +360,21 @@
mergedHistogram.Merge(toMerge[i])
}
- timeElapsed = time.Since(bc.lastResetTime).Seconds()
+ wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
+ syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage)
+ uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage)
+
+ bc.rusageLastReset = latestRusage
bc.lastResetTime = time.Now()
} else {
// Merge only, not reset.
for i := range bc.lockingHistograms {
bc.lockingHistograms[i].mergeInto(mergedHistogram)
}
- timeElapsed = time.Since(bc.lastResetTime).Seconds()
+
+ wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
+ syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage)
+ uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage)
}
b := make([]uint32, len(mergedHistogram.Buckets), len(mergedHistogram.Buckets))
@@ -376,9 +390,9 @@
SumOfSquares: float64(mergedHistogram.SumOfSquares),
Count: float64(mergedHistogram.Count),
},
- TimeElapsed: timeElapsed,
- TimeUser: 0,
- TimeSystem: 0,
+ TimeElapsed: wallTimeElapsed,
+ TimeUser: uTimeElapsed,
+ TimeSystem: sTimeElapsed,
}
}
diff --git a/benchmark/worker/benchmark_server.go b/benchmark/worker/benchmark_server.go
index 667ef2c..0d20581 100644
--- a/benchmark/worker/benchmark_server.go
+++ b/benchmark/worker/benchmark_server.go
@@ -38,6 +38,7 @@
"strconv"
"strings"
"sync"
+ "syscall"
"time"
"google.golang.org/grpc"
@@ -55,11 +56,12 @@
)
type benchmarkServer struct {
- port int
- cores int
- closeFunc func()
- mu sync.RWMutex
- lastResetTime time.Time
+ port int
+ cores int
+ closeFunc func()
+ mu sync.RWMutex
+ lastResetTime time.Time
+ rusageLastReset *syscall.Rusage
}
func printServerConfig(config *testpb.ServerConfig) {
@@ -156,18 +158,35 @@
grpclog.Fatalf("failed to get port number from server address: %v", err)
}
- return &benchmarkServer{port: p, cores: numOfCores, closeFunc: closeFunc, lastResetTime: time.Now()}, nil
+ rusage := new(syscall.Rusage)
+ syscall.Getrusage(syscall.RUSAGE_SELF, rusage)
+
+ return &benchmarkServer{
+ port: p,
+ cores: numOfCores,
+ closeFunc: closeFunc,
+ lastResetTime: time.Now(),
+ rusageLastReset: rusage,
+ }, nil
}
// getStats returns the stats for benchmark server.
// It resets lastResetTime if argument reset is true.
func (bs *benchmarkServer) getStats(reset bool) *testpb.ServerStats {
- // TODO wall time, sys time, user time.
bs.mu.RLock()
defer bs.mu.RUnlock()
- timeElapsed := time.Since(bs.lastResetTime).Seconds()
+ wallTimeElapsed := time.Since(bs.lastResetTime).Seconds()
+ rusageLatest := new(syscall.Rusage)
+ syscall.Getrusage(syscall.RUSAGE_SELF, rusageLatest)
+ uTimeElapsed, sTimeElapsed := cpuTimeDiff(bs.rusageLastReset, rusageLatest)
+
if reset {
bs.lastResetTime = time.Now()
+ bs.rusageLastReset = rusageLatest
}
- return &testpb.ServerStats{TimeElapsed: timeElapsed, TimeUser: 0, TimeSystem: 0}
+ return &testpb.ServerStats{
+ TimeElapsed: wallTimeElapsed,
+ TimeUser: uTimeElapsed,
+ TimeSystem: sTimeElapsed,
+ }
}
diff --git a/benchmark/worker/main.go b/benchmark/worker/main.go
index 17c5251..8a80406 100644
--- a/benchmark/worker/main.go
+++ b/benchmark/worker/main.go
@@ -38,6 +38,8 @@
"fmt"
"io"
"net"
+ "net/http"
+ _ "net/http/pprof"
"runtime"
"strconv"
"time"
@@ -50,8 +52,10 @@
)
var (
- driverPort = flag.Int("driver_port", 10000, "port for communication with driver")
- serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message")
+ driverPort = flag.Int("driver_port", 10000, "port for communication with driver")
+ serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message")
+ pprofPort = flag.Int("pprof_port", -1, "Port for pprof debug server to listen on. Pprof server doesn't start if unset")
+ blockProfRate = flag.Int("block_prof_rate", 0, "fraction of goroutine blocking events to report in blocking profile")
)
type byteBufCodec struct {
@@ -227,5 +231,14 @@
s.Stop()
}()
+ runtime.SetBlockProfileRate(*blockProfRate)
+
+ if *pprofPort >= 0 {
+ go func() {
+ grpclog.Println("Starting pprof server on port " + strconv.Itoa(*pprofPort))
+ grpclog.Println(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil))
+ }()
+ }
+
s.Serve(lis)
}
diff --git a/benchmark/worker/util.go b/benchmark/worker/util.go
index f0016ce..6f9b2b0 100644
--- a/benchmark/worker/util.go
+++ b/benchmark/worker/util.go
@@ -36,6 +36,7 @@
"log"
"os"
"path/filepath"
+ "syscall"
)
// abs returns the absolute path the given relative file or directory path,
@@ -52,6 +53,20 @@
return filepath.Join(v, rel)
}
+func cpuTimeDiff(first *syscall.Rusage, latest *syscall.Rusage) (float64, float64) {
+ var (
+ utimeDiffs = latest.Utime.Sec - first.Utime.Sec
+ utimeDiffus = latest.Utime.Usec - first.Utime.Usec
+ stimeDiffs = latest.Stime.Sec - first.Stime.Sec
+ stimeDiffus = latest.Stime.Usec - first.Stime.Usec
+ )
+
+ uTimeElapsed := float64(utimeDiffs) + float64(utimeDiffus)*1.0e-6
+ sTimeElapsed := float64(stimeDiffs) + float64(stimeDiffus)*1.0e-6
+
+ return uTimeElapsed, sTimeElapsed
+}
+
func goPackagePath(pkg string) (path string, err error) {
gp := os.Getenv("GOPATH")
if gp == "" {
diff --git a/call.go b/call.go
index e92a4bc..e370aee 100644
--- a/call.go
+++ b/call.go
@@ -43,6 +43,7 @@
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
@@ -79,7 +80,7 @@
return
}
}
- if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
+ if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK {
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
// Fix the order if necessary.
dopts.copts.StatsHandler.HandleRPC(ctx, inPayload)
@@ -267,7 +268,7 @@
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
- if _, ok := err.(*rpcError); ok {
+ if _, ok := err.(status.Status); ok {
return err
}
if err == errConnClosing || err == errConnUnavailable {
@@ -321,6 +322,6 @@
put()
put = nil
}
- return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
+ return stream.Status().Err()
}
}
diff --git a/call_test.go b/call_test.go
index 3c2165e..63e87c2 100644
--- a/call_test.go
+++ b/call_test.go
@@ -46,6 +46,7 @@
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
@@ -99,21 +100,21 @@
return
}
if v == "weird error" {
- h.t.WriteStatus(s, codes.Internal, weirdError)
+ h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
return
}
if v == "canceled" {
canceled++
- h.t.WriteStatus(s, codes.Internal, "")
+ h.t.WriteStatus(s, status.New(codes.Internal, ""))
return
}
if v == "port" {
- h.t.WriteStatus(s, codes.Internal, h.port)
+ h.t.WriteStatus(s, status.New(codes.Internal, h.port))
return
}
if v != expectedRequest {
- h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr))
+ h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
return
}
}
@@ -124,7 +125,7 @@
return
}
h.t.Write(s, reply, &transport.Options{})
- h.t.WriteStatus(s, codes.OK, "")
+ h.t.WriteStatus(s, status.New(codes.OK, ""))
}
type server struct {
@@ -239,7 +240,7 @@
var reply string
req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
- if _, ok := err.(*rpcError); !ok {
+ if _, ok := err.(status.Status); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
@@ -255,7 +256,7 @@
var reply string
req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
- if _, ok := err.(*rpcError); !ok {
+ if _, ok := err.(status.Status); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if got, want := ErrorDesc(err), weirdError; got != want {
diff --git a/rpc_util.go b/rpc_util.go
index 832fd46..1fad9d4 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -37,7 +37,6 @@
"bytes"
"compress/gzip"
"encoding/binary"
- "fmt"
"io"
"io/ioutil"
"math"
@@ -50,6 +49,7 @@
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
@@ -189,7 +189,9 @@
// unary RPC.
func Peer(peer *peer.Peer) CallOption {
return afterCall(func(c *callInfo) {
- *peer = *c.peer
+ if c.peer != nil {
+ *peer = *c.peer
+ }
})
}
@@ -370,88 +372,56 @@
return nil
}
-// rpcError defines the status from an RPC.
-type rpcError struct {
- code codes.Code
- desc string
-}
-
-func (e *rpcError) Error() string {
- return fmt.Sprintf("rpc error: code = %s desc = %s", e.code, e.desc)
-}
-
// Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown.
+//
+// Deprecated; use status.FromError and Code method instead.
func Code(err error) codes.Code {
- if err == nil {
- return codes.OK
- }
- if e, ok := err.(*rpcError); ok {
- return e.code
+ if s, ok := status.FromError(err); ok {
+ return s.Code()
}
return codes.Unknown
}
// ErrorDesc returns the error description of err if it was produced by the rpc system.
// Otherwise, it returns err.Error() or empty string when err is nil.
+//
+// Deprecated; use status.FromError and Message method instead.
func ErrorDesc(err error) string {
- if err == nil {
- return ""
- }
- if e, ok := err.(*rpcError); ok {
- return e.desc
+ if s, ok := status.FromError(err); ok {
+ return s.Message()
}
return err.Error()
}
// Errorf returns an error containing an error code and a description;
// Errorf returns nil if c is OK.
+//
+// Deprecated; use status.Errorf instead.
func Errorf(c codes.Code, format string, a ...interface{}) error {
- if c == codes.OK {
- return nil
- }
- return &rpcError{
- code: c,
- desc: fmt.Sprintf(format, a...),
- }
+ return status.Errorf(c, format, a...)
}
-// toRPCErr converts an error into a rpcError.
+// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
switch e := err.(type) {
- case *rpcError:
+ case status.Status:
return err
case transport.StreamError:
- return &rpcError{
- code: e.Code,
- desc: e.Desc,
- }
+ return status.Error(e.Code, e.Desc)
case transport.ConnectionError:
- return &rpcError{
- code: codes.Internal,
- desc: e.Desc,
- }
+ return status.Error(codes.Internal, e.Desc)
default:
switch err {
case context.DeadlineExceeded:
- return &rpcError{
- code: codes.DeadlineExceeded,
- desc: err.Error(),
- }
+ return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
- return &rpcError{
- code: codes.Canceled,
- desc: err.Error(),
- }
+ return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
- return &rpcError{
- code: codes.FailedPrecondition,
- desc: err.Error(),
- }
+ return status.Error(codes.FailedPrecondition, err.Error())
}
-
}
- return Errorf(codes.Unknown, "%v", err)
+ return status.Error(codes.Unknown, err.Error())
}
// convertCode converts a standard Go error into its canonical code. Note that
diff --git a/rpc_util_test.go b/rpc_util_test.go
index 375e42b..f2b43f0 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -41,8 +41,8 @@
"testing"
"github.com/golang/protobuf/proto"
- "golang.org/x/net/context"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
perfpb "google.golang.org/grpc/test/codec_perf"
"google.golang.org/grpc/transport"
)
@@ -150,51 +150,21 @@
// input
errIn error
// outputs
- errOut *rpcError
+ errOut error
}{
- {transport.StreamError{codes.Unknown, ""}, Errorf(codes.Unknown, "").(*rpcError)},
- {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)},
+ {transport.StreamError{Code: codes.Unknown, Desc: ""}, status.Error(codes.Unknown, "")},
+ {transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)},
} {
err := toRPCErr(test.errIn)
- rpcErr, ok := err.(*rpcError)
- if !ok {
- t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{})
+ if _, ok := err.(status.Status); !ok {
+ t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, ""))
}
- if *rpcErr != *test.errOut {
+ if !reflect.DeepEqual(err, test.errOut) {
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
}
}
}
-func TestContextErr(t *testing.T) {
- for _, test := range []struct {
- // input
- errIn error
- // outputs
- errOut transport.StreamError
- }{
- {context.DeadlineExceeded, transport.StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}},
- {context.Canceled, transport.StreamError{codes.Canceled, context.Canceled.Error()}},
- } {
- err := transport.ContextErr(test.errIn)
- if err != test.errOut {
- t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
- }
- }
-}
-
-func TestErrorsWithSameParameters(t *testing.T) {
- const description = "some description"
- e1 := Errorf(codes.AlreadyExists, description).(*rpcError)
- e2 := Errorf(codes.AlreadyExists, description).(*rpcError)
- if e1 == e2 {
- t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
- }
- if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) {
- t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
- }
-}
-
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
// bytes.
func bmEncode(b *testing.B, mSize int) {
diff --git a/server.go b/server.go
index 5049763..e2f4683 100644
--- a/server.go
+++ b/server.go
@@ -56,6 +56,7 @@
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
"google.golang.org/grpc/transport"
)
@@ -694,7 +695,7 @@
stream.SetSendCompress(s.opts.cp.Type())
}
p := &parser{r: stream}
- for {
+ for { // TODO: delete
pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
@@ -704,36 +705,35 @@
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
}
if err != nil {
- switch err := err.(type) {
- case *rpcError:
- if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ switch st := err.(type) {
+ case status.Status:
+ if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
- if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil {
+ if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
- panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
+ panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
}
return err
}
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
- switch err := err.(type) {
- case *rpcError:
- if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ if st, ok := err.(status.Status); ok {
+ if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return err
- default:
- if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
- }
- // TODO checkRecvPayload always return RPC error. Add a return here if necessary.
}
+ if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
+ }
+
+ // TODO checkRecvPayload always return RPC error. Add a return here if necessary.
}
var inPayload *stats.InPayload
if sh != nil {
@@ -741,8 +741,6 @@
RecvTime: time.Now(),
}
}
- statusCode := codes.OK
- statusDesc := ""
df := func(v interface{}) error {
if inPayload != nil {
inPayload.WireLength = len(req)
@@ -751,20 +749,16 @@
var err error
req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil {
- if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
- }
return Errorf(codes.Internal, err.Error())
}
}
if len(req) > s.opts.maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
- statusCode = codes.InvalidArgument
- statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxReceiveMessageSize)
+ return status.Errorf(codes.InvalidArgument, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxReceiveMessageSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
- return err
+ return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
if inPayload != nil {
inPayload.Payload = v
@@ -779,21 +773,20 @@
}
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil {
- if err, ok := appErr.(*rpcError); ok {
- statusCode = err.code
- statusDesc = err.desc
- } else {
- statusCode = convertCode(appErr)
- statusDesc = appErr.Error()
+ appStatus, ok := status.FromError(appErr)
+ if !ok {
+ // Convert appErr if it is not a grpc status error.
+ appErr = status.Error(convertCode(appErr), appErr.Error())
+ appStatus, _ = status.FromError(appErr)
}
- if trInfo != nil && statusCode != codes.OK {
- trInfo.tr.LazyLog(stringer(statusDesc), true)
+ if trInfo != nil {
+ trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
+ if e := t.WriteStatus(stream, appStatus); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
}
- return Errorf(statusCode, statusDesc)
+ return appErr
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
@@ -803,25 +796,17 @@
Delay: false,
}
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
- switch err := err.(type) {
- case transport.ConnectionError:
- // Nothing to do here.
- case transport.StreamError:
- statusCode = err.Code
- statusDesc = err.Desc
- default:
- statusCode = codes.Unknown
- statusDesc = err.Error()
- }
+ // TODO: Translate error into a status.Status error if necessary?
+ // TODO: Write status when appropriate.
+ return err
}
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
- errWrite := t.WriteStatus(stream, statusCode, statusDesc)
- if statusCode != codes.OK {
- return Errorf(statusCode, statusDesc)
- }
- return errWrite
+ // TODO: Should we be logging if writing status failed here, like above?
+ // Should the logging be in WriteStatus? Should we ignore the WriteStatus
+ // error or allow the stats handler to see it?
+ return t.WriteStatus(stream, status.New(codes.OK, ""))
}
}
@@ -891,32 +876,31 @@
appErr = s.opts.streamInt(server, ss, info, sd.Handler)
}
if appErr != nil {
- if err, ok := appErr.(*rpcError); ok {
- ss.statusCode = err.code
- ss.statusDesc = err.desc
- } else if err, ok := appErr.(transport.StreamError); ok {
- ss.statusCode = err.Code
- ss.statusDesc = err.Desc
- } else {
- ss.statusCode = convertCode(appErr)
- ss.statusDesc = appErr.Error()
+ switch err := appErr.(type) {
+ case status.Status:
+ // Do nothing
+ case transport.StreamError:
+ appErr = status.Error(err.Code, err.Desc)
+ default:
+ appErr = status.Error(convertCode(appErr), appErr.Error())
}
+ appStatus, _ := status.FromError(appErr)
+ if trInfo != nil {
+ ss.mu.Lock()
+ ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
+ ss.trInfo.tr.SetError()
+ ss.mu.Unlock()
+ }
+ t.WriteStatus(ss.s, appStatus)
+ // TODO: Should we log an error from WriteStatus here and below?
+ return appErr
}
if trInfo != nil {
ss.mu.Lock()
- if ss.statusCode != codes.OK {
- ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true)
- ss.trInfo.tr.SetError()
- } else {
- ss.trInfo.tr.LazyLog(stringer("OK"), false)
- }
+ ss.trInfo.tr.LazyLog(stringer("OK"), false)
ss.mu.Unlock()
}
- errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
- if ss.statusCode != codes.OK {
- return Errorf(ss.statusCode, ss.statusDesc)
- }
- return errWrite
+ return t.WriteStatus(ss.s, status.New(codes.OK, ""))
}
@@ -932,7 +916,7 @@
trInfo.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
- if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil {
+ if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -957,7 +941,7 @@
trInfo.tr.SetError()
}
errDesc := fmt.Sprintf("unknown service %v", service)
- if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
+ if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -987,7 +971,7 @@
return
}
errDesc := fmt.Sprintf("unknown method %v", method)
- if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
+ if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
diff --git a/stats/stats.go b/stats/stats.go
index a82448a..43d6f00 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -184,7 +184,7 @@
Client bool
// EndTime is the time when the RPC ends.
EndTime time.Time
- // Error is the error just happened. Its type is gRPC error.
+ // Error is the error just happened. It implements status.Status if non-nil.
Error error
}
diff --git a/status/status.go b/status/status.go
new file mode 100644
index 0000000..0e40208
--- /dev/null
+++ b/status/status.go
@@ -0,0 +1,160 @@
+/*
+ *
+ * Copyright 2017, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+// Package status implements errors returned by gRPC. These errors are
+// serialized and transmitted on the wire between server and client, and allow
+// for additional data to be transmitted via the Details field in the status
+// proto. gRPC service handlers should return an error created by this
+// package, and gRPC clients should expect a corresponding error to be
+// returned from the RPC call.
+//
+// This package upholds the invariants that a non-nil error may not
+// contain an OK code, and an OK code must result in a nil error.
+package status
+
+import (
+ "fmt"
+
+ "github.com/golang/protobuf/proto"
+ spb "github.com/google/go-genproto/googleapis/rpc/status"
+ "google.golang.org/grpc/codes"
+)
+
+// Status provides access to grpc status details and is implemented by all
+// errors returned from this package except nil errors, which are not typed.
+// Note: gRPC users should not implement their own Statuses. Custom data may
+// be attached to the spb.Status proto's Details field.
+type Status interface {
+ // Code returns the status code.
+ Code() codes.Code
+ // Message returns the status message.
+ Message() string
+ // Proto returns a copy of the status in proto form.
+ Proto() *spb.Status
+ // Err returns an error representing the status.
+ Err() error
+}
+
+// okStatus is a Status whose Code method returns codes.OK, but does not
+// implement error. To represent an OK code as an error, use an untyped nil.
+type okStatus struct{}
+
+func (okStatus) Code() codes.Code {
+ return codes.OK
+}
+
+func (okStatus) Message() string {
+ return ""
+}
+
+func (okStatus) Proto() *spb.Status {
+ return nil
+}
+
+func (okStatus) Err() error {
+ return nil
+}
+
+// statusError contains a status proto. It is embedded and not aliased to
+// allow for accessor functions of the same name. It implements error and
+// Status, and a nil statusError should never be returned by this package.
+type statusError struct {
+ *spb.Status
+}
+
+func (se *statusError) Error() string {
+ return fmt.Sprintf("rpc error: code = %s desc = %s", se.Code(), se.Message())
+}
+
+func (se *statusError) Code() codes.Code {
+ return codes.Code(se.Status.Code)
+}
+
+func (se *statusError) Message() string {
+ return se.Status.Message
+}
+
+func (se *statusError) Proto() *spb.Status {
+ return proto.Clone(se.Status).(*spb.Status)
+}
+
+func (se *statusError) Err() error {
+ return se
+}
+
+// New returns a Status representing c and msg.
+func New(c codes.Code, msg string) Status {
+ if c == codes.OK {
+ return okStatus{}
+ }
+ return &statusError{Status: &spb.Status{Code: int32(c), Message: msg}}
+}
+
+// Newf returns New(c, fmt.Sprintf(format, a...)).
+func Newf(c codes.Code, format string, a ...interface{}) Status {
+ return New(c, fmt.Sprintf(format, a...))
+}
+
+// Error returns an error representing c and msg. If c is OK, returns nil.
+func Error(c codes.Code, msg string) error {
+ return New(c, msg).Err()
+}
+
+// Errorf returns Error(c, fmt.Sprintf(format, a...)).
+func Errorf(c codes.Code, format string, a ...interface{}) error {
+ return Error(c, fmt.Sprintf(format, a...))
+}
+
+// ErrorProto returns an error representing s. If s.Code is OK, returns nil.
+func ErrorProto(s *spb.Status) error {
+ return FromProto(s).Err()
+}
+
+// FromProto returns a Status representing s. If s.Code is OK, Message and
+// Details may be lost.
+func FromProto(s *spb.Status) Status {
+ if s.GetCode() == int32(codes.OK) {
+ return okStatus{}
+ }
+ return &statusError{Status: proto.Clone(s).(*spb.Status)}
+}
+
+// FromError returns a Status representing err if it was produced from this
+// package, otherwise it returns nil, false.
+func FromError(err error) (s Status, ok bool) {
+ if err == nil {
+ return okStatus{}, true
+ }
+ s, ok = err.(Status)
+ return s, ok
+}
diff --git a/status/status_test.go b/status/status_test.go
new file mode 100644
index 0000000..34de196
--- /dev/null
+++ b/status/status_test.go
@@ -0,0 +1,110 @@
+/*
+ *
+ * Copyright 2017, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package status
+
+import (
+ "reflect"
+ "testing"
+
+ apb "github.com/golang/protobuf/ptypes/any"
+ spb "github.com/google/go-genproto/googleapis/rpc/status"
+ "google.golang.org/grpc/codes"
+)
+
+func TestErrorsWithSameParameters(t *testing.T) {
+ const description = "some description"
+ e1 := Errorf(codes.AlreadyExists, description)
+ e2 := Errorf(codes.AlreadyExists, description)
+ if e1 == e2 || !reflect.DeepEqual(e1, e2) {
+ t.Fatalf("Errors should be equivalent but unique - e1: %v, %v e2: %p, %v", e1.(*statusError), e1, e2.(*statusError), e2)
+ }
+}
+
+func TestFromToProto(t *testing.T) {
+ s := &spb.Status{
+ Code: int32(codes.Internal),
+ Message: "test test test",
+ Details: []*apb.Any{{TypeUrl: "foo", Value: []byte{3, 2, 1}}},
+ }
+
+ err := FromProto(s)
+ if got := err.Proto(); !reflect.DeepEqual(s, got) {
+ t.Fatalf("Expected errors to be identical - s: %v got: %v", s, got)
+ }
+}
+
+func TestError(t *testing.T) {
+ err := Error(codes.Internal, "test description")
+ if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want {
+ t.Fatalf("err.Error() = %q; want %q", got, want)
+ }
+ s := err.(Status)
+ if got, want := s.Code(), codes.Internal; got != want {
+ t.Fatalf("err.Code() = %s; want %s", got, want)
+ }
+ if got, want := s.Message(), "test description"; got != want {
+ t.Fatalf("err.Message() = %s; want %s", got, want)
+ }
+}
+
+func TestErrorOK(t *testing.T) {
+ err := Error(codes.OK, "foo")
+ if err != nil {
+ t.Fatalf("Error(codes.OK, _) = %p; want nil", err.(*statusError))
+ }
+}
+
+func TestErrorProtoOK(t *testing.T) {
+ s := &spb.Status{Code: int32(codes.OK)}
+ if got := ErrorProto(s); got != nil {
+ t.Fatalf("ErrorProto(%v) = %v; want nil", s, got)
+ }
+}
+
+func TestFromError(t *testing.T) {
+ code, message := codes.Internal, "test description"
+ err := Error(code, message)
+ s, ok := FromError(err)
+ if !ok || s.Code() != code || s.Message() != message || s.Err() == nil {
+ t.Fatalf("FromError(%v) = %v, %v; want <Code()=%s, Message()=%q, Err()!=nil>, true", err, s, ok, code, message)
+ }
+}
+
+func TestFromErrorOK(t *testing.T) {
+ code, message := codes.OK, ""
+ s, ok := FromError(nil)
+ if !ok || s.Code() != code || s.Message() != message || s.Err() != nil {
+ t.Fatalf("FromError(nil) = %v, %v; want <Code()=%s, Message()=%q, Err=nil>, true", s, ok, code, message)
+ }
+}
diff --git a/stream.go b/stream.go
index 034bcf7..2f05486 100644
--- a/stream.go
+++ b/stream.go
@@ -45,6 +45,7 @@
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
@@ -205,7 +206,7 @@
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
- if _, ok := err.(*rpcError); ok {
+ if _, ok := err.(status.Status); ok {
return nil, err
}
if err == errConnClosing || err == errConnUnavailable {
@@ -268,11 +269,7 @@
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
- if s.StatusCode() == codes.OK {
- cs.finish(nil)
- } else {
- cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
- }
+ cs.finish(s.Status().Err())
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
@@ -445,11 +442,11 @@
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
- if cs.s.StatusCode() == codes.OK {
- cs.finish(err)
- return nil
+ if se := cs.s.Status().Err(); se != nil {
+ return se
}
- return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
+ cs.finish(err)
+ return nil
}
return toRPCErr(err)
}
@@ -457,11 +454,11 @@
cs.closeTransportStream(err)
}
if err == io.EOF {
- if cs.s.StatusCode() == codes.OK {
- // Returns io.EOF to indicate the end of the stream.
- return
+ if statusErr := cs.s.Status().Err(); statusErr != nil {
+ return statusErr
}
- return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
+ // Returns io.EOF to indicate the end of the stream.
+ return
}
return toRPCErr(err)
}
@@ -545,18 +542,16 @@
// serverStream implements a server side Stream.
type serverStream struct {
- t transport.ServerTransport
- s *transport.Stream
- p *parser
- codec Codec
- cp Compressor
- dc Decompressor
- cbuf *bytes.Buffer
+ t transport.ServerTransport
+ s *transport.Stream
+ p *parser
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ cbuf *bytes.Buffer
maxReceiveMessageSize int
maxSendMessageSize int
- statusCode codes.Code
- statusDesc string
- trInfo *traceInfo
+ trInfo *traceInfo
statsHandler stats.Handler
diff --git a/test/end2end_test.go b/test/end2end_test.go
index b857bb6..32f20c7 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -54,6 +54,8 @@
"time"
"github.com/golang/protobuf/proto"
+ anypb "github.com/golang/protobuf/ptypes/any"
+ spb "github.com/google/go-genproto/googleapis/rpc/status"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc"
@@ -65,6 +67,7 @@
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
testpb "google.golang.org/grpc/test/grpc_testing"
)
@@ -92,8 +95,16 @@
malformedHTTP2Metadata = metadata.MD{
"Key": []string{"foo"},
}
- testAppUA = "myApp1/1.0 myApp2/0.9"
- failAppUA = "fail-this-RPC"
+ testAppUA = "myApp1/1.0 myApp2/0.9"
+ failAppUA = "fail-this-RPC"
+ detailedError = status.ErrorProto(&spb.Status{
+ Code: int32(codes.DataLoss),
+ Message: "error for testing: " + failAppUA,
+ Details: []*anypb.Any{{
+ TypeUrl: "url",
+ Value: []byte{6, 0, 0, 6, 1, 3},
+ }},
+ })
)
var raceMode bool // set by race_test.go in race mode
@@ -111,7 +122,7 @@
// For testing purpose, returns an error if user-agent is failAppUA.
// To test that client gets the correct error.
if ua, ok := md["user-agent"]; !ok || strings.HasPrefix(ua[0], failAppUA) {
- return nil, grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA)
+ return nil, detailedError
}
var str []string
for _, entry := range md["user-agent"] {
@@ -1815,7 +1826,7 @@
cc := te.clientConn()
wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded")
- if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) {
+ if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.DeadlineExceeded)
}
awaitNewConnLogOutput()
@@ -1837,7 +1848,7 @@
te.startServer(&testServer{security: e.security})
defer te.tearDown()
want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health")
- if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) {
+ if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
}
}
@@ -1864,7 +1875,7 @@
te.startServer(&testServer{security: e.security})
defer te.tearDown()
want := grpc.Errorf(codes.Unauthenticated, "user unauthenticated")
- if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) {
+ if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
}
}
@@ -1892,7 +1903,7 @@
t.Fatalf("Got the serving status %v, want SERVING", out.Status)
}
wantErr := grpc.Errorf(codes.NotFound, "unknown service")
- if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) {
+ if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.NotFound)
}
hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING)
@@ -1974,8 +1985,8 @@
tc := testpb.NewTestServiceClient(te.clientConn())
ctx := metadata.NewContext(context.Background(), testMetadata)
- wantErr := grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA)
- if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !equalErrors(err, wantErr) {
+ wantErr := detailedError
+ if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
}
}
@@ -2141,6 +2152,29 @@
}
}
+// TestPeerNegative tests that if call fails setting peer
+// doesn't cause a segmentation fault.
+// issue#1141 https://github.com/grpc/grpc-go/issues/1141
+func TestPeerNegative(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testPeerNegative(t, e)
+ }
+}
+
+func testPeerNegative(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ peer := new(peer.Peer)
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ tc.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(peer))
+}
+
func TestMetadataUnaryRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
@@ -3055,7 +3089,7 @@
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
}
wantErr := grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA)
- if _, err := stream.Recv(); !equalErrors(err, wantErr) {
+ if _, err := stream.Recv(); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, wantErr)
}
}
@@ -4245,7 +4279,3 @@
}
return fw.dst.Write(p)
}
-
-func equalErrors(l, r error) bool {
- return grpc.Code(l) == grpc.Code(r) && grpc.ErrorDesc(l) == grpc.ErrorDesc(r)
-}
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 10b6dc0..5bf6363 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -53,6 +53,7 @@
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/status"
)
// NewServerHandlerTransport returns a ServerTransport handling gRPC
@@ -182,7 +183,7 @@
}
}
-func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
+func (ht *serverHandlerTransport) WriteStatus(s *Stream, st status.Status) error {
err := ht.do(func() {
ht.writeCommonHeaders(s)
@@ -192,10 +193,13 @@
ht.rw.(http.Flusher).Flush()
h := ht.rw.Header()
- h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
- if statusDesc != "" {
- h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
+ h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code()))
+ if m := st.Message(); m != "" {
+ h.Set("Grpc-Message", encodeGrpcMessage(m))
}
+
+ // TODO: Support Grpc-Status-Details-Bin
+
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@@ -234,6 +238,7 @@
// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
h.Add("Trailer", "Grpc-Status")
h.Add("Trailer", "Grpc-Message")
+ // TODO: Support Grpc-Status-Details-Bin
if s.sendCompress != "" {
h.Set("Grpc-Encoding", s.sendCompress)
diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go
index 44adf2e..8437848 100644
--- a/transport/handler_server_test.go
+++ b/transport/handler_server_test.go
@@ -46,6 +46,7 @@
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
)
func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
@@ -298,7 +299,7 @@
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
- st.ht.WriteStatus(s, codes.OK, "")
+ st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
@@ -328,7 +329,7 @@
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
- st.ht.WriteStatus(s, statusCode, msg)
+ st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
@@ -379,7 +380,7 @@
t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
return
}
- ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
+ ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
func(s *Stream) { go runStream(s) },
diff --git a/transport/http2_client.go b/transport/http2_client.go
index d6e2998..7d72698 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -35,7 +35,6 @@
import (
"bytes"
- "fmt"
"io"
"math"
"net"
@@ -54,6 +53,7 @@
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
)
// http2Client implements the ClientTransport interface with HTTP2.
@@ -311,7 +311,7 @@
return s
}
-// NewStream creates a stream and register it into the transport as "active"
+// NewStream creates a stream and registers it into the transport as "active"
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
pr := &peer.Peer{
@@ -802,12 +802,9 @@
return
}
if err := s.fc.onData(uint32(size)); err != nil {
- s.state = streamDone
- s.statusCode = codes.Internal
- s.statusDesc = err.Error()
s.rstStream = true
s.rstError = http2.ErrCodeFlowControl
- close(s.done)
+ s.finish(status.New(codes.Internal, err.Error()))
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
return
@@ -835,10 +832,7 @@
s.mu.Unlock()
return
}
- s.state = streamDone
- s.statusCode = codes.Internal
- s.statusDesc = "server closed the stream without sending trailers"
- close(s.done)
+ s.finish(status.New(codes.Internal, "server closed the stream without sending trailers"))
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -854,18 +848,16 @@
s.mu.Unlock()
return
}
- s.state = streamDone
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
- s.statusCode, ok = http2ErrConvTab[http2.ErrCode(f.ErrCode)]
+ statusCode, ok := http2ErrConvTab[http2.ErrCode(f.ErrCode)]
if !ok {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
- s.statusCode = codes.Unknown
+ statusCode = codes.Unknown
}
- s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
- close(s.done)
+ s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %d", f.ErrCode))
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -944,18 +936,17 @@
}
var state decodeState
for _, hf := range frame.Fields {
- state.processHeaderField(hf)
- }
- if state.err != nil {
- s.mu.Lock()
- if !s.headerDone {
- close(s.headerChan)
- s.headerDone = true
+ if err := state.processHeaderField(hf); err != nil {
+ s.mu.Lock()
+ if !s.headerDone {
+ close(s.headerChan)
+ s.headerDone = true
+ }
+ s.mu.Unlock()
+ s.write(recvMsg{err: err})
+ // Something wrong. Stops reading even when there is remaining.
+ return
}
- s.mu.Unlock()
- s.write(recvMsg{err: state.err})
- // Something wrong. Stops reading even when there is remaining.
- return
}
endStream := frame.StreamEnded()
@@ -998,10 +989,7 @@
if len(state.mdata) > 0 {
s.trailer = state.mdata
}
- s.statusCode = state.statusCode
- s.statusDesc = state.statusDesc
- close(s.done)
- s.state = streamDone
+ s.finish(state.status())
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
diff --git a/transport/http2_server.go b/transport/http2_server.go
index f3bc569..9972a83 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -45,6 +45,7 @@
"sync/atomic"
"time"
+ "github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
@@ -55,6 +56,7 @@
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
)
@@ -227,13 +229,12 @@
var state decodeState
for _, hf := range frame.Fields {
- state.processHeaderField(hf)
- }
- if err := state.err; err != nil {
- if se, ok := err.(StreamError); ok {
- t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]})
+ if err := state.processHeaderField(hf); err != nil {
+ if se, ok := err.(StreamError); ok {
+ t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]})
+ }
+ return
}
- return
}
if frame.StreamEnded() {
@@ -670,7 +671,7 @@
// There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
-func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
+func (t *http2Server) WriteStatus(s *Stream, st status.Status) error {
var headersSent, hasHeader bool
s.mu.Lock()
if s.state == streamDone {
@@ -701,9 +702,24 @@
t.hEnc.WriteField(
hpack.HeaderField{
Name: "grpc-status",
- Value: strconv.Itoa(int(statusCode)),
+ Value: strconv.Itoa(int(st.Code())),
})
- t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
+
+ if p := st.Proto(); p != nil && len(p.Details) > 0 {
+ stBytes, err := proto.Marshal(p)
+ if err != nil {
+ // TODO: return error instead, when callers are able to handle it.
+ panic(err)
+ }
+
+ for k, v := range metadata.New(map[string]string{"grpc-status-details-bin": (string)(stBytes)}) {
+ for _, entry := range v {
+ t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
+ }
+ }
+ }
+
// Attach the trailer metadata.
for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
diff --git a/transport/http_util.go b/transport/http_util.go
index 6b96884..57aad62 100644
--- a/transport/http_util.go
+++ b/transport/http_util.go
@@ -44,11 +44,14 @@
"sync/atomic"
"time"
+ "github.com/golang/protobuf/proto"
+ spb "github.com/google/go-genproto/googleapis/rpc/status"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
)
const (
@@ -90,13 +93,15 @@
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
- err error // first error encountered decoding
-
encoding string
- // statusCode caches the stream status received from the trailer
- // the server sent. Client side only.
- statusCode codes.Code
- statusDesc string
+ // statusGen caches the stream status received from the trailer the server
+ // sent. Client side only. Do not access directly. After all trailers are
+ // parsed, use the status method to retrieve the status.
+ statusGen status.Status
+ // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
+ // intended for direct access outside of parsing.
+ rawStatusCode int32
+ rawStatusMsg string
// Server side only fields.
timeoutSet bool
timeout time.Duration
@@ -119,6 +124,7 @@
"grpc-message",
"grpc-status",
"grpc-timeout",
+ "grpc-status-details-bin",
"te":
return true
default:
@@ -137,12 +143,6 @@
}
}
-func (d *decodeState) setErr(err error) {
- if d.err == nil {
- d.err = err
- }
-}
-
func validContentType(t string) bool {
e := "application/grpc"
if !strings.HasPrefix(t, e) {
@@ -156,31 +156,45 @@
return true
}
-func (d *decodeState) processHeaderField(f hpack.HeaderField) {
+func (d *decodeState) status() status.Status {
+ if d.statusGen == nil {
+ // No status-details were provided; generate status using code/msg.
+ d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg)
+ }
+ return d.statusGen
+}
+
+func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name {
case "content-type":
if !validContentType(f.Value) {
- d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
- return
+ return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)
}
case "grpc-encoding":
d.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {
- d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err))
- return
+ return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)
}
- d.statusCode = codes.Code(code)
+ d.rawStatusCode = int32(code)
case "grpc-message":
- d.statusDesc = decodeGrpcMessage(f.Value)
+ d.rawStatusMsg = decodeGrpcMessage(f.Value)
+ case "grpc-status-details-bin":
+ _, v, err := metadata.DecodeKeyValue("grpc-status-details-bin", f.Value)
+ if err != nil {
+ return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
+ }
+ s := &spb.Status{}
+ if err := proto.Unmarshal([]byte(v), s); err != nil {
+ return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
+ }
+ d.statusGen = status.FromProto(s)
case "grpc-timeout":
d.timeoutSet = true
var err error
- d.timeout, err = decodeTimeout(f.Value)
- if err != nil {
- d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
- return
+ if d.timeout, err = decodeTimeout(f.Value); err != nil {
+ return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)
}
case ":path":
d.method = f.Value
@@ -192,11 +206,12 @@
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
if err != nil {
grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err)
- return
+ return nil
}
d.mdata[k] = append(d.mdata[k], v)
}
}
+ return nil
}
type timeoutUnit uint8
diff --git a/transport/transport.go b/transport/transport.go
index 5171680..3b8bd01 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -51,6 +51,7 @@
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
)
@@ -212,9 +213,8 @@
// true iff headerChan is closed. Used to avoid closing headerChan
// multiple times.
headerDone bool
- // the status received from the server.
- statusCode codes.Code
- statusDesc string
+ // the status error received from the server.
+ status status.Status
// rstStream indicates whether a RST_STREAM frame needs to be sent
// to the server to signify that this stream is closing.
rstStream bool
@@ -284,14 +284,9 @@
return s.method
}
-// StatusCode returns statusCode received from the server.
-func (s *Stream) StatusCode() codes.Code {
- return s.statusCode
-}
-
-// StatusDesc returns statusDesc received from the server.
-func (s *Stream) StatusDesc() string {
- return s.statusDesc
+// Status returns the status received from the server.
+func (s *Stream) Status() status.Status {
+ return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
@@ -338,6 +333,14 @@
return
}
+// finish sets the stream's state and status, and closes the done channel.
+// s.mu must be held by the caller.
+func (s *Stream) finish(st status.Status) {
+ s.status = st
+ s.state = streamDone
+ close(s.done)
+}
+
// The key to save transport.Stream in the context.
type streamKey struct{}
@@ -503,10 +506,9 @@
// Write may not be called on all streams.
Write(s *Stream, data []byte, opts *Options) error
- // WriteStatus sends the status of a stream to the client.
- // WriteStatus is the final call made on a stream and always
- // occurs.
- WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
+ // WriteStatus sends the status of a stream to the client. WriteStatus is
+ // the final call made on a stream and always occurs.
+ WriteStatus(s *Stream, st status.Status) error
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
@@ -572,6 +574,8 @@
ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
)
+// TODO: See if we can replace StreamError with status package errors.
+
// StreamError is an error that only affects one stream within a connection.
type StreamError struct {
Code codes.Code
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 3108b98..4e986e5 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -39,6 +39,7 @@
"io"
"math"
"net"
+ "reflect"
"strconv"
"strings"
"sync"
@@ -50,6 +51,7 @@
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
+ "google.golang.org/grpc/status"
)
type server struct {
@@ -100,7 +102,7 @@
// send a response back to the client.
h.t.Write(s, resp, &Options{})
// send the trailer to end the stream.
- h.t.WriteStatus(s, codes.OK, "")
+ h.t.WriteStatus(s, status.New(codes.OK, ""))
}
// handleStreamSuspension blocks until s.ctx is canceled.
@@ -142,7 +144,7 @@
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer so it must be encoded.
- h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
+ h.t.WriteStatus(s, encodingTestStatus)
}
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
@@ -1070,8 +1072,11 @@
}
// Server sent a resetStream for s already.
code := http2ErrConvTab[http2.ErrCodeFlowControl]
- if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code {
- t.Fatalf("%v got err %v with statusCode %d, want err <EOF> with statusCode %d", s, err, s.statusCode, code)
+ if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF {
+ t.Fatalf("%v got err %v want <EOF>", s, err)
+ }
+ if s.status.Code() != code {
+ t.Fatalf("%v got status %v; want Code=%v", s, s.status, code)
}
if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate <= initialWindowSize {
@@ -1125,9 +1130,14 @@
if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData <= initialWindowSize || conn.fc.pendingUpdate != 0 {
t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, >%d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0)
}
- if err != io.EOF || s.statusCode != codes.Internal {
- t.Fatalf("Got err %v and the status code %d, want <EOF> and the code %d", err, s.statusCode, codes.Internal)
+
+ if err != io.EOF {
+ t.Fatalf("Got err %v, want <EOF>", err)
}
+ if s.status.Code() != codes.Internal {
+ t.Fatalf("Got s.status %v, want s.status.Code()=Internal", s.status)
+ }
+
conn.CloseStream(s, err)
if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate <= initialWindowSize {
t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize)
@@ -1152,10 +1162,7 @@
server.stop()
}
-var (
- encodingTestStatusCode = codes.Internal
- encodingTestStatusDesc = "\n"
-)
+var encodingTestStatus = status.New(codes.Internal, "\n")
func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
@@ -1178,8 +1185,8 @@
if _, err := s.dec.Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
- if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
- t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
+ if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
+ t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus)
}
ct.Close()
server.stop()
@@ -1242,3 +1249,20 @@
}
}
}
+
+func TestContextErr(t *testing.T) {
+ for _, test := range []struct {
+ // input
+ errIn error
+ // outputs
+ errOut StreamError
+ }{
+ {context.DeadlineExceeded, StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}},
+ {context.Canceled, StreamError{codes.Canceled, context.Canceled.Error()}},
+ } {
+ err := ContextErr(test.errIn)
+ if err != test.errOut {
+ t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
+ }
+ }
+}