feat(pubsublite): Abstraction for leaf and composite services (#3143)
Pub/Sub Lite publisher and subscriber clients will consist of a hierarchy of services.
diff --git a/pubsublite/internal/wire/errors.go b/pubsublite/internal/wire/errors.go
index 881624c..5dc5bc1 100644
--- a/pubsublite/internal/wire/errors.go
+++ b/pubsublite/internal/wire/errors.go
@@ -19,4 +19,17 @@
// ErrOverflow indicates that the publish buffers have overflowed. See
// comments for PublishSettings.BufferedByteLimit.
ErrOverflow = errors.New("pubsublite: client-side publish buffers have overflowed")
+
+ // ErrServiceUninitialized indicates that a service (e.g. publisher or
+ // subscriber) cannot perform an operation because it is uninitialized.
+ ErrServiceUninitialized = errors.New("pubsublite: service must be started")
+
+ // ErrServiceStarting indicates that a service (e.g. publisher or subscriber)
+ // cannot perform an operation because it is starting up.
+ ErrServiceStarting = errors.New("pubsublite: service is starting up")
+
+ // ErrServiceStopped indicates that a service (e.g. publisher or subscriber)
+ // cannot perform an operation because it has stoped or is in the process of
+ // stopping.
+ ErrServiceStopped = errors.New("pubsublite: service has stopped or is stopping")
)
diff --git a/pubsublite/internal/wire/service.go b/pubsublite/internal/wire/service.go
new file mode 100644
index 0000000..928d9a8
--- /dev/null
+++ b/pubsublite/internal/wire/service.go
@@ -0,0 +1,343 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "sync"
+)
+
+// serviceStatus specifies the current status of the service. The order of the
+// values reflects the lifecycle of services. Note that some statuses may be
+// skipped.
+type serviceStatus int
+
+const (
+ // Service has not been started.
+ serviceUninitialized serviceStatus = 0
+ // Service is starting up.
+ serviceStarting serviceStatus = 1
+ // Service is active and accepting new data. Note that the underlying stream
+ // may be reconnecting due to retryable errors.
+ serviceActive serviceStatus = 2
+ // Service is gracefully shutting down by flushing all pending data. No new
+ // data is accepted.
+ serviceTerminating serviceStatus = 3
+ // Service has terminated. No new data is accepted.
+ serviceTerminated serviceStatus = 4
+)
+
+// serviceHandle is used to compare pointers to service instances.
+type serviceHandle interface{}
+
+// serviceStatusChangeFunc notifies the parent of service status changes.
+// `serviceTerminating` and `serviceTerminated` have an associated error. This
+// error may be nil if the user called Stop().
+type serviceStatusChangeFunc func(serviceHandle, serviceStatus, error)
+
+// service is the interface that must be implemented by services (essentially
+// gRPC client stream wrappers, e.g. subscriber, publisher) that can be
+// dependencies of a compositeService.
+type service interface {
+ Start()
+ Stop()
+
+ // Methods below are implemented by abstractService.
+ AddStatusChangeReceiver(serviceHandle, serviceStatusChangeFunc)
+ RemoveStatusChangeReceiver(serviceHandle)
+ Handle() serviceHandle
+ Error() error
+}
+
+// abstractService can be embedded into other structs to provide common
+// functionality for managing service status and status change receivers.
+type abstractService struct {
+ mu sync.Mutex
+ statusChangeReceivers []*statusChangeReceiver
+ status serviceStatus
+ // The error that cause the service to terminate.
+ err error
+}
+
+type statusChangeReceiver struct {
+ handle serviceHandle // For removing the receiver.
+ onStatusChange serviceStatusChangeFunc
+}
+
+func (as *abstractService) AddStatusChangeReceiver(handle serviceHandle, onStatusChange serviceStatusChangeFunc) {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+ as.statusChangeReceivers = append(
+ as.statusChangeReceivers,
+ &statusChangeReceiver{handle, onStatusChange})
+}
+
+func (as *abstractService) RemoveStatusChangeReceiver(handle serviceHandle) {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ for i := len(as.statusChangeReceivers) - 1; i >= 0; i-- {
+ r := as.statusChangeReceivers[i]
+ if r.handle == handle {
+ // Swap with last element, erase last element and truncate the slice.
+ lastIdx := len(as.statusChangeReceivers) - 1
+ if i != lastIdx {
+ as.statusChangeReceivers[i] = as.statusChangeReceivers[lastIdx]
+ }
+ as.statusChangeReceivers[lastIdx] = nil
+ as.statusChangeReceivers = as.statusChangeReceivers[:lastIdx]
+ }
+ }
+}
+
+// Handle identifies this service instance, even when there are multiple layers
+// of embedding.
+func (as *abstractService) Handle() serviceHandle {
+ return as
+}
+
+func (as *abstractService) Error() error {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+ return as.err
+}
+
+func (as *abstractService) Status() serviceStatus {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+ return as.status
+}
+
+func (as *abstractService) unsafeCheckServiceStatus() error {
+ switch {
+ case as.status == serviceUninitialized:
+ return ErrServiceUninitialized
+ case as.status == serviceStarting:
+ return ErrServiceStarting
+ case as.status >= serviceTerminating:
+ return ErrServiceStopped
+ default:
+ return nil
+ }
+}
+
+// unsafeUpdateStatus assumes the service is already holding a mutex when
+// called, as it often needs to be atomic with other operations.
+func (as *abstractService) unsafeUpdateStatus(targetStatus serviceStatus, err error) bool {
+ if as.status >= targetStatus {
+ // Already at the same or later stage of the service lifecycle.
+ return false
+ }
+
+ as.status = targetStatus
+ if as.err == nil {
+ // Prevent clobbering original error.
+ as.err = err
+ }
+
+ for _, receiver := range as.statusChangeReceivers {
+ // Notify in a goroutine to prevent deadlocks if the receiver is holding a
+ // locked mutex.
+ go receiver.onStatusChange(as.Handle(), as.status, as.err)
+ }
+ return true
+}
+
+type serviceHolder struct {
+ service service
+ lastStatus serviceStatus
+}
+
+// compositeService can be embedded into other structs to manage child services.
+// It implements the service interface and can itself be a dependency of another
+// compositeService.
+//
+// If one child service terminates due to a permanent failure, all other child
+// services are stopped. Child services can be added and removed dynamically.
+type compositeService struct {
+ // Used to block until all dependencies have started or terminated.
+ waitStarted chan struct{}
+ waitTerminated chan struct{}
+
+ dependencies []*serviceHolder
+ removed []*serviceHolder
+
+ abstractService
+}
+
+// init must be called after creation of the derived struct.
+func (cs *compositeService) init() {
+ cs.waitStarted = make(chan struct{})
+ cs.waitTerminated = make(chan struct{})
+}
+
+// Start up dependencies.
+func (cs *compositeService) Start() {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ if cs.abstractService.unsafeUpdateStatus(serviceStarting, nil) {
+ for _, s := range cs.dependencies {
+ s.service.Start()
+ }
+ }
+}
+
+// WaitStarted waits for all dependencies to start.
+func (cs *compositeService) WaitStarted() error {
+ <-cs.waitStarted
+ return cs.Error()
+}
+
+// Stop all dependencies.
+func (cs *compositeService) Stop() {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+ cs.unsafeInitiateShutdown(serviceTerminating, nil)
+}
+
+// WaitStopped waits for all dependencies to stop.
+func (cs *compositeService) WaitStopped() error {
+ <-cs.waitTerminated
+ return cs.Error()
+}
+
+func (cs *compositeService) unsafeAddServices(services ...service) error {
+ if cs.status >= serviceTerminating {
+ return ErrServiceStopped
+ }
+
+ for _, s := range services {
+ s.AddStatusChangeReceiver(cs.Handle(), cs.onServiceStatusChange)
+ cs.dependencies = append(cs.dependencies, &serviceHolder{service: s})
+ if cs.status > serviceUninitialized {
+ s.Start()
+ }
+ }
+ return nil
+}
+
+func (cs *compositeService) unsafeRemoveService(service service) {
+ removeIdx := -1
+ for i, s := range cs.dependencies {
+ if s.service.Handle() == service.Handle() {
+ // Move from the `dependencies` to the `removed` list.
+ cs.removed = append(cs.removed, s)
+ removeIdx = i
+ if s.lastStatus < serviceTerminating {
+ s.service.Stop()
+ }
+ break
+ }
+ }
+ cs.dependencies = removeFromSlice(cs.dependencies, removeIdx)
+}
+
+func (cs *compositeService) unsafeInitiateShutdown(targetStatus serviceStatus, err error) {
+ for _, s := range cs.dependencies {
+ if s.lastStatus < serviceTerminating {
+ s.service.Stop()
+ }
+ }
+ cs.unsafeUpdateStatus(targetStatus, err)
+}
+
+func (cs *compositeService) unsafeUpdateStatus(targetStatus serviceStatus, err error) (ret bool) {
+ previousStatus := cs.status
+ if ret = cs.abstractService.unsafeUpdateStatus(targetStatus, err); ret {
+ // Note: the waitStarted channel must be closed when the service fails to
+ // start.
+ if previousStatus == serviceStarting {
+ close(cs.waitStarted)
+ }
+ if targetStatus == serviceTerminated {
+ close(cs.waitTerminated)
+ }
+ }
+ return
+}
+
+func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status serviceStatus, err error) {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ removeIdx := -1
+ for i, s := range cs.removed {
+ if s.service.Handle() == handle {
+ if status == serviceTerminated {
+ s.service.RemoveStatusChangeReceiver(cs.Handle())
+ removeIdx = i
+ }
+ break
+ }
+ }
+ if removeIdx >= 0 {
+ cs.removed = removeFromSlice(cs.removed, removeIdx)
+ }
+
+ // Note: we cannot rely on the service not being in the removed list above to
+ // determine whether it is an active dependency. The notification may be for a
+ // service that is no longer in cs.removed or cs.dependencies, because status
+ // changes are notified asynchronously and may be received out of order.
+ isDependency := false
+ for _, s := range cs.dependencies {
+ if s.service.Handle() == handle {
+ if status > s.lastStatus {
+ s.lastStatus = status
+ }
+ isDependency = true
+ break
+ }
+ }
+
+ // If a single service terminates, stop them all, but allow the others to
+ // flush pending data. Ignore removed services that are stopping.
+ shouldTerminate := status >= serviceTerminating && isDependency
+ numStarted := 0
+ numTerminated := 0
+
+ for _, s := range cs.dependencies {
+ if shouldTerminate && s.lastStatus < serviceTerminating {
+ s.service.Stop()
+ }
+ if s.lastStatus >= serviceActive {
+ numStarted++
+ }
+ if s.lastStatus == serviceTerminated {
+ numTerminated++
+ }
+ }
+
+ switch {
+ case numTerminated == len(cs.dependencies) && len(cs.removed) == 0:
+ cs.unsafeUpdateStatus(serviceTerminated, err)
+ case shouldTerminate:
+ cs.unsafeUpdateStatus(serviceTerminating, err)
+ case numStarted == len(cs.dependencies):
+ cs.unsafeUpdateStatus(serviceActive, err)
+ }
+}
+
+func removeFromSlice(services []*serviceHolder, removeIdx int) []*serviceHolder {
+ lastIdx := len(services) - 1
+ if removeIdx < 0 || removeIdx > lastIdx {
+ return services
+ }
+
+ // Swap with last element, erase last element and truncate the slice.
+ if removeIdx != lastIdx {
+ services[removeIdx] = services[lastIdx]
+ }
+ services[lastIdx] = nil
+ return services[:lastIdx]
+}
diff --git a/pubsublite/internal/wire/service_test.go b/pubsublite/internal/wire/service_test.go
new file mode 100644
index 0000000..9a5b22c
--- /dev/null
+++ b/pubsublite/internal/wire/service_test.go
@@ -0,0 +1,530 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ "cloud.google.com/go/pubsublite/internal/test"
+)
+
+const receiveStatusTimeout = 5 * time.Second
+
+type testStatusChangeReceiver struct {
+ // Status change notifications are fired asynchronously, so a channel receives
+ // the statuses.
+ statusC chan serviceStatus
+ lastStatus serviceStatus
+ name string
+}
+
+func newTestStatusChangeReceiver(name string) *testStatusChangeReceiver {
+ return &testStatusChangeReceiver{
+ statusC: make(chan serviceStatus, 1),
+ name: name,
+ }
+}
+
+func (sr *testStatusChangeReceiver) Handle() interface{} { return sr }
+
+func (sr *testStatusChangeReceiver) OnStatusChange(handle serviceHandle, status serviceStatus, err error) {
+ sr.statusC <- status
+}
+
+func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatus) {
+ select {
+ case status := <-sr.statusC:
+ if status <= sr.lastStatus {
+ t.Errorf("%s: Duplicate service status: %d, last status: %d", sr.name, status, sr.lastStatus)
+ }
+ if status != want {
+ t.Errorf("%s: Got service status: %d, want: %d", sr.name, status, want)
+ }
+ sr.lastStatus = status
+ case <-time.After(receiveStatusTimeout):
+ t.Errorf("%s: Did not receive status within %v", sr.name, receiveStatusTimeout)
+ }
+}
+
+func (sr *testStatusChangeReceiver) VerifyNoStatusChanges(t *testing.T) {
+ select {
+ case status := <-sr.statusC:
+ t.Errorf("%s: Unexpected service status: %d", sr.name, status)
+ default:
+ }
+}
+
+type testService struct {
+ receiver *testStatusChangeReceiver
+ abstractService
+}
+
+func newTestService(name string) *testService {
+ receiver := newTestStatusChangeReceiver(name)
+ ts := &testService{receiver: receiver}
+ ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
+ return ts
+}
+
+func (ts *testService) Start() { ts.UpdateStatus(serviceStarting, nil) }
+func (ts *testService) Stop() { ts.UpdateStatus(serviceTerminating, nil) }
+
+func (ts *testService) UpdateStatus(targetStatus serviceStatus, err error) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.unsafeUpdateStatus(targetStatus, err)
+}
+
+func TestServiceUpdateStatusIsLinear(t *testing.T) {
+ err1 := errors.New("error1")
+ err2 := errors.New("error2")
+
+ service := newTestService("service")
+ service.UpdateStatus(serviceStarting, nil)
+ service.receiver.VerifyStatus(t, serviceStarting)
+
+ service.UpdateStatus(serviceActive, nil)
+ service.UpdateStatus(serviceActive, nil)
+ service.receiver.VerifyStatus(t, serviceActive)
+
+ service.UpdateStatus(serviceTerminating, err1)
+ service.UpdateStatus(serviceStarting, nil)
+ service.UpdateStatus(serviceTerminating, nil)
+ service.receiver.VerifyStatus(t, serviceTerminating)
+
+ service.UpdateStatus(serviceTerminated, err2)
+ service.UpdateStatus(serviceTerminated, nil)
+ service.receiver.VerifyStatus(t, serviceTerminated)
+
+ // Verify that the first error is not clobbered by the second.
+ if got, want := service.Error(), err1; !test.ErrorEqual(got, want) {
+ t.Errorf("service.Error(): got (%v), want (%v)", got, want)
+ }
+}
+
+func TestServiceCheckServiceStatus(t *testing.T) {
+ for _, tc := range []struct {
+ status serviceStatus
+ wantErr error
+ }{
+ {
+ status: serviceUninitialized,
+ wantErr: ErrServiceUninitialized,
+ },
+ {
+ status: serviceStarting,
+ wantErr: ErrServiceStarting,
+ },
+ {
+ status: serviceActive,
+ wantErr: nil,
+ },
+ {
+ status: serviceTerminating,
+ wantErr: ErrServiceStopped,
+ },
+ {
+ status: serviceTerminated,
+ wantErr: ErrServiceStopped,
+ },
+ } {
+ t.Run(fmt.Sprintf("Status=%v", tc.status), func(t *testing.T) {
+ s := newTestService("service")
+ s.UpdateStatus(tc.status, nil)
+ if gotErr := s.unsafeCheckServiceStatus(); !test.ErrorEqual(gotErr, tc.wantErr) {
+ t.Errorf("service.unsafeCheckServiceStatus(): got (%v), want (%v)", gotErr, tc.wantErr)
+ }
+ })
+ }
+}
+
+func TestServiceAddRemoveStatusChangeReceiver(t *testing.T) {
+ receiver1 := newTestStatusChangeReceiver("receiver1")
+ receiver2 := newTestStatusChangeReceiver("receiver2")
+ receiver3 := newTestStatusChangeReceiver("receiver3")
+
+ service := new(testService)
+ service.AddStatusChangeReceiver(receiver1.Handle(), receiver1.OnStatusChange)
+ service.AddStatusChangeReceiver(receiver2.Handle(), receiver2.OnStatusChange)
+ service.AddStatusChangeReceiver(receiver3.Handle(), receiver3.OnStatusChange)
+
+ t.Run("All receivers", func(t *testing.T) {
+ service.UpdateStatus(serviceActive, nil)
+
+ receiver1.VerifyStatus(t, serviceActive)
+ receiver2.VerifyStatus(t, serviceActive)
+ receiver3.VerifyStatus(t, serviceActive)
+ })
+
+ t.Run("receiver1 removed", func(t *testing.T) {
+ service.RemoveStatusChangeReceiver(receiver1.Handle())
+ service.UpdateStatus(serviceTerminating, nil)
+
+ receiver1.VerifyNoStatusChanges(t)
+ receiver2.VerifyStatus(t, serviceTerminating)
+ receiver3.VerifyStatus(t, serviceTerminating)
+ })
+
+ t.Run("receiver2 removed", func(t *testing.T) {
+ service.RemoveStatusChangeReceiver(receiver2.Handle())
+ service.UpdateStatus(serviceTerminated, nil)
+
+ receiver1.VerifyNoStatusChanges(t)
+ receiver2.VerifyNoStatusChanges(t)
+ receiver3.VerifyStatus(t, serviceTerminated)
+ })
+}
+
+type testCompositeService struct {
+ receiver *testStatusChangeReceiver
+ compositeService
+}
+
+func newTestCompositeService(name string) *testCompositeService {
+ receiver := newTestStatusChangeReceiver(name)
+ ts := &testCompositeService{receiver: receiver}
+ ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
+ ts.init()
+ return ts
+}
+
+func (ts *testCompositeService) AddServices(services ...service) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.unsafeAddServices(services...)
+}
+
+func (ts *testCompositeService) RemoveService(service service) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.unsafeRemoveService(service)
+}
+
+func (ts *testCompositeService) DependenciesLen() int {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ return len(ts.dependencies)
+}
+
+func (ts *testCompositeService) RemovedLen() int {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ return len(ts.removed)
+}
+
+func TestCompositeServiceNormalStop(t *testing.T) {
+ child1 := newTestService("child1")
+ child2 := newTestService("child2")
+ child3 := newTestService("child3")
+ parent := newTestCompositeService("parent")
+ parent.AddServices(child1, child2)
+
+ t.Run("Starting", func(t *testing.T) {
+ wantState := serviceUninitialized
+ if child1.Status() != wantState {
+ t.Errorf("child1: current service status: got %d, want %d", child1.Status(), wantState)
+ }
+ if child2.Status() != wantState {
+ t.Errorf("child2: current service status: got %d, want %d", child2.Status(), wantState)
+ }
+
+ parent.Start()
+
+ child1.receiver.VerifyStatus(t, serviceStarting)
+ child2.receiver.VerifyStatus(t, serviceStarting)
+ parent.receiver.VerifyStatus(t, serviceStarting)
+
+ // child3 is added after Start() and should be automatically started.
+ if child3.Status() != wantState {
+ t.Errorf("child3: current service status: got %d, want %d", child3.Status(), wantState)
+ }
+ parent.AddServices(child3)
+ child3.receiver.VerifyStatus(t, serviceStarting)
+ })
+
+ t.Run("Active", func(t *testing.T) {
+ // parent service is active once all children are active.
+ child1.UpdateStatus(serviceActive, nil)
+ child2.UpdateStatus(serviceActive, nil)
+ parent.receiver.VerifyNoStatusChanges(t)
+ child3.UpdateStatus(serviceActive, nil)
+
+ child1.receiver.VerifyStatus(t, serviceActive)
+ child2.receiver.VerifyStatus(t, serviceActive)
+ child3.receiver.VerifyStatus(t, serviceActive)
+ parent.receiver.VerifyStatus(t, serviceActive)
+ if err := parent.WaitStarted(); err != nil {
+ t.Errorf("compositeService.WaitStarted() got err: %v", err)
+ }
+ })
+
+ t.Run("Stopping", func(t *testing.T) {
+ parent.Stop()
+
+ child1.receiver.VerifyStatus(t, serviceTerminating)
+ child2.receiver.VerifyStatus(t, serviceTerminating)
+ child3.receiver.VerifyStatus(t, serviceTerminating)
+ parent.receiver.VerifyStatus(t, serviceTerminating)
+
+ // parent service is terminated once all children have terminated.
+ child1.UpdateStatus(serviceTerminated, nil)
+ child2.UpdateStatus(serviceTerminated, nil)
+ parent.receiver.VerifyNoStatusChanges(t)
+ child3.UpdateStatus(serviceTerminated, nil)
+
+ child1.receiver.VerifyStatus(t, serviceTerminated)
+ child2.receiver.VerifyStatus(t, serviceTerminated)
+ child3.receiver.VerifyStatus(t, serviceTerminated)
+ parent.receiver.VerifyStatus(t, serviceTerminated)
+ if err := parent.WaitStopped(); err != nil {
+ t.Errorf("compositeService.WaitStopped() got err: %v", err)
+ }
+ })
+}
+
+func TestCompositeServiceErrorDuringStartup(t *testing.T) {
+ child1 := newTestService("child1")
+ child2 := newTestService("child2")
+ parent := newTestCompositeService("parent")
+ parent.AddServices(child1, child2)
+
+ t.Run("Starting", func(t *testing.T) {
+ parent.Start()
+
+ parent.receiver.VerifyStatus(t, serviceStarting)
+ child1.receiver.VerifyStatus(t, serviceStarting)
+ child2.receiver.VerifyStatus(t, serviceStarting)
+ })
+
+ t.Run("Terminating", func(t *testing.T) {
+ // child1 now errors.
+ wantErr := errors.New("err during startup")
+ child1.UpdateStatus(serviceTerminated, wantErr)
+ child1.receiver.VerifyStatus(t, serviceTerminated)
+
+ // This causes parent and child2 to start terminating.
+ parent.receiver.VerifyStatus(t, serviceTerminating)
+ child2.receiver.VerifyStatus(t, serviceTerminating)
+
+ // parent has terminated once child2 has terminated.
+ child2.UpdateStatus(serviceTerminated, nil)
+ child2.receiver.VerifyStatus(t, serviceTerminated)
+ parent.receiver.VerifyStatus(t, serviceTerminated)
+ if gotErr := parent.WaitStarted(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("compositeService.WaitStarted() got err: (%v), want err: (%v)", gotErr, wantErr)
+ }
+ })
+}
+
+func TestCompositeServiceErrorWhileActive(t *testing.T) {
+ child1 := newTestService("child1")
+ child2 := newTestService("child2")
+ parent := newTestCompositeService("parent")
+ parent.AddServices(child1, child2)
+
+ t.Run("Starting", func(t *testing.T) {
+ parent.Start()
+
+ child1.receiver.VerifyStatus(t, serviceStarting)
+ child2.receiver.VerifyStatus(t, serviceStarting)
+ parent.receiver.VerifyStatus(t, serviceStarting)
+ })
+
+ t.Run("Active", func(t *testing.T) {
+ child1.UpdateStatus(serviceActive, nil)
+ child2.UpdateStatus(serviceActive, nil)
+
+ child1.receiver.VerifyStatus(t, serviceActive)
+ child2.receiver.VerifyStatus(t, serviceActive)
+ parent.receiver.VerifyStatus(t, serviceActive)
+ if err := parent.WaitStarted(); err != nil {
+ t.Errorf("compositeService.WaitStarted() got err: %v", err)
+ }
+ })
+
+ t.Run("Terminating", func(t *testing.T) {
+ // child2 now errors.
+ wantErr := errors.New("err while active")
+ child2.UpdateStatus(serviceTerminating, wantErr)
+ child2.receiver.VerifyStatus(t, serviceTerminating)
+
+ // This causes parent and child1 to start terminating.
+ child1.receiver.VerifyStatus(t, serviceTerminating)
+ parent.receiver.VerifyStatus(t, serviceTerminating)
+
+ // parent has terminated once both children have terminated.
+ child1.UpdateStatus(serviceTerminated, nil)
+ child2.UpdateStatus(serviceTerminated, nil)
+ child1.receiver.VerifyStatus(t, serviceTerminated)
+ child2.receiver.VerifyStatus(t, serviceTerminated)
+ parent.receiver.VerifyStatus(t, serviceTerminated)
+ if gotErr := parent.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
+ }
+ })
+}
+
+func TestCompositeServiceRemoveService(t *testing.T) {
+ child1 := newTestService("child1")
+ child2 := newTestService("child2")
+ parent := newTestCompositeService("parent")
+ parent.AddServices(child1, child2)
+
+ t.Run("Starting", func(t *testing.T) {
+ parent.Start()
+
+ child1.receiver.VerifyStatus(t, serviceStarting)
+ child2.receiver.VerifyStatus(t, serviceStarting)
+ parent.receiver.VerifyStatus(t, serviceStarting)
+ })
+
+ t.Run("Active", func(t *testing.T) {
+ child1.UpdateStatus(serviceActive, nil)
+ child2.UpdateStatus(serviceActive, nil)
+
+ child1.receiver.VerifyStatus(t, serviceActive)
+ child2.receiver.VerifyStatus(t, serviceActive)
+ parent.receiver.VerifyStatus(t, serviceActive)
+ })
+
+ t.Run("Remove service", func(t *testing.T) {
+ // Removing child1 should stop it, but leave everything else active.
+ parent.RemoveService(child1)
+
+ if got, want := parent.DependenciesLen(), 1; got != want {
+ t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
+ }
+ if got, want := parent.RemovedLen(), 1; got != want {
+ t.Errorf("compositeService.removed: got len %d, want %d", got, want)
+ }
+
+ child1.receiver.VerifyStatus(t, serviceTerminating)
+ child2.receiver.VerifyNoStatusChanges(t)
+ parent.receiver.VerifyNoStatusChanges(t)
+
+ // After child1 has terminated, it should be removed.
+ child1.UpdateStatus(serviceTerminated, nil)
+
+ child1.receiver.VerifyStatus(t, serviceTerminated)
+ child2.receiver.VerifyNoStatusChanges(t)
+ parent.receiver.VerifyNoStatusChanges(t)
+ })
+
+ t.Run("Terminating", func(t *testing.T) {
+ // Now stop the composite service.
+ parent.Stop()
+
+ child2.receiver.VerifyStatus(t, serviceTerminating)
+ parent.receiver.VerifyStatus(t, serviceTerminating)
+
+ child2.UpdateStatus(serviceTerminated, nil)
+
+ child2.receiver.VerifyStatus(t, serviceTerminated)
+ parent.receiver.VerifyStatus(t, serviceTerminated)
+ if err := parent.WaitStopped(); err != nil {
+ t.Errorf("compositeService.WaitStopped() got err: %v", err)
+ }
+
+ if got, want := parent.DependenciesLen(), 1; got != want {
+ t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
+ }
+ if got, want := parent.RemovedLen(), 0; got != want {
+ t.Errorf("compositeService.removed: got len %d, want %d", got, want)
+ }
+ })
+}
+
+func TestCompositeServiceTree(t *testing.T) {
+ leaf1 := newTestService("leaf1")
+ leaf2 := newTestService("leaf2")
+ intermediate1 := newTestCompositeService("intermediate1")
+ intermediate1.AddServices(leaf1, leaf2)
+
+ leaf3 := newTestService("leaf3")
+ leaf4 := newTestService("leaf4")
+ intermediate2 := newTestCompositeService("intermediate2")
+ intermediate2.AddServices(leaf3, leaf4)
+
+ root := newTestCompositeService("root")
+ root.AddServices(intermediate1, intermediate2)
+
+ wantErr := errors.New("fail")
+
+ t.Run("Starting", func(t *testing.T) {
+ // Start trickles down the tree.
+ root.Start()
+
+ leaf1.receiver.VerifyStatus(t, serviceStarting)
+ leaf2.receiver.VerifyStatus(t, serviceStarting)
+ leaf3.receiver.VerifyStatus(t, serviceStarting)
+ leaf4.receiver.VerifyStatus(t, serviceStarting)
+ intermediate1.receiver.VerifyStatus(t, serviceStarting)
+ intermediate2.receiver.VerifyStatus(t, serviceStarting)
+ root.receiver.VerifyStatus(t, serviceStarting)
+ })
+
+ t.Run("Active", func(t *testing.T) {
+ // serviceActive notification trickles up the tree.
+ leaf1.UpdateStatus(serviceActive, nil)
+ leaf2.UpdateStatus(serviceActive, nil)
+ leaf3.UpdateStatus(serviceActive, nil)
+ leaf4.UpdateStatus(serviceActive, nil)
+
+ leaf1.receiver.VerifyStatus(t, serviceActive)
+ leaf2.receiver.VerifyStatus(t, serviceActive)
+ leaf3.receiver.VerifyStatus(t, serviceActive)
+ leaf4.receiver.VerifyStatus(t, serviceActive)
+ intermediate1.receiver.VerifyStatus(t, serviceActive)
+ intermediate2.receiver.VerifyStatus(t, serviceActive)
+ root.receiver.VerifyStatus(t, serviceActive)
+ if err := root.WaitStarted(); err != nil {
+ t.Errorf("compositeService.WaitStarted() got err: %v", err)
+ }
+ })
+
+ t.Run("Leaf fails", func(t *testing.T) {
+ leaf1.UpdateStatus(serviceTerminated, wantErr)
+ leaf1.receiver.VerifyStatus(t, serviceTerminated)
+
+ // Leaf service failure should trickle up the tree and across to all other
+ // leaves, causing them all to start terminating.
+ leaf2.receiver.VerifyStatus(t, serviceTerminating)
+ leaf3.receiver.VerifyStatus(t, serviceTerminating)
+ leaf4.receiver.VerifyStatus(t, serviceTerminating)
+ intermediate1.receiver.VerifyStatus(t, serviceTerminating)
+ intermediate2.receiver.VerifyStatus(t, serviceTerminating)
+ root.receiver.VerifyStatus(t, serviceTerminating)
+ })
+
+ t.Run("Terminated", func(t *testing.T) {
+ // serviceTerminated notification trickles up the tree.
+ leaf2.UpdateStatus(serviceTerminated, nil)
+ leaf3.UpdateStatus(serviceTerminated, nil)
+ leaf4.UpdateStatus(serviceTerminated, nil)
+
+ leaf2.receiver.VerifyStatus(t, serviceTerminated)
+ leaf3.receiver.VerifyStatus(t, serviceTerminated)
+ leaf4.receiver.VerifyStatus(t, serviceTerminated)
+ intermediate1.receiver.VerifyStatus(t, serviceTerminated)
+ intermediate2.receiver.VerifyStatus(t, serviceTerminated)
+ root.receiver.VerifyStatus(t, serviceTerminated)
+
+ if gotErr := root.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
+ }
+ })
+}