| package ttrpc |
| |
| import ( |
| "context" |
| "math/rand" |
| "net" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "github.com/containerd/containerd/log" |
| "github.com/pkg/errors" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| ) |
| |
| var ( |
| ErrServerClosed = errors.New("ttrpc: server close") |
| ) |
| |
| type Server struct { |
| services *serviceSet |
| codec codec |
| |
| mu sync.Mutex |
| listeners map[net.Listener]struct{} |
| connections map[*serverConn]struct{} // all connections to current state |
| done chan struct{} // marks point at which we stop serving requests |
| } |
| |
| func NewServer() *Server { |
| return &Server{ |
| services: newServiceSet(), |
| done: make(chan struct{}), |
| listeners: make(map[net.Listener]struct{}), |
| connections: make(map[*serverConn]struct{}), |
| } |
| } |
| |
| func (s *Server) Register(name string, methods map[string]Method) { |
| s.services.register(name, methods) |
| } |
| |
| func (s *Server) Serve(l net.Listener) error { |
| s.addListener(l) |
| defer s.closeListener(l) |
| |
| var ( |
| ctx = context.Background() |
| backoff time.Duration |
| ) |
| |
| for { |
| conn, err := l.Accept() |
| if err != nil { |
| select { |
| case <-s.done: |
| return ErrServerClosed |
| default: |
| } |
| |
| if terr, ok := err.(interface { |
| Temporary() bool |
| }); ok && terr.Temporary() { |
| if backoff == 0 { |
| backoff = time.Millisecond |
| } else { |
| backoff *= 2 |
| } |
| |
| if max := time.Second; backoff > max { |
| backoff = max |
| } |
| |
| sleep := time.Duration(rand.Int63n(int64(backoff))) |
| log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep) |
| time.Sleep(sleep) |
| continue |
| } |
| |
| return err |
| } |
| |
| backoff = 0 |
| sc := s.newConn(conn) |
| go sc.run(ctx) |
| } |
| } |
| |
| func (s *Server) Shutdown(ctx context.Context) error { |
| s.mu.Lock() |
| lnerr := s.closeListeners() |
| select { |
| case <-s.done: |
| default: |
| // protected by mutex |
| close(s.done) |
| } |
| s.mu.Unlock() |
| |
| ticker := time.NewTicker(200 * time.Millisecond) |
| defer ticker.Stop() |
| for { |
| if s.closeIdleConns() { |
| return lnerr |
| } |
| select { |
| case <-ctx.Done(): |
| return ctx.Err() |
| case <-ticker.C: |
| } |
| } |
| } |
| |
| // Close the server without waiting for active connections. |
| func (s *Server) Close() error { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| |
| select { |
| case <-s.done: |
| default: |
| // protected by mutex |
| close(s.done) |
| } |
| |
| err := s.closeListeners() |
| for c := range s.connections { |
| c.close() |
| delete(s.connections, c) |
| } |
| |
| return err |
| } |
| |
| func (s *Server) addListener(l net.Listener) { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| s.listeners[l] = struct{}{} |
| } |
| |
| func (s *Server) closeListener(l net.Listener) error { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| |
| return s.closeListenerLocked(l) |
| } |
| |
| func (s *Server) closeListenerLocked(l net.Listener) error { |
| defer delete(s.listeners, l) |
| return l.Close() |
| } |
| |
| func (s *Server) closeListeners() error { |
| var err error |
| for l := range s.listeners { |
| if cerr := s.closeListenerLocked(l); cerr != nil && err == nil { |
| err = cerr |
| } |
| } |
| return err |
| } |
| |
| func (s *Server) addConnection(c *serverConn) { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| |
| s.connections[c] = struct{}{} |
| } |
| |
| func (s *Server) closeIdleConns() bool { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| quiescent := true |
| for c := range s.connections { |
| st, ok := c.getState() |
| if !ok || st != connStateIdle { |
| quiescent = false |
| continue |
| } |
| c.close() |
| delete(s.connections, c) |
| } |
| return quiescent |
| } |
| |
| type connState int |
| |
| const ( |
| connStateActive = iota + 1 // outstanding requests |
| connStateIdle // no requests |
| connStateClosed // closed connection |
| ) |
| |
| func (cs connState) String() string { |
| switch cs { |
| case connStateActive: |
| return "active" |
| case connStateIdle: |
| return "idle" |
| case connStateClosed: |
| return "closed" |
| default: |
| return "unknown" |
| } |
| } |
| |
| func (s *Server) newConn(conn net.Conn) *serverConn { |
| c := &serverConn{ |
| server: s, |
| conn: conn, |
| shutdown: make(chan struct{}), |
| } |
| c.setState(connStateIdle) |
| s.addConnection(c) |
| return c |
| } |
| |
| type serverConn struct { |
| server *Server |
| conn net.Conn |
| state atomic.Value |
| |
| shutdownOnce sync.Once |
| shutdown chan struct{} // forced shutdown, used by close |
| } |
| |
| func (c *serverConn) getState() (connState, bool) { |
| cs, ok := c.state.Load().(connState) |
| return cs, ok |
| } |
| |
| func (c *serverConn) setState(newstate connState) { |
| c.state.Store(newstate) |
| } |
| |
| func (c *serverConn) close() error { |
| c.shutdownOnce.Do(func() { |
| close(c.shutdown) |
| }) |
| |
| return nil |
| } |
| |
| func (c *serverConn) run(sctx context.Context) { |
| type ( |
| request struct { |
| id uint32 |
| req *Request |
| } |
| |
| response struct { |
| id uint32 |
| resp *Response |
| } |
| ) |
| |
| var ( |
| ch = newChannel(c.conn, c.conn) |
| ctx, cancel = context.WithCancel(sctx) |
| active int |
| state connState = connStateIdle |
| responses = make(chan response) |
| requests = make(chan request) |
| recvErr = make(chan error, 1) |
| shutdown = c.shutdown |
| done = make(chan struct{}) |
| ) |
| |
| defer c.conn.Close() |
| defer cancel() |
| defer close(done) |
| |
| go func(recvErr chan error) { |
| defer close(recvErr) |
| sendImmediate := func(id uint32, st *status.Status) bool { |
| select { |
| case responses <- response{ |
| // even though we've had an invalid stream id, we send it |
| // back on the same stream id so the client knows which |
| // stream id was bad. |
| id: id, |
| resp: &Response{ |
| Status: st.Proto(), |
| }, |
| }: |
| return true |
| case <-c.shutdown: |
| return false |
| case <-done: |
| return false |
| } |
| } |
| |
| for { |
| select { |
| case <-c.shutdown: |
| return |
| case <-done: |
| return |
| default: // proceed |
| } |
| |
| mh, p, err := ch.recv(ctx) |
| if err != nil { |
| status, ok := status.FromError(err) |
| if !ok { |
| recvErr <- err |
| return |
| } |
| |
| // in this case, we send an error for that particular message |
| // when the status is defined. |
| if !sendImmediate(mh.StreamID, status) { |
| return |
| } |
| |
| continue |
| } |
| |
| if mh.Type != messageTypeRequest { |
| // we must ignore this for future compat. |
| continue |
| } |
| |
| var req Request |
| if err := c.server.codec.Unmarshal(p, &req); err != nil { |
| ch.putmbuf(p) |
| if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { |
| return |
| } |
| continue |
| } |
| ch.putmbuf(p) |
| |
| if mh.StreamID%2 != 1 { |
| // enforce odd client initiated identifiers. |
| if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { |
| return |
| } |
| continue |
| } |
| |
| // Forward the request to the main loop. We don't wait on s.done |
| // because we have already accepted the client request. |
| select { |
| case requests <- request{ |
| id: mh.StreamID, |
| req: &req, |
| }: |
| case <-done: |
| return |
| } |
| } |
| }(recvErr) |
| |
| for { |
| newstate := state |
| switch { |
| case active > 0: |
| newstate = connStateActive |
| shutdown = nil |
| case active == 0: |
| newstate = connStateIdle |
| shutdown = c.shutdown // only enable this branch in idle mode |
| } |
| |
| if newstate != state { |
| c.setState(newstate) |
| state = newstate |
| } |
| |
| select { |
| case request := <-requests: |
| active++ |
| go func(id uint32) { |
| p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) |
| resp := &Response{ |
| Status: status.Proto(), |
| Payload: p, |
| } |
| |
| select { |
| case responses <- response{ |
| id: id, |
| resp: resp, |
| }: |
| case <-done: |
| } |
| }(request.id) |
| case response := <-responses: |
| p, err := c.server.codec.Marshal(response.resp) |
| if err != nil { |
| log.L.WithError(err).Error("failed marshaling response") |
| return |
| } |
| |
| if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { |
| log.L.WithError(err).Error("failed sending message on channel") |
| return |
| } |
| |
| active-- |
| case err := <-recvErr: |
| // TODO(stevvooe): Not wildly clear what we should do in this |
| // branch. Basically, it means that we are no longer receiving |
| // requests due to a terminal error. |
| recvErr = nil // connection is now "closing" |
| if err != nil { |
| log.L.WithError(err).Error("error receiving message") |
| } |
| case <-shutdown: |
| return |
| } |
| } |
| } |