blob: 6f50047edadc6f9ac854c721a2a7772b9ef0e39a [file] [log] [blame]
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package test
import (
"context"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
testpb "google.golang.org/grpc/test/grpc_testing"
)
type funcConfigSelector struct {
f func(iresolver.RPCInfo) *iresolver.RPCConfig
}
func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) *iresolver.RPCConfig {
return f.f(i)
}
func (s) TestConfigSelector(t *testing.T) {
gotContextChan := testutils.NewChannelWithSize(1)
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
gotContextChan.SendContext(ctx, ctx)
return &testpb.Empty{}, nil
},
}
ss.r = manual.NewBuilderWithScheme("confSel")
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctxDeadline := time.Now().Add(10 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), ctxDeadline)
defer cancel()
longCtxDeadline := time.Now().Add(30 * time.Second)
longdeadlineCtx, cancel := context.WithDeadline(context.Background(), longCtxDeadline)
defer cancel()
shorterTimeout := 3 * time.Second
testMD := metadata.MD{"footest": []string{"bazbar"}}
mdOut := metadata.MD{"handler": []string{"value"}}
var onCommittedCalled bool
testCases := []struct {
name string
md metadata.MD
config *iresolver.RPCConfig
wantMD metadata.MD
wantDeadline time.Time
wantTimeout time.Duration
}{{
name: "basic",
md: testMD,
config: &iresolver.RPCConfig{},
wantMD: testMD,
wantDeadline: ctxDeadline,
}, {
name: "alter MD",
md: testMD,
config: &iresolver.RPCConfig{
Context: metadata.NewOutgoingContext(ctx, mdOut),
},
wantMD: mdOut,
wantDeadline: ctxDeadline,
}, {
name: "alter timeout; remove MD",
md: testMD,
config: &iresolver.RPCConfig{
Context: longdeadlineCtx, // no metadata
},
wantMD: nil,
wantDeadline: longCtxDeadline,
}, {
name: "nil config",
md: metadata.MD{},
config: nil,
wantMD: nil,
wantDeadline: ctxDeadline,
}, {
name: "alter timeout via method config; remove MD",
md: testMD,
config: &iresolver.RPCConfig{
MethodConfig: serviceconfig.MethodConfig{
Timeout: &shorterTimeout,
},
},
wantMD: nil,
wantTimeout: shorterTimeout,
}, {
name: "onCommitted callback",
md: testMD,
config: &iresolver.RPCConfig{
OnCommitted: func() {
onCommittedCalled = true
},
},
wantMD: testMD,
wantDeadline: ctxDeadline,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var gotInfo *iresolver.RPCInfo
state := iresolver.SetConfigSelector(resolver.State{
Addresses: []resolver.Address{{Addr: ss.address}},
ServiceConfig: parseCfg(ss.r, "{}"),
}, funcConfigSelector{
f: func(i iresolver.RPCInfo) *iresolver.RPCConfig {
gotInfo = &i
cfg := tc.config
if cfg != nil && cfg.Context == nil {
cfg.Context = i.Context
}
return cfg
},
})
ss.r.UpdateState(state) // Blocks until config selector is applied
onCommittedCalled = false
ctx := metadata.NewOutgoingContext(ctx, tc.md)
startTime := time.Now()
if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
if gotInfo == nil {
t.Fatalf("no config selector data")
}
if want := "/grpc.testing.TestService/EmptyCall"; gotInfo.Method != want {
t.Errorf("gotInfo.Method = %q; want %q", gotInfo.Method, want)
}
gotContextI, ok := gotContextChan.ReceiveOrFail()
if !ok {
t.Fatalf("no context received")
}
gotContext := gotContextI.(context.Context)
gotMD, _ := metadata.FromOutgoingContext(gotInfo.Context)
if diff := cmp.Diff(tc.md, gotMD); diff != "" {
t.Errorf("gotInfo.Context contains MD %v; want %v\nDiffs: %v", gotMD, tc.md, diff)
}
gotMD, _ = metadata.FromIncomingContext(gotContext)
// Remove entries from gotMD not in tc.wantMD (e.g. authority header).
for k := range gotMD {
if _, ok := tc.wantMD[k]; !ok {
delete(gotMD, k)
}
}
if diff := cmp.Diff(tc.wantMD, gotMD, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("received md = %v; want %v\nDiffs: %v", gotMD, tc.wantMD, diff)
}
wantDeadline := tc.wantDeadline
if wantDeadline == (time.Time{}) {
wantDeadline = startTime.Add(tc.wantTimeout)
}
deadlineGot, _ := gotContext.Deadline()
if diff := deadlineGot.Sub(wantDeadline); diff > time.Second || diff < -time.Second {
t.Errorf("received deadline = %v; want ~%v", deadlineGot, wantDeadline)
}
if tc.config != nil && tc.config.OnCommitted != nil && !onCommittedCalled {
t.Errorf("OnCommitted callback not called")
}
})
}
}