Use dask for worker management.

The intention is to make the middleware replaceable, but under
assumptions that are strongly Dask-inspired.
diff --git a/compiler_opt/__init__.py b/compiler_opt/__init__.py
index 6078d4f..7c35352 100644
--- a/compiler_opt/__init__.py
+++ b/compiler_opt/__init__.py
@@ -15,6 +15,8 @@
 """Ensure flags are initialized for e.g. pytest harness case."""
 
 import sys
+#import compiler_opt.core.dask.worker_manager as default_factory
+import compiler_opt.core.multiprocessing.worker_manager as default_factory
 
 from absl import flags
 
@@ -28,3 +30,5 @@
 if not flags.FLAGS.is_parsed():
   flags.FLAGS(sys.argv, known_only=True)
 assert flags.FLAGS.is_parsed()
+
+worker_generator = default_factory.get_compilation_jobs
diff --git a/compiler_opt/core/abstract_worker.py b/compiler_opt/core/abstract_worker.py
new file mode 100644
index 0000000..a1ed868
--- /dev/null
+++ b/compiler_opt/core/abstract_worker.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Common abstraction for a worker contract."""
+
+from abc import abstractmethod
+
+
+class AbstractWorker:
+
+  @abstractmethod
+  def cancel_all_work(self):
+    ...
diff --git a/compiler_opt/core/dask/worker_manager.py b/compiler_opt/core/dask/worker_manager.py
new file mode 100644
index 0000000..d4f5f47
--- /dev/null
+++ b/compiler_opt/core/dask/worker_manager.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dask - based middleware implementation."""
+from absl import logging
+import gin
+import dask.config
+import dask.utils
+import multiprocessing
+import tempfile
+
+from compiler_opt.core.abstract_worker import AbstractWorker
+from concurrent.futures import ThreadPoolExecutor
+from dask.distributed import Client, Worker, LocalCluster
+from typing import Callable, Optional, Tuple
+
+
+class MTWorker(Worker):
+  """Multi-threaded worker."""
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    if self.executors['actor']:
+      self.executors['actor'].shutdown()
+    self.executors['actor'] = ThreadPoolExecutor(
+        max_workers=None, thread_name_prefix='Dask-Actor-MT')
+
+
+class LocalManager:
+  """Local, dask-based worker manager."""
+
+  def __init__(self):
+    self._tmpdir = tempfile.TemporaryDirectory()
+    dask.config.set({
+        'temporary-directory': self._tmpdir.name,
+        'distributed.worker.daemon': False,
+        'work-stealing': False
+    })
+
+    self._client = Client(
+        dashboard_address=None,
+        processes=True,
+        n_workers=1,
+        worker_class=MTWorker)
+    print(self._client)
+
+  def shutdown(self):
+    self._client.close()
+    self._tmpdir.cleanup()
+
+  def get_client(self):
+    return self._client
+
+
+def get_local_compilation_jobs(ctor: Callable[[], AbstractWorker],
+                               count: Optional[int]) -> Tuple[Callable, list]:
+
+  class DaskStubWrapper(AbstractWorker):
+
+    def __init__(self, stub: AbstractWorker):
+      self._stub = stub
+
+    def cancel_all_work(self):
+      return self._stub.cancel_all_work(separate_thread=False)
+
+    def __getattr__(self, name):
+      return self._stub.__getattr__(name)
+
+  if not count:
+    count = multiprocessing.cpu_count()
+  instance = LocalManager()
+  workers = [
+      instance.get_client().submit(ctor, actor=True) for _ in range(count)
+  ]
+  return instance.shutdown, [
+      DaskStubWrapper(worker.result()) for worker in workers
+  ]
diff --git a/compiler_opt/core/multiprocessing/worker_manager.py b/compiler_opt/core/multiprocessing/worker_manager.py
new file mode 100644
index 0000000..46fd911
--- /dev/null
+++ b/compiler_opt/core/multiprocessing/worker_manager.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dask - based middleware implementation."""
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+import concurrent.futures
+import functools
+import multiprocessing
+import multiprocessing.connection
+import queue
+import threading
+
+from compiler_opt.core.abstract_worker import AbstractWorker
+from typing import Callable, Optional, Tuple
+
+def worker(ctor, in_q:queue.Queue, out_q:queue.Queue):
+  pool = ThreadPoolExecutor()
+  obj = ctor()
+
+  def make_ondone(msgid):
+
+    def on_done(f: concurrent.futures.Future):
+      if f.exception():
+        out_q.put((msgid, False, f.exception()))
+      else:
+        out_q.put((msgid, True, f.result()))
+    return on_done
+
+  while True:
+    msgid, fname, args, kwargs, urgent = in_q.get()
+    the_func = getattr(obj, fname)
+    if urgent:
+      try:
+        res = the_func(*args, **kwargs)
+        out_q.put((msgid, True, res))
+      except Exception as e:
+        out_q.put((msgid, False, e))
+    else:
+      pool.submit(the_func, *args,
+                  **kwargs).add_done_callback(make_ondone(msgid))
+
+
+class Stub:
+  def __init__(self, ctor):
+    self._send = multiprocessing.get_context().Queue()
+    self._receive = multiprocessing.get_context().Queue()
+
+    self._process = multiprocessing.Process(
+        target=functools.partial(worker, ctor, self._send, self._receive))
+    self._lock = threading.Lock()
+    self._map:dict[object, concurrent.futures.Future] = {}
+    self._pump = threading.Thread(target=self._msg_pump)
+    self._done = threading.Event()
+    self._msgidlock = threading.Lock()
+    self._msgid = 0
+    self._process.start()
+    self._pump.start()
+
+  def _msg_pump(self):
+    while not self._done.is_set():
+      msgid, succeeded, value = self._receive.get()
+      with self._lock:
+        future = self._map[msgid]
+        del self._map[msgid]
+        if succeeded:
+          future.set_result(value)
+        else:
+          future.set_exception(value)
+
+  def __getattr__(self, name):
+    with self._msgidlock:
+      msgid = self._msgid
+      self._msgid += 1
+    def remote_call(*args, **kwargs):
+      self._send.put(
+          (msgid, name, args, kwargs, name == 'cancel_all_work'))
+      future = concurrent.futures.Future()
+      with self._lock:
+        self._map[msgid] = future
+      return future
+
+    return remote_call
+  
+  def kill(self):
+    try:
+      self._process.kill()
+    except:
+      pass
+    self._done.set()
+
+def get_compilation_jobs(ctor: Callable[[], AbstractWorker],
+                         count: Optional[int]) -> Tuple[Callable, list]:
+
+  if not count:
+    count = multiprocessing.cpu_count()
+  stubs = [Stub(ctor) for _ in range(count)]
+  def shutdown():
+    for s in stubs:
+      s.kill()
+  return shutdown, [Stub(ctor) for _ in range(count)]
diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py
index 464d908..4722d1c 100644
--- a/compiler_opt/rl/compilation_runner.py
+++ b/compiler_opt/rl/compilation_runner.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 """Module for running compilation and collect training data."""
 
-import concurrent
+from abc import abstractclassmethod, abstractmethod
+import concurrent.futures
 import dataclasses
 import json
-import multiprocessing
 import subprocess
 import threading
 from typing import Dict, List, Optional, Tuple
@@ -119,18 +119,6 @@
     Exception.__init__(self)
 
 
-class ProcessCancellationToken:
-
-  def __init__(self):
-    self._event = multiprocessing.Manager().Event()
-
-  def signal(self):
-    self._event.set()
-
-  def wait(self):
-    self._event.wait()
-
-
 def kill_process_ignore_exceptions(p: 'subprocess.Popen[bytes]'):
   # kill the process and ignore exceptions. Exceptions would be thrown if the
   # process has already been killed/finished (which is inherently in a race
@@ -157,6 +145,9 @@
     self._done = False
     self._lock = threading.Lock()
 
+  def enable(self):
+    self._done = False
+
   def register_process(self, p: 'subprocess.Popen[bytes]'):
     """Register a process for potential cancellation."""
     with self._lock:
@@ -165,12 +156,13 @@
         return
     kill_process_ignore_exceptions(p)
 
-  def signal(self):
+  def kill_all_processes(self):
     """Cancel any pending work."""
     with self._lock:
       self._done = True
     for p in self._processes:
       kill_process_ignore_exceptions(p)
+    return len(self._processes)
 
   def unregister_process(self, p: 'subprocess.Popen[bytes]'):
     with self._lock:
@@ -262,22 +254,24 @@
     assert not hasattr(self, 'sequence_examples')
 
 
+class CompilationRunnerStub:
+  """The interface of a stub to CompilationRunner."""
+
+  @abstractmethod
+  def collect_data(
+      self, file_paths: Tuple[str, ...], tf_policy_path: str,
+      reward_stat: Optional[Dict[str, RewardStat]]
+  ) -> concurrent.futures.Future[CompilationResult]:
+    ...
+
+  @abstractmethod
+  def cancel_all_work(self) -> concurrent.futures.Future:
+    ...
+
+
 class CompilationRunner:
   """Base class for collecting compilation data."""
 
-  _POOL: concurrent.futures.ThreadPoolExecutor = None
-
-  @staticmethod
-  def init_pool():
-    """Worker process initialization."""
-    CompilationRunner._POOL = concurrent.futures.ThreadPoolExecutor()
-
-  @staticmethod
-  def _get_pool():
-    """Internal API for fetching the cancellation token waiting pool."""
-    assert CompilationRunner._POOL
-    return CompilationRunner._POOL
-
   def __init__(self,
                clang_path: Optional[str] = None,
                launcher_path: Optional[str] = None,
@@ -293,40 +287,17 @@
     self._launcher_path = launcher_path
     self._moving_average_decay_rate = moving_average_decay_rate
     self._compilation_timeout = _COMPILATION_TIMEOUT.value
+    self._cancellation_manager = WorkerCancellationManager()
 
-  def _get_cancellation_manager(
-      self, cancellation_token: Optional[ProcessCancellationToken]
-  ) -> Optional[WorkerCancellationManager]:
-    """Convert the ProcessCancellationToken into a WorkerCancellationManager.
+  def enable(self):
+    self._cancellation_manager.enable()
 
