feat(pubsublite): Abstraction for leaf and composite services (#3143)

Pub/Sub Lite publisher and subscriber clients will consist of a hierarchy of services.
diff --git a/pubsublite/internal/wire/errors.go b/pubsublite/internal/wire/errors.go
index 881624c..5dc5bc1 100644
--- a/pubsublite/internal/wire/errors.go
+++ b/pubsublite/internal/wire/errors.go
@@ -19,4 +19,17 @@
 	// ErrOverflow indicates that the publish buffers have overflowed. See
 	// comments for PublishSettings.BufferedByteLimit.
 	ErrOverflow = errors.New("pubsublite: client-side publish buffers have overflowed")
+
+	// ErrServiceUninitialized indicates that a service (e.g. publisher or
+	// subscriber) cannot perform an operation because it is uninitialized.
+	ErrServiceUninitialized = errors.New("pubsublite: service must be started")
+
+	// ErrServiceStarting indicates that a service (e.g. publisher or subscriber)
+	// cannot perform an operation because it is starting up.
+	ErrServiceStarting = errors.New("pubsublite: service is starting up")
+
+	// ErrServiceStopped indicates that a service (e.g. publisher or subscriber)
+	// cannot perform an operation because it has stoped or is in the process of
+	// stopping.
+	ErrServiceStopped = errors.New("pubsublite: service has stopped or is stopping")
 )
