| package sshforward |
| |
| import ( |
| "io/ioutil" |
| "net" |
| "os" |
| "path/filepath" |
| |
| "github.com/moby/buildkit/session" |
| "github.com/pkg/errors" |
| context "golang.org/x/net/context" |
| "golang.org/x/sync/errgroup" |
| "google.golang.org/grpc/metadata" |
| ) |
| |
| // DefaultID is the default ssh ID |
| const DefaultID = "default" |
| |
| const KeySSHID = "buildkit.ssh.id" |
| |
| type server struct { |
| caller session.Caller |
| } |
| |
| func (s *server) run(ctx context.Context, l net.Listener, id string) error { |
| eg, ctx := errgroup.WithContext(ctx) |
| |
| eg.Go(func() error { |
| <-ctx.Done() |
| return ctx.Err() |
| }) |
| |
| eg.Go(func() error { |
| for { |
| conn, err := l.Accept() |
| if err != nil { |
| return err |
| } |
| |
| client := NewSSHClient(s.caller.Conn()) |
| |
| opts := make(map[string][]string) |
| opts[KeySSHID] = []string{id} |
| ctx = metadata.NewOutgoingContext(ctx, opts) |
| |
| stream, err := client.ForwardAgent(ctx) |
| if err != nil { |
| conn.Close() |
| return err |
| } |
| |
| go Copy(ctx, conn, stream, stream.CloseSend) |
| } |
| }) |
| |
| return eg.Wait() |
| } |
| |
| type SocketOpt struct { |
| ID string |
| UID int |
| GID int |
| Mode int |
| } |
| |
| func MountSSHSocket(ctx context.Context, c session.Caller, opt SocketOpt) (sockPath string, closer func() error, err error) { |
| dir, err := ioutil.TempDir("", ".buildkit-ssh-sock") |
| if err != nil { |
| return "", nil, errors.WithStack(err) |
| } |
| |
| defer func() { |
| if err != nil { |
| os.RemoveAll(dir) |
| } |
| }() |
| |
| sockPath = filepath.Join(dir, "ssh_auth_sock") |
| |
| l, err := net.Listen("unix", sockPath) |
| if err != nil { |
| return "", nil, errors.WithStack(err) |
| } |
| |
| if err := os.Chown(sockPath, opt.UID, opt.GID); err != nil { |
| l.Close() |
| return "", nil, errors.WithStack(err) |
| } |
| if err := os.Chmod(sockPath, os.FileMode(opt.Mode)); err != nil { |
| l.Close() |
| return "", nil, errors.WithStack(err) |
| } |
| |
| s := &server{caller: c} |
| |
| id := opt.ID |
| if id == "" { |
| id = DefaultID |
| } |
| |
| go s.run(ctx, l, id) // erroring per connection allowed |
| |
| return sockPath, func() error { |
| err := l.Close() |
| os.RemoveAll(sockPath) |
| return errors.WithStack(err) |
| }, nil |
| } |
| |
| func CheckSSHID(ctx context.Context, c session.Caller, id string) error { |
| client := NewSSHClient(c.Conn()) |
| _, err := client.CheckAgent(ctx, &CheckAgentRequest{ID: id}) |
| return errors.WithStack(err) |
| } |