Merge pull request #937 from menghanl/reflection_tutorial
Add server-reflection-tutorial.md
diff --git a/Documentation/grpc-auth-support.md b/Documentation/grpc-auth-support.md
index 5312b57..248d272 100644
--- a/Documentation/grpc-auth-support.md
+++ b/Documentation/grpc-auth-support.md
@@ -26,7 +26,7 @@
## Google Compute Engine (GCE)
```Go
-conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""), grpc.WithPerRPCCredentials(oauth.NewComputeEngine())))
+conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))
```
## JWT
@@ -36,6 +36,6 @@
if err != nil {
log.Fatalf("Failed to create JWT credentials: %v", err)
}
-conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""), grpc.WithPerRPCCredentials(jwtCreds)))
+conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(jwtCreds))
```
diff --git a/README.md b/README.md
index 660658b..39120c2 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,23 @@
Prerequisites
-------------
-This requires Go 1.5 or later .
+This requires Go 1.5 or later.
+
+A note on the version used: significant performance improvements in benchmarks
+of grpc-go have been seen by upgrading the go version from 1.5 to the latest
+1.7.1.
+
+From https://golang.org/doc/install, one way to install the latest version of go is:
+```
+$ GO_VERSION=1.7.1
+$ OS=linux
+$ ARCH=amd64
+$ curl -O https://storage.googleapis.com/golang/go${GO_VERSION}.${OS}-${ARCH}.tar.gz
+$ sudo tar -C /usr/local -xzf go$GO_VERSION.$OS-$ARCH.tar.gz
+$ # Put go on the PATH, keep the usual installation dir
+$ sudo ln -s /usr/local/go/bin/go /usr/bin/go
+$ rm go$GO_VERSION.$OS-$ARCH.tar.gz
+```
Constraints
-----------
@@ -30,3 +46,12 @@
------
GA
+FAQ
+---
+
+#### Compiling error, undefined: grpc.SupportPackageIsVersion
+
+Please update proto package, gRPC package and rebuild the proto files:
+ - `go get -u github.com/golang/protobuf/{proto,protoc-gen-go}`
+ - `go get -u google.golang.org/grpc`
+ - `protoc --go_out=plugins=grpc:. *.proto`
diff --git a/benchmark/grpc_testing/services.pb.go b/benchmark/grpc_testing/services.pb.go
index 15d0864..2aae317 100644
--- a/benchmark/grpc_testing/services.pb.go
+++ b/benchmark/grpc_testing/services.pb.go
@@ -24,7 +24,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for BenchmarkService service
@@ -161,7 +161,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor3,
+ Metadata: "services.proto",
}
// Client API for WorkerService service
@@ -417,7 +417,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor3,
+ Metadata: "services.proto",
}
func init() { proto.RegisterFile("services.proto", fileDescriptor3) }
diff --git a/call.go b/call.go
index 788b3d9..5d9214d 100644
--- a/call.go
+++ b/call.go
@@ -42,6 +42,7 @@
"golang.org/x/net/context"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/transport"
)
@@ -49,9 +50,9 @@
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
-func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
+// TODO ctx is used for stats collection and processing. It is the context passed from the application.
+func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
// Try to acquire header metadata from the server if there is any.
- var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
@@ -61,17 +62,28 @@
}()
c.headerMD, err = stream.Header()
if err != nil {
- return err
+ return
}
p := &parser{r: stream}
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{
+ Client: true,
+ }
+ }
for {
- if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
if err == io.EOF {
break
}
- return err
+ return
}
}
+ if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
+ // TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
+ // Fix the order if necessary.
+ stats.Handle(ctx, inPayload)
+ }
c.trailerMD = stream.Trailer()
return nil
}
@@ -90,15 +102,27 @@
}
}
}()
- var cbuf *bytes.Buffer
+ var (
+ cbuf *bytes.Buffer
+ outPayload *stats.OutPayload
+ )
if compressor != nil {
cbuf = new(bytes.Buffer)
}
- outBuf, err := encode(codec, args, compressor, cbuf)
+ if stats.On() {
+ outPayload = &stats.OutPayload{
+ Client: true,
+ }
+ }
+ outBuf, err := encode(codec, args, compressor, cbuf, outPayload)
if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err)
}
err = t.Write(stream, outBuf, opts)
+ if err == nil && outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.Handle(ctx, outPayload)
+ }
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
@@ -119,7 +143,7 @@
return invoke(ctx, method, args, reply, cc, opts...)
}
-func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
+func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo
for _, o := range opts {
if err := o.before(&c); err != nil {
@@ -141,12 +165,30 @@
c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false)
// TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set.
defer func() {
- if err != nil {
- c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
+ if e != nil {
+ c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true)
c.traceInfo.tr.SetError()
}
}()
}
+ if stats.On() {
+ begin := &stats.Begin{
+ Client: true,
+ BeginTime: time.Now(),
+ FailFast: c.failFast,
+ }
+ stats.Handle(ctx, begin)
+ }
+ defer func() {
+ if stats.On() {
+ end := &stats.End{
+ Client: true,
+ EndTime: time.Now(),
+ Error: e,
+ }
+ stats.Handle(ctx, end)
+ }
+ }()
topts := &transport.Options{
Last: true,
Delay: false,
@@ -206,7 +248,7 @@
}
return toRPCErr(err)
}
- err = recvResponse(cc.dopts, t, &c, stream, reply)
+ err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
diff --git a/call_test.go b/call_test.go
index 64976d7..3c2165e 100644
--- a/call_test.go
+++ b/call_test.go
@@ -118,7 +118,7 @@
}
}
// send a response back to end the stream.
- reply, err := encode(testCodec{}, &expectedResponse, nil, nil)
+ reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
@@ -164,7 +164,10 @@
if err != nil {
return
}
- st, err := transport.NewServerTransport("http2", conn, maxStreams, nil)
+ config := &transport.ServerConfig{
+ MaxStreams: maxStreams,
+ }
+ st, err := transport.NewServerTransport("http2", conn, config)
if err != nil {
continue
}
@@ -182,6 +185,8 @@
}
go st.HandleStreams(func(s *transport.Stream) {
go h.handleStream(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
})
}
}
diff --git a/clientconn.go b/clientconn.go
index 6167472..f6dab4b 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -199,6 +199,8 @@
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
+// If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's
+// Temporary() method to decide if it should try to reconnect to the network address.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
@@ -210,6 +212,17 @@
}
}
+// FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors.
+// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
+// address and won't try to reconnect.
+// The default value of FailOnNonTempDialError is false.
+// This is an EXPERIMENTAL API.
+func FailOnNonTempDialError(f bool) DialOption {
+ return func(o *dialOptions) {
+ o.copts.FailOnNonTempDialError = f
+ }
+}
+
// WithUserAgent returns a DialOption that specifies a user agent string for all the RPCs.
func WithUserAgent(s string) DialOption {
return func(o *dialOptions) {
diff --git a/clientconn_test.go b/clientconn_test.go
index 179be25..98612f4 100644
--- a/clientconn_test.go
+++ b/clientconn_test.go
@@ -34,13 +34,13 @@
package grpc
import (
+ "net"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
- "google.golang.org/grpc/credentials/oauth"
)
const tlsDir = "testdata/"
@@ -130,6 +130,17 @@
<-dialDone
}
+// securePerRPCCredentials always requires transport security.
+type securePerRPCCredentials struct{}
+
+func (c securePerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
+ return nil, nil
+}
+
+func (c securePerRPCCredentials) RequireTransportSecurity() bool {
+ return true
+}
+
func TestCredentialsMisuse(t *testing.T) {
tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
@@ -139,12 +150,8 @@
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
}
- rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
- if err != nil {
- t.Fatalf("Failed to create credentials %v", err)
- }
// security info on insecure connection
- if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
+ if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
}
}
@@ -188,3 +195,33 @@
}
conn.Close()
}
+
+type testErr struct {
+ temp bool
+}
+
+func (e *testErr) Error() string {
+ return "test error"
+}
+
+func (e *testErr) Temporary() bool {
+ return e.temp
+}
+
+var nonTemporaryError = &testErr{false}
+
+func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, error) {
+ return nil, nonTemporaryError
+}
+
+func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) {
+ ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock(), FailOnNonTempDialError(true)); err != nonTemporaryError {
+ t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError)
+ }
+
+ // Without FailOnNonTempDialError, gRPC will retry to connect, and dial should exit with time out error.
+ if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock()); err != context.DeadlineExceeded {
+ t.Fatalf("Dial(%q) = %v, want %v", "", err, context.DeadlineExceeded)
+ }
+}
diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go
index 8e68c4d..25393cc 100644
--- a/credentials/oauth/oauth.go
+++ b/credentials/oauth/oauth.go
@@ -61,7 +61,7 @@
}, nil
}
-// RequireTransportSecurity indicates whether the credentails requires transport security.
+// RequireTransportSecurity indicates whether the credentials requires transport security.
func (ts TokenSource) RequireTransportSecurity() bool {
return true
}
diff --git a/examples/README.md b/examples/README.md
index b65f8c5..6ea6b35 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -8,7 +8,7 @@
PREREQUISITES
-------------
-- This requires Go 1.4
+- This requires Go 1.5 or later
- Requires that [GOPATH is set](https://golang.org/doc/code.html#GOPATH)
```
diff --git a/examples/gotutorial.md b/examples/gotutorial.md
index 25c0a2d..6770b52 100644
--- a/examples/gotutorial.md
+++ b/examples/gotutorial.md
@@ -33,7 +33,7 @@
## Defining the service
-Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto).
+Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers](https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto).
To define a service, you specify a named `service` in your .proto file:
diff --git a/examples/helloworld/helloworld/helloworld.pb.go b/examples/helloworld/helloworld/helloworld.pb.go
index 0419f6a..c8c8942 100644
--- a/examples/helloworld/helloworld/helloworld.pb.go
+++ b/examples/helloworld/helloworld/helloworld.pb.go
@@ -65,7 +65,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for Greeter service
@@ -130,7 +130,7 @@
},
},
Streams: []grpc.StreamDesc{},
- Metadata: fileDescriptor0,
+ Metadata: "helloworld.proto",
}
func init() { proto.RegisterFile("helloworld.proto", fileDescriptor0) }
diff --git a/examples/route_guide/routeguide/route_guide.pb.go b/examples/route_guide/routeguide/route_guide.pb.go
index 9bb1d60..cbcf2f3 100644
--- a/examples/route_guide/routeguide/route_guide.pb.go
+++ b/examples/route_guide/routeguide/route_guide.pb.go
@@ -156,7 +156,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for RouteGuide service
@@ -452,7 +452,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor0,
+ Metadata: "route_guide.proto",
}
func init() { proto.RegisterFile("route_guide.proto", fileDescriptor0) }
diff --git a/grpclb/grpc_lb_v1/grpclb.pb.go b/grpclb/grpc_lb_v1/grpclb.pb.go
index da371e5..7be8947 100644
--- a/grpclb/grpc_lb_v1/grpclb.pb.go
+++ b/grpclb/grpc_lb_v1/grpclb.pb.go
@@ -420,7 +420,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for LoadBalancer service
@@ -517,7 +517,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor0,
+ Metadata: "grpclb.proto",
}
func init() { proto.RegisterFile("grpclb.proto", fileDescriptor0) }
diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go
index 996d27a..d9a1a8b 100644
--- a/grpclb/grpclb.go
+++ b/grpclb/grpclb.go
@@ -40,6 +40,7 @@
"errors"
"fmt"
"sync"
+ "time"
"golang.org/x/net/context"
"google.golang.org/grpc"
@@ -93,16 +94,17 @@
}
type balancer struct {
- r naming.Resolver
- mu sync.Mutex
- seq int // a sequence number to make sure addrCh does not get stale addresses.
- w naming.Watcher
- addrCh chan []grpc.Address
- rbs []remoteBalancerInfo
- addrs []*addrInfo
- next int
- waitCh chan struct{}
- done bool
+ r naming.Resolver
+ mu sync.Mutex
+ seq int // a sequence number to make sure addrCh does not get stale addresses.
+ w naming.Watcher
+ addrCh chan []grpc.Address
+ rbs []remoteBalancerInfo
+ addrs []*addrInfo
+ next int
+ waitCh chan struct{}
+ done bool
+ expTimer *time.Timer
}
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
@@ -180,14 +182,39 @@
return nil
}
+func (b *balancer) serverListExpire(seq int) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ // TODO: gRPC interanls do not clear the connections when the server list is stale.
+ // This means RPCs will keep using the existing server list until b receives new
+ // server list even though the list is expired. Revisit this behavior later.
+ if b.done || seq < b.seq {
+ return
+ }
+ b.next = 0
+ b.addrs = nil
+ // Ask grpc internals to close all the corresponding connections.
+ b.addrCh <- nil
+}
+
+func convertDuration(d *lbpb.Duration) time.Duration {
+ if d == nil {
+ return 0
+ }
+ return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
+}
+
func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
+ if l == nil {
+ return
+ }
servers := l.GetServers()
+ expiration := convertDuration(l.GetExpirationInterval())
var (
sl []*addrInfo
addrs []grpc.Address
)
for _, s := range servers {
- // TODO: Support ExpirationInterval
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
addr := grpc.Address{
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
@@ -209,11 +236,20 @@
b.next = 0
b.addrs = sl
b.addrCh <- addrs
+ if b.expTimer != nil {
+ b.expTimer.Stop()
+ b.expTimer = nil
+ }
+ if expiration > 0 {
+ b.expTimer = time.AfterFunc(expiration, func() {
+ b.serverListExpire(seq)
+ })
+ }
}
return
}
-func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
+func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
@@ -226,8 +262,6 @@
b.mu.Unlock()
return
}
- b.seq++
- seq := b.seq
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
@@ -260,6 +294,14 @@
if err != nil {
break
}
+ b.mu.Lock()
+ if b.done || seq < b.seq {
+ b.mu.Unlock()
+ return
+ }
+ b.seq++ // tick when receiving a new list of servers.
+ seq = b.seq
+ b.mu.Unlock()
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
@@ -326,10 +368,15 @@
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
return
}
+ b.mu.Lock()
+ b.seq++ // tick when getting a new balancer address
+ seq := b.seq
+ b.next = 0
+ b.mu.Unlock()
go func(cc *grpc.ClientConn) {
lbc := lbpb.NewLoadBalancerClient(cc)
for {
- if retry := b.callRemoteBalancer(lbc); !retry {
+ if retry := b.callRemoteBalancer(lbc, seq); !retry {
cc.Close()
return
}
@@ -497,6 +544,9 @@
b.mu.Lock()
defer b.mu.Unlock()
b.done = true
+ if b.expTimer != nil {
+ b.expTimer.Stop()
+ }
if b.waitCh != nil {
close(b.waitCh)
}
diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go
index 3215bea..f034b6b 100644
--- a/grpclb/grpclb_test.go
+++ b/grpclb/grpclb_test.go
@@ -162,14 +162,16 @@
}
type remoteBalancer struct {
- servers *lbpb.ServerList
- done chan struct{}
+ sls []*lbpb.ServerList
+ intervals []time.Duration
+ done chan struct{}
}
-func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer {
+func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
- servers: servers,
- done: make(chan struct{}),
+ sls: sls,
+ intervals: intervals,
+ done: make(chan struct{}),
}
}
@@ -186,13 +188,16 @@
if err := stream.Send(resp); err != nil {
return err
}
- resp = &lbpb.LoadBalanceResponse{
- LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
- ServerList: b.servers,
- },
- }
- if err := stream.Send(resp); err != nil {
- return err
+ for k, v := range b.sls {
+ time.Sleep(b.intervals[k])
+ resp = &lbpb.LoadBalanceResponse{
+ LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
+ ServerList: v,
+ },
+ }
+ if err := stream.Send(resp); err != nil {
+ return err
+ }
}
<-b.done
return nil
@@ -268,7 +273,9 @@
sl := &lbpb.ServerList{
Servers: bes,
}
- ls := newRemoteBalancer(sl)
+ sls := []*lbpb.ServerList{sl}
+ intervals := []time.Duration{0}
+ ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
@@ -343,7 +350,9 @@
sl := &lbpb.ServerList{
Servers: bes,
}
- ls := newRemoteBalancer(sl)
+ sls := []*lbpb.ServerList{sl}
+ intervals := []time.Duration{0}
+ ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
@@ -413,7 +422,9 @@
sl := &lbpb.ServerList{
Servers: bes,
}
- ls := newRemoteBalancer(sl)
+ sls := []*lbpb.ServerList{sl}
+ intervals := []time.Duration{0}
+ ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
@@ -439,3 +450,86 @@
}
cc.Close()
}
+
+func TestServerExpiration(t *testing.T) {
+ // Start a backend.
+ beLis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Failed to listen %v", err)
+ }
+ beAddr := strings.Split(beLis.Addr().String(), ":")
+ bePort, err := strconv.Atoi(beAddr[1])
+ backends := startBackends(t, besn, beLis)
+ defer stopBackends(backends)
+
+ // Start a load balancer.
+ lbLis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Failed to create the listener for the load balancer %v", err)
+ }
+ lbCreds := &serverNameCheckCreds{
+ sn: lbsn,
+ }
+ lb := grpc.NewServer(grpc.Creds(lbCreds))
+ if err != nil {
+ t.Fatalf("Failed to generate the port number %v", err)
+ }
+ be := &lbpb.Server{
+ IpAddress: []byte(beAddr[0]),
+ Port: int32(bePort),
+ LoadBalanceToken: lbToken,
+ }
+ var bes []*lbpb.Server
+ bes = append(bes, be)
+ exp := &lbpb.Duration{
+ Seconds: 0,
+ Nanos: 100000000, // 100ms
+ }
+ var sls []*lbpb.ServerList
+ sl := &lbpb.ServerList{
+ Servers: bes,
+ ExpirationInterval: exp,
+ }
+ sls = append(sls, sl)
+ sl = &lbpb.ServerList{
+ Servers: bes,
+ }
+ sls = append(sls, sl)
+ var intervals []time.Duration
+ intervals = append(intervals, 0)
+ intervals = append(intervals, 500*time.Millisecond)
+ ls := newRemoteBalancer(sls, intervals)
+ lbpb.RegisterLoadBalancerServer(lb, ls)
+ go func() {
+ lb.Serve(lbLis)
+ }()
+ defer func() {
+ ls.stop()
+ lb.Stop()
+ }()
+ creds := serverNameCheckCreds{
+ expected: besn,
+ }
+ ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+ cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
+ addr: lbLis.Addr().String(),
+ })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
+ if err != nil {
+ t.Fatalf("Failed to dial to the backend %v", err)
+ }
+ helloC := hwpb.NewGreeterClient(cc)
+ if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
+ }
+ // Sleep and wake up when the first server list gets expired.
+ time.Sleep(150 * time.Millisecond)
+ if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
+ }
+ // A non-failfast rpc should be succeeded after the second server list is received from
+ // the remote load balancer.
+ if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
+ }
+ cc.Close()
+}
diff --git a/grpclog/glogger/glogger.go b/grpclog/glogger/glogger.go
index 53e3c53..43b886c 100644
--- a/grpclog/glogger/glogger.go
+++ b/grpclog/glogger/glogger.go
@@ -37,6 +37,8 @@
package glogger
import (
+ "fmt"
+
"github.com/golang/glog"
"google.golang.org/grpc/grpclog"
)
@@ -48,25 +50,25 @@
type glogger struct{}
func (g *glogger) Fatal(args ...interface{}) {
- glog.Fatal(args...)
+ glog.FatalDepth(2, args...)
}
func (g *glogger) Fatalf(format string, args ...interface{}) {
- glog.Fatalf(format, args...)
+ glog.FatalDepth(2, fmt.Sprintf(format, args...))
}
func (g *glogger) Fatalln(args ...interface{}) {
- glog.Fatalln(args...)
+ glog.FatalDepth(2, fmt.Sprintln(args...))
}
func (g *glogger) Print(args ...interface{}) {
- glog.Info(args...)
+ glog.InfoDepth(2, args...)
}
func (g *glogger) Printf(format string, args ...interface{}) {
- glog.Infof(format, args...)
+ glog.InfoDepth(2, fmt.Sprintf(format, args...))
}
func (g *glogger) Println(args ...interface{}) {
- glog.Infoln(args...)
+ glog.InfoDepth(2, fmt.Sprintln(args...))
}
diff --git a/health/grpc_health_v1/health.pb.go b/health/grpc_health_v1/health.pb.go
index 0e6a910..89c4d45 100644
--- a/health/grpc_health_v1/health.pb.go
+++ b/health/grpc_health_v1/health.pb.go
@@ -90,7 +90,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for Health service
@@ -153,7 +153,7 @@
},
},
Streams: []grpc.StreamDesc{},
- Metadata: fileDescriptor0,
+ Metadata: "health.proto",
}
func init() { proto.RegisterFile("health.proto", fileDescriptor0) }
diff --git a/interop/client/client.go b/interop/client/client.go
index 7ae864e..7961752 100644
--- a/interop/client/client.go
+++ b/interop/client/client.go
@@ -71,7 +71,10 @@
oauth2_auth_token: large_unary with oauth2 token auth;
cancel_after_begin: cancellation after metadata has been sent but before payloads are sent;
cancel_after_first_response: cancellation after receiving 1st message from the server;
- status_code_and_message: status code propagated back to client.`)
+ status_code_and_message: status code propagated back to client;
+ custom_metadata: server will echo custom metadata;
+ unimplemented_method: client attempts to call unimplemented method;
+ unimplemented_service: client attempts to call unimplemented service.`)
// The test CA root cert file
testCAFile = "testdata/ca.pem"
@@ -184,6 +187,15 @@
case "status_code_and_message":
interop.DoStatusCodeAndMessage(tc)
grpclog.Println("StatusCodeAndMessage done")
+ case "custom_metadata":
+ interop.DoCustomMetadata(tc)
+ grpclog.Println("CustomMetadata done")
+ case "unimplemented_method":
+ interop.DoUnimplementedMethod(conn)
+ grpclog.Println("UnimplementedMethod done")
+ case "unimplemented_service":
+ interop.DoUnimplementedService(testpb.NewUnimplementedServiceClient(conn))
+ grpclog.Println("UnimplementedService done")
default:
grpclog.Fatal("Unsupported test case: ", *testCase)
}
diff --git a/interop/grpc_testing/test.pb.go b/interop/grpc_testing/test.pb.go
index 54ead93..76ae564 100755
--- a/interop/grpc_testing/test.pb.go
+++ b/interop/grpc_testing/test.pb.go
@@ -407,7 +407,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for TestService service
@@ -789,51 +789,118 @@
ClientStreams: true,
},
},
+ Metadata: "test.proto",
+}
+
+// Client API for UnimplementedService service
+
+type UnimplementedServiceClient interface {
+ // A call that no server should implement
+ UnimplementedCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
+}
+
+type unimplementedServiceClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewUnimplementedServiceClient(cc *grpc.ClientConn) UnimplementedServiceClient {
+ return &unimplementedServiceClient{cc}
+}
+
+func (c *unimplementedServiceClient) UnimplementedCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
+ out := new(Empty)
+ err := grpc.Invoke(ctx, "/grpc.testing.UnimplementedService/UnimplementedCall", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// Server API for UnimplementedService service
+
+type UnimplementedServiceServer interface {
+ // A call that no server should implement
+ UnimplementedCall(context.Context, *Empty) (*Empty, error)
+}
+
+func RegisterUnimplementedServiceServer(s *grpc.Server, srv UnimplementedServiceServer) {
+ s.RegisterService(&_UnimplementedService_serviceDesc, srv)
+}
+
+func _UnimplementedService_UnimplementedCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(Empty)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(UnimplementedServiceServer).UnimplementedCall(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpc.testing.UnimplementedService/UnimplementedCall",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(UnimplementedServiceServer).UnimplementedCall(ctx, req.(*Empty))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+var _UnimplementedService_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "grpc.testing.UnimplementedService",
+ HandlerType: (*UnimplementedServiceServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "UnimplementedCall",
+ Handler: _UnimplementedService_UnimplementedCall_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
Metadata: fileDescriptor0,
}
func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
- // 625 bytes of a gzipped FileDescriptorProto
- 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xbc, 0x54, 0x4d, 0x6f, 0xd3, 0x4c,
- 0x10, 0x7e, 0x9d, 0x8f, 0x37, 0xcd, 0x24, 0x35, 0xd1, 0x46, 0x15, 0xae, 0x8b, 0x44, 0x65, 0x0e,
- 0x18, 0x24, 0x02, 0x8a, 0x04, 0x07, 0x0e, 0xa0, 0xd0, 0xa6, 0xa2, 0x52, 0x9b, 0x04, 0x3b, 0x39,
- 0x47, 0x4b, 0xb2, 0x75, 0x2d, 0x39, 0xb6, 0xb1, 0xd7, 0x88, 0x70, 0xe0, 0xcf, 0xf0, 0x23, 0x38,
- 0xf0, 0xe7, 0xd8, 0x5d, 0x7f, 0xc4, 0x49, 0x5c, 0x91, 0xf2, 0x75, 0xdb, 0x7d, 0xf6, 0x99, 0x67,
- 0xe6, 0x99, 0x19, 0x1b, 0x80, 0x92, 0x90, 0x76, 0xfc, 0xc0, 0xa3, 0x1e, 0x6a, 0x5a, 0x81, 0x3f,
- 0xeb, 0x70, 0xc0, 0x76, 0x2d, 0xad, 0x06, 0xd5, 0xfe, 0xc2, 0xa7, 0x4b, 0xed, 0x02, 0x6a, 0x23,
- 0xbc, 0x74, 0x3c, 0x3c, 0x47, 0x4f, 0xa0, 0x42, 0x97, 0x3e, 0x51, 0xa4, 0x63, 0x49, 0x97, 0xbb,
- 0x87, 0x9d, 0x7c, 0x40, 0x27, 0x21, 0x8d, 0x19, 0xc1, 0x10, 0x34, 0x84, 0xa0, 0xf2, 0xde, 0x9b,
- 0x2f, 0x95, 0x12, 0xa3, 0x37, 0x0d, 0x71, 0xd6, 0x5e, 0x02, 0xf4, 0x67, 0xd7, 0x9e, 0x49, 0x31,
- 0x8d, 0x42, 0xce, 0x98, 0x79, 0xf3, 0x58, 0xb0, 0x6a, 0x88, 0x33, 0x52, 0xa0, 0xb6, 0x20, 0x61,
- 0x88, 0x2d, 0x22, 0x02, 0xeb, 0x46, 0x7a, 0xd5, 0xbe, 0x95, 0x60, 0xdf, 0xb4, 0x17, 0xbe, 0x43,
- 0x0c, 0xf2, 0x21, 0x62, 0x69, 0xd1, 0x2b, 0xd8, 0x0f, 0x48, 0xe8, 0x7b, 0x6e, 0x48, 0xa6, 0xbb,
- 0x55, 0xd6, 0x4c, 0xf9, 0xfc, 0x86, 0x1e, 0xe4, 0xe2, 0x43, 0xfb, 0x73, 0x9c, 0xb1, 0xba, 0x22,
- 0x99, 0x0c, 0x43, 0x4f, 0xa1, 0xe6, 0xc7, 0x0a, 0x4a, 0x99, 0x3d, 0x37, 0xba, 0x07, 0x85, 0xf2,
- 0x46, 0xca, 0xe2, 0xaa, 0x57, 0xb6, 0xe3, 0x4c, 0xa3, 0x90, 0x04, 0x2e, 0x5e, 0x10, 0xa5, 0xc2,
- 0xc2, 0xf6, 0x8c, 0x26, 0x07, 0x27, 0x09, 0x86, 0x74, 0x68, 0x09, 0x92, 0x87, 0x23, 0x7a, 0x3d,
- 0x0d, 0x67, 0x1e, 0xab, 0xbe, 0x2a, 0x78, 0x32, 0xc7, 0x87, 0x1c, 0x36, 0x39, 0x8a, 0x7a, 0x70,
- 0x67, 0x55, 0xa4, 0xe8, 0x9b, 0x52, 0x13, 0x75, 0x28, 0xeb, 0x75, 0xac, 0xfa, 0x6a, 0xc8, 0x99,
- 0x01, 0x71, 0xd7, 0xbe, 0x80, 0x9c, 0x36, 0x2e, 0xc6, 0xf3, 0xa6, 0xa4, 0x9d, 0x4c, 0xa9, 0xb0,
- 0x97, 0xf9, 0x89, 0xe7, 0x92, 0xdd, 0xd1, 0x7d, 0x68, 0xe4, 0x6d, 0x94, 0xc5, 0x33, 0x78, 0x99,
- 0x05, 0xb6, 0x43, 0x87, 0x26, 0x0d, 0x08, 0x5e, 0x30, 0xe9, 0x73, 0xd7, 0x8f, 0xe8, 0x09, 0x76,
- 0x9c, 0x74, 0x88, 0xb7, 0x2d, 0x45, 0x1b, 0x83, 0x5a, 0xa4, 0x96, 0x38, 0x7b, 0x01, 0x77, 0xb1,
- 0x65, 0x05, 0xc4, 0xc2, 0x94, 0xcc, 0xa7, 0x49, 0x4c, 0x3c, 0xdd, 0x78, 0xcd, 0x0e, 0x56, 0xcf,
- 0x89, 0x34, 0x1f, 0xb3, 0x76, 0x0e, 0x28, 0xd5, 0x18, 0xe1, 0x80, 0xd9, 0xa2, 0x24, 0x10, 0x1b,
- 0x9a, 0x0b, 0x15, 0x67, 0x6e, 0xd7, 0x76, 0xd9, 0xeb, 0x47, 0xcc, 0x67, 0x9c, 0xec, 0x0c, 0xa4,
- 0xd0, 0x24, 0xd4, 0xbe, 0x96, 0x72, 0x15, 0x0e, 0x23, 0xba, 0x61, 0xf8, 0x77, 0xb7, 0xf6, 0x1d,
- 0xb4, 0xb3, 0x78, 0x3f, 0x2b, 0x95, 0xd5, 0x51, 0x66, 0xcd, 0x3b, 0x5e, 0x57, 0xd9, 0xb6, 0x64,
- 0xa0, 0x60, 0xdb, 0xe6, 0xad, 0x77, 0xfc, 0x0f, 0x2c, 0xe5, 0x00, 0x8e, 0x0a, 0x9b, 0xf4, 0x8b,
- 0x1b, 0xfa, 0xf8, 0x35, 0x34, 0x72, 0x3d, 0x43, 0x2d, 0x68, 0x9e, 0x0c, 0x2f, 0x47, 0x46, 0xdf,
- 0x34, 0x7b, 0x6f, 0x2e, 0xfa, 0xad, 0xff, 0xd8, 0x2c, 0xe5, 0xc9, 0x60, 0x0d, 0x93, 0x10, 0xc0,
- 0xff, 0x46, 0x6f, 0x70, 0x3a, 0xbc, 0x6c, 0x95, 0xba, 0xdf, 0x2b, 0xd0, 0x18, 0x33, 0x75, 0x93,
- 0xcd, 0xd1, 0x9e, 0x11, 0xf4, 0x1c, 0xea, 0xe2, 0x17, 0xc8, 0xcb, 0x42, 0xed, 0x0d, 0x5f, 0xfc,
- 0x41, 0x2d, 0x02, 0xd1, 0x19, 0xd4, 0x27, 0x2e, 0x0e, 0xe2, 0xb0, 0xa3, 0x75, 0xc6, 0xda, 0xef,
- 0x4b, 0xbd, 0x57, 0xfc, 0x98, 0x34, 0xc0, 0x81, 0x76, 0x41, 0x7f, 0x90, 0xbe, 0x11, 0x74, 0xe3,
- 0x9e, 0xa9, 0x8f, 0x76, 0x60, 0xc6, 0xb9, 0x9e, 0x49, 0xc8, 0x06, 0xb4, 0xfd, 0x51, 0xa1, 0x87,
- 0x37, 0x48, 0x6c, 0x7e, 0xc4, 0xaa, 0xfe, 0x73, 0x62, 0x9c, 0x4a, 0xe7, 0xa9, 0xe4, 0xb3, 0xc8,
- 0x71, 0x4e, 0x23, 0xe6, 0xf6, 0xd3, 0x5f, 0xf3, 0xa4, 0x4b, 0xc2, 0x95, 0xfc, 0x16, 0x3b, 0x57,
- 0xff, 0x20, 0xd5, 0x8f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xbb, 0x7f, 0x47, 0xd6, 0x4b, 0x07, 0x00,
- 0x00,
+ // 649 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xbc, 0x54, 0x4d, 0x6f, 0xd3, 0x40,
+ 0x10, 0xc5, 0x69, 0x42, 0xda, 0x49, 0x6a, 0xc2, 0x94, 0x0a, 0x37, 0x45, 0x22, 0x32, 0x07, 0x0c,
+ 0x12, 0x01, 0x45, 0x82, 0x03, 0x12, 0xa0, 0xd2, 0xa6, 0xa2, 0x52, 0xdb, 0x14, 0xbb, 0x39, 0x47,
+ 0x4b, 0x32, 0x75, 0x2d, 0xf9, 0x0b, 0x7b, 0x5d, 0x91, 0x1e, 0xf8, 0x33, 0xfc, 0x08, 0x0e, 0xfc,
+ 0x39, 0xb4, 0x6b, 0x3b, 0x71, 0xd2, 0x54, 0x34, 0x7c, 0xdd, 0x76, 0xdf, 0xbe, 0xf9, 0x78, 0x33,
+ 0xcf, 0x06, 0xe0, 0x14, 0xf3, 0x76, 0x18, 0x05, 0x3c, 0xc0, 0xba, 0x1d, 0x85, 0xc3, 0xb6, 0x00,
+ 0x1c, 0xdf, 0xd6, 0xab, 0x50, 0xe9, 0x7a, 0x21, 0x1f, 0xeb, 0x87, 0x50, 0x3d, 0x61, 0x63, 0x37,
+ 0x60, 0x23, 0x7c, 0x06, 0x65, 0x3e, 0x0e, 0x49, 0x53, 0x5a, 0x8a, 0xa1, 0x76, 0xb6, 0xda, 0xc5,
+ 0x80, 0x76, 0x46, 0x3a, 0x1d, 0x87, 0x64, 0x4a, 0x1a, 0x22, 0x94, 0x3f, 0x05, 0xa3, 0xb1, 0x56,
+ 0x6a, 0x29, 0x46, 0xdd, 0x94, 0x67, 0xfd, 0x35, 0x40, 0x77, 0x78, 0x1e, 0x58, 0x9c, 0xf1, 0x24,
+ 0x16, 0x8c, 0x61, 0x30, 0x4a, 0x13, 0x56, 0x4c, 0x79, 0x46, 0x0d, 0xaa, 0x1e, 0xc5, 0x31, 0xb3,
+ 0x49, 0x06, 0xae, 0x99, 0xf9, 0x55, 0xff, 0x5e, 0x82, 0x75, 0xcb, 0xf1, 0x42, 0x97, 0x4c, 0xfa,
+ 0x9c, 0x50, 0xcc, 0xf1, 0x2d, 0xac, 0x47, 0x14, 0x87, 0x81, 0x1f, 0xd3, 0xe0, 0x66, 0x9d, 0xd5,
+ 0x73, 0xbe, 0xb8, 0xe1, 0xa3, 0x42, 0x7c, 0xec, 0x5c, 0xa6, 0x15, 0x2b, 0x53, 0x92, 0xe5, 0x5c,
+ 0x12, 0x3e, 0x87, 0x6a, 0x98, 0x66, 0xd0, 0x56, 0x5a, 0x8a, 0x51, 0xeb, 0x6c, 0x2e, 0x4c, 0x6f,
+ 0xe6, 0x2c, 0x91, 0xf5, 0xcc, 0x71, 0xdd, 0x41, 0x12, 0x53, 0xe4, 0x33, 0x8f, 0xb4, 0x72, 0x4b,
+ 0x31, 0x56, 0xcd, 0xba, 0x00, 0xfb, 0x19, 0x86, 0x06, 0x34, 0x24, 0x29, 0x60, 0x09, 0x3f, 0x1f,
+ 0xc4, 0xc3, 0x20, 0x24, 0xad, 0x22, 0x79, 0xaa, 0xc0, 0x7b, 0x02, 0xb6, 0x04, 0x8a, 0x3b, 0x70,
+ 0x67, 0xda, 0xa4, 0x9c, 0x9b, 0x56, 0x95, 0x7d, 0x68, 0xb3, 0x7d, 0x4c, 0xe7, 0x6a, 0xaa, 0x13,
+ 0x01, 0xf2, 0xae, 0x7f, 0x05, 0x35, 0x1f, 0x5c, 0x8a, 0x17, 0x45, 0x29, 0x37, 0x12, 0xd5, 0x84,
+ 0xd5, 0x89, 0x9e, 0x74, 0x2f, 0x93, 0x3b, 0x3e, 0x84, 0x5a, 0x51, 0xc6, 0x8a, 0x7c, 0x86, 0x60,
+ 0x22, 0x41, 0x3f, 0x84, 0x2d, 0x8b, 0x47, 0xc4, 0x3c, 0xc7, 0xb7, 0x0f, 0xfc, 0x30, 0xe1, 0xbb,
+ 0xcc, 0x75, 0xf3, 0x25, 0x2e, 0xdb, 0x8a, 0x7e, 0x0a, 0xcd, 0x45, 0xd9, 0x32, 0x65, 0xaf, 0xe0,
+ 0x3e, 0xb3, 0xed, 0x88, 0x6c, 0xc6, 0x69, 0x34, 0xc8, 0x62, 0xd2, 0xed, 0xa6, 0x36, 0xdb, 0x9c,
+ 0x3e, 0x67, 0xa9, 0xc5, 0x9a, 0xf5, 0x03, 0xc0, 0x3c, 0xc7, 0x09, 0x8b, 0x98, 0x47, 0x9c, 0x22,
+ 0xe9, 0xd0, 0x42, 0xa8, 0x3c, 0x0b, 0xb9, 0x8e, 0xcf, 0x29, 0xba, 0x60, 0x62, 0xc7, 0x99, 0x67,
+ 0x20, 0x87, 0xfa, 0xb1, 0xfe, 0xad, 0x54, 0xe8, 0xb0, 0x97, 0xf0, 0x39, 0xc1, 0x7f, 0xea, 0xda,
+ 0x8f, 0xb0, 0x31, 0x89, 0x0f, 0x27, 0xad, 0x6a, 0xa5, 0xd6, 0x8a, 0x51, 0xeb, 0xb4, 0x66, 0xb3,
+ 0x5c, 0x95, 0x64, 0x62, 0x74, 0x55, 0xe6, 0xd2, 0x1e, 0xff, 0x0b, 0xa6, 0x3c, 0x86, 0xed, 0x85,
+ 0x43, 0xfa, 0x4d, 0x87, 0x3e, 0x7d, 0x07, 0xb5, 0xc2, 0xcc, 0xb0, 0x01, 0xf5, 0xdd, 0xde, 0xd1,
+ 0x89, 0xd9, 0xb5, 0xac, 0x9d, 0xf7, 0x87, 0xdd, 0xc6, 0x2d, 0x44, 0x50, 0xfb, 0xc7, 0x33, 0x98,
+ 0x82, 0x00, 0xb7, 0xcd, 0x9d, 0xe3, 0xbd, 0xde, 0x51, 0xa3, 0xd4, 0xf9, 0x51, 0x86, 0xda, 0x29,
+ 0xc5, 0xdc, 0xa2, 0xe8, 0xc2, 0x19, 0x12, 0xbe, 0x84, 0x35, 0xf9, 0x0b, 0x14, 0x6d, 0xe1, 0xc6,
+ 0x9c, 0x2e, 0xf1, 0xd0, 0x5c, 0x04, 0xe2, 0x3e, 0xac, 0xf5, 0x7d, 0x16, 0xa5, 0x61, 0xdb, 0xb3,
+ 0x8c, 0x99, 0xdf, 0x57, 0xf3, 0xc1, 0xe2, 0xc7, 0x6c, 0x00, 0x2e, 0x6c, 0x2c, 0x98, 0x0f, 0x1a,
+ 0x73, 0x41, 0xd7, 0xfa, 0xac, 0xf9, 0xe4, 0x06, 0xcc, 0xb4, 0xd6, 0x0b, 0x05, 0x1d, 0xc0, 0xab,
+ 0x1f, 0x15, 0x3e, 0xbe, 0x26, 0xc5, 0xfc, 0x47, 0xdc, 0x34, 0x7e, 0x4d, 0x4c, 0x4b, 0x19, 0xa2,
+ 0x94, 0xba, 0x9f, 0xb8, 0xee, 0x5e, 0x12, 0xba, 0xf4, 0xe5, 0x9f, 0x69, 0x32, 0x14, 0xa9, 0x4a,
+ 0xfd, 0xc0, 0xdc, 0xb3, 0xff, 0x50, 0xaa, 0xd3, 0x87, 0x7b, 0x7d, 0x5f, 0x6e, 0xd0, 0x23, 0x9f,
+ 0xd3, 0x28, 0x77, 0xd1, 0x1b, 0xb8, 0x3b, 0x83, 0x2f, 0xe7, 0xa6, 0x9f, 0x01, 0x00, 0x00, 0xff,
+ 0xff, 0xdd, 0xb5, 0x50, 0x6f, 0xa2, 0x07, 0x00, 0x00,
}
diff --git a/interop/grpc_testing/test.proto b/interop/grpc_testing/test.proto
index d4ce2c1..cc2bb74 100644
--- a/interop/grpc_testing/test.proto
+++ b/interop/grpc_testing/test.proto
@@ -151,3 +151,10 @@
rpc HalfDuplexCall(stream StreamingOutputCallRequest)
returns (stream StreamingOutputCallResponse);
}
+
+// A simple service NOT implemented at servers so clients can test for
+// that case.
+service UnimplementedService {
+ // A call that no server should implement
+ rpc UnimplementedCall(grpc.testing.Empty) returns (grpc.testing.Empty);
+}
diff --git a/interop/test_utils.go b/interop/test_utils.go
index 908dd8d..e9d734e 100644
--- a/interop/test_utils.go
+++ b/interop/test_utils.go
@@ -52,10 +52,12 @@
)
var (
- reqSizes = []int{27182, 8, 1828, 45904}
- respSizes = []int{31415, 9, 2653, 58979}
- largeReqSize = 271828
- largeRespSize = 314159
+ reqSizes = []int{27182, 8, 1828, 45904}
+ respSizes = []int{31415, 9, 2653, 58979}
+ largeReqSize = 271828
+ largeRespSize = 314159
+ initialMetadataKey = "x-grpc-test-echo-initial"
+ trailingMetadataKey = "x-grpc-test-echo-trailing-bin"
)
func clientNewPayload(t testpb.PayloadType, size int) *testpb.Payload {
@@ -252,7 +254,9 @@
Payload: pl,
}
if err := stream.Send(req); err != nil {
- grpclog.Fatalf("%v.Send(%v) = %v", stream, req, err)
+ if grpc.Code(err) != codes.DeadlineExceeded {
+ grpclog.Fatalf("%v.Send(_) = %v", stream, err)
+ }
}
if _, err := stream.Recv(); grpc.Code(err) != codes.DeadlineExceeded {
grpclog.Fatalf("%v.Recv() = _, %v, want error code %d", stream, err, codes.DeadlineExceeded)
@@ -454,6 +458,92 @@
}
}
+var (
+ initialMetadataValue = "test_initial_metadata_value"
+ trailingMetadataValue = "\x0a\x0b\x0a\x0b\x0a\x0b"
+ customMetadata = metadata.Pairs(
+ initialMetadataKey, initialMetadataValue,
+ trailingMetadataKey, trailingMetadataValue,
+ )
+)
+
+func validateMetadata(header, trailer metadata.MD) {
+ if len(header[initialMetadataKey]) != 1 {
+ grpclog.Fatalf("Expected exactly one header from server. Received %d", len(header[initialMetadataKey]))
+ }
+ if header[initialMetadataKey][0] != initialMetadataValue {
+ grpclog.Fatalf("Got header %s; want %s", header[initialMetadataKey][0], initialMetadataValue)
+ }
+ if len(trailer[trailingMetadataKey]) != 1 {
+ grpclog.Fatalf("Expected exactly one trailer from server. Received %d", len(trailer[trailingMetadataKey]))
+ }
+ if trailer[trailingMetadataKey][0] != trailingMetadataValue {
+ grpclog.Fatalf("Got trailer %s; want %s", trailer[trailingMetadataKey][0], trailingMetadataValue)
+ }
+}
+
+// DoCustomMetadata checks that metadata is echoed back to the client.
+func DoCustomMetadata(tc testpb.TestServiceClient) {
+ // Testing with UnaryCall.
+ pl := clientNewPayload(testpb.PayloadType_COMPRESSABLE, 1)
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(1)),
+ Payload: pl,
+ }
+ ctx := metadata.NewContext(context.Background(), customMetadata)
+ var header, trailer metadata.MD
+ reply, err := tc.UnaryCall(
+ ctx,
+ req,
+ grpc.Header(&header),
+ grpc.Trailer(&trailer),
+ )
+ if err != nil {
+ grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
+ }
+ t := reply.GetPayload().GetType()
+ s := len(reply.GetPayload().GetBody())
+ if t != testpb.PayloadType_COMPRESSABLE || s != 1 {
+ grpclog.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, 1)
+ }
+ validateMetadata(header, trailer)
+
+ // Testing with FullDuplex.
+ stream, err := tc.FullDuplexCall(ctx)
+ if err != nil {
+ grpclog.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(1),
+ },
+ }
+ streamReq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: pl,
+ }
+ if err := stream.Send(streamReq); err != nil {
+ grpclog.Fatalf("%v.Send(%v) = %v", stream, streamReq, err)
+ }
+ streamHeader, err := stream.Header()
+ if err != nil {
+ grpclog.Fatalf("%v.Header() = %v", stream, err)
+ }
+ if _, err := stream.Recv(); err != nil {
+ grpclog.Fatalf("%v.Recv() = %v", stream, err)
+ }
+ if err := stream.CloseSend(); err != nil {
+ grpclog.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
+ }
+ if _, err := stream.Recv(); err != io.EOF {
+ grpclog.Fatalf("%v failed to complete the custom metadata test: %v", stream, err)
+ }
+ streamTrailer := stream.Trailer()
+ validateMetadata(streamHeader, streamTrailer)
+}
+
// DoStatusCodeAndMessage checks that the status code is propagated back to the client.
func DoStatusCodeAndMessage(tc testpb.TestServiceClient) {
var code int32 = 2
@@ -489,6 +579,22 @@
}
}
+// DoUnimplementedService attempts to call a method from an unimplemented service.
+func DoUnimplementedService(tc testpb.UnimplementedServiceClient) {
+ _, err := tc.UnimplementedCall(context.Background(), &testpb.Empty{})
+ if grpc.Code(err) != codes.Unimplemented {
+ grpclog.Fatalf("%v.UnimplementedCall() = _, %v, want _, %v", tc, grpc.Code(err), codes.Unimplemented)
+ }
+}
+
+// DoUnimplementedMethod attempts to call an unimplemented method.
+func DoUnimplementedMethod(cc *grpc.ClientConn) {
+ var req, reply proto.Message
+ if err := grpc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", req, reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented {
+ grpclog.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
+ }
+}
+
type testServer struct {
}
@@ -521,6 +627,16 @@
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
status := in.GetResponseStatus()
+ if md, ok := metadata.FromContext(ctx); ok {
+ if initialMetadata, ok := md[initialMetadataKey]; ok {
+ header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
+ grpc.SendHeader(ctx, header)
+ }
+ if trailingMetadata, ok := md[trailingMetadataKey]; ok {
+ trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
+ grpc.SetTrailer(ctx, trailer)
+ }
+ }
if status != nil && *status.Code != 0 {
return nil, grpc.Errorf(codes.Code(*status.Code), *status.Message)
}
@@ -570,6 +686,16 @@
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
+ if md, ok := metadata.FromContext(stream.Context()); ok {
+ if initialMetadata, ok := md[initialMetadataKey]; ok {
+ header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
+ stream.SendHeader(header)
+ }
+ if trailingMetadata, ok := md[trailingMetadataKey]; ok {
+ trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
+ stream.SetTrailer(trailer)
+ }
+ }
for {
in, err := stream.Recv()
if err == io.EOF {
diff --git a/metadata/metadata.go b/metadata/metadata.go
index 3c0ca7a..65dc5af 100644
--- a/metadata/metadata.go
+++ b/metadata/metadata.go
@@ -141,6 +141,8 @@
}
// FromContext returns the MD in ctx if it exists.
+// The returned md should be immutable, writing to it may cause races.
+// Modification should be made to the copies of the returned md.
func FromContext(ctx context.Context) (md MD, ok bool) {
md, ok = ctx.Value(mdKey{}).(MD)
return
diff --git a/reflection/grpc_reflection_v1alpha/reflection.pb.go b/reflection/grpc_reflection_v1alpha/reflection.pb.go
index da90479..76987a4 100644
--- a/reflection/grpc_reflection_v1alpha/reflection.pb.go
+++ b/reflection/grpc_reflection_v1alpha/reflection.pb.go
@@ -544,7 +544,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for ServerReflection service
@@ -643,7 +643,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor0,
+ Metadata: "reflection.proto",
}
func init() { proto.RegisterFile("reflection.proto", fileDescriptor0) }
diff --git a/reflection/grpc_testing/test.pb.go b/reflection/grpc_testing/test.pb.go
index add7abd..607dfd3 100644
--- a/reflection/grpc_testing/test.pb.go
+++ b/reflection/grpc_testing/test.pb.go
@@ -66,7 +66,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for SearchService service
@@ -195,7 +195,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor2,
+ Metadata: "test.proto",
}
func init() { proto.RegisterFile("test.proto", fileDescriptor2) }
diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go
index 686090a..d26eac3 100644
--- a/reflection/serverreflection.go
+++ b/reflection/serverreflection.go
@@ -119,11 +119,11 @@
func decompress(b []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
- return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
+ return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
out, err := ioutil.ReadAll(r)
if err != nil {
- return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
+ return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
return out, nil
}
@@ -251,11 +251,12 @@
}
// Metadata not valid.
- enc, ok := meta.([]byte)
+ fileNameForMeta, ok := meta.(string)
if !ok {
return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name)
}
+ enc := proto.FileDescriptor(fileNameForMeta)
fd, err = s.decodeFileDesc(enc)
if err != nil {
return nil, err
diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go
index ca9610e..1759e66 100644
--- a/reflection/serverreflection_test.go
+++ b/reflection/serverreflection_test.go
@@ -192,6 +192,9 @@
c := rpb.NewServerReflectionClient(conn)
stream, err := c.ServerReflectionInfo(context.Background())
+ if err != nil {
+ t.Fatalf("cannot get ServerReflectionInfo: %v", err)
+ }
testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
diff --git a/rpc_util.go b/rpc_util.go
index 6b60095..66d08b5 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -42,11 +42,13 @@
"io/ioutil"
"math"
"os"
+ "time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/transport"
)
@@ -255,9 +257,11 @@
// encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length.
-func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
- var b []byte
- var length uint
+func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) {
+ var (
+ b []byte
+ length uint
+ )
if msg != nil {
var err error
// TODO(zhaoq): optimize to reduce memory alloc and copying.
@@ -265,6 +269,12 @@
if err != nil {
return nil, err
}
+ if outPayload != nil {
+ outPayload.Payload = msg
+ // TODO truncate large payload.
+ outPayload.Data = b
+ outPayload.Length = len(b)
+ }
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
return nil, err
@@ -295,6 +305,10 @@
// Copy encoded msg to buf
copy(buf[5:], b)
+ if outPayload != nil {
+ outPayload.WireLength = len(buf)
+ }
+
return buf, nil
}
@@ -311,11 +325,14 @@
return nil
}
-func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error {
pf, d, err := p.recvMsg(maxMsgSize)
if err != nil {
return err
}
+ if inPayload != nil {
+ inPayload.WireLength = len(d)
+ }
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err
}
@@ -333,6 +350,13 @@
if err := c.Unmarshal(d, m); err != nil {
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
+ if inPayload != nil {
+ inPayload.RecvTime = time.Now()
+ inPayload.Payload = m
+ // TODO truncate large payload.
+ inPayload.Data = d
+ inPayload.Length = len(d)
+ }
return nil
}
@@ -448,10 +472,10 @@
return codes.Unknown
}
-// SupportPackageIsVersion3 is referenced from generated protocol buffer files
+// SupportPackageIsVersion4 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the grpc package.
//
// This constant may be renamed in the future if a change in the generated code
// requires a synchronised update of grpc-go and protoc-gen-go. This constant
// should not be referenced from any other code.
-const SupportPackageIsVersion3 = true
+const SupportPackageIsVersion4 = true
diff --git a/rpc_util_test.go b/rpc_util_test.go
index 0ba2d44..375e42b 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -114,7 +114,7 @@
}{
{nil, nil, []byte{0, 0, 0, 0, 0}, nil},
} {
- b, err := encode(protoCodec{}, test.msg, nil, nil)
+ b, err := encode(protoCodec{}, test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) {
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
}
@@ -199,12 +199,12 @@
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
- encoded, _ := encode(protoCodec{}, msg, nil, nil)
+ encoded, _ := encode(protoCodec{}, msg, nil, nil, nil)
encodedSz := int64(len(encoded))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
- encode(protoCodec{}, msg, nil, nil)
+ encode(protoCodec{}, msg, nil, nil, nil)
}
b.SetBytes(encodedSz)
}
diff --git a/server.go b/server.go
index e0bb187..3af001a 100644
--- a/server.go
+++ b/server.go
@@ -54,6 +54,8 @@
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
+ "google.golang.org/grpc/tap"
"google.golang.org/grpc/transport"
)
@@ -110,6 +112,7 @@
maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
+ inTapHandle tap.ServerInHandle
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
@@ -186,6 +189,17 @@
}
}
+// InTapHandle returns a ServerOption that sets the tap handle for all the server
+// transport to be created. Only one can be installed.
+func InTapHandle(h tap.ServerInHandle) ServerOption {
+ return func(o *options) {
+ if o.inTapHandle != nil {
+ panic("The tap handle has been set.")
+ }
+ o.inTapHandle = h
+ }
+}
+
// NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
@@ -412,17 +426,22 @@
if s.opts.useHandlerImpl {
s.serveUsingHandler(conn)
} else {
- s.serveNewHTTP2Transport(conn, authInfo)
+ s.serveHTTP2Transport(conn, authInfo)
}
}
-// serveNewHTTP2Transport sets up a new http/2 transport (using the
+// serveHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go) and
// serves streams on it.
// This is run in its own goroutine (it does network I/O in
// transport.NewServerTransport).
-func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
- st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
+func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
+ config := &transport.ServerConfig{
+ MaxStreams: s.opts.maxConcurrentStreams,
+ AuthInfo: authInfo,
+ InTapHandle: s.opts.inTapHandle,
+ }
+ st, err := transport.NewServerTransport("http2", c, config)
if err != nil {
s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
@@ -448,6 +467,12 @@
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}()
+ }, func(ctx context.Context, method string) context.Context {
+ if !EnableTracing {
+ return ctx
+ }
+ tr := trace.New("grpc.Recv."+methodFamily(method), method)
+ return trace.NewContext(ctx, tr)
})
wg.Wait()
}
@@ -497,15 +522,17 @@
// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
- if !EnableTracing {
+ tr, ok := trace.FromContext(stream.Context())
+ if !ok {
return nil
}
+
trInfo = &traceInfo{
- tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
+ tr: tr,
}
trInfo.firstLine.client = false
trInfo.firstLine.remoteAddr = st.RemoteAddr()
- stream.TraceContext(trInfo.tr)
+
if dl, ok := stream.Context().Deadline(); ok {
trInfo.firstLine.deadline = dl.Sub(time.Now())
}
@@ -532,11 +559,17 @@
}
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
- var cbuf *bytes.Buffer
+ var (
+ cbuf *bytes.Buffer
+ outPayload *stats.OutPayload
+ )
if cp != nil {
cbuf = new(bytes.Buffer)
}
- p, err := encode(s.opts.codec, msg, cp, cbuf)
+ if stats.On() {
+ outPayload = &stats.OutPayload{}
+ }
+ p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
if err != nil {
// This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program
@@ -547,10 +580,32 @@
// the optimal option.
grpclog.Fatalf("grpc: Server failed to encode response %v", err)
}
- return t.Write(stream, p, opts)
+ err = t.Write(stream, p, opts)
+ if err == nil && outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.Handle(stream.Context(), outPayload)
+ }
+ return err
}
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
+ if stats.On() {
+ begin := &stats.Begin{
+ BeginTime: time.Now(),
+ }
+ stats.Handle(stream.Context(), begin)
+ }
+ defer func() {
+ if stats.On() {
+ end := &stats.End{
+ EndTime: time.Now(),
+ }
+ if err != nil && err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.Handle(stream.Context(), end)
+ }
+ }()
if trInfo != nil {
defer trInfo.tr.Finish()
trInfo.firstLine.client = false
@@ -579,14 +634,14 @@
if err != nil {
switch err := err.(type) {
case *rpcError:
- if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
- if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
@@ -597,20 +652,29 @@
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
switch err := err.(type) {
case *rpcError:
- if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
+ return err
default:
- if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
-
+ // TODO checkRecvPayload always return RPC error. Add a return here if necessary.
}
- return err
+ }
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{
+ RecvTime: time.Now(),
+ }
}
statusCode := codes.OK
statusDesc := ""
df := func(v interface{}) error {
+ if inPayload != nil {
+ inPayload.WireLength = len(req)
+ }
if pf == compressionMade {
var err error
req, err = s.opts.dc.Do(bytes.NewReader(req))
@@ -618,7 +682,7 @@
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
- return err
+ return Errorf(codes.Internal, err.Error())
}
}
if len(req) > s.opts.maxMsgSize {
@@ -630,6 +694,12 @@
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
+ if inPayload != nil {
+ inPayload.Payload = v
+ inPayload.Data = req
+ inPayload.Length = len(req)
+ stats.Handle(stream.Context(), inPayload)
+ }
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
@@ -650,9 +720,8 @@
}
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
- return err
}
- return nil
+ return Errorf(statusCode, statusDesc)
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
@@ -677,11 +746,32 @@
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
- return t.WriteStatus(stream, statusCode, statusDesc)
+ errWrite := t.WriteStatus(stream, statusCode, statusDesc)
+ if statusCode != codes.OK {
+ return Errorf(statusCode, statusDesc)
+ }
+ return errWrite
}
}
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
+ if stats.On() {
+ begin := &stats.Begin{
+ BeginTime: time.Now(),
+ }
+ stats.Handle(stream.Context(), begin)
+ }
+ defer func() {
+ if stats.On() {
+ end := &stats.End{
+ EndTime: time.Now(),
+ }
+ if err != nil && err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.Handle(stream.Context(), end)
+ }
+ }()
if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type())
}
@@ -744,7 +834,11 @@
}
ss.mu.Unlock()
}
- return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
+ errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
+ if ss.statusCode != codes.OK {
+ return Errorf(ss.statusCode, ss.statusDesc)
+ }
+ return errWrite
}
@@ -759,7 +853,8 @@
trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.InvalidArgument, fmt.Sprintf("malformed method name: %q", stream.Method())); err != nil {
+ errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
+ if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -779,7 +874,8 @@
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown service %v", service)); err != nil {
+ errDesc := fmt.Sprintf("unknown service %v", service)
+ if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -804,7 +900,8 @@
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil {
+ errDesc := fmt.Sprintf("unknown method %v", method)
+ if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
diff --git a/stats/grpc_testing/test.pb.go b/stats/grpc_testing/test.pb.go
new file mode 100644
index 0000000..b24dcd8
--- /dev/null
+++ b/stats/grpc_testing/test.pb.go
@@ -0,0 +1,225 @@
+// Code generated by protoc-gen-go.
+// source: test.proto
+// DO NOT EDIT!
+
+/*
+Package grpc_testing is a generated protocol buffer package.
+
+It is generated from these files:
+ test.proto
+
+It has these top-level messages:
+ SimpleRequest
+ SimpleResponse
+*/
+package grpc_testing
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+import (
+ context "golang.org/x/net/context"
+ grpc "google.golang.org/grpc"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
+
+// Unary request.
+type SimpleRequest struct {
+ Id int32 `protobuf:"varint,2,opt,name=id" json:"id,omitempty"`
+}
+
+func (m *SimpleRequest) Reset() { *m = SimpleRequest{} }
+func (m *SimpleRequest) String() string { return proto.CompactTextString(m) }
+func (*SimpleRequest) ProtoMessage() {}
+func (*SimpleRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+// Unary response, as configured by the request.
+type SimpleResponse struct {
+ Id int32 `protobuf:"varint,3,opt,name=id" json:"id,omitempty"`
+}
+
+func (m *SimpleResponse) Reset() { *m = SimpleResponse{} }
+func (m *SimpleResponse) String() string { return proto.CompactTextString(m) }
+func (*SimpleResponse) ProtoMessage() {}
+func (*SimpleResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
+func init() {
+ proto.RegisterType((*SimpleRequest)(nil), "grpc.testing.SimpleRequest")
+ proto.RegisterType((*SimpleResponse)(nil), "grpc.testing.SimpleResponse")
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// Client API for TestService service
+
+type TestServiceClient interface {
+ // One request followed by one response.
+ // The server returns the client id as-is.
+ UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error)
+ // A sequence of requests with each request served by the server immediately.
+ // As one request could lead to multiple responses, this interface
+ // demonstrates the idea of full duplexing.
+ FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error)
+}
+
+type testServiceClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient {
+ return &testServiceClient{cc}
+}
+
+func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) {
+ out := new(SimpleResponse)
+ err := grpc.Invoke(ctx, "/grpc.testing.TestService/UnaryCall", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) {
+ stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &testServiceFullDuplexCallClient{stream}
+ return x, nil
+}
+
+type TestService_FullDuplexCallClient interface {
+ Send(*SimpleRequest) error
+ Recv() (*SimpleResponse, error)
+ grpc.ClientStream
+}
+
+type testServiceFullDuplexCallClient struct {
+ grpc.ClientStream
+}
+
+func (x *testServiceFullDuplexCallClient) Send(m *SimpleRequest) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *testServiceFullDuplexCallClient) Recv() (*SimpleResponse, error) {
+ m := new(SimpleResponse)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// Server API for TestService service
+
+type TestServiceServer interface {
+ // One request followed by one response.
+ // The server returns the client id as-is.
+ UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error)
+ // A sequence of requests with each request served by the server immediately.
+ // As one request could lead to multiple responses, this interface
+ // demonstrates the idea of full duplexing.
+ FullDuplexCall(TestService_FullDuplexCallServer) error
+}
+
+func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) {
+ s.RegisterService(&_TestService_serviceDesc, srv)
+}
+
+func _TestService_UnaryCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(SimpleRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestServiceServer).UnaryCall(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpc.testing.TestService/UnaryCall",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestServiceServer).UnaryCall(ctx, req.(*SimpleRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _TestService_FullDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(TestServiceServer).FullDuplexCall(&testServiceFullDuplexCallServer{stream})
+}
+
+type TestService_FullDuplexCallServer interface {
+ Send(*SimpleResponse) error
+ Recv() (*SimpleRequest, error)
+ grpc.ServerStream
+}
+
+type testServiceFullDuplexCallServer struct {
+ grpc.ServerStream
+}
+
+func (x *testServiceFullDuplexCallServer) Send(m *SimpleResponse) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *testServiceFullDuplexCallServer) Recv() (*SimpleRequest, error) {
+ m := new(SimpleRequest)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+var _TestService_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "grpc.testing.TestService",
+ HandlerType: (*TestServiceServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "UnaryCall",
+ Handler: _TestService_UnaryCall_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "FullDuplexCall",
+ Handler: _TestService_FullDuplexCall_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
+ Metadata: "test.proto",
+}
+
+func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 167 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
+ 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48, 0xd6, 0x03, 0x09, 0x64,
+ 0xe6, 0xa5, 0x2b, 0xc9, 0x73, 0xf1, 0x06, 0x67, 0xe6, 0x16, 0xe4, 0xa4, 0x06, 0xa5, 0x16, 0x96,
+ 0xa6, 0x16, 0x97, 0x08, 0xf1, 0x71, 0x31, 0x65, 0xa6, 0x48, 0x30, 0x29, 0x30, 0x6a, 0xb0, 0x06,
+ 0x31, 0x65, 0xa6, 0x28, 0x29, 0x70, 0xf1, 0xc1, 0x14, 0x14, 0x17, 0xe4, 0xe7, 0x15, 0xa7, 0x42,
+ 0x55, 0x30, 0xc3, 0x54, 0x18, 0x2d, 0x63, 0xe4, 0xe2, 0x0e, 0x49, 0x2d, 0x2e, 0x09, 0x4e, 0x2d,
+ 0x2a, 0xcb, 0x4c, 0x4e, 0x15, 0x72, 0xe3, 0xe2, 0x0c, 0xcd, 0x4b, 0x2c, 0xaa, 0x74, 0x4e, 0xcc,
+ 0xc9, 0x11, 0x92, 0xd6, 0x43, 0xb6, 0x4e, 0x0f, 0xc5, 0x2e, 0x29, 0x19, 0xec, 0x92, 0x50, 0x7b,
+ 0xfc, 0xb9, 0xf8, 0xdc, 0x4a, 0x73, 0x72, 0x5c, 0x4a, 0x0b, 0x72, 0x52, 0x2b, 0x28, 0x34, 0x4c,
+ 0x83, 0xd1, 0x80, 0x31, 0x89, 0x0d, 0x1c, 0x00, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x8d,
+ 0x82, 0x5b, 0xdd, 0x0e, 0x01, 0x00, 0x00,
+}
diff --git a/stats/grpc_testing/test.proto b/stats/grpc_testing/test.proto
new file mode 100644
index 0000000..54e6f74
--- /dev/null
+++ b/stats/grpc_testing/test.proto
@@ -0,0 +1,23 @@
+syntax = "proto3";
+
+package grpc.testing;
+
+message SimpleRequest {
+ int32 id = 2;
+}
+
+message SimpleResponse {
+ int32 id = 3;
+}
+
+// A simple test service.
+service TestService {
+ // One request followed by one response.
+ // The server returns the client id as-is.
+ rpc UnaryCall(SimpleRequest) returns (SimpleResponse);
+
+ // A sequence of requests with each request served by the server immediately.
+ // As one request could lead to multiple responses, this interface
+ // demonstrates the idea of full duplexing.
+ rpc FullDuplexCall(stream SimpleRequest) returns (stream SimpleResponse);
+}
diff --git a/stats/stats.go b/stats/stats.go
new file mode 100644
index 0000000..4b030d9
--- /dev/null
+++ b/stats/stats.go
@@ -0,0 +1,219 @@
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+// Package stats is for collecting and reporting various network and RPC stats.
+// This package is for monitoring purpose only. All fields are read-only.
+// All APIs are experimental.
+package stats // import "google.golang.org/grpc/stats"
+
+import (
+ "net"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/context"
+ "google.golang.org/grpc/grpclog"
+)
+
+// RPCStats contains stats information about RPCs.
+// All stats types in this package implements this interface.
+type RPCStats interface {
+ // IsClient returns true if this RPCStats is from client side.
+ IsClient() bool
+}
+
+// Begin contains stats when an RPC begins.
+// FailFast are only valid if Client is true.
+type Begin struct {
+ // Client is true if this Begin is from client side.
+ Client bool
+ // BeginTime is the time when the RPC begins.
+ BeginTime time.Time
+ // FailFast indicates if this RPC is failfast.
+ FailFast bool
+}
+
+// IsClient indicates if this is from client side.
+func (s *Begin) IsClient() bool { return s.Client }
+
+// InPayload contains the information for an incoming payload.
+type InPayload struct {
+ // Client is true if this InPayload is from client side.
+ Client bool
+ // Payload is the payload with original type.
+ Payload interface{}
+ // Data is the serialized message payload.
+ Data []byte
+ // Length is the length of uncompressed data.
+ Length int
+ // WireLength is the length of data on wire (compressed, signed, encrypted).
+ WireLength int
+ // RecvTime is the time when the payload is received.
+ RecvTime time.Time
+}
+
+// IsClient indicates if this is from client side.
+func (s *InPayload) IsClient() bool { return s.Client }
+
+// InHeader contains stats when a header is received.
+// FullMethod, addresses and Compression are only valid if Client is false.
+type InHeader struct {
+ // Client is true if this InHeader is from client side.
+ Client bool
+ // WireLength is the wire length of header.
+ WireLength int
+
+ // FullMethod is the full RPC method string, i.e., /package.service/method.
+ FullMethod string
+ // RemoteAddr is the remote address of the corresponding connection.
+ RemoteAddr net.Addr
+ // LocalAddr is the local address of the corresponding connection.
+ LocalAddr net.Addr
+ // Compression is the compression algorithm used for the RPC.
+ Compression string
+}
+
+// IsClient indicates if this is from client side.
+func (s *InHeader) IsClient() bool { return s.Client }
+
+// InTrailer contains stats when a trailer is received.
+type InTrailer struct {
+ // Client is true if this InTrailer is from client side.
+ Client bool
+ // WireLength is the wire length of trailer.
+ WireLength int
+}
+
+// IsClient indicates if this is from client side.
+func (s *InTrailer) IsClient() bool { return s.Client }
+
+// OutPayload contains the information for an outgoing payload.
+type OutPayload struct {
+ // Client is true if this OutPayload is from client side.
+ Client bool
+ // Payload is the payload with original type.
+ Payload interface{}
+ // Data is the serialized message payload.
+ Data []byte
+ // Length is the length of uncompressed data.
+ Length int
+ // WireLength is the length of data on wire (compressed, signed, encrypted).
+ WireLength int
+ // SentTime is the time when the payload is sent.
+ SentTime time.Time
+}
+
+// IsClient indicates if this is from client side.
+func (s *OutPayload) IsClient() bool { return s.Client }
+
+// OutHeader contains stats when a header is sent.
+// FullMethod, addresses and Compression are only valid if Client is true.
+type OutHeader struct {
+ // Client is true if this OutHeader is from client side.
+ Client bool
+ // WireLength is the wire length of header.
+ WireLength int
+
+ // FullMethod is the full RPC method string, i.e., /package.service/method.
+ FullMethod string
+ // RemoteAddr is the remote address of the corresponding connection.
+ RemoteAddr net.Addr
+ // LocalAddr is the local address of the corresponding connection.
+ LocalAddr net.Addr
+ // Compression is the compression algorithm used for the RPC.
+ Compression string
+}
+
+// IsClient indicates if this is from client side.
+func (s *OutHeader) IsClient() bool { return s.Client }
+
+// OutTrailer contains stats when a trailer is sent.
+type OutTrailer struct {
+ // Client is true if this OutTrailer is from client side.
+ Client bool
+ // WireLength is the wire length of trailer.
+ WireLength int
+}
+
+// IsClient indicates if this is from client side.
+func (s *OutTrailer) IsClient() bool { return s.Client }
+
+// End contains stats when an RPC ends.
+type End struct {
+ // Client is true if this End is from client side.
+ Client bool
+ // EndTime is the time when the RPC ends.
+ EndTime time.Time
+ // Error is the error just happened. Its type is gRPC error.
+ Error error
+}
+
+// IsClient indicates if this is from client side.
+func (s *End) IsClient() bool { return s.Client }
+
+var (
+ on = new(int32)
+ handler func(context.Context, RPCStats)
+)
+
+// On indicates whether stats is started.
+func On() bool {
+ return atomic.CompareAndSwapInt32(on, 1, 1)
+}
+
+// Handle processes the stats using the call back function registered by user.
+func Handle(ctx context.Context, s RPCStats) {
+ handler(ctx, s)
+}
+
+// RegisterHandler registers the user handler function.
+// If another handler was registered before, this new handler will overwrite the old one.
+// This handler function will be called to process the stats.
+func RegisterHandler(f func(context.Context, RPCStats)) {
+ handler = f
+}
+
+// Start starts the stats collection and reporting if there is a registered stats handle.
+func Start() {
+ if handler == nil {
+ grpclog.Println("handler is nil when starting stats. Stats is not started")
+ return
+ }
+ atomic.StoreInt32(on, 1)
+}
+
+// Stop stops the stats collection and processing.
+// Stop does not unregister handler.
+func Stop() {
+ atomic.StoreInt32(on, 0)
+}
diff --git a/stats/stats_test.go b/stats/stats_test.go
new file mode 100644
index 0000000..e904810
--- /dev/null
+++ b/stats/stats_test.go
@@ -0,0 +1,1028 @@
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package stats_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "reflect"
+ "sync"
+ "testing"
+
+ "github.com/golang/protobuf/proto"
+ "golang.org/x/net/context"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
+ testpb "google.golang.org/grpc/stats/grpc_testing"
+)
+
+func TestStartStop(t *testing.T) {
+ stats.RegisterHandler(nil)
+ stats.Start()
+ if stats.On() != false {
+ t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false")
+ }
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {})
+ if stats.On() != false {
+ t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false")
+ }
+ stats.Start()
+ if stats.On() != true {
+ t.Fatalf("after stats.Start(_), stats.On() = false, want true")
+ }
+ stats.Stop()
+ if stats.On() != false {
+ t.Fatalf("after stats.Stop(), stats.On() = true, want false")
+ }
+}
+
+var (
+ // For headers:
+ testMetadata = metadata.MD{
+ "key1": []string{"value1"},
+ "key2": []string{"value2"},
+ }
+ // For trailers:
+ testTrailerMetadata = metadata.MD{
+ "tkey1": []string{"trailerValue1"},
+ "tkey2": []string{"trailerValue2"},
+ }
+ // The id for which the service handler should return error.
+ errorID int32 = 32202
+)
+
+type testServer struct{}
+
+func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
+ md, ok := metadata.FromContext(ctx)
+ if ok {
+ if err := grpc.SendHeader(ctx, md); err != nil {
+ return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
+ }
+ if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
+ return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
+ }
+ }
+
+ if in.Id == errorID {
+ return nil, fmt.Errorf("got error id: %v", in.Id)
+ }
+
+ return &testpb.SimpleResponse{Id: in.Id}, nil
+}
+
+func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
+ md, ok := metadata.FromContext(stream.Context())
+ if ok {
+ if err := stream.SendHeader(md); err != nil {
+ return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
+ }
+ stream.SetTrailer(testTrailerMetadata)
+ }
+ for {
+ in, err := stream.Recv()
+ if err == io.EOF {
+ // read done.
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+
+ if in.Id == errorID {
+ return fmt.Errorf("got error id: %v", in.Id)
+ }
+
+ if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
+ return err
+ }
+ }
+}
+
+// test is an end-to-end test. It should be created with the newTest
+// func, modified as needed, and then started with its startServer method.
+// It should be cleaned up with the tearDown method.
+type test struct {
+ t *testing.T
+ compress string
+
+ ctx context.Context // valid for life of test, before tearDown
+ cancel context.CancelFunc
+
+ testServer testpb.TestServiceServer // nil means none
+ // srv and srvAddr are set once startServer is called.
+ srv *grpc.Server
+ srvAddr string
+
+ cc *grpc.ClientConn // nil until requested via clientConn
+}
+
+func (te *test) tearDown() {
+ if te.cancel != nil {
+ te.cancel()
+ te.cancel = nil
+ }
+ if te.cc != nil {
+ te.cc.Close()
+ te.cc = nil
+ }
+ te.srv.Stop()
+}
+
+// newTest returns a new test using the provided testing.T and
+// environment. It is returned with default values. Tests should
+// modify it before calling its startServer and clientConn methods.
+func newTest(t *testing.T, compress string) *test {
+ te := &test{t: t, compress: compress}
+ te.ctx, te.cancel = context.WithCancel(context.Background())
+ return te
+}
+
+// startServer starts a gRPC server listening. Callers should defer a
+// call to te.tearDown to clean up.
+func (te *test) startServer(ts testpb.TestServiceServer) {
+ te.testServer = ts
+ lis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ te.t.Fatalf("Failed to listen: %v", err)
+ }
+ var opts []grpc.ServerOption
+ if te.compress == "gzip" {
+ opts = append(opts,
+ grpc.RPCCompressor(grpc.NewGZIPCompressor()),
+ grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
+ )
+ }
+ s := grpc.NewServer(opts...)
+ te.srv = s
+ if te.testServer != nil {
+ testpb.RegisterTestServiceServer(s, te.testServer)
+ }
+ _, port, err := net.SplitHostPort(lis.Addr().String())
+ if err != nil {
+ te.t.Fatalf("Failed to parse listener address: %v", err)
+ }
+ addr := "127.0.0.1:" + port
+
+ go s.Serve(lis)
+ te.srvAddr = addr
+}
+
+func (te *test) clientConn() *grpc.ClientConn {
+ if te.cc != nil {
+ return te.cc
+ }
+ opts := []grpc.DialOption{grpc.WithInsecure()}
+ if te.compress == "gzip" {
+ opts = append(opts,
+ grpc.WithCompressor(grpc.NewGZIPCompressor()),
+ grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
+ )
+ }
+
+ var err error
+ te.cc, err = grpc.Dial(te.srvAddr, opts...)
+ if err != nil {
+ te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
+ }
+ return te.cc
+}
+
+type rpcConfig struct {
+ count int // Number of requests and responses for streaming RPCs.
+ success bool // Whether the RPC should succeed or return error.
+ failfast bool
+}
+
+func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
+ var (
+ resp *testpb.SimpleResponse
+ req *testpb.SimpleRequest
+ err error
+ )
+ tc := testpb.NewTestServiceClient(te.clientConn())
+ if c.success {
+ req = &testpb.SimpleRequest{Id: errorID + 1}
+ } else {
+ req = &testpb.SimpleRequest{Id: errorID}
+ }
+ ctx := metadata.NewContext(context.Background(), testMetadata)
+
+ resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
+ if err != nil {
+ return req, resp, err
+ }
+
+ return req, resp, err
+}
+
+func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
+ var (
+ reqs []*testpb.SimpleRequest
+ resps []*testpb.SimpleResponse
+ err error
+ )
+ tc := testpb.NewTestServiceClient(te.clientConn())
+ stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
+ if err != nil {
+ return reqs, resps, err
+ }
+ var startID int32
+ if !c.success {
+ startID = errorID
+ }
+ for i := 0; i < c.count; i++ {
+ req := &testpb.SimpleRequest{
+ Id: int32(i) + startID,
+ }
+ reqs = append(reqs, req)
+ if err = stream.Send(req); err != nil {
+ return reqs, resps, err
+ }
+ var resp *testpb.SimpleResponse
+ if resp, err = stream.Recv(); err != nil {
+ return reqs, resps, err
+ }
+ resps = append(resps, resp)
+ }
+ if err = stream.CloseSend(); err != nil {
+ return reqs, resps, err
+ }
+ if _, err = stream.Recv(); err != io.EOF {
+ return reqs, resps, err
+ }
+
+ return reqs, resps, err
+}
+
+type expectedData struct {
+ method string
+ serverAddr string
+ compression string
+ reqIdx int
+ requests []*testpb.SimpleRequest
+ respIdx int
+ responses []*testpb.SimpleResponse
+ err error
+ failfast bool
+}
+
+type gotData struct {
+ ctx context.Context
+ client bool
+ s stats.RPCStats
+}
+
+const (
+ begin int = iota
+ end
+ inPayload
+ inHeader
+ inTrailer
+ outPayload
+ outHeader
+ outTrailer
+)
+
+func checkBegin(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.Begin
+ )
+ if st, ok = d.s.(*stats.Begin); !ok {
+ t.Fatalf("got %T, want Begin", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ if st.BeginTime.IsZero() {
+ t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
+ }
+ if d.client {
+ if st.FailFast != e.failfast {
+ t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
+ }
+ }
+}
+
+func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.InHeader
+ )
+ if st, ok = d.s.(*stats.InHeader); !ok {
+ t.Fatalf("got %T, want InHeader", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ // TODO check real length, not just > 0.
+ if st.WireLength <= 0 {
+ t.Fatalf("st.Lenght = 0, want > 0")
+ }
+ if !d.client {
+ if st.FullMethod != e.method {
+ t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
+ }
+ if st.LocalAddr.String() != e.serverAddr {
+ t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
+ }
+ if st.Compression != e.compression {
+ t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
+ }
+ }
+}
+
+func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.InPayload
+ )
+ if st, ok = d.s.(*stats.InPayload); !ok {
+ t.Fatalf("got %T, want InPayload", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ if d.client {
+ b, err := proto.Marshal(e.responses[e.respIdx])
+ if err != nil {
+ t.Fatalf("failed to marshal message: %v", err)
+ }
+ if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
+ t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
+ }
+ e.respIdx++
+ if string(st.Data) != string(b) {
+ t.Fatalf("st.Data = %v, want %v", st.Data, b)
+ }
+ if st.Length != len(b) {
+ t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
+ }
+ } else {
+ b, err := proto.Marshal(e.requests[e.reqIdx])
+ if err != nil {
+ t.Fatalf("failed to marshal message: %v", err)
+ }
+ if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
+ t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
+ }
+ e.reqIdx++
+ if string(st.Data) != string(b) {
+ t.Fatalf("st.Data = %v, want %v", st.Data, b)
+ }
+ if st.Length != len(b) {
+ t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
+ }
+ }
+ // TODO check WireLength and ReceivedTime.
+ if st.RecvTime.IsZero() {
+ t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
+ }
+}
+
+func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.InTrailer
+ )
+ if st, ok = d.s.(*stats.InTrailer); !ok {
+ t.Fatalf("got %T, want InTrailer", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ // TODO check real length, not just > 0.
+ if st.WireLength <= 0 {
+ t.Fatalf("st.Lenght = 0, want > 0")
+ }
+}
+
+func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.OutHeader
+ )
+ if st, ok = d.s.(*stats.OutHeader); !ok {
+ t.Fatalf("got %T, want OutHeader", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ // TODO check real length, not just > 0.
+ if st.WireLength <= 0 {
+ t.Fatalf("st.Lenght = 0, want > 0")
+ }
+ if d.client {
+ if st.FullMethod != e.method {
+ t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
+ }
+ if st.RemoteAddr.String() != e.serverAddr {
+ t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
+ }
+ if st.Compression != e.compression {
+ t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
+ }
+ }
+}
+
+func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.OutPayload
+ )
+ if st, ok = d.s.(*stats.OutPayload); !ok {
+ t.Fatalf("got %T, want OutPayload", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ if d.client {
+ b, err := proto.Marshal(e.requests[e.reqIdx])
+ if err != nil {
+ t.Fatalf("failed to marshal message: %v", err)
+ }
+ if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
+ t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
+ }
+ e.reqIdx++
+ if string(st.Data) != string(b) {
+ t.Fatalf("st.Data = %v, want %v", st.Data, b)
+ }
+ if st.Length != len(b) {
+ t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
+ }
+ } else {
+ b, err := proto.Marshal(e.responses[e.respIdx])
+ if err != nil {
+ t.Fatalf("failed to marshal message: %v", err)
+ }
+ if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
+ t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
+ }
+ e.respIdx++
+ if string(st.Data) != string(b) {
+ t.Fatalf("st.Data = %v, want %v", st.Data, b)
+ }
+ if st.Length != len(b) {
+ t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
+ }
+ }
+ // TODO check WireLength and ReceivedTime.
+ if st.SentTime.IsZero() {
+ t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
+ }
+}
+
+func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.OutTrailer
+ )
+ if st, ok = d.s.(*stats.OutTrailer); !ok {
+ t.Fatalf("got %T, want OutTrailer", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ if st.Client {
+ t.Fatalf("st IsClient = true, want false")
+ }
+ // TODO check real length, not just > 0.
+ if st.WireLength <= 0 {
+ t.Fatalf("st.Lenght = 0, want > 0")
+ }
+}
+
+func checkEnd(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.End
+ )
+ if st, ok = d.s.(*stats.End); !ok {
+ t.Fatalf("got %T, want End", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ if st.EndTime.IsZero() {
+ t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
+ }
+ if grpc.Code(st.Error) != grpc.Code(e.err) || grpc.ErrorDesc(st.Error) != grpc.ErrorDesc(e.err) {
+ t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
+ }
+}
+
+func TestServerStatsUnaryRPC(t *testing.T) {
+ var got []*gotData
+
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ req, resp, err := te.doUnaryCall(&rpcConfig{success: true})
+ if err != nil {
+ t.Fatalf(err.Error())
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/UnaryCall",
+ serverAddr: te.srvAddr,
+ requests: []*testpb.SimpleRequest{req},
+ responses: []*testpb.SimpleResponse{resp},
+ }
+
+ checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkInHeader,
+ checkBegin,
+ checkInPayload,
+ checkOutHeader,
+ checkOutPayload,
+ checkOutTrailer,
+ checkEnd,
+ }
+
+ if len(got) != len(checkFuncs) {
+ t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
+ }
+
+ for i := 0; i < len(got)-1; i++ {
+ if got[i].ctx != got[i+1].ctx {
+ t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ }
+ }
+
+ for i, f := range checkFuncs {
+ f(t, got[i], expect)
+ }
+}
+
+func TestServerStatsUnaryRPCError(t *testing.T) {
+ var got []*gotData
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ req, resp, err := te.doUnaryCall(&rpcConfig{success: false})
+ if err == nil {
+ t.Fatalf("got error <nil>; want <non-nil>")
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/UnaryCall",
+ serverAddr: te.srvAddr,
+ requests: []*testpb.SimpleRequest{req},
+ responses: []*testpb.SimpleResponse{resp},
+ err: err,
+ }
+
+ checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkInHeader,
+ checkBegin,
+ checkInPayload,
+ checkOutHeader,
+ checkOutTrailer,
+ checkEnd,
+ }
+
+ if len(got) != len(checkFuncs) {
+ t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
+ }
+
+ for i := 0; i < len(got)-1; i++ {
+ if got[i].ctx != got[i+1].ctx {
+ t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ }
+ }
+
+ for i, f := range checkFuncs {
+ f(t, got[i], expect)
+ }
+}
+
+func TestServerStatsStreamingRPC(t *testing.T) {
+ var got []*gotData
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "gzip")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ count := 5
+ reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true})
+ if err == nil {
+ t.Fatalf(err.Error())
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/FullDuplexCall",
+ serverAddr: te.srvAddr,
+ compression: "gzip",
+ requests: reqs,
+ responses: resps,
+ }
+
+ checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkInHeader,
+ checkBegin,
+ checkOutHeader,
+ }
+ ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkInPayload,
+ checkOutPayload,
+ }
+ for i := 0; i < count; i++ {
+ checkFuncs = append(checkFuncs, ioPayFuncs...)
+ }
+ checkFuncs = append(checkFuncs, checkOutTrailer, checkEnd)
+
+ if len(got) != len(checkFuncs) {
+ t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
+ }
+
+ for i := 0; i < len(got)-1; i++ {
+ if got[i].ctx != got[i+1].ctx {
+ t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ }
+ }
+
+ for i, f := range checkFuncs {
+ f(t, got[i], expect)
+ }
+}
+
+func TestServerStatsStreamingRPCError(t *testing.T) {
+ var got []*gotData
+
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "gzip")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ count := 5
+ reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false})
+ if err == nil {
+ t.Fatalf("got error <nil>; want <non-nil>")
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/FullDuplexCall",
+ serverAddr: te.srvAddr,
+ compression: "gzip",
+ requests: reqs,
+ responses: resps,
+ err: err,
+ }
+
+ checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkInHeader,
+ checkBegin,
+ checkOutHeader,
+ checkInPayload,
+ checkOutTrailer,
+ checkEnd,
+ }
+
+ if len(got) != len(checkFuncs) {
+ t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
+ }
+
+ for i := 0; i < len(got)-1; i++ {
+ if got[i].ctx != got[i+1].ctx {
+ t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ }
+ }
+
+ for i, f := range checkFuncs {
+ f(t, got[i], expect)
+ }
+}
+
+type checkFuncWithCount struct {
+ f func(t *testing.T, d *gotData, e *expectedData)
+ c int // expected count
+}
+
+func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
+ var expectLen int
+ for _, v := range checkFuncs {
+ expectLen += v.c
+ }
+ if len(got) != expectLen {
+ t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
+ }
+
+ for i := 0; i < len(got)-1; i++ {
+ if got[i].ctx != got[i+1].ctx {
+ t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ }
+ }
+
+ for _, s := range got {
+ switch s.s.(type) {
+ case *stats.Begin:
+ if checkFuncs[begin].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[begin].f(t, s, expect)
+ checkFuncs[begin].c--
+ case *stats.OutHeader:
+ if checkFuncs[outHeader].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[outHeader].f(t, s, expect)
+ checkFuncs[outHeader].c--
+ case *stats.OutPayload:
+ if checkFuncs[outPayload].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[outPayload].f(t, s, expect)
+ checkFuncs[outPayload].c--
+ case *stats.InHeader:
+ if checkFuncs[inHeader].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[inHeader].f(t, s, expect)
+ checkFuncs[inHeader].c--
+ case *stats.InPayload:
+ if checkFuncs[inPayload].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[inPayload].f(t, s, expect)
+ checkFuncs[inPayload].c--
+ case *stats.InTrailer:
+ if checkFuncs[inTrailer].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[inTrailer].f(t, s, expect)
+ checkFuncs[inTrailer].c--
+ case *stats.End:
+ if checkFuncs[end].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ checkFuncs[end].f(t, s, expect)
+ checkFuncs[end].c--
+ default:
+ t.Fatalf("unexpected stats: %T", s)
+ }
+ }
+}
+
+func TestClientStatsUnaryRPC(t *testing.T) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ failfast := false
+ req, resp, err := te.doUnaryCall(&rpcConfig{success: true, failfast: failfast})
+ if err != nil {
+ t.Fatalf(err.Error())
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/UnaryCall",
+ serverAddr: te.srvAddr,
+ requests: []*testpb.SimpleRequest{req},
+ responses: []*testpb.SimpleResponse{resp},
+ failfast: failfast,
+ }
+
+ checkFuncs := map[int]*checkFuncWithCount{
+ begin: {checkBegin, 1},
+ outHeader: {checkOutHeader, 1},
+ outPayload: {checkOutPayload, 1},
+ inHeader: {checkInHeader, 1},
+ inPayload: {checkInPayload, 1},
+ inTrailer: {checkInTrailer, 1},
+ end: {checkEnd, 1},
+ }
+
+ checkClientStats(t, got, expect, checkFuncs)
+}
+
+func TestClientStatsUnaryRPCError(t *testing.T) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ failfast := true
+ req, resp, err := te.doUnaryCall(&rpcConfig{success: false, failfast: failfast})
+ if err == nil {
+ t.Fatalf("got error <nil>; want <non-nil>")
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/UnaryCall",
+ serverAddr: te.srvAddr,
+ requests: []*testpb.SimpleRequest{req},
+ responses: []*testpb.SimpleResponse{resp},
+ err: err,
+ failfast: failfast,
+ }
+
+ checkFuncs := map[int]*checkFuncWithCount{
+ begin: {checkBegin, 1},
+ outHeader: {checkOutHeader, 1},
+ outPayload: {checkOutPayload, 1},
+ inHeader: {checkInHeader, 1},
+ inTrailer: {checkInTrailer, 1},
+ end: {checkEnd, 1},
+ }
+
+ checkClientStats(t, got, expect, checkFuncs)
+}
+
+func TestClientStatsStreamingRPC(t *testing.T) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ // t.Logf(" == %T %v", s, s.IsClient())
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "gzip")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ count := 5
+ failfast := false
+ reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true, failfast: failfast})
+ if err == nil {
+ t.Fatalf(err.Error())
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/FullDuplexCall",
+ serverAddr: te.srvAddr,
+ compression: "gzip",
+ requests: reqs,
+ responses: resps,
+ failfast: failfast,
+ }
+
+ checkFuncs := map[int]*checkFuncWithCount{
+ begin: {checkBegin, 1},
+ outHeader: {checkOutHeader, 1},
+ outPayload: {checkOutPayload, count},
+ inHeader: {checkInHeader, 1},
+ inPayload: {checkInPayload, count},
+ inTrailer: {checkInTrailer, 1},
+ end: {checkEnd, 1},
+ }
+
+ checkClientStats(t, got, expect, checkFuncs)
+}
+
+func TestClientStatsStreamingRPCError(t *testing.T) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.Start()
+ defer stats.Stop()
+
+ te := newTest(t, "gzip")
+ te.startServer(&testServer{})
+ defer te.tearDown()
+
+ count := 5
+ failfast := true
+ reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false, failfast: failfast})
+ if err == nil {
+ t.Fatalf("got error <nil>; want <non-nil>")
+ }
+ te.srv.GracefulStop() // Wait for the server to stop.
+
+ expect := &expectedData{
+ method: "/grpc.testing.TestService/FullDuplexCall",
+ serverAddr: te.srvAddr,
+ compression: "gzip",
+ requests: reqs,
+ responses: resps,
+ err: err,
+ failfast: failfast,
+ }
+
+ checkFuncs := map[int]*checkFuncWithCount{
+ begin: {checkBegin, 1},
+ outHeader: {checkOutHeader, 1},
+ outPayload: {checkOutPayload, 1},
+ inHeader: {checkInHeader, 1},
+ inTrailer: {checkInTrailer, 1},
+ end: {checkEnd, 1},
+ }
+
+ checkClientStats(t, got, expect, checkFuncs)
+}
diff --git a/stream.go b/stream.go
index 4681054..95c8acf 100644
--- a/stream.go
+++ b/stream.go
@@ -45,6 +45,7 @@
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/transport"
)
@@ -97,7 +98,7 @@
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
-func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
@@ -143,6 +144,24 @@
}
}()
}
+ if stats.On() {
+ begin := &stats.Begin{
+ Client: true,
+ BeginTime: time.Now(),
+ FailFast: c.failFast,
+ }
+ stats.Handle(ctx, begin)
+ }
+ defer func() {
+ if err != nil && stats.On() {
+ // Only handle end stats if err != nil.
+ end := &stats.End{
+ Client: true,
+ Error: err,
+ }
+ stats.Handle(ctx, end)
+ }
+ }()
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
@@ -194,6 +213,8 @@
tracing: EnableTracing,
trInfo: trInfo,
+
+ statsCtx: ctx,
}
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
@@ -246,6 +267,11 @@
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
trInfo traceInfo
+
+ // statsCtx keeps the user context for stats handling.
+ // All stats collection should use the statsCtx (instead of the stream context)
+ // so that all the generated stats for a particular RPC can be associated in the processing phase.
+ statsCtx context.Context
}
func (cs *clientStream) Context() context.Context {
@@ -274,6 +300,8 @@
}
cs.mu.Unlock()
}
+ // TODO Investigate how to signal the stats handling party.
+ // generate error stats if err != nil && err != io.EOF?
defer func() {
if err != nil {
cs.finish(err)
@@ -296,7 +324,13 @@
}
err = toRPCErr(err)
}()
- out, err := encode(cs.codec, m, cs.cp, cs.cbuf)
+ var outPayload *stats.OutPayload
+ if stats.On() {
+ outPayload = &stats.OutPayload{
+ Client: true,
+ }
+ }
+ out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
defer func() {
if cs.cbuf != nil {
cs.cbuf.Reset()
@@ -305,11 +339,37 @@
if err != nil {
return Errorf(codes.Internal, "grpc: %v", err)
}
- return cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ if err == nil && outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.Handle(cs.statsCtx, outPayload)
+ }
+ return err
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
+ defer func() {
+ if err != nil && stats.On() {
+ // Only generate End if err != nil.
+ // If err == nil, it's not the last RecvMsg.
+ // The last RecvMsg gets either an RPC error or io.EOF.
+ end := &stats.End{
+ Client: true,
+ EndTime: time.Now(),
+ }
+ if err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.Handle(cs.statsCtx, end)
+ }
+ }()
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{
+ Client: true,
+ }
+ }
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -324,11 +384,15 @@
}
cs.mu.Unlock()
}
+ if inPayload != nil {
+ stats.Handle(cs.statsCtx, inPayload)
+ }
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
}
// Special handling for client streaming rpc.
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
+ // This recv expects EOF or errors, so we don't collect inPayload.
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -482,7 +546,11 @@
ss.mu.Unlock()
}
}()
- out, err := encode(ss.codec, m, ss.cp, ss.cbuf)
+ var outPayload *stats.OutPayload
+ if stats.On() {
+ outPayload = &stats.OutPayload{}
+ }
+ out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
defer func() {
if ss.cbuf != nil {
ss.cbuf.Reset()
@@ -495,6 +563,10 @@
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
+ if outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.Handle(ss.s.Context(), outPayload)
+ }
return nil
}
@@ -513,7 +585,11 @@
ss.mu.Unlock()
}
}()
- if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize); err != nil {
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{}
+ }
+ if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
if err == io.EOF {
return err
}
@@ -522,5 +598,8 @@
}
return toRPCErr(err)
}
+ if inPayload != nil {
+ stats.Handle(ss.s.Context(), inPayload)
+ }
return nil
}
diff --git a/stress/client/main.go b/stress/client/main.go
index 4579aab..99e164b 100644
--- a/stress/client/main.go
+++ b/stress/client/main.go
@@ -47,6 +47,7 @@
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/interop"
testpb "google.golang.org/grpc/interop/grpc_testing"
@@ -60,6 +61,12 @@
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
+ useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
+ testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
+ tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
+
+ // The test CA root cert file
+ testCAFile = "testdata/ca.pem"
)
// testCaseWithWeight contains the test case type and its weight.
@@ -180,7 +187,7 @@
return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
}
-// createGauge creates a guage using the given name in metrics server.
+// createGauge creates a gauge using the given name in metrics server.
func (s *server) createGauge(name string) *gauge {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -242,10 +249,13 @@
func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
grpclog.Printf("server_addresses: %s", *serverAddresses)
grpclog.Printf("test_cases: %s", *testCases)
- grpclog.Printf("test_duration-secs: %d", *testDurationSecs)
+ grpclog.Printf("test_duration_secs: %d", *testDurationSecs)
grpclog.Printf("num_channels_per_server: %d", *numChannelsPerServer)
grpclog.Printf("num_stubs_per_channel: %d", *numStubsPerChannel)
grpclog.Printf("metrics_port: %d", *metricsPort)
+ grpclog.Printf("use_tls: %t", *useTLS)
+ grpclog.Printf("use_test_ca: %t", *testCA)
+ grpclog.Printf("server_host_override: %s", *tlsServerName)
grpclog.Println("addresses:")
for i, addr := range addresses {
@@ -257,6 +267,30 @@
}
}
+func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
+ var opts []grpc.DialOption
+ if useTLS {
+ var sn string
+ if tlsServerName != "" {
+ sn = tlsServerName
+ }
+ var creds credentials.TransportCredentials
+ if testCA {
+ var err error
+ creds, err = credentials.NewClientTLSFromFile(testCAFile, sn)
+ if err != nil {
+ grpclog.Fatalf("Failed to create TLS credentials %v", err)
+ }
+ } else {
+ creds = credentials.NewClientTLSFromCert(nil, sn)
+ }
+ opts = append(opts, grpc.WithTransportCredentials(creds))
+ } else {
+ opts = append(opts, grpc.WithInsecure())
+ }
+ return grpc.Dial(address, opts...)
+}
+
func main() {
flag.Parse()
addresses := strings.Split(*serverAddresses, ",")
@@ -271,7 +305,7 @@
for serverIndex, address := range addresses {
for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
- conn, err := grpc.Dial(address, grpc.WithInsecure())
+ conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
if err != nil {
grpclog.Fatalf("Fail to dial: %v", err)
}
diff --git a/stress/client/testdata/ca.pem b/stress/client/testdata/ca.pem
new file mode 100644
index 0000000..6c8511a
--- /dev/null
+++ b/stress/client/testdata/ca.pem
@@ -0,0 +1,15 @@
+-----BEGIN CERTIFICATE-----
+MIICSjCCAbOgAwIBAgIJAJHGGR4dGioHMA0GCSqGSIb3DQEBCwUAMFYxCzAJBgNV
+BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
+aWRnaXRzIFB0eSBMdGQxDzANBgNVBAMTBnRlc3RjYTAeFw0xNDExMTEyMjMxMjla
+Fw0yNDExMDgyMjMxMjlaMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0
+YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMT
+BnRlc3RjYTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAwEDfBV5MYdlHVHJ7
++L4nxrZy7mBfAVXpOc5vMYztssUI7mL2/iYujiIXM+weZYNTEpLdjyJdu7R5gGUu
+g1jSVK/EPHfc74O7AyZU34PNIP4Sh33N+/A5YexrNgJlPY+E3GdVYi4ldWJjgkAd
+Qah2PH5ACLrIIC6tRka9hcaBlIECAwEAAaMgMB4wDAYDVR0TBAUwAwEB/zAOBgNV
+HQ8BAf8EBAMCAgQwDQYJKoZIhvcNAQELBQADgYEAHzC7jdYlzAVmddi/gdAeKPau
+sPBG/C2HCWqHzpCUHcKuvMzDVkY/MP2o6JIW2DBbY64bO/FceExhjcykgaYtCH/m
+oIU63+CFOTtR7otyQAWHqXa7q4SbCDlG7DyRFxqG0txPtGvy12lgldA2+RgcigQG
+Dfcog5wrJytaQ6UA0wE=
+-----END CERTIFICATE-----
diff --git a/stress/client/testdata/server1.key b/stress/client/testdata/server1.key
new file mode 100644
index 0000000..143a5b8
--- /dev/null
+++ b/stress/client/testdata/server1.key
@@ -0,0 +1,16 @@
+-----BEGIN PRIVATE KEY-----
+MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAOHDFScoLCVJpYDD
+M4HYtIdV6Ake/sMNaaKdODjDMsux/4tDydlumN+fm+AjPEK5GHhGn1BgzkWF+slf
+3BxhrA/8dNsnunstVA7ZBgA/5qQxMfGAq4wHNVX77fBZOgp9VlSMVfyd9N8YwbBY
+AckOeUQadTi2X1S6OgJXgQ0m3MWhAgMBAAECgYAn7qGnM2vbjJNBm0VZCkOkTIWm
+V10okw7EPJrdL2mkre9NasghNXbE1y5zDshx5Nt3KsazKOxTT8d0Jwh/3KbaN+YY
+tTCbKGW0pXDRBhwUHRcuRzScjli8Rih5UOCiZkhefUTcRb6xIhZJuQy71tjaSy0p
+dHZRmYyBYO2YEQ8xoQJBAPrJPhMBkzmEYFtyIEqAxQ/o/A6E+E4w8i+KM7nQCK7q
+K4JXzyXVAjLfyBZWHGM2uro/fjqPggGD6QH1qXCkI4MCQQDmdKeb2TrKRh5BY1LR
+81aJGKcJ2XbcDu6wMZK4oqWbTX2KiYn9GB0woM6nSr/Y6iy1u145YzYxEV/iMwff
+DJULAkB8B2MnyzOg0pNFJqBJuH29bKCcHa8gHJzqXhNO5lAlEbMK95p/P2Wi+4Hd
+aiEIAF1BF326QJcvYKmwSmrORp85AkAlSNxRJ50OWrfMZnBgzVjDx3xG6KsFQVk2
+ol6VhqL6dFgKUORFUWBvnKSyhjJxurlPEahV6oo6+A+mPhFY8eUvAkAZQyTdupP3
+XEFQKctGz+9+gKkemDp7LBBMEMBXrGTLPhpEfcjv/7KPdnFHYmhYeBTBnuVmTVWe
+F98XJ7tIFfJq
+-----END PRIVATE KEY-----
diff --git a/stress/client/testdata/server1.pem b/stress/client/testdata/server1.pem
new file mode 100644
index 0000000..f3d43fc
--- /dev/null
+++ b/stress/client/testdata/server1.pem
@@ -0,0 +1,16 @@
+-----BEGIN CERTIFICATE-----
+MIICnDCCAgWgAwIBAgIBBzANBgkqhkiG9w0BAQsFADBWMQswCQYDVQQGEwJBVTET
+MBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ
+dHkgTHRkMQ8wDQYDVQQDEwZ0ZXN0Y2EwHhcNMTUxMTA0MDIyMDI0WhcNMjUxMTAx
+MDIyMDI0WjBlMQswCQYDVQQGEwJVUzERMA8GA1UECBMISWxsaW5vaXMxEDAOBgNV
+BAcTB0NoaWNhZ28xFTATBgNVBAoTDEV4YW1wbGUsIENvLjEaMBgGA1UEAxQRKi50
+ZXN0Lmdvb2dsZS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAOHDFSco
+LCVJpYDDM4HYtIdV6Ake/sMNaaKdODjDMsux/4tDydlumN+fm+AjPEK5GHhGn1Bg
+zkWF+slf3BxhrA/8dNsnunstVA7ZBgA/5qQxMfGAq4wHNVX77fBZOgp9VlSMVfyd
+9N8YwbBYAckOeUQadTi2X1S6OgJXgQ0m3MWhAgMBAAGjazBpMAkGA1UdEwQCMAAw
+CwYDVR0PBAQDAgXgME8GA1UdEQRIMEaCECoudGVzdC5nb29nbGUuZnKCGHdhdGVy
+em9vaS50ZXN0Lmdvb2dsZS5iZYISKi50ZXN0LnlvdXR1YmUuY29thwTAqAEDMA0G
+CSqGSIb3DQEBCwUAA4GBAJFXVifQNub1LUP4JlnX5lXNlo8FxZ2a12AFQs+bzoJ6
+hM044EDjqyxUqSbVePK0ni3w1fHQB5rY9yYC5f8G7aqqTY1QOhoUk8ZTSTRpnkTh
+y4jjdvTZeLDVBlueZUTDRmy2feY5aZIU18vFDK08dTG0A87pppuv1LNIR3loveU8
+-----END CERTIFICATE-----
diff --git a/stress/grpc_testing/metrics.pb.go b/stress/grpc_testing/metrics.pb.go
index 4ad4ccd..a1310b5 100644
--- a/stress/grpc_testing/metrics.pb.go
+++ b/stress/grpc_testing/metrics.pb.go
@@ -205,7 +205,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for MetricsService service
@@ -335,7 +335,7 @@
ServerStreams: true,
},
},
- Metadata: fileDescriptor0,
+ Metadata: "metrics.proto",
}
func init() { proto.RegisterFile("metrics.proto", fileDescriptor0) }
diff --git a/tap/tap.go b/tap/tap.go
new file mode 100644
index 0000000..0f36647
--- /dev/null
+++ b/tap/tap.go
@@ -0,0 +1,54 @@
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+// Package tap defines the function handles which are executed on the transport
+// layer of gRPC-Go and related information. Everything here is EXPERIMENTAL.
+package tap
+
+import (
+ "golang.org/x/net/context"
+)
+
+// Info defines the relevant information needed by the handles.
+type Info struct {
+ // FullMethodName is the string of grpc method (in the format of
+ // /package.service/method).
+ FullMethodName string
+ // TODO: More to be added.
+}
+
+// ServerInHandle defines the function which runs when a new stream is created
+// on the server side. Note that it is executed in the per-connection I/O goroutine(s) instead
+// of per-RPC goroutine. Therefore, users should NOT have any blocking/time-consuming
+// work in this handle. Otherwise all the RPCs would slow down.
+type ServerInHandle func(ctx context.Context, info *Info) (context.Context, error)
diff --git a/test/end2end_test.go b/test/end2end_test.go
index d12a1f9..88c3626 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -65,6 +65,7 @@
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/tap"
testpb "google.golang.org/grpc/test/grpc_testing"
)
@@ -418,6 +419,7 @@
testServer testpb.TestServiceServer // nil means none
healthServer *health.Server // nil means disabled
maxStream uint32
+ tapHandle tap.ServerInHandle
maxMsgSize int
userAgent string
clientCompression bool
@@ -473,6 +475,9 @@
if te.maxMsgSize > 0 {
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
}
+ if te.tapHandle != nil {
+ sopts = append(sopts, grpc.InTapHandle(te.tapHandle))
+ }
if te.serverCompression {
sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
@@ -625,7 +630,10 @@
}
te.srv.Stop()
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond)
- if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
+ _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false))
+ if e.balancer && grpc.Code(err) != codes.DeadlineExceeded {
+ // If e.balancer == nil, the ac will stop reconnecting because the dialer returns non-temp error,
+ // the error will be an internal error.
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %s", ctx, err, codes.DeadlineExceeded)
}
awaitNewConnLogOutput()
@@ -1005,6 +1013,68 @@
awaitNewConnLogOutput()
}
+func TestTap(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ if e.name == "handler-tls" {
+ continue
+ }
+ testTap(t, e)
+ }
+}
+
+type myTap struct {
+ cnt int
+}
+
+func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
+ if info != nil {
+ if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" {
+ t.cnt++
+ } else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" {
+ return nil, fmt.Errorf("tap error")
+ }
+ }
+ return ctx, nil
+}
+
+func testTap(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.userAgent = testAppUA
+ ttap := &myTap{}
+ te.tapHandle = ttap.handle
+ te.declareLogNoise(
+ "transport: http2Client.notifyError got notified that the client transport was broken EOF",
+ "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
+ "grpc: addrConn.resetTransport failed to create client transport: connection error",
+ )
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+ t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
+ }
+ if ttap.cnt != 1 {
+ t.Fatalf("Get the count in ttap %d, want 1", ttap.cnt)
+ }
+
+ payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 31)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(45),
+ Payload: payload,
+ }
+ if _, err := tc.UnaryCall(context.Background(), req); grpc.Code(err) != codes.Unavailable {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
+ }
+}
+
func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
ctx, _ := context.WithTimeout(context.Background(), d)
hc := healthpb.NewHealthClient(cc)
@@ -2504,8 +2574,7 @@
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
- ctx, cancel := context.WithCancel(context.Background())
- if _, err := tc.StreamingInputCall(ctx); err != nil {
+ if _, err := tc.StreamingInputCall(context.Background()); err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
// Loop until the new max stream setting is effective.
@@ -2522,18 +2591,26 @@
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded)
}
- cancel()
var wg sync.WaitGroup
- for i := 0; i < 100; i++ {
+ for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
- ctx, cancel := context.WithCancel(context.Background())
- if _, err := tc.StreamingInputCall(ctx); err != nil {
- t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+ payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 314)
+ if err != nil {
+ t.Fatal(err)
}
- cancel()
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(1592),
+ Payload: payload,
+ }
+ // No rpc should go through due to the max streams limit.
+ ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
+ t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
+ }
}()
}
wg.Wait()
diff --git a/test/grpc_testing/test.pb.go b/test/grpc_testing/test.pb.go
index 0ceb12d..e584c4d 100644
--- a/test/grpc_testing/test.pb.go
+++ b/test/grpc_testing/test.pb.go
@@ -360,7 +360,7 @@
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion3
+const _ = grpc.SupportPackageIsVersion4
// Client API for TestService service
@@ -742,7 +742,7 @@
ClientStreams: true,
},
},
- Metadata: fileDescriptor0,
+ Metadata: "test.proto",
}
func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 114e349..10b6dc0 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -268,7 +268,7 @@
})
}
-func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
+func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var ctx context.Context
diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go
index 84fc917..9843d36 100644
--- a/transport/handler_server_test.go
+++ b/transport/handler_server_test.go
@@ -300,7 +300,10 @@
st.bodyw.Close() // no body
st.ht.WriteStatus(s, codes.OK, "")
}
- st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
+ st.ht.HandleStreams(
+ func(s *Stream) { go handleStream(s) },
+ func(ctx context.Context, method string) context.Context { return ctx },
+ )
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@@ -327,7 +330,10 @@
handleStream := func(s *Stream) {
st.ht.WriteStatus(s, statusCode, msg)
}
- st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
+ st.ht.HandleStreams(
+ func(s *Stream) { go handleStream(s) },
+ func(ctx context.Context, method string) context.Context { return ctx },
+ )
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@@ -375,7 +381,10 @@
}
ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
}
- ht.HandleStreams(func(s *Stream) { go runStream(s) })
+ ht.HandleStreams(
+ func(s *Stream) { go runStream(s) },
+ func(ctx context.Context, method string) context.Context { return ctx },
+ )
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 2b0f680..cbd9f32 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -51,16 +51,19 @@
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/stats"
)
// http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct {
- target string // server name/addr
- userAgent string
- md interface{}
- conn net.Conn // underlying communication channel
- authInfo credentials.AuthInfo // auth info about the connection
- nextID uint32 // the next stream ID to be used
+ target string // server name/addr
+ userAgent string
+ md interface{}
+ conn net.Conn // underlying communication channel
+ remoteAddr net.Addr
+ localAddr net.Addr
+ authInfo credentials.AuthInfo // auth info about the connection
+ nextID uint32 // the next stream ID to be used
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan
@@ -150,6 +153,9 @@
scheme := "http"
conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
+ if opts.FailOnNonTempDialError {
+ return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err)
+ }
return nil, connectionErrorf(true, err, "transport: %v", err)
}
// Any further errors will close the underlying connection
@@ -175,11 +181,13 @@
}
var buf bytes.Buffer
t := &http2Client{
- target: addr.Addr,
- userAgent: ua,
- md: addr.Metadata,
- conn: conn,
- authInfo: authInfo,
+ target: addr.Addr,
+ userAgent: ua,
+ md: addr.Metadata,
+ conn: conn,
+ remoteAddr: conn.RemoteAddr(),
+ localAddr: conn.LocalAddr(),
+ authInfo: authInfo,
// The client initiated stream id is odd starting from 1.
nextID: 1,
writableChan: make(chan int, 1),
@@ -270,12 +278,13 @@
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
pr := &peer.Peer{
- Addr: t.conn.RemoteAddr(),
+ Addr: t.remoteAddr,
}
// Attach Auth info if there is any.
if t.authInfo != nil {
pr.AuthInfo = t.authInfo
}
+ userCtx := ctx
ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string)
for _, c := range t.creds {
@@ -347,6 +356,7 @@
return nil, ErrConnClosing
}
s := t.newStream(ctx, callHdr)
+ s.clientStatsCtx = userCtx
t.activeStreams[s.id] = s
// This stream is not counted when applySetings(...) initialize t.streamsQuota.
@@ -413,6 +423,7 @@
}
}
first := true
+ bufLen := t.hBuf.Len()
// Sends the headers in a single batch even when they span multiple frames.
for !endHeaders {
size := t.hBuf.Len()
@@ -447,6 +458,17 @@
return nil, connectionErrorf(true, err, "transport: %v", err)
}
}
+ if stats.On() {
+ outHeader := &stats.OutHeader{
+ Client: true,
+ WireLength: bufLen,
+ FullMethod: callHdr.Method,
+ RemoteAddr: t.remoteAddr,
+ LocalAddr: t.localAddr,
+ Compression: callHdr.SendCompress,
+ }
+ stats.Handle(s.clientStatsCtx, outHeader)
+ }
t.writableChan <- 0
return s, nil
}
@@ -874,6 +896,24 @@
}
endStream := frame.StreamEnded()
+ var isHeader bool
+ defer func() {
+ if stats.On() {
+ if isHeader {
+ inHeader := &stats.InHeader{
+ Client: true,
+ WireLength: int(frame.Header().Length),
+ }
+ stats.Handle(s.clientStatsCtx, inHeader)
+ } else {
+ inTrailer := &stats.InTrailer{
+ Client: true,
+ WireLength: int(frame.Header().Length),
+ }
+ stats.Handle(s.clientStatsCtx, inTrailer)
+ }
+ }
+ }()
s.mu.Lock()
if !endStream {
@@ -885,6 +925,7 @@
}
close(s.headerChan)
s.headerDone = true
+ isHeader = true
}
if !endStream || s.state == streamDone {
s.mu.Unlock()
diff --git a/transport/http2_server.go b/transport/http2_server.go
index a62fb7c..db9beb9 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -50,6 +50,8 @@
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/stats"
+ "google.golang.org/grpc/tap"
)
// ErrIllegalHeaderWrite indicates that setting header is illegal because of
@@ -59,8 +61,11 @@
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
conn net.Conn
+ remoteAddr net.Addr
+ localAddr net.Addr
maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection
+ inTapHandle tap.ServerInHandle
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by receiving a value on writableChan
// and releases it by sending on writableChan.
@@ -91,12 +96,13 @@
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
// returned if something goes wrong.
-func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (_ ServerTransport, err error) {
+func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) {
framer := newFramer(conn)
// Send initial settings as connection preface to client.
var settings []http2.Setting
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
+ maxStreams := config.MaxStreams
if maxStreams == 0 {
maxStreams = math.MaxUint32
} else {
@@ -122,11 +128,14 @@
var buf bytes.Buffer
t := &http2Server{
conn: conn,
- authInfo: authInfo,
+ remoteAddr: conn.RemoteAddr(),
+ localAddr: conn.LocalAddr(),
+ authInfo: config.AuthInfo,
framer: framer,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
maxStreams: maxStreams,
+ inTapHandle: config.InTapHandle,
controlBuf: newRecvBuffer(),
fc: &inFlow{limit: initialConnWindowSize},
sendQuotaPool: newQuotaPool(defaultWindowSize),
@@ -142,7 +151,7 @@
}
// operateHeader takes action on the decoded headers.
-func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
+func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) {
buf := newRecvBuffer()
s := &Stream{
id: frame.Header().StreamID,
@@ -173,7 +182,7 @@
s.ctx, s.cancel = context.WithCancel(context.TODO())
}
pr := &peer.Peer{
- Addr: t.conn.RemoteAddr(),
+ Addr: t.remoteAddr,
}
// Attach Auth info if there is any.
if t.authInfo != nil {
@@ -195,6 +204,18 @@
}
s.recvCompress = state.encoding
s.method = state.method
+ if t.inTapHandle != nil {
+ var err error
+ info := &tap.Info{
+ FullMethodName: state.method,
+ }
+ s.ctx, err = t.inTapHandle(s.ctx, info)
+ if err != nil {
+ // TODO: Log the real error.
+ t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
+ return
+ }
+ }
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
@@ -218,13 +239,25 @@
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
}
+ s.ctx = traceCtx(s.ctx, s.method)
+ if stats.On() {
+ inHeader := &stats.InHeader{
+ FullMethod: s.method,
+ RemoteAddr: t.remoteAddr,
+ LocalAddr: t.localAddr,
+ Compression: s.recvCompress,
+ WireLength: int(frame.Header().Length),
+ }
+ stats.Handle(s.ctx, inHeader)
+ }
handle(s)
return
}
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
-func (t *http2Server) HandleStreams(handle func(*Stream)) {
+// traceCtx attaches trace to ctx and returns the new context.
+func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
// Check the validity of client preface.
preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil {
@@ -279,7 +312,7 @@
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
- if t.operateHeaders(frame, handle) {
+ if t.operateHeaders(frame, handle, traceCtx) {
t.Close()
break
}
@@ -492,9 +525,16 @@
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
}
}
+ bufLen := t.hBuf.Len()
if err := t.writeHeaders(s, t.hBuf, false); err != nil {
return err
}
+ if stats.On() {
+ outHeader := &stats.OutHeader{
+ WireLength: bufLen,
+ }
+ stats.Handle(s.Context(), outHeader)
+ }
t.writableChan <- 0
return nil
}
@@ -547,10 +587,17 @@
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
}
}
+ bufLen := t.hBuf.Len()
if err := t.writeHeaders(s, t.hBuf, true); err != nil {
t.Close()
return err
}
+ if stats.On() {
+ outTrailer := &stats.OutTrailer{
+ WireLength: bufLen,
+ }
+ stats.Handle(s.Context(), outTrailer)
+ }
t.closeStream(s)
t.writableChan <- 0
return nil
@@ -767,7 +814,7 @@
}
func (t *http2Server) RemoteAddr() net.Addr {
- return t.conn.RemoteAddr()
+ return t.remoteAddr
}
func (t *http2Server) Drain() {
diff --git a/transport/transport.go b/transport/transport.go
index 413f749..4726bb2 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -45,10 +45,10 @@
"sync"
"golang.org/x/net/context"
- "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/tap"
)
// recvMsg represents the received msg from the transport. All transport
@@ -167,6 +167,11 @@
id uint32
// nil for client side Stream.
st ServerTransport
+ // clientStatsCtx keeps the user context for stats handling.
+ // It's only valid on client side. Server side stats context is same as s.ctx.
+ // All client side stats collection should use the clientStatsCtx (instead of the stream context)
+ // so that all the generated stats for a particular RPC can be associated in the processing phase.
+ clientStatsCtx context.Context
// ctx is the associated context of the stream.
ctx context.Context
// cancel is always nil for client side Stream.
@@ -266,11 +271,6 @@
return s.ctx
}
-// TraceContext recreates the context of s with a trace.Trace.
-func (s *Stream) TraceContext(tr trace.Trace) {
- s.ctx = trace.NewContext(s.ctx, tr)
-}
-
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
@@ -355,10 +355,17 @@
draining
)
+// ServerConfig consists of all the configurations to establish a server transport.
+type ServerConfig struct {
+ MaxStreams uint32
+ AuthInfo credentials.AuthInfo
+ InTapHandle tap.ServerInHandle
+}
+
// NewServerTransport creates a ServerTransport with conn or non-nil error
// if it fails.
-func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (ServerTransport, error) {
- return newHTTP2Server(conn, maxStreams, authInfo)
+func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) {
+ return newHTTP2Server(conn, config)
}
// ConnectOptions covers all relevant options for communicating with the server.
@@ -367,6 +374,8 @@
UserAgent string
// Dialer specifies how to dial a network address.
Dialer func(context.Context, string) (net.Conn, error)
+ // FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors.
+ FailOnNonTempDialError bool
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
@@ -466,7 +475,7 @@
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler.
- HandleStreams(func(*Stream))
+ HandleStreams(func(*Stream), func(context.Context, string) context.Context)
// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 81320e6..1ca6eb1 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -179,7 +179,10 @@
if err != nil {
return
}
- transport, err := NewServerTransport("http2", conn, maxStreams, nil)
+ config := &ServerConfig{
+ MaxStreams: maxStreams,
+ }
+ transport, err := NewServerTransport("http2", conn, config)
if err != nil {
return
}
@@ -194,22 +197,33 @@
h := &testStreamHandler{transport.(*http2Server)}
switch ht {
case suspended:
- go transport.HandleStreams(h.handleStreamSuspension)
+ go transport.HandleStreams(h.handleStreamSuspension,
+ func(ctx context.Context, method string) context.Context {
+ return ctx
+ })
case misbehaved:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
})
case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamEncodingRequiredStatus(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
})
case invalidHeaderField:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamInvalidHeaderField(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
})
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
})
}
}