diff --git a/pubsublite/internal/wire/service.go b/pubsublite/internal/wire/service.go
new file mode 100644
index 0000000..928d9a8
--- /dev/null
+++ b/pubsublite/internal/wire/service.go
@@ -0,0 +1,343 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"sync"
+)
+
+// serviceStatus specifies the current status of the service. The order of the
+// values reflects the lifecycle of services. Note that some statuses may be
+// skipped.
+type serviceStatus int
+
+const (
+	// Service has not been started.
+	serviceUninitialized serviceStatus = 0
+	// Service is starting up.
+	serviceStarting serviceStatus = 1
+	// Service is active and accepting new data. Note that the underlying stream
+	// may be reconnecting due to retryable errors.
+	serviceActive serviceStatus = 2
+	// Service is gracefully shutting down by flushing all pending data. No new
+	// data is accepted.
+	serviceTerminating serviceStatus = 3
+	// Service has terminated. No new data is accepted.
+	serviceTerminated serviceStatus = 4
+)
+
+// serviceHandle is used to compare pointers to service instances.
+type serviceHandle interface{}
+
+// serviceStatusChangeFunc notifies the parent of service status changes.
+// `serviceTerminating` and `serviceTerminated` have an associated error. This
+// error may be nil if the user called Stop().
+type serviceStatusChangeFunc func(serviceHandle, serviceStatus, error)
+
+// service is the interface that must be implemented by services (essentially
+// gRPC client stream wrappers, e.g. subscriber, publisher) that can be
+// dependencies of a compositeService.
+type service interface {
+	Start()
+	Stop()
+
+	// Methods below are implemented by abstractService.
+	AddStatusChangeReceiver(serviceHandle, serviceStatusChangeFunc)
+	RemoveStatusChangeReceiver(serviceHandle)
+	Handle() serviceHandle
+	Error() error
+}
+
+// abstractService can be embedded into other structs to provide common
+// functionality for managing service status and status change receivers.
+type abstractService struct {
+	mu                    sync.Mutex
+	statusChangeReceivers []*statusChangeReceiver
+	status                serviceStatus
+	// The error that cause the service to terminate.
+	err error
+}
+
+type statusChangeReceiver struct {
+	handle         serviceHandle // For removing the receiver.
+	onStatusChange serviceStatusChangeFunc
+}
+
+func (as *abstractService) AddStatusChangeReceiver(handle serviceHandle, onStatusChange serviceStatusChangeFunc) {
+	as.mu.Lock()
+	defer as.mu.Unlock()
+	as.statusChangeReceivers = append(
+		as.statusChangeReceivers,
+		&statusChangeReceiver{handle, onStatusChange})
+}
+
+func (as *abstractService) RemoveStatusChangeReceiver(handle serviceHandle) {
+	as.mu.Lock()
+	defer as.mu.Unlock()
+
+	for i := len(as.statusChangeReceivers) - 1; i >= 0; i-- {
+		r := as.statusChangeReceivers[i]
+		if r.handle == handle {
+			// Swap with last element, erase last element and truncate the slice.
+			lastIdx := len(as.statusChangeReceivers) - 1
+			if i != lastIdx {
+				as.statusChangeReceivers[i] = as.statusChangeReceivers[lastIdx]
+			}
+			as.statusChangeReceivers[lastIdx] = nil
+			as.statusChangeReceivers = as.statusChangeReceivers[:lastIdx]
+		}
+	}
+}
+
+// Handle identifies this service instance, even when there are multiple layers
+// of embedding.
+func (as *abstractService) Handle() serviceHandle {
+	return as
+}
+
+func (as *abstractService) Error() error {
+	as.mu.Lock()
+	defer as.mu.Unlock()
+	return as.err
+}
+
+func (as *abstractService) Status() serviceStatus {
+	as.mu.Lock()
+	defer as.mu.Unlock()
+	return as.status
+}
+
+func (as *abstractService) unsafeCheckServiceStatus() error {
+	switch {
+	case as.status == serviceUninitialized:
+		return ErrServiceUninitialized
+	case as.status == serviceStarting:
+		return ErrServiceStarting
+	case as.status >= serviceTerminating:
+		return ErrServiceStopped
+	default:
+		return nil
+	}
+}
+
+// unsafeUpdateStatus assumes the service is already holding a mutex when
+// called, as it often needs to be atomic with other operations.
+func (as *abstractService) unsafeUpdateStatus(targetStatus serviceStatus, err error) bool {
+	if as.status >= targetStatus {
+		// Already at the same or later stage of the service lifecycle.
+		return false
+	}
+
+	as.status = targetStatus
+	if as.err == nil {
+		// Prevent clobbering original error.
+		as.err = err
+	}
+
+	for _, receiver := range as.statusChangeReceivers {
+		// Notify in a goroutine to prevent deadlocks if the receiver is holding a
+		// locked mutex.
+		go receiver.onStatusChange(as.Handle(), as.status, as.err)
+	}
+	return true
+}
+
+type serviceHolder struct {
+	service    service
+	lastStatus serviceStatus
+}
+
+// compositeService can be embedded into other structs to manage child services.
+// It implements the service interface and can itself be a dependency of another
+// compositeService.
+//
+// If one child service terminates due to a permanent failure, all other child
+// services are stopped. Child services can be added and removed dynamically.
+type compositeService struct {
+	// Used to block until all dependencies have started or terminated.
+	waitStarted    chan struct{}
+	waitTerminated chan struct{}
+
+	dependencies []*serviceHolder
+	removed      []*serviceHolder
+
+	abstractService
+}
+
+// init must be called after creation of the derived struct.
+func (cs *compositeService) init() {
+	cs.waitStarted = make(chan struct{})
+	cs.waitTerminated = make(chan struct{})
+}
+
+// Start up dependencies.
+func (cs *compositeService) Start() {
+	cs.mu.Lock()
+	defer cs.mu.Unlock()
+
+	if cs.abstractService.unsafeUpdateStatus(serviceStarting, nil) {
+		for _, s := range cs.dependencies {
+			s.service.Start()
+		}
+	}
+}
+
+// WaitStarted waits for all dependencies to start.
+func (cs *compositeService) WaitStarted() error {
+	<-cs.waitStarted
+	return cs.Error()
+}
+
+// Stop all dependencies.
+func (cs *compositeService) Stop() {
+	cs.mu.Lock()
+	defer cs.mu.Unlock()
+	cs.unsafeInitiateShutdown(serviceTerminating, nil)
+}
+
+// WaitStopped waits for all dependencies to stop.
+func (cs *compositeService) WaitStopped() error {
+	<-cs.waitTerminated
+	return cs.Error()
+}
+
+func (cs *compositeService) unsafeAddServices(services ...service) error {
+	if cs.status >= serviceTerminating {
+		return ErrServiceStopped
+	}
+
+	for _, s := range services {
+		s.AddStatusChangeReceiver(cs.Handle(), cs.onServiceStatusChange)
+		cs.dependencies = append(cs.dependencies, &serviceHolder{service: s})
+		if cs.status > serviceUninitialized {
+			s.Start()
+		}
+	}
+	return nil
+}
+
+func (cs *compositeService) unsafeRemoveService(service service) {
+	removeIdx := -1
+	for i, s := range cs.dependencies {
+		if s.service.Handle() == service.Handle() {
+			// Move from the `dependencies` to the `removed` list.
+			cs.removed = append(cs.removed, s)
+			removeIdx = i
+			if s.lastStatus < serviceTerminating {
+				s.service.Stop()
+			}
+			break
+		}
+	}
+	cs.dependencies = removeFromSlice(cs.dependencies, removeIdx)
+}
+
+func (cs *compositeService) unsafeInitiateShutdown(targetStatus serviceStatus, err error) {
+	for _, s := range cs.dependencies {
+		if s.lastStatus < serviceTerminating {
+			s.service.Stop()
+		}
+	}
+	cs.unsafeUpdateStatus(targetStatus, err)
+}
+
+func (cs *compositeService) unsafeUpdateStatus(targetStatus serviceStatus, err error) (ret bool) {
+	previousStatus := cs.status
+	if ret = cs.abstractService.unsafeUpdateStatus(targetStatus, err); ret {
+		// Note: the waitStarted channel must be closed when the service fails to
+		// start.
+		if previousStatus == serviceStarting {
+			close(cs.waitStarted)
+		}
+		if targetStatus == serviceTerminated {
+			close(cs.waitTerminated)
+		}
+	}
+	return
+}
+
+func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status serviceStatus, err error) {
+	cs.mu.Lock()
+	defer cs.mu.Unlock()
+
+	removeIdx := -1
+	for i, s := range cs.removed {
+		if s.service.Handle() == handle {
+			if status == serviceTerminated {
+				s.service.RemoveStatusChangeReceiver(cs.Handle())
+				removeIdx = i
+			}
+			break
+		}
+	}
+	if removeIdx >= 0 {
+		cs.removed = removeFromSlice(cs.removed, removeIdx)
+	}
+
+	// Note: we cannot rely on the service not being in the removed list above to
+	// determine whether it is an active dependency. The notification may be for a
+	// service that is no longer in cs.removed or cs.dependencies, because status
+	// changes are notified asynchronously and may be received out of order.
+	isDependency := false
+	for _, s := range cs.dependencies {
+		if s.service.Handle() == handle {
+			if status > s.lastStatus {
+				s.lastStatus = status
+			}
+			isDependency = true
+			break
+		}
+	}
+
+	// If a single service terminates, stop them all, but allow the others to
+	// flush pending data. Ignore removed services that are stopping.
+	shouldTerminate := status >= serviceTerminating && isDependency
+	numStarted := 0
+	numTerminated := 0
+
+	for _, s := range cs.dependencies {
+		if shouldTerminate && s.lastStatus < serviceTerminating {
+			s.service.Stop()
+		}
+		if s.lastStatus >= serviceActive {
+			numStarted++
+		}
+		if s.lastStatus == serviceTerminated {
+			numTerminated++
+		}
+	}
+
+	switch {
+	case numTerminated == len(cs.dependencies) && len(cs.removed) == 0:
+		cs.unsafeUpdateStatus(serviceTerminated, err)
+	case shouldTerminate:
+		cs.unsafeUpdateStatus(serviceTerminating, err)
+	case numStarted == len(cs.dependencies):
+		cs.unsafeUpdateStatus(serviceActive, err)
+	}
+}
+
+func removeFromSlice(services []*serviceHolder, removeIdx int) []*serviceHolder {
+	lastIdx := len(services) - 1
+	if removeIdx < 0 || removeIdx > lastIdx {
+		return services
+	}
+
+	// Swap with last element, erase last element and truncate the slice.
+	if removeIdx != lastIdx {
+		services[removeIdx] = services[lastIdx]
+	}
+	services[lastIdx] = nil
+	return services[:lastIdx]
+}
diff --git a/pubsublite/internal/wire/service_test.go b/pubsublite/internal/wire/service_test.go
new file mode 100644
index 0000000..9a5b22c
--- /dev/null
+++ b/pubsublite/internal/wire/service_test.go
@@ -0,0 +1,530 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"errors"
+	"fmt"
+	"testing"
+	"time"
+
+	"cloud.google.com/go/pubsublite/internal/test"
+)
+
+const receiveStatusTimeout = 5 * time.Second
+
+type testStatusChangeReceiver struct {
+	// Status change notifications are fired asynchronously, so a channel receives
+	// the statuses.
+	statusC    chan serviceStatus
+	lastStatus serviceStatus
+	name       string
+}
+
+func newTestStatusChangeReceiver(name string) *testStatusChangeReceiver {
+	return &testStatusChangeReceiver{
+		statusC: make(chan serviceStatus, 1),
+		name:    name,
+	}
+}
+
+func (sr *testStatusChangeReceiver) Handle() interface{} { return sr }
+
+func (sr *testStatusChangeReceiver) OnStatusChange(handle serviceHandle, status serviceStatus, err error) {
+	sr.statusC <- status
+}
+
+func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatus) {
+	select {
+	case status := <-sr.statusC:
+		if status <= sr.lastStatus {
+			t.Errorf("%s: Duplicate service status: %d, last status: %d", sr.name, status, sr.lastStatus)
+		}
+		if status != want {
+			t.Errorf("%s: Got service status: %d, want: %d", sr.name, status, want)
+		}
+		sr.lastStatus = status
+	case <-time.After(receiveStatusTimeout):
+		t.Errorf("%s: Did not receive status within %v", sr.name, receiveStatusTimeout)
+	}
+}
+
+func (sr *testStatusChangeReceiver) VerifyNoStatusChanges(t *testing.T) {
+	select {
+	case status := <-sr.statusC:
+		t.Errorf("%s: Unexpected service status: %d", sr.name, status)
+	default:
+	}
+}
+
+type testService struct {
+	receiver *testStatusChangeReceiver
+	abstractService
+}
+
+func newTestService(name string) *testService {
+	receiver := newTestStatusChangeReceiver(name)
+	ts := &testService{receiver: receiver}
+	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
+	return ts
+}
+
+func (ts *testService) Start() { ts.UpdateStatus(serviceStarting, nil) }
+func (ts *testService) Stop()  { ts.UpdateStatus(serviceTerminating, nil) }
+
+func (ts *testService) UpdateStatus(targetStatus serviceStatus, err error) {
+	ts.mu.Lock()
+	defer ts.mu.Unlock()
+	ts.unsafeUpdateStatus(targetStatus, err)
+}
+
+func TestServiceUpdateStatusIsLinear(t *testing.T) {
+	err1 := errors.New("error1")
+	err2 := errors.New("error2")
+
+	service := newTestService("service")
+	service.UpdateStatus(serviceStarting, nil)
+	service.receiver.VerifyStatus(t, serviceStarting)
+
+	service.UpdateStatus(serviceActive, nil)
+	service.UpdateStatus(serviceActive, nil)
+	service.receiver.VerifyStatus(t, serviceActive)
+
+	service.UpdateStatus(serviceTerminating, err1)
+	service.UpdateStatus(serviceStarting, nil)
+	service.UpdateStatus(serviceTerminating, nil)
+	service.receiver.VerifyStatus(t, serviceTerminating)
+
+	service.UpdateStatus(serviceTerminated, err2)
+	service.UpdateStatus(serviceTerminated, nil)
+	service.receiver.VerifyStatus(t, serviceTerminated)
+
+	// Verify that the first error is not clobbered by the second.
+	if got, want := service.Error(), err1; !test.ErrorEqual(got, want) {
+		t.Errorf("service.Error(): got (%v), want (%v)", got, want)
+	}
+}
+
+func TestServiceCheckServiceStatus(t *testing.T) {
+	for _, tc := range []struct {
+		status  serviceStatus
+		wantErr error
+	}{
+		{
+			status:  serviceUninitialized,
+			wantErr: ErrServiceUninitialized,
+		},
+		{
+			status:  serviceStarting,
+			wantErr: ErrServiceStarting,
+		},
+		{
+			status:  serviceActive,
+			wantErr: nil,
+		},
+		{
+			status:  serviceTerminating,
+			wantErr: ErrServiceStopped,
+		},
+		{
+			status:  serviceTerminated,
+			wantErr: ErrServiceStopped,
+		},
+	} {
+		t.Run(fmt.Sprintf("Status=%v", tc.status), func(t *testing.T) {
+			s := newTestService("service")
+			s.UpdateStatus(tc.status, nil)
+			if gotErr := s.unsafeCheckServiceStatus(); !test.ErrorEqual(gotErr, tc.wantErr) {
+				t.Errorf("service.unsafeCheckServiceStatus(): got (%v), want (%v)", gotErr, tc.wantErr)
+			}
+		})
+	}
+}
+
+func TestServiceAddRemoveStatusChangeReceiver(t *testing.T) {
+	receiver1 := newTestStatusChangeReceiver("receiver1")
+	receiver2 := newTestStatusChangeReceiver("receiver2")
+	receiver3 := newTestStatusChangeReceiver("receiver3")
+
+	service := new(testService)
+	service.AddStatusChangeReceiver(receiver1.Handle(), receiver1.OnStatusChange)
+	service.AddStatusChangeReceiver(receiver2.Handle(), receiver2.OnStatusChange)
+	service.AddStatusChangeReceiver(receiver3.Handle(), receiver3.OnStatusChange)
+
+	t.Run("All receivers", func(t *testing.T) {
+		service.UpdateStatus(serviceActive, nil)
+
+		receiver1.VerifyStatus(t, serviceActive)
+		receiver2.VerifyStatus(t, serviceActive)
+		receiver3.VerifyStatus(t, serviceActive)
+	})
+
+	t.Run("receiver1 removed", func(t *testing.T) {
+		service.RemoveStatusChangeReceiver(receiver1.Handle())
+		service.UpdateStatus(serviceTerminating, nil)
+
+		receiver1.VerifyNoStatusChanges(t)
+		receiver2.VerifyStatus(t, serviceTerminating)
+		receiver3.VerifyStatus(t, serviceTerminating)
+	})
+
+	t.Run("receiver2 removed", func(t *testing.T) {
+		service.RemoveStatusChangeReceiver(receiver2.Handle())
+		service.UpdateStatus(serviceTerminated, nil)
+
+		receiver1.VerifyNoStatusChanges(t)
+		receiver2.VerifyNoStatusChanges(t)
+		receiver3.VerifyStatus(t, serviceTerminated)
+	})
+}
+
+type testCompositeService struct {
+	receiver *testStatusChangeReceiver
+	compositeService
+}
+
+func newTestCompositeService(name string) *testCompositeService {
+	receiver := newTestStatusChangeReceiver(name)
+	ts := &testCompositeService{receiver: receiver}
+	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
+	ts.init()
+	return ts
+}
+
+func (ts *testCompositeService) AddServices(services ...service) {
+	ts.mu.Lock()
+	defer ts.mu.Unlock()
+	ts.unsafeAddServices(services...)
+}
+
+func (ts *testCompositeService) RemoveService(service service) {
+	ts.mu.Lock()
+	defer ts.mu.Unlock()
+	ts.unsafeRemoveService(service)
+}
+
+func (ts *testCompositeService) DependenciesLen() int {
+	ts.mu.Lock()
+	defer ts.mu.Unlock()
+	return len(ts.dependencies)
+}
+
+func (ts *testCompositeService) RemovedLen() int {
+	ts.mu.Lock()
+	defer ts.mu.Unlock()
+	return len(ts.removed)
+}
+
+func TestCompositeServiceNormalStop(t *testing.T) {
+	child1 := newTestService("child1")
+	child2 := newTestService("child2")
+	child3 := newTestService("child3")
+	parent := newTestCompositeService("parent")
+	parent.AddServices(child1, child2)
+
+	t.Run("Starting", func(t *testing.T) {
+		wantState := serviceUninitialized
+		if child1.Status() != wantState {
+			t.Errorf("child1: current service status: got %d, want %d", child1.Status(), wantState)
+		}
+		if child2.Status() != wantState {
+			t.Errorf("child2: current service status: got %d, want %d", child2.Status(), wantState)
+		}
+
+		parent.Start()
+
+		child1.receiver.VerifyStatus(t, serviceStarting)
+		child2.receiver.VerifyStatus(t, serviceStarting)
+		parent.receiver.VerifyStatus(t, serviceStarting)
+
+		// child3 is added after Start() and should be automatically started.
+		if child3.Status() != wantState {
+			t.Errorf("child3: current service status: got %d, want %d", child3.Status(), wantState)
+		}
+		parent.AddServices(child3)
+		child3.receiver.VerifyStatus(t, serviceStarting)
+	})
+
+	t.Run("Active", func(t *testing.T) {
+		// parent service is active once all children are active.
+		child1.UpdateStatus(serviceActive, nil)
+		child2.UpdateStatus(serviceActive, nil)
+		parent.receiver.VerifyNoStatusChanges(t)
+		child3.UpdateStatus(serviceActive, nil)
+
+		child1.receiver.VerifyStatus(t, serviceActive)
+		child2.receiver.VerifyStatus(t, serviceActive)
+		child3.receiver.VerifyStatus(t, serviceActive)
+		parent.receiver.VerifyStatus(t, serviceActive)
+		if err := parent.WaitStarted(); err != nil {
+			t.Errorf("compositeService.WaitStarted() got err: %v", err)
+		}
+	})
+
+	t.Run("Stopping", func(t *testing.T) {
+		parent.Stop()
+
+		child1.receiver.VerifyStatus(t, serviceTerminating)
+		child2.receiver.VerifyStatus(t, serviceTerminating)
+		child3.receiver.VerifyStatus(t, serviceTerminating)
+		parent.receiver.VerifyStatus(t, serviceTerminating)
+
+		// parent service is terminated once all children have terminated.
+		child1.UpdateStatus(serviceTerminated, nil)
+		child2.UpdateStatus(serviceTerminated, nil)
+		parent.receiver.VerifyNoStatusChanges(t)
+		child3.UpdateStatus(serviceTerminated, nil)
+
+		child1.receiver.VerifyStatus(t, serviceTerminated)
+		child2.receiver.VerifyStatus(t, serviceTerminated)
+		child3.receiver.VerifyStatus(t, serviceTerminated)
+		parent.receiver.VerifyStatus(t, serviceTerminated)
+		if err := parent.WaitStopped(); err != nil {
+			t.Errorf("compositeService.WaitStopped() got err: %v", err)
+		}
+	})
+}
+
+func TestCompositeServiceErrorDuringStartup(t *testing.T) {
+	child1 := newTestService("child1")
+	child2 := newTestService("child2")
+	parent := newTestCompositeService("parent")
+	parent.AddServices(child1, child2)
+
+	t.Run("Starting", func(t *testing.T) {
+		parent.Start()
+
+		parent.receiver.VerifyStatus(t, serviceStarting)
+		child1.receiver.VerifyStatus(t, serviceStarting)
+		child2.receiver.VerifyStatus(t, serviceStarting)
+	})
+
+	t.Run("Terminating", func(t *testing.T) {
+		// child1 now errors.
+		wantErr := errors.New("err during startup")
+		child1.UpdateStatus(serviceTerminated, wantErr)
+		child1.receiver.VerifyStatus(t, serviceTerminated)
+
+		// This causes parent and child2 to start terminating.
+		parent.receiver.VerifyStatus(t, serviceTerminating)
+		child2.receiver.VerifyStatus(t, serviceTerminating)
+
+		// parent has terminated once child2 has terminated.
+		child2.UpdateStatus(serviceTerminated, nil)
+		child2.receiver.VerifyStatus(t, serviceTerminated)
+		parent.receiver.VerifyStatus(t, serviceTerminated)
+		if gotErr := parent.WaitStarted(); !test.ErrorEqual(gotErr, wantErr) {
+			t.Errorf("compositeService.WaitStarted() got err: (%v), want err: (%v)", gotErr, wantErr)
+		}
+	})
+}
+
+func TestCompositeServiceErrorWhileActive(t *testing.T) {
+	child1 := newTestService("child1")
+	child2 := newTestService("child2")
+	parent := newTestCompositeService("parent")
+	parent.AddServices(child1, child2)
+
+	t.Run("Starting", func(t *testing.T) {
+		parent.Start()
+
+		child1.receiver.VerifyStatus(t, serviceStarting)
+		child2.receiver.VerifyStatus(t, serviceStarting)
+		parent.receiver.VerifyStatus(t, serviceStarting)
+	})
+
+	t.Run("Active", func(t *testing.T) {
+		child1.UpdateStatus(serviceActive, nil)
+		child2.UpdateStatus(serviceActive, nil)
+
+		child1.receiver.VerifyStatus(t, serviceActive)
+		child2.receiver.VerifyStatus(t, serviceActive)
+		parent.receiver.VerifyStatus(t, serviceActive)
+		if err := parent.WaitStarted(); err != nil {
+			t.Errorf("compositeService.WaitStarted() got err: %v", err)
+		}
+	})
+
+	t.Run("Terminating", func(t *testing.T) {
+		// child2 now errors.
+		wantErr := errors.New("err while active")
+		child2.UpdateStatus(serviceTerminating, wantErr)
+		child2.receiver.VerifyStatus(t, serviceTerminating)
+
+		// This causes parent and child1 to start terminating.
+		child1.receiver.VerifyStatus(t, serviceTerminating)
+		parent.receiver.VerifyStatus(t, serviceTerminating)
+
+		// parent has terminated once both children have terminated.
+		child1.UpdateStatus(serviceTerminated, nil)
+		child2.UpdateStatus(serviceTerminated, nil)
+		child1.receiver.VerifyStatus(t, serviceTerminated)
+		child2.receiver.VerifyStatus(t, serviceTerminated)
+		parent.receiver.VerifyStatus(t, serviceTerminated)
+		if gotErr := parent.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
+			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
+		}
+	})
+}
+
+func TestCompositeServiceRemoveService(t *testing.T) {
+	child1 := newTestService("child1")
+	child2 := newTestService("child2")
+	parent := newTestCompositeService("parent")
+	parent.AddServices(child1, child2)
+
+	t.Run("Starting", func(t *testing.T) {
+		parent.Start()
+
+		child1.receiver.VerifyStatus(t, serviceStarting)
+		child2.receiver.VerifyStatus(t, serviceStarting)
+		parent.receiver.VerifyStatus(t, serviceStarting)
+	})
+
+	t.Run("Active", func(t *testing.T) {
+		child1.UpdateStatus(serviceActive, nil)
+		child2.UpdateStatus(serviceActive, nil)
+
+		child1.receiver.VerifyStatus(t, serviceActive)
+		child2.receiver.VerifyStatus(t, serviceActive)
+		parent.receiver.VerifyStatus(t, serviceActive)
+	})
+
+	t.Run("Remove service", func(t *testing.T) {
+		// Removing child1 should stop it, but leave everything else active.
+		parent.RemoveService(child1)
+
+		if got, want := parent.DependenciesLen(), 1; got != want {
+			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
+		}
+		if got, want := parent.RemovedLen(), 1; got != want {
+			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
+		}
+
+		child1.receiver.VerifyStatus(t, serviceTerminating)
+		child2.receiver.VerifyNoStatusChanges(t)
+		parent.receiver.VerifyNoStatusChanges(t)
+
+		// After child1 has terminated, it should be removed.
+		child1.UpdateStatus(serviceTerminated, nil)
+
+		child1.receiver.VerifyStatus(t, serviceTerminated)
+		child2.receiver.VerifyNoStatusChanges(t)
+		parent.receiver.VerifyNoStatusChanges(t)
+	})
+
+	t.Run("Terminating", func(t *testing.T) {
+		// Now stop the composite service.
+		parent.Stop()
+
+		child2.receiver.VerifyStatus(t, serviceTerminating)
+		parent.receiver.VerifyStatus(t, serviceTerminating)
+
+		child2.UpdateStatus(serviceTerminated, nil)
+
+		child2.receiver.VerifyStatus(t, serviceTerminated)
+		parent.receiver.VerifyStatus(t, serviceTerminated)
+		if err := parent.WaitStopped(); err != nil {
+			t.Errorf("compositeService.WaitStopped() got err: %v", err)
+		}
+
+		if got, want := parent.DependenciesLen(), 1; got != want {
+			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
+		}
+		if got, want := parent.RemovedLen(), 0; got != want {
+			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
+		}
+	})
+}
+
+func TestCompositeServiceTree(t *testing.T) {
+	leaf1 := newTestService("leaf1")
+	leaf2 := newTestService("leaf2")
+	intermediate1 := newTestCompositeService("intermediate1")
+	intermediate1.AddServices(leaf1, leaf2)
+
+	leaf3 := newTestService("leaf3")
+	leaf4 := newTestService("leaf4")
+	intermediate2 := newTestCompositeService("intermediate2")
+	intermediate2.AddServices(leaf3, leaf4)
+
+	root := newTestCompositeService("root")
+	root.AddServices(intermediate1, intermediate2)
+
+	wantErr := errors.New("fail")
+
+	t.Run("Starting", func(t *testing.T) {
+		// Start trickles down the tree.
+		root.Start()
+
+		leaf1.receiver.VerifyStatus(t, serviceStarting)
+		leaf2.receiver.VerifyStatus(t, serviceStarting)
+		leaf3.receiver.VerifyStatus(t, serviceStarting)
+		leaf4.receiver.VerifyStatus(t, serviceStarting)
+		intermediate1.receiver.VerifyStatus(t, serviceStarting)
+		intermediate2.receiver.VerifyStatus(t, serviceStarting)
+		root.receiver.VerifyStatus(t, serviceStarting)
+	})
+
+	t.Run("Active", func(t *testing.T) {
+		// serviceActive notification trickles up the tree.
+		leaf1.UpdateStatus(serviceActive, nil)
+		leaf2.UpdateStatus(serviceActive, nil)
+		leaf3.UpdateStatus(serviceActive, nil)
+		leaf4.UpdateStatus(serviceActive, nil)
+
+		leaf1.receiver.VerifyStatus(t, serviceActive)
+		leaf2.receiver.VerifyStatus(t, serviceActive)
+		leaf3.receiver.VerifyStatus(t, serviceActive)
+		leaf4.receiver.VerifyStatus(t, serviceActive)
+		intermediate1.receiver.VerifyStatus(t, serviceActive)
+		intermediate2.receiver.VerifyStatus(t, serviceActive)
+		root.receiver.VerifyStatus(t, serviceActive)
+		if err := root.WaitStarted(); err != nil {
+			t.Errorf("compositeService.WaitStarted() got err: %v", err)
+		}
+	})
+
+	t.Run("Leaf fails", func(t *testing.T) {
+		leaf1.UpdateStatus(serviceTerminated, wantErr)
+		leaf1.receiver.VerifyStatus(t, serviceTerminated)
+
+		// Leaf service failure should trickle up the tree and across to all other
+		// leaves, causing them all to start terminating.
+		leaf2.receiver.VerifyStatus(t, serviceTerminating)
+		leaf3.receiver.VerifyStatus(t, serviceTerminating)
+		leaf4.receiver.VerifyStatus(t, serviceTerminating)
+		intermediate1.receiver.VerifyStatus(t, serviceTerminating)
+		intermediate2.receiver.VerifyStatus(t, serviceTerminating)
+		root.receiver.VerifyStatus(t, serviceTerminating)
+	})
+
+	t.Run("Terminated", func(t *testing.T) {
+		// serviceTerminated notification trickles up the tree.
+		leaf2.UpdateStatus(serviceTerminated, nil)
+		leaf3.UpdateStatus(serviceTerminated, nil)
+		leaf4.UpdateStatus(serviceTerminated, nil)
+
+		leaf2.receiver.VerifyStatus(t, serviceTerminated)
+		leaf3.receiver.VerifyStatus(t, serviceTerminated)
+		leaf4.receiver.VerifyStatus(t, serviceTerminated)
+		intermediate1.receiver.VerifyStatus(t, serviceTerminated)
+		intermediate2.receiver.VerifyStatus(t, serviceTerminated)
+		root.receiver.VerifyStatus(t, serviceTerminated)
+
+		if gotErr := root.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
+			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
+		}
+	})
+}