blob: 6804235486ba6a1f98b2619c960c26d5f5b3e7a9 [file] [log] [blame]
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// This file is for testing only. Runs a fake grpclb balancer server.
// The name of the service to load balance for and the addresses
// of that service are provided by command line flags.
package main
import (
"flag"
"net"
"strconv"
"strings"
"time"
"google.golang.org/grpc"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/alts"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var (
port = flag.Int("port", 10000, "Port to listen on.")
backendAddrs = flag.String("backend_addrs", "", "Comma separated list of backend IP/port addresses.")
useALTS = flag.Bool("use_alts", false, "Listen on ALTS credentials.")
useTLS = flag.Bool("use_tls", false, "Listen on TLS credentials, using a test certificate.")
shortStream = flag.Bool("short_stream", false, "End the balancer stream immediately after sending the first server list.")
serviceName = flag.String("service_name", "UNSET", "Name of the service being load balanced for.")
logger = grpclog.Component("interop")
)
type loadBalancerServer struct {
lbpb.UnimplementedLoadBalancerServer
serverListResponse *lbpb.LoadBalanceResponse
}
func (l *loadBalancerServer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error {
logger.Info("Begin handling new BalancerLoad request.")
var lbReq *lbpb.LoadBalanceRequest
var err error
if lbReq, err = stream.Recv(); err != nil {
logger.Errorf("Error receiving LoadBalanceRequest: %v", err)
return err
}
logger.Info("LoadBalancerRequest received.")
initialReq := lbReq.GetInitialRequest()
if initialReq == nil {
logger.Info("Expected first request to be an InitialRequest. Got: %v", lbReq)
return status.Error(codes.Unknown, "First request not an InitialRequest")
}
// gRPC clients targeting foo.bar.com:443 can sometimes include the ":443" suffix in
// their requested names; handle this case. TODO: make 443 configurable?
var cleanedName string
var requestedNamePortNumber string
if cleanedName, requestedNamePortNumber, err = net.SplitHostPort(initialReq.Name); err != nil {
cleanedName = initialReq.Name
} else {
if requestedNamePortNumber != "443" {
logger.Info("Bad requested service name port number: %v.", requestedNamePortNumber)
return status.Error(codes.Unknown, "Bad requested service name port number")
}
}
if cleanedName != *serviceName {
logger.Info("Expected requested service name: %v. Got: %v", *serviceName, initialReq.Name)
return status.Error(codes.NotFound, "Bad requested service name")
}
if err := stream.Send(&lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbpb.InitialLoadBalanceResponse{},
},
}); err != nil {
logger.Errorf("Error sending initial LB response: %v", err)
return status.Error(codes.Unknown, "Error sending initial response")
}
logger.Info("Send LoadBalanceResponse: %v", l.serverListResponse)
if err := stream.Send(l.serverListResponse); err != nil {
logger.Errorf("Error sending LB response: %v", err)
return status.Error(codes.Unknown, "Error sending response")
}
if *shortStream {
return nil
}
for {
logger.Info("Send LoadBalanceResponse: %v", l.serverListResponse)
if err := stream.Send(l.serverListResponse); err != nil {
logger.Errorf("Error sending LB response: %v", err)
return status.Error(codes.Unknown, "Error sending response")
}
time.Sleep(10 * time.Second)
}
}
func main() {
flag.Parse()
var opts []grpc.ServerOption
if *useTLS {
certFile := testdata.Path("server1.pem")
keyFile := testdata.Path("server1.key")
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
logger.Fatalf("Failed to generate credentials %v", err)
}
opts = append(opts, grpc.Creds(creds))
} else if *useALTS {
altsOpts := alts.DefaultServerOptions()
altsTC := alts.NewServerCreds(altsOpts)
opts = append(opts, grpc.Creds(altsTC))
}
var serverList []*lbpb.Server
if len(*backendAddrs) == 0 {
serverList = make([]*lbpb.Server, 0)
} else {
rawBackendAddrs := strings.Split(*backendAddrs, ",")
serverList = make([]*lbpb.Server, len(rawBackendAddrs))
for i := range rawBackendAddrs {
rawIP, rawPort, err := net.SplitHostPort(rawBackendAddrs[i])
if err != nil {
logger.Fatalf("Failed to parse --backend_addrs[%d]=%v, error: %v", i, rawBackendAddrs[i], err)
}
ip := net.ParseIP(rawIP)
if ip == nil {
logger.Fatalf("Failed to parse ip: %v", rawIP)
}
numericPort, err := strconv.Atoi(rawPort)
if err != nil {
logger.Fatalf("Failed to convert port %v to int", rawPort)
}
logger.Infof("Adding backend ip: %v, port: %d", ip.String(), numericPort)
serverList[i] = &lbpb.Server{
IpAddress: ip,
Port: int32(numericPort),
}
}
}
serverListResponse := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: &lbpb.ServerList{
Servers: serverList,
},
},
}
server := grpc.NewServer(opts...)
logger.Infof("Begin listening on %d.", *port)
lis, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
logger.Fatalf("Failed to listen on port %v: %v", *port, err)
}
lbpb.RegisterLoadBalancerServer(server, &loadBalancerServer{
serverListResponse: serverListResponse,
})
server.Serve(lis)
}