Make the worker manager type a parameter to `train_eval`
This allows reusing `train_locally` with other worker managers. It's the
minimum refactoring necessary - subsequent ones would make this a
library and also remove the `_local` suffix from this and a few other
places, since they aren't "local" in any sense anymore.
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 6348cc5..78143b1 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -59,7 +59,8 @@
@gin.configurable
-def train_eval(agent_name=constant.AgentName.PPO,
+def train_eval(worker_manager_class=LocalWorkerPool,
+ agent_name=constant.AgentName.PPO,
warmstart_policy_dir=None,
num_policy_iterations=0,
num_modules=100,
@@ -133,7 +134,7 @@
logging.info('Loaded Reward Stat Map from disk, containing %d modules',
len(reward_stat_map))
- with LocalWorkerPool(
+ with worker_manager_class(
worker_class=problem_config.get_runner_type(),
count=FLAGS.num_workers,
moving_average_decay_rate=moving_average_decay_rate) as worker_pool: