| package ttrpc |
| |
| import ( |
| "context" |
| "net" |
| "sync" |
| |
| "github.com/containerd/containerd/log" |
| "github.com/gogo/protobuf/proto" |
| "github.com/pkg/errors" |
| "google.golang.org/grpc/status" |
| ) |
| |
| type Client struct { |
| codec codec |
| conn net.Conn |
| channel *channel |
| calls chan *callRequest |
| |
| closed chan struct{} |
| closeOnce sync.Once |
| done chan struct{} |
| err error |
| } |
| |
| func NewClient(conn net.Conn) *Client { |
| c := &Client{ |
| codec: codec{}, |
| conn: conn, |
| channel: newChannel(conn, conn), |
| calls: make(chan *callRequest), |
| closed: make(chan struct{}), |
| done: make(chan struct{}), |
| } |
| |
| go c.run() |
| return c |
| } |
| |
| type callRequest struct { |
| ctx context.Context |
| req *Request |
| resp *Response // response will be written back here |
| errs chan error // error written here on completion |
| } |
| |
| func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { |
| payload, err := c.codec.Marshal(req) |
| if err != nil { |
| return err |
| } |
| |
| var ( |
| creq = &Request{ |
| Service: service, |
| Method: method, |
| Payload: payload, |
| } |
| |
| cresp = &Response{} |
| ) |
| |
| if err := c.dispatch(ctx, creq, cresp); err != nil { |
| return err |
| } |
| |
| if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil { |
| return err |
| } |
| |
| if cresp.Status == nil { |
| return errors.New("no status provided on response") |
| } |
| |
| return status.ErrorProto(cresp.Status) |
| } |
| |
| func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { |
| errs := make(chan error, 1) |
| call := &callRequest{ |
| req: req, |
| resp: resp, |
| errs: errs, |
| } |
| |
| select { |
| case c.calls <- call: |
| case <-c.done: |
| return c.err |
| } |
| |
| select { |
| case err := <-errs: |
| return err |
| case <-c.done: |
| return c.err |
| } |
| } |
| |
| func (c *Client) Close() error { |
| c.closeOnce.Do(func() { |
| close(c.closed) |
| }) |
| |
| return nil |
| } |
| |
| type message struct { |
| messageHeader |
| p []byte |
| err error |
| } |
| |
| func (c *Client) run() { |
| var ( |
| streamID uint32 = 1 |
| waiters = make(map[uint32]*callRequest) |
| calls = c.calls |
| incoming = make(chan *message) |
| shutdown = make(chan struct{}) |
| shutdownErr error |
| ) |
| |
| go func() { |
| defer close(shutdown) |
| |
| // start one more goroutine to recv messages without blocking. |
| for { |
| mh, p, err := c.channel.recv(context.TODO()) |
| if err != nil { |
| _, ok := status.FromError(err) |
| if !ok { |
| // treat all errors that are not an rpc status as terminal. |
| // all others poison the connection. |
| shutdownErr = err |
| return |
| } |
| } |
| select { |
| case incoming <- &message{ |
| messageHeader: mh, |
| p: p[:mh.Length], |
| err: err, |
| }: |
| case <-c.done: |
| return |
| } |
| } |
| }() |
| |
| defer c.conn.Close() |
| defer close(c.done) |
| |
| for { |
| select { |
| case call := <-calls: |
| if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { |
| call.errs <- err |
| continue |
| } |
| |
| waiters[streamID] = call |
| streamID += 2 // enforce odd client initiated request ids |
| case msg := <-incoming: |
| call, ok := waiters[msg.StreamID] |
| if !ok { |
| log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID) |
| continue |
| } |
| |
| call.errs <- c.recv(call.resp, msg) |
| delete(waiters, msg.StreamID) |
| case <-shutdown: |
| shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") |
| c.err = shutdownErr |
| for _, waiter := range waiters { |
| waiter.errs <- shutdownErr |
| } |
| c.Close() |
| return |
| case <-c.closed: |
| // broadcast the shutdown error to the remaining waiters. |
| for _, waiter := range waiters { |
| waiter.errs <- shutdownErr |
| } |
| return |
| } |
| } |
| } |
| |
| func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { |
| p, err := c.codec.Marshal(msg) |
| if err != nil { |
| return err |
| } |
| |
| return c.channel.send(ctx, streamID, mtype, p) |
| } |
| |
| func (c *Client) recv(resp *Response, msg *message) error { |
| if msg.err != nil { |
| return msg.err |
| } |
| |
| if msg.Type != messageTypeResponse { |
| return errors.New("unkown message type received") |
| } |
| |
| defer c.channel.putmbuf(msg.p) |
| return proto.Unmarshal(msg.p, resp) |
| } |