[swarm_docker] Add swarm_docker

Change-Id: I9e14a0efd5400599de68215d38f87059f26d8c78
diff --git a/Gopkg.toml b/Gopkg.toml
index 6a8f639..af7232a 100644
--- a/Gopkg.toml
+++ b/Gopkg.toml
@@ -41,3 +41,23 @@
   name = "google.golang.org/appengine"
   branch = "master"
   source = "https://fuchsia.googlesource.com/third_party/github.com/golang/appengine.git"
+
+[[constraint]]
+  name = "github.com/docker/docker"
+  revision = "daded8da9178ec097e59a6b0b2bf12754b5472d8"
+  source = "https://github.com/moby/moby"
+
+[[override]]
+  name = "github.com/docker/distribution"
+  revision = "f4118485915abb8b163442717326597908eee6aa"
+  source = "https://github.com/docker/distribution"
+
+[[override]]
+  name = "github.com/docker/go-connections"
+  version = "v0.3.0"
+  source = "https://github.com/docker/go-connections"
+
+[[override]]
+  name = "github.com/docker/go-units"
+  version = "v0.3.2"
+  source = "https://github.com/docker/go-units"
diff --git a/cmd/swarm_docker/swarm_docker.go b/cmd/swarm_docker/swarm_docker.go
new file mode 100644
index 0000000..d04273c
--- /dev/null
+++ b/cmd/swarm_docker/swarm_docker.go
@@ -0,0 +1,595 @@
+// Copyright 2017 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This is a daemon for managing a pool of Docker containers on swarming bots.
+// The daemon is event-driven, processing the stream of events from the Docker
+// daemon.
+//
+// The daemon creates and starts a set number of containers from a specified
+// image which is pulled from the Google container registry. These containers
+// are then monitored and automatically restarted on exit (i.e. when
+// swarming_bot itself has shutdown the container by invoking /sbin/shutdown).
+//
+// When the daemon is sent the SIGTERM signal, it'll start draining the pool;
+// it'll send a SIGTERM signal to every container which causes the swarming_bot
+// to exit at the next opportunity (i.e. when not running task). When all the
+// containers have succefully exited, the daemon itself will terminate.
+//
+// When the daemon is sent the SIGINT signal, it'll immediately terminate all
+// containers without any grace period.
+package main
+
+import (
+	"context"
+	"encoding/base64"
+	"encoding/json"
+	"flag"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"net"
+	"os"
+	"os/signal"
+	"strings"
+	"sync"
+	"sync/atomic"
+	"syscall"
+	"time"
+
+	"github.com/docker/docker/api/types"
+	"github.com/docker/docker/api/types/container"
+	"github.com/docker/docker/api/types/filters"
+	"github.com/docker/docker/api/types/mount"
+	"github.com/docker/docker/api/types/network"
+	docker "github.com/docker/docker/client"
+)
+
+const (
+	registryDomain = "gcr.io"
+	swarmingURL    = "https://chromium-swarm.appspot.com"
+)
+
+var (
+	config  string        // config file path
+	timeout time.Duration // default timeout for all operations
+)
+
+func init() {
+	flag.StringVar(&config, "config", "/etc/swarm_docker/config.json", "config file path")
+	flag.DurationVar(&timeout, "timeout", 1*time.Minute, "default timeout")
+}
+
+// Config contains the service configuration.
+type Config struct {
+	Memory         int         `json:"memory"`
+	Cpus           int         `json:"cpus"`
+	User           string      `json:"user"`
+	NetworkMode    string      `json:"network_mode"`
+	ImageName      string      `json:"image_name"`
+	Project        string      `json:"project"`
+	SwarmingServer string      `json:"swarming_server"`
+	Credentials    string      `json:"credentials"`
+	Containers     []Container `json:"containers"`
+}
+
+// Container describes the container instance.
+type Container struct {
+	Name    string `json:"name"`
+	Devices []struct {
+		PathOnHost      string `json:"path_on_host"`
+		PathInContainer string `json:"path_in_container"`
+		Permissions     string `json:"permissions"`
+	} `json:"devices,omitempty"`
+	Mounts []struct {
+		Source   string `json:"source"`
+		Target   string `json:"target"`
+		ReadOnly bool   `json:"readonly,omitempty"`
+	} `json:"mounts,omitempty"`
+	hostname   string `json:"-"`
+	domainname string `json:"-"`
+	cpuset     string `json:"-"`
+	memory     int    `json:"-"`
+}
+
+// stringsFlag implements flag.Value interface for array of strings.
+type stringsFlag []string
+
+func (s *stringsFlag) String() string {
+	return strings.Join(*s, ", ")
+}
+
+func (s *stringsFlag) Set(value string) error {
+	*s = append(*s, value)
+	return nil
+}
+
+// Pool is used to manage a static pool of container instances.
+type Pool struct {
+	client     docker.CommonAPIClient
+	containers map[string]Container
+	draining   int32 // accessed atomically
+	started    time.Time
+	lock       sync.RWMutex
+	ctx        context.Context
+	cancel     context.CancelFunc
+}
+
+// NewPool creates a new Pool instance.
+func NewPool(client docker.CommonAPIClient) *Pool {
+	return &Pool{
+		client:     client,
+		containers: map[string]Container{},
+	}
+}
+
+// Create instantiates all the containers but doesn't start them.
+func (p *Pool) Create(ctx context.Context, cfg *Config, containers []*Container) <-chan error {
+	errc := make(chan error)
+	var wg sync.WaitGroup
+	p.lock.RLock()
+	for _, c := range containers {
+		wg.Add(1)
+		go func(c *Container) {
+			defer wg.Done()
+			log.Printf("creating container %s\n", c.Name)
+			if id, err := p.create(ctx, cfg, c); err != nil {
+				errc <- fmt.Errorf("failed to create container %s: %v", c.Name, err)
+			} else {
+				log.Printf("container %s created\n", id)
+				p.lock.Lock()
+				p.containers[id] = *c
+				p.lock.Unlock()
+			}
+		}(c)
+	}
+	p.lock.RUnlock()
+	go func() {
+		wg.Wait()
+		close(errc)
+	}()
+	return errc
+}
+
+// Serve runs the main loop responsible for managing containers, restarting
+// them as need and handling signals appropriately.
+func (p *Pool) Serve(ctx context.Context) <-chan error {
+	errc := make(chan error)
+	p.started = time.Now()
+
+	var wg sync.WaitGroup
+	p.lock.RLock()
+	for id, _ := range p.containers {
+		wg.Add(1)
+		go func(id string) {
+			defer wg.Done()
+			log.Printf("starting container %s\n", id)
+			if err := p.client.ContainerStart(ctx, id, types.ContainerStartOptions{}); err != nil {
+				log.Printf("failed to start container %s: %v\n", id, err)
+				p.lock.Lock()
+				delete(p.containers, id)
+				p.lock.Unlock()
+			} else {
+				log.Printf("container %s started\n", id)
+			}
+		}(id)
+	}
+	p.lock.RUnlock()
+	wg.Wait()
+
+	args := []filters.KeyValuePair{{Key: "type", Value: "container"}}
+	for id, _ := range p.containers {
+		args = append(args, filters.KeyValuePair{Key: "container", Value: id})
+	}
+	filters := filters.NewArgs(args...)
+
+	go func() {
+		msgs, errs := p.client.Events(ctx, types.EventsOptions{
+			Filters: filters,
+			Since:   p.started.Format(time.UnixDate),
+		})
+		for {
+			select {
+			case msg := <-msgs:
+				switch action := msg.Action; action {
+				case "die":
+					name := msg.Actor.Attributes["name"]
+					if atomic.LoadInt32(&p.draining) != 0 {
+						continue
+					}
+					go func() {
+						log.Printf("restart %s\n", name)
+						if err := p.client.ContainerStart(ctx, msg.Actor.ID, types.ContainerStartOptions{}); err != nil {
+							errc <- err
+						}
+					}()
+				}
+			case <-ctx.Done():
+				if err := ctx.Err(); err != context.Canceled {
+					errc <- err
+				}
+				return
+			case err := <-errs:
+				if err != io.EOF {
+					errc <- err
+				}
+				return
+			}
+		}
+	}()
+
+	return errc
+}
+
+// Drain sends a SIGTERM signal to all containers and waits for their exit.
+func (p *Pool) Drain(ctx context.Context) <-chan error {
+	atomic.AddInt32(&p.draining, 1)
+
+	var wg sync.WaitGroup
+	errc := make(chan error)
+	for id, _ := range p.containers {
+		wg.Add(1)
+		go func(id string) {
+			defer wg.Done()
+			log.Printf("send termination to %s\n", id)
+			if err := p.client.ContainerKill(ctx, id, "TERM"); err != nil {
+				errc <- err
+			}
+			msgs, errs := p.client.ContainerWait(ctx, id, container.WaitConditionNotRunning)
+			select {
+			case body := <-msgs:
+				log.Printf("container %s exited (status %d)\n", id, body.StatusCode)
+			case err := <-errs:
+				errc <- err
+			}
+		}(id)
+	}
+	go func() {
+		wg.Wait()
+		atomic.AddInt32(&p.draining, -1)
+		close(errc)
+	}()
+	return errc
+}
+
+// Remove forcibly stops and removes all containers.
+func (p *Pool) Remove(ctx context.Context) <-chan error {
+	ctx, cancel := context.WithTimeout(ctx, timeout)
+	atomic.AddInt32(&p.draining, 1)
+
+	var wg sync.WaitGroup
+	errc := make(chan error)
+	for id, _ := range p.containers {
+		wg.Add(1)
+		go func(id string) {
+			defer wg.Done()
+			log.Printf("stop and remove %s\n", id)
+			if err := p.client.ContainerStop(ctx, id, nil); err != nil {
+				errc <- err
+			}
+			if err := p.client.ContainerRemove(ctx, id, types.ContainerRemoveOptions{Force: true}); err != nil {
+				errc <- err
+			}
+		}(id)
+	}
+	go func() {
+		wg.Wait()
+		atomic.AddInt32(&p.draining, -1)
+		cancel()
+		close(errc)
+	}()
+	return errc
+}
+
+// create creates a new swarming container with the appropraite configuration.
+func (p *Pool) create(ctx context.Context, cfg *Config, c *Container) (string, error) {
+	config := container.Config{
+		Hostname:   c.hostname,
+		Domainname: c.domainname, // should be same as host
+		Image:      fmt.Sprintf("%s/%s/%s", registryDomain, cfg.Project, cfg.ImageName),
+		Cmd:        []string{"-swarming-server", cfg.SwarmingServer, "-user", cfg.User},
+	}
+	hostConfig := container.HostConfig{
+		Resources: container.Resources{
+			Memory:     int64(c.memory) * 1024 * 1024 * 1024,
+			CpusetCpus: c.cpuset,
+		},
+		Mounts: []mount.Mount{{
+			Type:     mount.TypeBind,
+			Source:   "/etc/group",
+			Target:   "/etc/group",
+			ReadOnly: true,
+		}, {
+			Type:     mount.TypeBind,
+			Source:   "/etc/passwd",
+			Target:   "/etc/passwd",
+			ReadOnly: true,
+		}, {
+			Type:     mount.TypeBind,
+			Source:   "/etc/shadow",
+			Target:   "/etc/shadow",
+			ReadOnly: true,
+		}, {
+			Type:     mount.TypeBind,
+			Source:   "/home/swarming",
+			Target:   "/home/swarming",
+			ReadOnly: false,
+		}, {
+			// Needed by swarming bot to auth with server.
+			Type:     mount.TypeBind,
+			Source:   "/var/lib/luci_machine_tokend",
+			Target:   "/var/lib/luci_machine_tokend",
+			ReadOnly: true,
+		}},
+		NetworkMode: container.NetworkMode(cfg.NetworkMode),
+	}
+	for _, d := range c.Devices {
+		hostConfig.Resources.Devices = append(hostConfig.Resources.Devices, container.DeviceMapping{
+			PathOnHost:        d.PathOnHost,
+			PathInContainer:   d.PathInContainer,
+			CgroupPermissions: d.Permissions,
+		})
+	}
+	for _, m := range c.Mounts {
+		hostConfig.Mounts = append(hostConfig.Mounts, mount.Mount{
+			Type:     mount.TypeBind,
+			Source:   m.Source,
+			Target:   m.Target,
+			ReadOnly: m.ReadOnly,
+		})
+	}
+	networkingConfig := network.NetworkingConfig{}
+	res, err := p.client.ContainerCreate(ctx, &config, &hostConfig, &networkingConfig, c.Name)
+	if err != nil {
+		return "", fmt.Errorf("failed to create new container: %v", err)
+	}
+	return res.ID, nil
+}
+
+// Image represents a Docker image.
+type Image struct {
+	client docker.ImageAPIClient
+}
+
+// NewImage creates a new Image instance.
+func NewImage(client docker.ImageAPIClient) *Image {
+	return &Image{
+		client: client,
+	}
+}
+
+func (r *Image) Exists(ctx context.Context, project, imageName string) error {
+	reference := fmt.Sprintf("%s/%s/%s", registryDomain, project, imageName)
+	// Check whether we already have the image.
+	filters := filters.NewArgs(filters.KeyValuePair{
+		Key:   "reference",
+		Value: reference,
+	})
+	summary, err := r.client.ImageList(ctx, types.ImageListOptions{Filters: filters})
+	if err != nil {
+		return fmt.Errorf("cannot list images: %v\n", err)
+	}
+	if len(summary) != 0 {
+		return nil
+	}
+	return nil
+}
+
+// Pull checks whether the image is present on the host and if not fetches
+// it from the remote container registry using credentials for authentication.
+func (r *Image) Pull(ctx context.Context, project, imageName, credentials string) error {
+	reference := fmt.Sprintf("%s/%s/%s", registryDomain, project, imageName)
+
+	buf, err := json.Marshal(types.AuthConfig{
+		Username:      "_json_key",
+		Password:      credentials,
+		ServerAddress: fmt.Sprintf("https://%s", registryDomain),
+	})
+	if err != nil {
+		return fmt.Errorf("failed to marshall auth: %v", err)
+	}
+
+	// Fetch the image from the remote container registry.
+	log.Printf("pulling the image %s\n", reference)
+	res, err := r.client.ImagePull(ctx, reference, types.ImagePullOptions{
+		RegistryAuth: base64.URLEncoding.EncodeToString(buf),
+		All:          true,
+	})
+	if err != nil {
+		return fmt.Errorf("image pull failed: %v", err)
+	}
+	defer res.Close()
+
+	type JSONMessage struct {
+		Status   string `json:"status,omitempty"`
+		Progress string `json:serror,omitempty"`
+		Error    string `json:"error,omitempty"`
+	}
+
+	// Report progress as the image is being downloaded.
+	dec := json.NewDecoder(res)
+	for dec.More() {
+		var m JSONMessage
+		err := dec.Decode(&m)
+		if err != nil {
+			return err
+		}
+		log.Printf("%s %s\n", m.Status, m.Progress)
+		if m.Error != "" {
+			log.Printf("%s\n", m.Error)
+		}
+	}
+	return nil
+}
+
+// getHostDomain return the name and domain components of FQDN.
+func getHostDomain() (string, string, error) {
+	hostname, err := os.Hostname()
+	if err != nil {
+		return "", "", err
+	}
+	strs := strings.SplitN(hostname, ".", 2)
+	if len(strs) == 2 {
+		return strs[0], strs[1], nil
+	} else {
+		return strs[0], "", nil
+	}
+}
+
+// loadConfig reads the service configuration from a file.
+func loadConfig(ctx context.Context, path string) (*Config, error) {
+	file, err := os.Open(path)
+	if err != nil {
+		return nil, err
+	}
+
+	var config Config
+	if err := json.NewDecoder(file).Decode(&config); err != nil {
+		return nil, err
+	}
+
+	return &config, err
+}
+
+// sdNotify sends a message to the init systemd daemon using the domain
+// socket referenced in the $NOTIFY_SOCKET environment variable.
+func sdNotify(unsetEnvironment bool, state string) (sent bool, err error) {
+	addr := &net.UnixAddr{
+		Name: os.Getenv("NOTIFY_SOCKET"),
+		Net:  "unixgram",
+	}
+
+	// NOTIFY_SOCKET not set
+	if addr.Name == "" {
+		return false, nil
+	}
+
+	if unsetEnvironment {
+		err = os.Unsetenv("NOTIFY_SOCKET")
+	}
+	if err != nil {
+		return false, err
+	}
+
+	conn, err := net.DialUnix(addr.Net, nil, addr)
+	if err != nil {
+		return false, fmt.Errorf("error connecting to NOTIFY_SOCKET: %v", err)
+	}
+	defer conn.Close()
+
+	_, err = conn.Write([]byte(state))
+	if err != nil {
+		return false, fmt.Errorf("error sending the message: %v", err)
+	}
+	return true, nil
+}
+
+func main() {
+	flag.Parse()
+
+	client, err := docker.NewEnvClient()
+	if err != nil {
+		log.Fatalln("cannot create new client", err)
+	}
+	defer client.Close()
+
+	ctx := context.Background()
+
+	config, err := loadConfig(ctx, config)
+	if err != nil {
+		log.Fatalln("cannot read configuration", err)
+	}
+
+	// Let systemd know that we're ready.
+	if _, err := sdNotify(false, "READY=1"); err != nil {
+		log.Printf("failed to notify systemd: %v", err)
+	}
+
+	// Prune stale images.
+	if _, err := client.ImagesPrune(ctx, filters.Args{}); err != nil {
+		log.Fatalln("failed to prune images")
+	}
+
+	// Read the credentials from the specified file.
+	bytes, err := ioutil.ReadFile(config.Credentials)
+	if err != nil {
+		log.Fatalln("failed to read credentials", err)
+	}
+
+	// Try to pull the image if not already present.
+	registry := NewImage(client)
+	if err := registry.Pull(ctx, config.Project, config.ImageName, string(bytes)); err != nil {
+		log.Fatalln("failed to pull image", err)
+	}
+
+	hostname, domainname, err := getHostDomain()
+	if err != nil {
+		log.Fatalln("failed to get hostname", err)
+	}
+
+	// Prune stale containers.
+	if _, err := client.ContainersPrune(ctx, filters.Args{}); err != nil {
+		log.Fatalln("failed to prune containers")
+	}
+
+	// Create the specified number containers.
+	containers := make([]*Container, len(config.Containers))
+	for i, c := range config.Containers {
+		cpuset := ""
+		if config.Cpus != 0 {
+			cpuset = fmt.Sprintf("%d-%d", i*config.Cpus, (i+1)*config.Cpus-1)
+		}
+		containers[i] = &Container{
+			Name:       c.Name,
+			Devices:    c.Devices,
+			Mounts:     c.Mounts,
+			hostname:   fmt.Sprintf("%s--%s", hostname, c.Name),
+			domainname: domainname,
+			cpuset:     cpuset,
+			memory:     config.Memory,
+		}
+	}
+
+	// Create and start the pool.
+	pool := NewPool(client)
+	for err := range pool.Create(ctx, config, containers) {
+		log.Fatalln("failed to create pool", err)
+	}
+
+	// SIGTERM terminates the process with graceful shutdown draining the pool.
+	signals := make(chan os.Signal, 1)
+	signal.Notify(signals, syscall.SIGTERM)
+
+	// SIGINT and SIGHUP terminate the process forcibly stopping containers.
+	interrupts := make(chan os.Signal, 1)
+	signal.Notify(interrupts, syscall.SIGINT, syscall.SIGHUP)
+
+	ctx, cancel := context.WithCancel(ctx)
+	errs := pool.Serve(ctx)
+
+loop:
+	for {
+		select {
+		case <-signals:
+			for err := range pool.Drain(ctx) {
+				log.Printf("failed to drain pool: %s\n", err)
+			}
+			break loop
+		case <-interrupts:
+			cancel()
+			break loop
+		case err := <-errs:
+			log.Printf("error in the pool: %v", err)
+		}
+	}
+
+	// Let systemd know that we began to shutdown.
+	if _, err := sdNotify(false, "STOPPING=1"); err != nil {
+		log.Printf("failed to notify systemd: %v", err)
+	}
+
+	for err := range pool.Remove(ctx) {
+		log.Printf("failed to remove container: %v", err)
+	}
+}
diff --git a/cmd/swarm_docker/swarm_docker_test.go b/cmd/swarm_docker/swarm_docker_test.go
new file mode 100644
index 0000000..77d6ccb
--- /dev/null
+++ b/cmd/swarm_docker/swarm_docker_test.go
@@ -0,0 +1,186 @@
+// Copyright 2017 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+package main
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"io"
+	"io/ioutil"
+	"testing"
+	"time"
+
+	"github.com/docker/docker/api/types"
+	"github.com/docker/docker/api/types/container"
+	"github.com/docker/docker/api/types/events"
+	"github.com/docker/docker/api/types/network"
+	"github.com/docker/docker/client"
+)
+
+// mockedCommonAPIClient provides mock implementations of Docker client methods needed by Pool.
+type mockedCommonAPIClient struct {
+	client.CommonAPIClient
+	containers map[string]bool
+	messages   chan events.Message
+	errors     chan error
+}
+
+func (m *mockedCommonAPIClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, containerName string) (container.ContainerCreateCreatedBody, error) {
+	return container.ContainerCreateCreatedBody{ID: "test"}, nil
+}
+
+func (m *mockedCommonAPIClient) ContainerStart(ctx context.Context, containerID string, options types.ContainerStartOptions) error {
+	if containerID != "test" {
+		return errors.New("invalid ID")
+	}
+	return nil
+}
+
+func (m *mockedCommonAPIClient) Events(ctx context.Context, options types.EventsOptions) (<-chan events.Message, <-chan error) {
+	return m.messages, m.errors
+}
+
+func (m *mockedCommonAPIClient) ContainerRestart(ctx context.Context, containerID string, timeout *time.Duration) error {
+	if containerID != "test" {
+		return errors.New("invalid ID")
+	}
+	return nil
+}
+
+func (m *mockedCommonAPIClient) ContainerStop(ctx context.Context, containerID string, timeout *time.Duration) error {
+	if containerID != "test" {
+		return errors.New("invalid ID")
+	}
+	return nil
+}
+
+func (m *mockedCommonAPIClient) ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.ContainerWaitOKBody, <-chan error) {
+	msgs := make(chan container.ContainerWaitOKBody, 1)
+	msgs <- container.ContainerWaitOKBody{StatusCode: 0}
+	errs := make(chan error, 0)
+	return msgs, errs
+}
+
+func (m *mockedCommonAPIClient) ContainerRemove(ctx context.Context, containerID string, options types.ContainerRemoveOptions) error {
+	if containerID != "test" {
+		return errors.New("invalid ID")
+	}
+	return nil
+}
+
+func (m *mockedCommonAPIClient) ContainerKill(ctx context.Context, container, signal string) error {
+	return nil
+}
+
+// mockedImageAPIClient provides mock implementations of Docker client methods needed by imagePull.
+type mockedImageAPIClient struct {
+	client.ImageAPIClient
+	images map[string]types.ImageSummary
+}
+
+func (m *mockedImageAPIClient) ImageList(ctx context.Context, options types.ImageListOptions) ([]types.ImageSummary, error) {
+	reference := options.Filters.Get("reference")
+	images := []types.ImageSummary{}
+	for _, r := range reference {
+		if image, ok := m.images[r]; ok {
+			images = append(images, image)
+		}
+	}
+	return images, nil
+}
+
+func (m *mockedImageAPIClient) ImagePull(ctx context.Context, ref string, options types.ImagePullOptions) (io.ReadCloser, error) {
+	if _, ok := m.images[ref]; !ok {
+		return nil, errors.New("image does not exist")
+	}
+	return ioutil.NopCloser(bytes.NewReader([]byte(`{"status": "done"}`))), nil
+}
+
+func TestImage(t *testing.T) {
+	ctx := context.Background()
+
+	mockClient := mockedImageAPIClient{
+		images: map[string]types.ImageSummary{
+			registryDomain + "/foo/bar": types.ImageSummary{},
+		},
+	}
+
+	t.Run("image exists", func(t *testing.T) {
+		image := NewImage(&mockClient)
+		if err := image.Pull(ctx, "foo", "bar", "key"); err != nil {
+			t.Error("image pull failed", err)
+		}
+	})
+
+	t.Run("image does not exists", func(t *testing.T) {
+		mockClient := mockedImageAPIClient{}
+		image := NewImage(&mockClient)
+		if err := image.Pull(ctx, "foo", "baz", "key"); err == nil {
+			t.Error("image pull didn't fail", err)
+		}
+	})
+}
+
+func TestPool(t *testing.T) {
+	ctx := context.Background()
+
+	// TODO: test behavior with multiple containers
+	containers := []*Container{&Container{
+		Name:       "test",
+		hostname:   "test--test001",
+		domainname: "local",
+		workdir:    "/b/test001",
+		cpuset:     "0-1",
+	}}
+
+	t.Run("pool create, restart and remove", func(t *testing.T) {
+		msgs := make(chan events.Message, 1)
+		errs := make(chan error, 1)
+		mockClient := mockedCommonAPIClient{
+			containers: map[string]bool{
+				"test--test001": true,
+			},
+			messages: msgs,
+			errors:   errs,
+		}
+
+		pool := NewPool(&mockClient)
+		if err := <-pool.Create(ctx, containers); err != nil {
+			t.Error("pool creation failed", err)
+		}
+		pool.Serve(ctx)
+		mockClient.messages <- events.Message{
+			Action: "die",
+			Actor: events.Actor{
+				Attributes: map[string]string{
+					"name": "test",
+				},
+			},
+		}
+		for err := range pool.Remove(ctx) {
+			t.Error("pool removal failed", err)
+		}
+	})
+
+	t.Run("pool create, drain and remove", func(t *testing.T) {
+		mockClient := mockedCommonAPIClient{
+			containers: map[string]bool{
+				"test--test001": true,
+			},
+		}
+
+		pool := NewPool(&mockClient)
+		if err := <-pool.Create(ctx, containers); err != nil {
+			t.Error("pool creation failed", err)
+		}
+		pool.Serve(ctx)
+		for err := range pool.Drain(ctx) {
+			t.Error("pool drain failed", err)
+		}
+		for err := range pool.Remove(ctx) {
+			t.Error("pool removal failed", err)
+		}
+	})
+}