Adding dial options for PerRPCCredentials (#1225)
* Adding dial options for PerRPCCredentials
* Added tests for PerRPCCredentials
* Post-review updates
* post-review updates
diff --git a/call.go b/call.go
index 0eb5f5c..0843865 100644
--- a/call.go
+++ b/call.go
@@ -219,6 +219,9 @@
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
+ if c.creds != nil {
+ callHdr.Creds = c.creds
+ }
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
diff --git a/rpc_util.go b/rpc_util.go
index 31a8732..6a32afd 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -46,6 +46,7 @@
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -141,6 +142,7 @@
trailerMD metadata.MD
peer *peer.Peer
traceInfo traceInfo // in trace.go
+ creds credentials.PerRPCCredentials
}
var defaultCallInfo = callInfo{failFast: true}
@@ -207,6 +209,15 @@
})
}
+// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
+// for a call.
+func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
+ return beforeCall(func(c *callInfo) error {
+ c.creds = creds
+ return nil
+ })
+}
+
// The format of the payload: compressed or not?
type payloadFormat uint8
diff --git a/stream.go b/stream.go
index 2576433..ec534a0 100644
--- a/stream.go
+++ b/stream.go
@@ -132,6 +132,9 @@
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
+ if c.creds != nil {
+ callHdr.Creds = c.creds
+ }
var trInfo traceInfo
if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 01b3e4f..ced2509 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -449,6 +449,7 @@
serverInitialConnWindowSize int32
clientInitialWindowSize int32
clientInitialConnWindowSize int32
+ perRPCCreds credentials.PerRPCCredentials
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
@@ -621,6 +622,9 @@
if te.clientInitialConnWindowSize > 0 {
opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize))
}
+ if te.perRPCCreds != nil {
+ opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
+ }
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
@@ -3984,3 +3988,120 @@
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
}
}
+
+var (
+ // test authdata
+ authdata = map[string]string{
+ "test-key": "test-value",
+ "test-key2-bin": string([]byte{1, 2, 3}),
+ }
+)
+
+type testPerRPCCredentials struct{}
+
+func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
+ return authdata, nil
+}
+
+func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
+ return false
+}
+
+func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return ctx, fmt.Errorf("didn't find metadata in context")
+ }
+ for k, vwant := range authdata {
+ vgot, ok := md[k]
+ if !ok {
+ return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
+ }
+ if vgot[0] != vwant {
+ return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
+ }
+ }
+ return ctx, nil
+}
+
+func TestPerRPCCredentialsViaDialOptions(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testPerRPCCredentialsViaDialOptions(t, e)
+ }
+}
+
+func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.tapHandle = authHandle
+ te.perRPCCreds = testPerRPCCredentials{}
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+ t.Fatalf("Test failed. Reason: %v", err)
+ }
+}
+
+func TestPerRPCCredentialsViaCallOptions(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testPerRPCCredentialsViaCallOptions(t, e)
+ }
+}
+
+func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.tapHandle = authHandle
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
+ t.Fatalf("Test failed. Reason: %v", err)
+ }
+}
+
+func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
+ }
+}
+
+func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.perRPCCreds = testPerRPCCredentials{}
+ // When credentials are provided via both dial options and call options,
+ // we apply both sets.
+ te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return ctx, fmt.Errorf("couldn't find metadata in context")
+ }
+ for k, vwant := range authdata {
+ vgot, ok := md[k]
+ if !ok {
+ return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
+ }
+ if len(vgot) != 2 {
+ return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
+ }
+ if vgot[0] != vwant || vgot[1] != vwant {
+ return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
+ }
+ }
+ return ctx, nil
+ }
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
+ t.Fatalf("Test failed. Reason: %v", err)
+ }
+}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 3e5ff73..7db73d3 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -101,6 +101,8 @@
// The scheme used: https if TLS is on, http otherwise.
scheme string
+ isSecure bool
+
creds []credentials.PerRPCCredentials
// Boolean to keep track of reading activity on transport.
@@ -181,7 +183,10 @@
conn.Close()
}
}(conn)
- var authInfo credentials.AuthInfo
+ var (
+ isSecure bool
+ authInfo credentials.AuthInfo
+ )
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
@@ -191,6 +196,7 @@
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
}
+ isSecure = true
}
kp := opts.KeepaliveParams
// Validate keepalive parameters.
@@ -230,6 +236,7 @@
scheme: scheme,
state: reachable,
activeStreams: make(map[uint32]*Stream),
+ isSecure: isSecure,
creds: opts.PerRPCCredentials,
maxStreams: defaultMaxStreamsClient,
streamsQuota: newQuotaPool(defaultMaxStreamsClient),
@@ -335,8 +342,12 @@
pr.AuthInfo = t.authInfo
}
ctx = peer.NewContext(ctx, pr)
- authData := make(map[string]string)
- for _, c := range t.creds {
+ var (
+ authData = make(map[string]string)
+ audience string
+ )
+ // Create an audience string only if needed.
+ if len(t.creds) > 0 || callHdr.Creds != nil {
// Construct URI required to get auth request metadata.
var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
@@ -347,17 +358,39 @@
}
pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 {
- return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
+ pos = len(callHdr.Method)
}
- audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
+ audience = "https://" + callHdr.Host + port + callHdr.Method[:pos]
+ }
+ for _, c := range t.creds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
- return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
+ return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
+ // Capital header names are illegal in HTTP/2.
+ k = strings.ToLower(k)
authData[k] = v
}
}
+ callAuthData := make(map[string]string)
+ // Check if credentials.PerRPCCredentials were provided via call options.
+ // Note: if these credentials are provided both via dial options and call
+ // options, then both sets of credentials will be applied.
+ if callCreds := callHdr.Creds; callCreds != nil {
+ if !t.isSecure && callCreds.RequireTransportSecurity() {
+ return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton")
+ }
+ data, err := callCreds.GetRequestMetadata(ctx, audience)
+ if err != nil {
+ return nil, streamErrorf(codes.Internal, "transport: %v", err)
+ }
+ for k, v := range data {
+ // Capital header names are illegal in HTTP/2
+ k = strings.ToLower(k)
+ callAuthData[k] = v
+ }
+ }
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
@@ -435,9 +468,10 @@
}
for k, v := range authData {
- // Capital header names are illegal in HTTP/2.
- k = strings.ToLower(k)
- t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
+ t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ }
+ for k, v := range callAuthData {
+ t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
var (
hasMD bool
diff --git a/transport/transport.go b/transport/transport.go
index 4bd4dc4..2dff7c8 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -469,6 +469,9 @@
// outbound message.
SendCompress string
+ // Creds specifies credentials.PerRPCCredentials for a call.
+ Creds credentials.PerRPCCredentials
+
// Flush indicates whether a new stream command should be sent
// to the peer without waiting for the first data. This is
// only a hint. The transport may modify the flush decision