Have the worker pool context produce a pool object (#154)

The motivation is:
- allowing worker managers specify the level of concurrency for a
  worker. This is set to '10' right now by the data collector, but it's
  really specific to the worker set up. For the local case, where we max
  out the hardware threads with workers, 10 is overly-generous. For a
  distributed case, a number approaching the hardware threads would be
  more appropriate
- allowing distributed worker managers update the set of available
  workers over time. Workers could get preempted, or new ones become
  available. Having a way to periodically check and update what's
  available - albeit there are never guarantees a worker, once
  discovered, stays alive - helps avoid artificial starvation.

The rest is renames that fall out of this refactoring.
diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py
index 425d83b..e6090f0 100644
--- a/compiler_opt/distributed/local/local_worker_manager.py
+++ b/compiler_opt/distributed/local/local_worker_manager.py
@@ -36,7 +36,7 @@
 
 from absl import logging
 # pylint: disable=unused-import
-from compiler_opt.distributed.worker import Worker
+from compiler_opt.distributed.worker import Worker, FixedWorkerPool
 
 from contextlib import AbstractContextManager
 from multiprocessing import connection
@@ -238,7 +238,7 @@
   return _Stub()
 
 
-class LocalWorkerPool(AbstractContextManager):
+class LocalWorkerPoolManager(AbstractContextManager):
   """A pool of workers hosted on the local machines, each in its own process."""
 
   def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
@@ -249,8 +249,8 @@
         _make_stub(worker_class, *args, **kwargs) for _ in range(count)
     ]
 
-  def __enter__(self):
-    return self._stubs
+  def __enter__(self) -> FixedWorkerPool:
+    return FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
 
   def __exit__(self, *args):
     # first, trigger killing the worker process and exiting of the msg pump,
diff --git a/compiler_opt/distributed/local/local_worker_manager_test.py b/compiler_opt/distributed/local/local_worker_manager_test.py
index 75aa907..943b2cd 100644
--- a/compiler_opt/distributed/local/local_worker_manager_test.py
+++ b/compiler_opt/distributed/local/local_worker_manager_test.py
@@ -62,9 +62,9 @@
 
   def test_pool(self):
 
-    with local_worker_manager.LocalWorkerPool(JobNormal, 2) as pool:
-      p1 = pool[0]
-      p2 = pool[1]
+    with local_worker_manager.LocalWorkerPoolManager(JobNormal, 2) as pool:
+      p1 = pool.get_currently_active()[0]
+      p2 = pool.get_currently_active()[1]
       set_futures = [p1.set_token(1), p2.set_token(2)]
       done, not_done = concurrent.futures.wait(set_futures)
       self.assertLen(done, 2)
@@ -81,16 +81,16 @@
 
   def test_failure(self):
 
-    with local_worker_manager.LocalWorkerPool(JobFail, 2) as pool:
+    with local_worker_manager.LocalWorkerPoolManager(JobFail, 2) as pool:
       with self.assertRaises(concurrent.futures.CancelledError):
         # this will fail because we didn't pass the arg to the ctor, so the
         # worker hosting process will crash.
-        pool[0].method().result()
+        pool.get_currently_active()[0].method().result()
 
   def test_worker_crash_while_waiting(self):
 
-    with local_worker_manager.LocalWorkerPool(JobSlow, 2) as pool:
-      p = pool[0]
+    with local_worker_manager.LocalWorkerPoolManager(JobSlow, 2) as pool:
+      p = pool.get_currently_active()[0]
       f = p.method()
       self.assertFalse(f.done())
       try:
diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py
index ff040d8..e1a93d1 100644
--- a/compiler_opt/distributed/worker.py
+++ b/compiler_opt/distributed/worker.py
@@ -14,7 +14,8 @@
 # limitations under the License.
 """Common abstraction for a worker contract."""
 
-from typing import Iterable, Optional, Protocol, TypeVar
+import abc
+from typing import Any, List, Iterable, Optional, Protocol, TypeVar
 
 
 class Worker(Protocol):
@@ -28,6 +29,34 @@
 T = TypeVar('T')
 
 