-    The conversion also registers the ProcessCancellationToken wait() on a
-    thread which will call the WorkerCancellationManager upon completion.
-    Since the token is always signaled, the thread always completes its work.
-
-    Args:
-      cancellation_token: the ProcessCancellationToken to convert.
-
-    Returns:
-      a WorkerCancellationManager, if a ProcessCancellationToken was given.
-    """
-    if not cancellation_token:
-      return None
-    ret = WorkerCancellationManager()
-
-    def signaler():
-      cancellation_token.wait()
-      ret.signal()
-
-    CompilationRunner._get_pool().submit(signaler)
-    return ret
+  def cancel_all_work(self):
+    return self._cancellation_manager.kill_all_processes()
 
   def collect_data(
-      self,
-      file_paths: Tuple[str, ...],
-      tf_policy_path: str,
-      reward_stat: Optional[Dict[str, RewardStat]],
-      cancellation_token: Optional[ProcessCancellationToken] = None
-  ) -> CompilationResult:
+      self, file_paths: Tuple[str, ...], tf_policy_path: str,
+      reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
     """Collect data for the given IR file and policy.
 
     Args:
@@ -346,14 +317,12 @@
       compilation_runner.ProcessKilledException is passed through.
       ValueError if example under default policy and ml policy does not match.
     """
-    cancellation_manager = self._get_cancellation_manager(cancellation_token)
-
     if reward_stat is None:
       default_result = self._compile_fn(
           file_paths,
           tf_policy_path='',
           reward_only=bool(tf_policy_path),
-          cancellation_manager=cancellation_manager)
+          cancellation_manager=self._cancellation_manager)
       reward_stat = {
           k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
       }
@@ -363,7 +332,7 @@
           file_paths,
           tf_policy_path,
           reward_only=False,
-          cancellation_manager=cancellation_manager)
+          cancellation_manager=self._cancellation_manager)
     else:
       policy_result = default_result
 
diff --git a/compiler_opt/rl/inlining/__init__.py b/compiler_opt/rl/inlining/__init__.py
index 6487ac8..bce5988 100644
--- a/compiler_opt/rl/inlining/__init__.py
+++ b/compiler_opt/rl/inlining/__init__.py
@@ -28,6 +28,9 @@
   def get_runner(self, *args, **kwargs):
     return inlining_runner.InliningRunner(*args, **kwargs)
 
+  def get_runner_ctor(self):
+    return inlining_runner.InliningRunner
+
   def get_signature_spec(self):
     return config.get_inlining_signature_spec()
 
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index 3b80bbc..fea34e4 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -14,19 +14,31 @@
 # limitations under the License.
 """Module for collecting data locally."""
 
+import concurrent.futures
 import itertools
 import random
 import time
-from typing import Callable, Dict, Iterator, List, Tuple, Optional
+from typing import Callable, Dict, Iterator, Iterable, List, Tuple, Optional
 
 from absl import logging
-import multiprocessing  # for Pool
+import gin
 from tf_agents.trajectories import trajectory
 
+from compiler_opt import worker_generator
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import data_collector
 
 
+def wait_for(futures:Iterable[concurrent.futures.Future]):
+  """Dask futures don't support more than result()."""
+  for f in futures:
+    try:
+      _ = f.result()
+    except:
+      pass
+
+
+@gin.configurable()
 class LocalDataCollector(data_collector.DataCollector):
   """class for local data collection."""
 
