Have the worker pool context produce a pool object (#154)
The motivation is:
- allowing worker managers specify the level of concurrency for a
worker. This is set to '10' right now by the data collector, but it's
really specific to the worker set up. For the local case, where we max
out the hardware threads with workers, 10 is overly-generous. For a
distributed case, a number approaching the hardware threads would be
more appropriate
- allowing distributed worker managers update the set of available
workers over time. Workers could get preempted, or new ones become
available. Having a way to periodically check and update what's
available - albeit there are never guarantees a worker, once
discovered, stays alive - helps avoid artificial starvation.
The rest is renames that fall out of this refactoring.
diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py
index 425d83b..e6090f0 100644
--- a/compiler_opt/distributed/local/local_worker_manager.py
+++ b/compiler_opt/distributed/local/local_worker_manager.py
@@ -36,7 +36,7 @@
from absl import logging
# pylint: disable=unused-import
-from compiler_opt.distributed.worker import Worker
+from compiler_opt.distributed.worker import Worker, FixedWorkerPool
from contextlib import AbstractContextManager
from multiprocessing import connection
@@ -238,7 +238,7 @@
return _Stub()
-class LocalWorkerPool(AbstractContextManager):
+class LocalWorkerPoolManager(AbstractContextManager):
"""A pool of workers hosted on the local machines, each in its own process."""
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
@@ -249,8 +249,8 @@
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
]
- def __enter__(self):
- return self._stubs
+ def __enter__(self) -> FixedWorkerPool:
+ return FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
def __exit__(self, *args):
# first, trigger killing the worker process and exiting of the msg pump,
diff --git a/compiler_opt/distributed/local/local_worker_manager_test.py b/compiler_opt/distributed/local/local_worker_manager_test.py
index 75aa907..943b2cd 100644
--- a/compiler_opt/distributed/local/local_worker_manager_test.py
+++ b/compiler_opt/distributed/local/local_worker_manager_test.py
@@ -62,9 +62,9 @@
def test_pool(self):
- with local_worker_manager.LocalWorkerPool(JobNormal, 2) as pool:
- p1 = pool[0]
- p2 = pool[1]
+ with local_worker_manager.LocalWorkerPoolManager(JobNormal, 2) as pool:
+ p1 = pool.get_currently_active()[0]
+ p2 = pool.get_currently_active()[1]
set_futures = [p1.set_token(1), p2.set_token(2)]
done, not_done = concurrent.futures.wait(set_futures)
self.assertLen(done, 2)
@@ -81,16 +81,16 @@
def test_failure(self):
- with local_worker_manager.LocalWorkerPool(JobFail, 2) as pool:
+ with local_worker_manager.LocalWorkerPoolManager(JobFail, 2) as pool:
with self.assertRaises(concurrent.futures.CancelledError):
# this will fail because we didn't pass the arg to the ctor, so the
# worker hosting process will crash.
- pool[0].method().result()
+ pool.get_currently_active()[0].method().result()
def test_worker_crash_while_waiting(self):
- with local_worker_manager.LocalWorkerPool(JobSlow, 2) as pool:
- p = pool[0]
+ with local_worker_manager.LocalWorkerPoolManager(JobSlow, 2) as pool:
+ p = pool.get_currently_active()[0]
f = p.method()
self.assertFalse(f.done())
try:
diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py
index ff040d8..e1a93d1 100644
--- a/compiler_opt/distributed/worker.py
+++ b/compiler_opt/distributed/worker.py
@@ -14,7 +14,8 @@
# limitations under the License.
"""Common abstraction for a worker contract."""
-from typing import Iterable, Optional, Protocol, TypeVar
+import abc
+from typing import Any, List, Iterable, Optional, Protocol, TypeVar
class Worker(Protocol):
@@ -28,6 +29,34 @@
T = TypeVar('T')
+class WorkerPool(metaclass=abc.ABCMeta):
+ """Abstraction of a pool of workers that may be refreshed."""
+
+ # Issue #155 would strongly-type the return type.
+ @abc.abstractmethod
+ def get_currently_active(self) -> List[Any]:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_worker_concurrency(self) -> int:
+ raise NotImplementedError()
+
+
+class FixedWorkerPool(WorkerPool):
+ """A WorkerPool built from a fixed list of workers."""
+
+ # Issue #155 would strongly-type `workers`
+ def __init__(self, workers: List[Any], worker_concurrency: int = 2):
+ self._workers = workers
+ self._worker_concurrency = worker_concurrency
+
+ def get_currently_active(self):
+ return self._workers
+
+ def get_worker_concurrency(self):
+ return self._worker_concurrency
+
+
# Dask's Futures are limited. This captures that.
class WorkerFuture(Protocol[T]):
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index 07fbcbd..20a72d2 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -37,7 +37,7 @@
self,
cps: corpus.Corpus,
num_modules: int,
- worker_pool: List[compilation_runner.CompilationRunnerStub],
+ worker_pool: worker.WorkerPool,
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
reward_stat_map: Dict[str, Optional[Dict[str,
compilation_runner.RewardStat]]],
@@ -49,6 +49,9 @@
self._num_modules = num_modules
self._parser = parser
self._worker_pool = worker_pool
+ self._workers: List[
+ compilation_runner
+ .CompilationRunnerStub] = self._worker_pool.get_currently_active()
self._reward_stat_map = reward_stat_map
self._exit_checker_ctor = exit_checker_ctor
# _reset_workers is a future that resolves when post-data collection cleanup
@@ -75,8 +78,11 @@
def close_pool(self):
self._join_pending_jobs()
- for p in self._worker_pool:
+ # if the pool lost some workers, that's fine - we don't need to tell them
+ # anything anymore. To the new ones, the call is redudant (fine).
+ for p in self._workers:
p.cancel_all_work()
+ self._workers = None
self._worker_pool = None
def _join_pending_jobs(self):
@@ -110,7 +116,9 @@
return work
work = [work_factory(job) for job in jobs]
- return buffered_scheduler.schedule(work, self._worker_pool, buffer=10)
+ self._workers = self._worker_pool.get_currently_active()
+ return buffered_scheduler.schedule(
+ work, self._workers, self._worker_pool.get_worker_concurrency())
def collect_data(
self, policy: policy_saver.Policy
@@ -158,13 +166,13 @@
# signal whatever work is left to finish, and re-enable workers.
def wrapup():
- cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
+ cancel_futures = [wkr.cancel_all_work() for wkr in self._workers]
worker.wait_for(cancel_futures)
# now that the workers killed pending compilations, make sure the workers
# drained their working queues first - they should all complete quickly
# since the cancellation manager is killing immediately any process starts
worker.wait_for(self._current_futures)
- worker.wait_for([wkr.enable() for wkr in self._worker_pool])
+ worker.wait_for([wkr.enable() for wkr in self._workers])
self._reset_workers = self._pool.submit(wrapup)
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index 52ea46b..b48fb1f 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -23,7 +23,7 @@
import tensorflow as tf
from tf_agents.system import system_multiprocessing as multiprocessing
-from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
+from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import corpus
from compiler_opt.rl import data_collector
@@ -142,7 +142,7 @@
return _test_iterator_fn
sampler = DeterministicSampler()
- with LocalWorkerPool(worker_class=MyRunner, count=4) as lwp:
+ with LocalWorkerPoolManager(worker_class=MyRunner, count=4) as lwp:
collector = local_data_collector.LocalDataCollector(
cps=corpus.create_corpus_for_testing(
location=self.create_tempdir(),
@@ -214,7 +214,7 @@
def wait(self, _):
return False
- with LocalWorkerPool(worker_class=Sleeper, count=4) as lwp:
+ with LocalWorkerPoolManager(worker_class=Sleeper, count=4) as lwp:
collector = local_data_collector.LocalDataCollector(
cps=corpus.create_corpus_for_testing(
location=self.create_tempdir(),
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 78143b1..ddf6dd9 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -29,7 +29,7 @@
from tf_agents.system import system_multiprocessing as multiprocessing
from typing import List
-from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
+from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
from compiler_opt.rl import agent_creators
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import constant
@@ -59,7 +59,7 @@
@gin.configurable
-def train_eval(worker_manager_class=LocalWorkerPool,
+def train_eval(worker_manager_class=LocalWorkerPoolManager,
agent_name=constant.AgentName.PPO,
warmstart_policy_dir=None,
num_policy_iterations=0,