| package session |
| |
| import ( |
| "context" |
| "net" |
| "net/http" |
| "strings" |
| "sync" |
| |
| "github.com/pkg/errors" |
| "google.golang.org/grpc" |
| ) |
| |
| // Caller can invoke requests on the session |
| type Caller interface { |
| Context() context.Context |
| Supports(method string) bool |
| Conn() *grpc.ClientConn |
| Name() string |
| SharedKey() string |
| } |
| |
| type client struct { |
| Session |
| cc *grpc.ClientConn |
| supported map[string]struct{} |
| } |
| |
| // Manager is a controller for accessing currently active sessions |
| type Manager struct { |
| sessions map[string]*client |
| mu sync.Mutex |
| updateCondition *sync.Cond |
| } |
| |
| // NewManager returns a new Manager |
| func NewManager() (*Manager, error) { |
| sm := &Manager{ |
| sessions: make(map[string]*client), |
| } |
| sm.updateCondition = sync.NewCond(&sm.mu) |
| return sm, nil |
| } |
| |
| // HandleHTTPRequest handles an incoming HTTP request |
| func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) error { |
| hijacker, ok := w.(http.Hijacker) |
| if !ok { |
| return errors.New("handler does not support hijack") |
| } |
| |
| id := r.Header.Get(headerSessionID) |
| |
| proto := r.Header.Get("Upgrade") |
| |
| sm.mu.Lock() |
| if _, ok := sm.sessions[id]; ok { |
| sm.mu.Unlock() |
| return errors.Errorf("session %s already exists", id) |
| } |
| |
| if proto == "" { |
| sm.mu.Unlock() |
| return errors.New("no upgrade proto in request") |
| } |
| |
| if proto != "h2c" { |
| sm.mu.Unlock() |
| return errors.Errorf("protocol %s not supported", proto) |
| } |
| |
| conn, _, err := hijacker.Hijack() |
| if err != nil { |
| sm.mu.Unlock() |
| return errors.Wrap(err, "failed to hijack connection") |
| } |
| |
| resp := &http.Response{ |
| StatusCode: http.StatusSwitchingProtocols, |
| ProtoMajor: 1, |
| ProtoMinor: 1, |
| Header: http.Header{}, |
| } |
| resp.Header.Set("Connection", "Upgrade") |
| resp.Header.Set("Upgrade", proto) |
| |
| // set raw mode |
| conn.Write([]byte{}) |
| resp.Write(conn) |
| |
| return sm.handleConn(ctx, conn, r.Header) |
| } |
| |
| // HandleConn handles an incoming raw connection |
| func (sm *Manager) HandleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error { |
| sm.mu.Lock() |
| return sm.handleConn(ctx, conn, opts) |
| } |
| |
| // caller needs to take lock, this function will release it |
| func (sm *Manager) handleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error { |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| |
| opts = canonicalHeaders(opts) |
| |
| h := http.Header(opts) |
| id := h.Get(headerSessionID) |
| name := h.Get(headerSessionName) |
| sharedKey := h.Get(headerSessionSharedKey) |
| |
| ctx, cc, err := grpcClientConn(ctx, conn) |
| if err != nil { |
| sm.mu.Unlock() |
| return err |
| } |
| |
| c := &client{ |
| Session: Session{ |
| id: id, |
| name: name, |
| sharedKey: sharedKey, |
| ctx: ctx, |
| cancelCtx: cancel, |
| done: make(chan struct{}), |
| }, |
| cc: cc, |
| supported: make(map[string]struct{}), |
| } |
| |
| for _, m := range opts[headerSessionMethod] { |
| c.supported[strings.ToLower(m)] = struct{}{} |
| } |
| sm.sessions[id] = c |
| sm.updateCondition.Broadcast() |
| sm.mu.Unlock() |
| |
| defer func() { |
| sm.mu.Lock() |
| delete(sm.sessions, id) |
| sm.mu.Unlock() |
| }() |
| |
| <-c.ctx.Done() |
| conn.Close() |
| close(c.done) |
| |
| return nil |
| } |
| |
| // Get returns a session by ID |
| func (sm *Manager) Get(ctx context.Context, id string, noWait bool) (Caller, error) { |
| // session prefix is used to identify vertexes with different contexts so |
| // they would not collide, but for lookup we don't need the prefix |
| if p := strings.SplitN(id, ":", 2); len(p) == 2 && len(p[1]) > 0 { |
| id = p[1] |
| } |
| |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| |
| go func() { |
| select { |
| case <-ctx.Done(): |
| sm.mu.Lock() |
| sm.updateCondition.Broadcast() |
| sm.mu.Unlock() |
| } |
| }() |
| |
| var c *client |
| |
| sm.mu.Lock() |
| for { |
| select { |
| case <-ctx.Done(): |
| sm.mu.Unlock() |
| return nil, errors.Wrapf(ctx.Err(), "no active session for %s", id) |
| default: |
| } |
| var ok bool |
| c, ok = sm.sessions[id] |
| if (!ok || c.closed()) && !noWait { |
| sm.updateCondition.Wait() |
| continue |
| } |
| sm.mu.Unlock() |
| break |
| } |
| |
| if c == nil { |
| return nil, nil |
| } |
| |
| return c, nil |
| } |
| |
| func (c *client) Context() context.Context { |
| return c.context() |
| } |
| |
| func (c *client) Name() string { |
| return c.name |
| } |
| |
| func (c *client) SharedKey() string { |
| return c.sharedKey |
| } |
| |
| func (c *client) Supports(url string) bool { |
| _, ok := c.supported[strings.ToLower(url)] |
| return ok |
| } |
| func (c *client) Conn() *grpc.ClientConn { |
| return c.cc |
| } |
| |
| func canonicalHeaders(in map[string][]string) map[string][]string { |
| out := map[string][]string{} |
| for k := range in { |
| out[http.CanonicalHeaderKey(k)] = in[k] |
| } |
| return out |
| } |