+class WorkerPool(metaclass=abc.ABCMeta):
+  """Abstraction of a pool of workers that may be refreshed."""
+
+  # Issue #155 would strongly-type the return type.
+  @abc.abstractmethod
+  def get_currently_active(self) -> List[Any]:
+    raise NotImplementedError()
+
+  @abc.abstractmethod
+  def get_worker_concurrency(self) -> int:
+    raise NotImplementedError()
+
+
+class FixedWorkerPool(WorkerPool):
+  """A WorkerPool built from a fixed list of workers."""
+
+  # Issue #155 would strongly-type `workers`
+  def __init__(self, workers: List[Any], worker_concurrency: int = 2):
+    self._workers = workers
+    self._worker_concurrency = worker_concurrency
+
+  def get_currently_active(self):
+    return self._workers
+
+  def get_worker_concurrency(self):
+    return self._worker_concurrency
+
+
 # Dask's Futures are limited. This captures that.
 class WorkerFuture(Protocol[T]):
 
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index 07fbcbd..20a72d2 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -37,7 +37,7 @@
       self,
       cps: corpus.Corpus,
       num_modules: int,
-      worker_pool: List[compilation_runner.CompilationRunnerStub],
+      worker_pool: worker.WorkerPool,
       parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
       reward_stat_map: Dict[str, Optional[Dict[str,
                                                compilation_runner.RewardStat]]],
@@ -49,6 +49,9 @@
     self._num_modules = num_modules
     self._parser = parser
     self._worker_pool = worker_pool
+    self._workers: List[
+        compilation_runner
+        .CompilationRunnerStub] = self._worker_pool.get_currently_active()
     self._reward_stat_map = reward_stat_map
     self._exit_checker_ctor = exit_checker_ctor
     # _reset_workers is a future that resolves when post-data collection cleanup
@@ -75,8 +78,11 @@
 
   def close_pool(self):
     self._join_pending_jobs()
-    for p in self._worker_pool:
+    # if the pool lost some workers, that's fine - we don't need to tell them
+    # anything anymore. To the new ones, the call is redudant (fine).
+    for p in self._workers:
       p.cancel_all_work()
+    self._workers = None
     self._worker_pool = None
 
   def _join_pending_jobs(self):
@@ -110,7 +116,9 @@
       return work
 
     work = [work_factory(job) for job in jobs]
-    return buffered_scheduler.schedule(work, self._worker_pool, buffer=10)
+    self._workers = self._worker_pool.get_currently_active()
+    return buffered_scheduler.schedule(
+        work, self._workers, self._worker_pool.get_worker_concurrency())
 
   def collect_data(
       self, policy: policy_saver.Policy
@@ -158,13 +166,13 @@
 
     # signal whatever work is left to finish, and re-enable workers.
     def wrapup():
-      cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
+      cancel_futures = [wkr.cancel_all_work() for wkr in self._workers]
       worker.wait_for(cancel_futures)
       # now that the workers killed pending compilations, make sure the workers
       # drained their working queues first - they should all complete quickly
       # since the cancellation manager is killing immediately any process starts
       worker.wait_for(self._current_futures)
-      worker.wait_for([wkr.enable() for wkr in self._worker_pool])
+      worker.wait_for([wkr.enable() for wkr in self._workers])
 
     self._reset_workers = self._pool.submit(wrapup)
 
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index 52ea46b..b48fb1f 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -23,7 +23,7 @@
 import tensorflow as tf
 from tf_agents.system import system_multiprocessing as multiprocessing
 
-from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
+from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import corpus
 from compiler_opt.rl import data_collector
@@ -142,7 +142,7 @@
       return _test_iterator_fn
 
     sampler = DeterministicSampler()
-    with LocalWorkerPool(worker_class=MyRunner, count=4) as lwp:
+    with LocalWorkerPoolManager(worker_class=MyRunner, count=4) as lwp:
       collector = local_data_collector.LocalDataCollector(
           cps=corpus.create_corpus_for_testing(
               location=self.create_tempdir(),
@@ -214,7 +214,7 @@
       def wait(self, _):
         return False
 
-    with LocalWorkerPool(worker_class=Sleeper, count=4) as lwp:
+    with LocalWorkerPoolManager(worker_class=Sleeper, count=4) as lwp:
       collector = local_data_collector.LocalDataCollector(
           cps=corpus.create_corpus_for_testing(
               location=self.create_tempdir(),
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 78143b1..ddf6dd9 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -29,7 +29,7 @@
 from tf_agents.system import system_multiprocessing as multiprocessing
 from typing import List
 
-from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
+from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
 from compiler_opt.rl import agent_creators
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import constant
@@ -59,7 +59,7 @@
 
 
 @gin.configurable
-def train_eval(worker_manager_class=LocalWorkerPool,
+def train_eval(worker_manager_class=LocalWorkerPoolManager,
                agent_name=constant.AgentName.PPO,
                warmstart_policy_dir=None,
                num_policy_iterations=0,