blob: fea34e4e44b6a8ee733245e32dd0e2d7d6e06d0e [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for collecting data locally."""
import concurrent.futures
import itertools
import random
import time
from typing import Callable, Dict, Iterator, Iterable, List, Tuple, Optional
from absl import logging
import gin
from tf_agents.trajectories import trajectory
from compiler_opt import worker_generator
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import data_collector
def wait_for(futures:Iterable[concurrent.futures.Future]):
"""Dask futures don't support more than result()."""
for f in futures:
try:
_ = f.result()
except:
pass
@gin.configurable()
class LocalDataCollector(data_collector.DataCollector):
"""class for local data collection."""
def __init__(
self,
file_paths: Tuple[Tuple[str, ...], ...],
num_workers: int,
num_modules: int,
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
reward_stat_map: Dict[str, Optional[Dict[str,
compilation_runner.RewardStat]]],
worker_ctor: Callable[[], compilation_runner.CompilationRunnerStub],
exit_checker_ctor=data_collector.EarlyExitChecker):
# TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
super().__init__()
self._file_paths = file_paths
self._num_modules = num_modules
self._parser = parser
self._shutdown, self._worker_pool = worker_generator(
worker_ctor, count=num_workers)
self._reward_stat_map = reward_stat_map
self._exit_checker_ctor = exit_checker_ctor
self._pending_work: concurrent.futures.Future = None
self._current_work: list[Tuple[str, concurrent.futures.Future]] = []
self._pool = concurrent.futures.ThreadPoolExecutor()
def get_last_work(self):
return self._current_work
def close_pool(self):
self.join_pending_jobs()
# Stop accepting new work
for p in self._worker_pool:
p.cancel_all_work()
self._worker_pool = None
self._shutdown()
self._pool.shutdown()
def join_pending_jobs(self):
t1 = time.time()
if self._pending_work:
concurrent.futures.wait([self._pending_work])
self._pending_work = None
self._current_work = []
# this should have taken negligible time, normally, since all the work
# has been cancelled and the workers had time to process the cancellation
# while training was unfolding.
logging.info('Waiting for pending work from last iteration took %f',
time.time() - t1)
def _schedule_jobs(
self, policy_path, sampled_file_paths
) -> List[concurrent.futures.Future[compilation_runner.CompilationResult]]:
# by now, all the pending work, which was signaled to cancel, must've
# finished
self.join_pending_jobs()
jobs = [(file_paths, policy_path,
self._reward_stat_map['-'.join(file_paths)])
for file_paths in sampled_file_paths]
# Naive load balancing
ret = []
for i in range(len(jobs)):
ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
*(jobs[i])))
return ret
def collect_data(
self, policy_path: str
) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, Dict[str, float]]]:
"""Collect data for a given policy.
Args:
policy_path: the path to the policy directory to collect data with.
Returns:
An iterator of batched trajectory.Trajectory that are ready to be fed to
training.
A dict of extra monitoring information, e.g., how many modules succeeded.
They will be reported using `tf.scalar.summary` by the trainer so these
information is viewable in TensorBoard.
"""
sampled_file_paths = random.sample(self._file_paths, k=self._num_modules)
results = self._schedule_jobs(policy_path, sampled_file_paths)
def wait_for_termination():
early_exit = self._exit_checker_ctor(num_modules=self._num_modules)
def get_num_finished_work():
finished_work = sum(res.done() for res in results)
return finished_work
return early_exit.wait(get_num_finished_work)
wait_seconds = wait_for_termination()
self._current_work = list(zip(sampled_file_paths, results))
finished_work = [
(paths, res) for paths, res in self._current_work if res.done()
]
def is_successful(f):
try:
_ = f.result()
return True
except: # pylint: disable=bare-except
return False
successful_work = [(paths, res.result())
for paths, res in finished_work
if is_successful(res)]
failures = len(finished_work) - len(successful_work)
logging.info(('%d of %d modules finished in %d seconds (%d failures).'),
len(finished_work), self._num_modules, wait_seconds, failures)
# signal whatever work is left to finish
def wrapup():
cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
wait_for(cancel_futures)
wait_for(results)
wait_for([wkr.enable() for wkr in self._worker_pool])
self._pending_work = self._pool.submit(wrapup)
sequence_examples = list(
itertools.chain.from_iterable(
[res.serialized_sequence_examples for (_, res) in successful_work]))
total_trajectory_length = sum(res.length for (_, res) in successful_work)
self._reward_stat_map.update({
'-'.join(file_paths): res.reward_stats
for (file_paths, res) in successful_work
})
monitor_dict = {}
monitor_dict['default'] = {
'success_modules': len(successful_work),
'total_trajectory_length': total_trajectory_length,
}
rewards = list(
itertools.chain.from_iterable(
[res.rewards for (_, res) in successful_work]))
monitor_dict[
'reward_distribution'] = data_collector.build_distribution_monitor(
rewards)
parsed = self._parser(sequence_examples)
return parsed, monitor_dict
def on_dataset_consumed(self,
dataset_iterator: Iterator[trajectory.Trajectory]):
pass