@@ -35,71 +47,67 @@
       file_paths: Tuple[Tuple[str, ...], ...],
       num_workers: int,
       num_modules: int,
-      runner: compilation_runner.CompilationRunner,
       parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
       reward_stat_map: Dict[str, Optional[Dict[str,
                                                compilation_runner.RewardStat]]],
+      worker_ctor: Callable[[], compilation_runner.CompilationRunnerStub],
       exit_checker_ctor=data_collector.EarlyExitChecker):
     # TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
     super().__init__()
 
     self._file_paths = file_paths
     self._num_modules = num_modules
-    self._runner = runner
     self._parser = parser
-    self._pool = multiprocessing.get_context().Pool(
-        num_workers, initializer=compilation_runner.CompilationRunner.init_pool)
+    self._shutdown, self._worker_pool = worker_generator(
+        worker_ctor, count=num_workers)
     self._reward_stat_map = reward_stat_map
 
     self._exit_checker_ctor = exit_checker_ctor
-    self._pending_work = None
-    # hold on to the token so it won't get GCed before all its wait()
-    # complete
-    self._last_token = None
+    self._pending_work: concurrent.futures.Future = None
+    self._current_work: list[Tuple[str, concurrent.futures.Future]] = []
+    self._pool = concurrent.futures.ThreadPoolExecutor()
+
+  def get_last_work(self):
+    return self._current_work
 
   def close_pool(self):
-    self._join_pending_jobs()
-    if self._pool:
-      # Stop accepting new work
-      self._pool.close()
-      self._pool.join()
-      self._pool = None
+    self.join_pending_jobs()
+    # Stop accepting new work
+    for p in self._worker_pool:
+      p.cancel_all_work()
+    self._worker_pool = None
+    self._shutdown()
+    self._pool.shutdown()
 
-  def _join_pending_jobs(self):
+  def join_pending_jobs(self):
+    t1 = time.time()
     if self._pending_work:
-      t1 = time.time()
-      for w in self._pending_work:
-        w.wait()
+      concurrent.futures.wait([self._pending_work])
 
-      self._pending_work = None
-      # this should have taken negligible time, normally, since all the work
-      # has been cancelled and the workers had time to process the cancellation
-      # while training was unfolding.
-      logging.info('Waiting for pending work from last iteration took %f',
-                   time.time() - t1)
-    self._last_token = None
+    self._pending_work = None
+    self._current_work = []
+    # this should have taken negligible time, normally, since all the work
+    # has been cancelled and the workers had time to process the cancellation
+    # while training was unfolding.
+    logging.info('Waiting for pending work from last iteration took %f',
+                  time.time() - t1)
 
-  def _schedule_jobs(self, policy_path, sampled_file_paths):
+  def _schedule_jobs(
+      self, policy_path, sampled_file_paths
+  ) -> List[concurrent.futures.Future[compilation_runner.CompilationResult]]:
     # by now, all the pending work, which was signaled to cancel, must've
     # finished
