add stats tagger APIs and connection stats. (#992)
* add stats.tagger APIs and connection stats.
* fix comments
use ac.ctx in http2client
change name and comments
small fixes stats_tests
* add a TODO to ConnTagInfo
* rename handle to handleRPC
* modify stats comments
diff --git a/call.go b/call.go
index 5d9214d..fc8e18a 100644
--- a/call.go
+++ b/call.go
@@ -82,7 +82,7 @@
if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
// Fix the order if necessary.
- stats.Handle(ctx, inPayload)
+ stats.HandleRPC(ctx, inPayload)
}
c.trailerMD = stream.Trailer()
return nil
@@ -121,7 +121,7 @@
err = t.Write(stream, outBuf, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
- stats.Handle(ctx, outPayload)
+ stats.HandleRPC(ctx, outPayload)
}
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
@@ -172,12 +172,13 @@
}()
}
if stats.On() {
+ ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
FailFast: c.failFast,
}
- stats.Handle(ctx, begin)
+ stats.HandleRPC(ctx, begin)
}
defer func() {
if stats.On() {
@@ -186,7 +187,7 @@
EndTime: time.Now(),
Error: e,
}
- stats.Handle(ctx, end)
+ stats.HandleRPC(ctx, end)
}
}()
topts := &transport.Options{
diff --git a/server.go b/server.go
index 3af001a..22aa33b 100644
--- a/server.go
+++ b/server.go
@@ -583,7 +583,7 @@
err = t.Write(stream, p, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
- stats.Handle(stream.Context(), outPayload)
+ stats.HandleRPC(stream.Context(), outPayload)
}
return err
}
@@ -593,7 +593,7 @@
begin := &stats.Begin{
BeginTime: time.Now(),
}
- stats.Handle(stream.Context(), begin)
+ stats.HandleRPC(stream.Context(), begin)
}
defer func() {
if stats.On() {
@@ -603,7 +603,7 @@
if err != nil && err != io.EOF {
end.Error = toRPCErr(err)
}
- stats.Handle(stream.Context(), end)
+ stats.HandleRPC(stream.Context(), end)
}
}()
if trInfo != nil {
@@ -698,7 +698,7 @@
inPayload.Payload = v
inPayload.Data = req
inPayload.Length = len(req)
- stats.Handle(stream.Context(), inPayload)
+ stats.HandleRPC(stream.Context(), inPayload)
}
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
@@ -759,7 +759,7 @@
begin := &stats.Begin{
BeginTime: time.Now(),
}
- stats.Handle(stream.Context(), begin)
+ stats.HandleRPC(stream.Context(), begin)
}
defer func() {
if stats.On() {
@@ -769,7 +769,7 @@
if err != nil && err != io.EOF {
end.Error = toRPCErr(err)
}
- stats.Handle(stream.Context(), end)
+ stats.HandleRPC(stream.Context(), end)
}
}()
if s.opts.cp != nil {
diff --git a/stats/handlers.go b/stats/handlers.go
new file mode 100644
index 0000000..d41c524
--- /dev/null
+++ b/stats/handlers.go
@@ -0,0 +1,152 @@
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package stats
+
+import (
+ "net"
+ "sync/atomic"
+
+ "golang.org/x/net/context"
+ "google.golang.org/grpc/grpclog"
+)
+
+// ConnTagInfo defines the relevant information needed by connection context tagger.
+type ConnTagInfo struct {
+ // RemoteAddr is the remote address of the corresponding connection.
+ RemoteAddr net.Addr
+ // LocalAddr is the local address of the corresponding connection.
+ LocalAddr net.Addr
+ // TODO add QOS related fields.
+}
+
+// RPCTagInfo defines the relevant information needed by RPC context tagger.
+type RPCTagInfo struct {
+ // FullMethodName is the RPC method in the format of /package.service/method.
+ FullMethodName string
+}
+
+var (
+ on = new(int32)
+ rpcHandler func(context.Context, RPCStats)
+ connHandler func(context.Context, ConnStats)
+ connTagger func(context.Context, *ConnTagInfo) context.Context
+ rpcTagger func(context.Context, *RPCTagInfo) context.Context
+)
+
+// HandleRPC processes the RPC stats using the rpc handler registered by the user.
+func HandleRPC(ctx context.Context, s RPCStats) {
+ if rpcHandler == nil {
+ return
+ }
+ rpcHandler(ctx, s)
+}
+
+// RegisterRPCHandler registers the user handler function for RPC stats processing.
+// It should be called only once. The later call will overwrite the former value if it is called multiple times.
+// This handler function will be called to process the rpc stats.
+func RegisterRPCHandler(f func(context.Context, RPCStats)) {
+ rpcHandler = f
+}
+
+// HandleConn processes the stats using the call back function registered by user.
+func HandleConn(ctx context.Context, s ConnStats) {
+ if connHandler == nil {
+ return
+ }
+ connHandler(ctx, s)
+}
+
+// RegisterConnHandler registers the user handler function for conn stats.
+// It should be called only once. The later call will overwrite the former value if it is called multiple times.
+// This handler function will be called to process the conn stats.
+func RegisterConnHandler(f func(context.Context, ConnStats)) {
+ connHandler = f
+}
+
+// TagConn calls user registered connection context tagger.
+func TagConn(ctx context.Context, info *ConnTagInfo) context.Context {
+ if connTagger == nil {
+ return ctx
+ }
+ return connTagger(ctx, info)
+}
+
+// RegisterConnTagger registers the user connection context tagger function.
+// The connection context tagger can attach some information to the given context.
+// The returned context will be used for stats handling.
+// For conn stats handling, the context used in connHandler for this
+// connection will be derived from the context returned.
+// For RPC stats handling,
+// - On server side, the context used in rpcHandler for all RPCs on this
+// connection will be derived from the context returned.
+// - On client side, the context is not derived from the context returned.
+func RegisterConnTagger(t func(context.Context, *ConnTagInfo) context.Context) {
+ connTagger = t
+}
+
+// TagRPC calls the user registered RPC context tagger.
+func TagRPC(ctx context.Context, info *RPCTagInfo) context.Context {
+ if rpcTagger == nil {
+ return ctx
+ }
+ return rpcTagger(ctx, info)
+}
+
+// RegisterRPCTagger registers the user RPC context tagger function.
+// The RPC context tagger can attach some information to the given context.
+// The context used in stats rpcHandler for this RPC will be derived from the
+// context returned.
+func RegisterRPCTagger(t func(context.Context, *RPCTagInfo) context.Context) {
+ rpcTagger = t
+}
+
+// Start starts the stats collection and processing if there is a registered stats handle.
+func Start() {
+ if rpcHandler == nil && connHandler == nil {
+ grpclog.Println("rpcHandler and connHandler are both nil when starting stats. Stats is not started")
+ return
+ }
+ atomic.StoreInt32(on, 1)
+}
+
+// Stop stops the stats collection and processing.
+// Stop does not unregister the handlers.
+func Stop() {
+ atomic.StoreInt32(on, 0)
+}
+
+// On indicates whether the stats collection and processing is on.
+func On() bool {
+ return atomic.CompareAndSwapInt32(on, 1, 1)
+}
diff --git a/stats/stats.go b/stats/stats.go
index 4b030d9..a82448a 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -38,16 +38,12 @@
import (
"net"
- "sync/atomic"
"time"
-
- "golang.org/x/net/context"
- "google.golang.org/grpc/grpclog"
)
// RPCStats contains stats information about RPCs.
-// All stats types in this package implements this interface.
type RPCStats interface {
+ isRPCStats()
// IsClient returns true if this RPCStats is from client side.
IsClient() bool
}
@@ -66,6 +62,8 @@
// IsClient indicates if this is from client side.
func (s *Begin) IsClient() bool { return s.Client }
+func (s *Begin) isRPCStats() {}
+
// InPayload contains the information for an incoming payload.
type InPayload struct {
// Client is true if this InPayload is from client side.
@@ -85,6 +83,8 @@
// IsClient indicates if this is from client side.
func (s *InPayload) IsClient() bool { return s.Client }
+func (s *InPayload) isRPCStats() {}
+
// InHeader contains stats when a header is received.
// FullMethod, addresses and Compression are only valid if Client is false.
type InHeader struct {
@@ -106,6 +106,8 @@
// IsClient indicates if this is from client side.
func (s *InHeader) IsClient() bool { return s.Client }
+func (s *InHeader) isRPCStats() {}
+
// InTrailer contains stats when a trailer is received.
type InTrailer struct {
// Client is true if this InTrailer is from client side.
@@ -117,6 +119,8 @@
// IsClient indicates if this is from client side.
func (s *InTrailer) IsClient() bool { return s.Client }
+func (s *InTrailer) isRPCStats() {}
+
// OutPayload contains the information for an outgoing payload.
type OutPayload struct {
// Client is true if this OutPayload is from client side.
@@ -136,6 +140,8 @@
// IsClient indicates if this is from client side.
func (s *OutPayload) IsClient() bool { return s.Client }
+func (s *OutPayload) isRPCStats() {}
+
// OutHeader contains stats when a header is sent.
// FullMethod, addresses and Compression are only valid if Client is true.
type OutHeader struct {
@@ -157,6 +163,8 @@
// IsClient indicates if this is from client side.
func (s *OutHeader) IsClient() bool { return s.Client }
+func (s *OutHeader) isRPCStats() {}
+
// OutTrailer contains stats when a trailer is sent.
type OutTrailer struct {
// Client is true if this OutTrailer is from client side.
@@ -168,6 +176,8 @@
// IsClient indicates if this is from client side.
func (s *OutTrailer) IsClient() bool { return s.Client }
+func (s *OutTrailer) isRPCStats() {}
+
// End contains stats when an RPC ends.
type End struct {
// Client is true if this End is from client side.
@@ -181,39 +191,33 @@
// IsClient indicates if this is from client side.
func (s *End) IsClient() bool { return s.Client }
-var (
- on = new(int32)
- handler func(context.Context, RPCStats)
-)
+func (s *End) isRPCStats() {}
-// On indicates whether stats is started.
-func On() bool {
- return atomic.CompareAndSwapInt32(on, 1, 1)
+// ConnStats contains stats information about connections.
+type ConnStats interface {
+ isConnStats()
+ // IsClient returns true if this ConnStats is from client side.
+ IsClient() bool
}
-// Handle processes the stats using the call back function registered by user.
-func Handle(ctx context.Context, s RPCStats) {
- handler(ctx, s)
+// ConnBegin contains the stats of a connection when it is established.
+type ConnBegin struct {
+ // Client is true if this ConnBegin is from client side.
+ Client bool
}
-// RegisterHandler registers the user handler function.
-// If another handler was registered before, this new handler will overwrite the old one.
-// This handler function will be called to process the stats.
-func RegisterHandler(f func(context.Context, RPCStats)) {
- handler = f
+// IsClient indicates if this is from client side.
+func (s *ConnBegin) IsClient() bool { return s.Client }
+
+func (s *ConnBegin) isConnStats() {}
+
+// ConnEnd contains the stats of a connection when it ends.
+type ConnEnd struct {
+ // Client is true if this ConnEnd is from client side.
+ Client bool
}
-// Start starts the stats collection and reporting if there is a registered stats handle.
-func Start() {
- if handler == nil {
- grpclog.Println("handler is nil when starting stats. Stats is not started")
- return
- }
- atomic.StoreInt32(on, 1)
-}
+// IsClient indicates if this is from client side.
+func (s *ConnEnd) IsClient() bool { return s.Client }
-// Stop stops the stats collection and processing.
-// Stop does not unregister handler.
-func Stop() {
- atomic.StoreInt32(on, 0)
-}
+func (s *ConnEnd) isConnStats() {}
diff --git a/stats/stats_test.go b/stats/stats_test.go
index e904810..1761e79 100644
--- a/stats/stats_test.go
+++ b/stats/stats_test.go
@@ -49,26 +49,87 @@
testpb "google.golang.org/grpc/stats/grpc_testing"
)
+func init() {
+ grpc.EnableTracing = false
+}
+
func TestStartStop(t *testing.T) {
- stats.RegisterHandler(nil)
+ stats.RegisterRPCHandler(nil)
+ stats.RegisterConnHandler(nil)
stats.Start()
- if stats.On() != false {
+ if stats.On() {
t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false")
}
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {})
- if stats.On() != false {
- t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false")
- }
+
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {})
+ stats.RegisterConnHandler(nil)
stats.Start()
- if stats.On() != true {
- t.Fatalf("after stats.Start(_), stats.On() = false, want true")
+ if !stats.On() {
+ t.Fatalf("stats.Start() with non-nil handler, stats.On() = false, want true")
}
stats.Stop()
- if stats.On() != false {
+
+ stats.RegisterRPCHandler(nil)
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {})
+ stats.Start()
+ if !stats.On() {
+ t.Fatalf("stats.Start() with non-nil conn handler, stats.On() = false, want true")
+ }
+ stats.Stop()
+
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {})
+ if stats.On() {
+ t.Fatalf("after stats.RegisterRPCHandler(), stats.On() = true, want false")
+ }
+ stats.Start()
+ if !stats.On() {
+ t.Fatalf("after stats.Start(_), stats.On() = false, want true")
+ }
+
+ stats.Stop()
+ if stats.On() {
t.Fatalf("after stats.Stop(), stats.On() = true, want false")
}
}
+type connCtxKey struct{}
+type rpcCtxKey struct{}
+
+func TestTagConnCtx(t *testing.T) {
+ defer stats.RegisterConnTagger(nil)
+ ctx1 := context.Background()
+ stats.RegisterConnTagger(nil)
+ ctx2 := stats.TagConn(ctx1, nil)
+ if ctx2 != ctx1 {
+ t.Fatalf("nil conn ctx tagger should not modify context, got %v; want %v", ctx2, ctx1)
+ }
+ stats.RegisterConnTagger(func(ctx context.Context, info *stats.ConnTagInfo) context.Context {
+ return context.WithValue(ctx, connCtxKey{}, "connctxvalue")
+ })
+ ctx3 := stats.TagConn(ctx1, nil)
+ if v, ok := ctx3.Value(connCtxKey{}).(string); !ok || v != "connctxvalue" {
+ t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, connCtxKey{}, "connctxvalue"))
+ }
+}
+
+func TestTagRPCCtx(t *testing.T) {
+ defer stats.RegisterRPCTagger(nil)
+ ctx1 := context.Background()
+ stats.RegisterRPCTagger(nil)
+ ctx2 := stats.TagRPC(ctx1, nil)
+ if ctx2 != ctx1 {
+ t.Fatalf("nil rpc ctx tagger should not modify context, got %v; want %v", ctx2, ctx1)
+ }
+ stats.RegisterRPCTagger(func(ctx context.Context, info *stats.RPCTagInfo) context.Context {
+ return context.WithValue(ctx, rpcCtxKey{}, "rpcctxvalue")
+ })
+ ctx3 := stats.TagRPC(ctx1, nil)
+ if v, ok := ctx3.Value(rpcCtxKey{}).(string); !ok || v != "rpcctxvalue" {
+ t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, rpcCtxKey{}, "rpcctxvalue"))
+ }
+}
+
var (
// For headers:
testMetadata = metadata.MD{
@@ -242,10 +303,6 @@
ctx := metadata.NewContext(context.Background(), testMetadata)
resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
- if err != nil {
- return req, resp, err
- }
-
return req, resp, err
}
@@ -303,7 +360,7 @@
type gotData struct {
ctx context.Context
client bool
- s stats.RPCStats
+ s interface{} // This could be RPCStats or ConnStats.
}
const (
@@ -315,6 +372,8 @@
outPayload
outHeader
outTrailer
+ connbegin
+ connend
)
func checkBegin(t *testing.T, d *gotData, e *expectedData) {
@@ -363,6 +422,24 @@
if st.Compression != e.compression {
t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
}
+
+ if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
+ if connInfo.RemoteAddr != st.RemoteAddr {
+ t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
+ }
+ if connInfo.LocalAddr != st.LocalAddr {
+ t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
+ }
+ } else {
+ t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
+ }
+ if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
+ if rpcInfo.FullMethodName != st.FullMethod {
+ t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
+ }
+ } else {
+ t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
+ }
}
}
@@ -451,11 +528,19 @@
t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
}
if st.RemoteAddr.String() != e.serverAddr {
- t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
+ t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
}
if st.Compression != e.compression {
t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
}
+
+ if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
+ if rpcInfo.FullMethodName != st.FullMethod {
+ t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
+ }
+ } else {
+ t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
+ }
}
}
@@ -546,14 +631,91 @@
}
}
-func TestServerStatsUnaryRPC(t *testing.T) {
- var got []*gotData
+func checkConnBegin(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.ConnBegin
+ )
+ if st, ok = d.s.(*stats.ConnBegin); !ok {
+ t.Fatalf("got %T, want ConnBegin", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ st.IsClient() // TODO remove this.
+}
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+func checkConnEnd(t *testing.T, d *gotData, e *expectedData) {
+ var (
+ ok bool
+ st *stats.ConnEnd
+ )
+ if st, ok = d.s.(*stats.ConnEnd); !ok {
+ t.Fatalf("got %T, want ConnEnd", d.s)
+ }
+ if d.ctx == nil {
+ t.Fatalf("d.ctx = nil, want <non-nil>")
+ }
+ st.IsClient() // TODO remove this.
+}
+
+func tagConnCtx(ctx context.Context, info *stats.ConnTagInfo) context.Context {
+ return context.WithValue(ctx, connCtxKey{}, info)
+}
+
+func tagRPCCtx(ctx context.Context, info *stats.RPCTagInfo) context.Context {
+ return context.WithValue(ctx, rpcCtxKey{}, info)
+}
+
+func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
+ if len(got) != len(checkFuncs) {
+ t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
+ }
+
+ var (
+ rpcctx context.Context
+ connctx context.Context
+ )
+ for i := 0; i < len(got); i++ {
+ if _, ok := got[i].s.(stats.RPCStats); ok {
+ if rpcctx != nil && got[i].ctx != rpcctx {
+ t.Fatalf("got different contexts with stats %T", got[i].s)
+ }
+ rpcctx = got[i].ctx
+ } else {
+ if connctx != nil && got[i].ctx != connctx {
+ t.Fatalf("got different contexts with stats %T", got[i].s)
+ }
+ connctx = got[i].ctx
+ }
+ }
+
+ for i, f := range checkFuncs {
+ f(t, got[i], expect)
+ }
+}
+
+func TestServerStatsUnaryRPC(t *testing.T) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
if !s.IsClient() {
got = append(got, &gotData{ctx, false, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -575,6 +737,7 @@
}
checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkConnBegin,
checkInHeader,
checkBegin,
checkInPayload,
@@ -582,30 +745,33 @@
checkOutPayload,
checkOutTrailer,
checkEnd,
+ checkConnEnd,
}
- if len(got) != len(checkFuncs) {
- t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
- }
-
- for i := 0; i < len(got)-1; i++ {
- if got[i].ctx != got[i+1].ctx {
- t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
- }
- }
-
- for i, f := range checkFuncs {
- f(t, got[i], expect)
- }
+ checkServerStats(t, got, expect, checkFuncs)
}
func TestServerStatsUnaryRPCError(t *testing.T) {
- var got []*gotData
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
if !s.IsClient() {
got = append(got, &gotData{ctx, false, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -628,36 +794,40 @@
}
checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkConnBegin,
checkInHeader,
checkBegin,
checkInPayload,
checkOutHeader,
checkOutTrailer,
checkEnd,
+ checkConnEnd,
}
- if len(got) != len(checkFuncs) {
- t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
- }
-
- for i := 0; i < len(got)-1; i++ {
- if got[i].ctx != got[i+1].ctx {
- t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
- }
- }
-
- for i, f := range checkFuncs {
- f(t, got[i], expect)
- }
+ checkServerStats(t, got, expect, checkFuncs)
}
func TestServerStatsStreamingRPC(t *testing.T) {
- var got []*gotData
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
if !s.IsClient() {
got = append(got, &gotData{ctx, false, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -681,6 +851,7 @@
}
checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkConnBegin,
checkInHeader,
checkBegin,
checkOutHeader,
@@ -692,31 +863,36 @@
for i := 0; i < count; i++ {
checkFuncs = append(checkFuncs, ioPayFuncs...)
}
- checkFuncs = append(checkFuncs, checkOutTrailer, checkEnd)
+ checkFuncs = append(checkFuncs,
+ checkOutTrailer,
+ checkEnd,
+ checkConnEnd,
+ )
- if len(got) != len(checkFuncs) {
- t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
- }
-
- for i := 0; i < len(got)-1; i++ {
- if got[i].ctx != got[i+1].ctx {
- t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
- }
- }
-
- for i, f := range checkFuncs {
- f(t, got[i], expect)
- }
+ checkServerStats(t, got, expect, checkFuncs)
}
func TestServerStatsStreamingRPCError(t *testing.T) {
- var got []*gotData
-
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ var (
+ mu sync.Mutex
+ got []*gotData
+ )
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
+ mu.Lock()
+ defer mu.Unlock()
if !s.IsClient() {
got = append(got, &gotData{ctx, false, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if !s.IsClient() {
+ got = append(got, &gotData{ctx, false, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -741,27 +917,17 @@
}
checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
+ checkConnBegin,
checkInHeader,
checkBegin,
checkOutHeader,
checkInPayload,
checkOutTrailer,
checkEnd,
+ checkConnEnd,
}
- if len(got) != len(checkFuncs) {
- t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
- }
-
- for i := 0; i < len(got)-1; i++ {
- if got[i].ctx != got[i+1].ctx {
- t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
- }
- }
-
- for i, f := range checkFuncs {
- f(t, got[i], expect)
- }
+ checkServerStats(t, got, expect, checkFuncs)
}
type checkFuncWithCount struct {
@@ -778,9 +944,21 @@
t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
}
- for i := 0; i < len(got)-1; i++ {
- if got[i].ctx != got[i+1].ctx {
- t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s)
+ var (
+ rpcctx context.Context
+ connctx context.Context
+ )
+ for i := 0; i < len(got); i++ {
+ if _, ok := got[i].s.(stats.RPCStats); ok {
+ if rpcctx != nil && got[i].ctx != rpcctx {
+ t.Fatalf("got different contexts with stats %T", got[i].s)
+ }
+ rpcctx = got[i].ctx
+ } else {
+ if connctx != nil && got[i].ctx != connctx {
+ t.Fatalf("got different contexts with stats %T", got[i].s)
+ }
+ connctx = got[i].ctx
}
}
@@ -788,48 +966,60 @@
switch s.s.(type) {
case *stats.Begin:
if checkFuncs[begin].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[begin].f(t, s, expect)
checkFuncs[begin].c--
case *stats.OutHeader:
if checkFuncs[outHeader].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[outHeader].f(t, s, expect)
checkFuncs[outHeader].c--
case *stats.OutPayload:
if checkFuncs[outPayload].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[outPayload].f(t, s, expect)
checkFuncs[outPayload].c--
case *stats.InHeader:
if checkFuncs[inHeader].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[inHeader].f(t, s, expect)
checkFuncs[inHeader].c--
case *stats.InPayload:
if checkFuncs[inPayload].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[inPayload].f(t, s, expect)
checkFuncs[inPayload].c--
case *stats.InTrailer:
if checkFuncs[inTrailer].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[inTrailer].f(t, s, expect)
checkFuncs[inTrailer].c--
case *stats.End:
if checkFuncs[end].c <= 0 {
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
checkFuncs[end].f(t, s, expect)
checkFuncs[end].c--
+ case *stats.ConnBegin:
+ if checkFuncs[connbegin].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s.s)
+ }
+ checkFuncs[connbegin].f(t, s, expect)
+ checkFuncs[connbegin].c--
+ case *stats.ConnEnd:
+ if checkFuncs[connend].c <= 0 {
+ t.Fatalf("unexpected stats: %T", s.s)
+ }
+ checkFuncs[connend].f(t, s, expect)
+ checkFuncs[connend].c--
default:
- t.Fatalf("unexpected stats: %T", s)
+ t.Fatalf("unexpected stats: %T", s.s)
}
}
}
@@ -839,13 +1029,22 @@
mu sync.Mutex
got []*gotData
)
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
mu.Lock()
defer mu.Unlock()
if s.IsClient() {
got = append(got, &gotData{ctx, true, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -869,6 +1068,7 @@
}
checkFuncs := map[int]*checkFuncWithCount{
+ connbegin: {checkConnBegin, 1},
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, 1},
@@ -876,6 +1076,7 @@
inPayload: {checkInPayload, 1},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
+ connend: {checkConnEnd, 1},
}
checkClientStats(t, got, expect, checkFuncs)
@@ -886,13 +1087,22 @@
mu sync.Mutex
got []*gotData
)
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
mu.Lock()
defer mu.Unlock()
if s.IsClient() {
got = append(got, &gotData{ctx, true, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -917,12 +1127,14 @@
}
checkFuncs := map[int]*checkFuncWithCount{
+ connbegin: {checkConnBegin, 1},
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, 1},
inHeader: {checkInHeader, 1},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
+ connend: {checkConnEnd, 1},
}
checkClientStats(t, got, expect, checkFuncs)
@@ -933,14 +1145,22 @@
mu sync.Mutex
got []*gotData
)
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
mu.Lock()
defer mu.Unlock()
if s.IsClient() {
- // t.Logf(" == %T %v", s, s.IsClient())
got = append(got, &gotData{ctx, true, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -966,6 +1186,7 @@
}
checkFuncs := map[int]*checkFuncWithCount{
+ connbegin: {checkConnBegin, 1},
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, count},
@@ -973,6 +1194,7 @@
inPayload: {checkInPayload, count},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
+ connend: {checkConnEnd, 1},
}
checkClientStats(t, got, expect, checkFuncs)
@@ -983,13 +1205,22 @@
mu sync.Mutex
got []*gotData
)
- stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {
+ stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
mu.Lock()
defer mu.Unlock()
if s.IsClient() {
got = append(got, &gotData{ctx, true, s})
}
})
+ stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
+ mu.Lock()
+ defer mu.Unlock()
+ if s.IsClient() {
+ got = append(got, &gotData{ctx, true, s})
+ }
+ })
+ stats.RegisterConnTagger(tagConnCtx)
+ stats.RegisterRPCTagger(tagRPCCtx)
stats.Start()
defer stats.Stop()
@@ -1016,12 +1247,14 @@
}
checkFuncs := map[int]*checkFuncWithCount{
+ connbegin: {checkConnBegin, 1},
begin: {checkBegin, 1},
outHeader: {checkOutHeader, 1},
outPayload: {checkOutPayload, 1},
inHeader: {checkInHeader, 1},
inTrailer: {checkInTrailer, 1},
end: {checkEnd, 1},
+ connend: {checkConnEnd, 1},
}
checkClientStats(t, got, expect, checkFuncs)
diff --git a/stream.go b/stream.go
index 95c8acf..1bcd218 100644
--- a/stream.go
+++ b/stream.go
@@ -145,12 +145,13 @@
}()
}
if stats.On() {
+ ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
FailFast: c.failFast,
}
- stats.Handle(ctx, begin)
+ stats.HandleRPC(ctx, begin)
}
defer func() {
if err != nil && stats.On() {
@@ -159,7 +160,7 @@
Client: true,
Error: err,
}
- stats.Handle(ctx, end)
+ stats.HandleRPC(ctx, end)
}
}()
gopts := BalancerGetOptions{
@@ -342,7 +343,7 @@
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
- stats.Handle(cs.statsCtx, outPayload)
+ stats.HandleRPC(cs.statsCtx, outPayload)
}
return err
}
@@ -360,7 +361,7 @@
if err != io.EOF {
end.Error = toRPCErr(err)
}
- stats.Handle(cs.statsCtx, end)
+ stats.HandleRPC(cs.statsCtx, end)
}
}()
var inPayload *stats.InPayload
@@ -385,7 +386,7 @@
cs.mu.Unlock()
}
if inPayload != nil {
- stats.Handle(cs.statsCtx, inPayload)
+ stats.HandleRPC(cs.statsCtx, inPayload)
}
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
@@ -565,7 +566,7 @@
}
if outPayload != nil {
outPayload.SentTime = time.Now()
- stats.Handle(ss.s.Context(), outPayload)
+ stats.HandleRPC(ss.s.Context(), outPayload)
}
return nil
}
@@ -599,7 +600,7 @@
return toRPCErr(err)
}
if inPayload != nil {
- stats.Handle(ss.s.Context(), inPayload)
+ stats.HandleRPC(ss.s.Context(), inPayload)
}
return nil
}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index cbd9f32..5640aea 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -56,6 +56,7 @@
// http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct {
+ ctx context.Context
target string // server name/addr
userAgent string
md interface{}
@@ -181,6 +182,7 @@
}
var buf bytes.Buffer
t := &http2Client{
+ ctx: ctx,
target: addr.Addr,
userAgent: ua,
md: addr.Metadata,
@@ -242,6 +244,16 @@
}
go t.controller()
t.writableChan <- 0
+ if stats.On() {
+ t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{
+ RemoteAddr: t.remoteAddr,
+ LocalAddr: t.localAddr,
+ })
+ connBegin := &stats.ConnBegin{
+ Client: true,
+ }
+ stats.HandleConn(t.ctx, connBegin)
+ }
return t, nil
}
@@ -467,7 +479,7 @@
LocalAddr: t.localAddr,
Compression: callHdr.SendCompress,
}
- stats.Handle(s.clientStatsCtx, outHeader)
+ stats.HandleRPC(s.clientStatsCtx, outHeader)
}
t.writableChan <- 0
return s, nil
@@ -547,6 +559,12 @@
s.mu.Unlock()
s.write(recvMsg{err: ErrConnClosing})
}
+ if stats.On() {
+ connEnd := &stats.ConnEnd{
+ Client: true,
+ }
+ stats.HandleConn(t.ctx, connEnd)
+ }
return
}
@@ -904,13 +922,13 @@
Client: true,
WireLength: int(frame.Header().Length),
}
- stats.Handle(s.clientStatsCtx, inHeader)
+ stats.HandleRPC(s.clientStatsCtx, inHeader)
} else {
inTrailer := &stats.InTrailer{
Client: true,
WireLength: int(frame.Header().Length),
}
- stats.Handle(s.clientStatsCtx, inTrailer)
+ stats.HandleRPC(s.clientStatsCtx, inTrailer)
}
}
}()
diff --git a/transport/http2_server.go b/transport/http2_server.go
index db9beb9..62ea303 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -60,6 +60,7 @@
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
+ ctx context.Context
conn net.Conn
remoteAddr net.Addr
localAddr net.Addr
@@ -127,6 +128,7 @@
}
var buf bytes.Buffer
t := &http2Server{
+ ctx: context.Background(),
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
@@ -145,6 +147,14 @@
activeStreams: make(map[uint32]*Stream),
streamSendQuota: defaultWindowSize,
}
+ if stats.On() {
+ t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{
+ RemoteAddr: t.remoteAddr,
+ LocalAddr: t.localAddr,
+ })
+ connBegin := &stats.ConnBegin{}
+ stats.HandleConn(t.ctx, connBegin)
+ }
go t.controller()
t.writableChan <- 0
return t, nil
@@ -177,9 +187,9 @@
}
s.recvCompress = state.encoding
if state.timeoutSet {
- s.ctx, s.cancel = context.WithTimeout(context.TODO(), state.timeout)
+ s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
} else {
- s.ctx, s.cancel = context.WithCancel(context.TODO())
+ s.ctx, s.cancel = context.WithCancel(t.ctx)
}
pr := &peer.Peer{
Addr: t.remoteAddr,
@@ -241,6 +251,7 @@
}
s.ctx = traceCtx(s.ctx, s.method)
if stats.On() {
+ s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
@@ -248,7 +259,7 @@
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
}
- stats.Handle(s.ctx, inHeader)
+ stats.HandleRPC(s.ctx, inHeader)
}
handle(s)
return
@@ -533,7 +544,7 @@
outHeader := &stats.OutHeader{
WireLength: bufLen,
}
- stats.Handle(s.Context(), outHeader)
+ stats.HandleRPC(s.Context(), outHeader)
}
t.writableChan <- 0
return nil
@@ -596,7 +607,7 @@
outTrailer := &stats.OutTrailer{
WireLength: bufLen,
}
- stats.Handle(s.Context(), outTrailer)
+ stats.HandleRPC(s.Context(), outTrailer)
}
t.closeStream(s)
t.writableChan <- 0
@@ -783,6 +794,10 @@
for _, s := range streams {
s.cancel()
}
+ if stats.On() {
+ connEnd := &stats.ConnEnd{}
+ stats.HandleConn(t.ctx, connEnd)
+ }
return
}