Use dask for worker management.
The intention is to make the middleware replaceable, but under
assumptions that are strongly Dask-inspired.
diff --git a/compiler_opt/__init__.py b/compiler_opt/__init__.py
index 6078d4f..7c35352 100644
--- a/compiler_opt/__init__.py
+++ b/compiler_opt/__init__.py
@@ -15,6 +15,8 @@
"""Ensure flags are initialized for e.g. pytest harness case."""
import sys
+#import compiler_opt.core.dask.worker_manager as default_factory
+import compiler_opt.core.multiprocessing.worker_manager as default_factory
from absl import flags
@@ -28,3 +30,5 @@
if not flags.FLAGS.is_parsed():
flags.FLAGS(sys.argv, known_only=True)
assert flags.FLAGS.is_parsed()
+
+worker_generator = default_factory.get_compilation_jobs
diff --git a/compiler_opt/core/abstract_worker.py b/compiler_opt/core/abstract_worker.py
new file mode 100644
index 0000000..a1ed868
--- /dev/null
+++ b/compiler_opt/core/abstract_worker.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Common abstraction for a worker contract."""
+
+from abc import abstractmethod
+
+
+class AbstractWorker:
+
+ @abstractmethod
+ def cancel_all_work(self):
+ ...
diff --git a/compiler_opt/core/dask/worker_manager.py b/compiler_opt/core/dask/worker_manager.py
new file mode 100644
index 0000000..d4f5f47
--- /dev/null
+++ b/compiler_opt/core/dask/worker_manager.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dask - based middleware implementation."""
+from absl import logging
+import gin
+import dask.config
+import dask.utils
+import multiprocessing
+import tempfile
+
+from compiler_opt.core.abstract_worker import AbstractWorker
+from concurrent.futures import ThreadPoolExecutor
+from dask.distributed import Client, Worker, LocalCluster
+from typing import Callable, Optional, Tuple
+
+
+class MTWorker(Worker):
+ """Multi-threaded worker."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.executors['actor']:
+ self.executors['actor'].shutdown()
+ self.executors['actor'] = ThreadPoolExecutor(
+ max_workers=None, thread_name_prefix='Dask-Actor-MT')
+
+
+class LocalManager:
+ """Local, dask-based worker manager."""
+
+ def __init__(self):
+ self._tmpdir = tempfile.TemporaryDirectory()
+ dask.config.set({
+ 'temporary-directory': self._tmpdir.name,
+ 'distributed.worker.daemon': False,
+ 'work-stealing': False
+ })
+
+ self._client = Client(
+ dashboard_address=None,
+ processes=True,
+ n_workers=1,
+ worker_class=MTWorker)
+ print(self._client)
+
+ def shutdown(self):
+ self._client.close()
+ self._tmpdir.cleanup()
+
+ def get_client(self):
+ return self._client
+
+
+def get_local_compilation_jobs(ctor: Callable[[], AbstractWorker],
+ count: Optional[int]) -> Tuple[Callable, list]:
+
+ class DaskStubWrapper(AbstractWorker):
+
+ def __init__(self, stub: AbstractWorker):
+ self._stub = stub
+
+ def cancel_all_work(self):
+ return self._stub.cancel_all_work(separate_thread=False)
+
+ def __getattr__(self, name):
+ return self._stub.__getattr__(name)
+
+ if not count:
+ count = multiprocessing.cpu_count()
+ instance = LocalManager()
+ workers = [
+ instance.get_client().submit(ctor, actor=True) for _ in range(count)
+ ]
+ return instance.shutdown, [
+ DaskStubWrapper(worker.result()) for worker in workers
+ ]
diff --git a/compiler_opt/core/multiprocessing/worker_manager.py b/compiler_opt/core/multiprocessing/worker_manager.py
new file mode 100644
index 0000000..46fd911
--- /dev/null
+++ b/compiler_opt/core/multiprocessing/worker_manager.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dask - based middleware implementation."""
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+import concurrent.futures
+import functools
+import multiprocessing
+import multiprocessing.connection
+import queue
+import threading
+
+from compiler_opt.core.abstract_worker import AbstractWorker
+from typing import Callable, Optional, Tuple
+
+def worker(ctor, in_q:queue.Queue, out_q:queue.Queue):
+ pool = ThreadPoolExecutor()
+ obj = ctor()
+
+ def make_ondone(msgid):
+
+ def on_done(f: concurrent.futures.Future):
+ if f.exception():
+ out_q.put((msgid, False, f.exception()))
+ else:
+ out_q.put((msgid, True, f.result()))
+ return on_done
+
+ while True:
+ msgid, fname, args, kwargs, urgent = in_q.get()
+ the_func = getattr(obj, fname)
+ if urgent:
+ try:
+ res = the_func(*args, **kwargs)
+ out_q.put((msgid, True, res))
+ except Exception as e:
+ out_q.put((msgid, False, e))
+ else:
+ pool.submit(the_func, *args,
+ **kwargs).add_done_callback(make_ondone(msgid))
+
+
+class Stub:
+ def __init__(self, ctor):
+ self._send = multiprocessing.get_context().Queue()
+ self._receive = multiprocessing.get_context().Queue()
+
+ self._process = multiprocessing.Process(
+ target=functools.partial(worker, ctor, self._send, self._receive))
+ self._lock = threading.Lock()
+ self._map:dict[object, concurrent.futures.Future] = {}
+ self._pump = threading.Thread(target=self._msg_pump)
+ self._done = threading.Event()
+ self._msgidlock = threading.Lock()
+ self._msgid = 0
+ self._process.start()
+ self._pump.start()
+
+ def _msg_pump(self):
+ while not self._done.is_set():
+ msgid, succeeded, value = self._receive.get()
+ with self._lock:
+ future = self._map[msgid]
+ del self._map[msgid]
+ if succeeded:
+ future.set_result(value)
+ else:
+ future.set_exception(value)
+
+ def __getattr__(self, name):
+ with self._msgidlock:
+ msgid = self._msgid
+ self._msgid += 1
+ def remote_call(*args, **kwargs):
+ self._send.put(
+ (msgid, name, args, kwargs, name == 'cancel_all_work'))
+ future = concurrent.futures.Future()
+ with self._lock:
+ self._map[msgid] = future
+ return future
+
+ return remote_call
+
+ def kill(self):
+ try:
+ self._process.kill()
+ except:
+ pass
+ self._done.set()
+
+def get_compilation_jobs(ctor: Callable[[], AbstractWorker],
+ count: Optional[int]) -> Tuple[Callable, list]:
+
+ if not count:
+ count = multiprocessing.cpu_count()
+ stubs = [Stub(ctor) for _ in range(count)]
+ def shutdown():
+ for s in stubs:
+ s.kill()
+ return shutdown, [Stub(ctor) for _ in range(count)]
diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py
index 464d908..4722d1c 100644
--- a/compiler_opt/rl/compilation_runner.py
+++ b/compiler_opt/rl/compilation_runner.py
@@ -14,10 +14,10 @@
# limitations under the License.
"""Module for running compilation and collect training data."""
-import concurrent
+from abc import abstractclassmethod, abstractmethod
+import concurrent.futures
import dataclasses
import json
-import multiprocessing
import subprocess
import threading
from typing import Dict, List, Optional, Tuple
@@ -119,18 +119,6 @@
Exception.__init__(self)
-class ProcessCancellationToken:
-
- def __init__(self):
- self._event = multiprocessing.Manager().Event()
-
- def signal(self):
- self._event.set()
-
- def wait(self):
- self._event.wait()
-
-
def kill_process_ignore_exceptions(p: 'subprocess.Popen[bytes]'):
# kill the process and ignore exceptions. Exceptions would be thrown if the
# process has already been killed/finished (which is inherently in a race
@@ -157,6 +145,9 @@
self._done = False
self._lock = threading.Lock()
+ def enable(self):
+ self._done = False
+
def register_process(self, p: 'subprocess.Popen[bytes]'):
"""Register a process for potential cancellation."""
with self._lock:
@@ -165,12 +156,13 @@
return
kill_process_ignore_exceptions(p)
- def signal(self):
+ def kill_all_processes(self):
"""Cancel any pending work."""
with self._lock:
self._done = True
for p in self._processes:
kill_process_ignore_exceptions(p)
+ return len(self._processes)
def unregister_process(self, p: 'subprocess.Popen[bytes]'):
with self._lock:
@@ -262,22 +254,24 @@
assert not hasattr(self, 'sequence_examples')
+class CompilationRunnerStub:
+ """The interface of a stub to CompilationRunner."""
+
+ @abstractmethod
+ def collect_data(
+ self, file_paths: Tuple[str, ...], tf_policy_path: str,
+ reward_stat: Optional[Dict[str, RewardStat]]
+ ) -> concurrent.futures.Future[CompilationResult]:
+ ...
+
+ @abstractmethod
+ def cancel_all_work(self) -> concurrent.futures.Future:
+ ...
+
+
class CompilationRunner:
"""Base class for collecting compilation data."""
- _POOL: concurrent.futures.ThreadPoolExecutor = None
-
- @staticmethod
- def init_pool():
- """Worker process initialization."""
- CompilationRunner._POOL = concurrent.futures.ThreadPoolExecutor()
-
- @staticmethod
- def _get_pool():
- """Internal API for fetching the cancellation token waiting pool."""
- assert CompilationRunner._POOL
- return CompilationRunner._POOL
-
def __init__(self,
clang_path: Optional[str] = None,
launcher_path: Optional[str] = None,
@@ -293,40 +287,17 @@
self._launcher_path = launcher_path
self._moving_average_decay_rate = moving_average_decay_rate
self._compilation_timeout = _COMPILATION_TIMEOUT.value
+ self._cancellation_manager = WorkerCancellationManager()
- def _get_cancellation_manager(
- self, cancellation_token: Optional[ProcessCancellationToken]
- ) -> Optional[WorkerCancellationManager]:
- """Convert the ProcessCancellationToken into a WorkerCancellationManager.
+ def enable(self):
+ self._cancellation_manager.enable()
- The conversion also registers the ProcessCancellationToken wait() on a
- thread which will call the WorkerCancellationManager upon completion.
- Since the token is always signaled, the thread always completes its work.
-
- Args:
- cancellation_token: the ProcessCancellationToken to convert.
-
- Returns:
- a WorkerCancellationManager, if a ProcessCancellationToken was given.
- """
- if not cancellation_token:
- return None
- ret = WorkerCancellationManager()
-
- def signaler():
- cancellation_token.wait()
- ret.signal()
-
- CompilationRunner._get_pool().submit(signaler)
- return ret
+ def cancel_all_work(self):
+ return self._cancellation_manager.kill_all_processes()
def collect_data(
- self,
- file_paths: Tuple[str, ...],
- tf_policy_path: str,
- reward_stat: Optional[Dict[str, RewardStat]],
- cancellation_token: Optional[ProcessCancellationToken] = None
- ) -> CompilationResult:
+ self, file_paths: Tuple[str, ...], tf_policy_path: str,
+ reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
"""Collect data for the given IR file and policy.
Args:
@@ -346,14 +317,12 @@
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
- cancellation_manager = self._get_cancellation_manager(cancellation_token)
-
if reward_stat is None:
default_result = self._compile_fn(
file_paths,
tf_policy_path='',
reward_only=bool(tf_policy_path),
- cancellation_manager=cancellation_manager)
+ cancellation_manager=self._cancellation_manager)
reward_stat = {
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
}
@@ -363,7 +332,7 @@
file_paths,
tf_policy_path,
reward_only=False,
- cancellation_manager=cancellation_manager)
+ cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result
diff --git a/compiler_opt/rl/inlining/__init__.py b/compiler_opt/rl/inlining/__init__.py
index 6487ac8..bce5988 100644
--- a/compiler_opt/rl/inlining/__init__.py
+++ b/compiler_opt/rl/inlining/__init__.py
@@ -28,6 +28,9 @@
def get_runner(self, *args, **kwargs):
return inlining_runner.InliningRunner(*args, **kwargs)
+ def get_runner_ctor(self):
+ return inlining_runner.InliningRunner
+
def get_signature_spec(self):
return config.get_inlining_signature_spec()
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index 3b80bbc..fea34e4 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -14,19 +14,31 @@
# limitations under the License.
"""Module for collecting data locally."""
+import concurrent.futures
import itertools
import random
import time
-from typing import Callable, Dict, Iterator, List, Tuple, Optional
+from typing import Callable, Dict, Iterator, Iterable, List, Tuple, Optional
from absl import logging
-import multiprocessing # for Pool
+import gin
from tf_agents.trajectories import trajectory
+from compiler_opt import worker_generator
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import data_collector
+def wait_for(futures:Iterable[concurrent.futures.Future]):
+ """Dask futures don't support more than result()."""
+ for f in futures:
+ try:
+ _ = f.result()
+ except:
+ pass
+
+
+@gin.configurable()
class LocalDataCollector(data_collector.DataCollector):
"""class for local data collection."""
@@ -35,71 +47,67 @@
file_paths: Tuple[Tuple[str, ...], ...],
num_workers: int,
num_modules: int,
- runner: compilation_runner.CompilationRunner,
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
reward_stat_map: Dict[str, Optional[Dict[str,
compilation_runner.RewardStat]]],
+ worker_ctor: Callable[[], compilation_runner.CompilationRunnerStub],
exit_checker_ctor=data_collector.EarlyExitChecker):
# TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
super().__init__()
self._file_paths = file_paths
self._num_modules = num_modules
- self._runner = runner
self._parser = parser
- self._pool = multiprocessing.get_context().Pool(
- num_workers, initializer=compilation_runner.CompilationRunner.init_pool)
+ self._shutdown, self._worker_pool = worker_generator(
+ worker_ctor, count=num_workers)
self._reward_stat_map = reward_stat_map
self._exit_checker_ctor = exit_checker_ctor
- self._pending_work = None
- # hold on to the token so it won't get GCed before all its wait()
- # complete
- self._last_token = None
+ self._pending_work: concurrent.futures.Future = None
+ self._current_work: list[Tuple[str, concurrent.futures.Future]] = []
+ self._pool = concurrent.futures.ThreadPoolExecutor()
+
+ def get_last_work(self):
+ return self._current_work
def close_pool(self):
- self._join_pending_jobs()
- if self._pool:
- # Stop accepting new work
- self._pool.close()
- self._pool.join()
- self._pool = None
+ self.join_pending_jobs()
+ # Stop accepting new work
+ for p in self._worker_pool:
+ p.cancel_all_work()
+ self._worker_pool = None
+ self._shutdown()
+ self._pool.shutdown()
- def _join_pending_jobs(self):
+ def join_pending_jobs(self):
+ t1 = time.time()
if self._pending_work:
- t1 = time.time()
- for w in self._pending_work:
- w.wait()
+ concurrent.futures.wait([self._pending_work])
- self._pending_work = None
- # this should have taken negligible time, normally, since all the work
- # has been cancelled and the workers had time to process the cancellation
- # while training was unfolding.
- logging.info('Waiting for pending work from last iteration took %f',
- time.time() - t1)
- self._last_token = None
+ self._pending_work = None
+ self._current_work = []
+ # this should have taken negligible time, normally, since all the work
+ # has been cancelled and the workers had time to process the cancellation
+ # while training was unfolding.
+ logging.info('Waiting for pending work from last iteration took %f',
+ time.time() - t1)
- def _schedule_jobs(self, policy_path, sampled_file_paths):
+ def _schedule_jobs(
+ self, policy_path, sampled_file_paths
+ ) -> List[concurrent.futures.Future[compilation_runner.CompilationResult]]:
# by now, all the pending work, which was signaled to cancel, must've
# finished
- self._join_pending_jobs()
- cancellation_token = compilation_runner.ProcessCancellationToken()
+ self.join_pending_jobs()
jobs = [(file_paths, policy_path,
- self._reward_stat_map['-'.join(file_paths)], cancellation_token)
+ self._reward_stat_map['-'.join(file_paths)])
for file_paths in sampled_file_paths]
- # Make sure we're not missing failures in workers. All but
- # ProcessKilledError, which we want to ignore.
- def error_callback(e):
- if isinstance(e, compilation_runner.ProcessKilledError):
- return
- logging.exception('Error in worker: %r', e)
-
- return (cancellation_token, [
- self._pool.apply_async(
- self._runner.collect_data, job, error_callback=error_callback)
- for job in jobs
- ])
+ # Naive load balancing
+ ret = []
+ for i in range(len(jobs)):
+ ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
+ *(jobs[i])))
+ return ret
def collect_data(
self, policy_path: str
@@ -117,39 +125,50 @@
information is viewable in TensorBoard.
"""
sampled_file_paths = random.sample(self._file_paths, k=self._num_modules)
- ct, results = self._schedule_jobs(policy_path, sampled_file_paths)
+ results = self._schedule_jobs(policy_path, sampled_file_paths)
def wait_for_termination():
early_exit = self._exit_checker_ctor(num_modules=self._num_modules)
def get_num_finished_work():
- finished_work = sum(res.ready() for res in results)
+ finished_work = sum(res.done() for res in results)
return finished_work
return early_exit.wait(get_num_finished_work)
wait_seconds = wait_for_termination()
- # signal whatever work is left to finish
- ct.signal()
- current_work = zip(sampled_file_paths, results)
- finished_work = [(paths, res) for paths, res in current_work if res.ready()]
- successful_work = [
- (paths, res) for paths, res in finished_work if res.successful()
+ self._current_work = list(zip(sampled_file_paths, results))
+ finished_work = [
+ (paths, res) for paths, res in self._current_work if res.done()
]
- failures = len(finished_work) - len(successful_work)
+ def is_successful(f):
+ try:
+ _ = f.result()
+ return True
+ except: # pylint: disable=bare-except
+ return False
+
+ successful_work = [(paths, res.result())
+ for paths, res in finished_work
+ if is_successful(res)]
+ failures = len(finished_work) - len(successful_work)
logging.info(('%d of %d modules finished in %d seconds (%d failures).'),
len(finished_work), self._num_modules, wait_seconds, failures)
+ # signal whatever work is left to finish
+ def wrapup():
+ cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
+ wait_for(cancel_futures)
+ wait_for(results)
+ wait_for([wkr.enable() for wkr in self._worker_pool])
+ self._pending_work = self._pool.submit(wrapup)
sequence_examples = list(
- itertools.chain.from_iterable([
- res.get().serialized_sequence_examples
- for (_, res) in successful_work
- ]))
- total_trajectory_length = sum(
- res.get().length for (_, res) in successful_work)
+ itertools.chain.from_iterable(
+ [res.serialized_sequence_examples for (_, res) in successful_work]))
+ total_trajectory_length = sum(res.length for (_, res) in successful_work)
self._reward_stat_map.update({
- '-'.join(file_paths): res.get().reward_stats
+ '-'.join(file_paths): res.reward_stats
for (file_paths, res) in successful_work
})
@@ -160,19 +179,13 @@
}
rewards = list(
itertools.chain.from_iterable(
- [res.get().rewards for (_, res) in successful_work]))
+ [res.rewards for (_, res) in successful_work]))
monitor_dict[
'reward_distribution'] = data_collector.build_distribution_monitor(
rewards)
parsed = self._parser(sequence_examples)
- self._pending_work = [res for res in results if not res.ready()]
- # if some of the cancelled work hasn't yet processed the signal, let's let
- # it do that while we process the data. We also need to hold on to the
- # current token, so its Event object doesn't get GC-ed here.
- if self._pending_work:
- self._last_token = ct
return parsed, monitor_dict
def on_dataset_consumed(self,
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index 51a35b4..84f0e2e 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -44,7 +44,7 @@
return text_format.Parse(sequence_example_text, tf.train.SequenceExample())
-def mock_collect_data(file_paths, tf_policy_dir, reward_stat, _):
+def mock_collect_data(file_paths, tf_policy_dir, reward_stat):
assert file_paths == ('a', 'b')
assert tf_policy_dir == 'policy'
assert reward_stat is None or reward_stat == {
@@ -77,30 +77,10 @@
class Sleeper(compilation_runner.CompilationRunner):
"""Test CompilationRunner that just sleeps."""
- # pylint: disable=super-init-not-called
- def __init__(self, killed, timedout, living):
- self._killed = killed
- self._timedout = timedout
- self._living = living
- self._lock = mp.Manager().Lock()
-
- def collect_data(self, file_paths, tf_policy_path, reward_stat,
- cancellation_token):
- _ = reward_stat
- cancellation_manager = self._get_cancellation_manager(cancellation_token)
- try:
- compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
- cancellation_manager)
- except compilation_runner.ProcessKilledError as e:
- with self._lock:
- self._killed.value += 1
- raise e
- except subprocess.TimeoutExpired as e:
- with self._lock:
- self._timedout.value += 1
- raise e
- with self._lock:
- self._living.value += 1
+ def collect_data(self, file_paths, tf_policy_path, reward_stat):
+ _ = file_paths, tf_policy_path, reward_stat
+ compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
+ self._cancellation_manager)
return compilation_runner.CompilationResult(
sequence_examples=[], reward_stats={}, rewards=[], keys=[])
@@ -108,10 +88,15 @@
class LocalDataCollectorTest(tf.test.TestCase):
def test_local_data_collector(self):
- mock_compilation_runner = mock.create_autospec(
- compilation_runner.CompilationRunner)
- mock_compilation_runner.collect_data = mock_collect_data
+ def make_runner():
+
+ class MyRunner(compilation_runner.CompilationRunner):
+
+ def collect_data(self, *args, **kwargs):
+ return mock_collect_data(*args, **kwargs)
+
+ return MyRunner()
def create_test_iterator_fn():
@@ -132,8 +117,8 @@
file_paths=tuple([('a', 'b')] * 100),
num_workers=4,
num_modules=9,
- runner=mock_compilation_runner,
parser=create_test_iterator_fn(),
+ worker_ctor=make_runner,
reward_stat_map=collections.defaultdict(lambda: None))
data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
@@ -161,11 +146,6 @@
collector.close_pool()
def test_local_data_collector_task_management(self):
- killed = mp.Manager().Value('i', value=0)
- timedout = mp.Manager().Value('i', value=0)
- living = mp.Manager().Value('i', value=0)
-
- mock_compilation_runner = Sleeper(killed, timedout, living)
def parser(_):
pass
@@ -182,17 +162,18 @@
file_paths=tuple([('a', 'b')] * 200),
num_workers=4,
num_modules=4,
- runner=mock_compilation_runner,
+ worker_ctor=Sleeper,
parser=parser,
reward_stat_map=collections.defaultdict(lambda: None),
exit_checker_ctor=QuickExiter)
collector.collect_data(policy_path='policy')
- # close the pool first, so we are forced to wait for the workers to process
- # their cancellation.
+ collector.join_pending_jobs()
+ killed = 0
+ for _, w in collector.get_last_work():
+ self.assertRaises(compilation_runner.ProcessKilledError, w.result)
+ killed += 1
+ self.assertEquals(killed, 4)
collector.close_pool()
- self.assertEqual(killed.value, 4)
- self.assertEqual(living.value, 0)
- self.assertEqual(timedout.value, 0)
if __name__ == '__main__':
diff --git a/compiler_opt/rl/problem_configuration.py b/compiler_opt/rl/problem_configuration.py
index ff6d089..28f169d 100644
--- a/compiler_opt/rl/problem_configuration.py
+++ b/compiler_opt/rl/problem_configuration.py
@@ -94,7 +94,10 @@
@abc.abstractmethod
def get_runner(self, *args, **kwargs) -> compilation_runner.CompilationRunner:
raise NotImplementedError
-
+
+ @abc.abstractmethod
+ def get_runner_ctor(self):
+ raise NotImplementedError
def is_thinlto(module_paths: Iterable[str]) -> bool:
return tf.io.gfile.exists(next(iter(module_paths)) + '.thinlto.bc')
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 5a80a9b..f4576c8 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -25,7 +25,6 @@
from absl import logging
import gin
import tensorflow as tf
-from tf_agents.system import system_multiprocessing as multiprocessing
from typing import List
from compiler_opt.rl import agent_creators
@@ -111,9 +110,6 @@
file_paths = [(path + '.bc', path + '.cmd', path + '.thinlto.bc')
for path in module_paths]
- runner = problem_config.get_runner(
- moving_average_decay_rate=moving_average_decay_rate)
-
dataset_fn = data_reader.create_sequence_example_dataset_fn(
agent_name=agent_name,
time_step_spec=time_step_spec,
@@ -145,7 +141,9 @@
file_paths=file_paths,
num_workers=FLAGS.num_workers,
num_modules=FLAGS.num_modules,
- runner=runner,
+ worker_ctor=functools.partial(
+ problem_config.get_runner_ctor(),
+ moving_average_decay_rate=moving_average_decay_rate),
parser=sequence_example_iterator_fn,
reward_stat_map=reward_stat_map)
@@ -185,4 +183,4 @@
if __name__ == '__main__':
flags.mark_flag_as_required('data_path')
- multiprocessing.handle_main(functools.partial(app.run, main))
+ app.run(main)
diff --git a/requirements.txt b/requirements.txt
index 45a37e8..c61d9d6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,10 +3,14 @@
cachetools==4.2.2
certifi==2021.5.30
charset_normalizer==2.0.4
+click==8.1.3
cloudpickle==1.6.0
+dask[distributed]==2022.6.0
decorator==5.0.9
+distributed==2022.6.0
dm-tree==0.1.6
flatbuffers==2.0
+fsspec==2022.5.0
future==0.18.2
gast==0.4.0
gin==0.1.006
@@ -16,15 +20,21 @@
google-pasta==0.2.0
grpcio==1.39.0
gym==0.19.0
+HeapDict==1.0.1
h5py==3.6.0
idna==3.2
+Jinja2==3.1.2
keras==2.7.0
keras-preprocessing==1.1.2
libclang==12.0.0
+locket==1.0.0
markdown==3.3.4
+MarkupSafe==2.1.1
+msgpack==1.0.4
numpy==1.22.1
oauthlib==3.1.1
opt-einsum==3.3.0
+partd==1.2.0
pillow==8.3.1
protobuf==3.17.3
pyasn1==0.4.8
@@ -36,6 +46,8 @@
scipy==1.7.1
setuptools==57.4.0
six==1.16.0
+sortedcontainers==2.4.0
+tblib==1.7.0
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
@@ -45,10 +57,12 @@
tensorflow-probability==0.15.0
termcolor==1.1.0
tf-agents==0.11.0
+toolz==0.11.2
+tornado==6.1
typing-extensions==4.0.1
urllib3==1.26.6
werkzeug==2.0.1
wheel==0.37.0
wrapt==1.13.3
+zict==2.2.0
zipp==3.7.0
-