blob: c6a5fe74bd55ead10f2e4885799de50dc32828e7 [file] [log] [blame]
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package test
import (
"context"
"io"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
type ctxKey string
func (s) TestChainUnaryServerInterceptor(t *testing.T) {
var (
firstIntKey = ctxKey("firstIntKey")
secondIntKey = ctxKey("secondIntKey")
)
firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ctx.Value(firstIntKey) != nil {
return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
}
firstCtx := context.WithValue(ctx, firstIntKey, 0)
resp, err := handler(firstCtx, req)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
}
simpleResp, ok := resp.(*testpb.SimpleResponse)
if !ok {
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
}
return &testpb.SimpleResponse{
Payload: &testpb.Payload{
Type: simpleResp.GetPayload().GetType(),
Body: append(simpleResp.GetPayload().GetBody(), '1'),
},
}, nil
}
secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ctx.Value(firstIntKey) == nil {
return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
}
secondCtx := context.WithValue(ctx, secondIntKey, 1)
resp, err := handler(secondCtx, req)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
}
simpleResp, ok := resp.(*testpb.SimpleResponse)
if !ok {
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
}
return &testpb.SimpleResponse{
Payload: &testpb.Payload{
Type: simpleResp.GetPayload().GetType(),
Body: append(simpleResp.GetPayload().GetBody(), '2'),
},
}, nil
}
lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ctx.Value(firstIntKey) == nil {
return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) == nil {
return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
}
resp, err := handler(ctx, req)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
}
simpleResp, ok := resp.(*testpb.SimpleResponse)
if !ok {
return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
}
return &testpb.SimpleResponse{
Payload: &testpb.Payload{
Type: simpleResp.GetPayload().GetType(),
Body: append(simpleResp.GetPayload().GetBody(), '3'),
},
}, nil
}
sopts := []grpc.ServerOption{
grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
}
ss := &stubServer{
unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
if err != nil {
return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
}
return &testpb.SimpleResponse{
Payload: payload,
}, nil
},
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
resp, err := ss.client.UnaryCall(context.Background(), &testpb.SimpleRequest{})
if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
t.Fatalf("ss.client.UnaryCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
}
respBytes := resp.Payload.GetBody()
if string(respBytes) != "321" {
t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
}
}
func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
baseIntKey := ctxKey("baseIntKey")
baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ctx.Value(baseIntKey) != nil {
return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
}
baseCtx := context.WithValue(ctx, baseIntKey, 1)
return handler(baseCtx, req)
}
chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ctx.Value(baseIntKey) == nil {
return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
}
return handler(ctx, req)
}
sopts := []grpc.ServerOption{
grpc.UnaryInterceptor(baseInt),
grpc.ChainUnaryInterceptor(chainInt),
}
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
resp, err := ss.client.EmptyCall(context.Background(), &testpb.Empty{})
if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
}
}
func (s) TestChainStreamServerInterceptor(t *testing.T) {
callCounts := make([]int, 4)
firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if callCounts[0] != 0 {
return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
}
if callCounts[1] != 0 {
return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
}
if callCounts[2] != 0 {
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
}
if callCounts[3] != 0 {
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
}
callCounts[0]++
return handler(srv, stream)
}
secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if callCounts[0] != 1 {
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
}
if callCounts[1] != 0 {
return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
}
if callCounts[2] != 0 {
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
}
if callCounts[3] != 0 {
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
}
callCounts[1]++
return handler(srv, stream)
}
lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if callCounts[0] != 1 {
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
}
if callCounts[1] != 1 {
return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
}
if callCounts[2] != 0 {
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
}
if callCounts[3] != 0 {
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
}
callCounts[2]++
return handler(srv, stream)
}
sopts := []grpc.ServerOption{
grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
}
ss := &stubServer{
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
if callCounts[0] != 1 {
return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
}
if callCounts[1] != 1 {
return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
}
if callCounts[2] != 1 {
return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
}
if callCounts[3] != 0 {
return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
}
callCounts[3]++
return nil
},
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
stream, err := ss.client.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("failed to FullDuplexCall: %v", err)
}
_, err = stream.Recv()
if err != io.EOF {
t.Fatalf("failed to recv from stream: %v", err)
}
if callCounts[3] != 1 {
t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
}
}