| /* |
| * |
| * Copyright 2021 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| package rls |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "errors" |
| "fmt" |
| "os" |
| "regexp" |
| "testing" |
| "time" |
| |
| "github.com/google/go-cmp/cmp" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/balancer" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/credentials" |
| "google.golang.org/grpc/internal" |
| rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" |
| rlstest "google.golang.org/grpc/internal/testutils/rls" |
| "google.golang.org/grpc/metadata" |
| "google.golang.org/grpc/status" |
| "google.golang.org/grpc/testdata" |
| "google.golang.org/protobuf/proto" |
| ) |
| |
| // TestControlChannelThrottled tests the case where the adaptive throttler |
| // indicates that the control channel needs to be throttled. |
| func (s) TestControlChannelThrottled(t *testing.T) { |
| // Start an RLS server and set the throttler to always throttle requests. |
| rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil) |
| overrideAdaptiveThrottler(t, alwaysThrottlingThrottler()) |
| |
| // Create a control channel to the fake RLS server. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil) |
| if err != nil { |
| t.Fatalf("Failed to create control channel to RLS server: %v", err) |
| } |
| defer ctrlCh.close() |
| |
| // Perform the lookup and expect the attempt to be throttled. |
| ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil) |
| |
| select { |
| case <-rlsReqCh: |
| t.Fatal("RouteLookup RPC invoked when control channel is throtlled") |
| case <-time.After(defaultTestShortTimeout): |
| } |
| } |
| |
| // TestLookupFailure tests the case where the RLS server responds with an error. |
| func (s) TestLookupFailure(t *testing.T) { |
| // Start an RLS server and set the throttler to never throttle requests. |
| rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil) |
| overrideAdaptiveThrottler(t, neverThrottlingThrottler()) |
| |
| // Setup the RLS server to respond with errors. |
| rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse { |
| return &rlstest.RouteLookupResponse{Err: errors.New("rls failure")} |
| }) |
| |
| // Create a control channel to the fake RLS server. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil) |
| if err != nil { |
| t.Fatalf("Failed to create control channel to RLS server: %v", err) |
| } |
| defer ctrlCh.close() |
| |
| // Perform the lookup and expect the callback to be invoked with an error. |
| errCh := make(chan error, 1) |
| ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { |
| if err == nil { |
| errCh <- errors.New("rlsClient.lookup() succeeded, should have failed") |
| return |
| } |
| errCh <- nil |
| }) |
| |
| select { |
| case <-time.After(defaultTestTimeout): |
| t.Fatal("timeout when waiting for lookup callback to be invoked") |
| case err := <-errCh: |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| |
| // TestLookupDeadlineExceeded tests the case where the RLS server does not |
| // respond within the configured rpc timeout. |
| func (s) TestLookupDeadlineExceeded(t *testing.T) { |
| // A unary interceptor which returns a status error with DeadlineExceeded. |
| interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
| return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded") |
| } |
| |
| // Start an RLS server and set the throttler to never throttle. |
| rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor)) |
| overrideAdaptiveThrottler(t, neverThrottlingThrottler()) |
| |
| // Create a control channel with a small deadline. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestShortTimeout, balancer.BuildOptions{}, nil) |
| if err != nil { |
| t.Fatalf("Failed to create control channel to RLS server: %v", err) |
| } |
| defer ctrlCh.close() |
| |
| // Perform the lookup and expect the callback to be invoked with an error. |
| errCh := make(chan error) |
| ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { |
| if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { |
| errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded) |
| return |
| } |
| errCh <- nil |
| }) |
| |
| select { |
| case <-time.After(defaultTestTimeout): |
| t.Fatal("timeout when waiting for lookup callback to be invoked") |
| case err := <-errCh: |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| |
| // testCredsBundle wraps a test call creds and real transport creds. |
| type testCredsBundle struct { |
| transportCreds credentials.TransportCredentials |
| callCreds credentials.PerRPCCredentials |
| } |
| |
| func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials { |
| return f.transportCreds |
| } |
| |
| func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { |
| return f.callCreds |
| } |
| |
| func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { |
| if mode != internal.CredsBundleModeFallback { |
| return nil, fmt.Errorf("unsupported mode: %v", mode) |
| } |
| return &testCredsBundle{ |
| transportCreds: f.transportCreds, |
| callCreds: f.callCreds, |
| }, nil |
| } |
| |
| var ( |
| // Call creds sent by the testPerRPCCredentials on the client, and verified |
| // by an interceptor on the server. |
| perRPCCredsData = map[string]string{ |
| "test-key": "test-value", |
| "test-key-bin": string([]byte{1, 2, 3}), |
| } |
| ) |
| |
| type testPerRPCCredentials struct { |
| callCreds map[string]string |
| } |
| |
| func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { |
| return f.callCreds, nil |
| } |
| |
| func (f *testPerRPCCredentials) RequireTransportSecurity() bool { |
| return true |
| } |
| |
| // Unary server interceptor which validates if the RPC contains call credentials |
| // which match `perRPCCredsData |
| func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
| md, ok := metadata.FromIncomingContext(ctx) |
| if !ok { |
| return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context") |
| } |
| for k, want := range perRPCCredsData { |
| got, ok := md[k] |
| if !ok { |
| return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k) |
| } |
| if got[0] != want { |
| return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want) |
| } |
| } |
| return handler(ctx, req) |
| } |
| |
| // makeTLSCreds is a test helper which creates a TLS based transport credentials |
| // from files specified in the arguments. |
| func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials { |
| cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath)) |
| if err != nil { |
| t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err) |
| } |
| b, err := os.ReadFile(testdata.Path(rootsPath)) |
| if err != nil { |
| t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err) |
| } |
| roots := x509.NewCertPool() |
| if !roots.AppendCertsFromPEM(b) { |
| t.Fatal("failed to append certificates") |
| } |
| return credentials.NewTLS(&tls.Config{ |
| Certificates: []tls.Certificate{cert}, |
| RootCAs: roots, |
| }) |
| } |
| |
| const ( |
| wantHeaderData = "headerData" |
| staleHeaderData = "staleHeaderData" |
| ) |
| |
| var ( |
| keyMap = map[string]string{ |
| "k1": "v1", |
| "k2": "v2", |
| } |
| wantTargets = []string{"us_east_1.firestore.googleapis.com"} |
| lookupRequest = &rlspb.RouteLookupRequest{ |
| TargetType: "grpc", |
| KeyMap: keyMap, |
| Reason: rlspb.RouteLookupRequest_REASON_MISS, |
| StaleHeaderData: staleHeaderData, |
| } |
| lookupResponse = &rlstest.RouteLookupResponse{ |
| Resp: &rlspb.RouteLookupResponse{ |
| Targets: wantTargets, |
| HeaderData: wantHeaderData, |
| }, |
| } |
| ) |
| |
| func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) { |
| // Start an RLS server and set the throttler to never throttle requests. |
| rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...) |
| overrideAdaptiveThrottler(t, neverThrottlingThrottler()) |
| |
| // Setup the RLS server to respond with a valid response. |
| rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse { |
| return lookupResponse |
| }) |
| |
| // Verify that the request received by the RLS matches the expected one. |
| rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) { |
| if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" { |
| t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff) |
| } |
| }) |
| |
| // Create a control channel to the fake server. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil) |
| if err != nil { |
| t.Fatalf("Failed to create control channel to RLS server: %v", err) |
| } |
| defer ctrlCh.close() |
| |
| // Perform the lookup and expect a successful callback invocation. |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| errCh := make(chan error, 1) |
| ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) { |
| if err != nil { |
| errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err) |
| return |
| } |
| if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData { |
| errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData) |
| return |
| } |
| errCh <- nil |
| }) |
| |
| select { |
| case <-ctx.Done(): |
| t.Fatal("timeout when waiting for lookup callback to be invoked") |
| case err := <-errCh: |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| |
| // TestControlChannelCredsSuccess tests creation of the control channel with |
| // different credentials, which are expected to succeed. |
| func (s) TestControlChannelCredsSuccess(t *testing.T) { |
| serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") |
| clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") |
| |
| tests := []struct { |
| name string |
| sopts []grpc.ServerOption |
| bopts balancer.BuildOptions |
| }{ |
| { |
| name: "insecure", |
| sopts: nil, |
| bopts: balancer.BuildOptions{}, |
| }, |
| { |
| name: "transport creds only", |
| sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, |
| bopts: balancer.BuildOptions{ |
| DialCreds: clientCreds, |
| Authority: "x.test.example.com", |
| }, |
| }, |
| { |
| name: "creds bundle", |
| sopts: []grpc.ServerOption{ |
| grpc.Creds(serverCreds), |
| grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), |
| }, |
| bopts: balancer.BuildOptions{ |
| CredsBundle: &testCredsBundle{ |
| transportCreds: clientCreds, |
| callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData}, |
| }, |
| Authority: "x.test.example.com", |
| }, |
| }, |
| } |
| for _, test := range tests { |
| t.Run(test.name, func(t *testing.T) { |
| testControlChannelCredsSuccess(t, test.sopts, test.bopts) |
| }) |
| } |
| } |
| |
| func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErrRegex *regexp.Regexp) { |
| // StartFakeRouteLookupServer a fake server. |
| // |
| // Start an RLS server and set the throttler to never throttle requests. The |
| // creds failures happen before the RPC handler on the server is invoked. |
| // So, there is need to setup the request and responses on the fake server. |
| rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...) |
| overrideAdaptiveThrottler(t, neverThrottlingThrottler()) |
| |
| // Create the control channel to the fake server. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil) |
| if err != nil { |
| t.Fatalf("Failed to create control channel to RLS server: %v", err) |
| } |
| defer ctrlCh.close() |
| |
| // Perform the lookup and expect the callback to be invoked with an error. |
| errCh := make(chan error) |
| ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { |
| if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !wantErrRegex.MatchString(st.String()) { |
| errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErrRegex.String()) |
| return |
| } |
| errCh <- nil |
| }) |
| |
| select { |
| case <-time.After(defaultTestTimeout): |
| t.Fatal("timeout when waiting for lookup callback to be invoked") |
| case err := <-errCh: |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| |
| // TestControlChannelCredsFailure tests creation of the control channel with |
| // different credentials, which are expected to fail. |
| func (s) TestControlChannelCredsFailure(t *testing.T) { |
| serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") |
| clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") |
| |
| tests := []struct { |
| name string |
| sopts []grpc.ServerOption |
| bopts balancer.BuildOptions |
| wantCode codes.Code |
| wantErrRegex *regexp.Regexp |
| }{ |
| { |
| name: "transport creds authority mismatch", |
| sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, |
| bopts: balancer.BuildOptions{ |
| DialCreds: clientCreds, |
| Authority: "authority-mismatch", |
| }, |
| wantCode: codes.Unavailable, |
| wantErrRegex: regexp.MustCompile(`transport: authentication handshake failed: .* \*\.test\.example\.com.*authority-mismatch`), |
| }, |
| { |
| name: "transport creds handshake failure", |
| sopts: nil, // server expects insecure connection |
| bopts: balancer.BuildOptions{ |
| DialCreds: clientCreds, |
| Authority: "x.test.example.com", |
| }, |
| wantCode: codes.Unavailable, |
| wantErrRegex: regexp.MustCompile("transport: authentication handshake failed: .*"), |
| }, |
| { |
| name: "call creds mismatch", |
| sopts: []grpc.ServerOption{ |
| grpc.Creds(serverCreds), |
| grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds |
| }, |
| bopts: balancer.BuildOptions{ |
| CredsBundle: &testCredsBundle{ |
| transportCreds: clientCreds, |
| callCreds: &testPerRPCCredentials{}, // sends no call creds |
| }, |
| Authority: "x.test.example.com", |
| }, |
| wantCode: codes.PermissionDenied, |
| wantErrRegex: regexp.MustCompile("didn't find call creds"), |
| }, |
| } |
| for _, test := range tests { |
| t.Run(test.name, func(t *testing.T) { |
| testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErrRegex) |
| }) |
| } |
| } |
| |
| type unsupportedCredsBundle struct { |
| credentials.Bundle |
| } |
| |
| func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { |
| return nil, fmt.Errorf("unsupported mode: %v", mode) |
| } |
| |
| // TestNewControlChannelUnsupportedCredsBundle tests the case where the control |
| // channel is configured with a bundle which does not support the mode we use. |
| func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) { |
| rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil) |
| |
| // Create the control channel to the fake server. |
| ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil) |
| if err == nil { |
| ctrlCh.close() |
| t.Fatal("newControlChannel succeeded when expected to fail") |
| } |
| } |