-    self._join_pending_jobs()
-    cancellation_token = compilation_runner.ProcessCancellationToken()
+    self.join_pending_jobs()
     jobs = [(file_paths, policy_path,
-             self._reward_stat_map['-'.join(file_paths)], cancellation_token)
+             self._reward_stat_map['-'.join(file_paths)])
             for file_paths in sampled_file_paths]
 
-    # Make sure we're not missing failures in workers. All but
-    # ProcessKilledError, which we want to ignore.
-    def error_callback(e):
-      if isinstance(e, compilation_runner.ProcessKilledError):
-        return
-      logging.exception('Error in worker: %r', e)
-
-    return (cancellation_token, [
-        self._pool.apply_async(
-            self._runner.collect_data, job, error_callback=error_callback)
-        for job in jobs
-    ])
+    # Naive load balancing
+    ret = []
+    for i in range(len(jobs)):
+      ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
+          *(jobs[i])))
+    return ret
 
   def collect_data(
       self, policy_path: str
@@ -117,39 +125,50 @@
       information is viewable in TensorBoard.
     """
     sampled_file_paths = random.sample(self._file_paths, k=self._num_modules)
-    ct, results = self._schedule_jobs(policy_path, sampled_file_paths)
+    results = self._schedule_jobs(policy_path, sampled_file_paths)
 
     def wait_for_termination():
       early_exit = self._exit_checker_ctor(num_modules=self._num_modules)
 
       def get_num_finished_work():
-        finished_work = sum(res.ready() for res in results)
+        finished_work = sum(res.done() for res in results)
         return finished_work
 
       return early_exit.wait(get_num_finished_work)
 
     wait_seconds = wait_for_termination()
-    # signal whatever work is left to finish
-    ct.signal()
-    current_work = zip(sampled_file_paths, results)
-    finished_work = [(paths, res) for paths, res in current_work if res.ready()]
-    successful_work = [
-        (paths, res) for paths, res in finished_work if res.successful()
+    self._current_work = list(zip(sampled_file_paths, results))
+    finished_work = [
+        (paths, res) for paths, res in self._current_work if res.done()
     ]
-    failures = len(finished_work) - len(successful_work)
 
+    def is_successful(f):
+      try:
+        _ = f.result()
+        return True
+      except:  # pylint: disable=bare-except
+        return False
+
+    successful_work = [(paths, res.result())
+                       for paths, res in finished_work
+                       if is_successful(res)]
+    failures = len(finished_work) - len(successful_work)
     logging.info(('%d of %d modules finished in %d seconds (%d failures).'),
                  len(finished_work), self._num_modules, wait_seconds, failures)
+    # signal whatever work is left to finish
+    def wrapup():
+      cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
+      wait_for(cancel_futures)
+      wait_for(results)
+      wait_for([wkr.enable() for wkr in self._worker_pool])
+    self._pending_work = self._pool.submit(wrapup)
 
     sequence_examples = list(
-        itertools.chain.from_iterable([
-            res.get().serialized_sequence_examples
-            for (_, res) in successful_work
-        ]))
-    total_trajectory_length = sum(
-        res.get().length for (_, res) in successful_work)
+        itertools.chain.from_iterable(
+            [res.serialized_sequence_examples for (_, res) in successful_work]))
+    total_trajectory_length = sum(res.length for (_, res) in successful_work)
     self._reward_stat_map.update({
-        '-'.join(file_paths): res.get().reward_stats
+        '-'.join(file_paths): res.reward_stats
         for (file_paths, res) in successful_work
     })
 
@@ -160,19 +179,13 @@
     }
     rewards = list(
         itertools.chain.from_iterable(
-            [res.get().rewards for (_, res) in successful_work]))
+            [res.rewards for (_, res) in successful_work]))
     monitor_dict[
         'reward_distribution'] = data_collector.build_distribution_monitor(
             rewards)
 
     parsed = self._parser(sequence_examples)
 
-    self._pending_work = [res for res in results if not res.ready()]
-    # if some of the cancelled work hasn't yet processed the signal, let's let
-    # it do that while we process the data. We also need to hold on to the
-    # current token, so its Event object doesn't get GC-ed here.
-    if self._pending_work:
-      self._last_token = ct
     return parsed, monitor_dict
 
   def on_dataset_consumed(self,
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index 51a35b4..84f0e2e 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -44,7 +44,7 @@
   return text_format.Parse(sequence_example_text, tf.train.SequenceExample())
 
 
-def mock_collect_data(file_paths, tf_policy_dir, reward_stat, _):
+def mock_collect_data(file_paths, tf_policy_dir, reward_stat):
   assert file_paths == ('a', 'b')
   assert tf_policy_dir == 'policy'
   assert reward_stat is None or reward_stat == {
@@ -77,30 +77,10 @@
 class Sleeper(compilation_runner.CompilationRunner):
   """Test CompilationRunner that just sleeps."""
 
-  # pylint: disable=super-init-not-called
-  def __init__(self, killed, timedout, living):
-    self._killed = killed
-    self._timedout = timedout
-    self._living = living
-    self._lock = mp.Manager().Lock()
-
-  def collect_data(self, file_paths, tf_policy_path, reward_stat,
-                   cancellation_token):
-    _ = reward_stat
-    cancellation_manager = self._get_cancellation_manager(cancellation_token)
-    try:
-      compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
-                                                   cancellation_manager)
-    except compilation_runner.ProcessKilledError as e:
-      with self._lock:
-        self._killed.value += 1
-      raise e
-    except subprocess.TimeoutExpired as e:
-      with self._lock:
-        self._timedout.value += 1
-      raise e
-    with self._lock:
-      self._living.value += 1
+  def collect_data(self, file_paths, tf_policy_path, reward_stat):
+    _ = file_paths, tf_policy_path, reward_stat
+    compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
+                                                 self._cancellation_manager)
     return compilation_runner.CompilationResult(
         sequence_examples=[], reward_stats={}, rewards=[], keys=[])
 
@@ -108,10 +88,15 @@
 class LocalDataCollectorTest(tf.test.TestCase):
 
   def test_local_data_collector(self):
-    mock_compilation_runner = mock.create_autospec(
-        compilation_runner.CompilationRunner)
 
-    mock_compilation_runner.collect_data = mock_collect_data
+    def make_runner():
+
+      class MyRunner(compilation_runner.CompilationRunner):
+
+        def collect_data(self, *args, **kwargs):
+          return mock_collect_data(*args, **kwargs)
+
+      return MyRunner()
 
     def create_test_iterator_fn():
 
@@ -132,8 +117,8 @@
         file_paths=tuple([('a', 'b')] * 100),
         num_workers=4,
         num_modules=9,
-        runner=mock_compilation_runner,
         parser=create_test_iterator_fn(),
+        worker_ctor=make_runner,
         reward_stat_map=collections.defaultdict(lambda: None))
 
     data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
@@ -161,11 +146,6 @@
     collector.close_pool()
 
   def test_local_data_collector_task_management(self):
-    killed = mp.Manager().Value('i', value=0)
-    timedout = mp.Manager().Value('i', value=0)
-    living = mp.Manager().Value('i', value=0)
-
-    mock_compilation_runner = Sleeper(killed, timedout, living)
 
     def parser(_):
       pass
@@ -182,17 +162,18 @@
         file_paths=tuple([('a', 'b')] * 200),
         num_workers=4,
         num_modules=4,
-        runner=mock_compilation_runner,
+        worker_ctor=Sleeper,
         parser=parser,
         reward_stat_map=collections.defaultdict(lambda: None),
         exit_checker_ctor=QuickExiter)
     collector.collect_data(policy_path='policy')
-    # close the pool first, so we are forced to wait for the workers to process
-    # their cancellation.
+    collector.join_pending_jobs()
+    killed = 0
+    for _, w in collector.get_last_work():
+      self.assertRaises(compilation_runner.ProcessKilledError, w.result)
+      killed += 1
+    self.assertEquals(killed, 4)
     collector.close_pool()
-    self.assertEqual(killed.value, 4)
-    self.assertEqual(living.value, 0)
-    self.assertEqual(timedout.value, 0)
 
 
 if __name__ == '__main__':
diff --git a/compiler_opt/rl/problem_configuration.py b/compiler_opt/rl/problem_configuration.py
index ff6d089..28f169d 100644
--- a/compiler_opt/rl/problem_configuration.py
+++ b/compiler_opt/rl/problem_configuration.py
@@ -94,7 +94,10 @@
   @abc.abstractmethod
   def get_runner(self, *args, **kwargs) -> compilation_runner.CompilationRunner:
     raise NotImplementedError
-
+  
+  @abc.abstractmethod
+  def get_runner_ctor(self):
+    raise NotImplementedError
 
 def is_thinlto(module_paths: Iterable[str]) -> bool:
   return tf.io.gfile.exists(next(iter(module_paths)) + '.thinlto.bc')
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 5a80a9b..f4576c8 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -25,7 +25,6 @@
 from absl import logging
 import gin
 import tensorflow as tf
-from tf_agents.system import system_multiprocessing as multiprocessing
 from typing import List
 
 from compiler_opt.rl import agent_creators
@@ -111,9 +110,6 @@
       file_paths = [(path + '.bc', path + '.cmd', path + '.thinlto.bc')
                     for path in module_paths]
 
-  runner = problem_config.get_runner(
-      moving_average_decay_rate=moving_average_decay_rate)
-
   dataset_fn = data_reader.create_sequence_example_dataset_fn(
       agent_name=agent_name,
       time_step_spec=time_step_spec,
@@ -145,7 +141,9 @@
       file_paths=file_paths,
       num_workers=FLAGS.num_workers,
       num_modules=FLAGS.num_modules,
-      runner=runner,
+      worker_ctor=functools.partial(
+          problem_config.get_runner_ctor(),
+          moving_average_decay_rate=moving_average_decay_rate),
       parser=sequence_example_iterator_fn,
       reward_stat_map=reward_stat_map)
 
@@ -185,4 +183,4 @@
 
 if __name__ == '__main__':
   flags.mark_flag_as_required('data_path')
-  multiprocessing.handle_main(functools.partial(app.run, main))
+  app.run(main)
diff --git a/requirements.txt b/requirements.txt
index 45a37e8..c61d9d6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,10 +3,14 @@
 cachetools==4.2.2
 certifi==2021.5.30
 charset_normalizer==2.0.4
+click==8.1.3
 cloudpickle==1.6.0
+dask[distributed]==2022.6.0
 decorator==5.0.9
+distributed==2022.6.0
 dm-tree==0.1.6
 flatbuffers==2.0
+fsspec==2022.5.0
 future==0.18.2
 gast==0.4.0
 gin==0.1.006
@@ -16,15 +20,21 @@
 google-pasta==0.2.0
 grpcio==1.39.0
 gym==0.19.0
+HeapDict==1.0.1
 h5py==3.6.0
 idna==3.2
+Jinja2==3.1.2
 keras==2.7.0
 keras-preprocessing==1.1.2
 libclang==12.0.0
+locket==1.0.0
 markdown==3.3.4
+MarkupSafe==2.1.1
+msgpack==1.0.4
 numpy==1.22.1
 oauthlib==3.1.1
 opt-einsum==3.3.0
+partd==1.2.0
 pillow==8.3.1
 protobuf==3.17.3
 pyasn1==0.4.8
@@ -36,6 +46,8 @@
 scipy==1.7.1
 setuptools==57.4.0
 six==1.16.0
+sortedcontainers==2.4.0
+tblib==1.7.0
 tensorboard==2.7.0
 tensorboard-data-server==0.6.1
 tensorboard-plugin-wit==1.8.0
@@ -45,10 +57,12 @@
 tensorflow-probability==0.15.0
 termcolor==1.1.0
 tf-agents==0.11.0
+toolz==0.11.2
+tornado==6.1
 typing-extensions==4.0.1
 urllib3==1.26.6
 werkzeug==2.0.1
 wheel==0.37.0
 wrapt==1.13.3
+zict==2.2.0
 zipp==3.7.0
-