| # coding=utf-8 |
| # Copyright 2020 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Module for collecting data locally.""" |
| |
| import concurrent.futures |
| import itertools |
| import random |
| import time |
| from typing import Callable, Dict, Iterator, Iterable, List, Tuple, Optional |
| |
| from absl import logging |
| import gin |
| from tf_agents.trajectories import trajectory |
| |
| from compiler_opt import worker_generator |
| from compiler_opt.rl import compilation_runner |
| from compiler_opt.rl import data_collector |
| |
| |
| def wait_for(futures:Iterable[concurrent.futures.Future]): |
| """Dask futures don't support more than result().""" |
| for f in futures: |
| try: |
| _ = f.result() |
| except: |
| pass |
| |
| |
| @gin.configurable() |
| class LocalDataCollector(data_collector.DataCollector): |
| """class for local data collection.""" |
| |
| def __init__( |
| self, |
| file_paths: Tuple[Tuple[str, ...], ...], |
| num_workers: int, |
| num_modules: int, |
| parser: Callable[[List[str]], Iterator[trajectory.Trajectory]], |
| reward_stat_map: Dict[str, Optional[Dict[str, |
| compilation_runner.RewardStat]]], |
| worker_ctor: Callable[[], compilation_runner.CompilationRunnerStub], |
| exit_checker_ctor=data_collector.EarlyExitChecker): |
| # TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support |
| super().__init__() |
| |
| self._file_paths = file_paths |
| self._num_modules = num_modules |
| self._parser = parser |
| self._shutdown, self._worker_pool = worker_generator( |
| worker_ctor, count=num_workers) |
| self._reward_stat_map = reward_stat_map |
| |
| self._exit_checker_ctor = exit_checker_ctor |
| self._pending_work: concurrent.futures.Future = None |
| self._current_work: list[Tuple[str, concurrent.futures.Future]] = [] |
| self._pool = concurrent.futures.ThreadPoolExecutor() |
| |
| def get_last_work(self): |
| return self._current_work |
| |
| def close_pool(self): |
| self.join_pending_jobs() |
| # Stop accepting new work |
| for p in self._worker_pool: |
| p.cancel_all_work() |
| self._worker_pool = None |
| self._shutdown() |
| self._pool.shutdown() |
| |
| def join_pending_jobs(self): |
| t1 = time.time() |
| if self._pending_work: |
| concurrent.futures.wait([self._pending_work]) |
| |
| self._pending_work = None |
| self._current_work = [] |
| # this should have taken negligible time, normally, since all the work |
| # has been cancelled and the workers had time to process the cancellation |
| # while training was unfolding. |
| logging.info('Waiting for pending work from last iteration took %f', |
| time.time() - t1) |
| |
| def _schedule_jobs( |
| self, policy_path, sampled_file_paths |
| ) -> List[concurrent.futures.Future[compilation_runner.CompilationResult]]: |
| # by now, all the pending work, which was signaled to cancel, must've |
| # finished |
| self.join_pending_jobs() |
| jobs = [(file_paths, policy_path, |
| self._reward_stat_map['-'.join(file_paths)]) |
| for file_paths in sampled_file_paths] |
| |
| # Naive load balancing |
| ret = [] |
| for i in range(len(jobs)): |
| ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data( |
| *(jobs[i]))) |
| return ret |
| |
| def collect_data( |
| self, policy_path: str |
| ) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, Dict[str, float]]]: |
| """Collect data for a given policy. |
| |
| Args: |
| policy_path: the path to the policy directory to collect data with. |
| |
| Returns: |
| An iterator of batched trajectory.Trajectory that are ready to be fed to |
| training. |
| A dict of extra monitoring information, e.g., how many modules succeeded. |
| They will be reported using `tf.scalar.summary` by the trainer so these |
| information is viewable in TensorBoard. |
| """ |
| sampled_file_paths = random.sample(self._file_paths, k=self._num_modules) |
| results = self._schedule_jobs(policy_path, sampled_file_paths) |
| |
| def wait_for_termination(): |
| early_exit = self._exit_checker_ctor(num_modules=self._num_modules) |
| |
| def get_num_finished_work(): |
| finished_work = sum(res.done() for res in results) |
| return finished_work |
| |
| return early_exit.wait(get_num_finished_work) |
| |
| wait_seconds = wait_for_termination() |
| self._current_work = list(zip(sampled_file_paths, results)) |
| finished_work = [ |
| (paths, res) for paths, res in self._current_work if res.done() |
| ] |
| |
| def is_successful(f): |
| try: |
| _ = f.result() |
| return True |
| except: # pylint: disable=bare-except |
| return False |
| |
| successful_work = [(paths, res.result()) |
| for paths, res in finished_work |
| if is_successful(res)] |
| failures = len(finished_work) - len(successful_work) |
| logging.info(('%d of %d modules finished in %d seconds (%d failures).'), |
| len(finished_work), self._num_modules, wait_seconds, failures) |
| # signal whatever work is left to finish |
| def wrapup(): |
| cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool] |
| wait_for(cancel_futures) |
| wait_for(results) |
| wait_for([wkr.enable() for wkr in self._worker_pool]) |
| self._pending_work = self._pool.submit(wrapup) |
| |
| sequence_examples = list( |
| itertools.chain.from_iterable( |
| [res.serialized_sequence_examples for (_, res) in successful_work])) |
| total_trajectory_length = sum(res.length for (_, res) in successful_work) |
| self._reward_stat_map.update({ |
| '-'.join(file_paths): res.reward_stats |
| for (file_paths, res) in successful_work |
| }) |
| |
| monitor_dict = {} |
| monitor_dict['default'] = { |
| 'success_modules': len(successful_work), |
| 'total_trajectory_length': total_trajectory_length, |
| } |
| rewards = list( |
| itertools.chain.from_iterable( |
| [res.rewards for (_, res) in successful_work])) |
| monitor_dict[ |
| 'reward_distribution'] = data_collector.build_distribution_monitor( |
| rewards) |
| |
| parsed = self._parser(sequence_examples) |
| |
| return parsed, monitor_dict |
| |
| def on_dataset_consumed(self, |
| dataset_iterator: Iterator[trajectory.Trajectory]): |
| pass |