Calling handleRPC with context derived from the original (#1227)
* Calling handleRPC with different context derived from the original context
* change comment for tagRPC and stats fields
diff --git a/call.go b/call.go
index b393731..0eb5f5c 100644
--- a/call.go
+++ b/call.go
@@ -182,7 +182,7 @@
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
- ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
+ ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
diff --git a/stats/handlers.go b/stats/handlers.go
index 26e1a8e..5fdce2f 100644
--- a/stats/handlers.go
+++ b/stats/handlers.go
@@ -45,19 +45,22 @@
RemoteAddr net.Addr
// LocalAddr is the local address of the corresponding connection.
LocalAddr net.Addr
- // TODO add QOS related fields.
}
// RPCTagInfo defines the relevant information needed by RPC context tagger.
type RPCTagInfo struct {
// FullMethodName is the RPC method in the format of /package.service/method.
FullMethodName string
+ // FailFast indicates if this RPC is failfast.
+ // This field is only valid on client side, it's always false on server side.
+ FailFast bool
}
// Handler defines the interface for the related stats handling (e.g., RPCs, connections).
type Handler interface {
// TagRPC can attach some information to the given context.
- // The returned context is used in the rest lifetime of the RPC.
+ // The context used for the rest lifetime of the RPC will be derived from
+ // the returned context.
TagRPC(context.Context, *RPCTagInfo) context.Context
// HandleRPC processes the RPC stats.
HandleRPC(context.Context, RPCStats)
diff --git a/stats/stats.go b/stats/stats.go
index 75bdd81..6c406c7 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -86,13 +86,13 @@
func (s *InPayload) isRPCStats() {}
// InHeader contains stats when a header is received.
-// FullMethod, addresses and Compression are only valid if Client is false.
type InHeader struct {
// Client is true if this InHeader is from client side.
Client bool
// WireLength is the wire length of header.
WireLength int
+ // The following fields are valid only if Client is false.
// FullMethod is the full RPC method string, i.e., /package.service/method.
FullMethod string
// RemoteAddr is the remote address of the corresponding connection.
@@ -143,13 +143,13 @@
func (s *OutPayload) isRPCStats() {}
// OutHeader contains stats when a header is sent.
-// FullMethod, addresses and Compression are only valid if Client is true.
type OutHeader struct {
// Client is true if this OutHeader is from client side.
Client bool
// WireLength is the wire length of header.
WireLength int
+ // The following fields are valid only if Client is true.
// FullMethod is the full RPC method string, i.e., /package.service/method.
FullMethod string
// RemoteAddr is the remote address of the corresponding connection.
diff --git a/stats/stats_test.go b/stats/stats_test.go
index c770c15..35d60a4 100644
--- a/stats/stats_test.go
+++ b/stats/stats_test.go
@@ -800,13 +800,14 @@
t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
}
- var rpcctx context.Context
+ var tagInfoInCtx *stats.RPCTagInfo
for i := 0; i < len(got); i++ {
if _, ok := got[i].s.(stats.RPCStats); ok {
- if rpcctx != nil && got[i].ctx != rpcctx {
- t.Fatalf("got different contexts with stats %T", got[i].s)
+ tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
+ if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
+ t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
}
- rpcctx = got[i].ctx
+ tagInfoInCtx = tagInfoInCtxNew
}
}
diff --git a/stream.go b/stream.go
index 4204967..2576433 100644
--- a/stream.go
+++ b/stream.go
@@ -154,7 +154,7 @@
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
- ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
+ ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 736a4b3..80583ab 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -334,7 +334,6 @@
if t.authInfo != nil {
pr.AuthInfo = t.authInfo
}
- userCtx := ctx
ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string)
for _, c := range t.creds {
@@ -401,7 +400,6 @@
return nil, ErrConnClosing
}
s := t.newStream(ctx, callHdr)
- s.clientStatsCtx = userCtx
t.activeStreams[s.id] = s
// If the number of active streams change from 0 to 1, then check if keepalive
// has gone dormant. If so, wake it up.
@@ -514,7 +512,7 @@
LocalAddr: t.localAddr,
Compression: callHdr.SendCompress,
}
- t.statsHandler.HandleRPC(s.clientStatsCtx, outHeader)
+ t.statsHandler.HandleRPC(s.ctx, outHeader)
}
t.writableChan <- 0
return s, nil
@@ -993,13 +991,13 @@
Client: true,
WireLength: int(frame.Header().Length),
}
- t.statsHandler.HandleRPC(s.clientStatsCtx, inHeader)
+ t.statsHandler.HandleRPC(s.ctx, inHeader)
} else {
inTrailer := &stats.InTrailer{
Client: true,
WireLength: int(frame.Header().Length),
}
- t.statsHandler.HandleRPC(s.clientStatsCtx, inTrailer)
+ t.statsHandler.HandleRPC(s.ctx, inTrailer)
}
}
}()
diff --git a/transport/transport.go b/transport/transport.go
index c22333c..4bd4dc4 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -171,11 +171,6 @@
id uint32
// nil for client side Stream.
st ServerTransport
- // clientStatsCtx keeps the user context for stats handling.
- // It's only valid on client side. Server side stats context is same as s.ctx.
- // All client side stats collection should use the clientStatsCtx (instead of the stream context)
- // so that all the generated stats for a particular RPC can be associated in the processing phase.
- clientStatsCtx context.Context
// ctx is the associated context of the stream.
ctx context.Context
// cancel is always nil for client side Stream.