blob: 282d611775ef2cedf1a746112cedcd166cbb19ce [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for collecting data locally."""
import collections
import random
from typing import Callable, Dict, Iterator, List, Tuple, Optional
from absl import logging
from tf_agents.system import system_multiprocessing as multiprocessing
from tf_agents.trajectories import trajectory
from compiler_opt.rl import data_collector
# How much work is allowed to be unfinished, relative to number of tasks
# requested in a data collection session, before we stop accepting new work and
# start waiting for it to drain.
_UNFINISHED_WORK_RATIO = 0.5
def default_overload_handler(total_unfinished_work):
logging.warn('Too much unfinished work: %d unfinished!',
total_unfinished_work)
class LocalDataCollector(data_collector.DataCollector):
"""class for local data collection."""
def __init__(
self,
file_paths: Tuple[Tuple[str, ...], ...],
num_workers: int,
num_modules: int,
runner: Callable[[Tuple[str, ...], str, Optional[float], Optional[float]],
Tuple[str, float, float, float]],
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
use_stale_results: bool = False,
max_unfinished_tasks: Optional[int] = None,
overload_handler: Callable[[int], None] = default_overload_handler):
super(LocalDataCollector, self).__init__()
self._file_paths = file_paths
self._num_modules = num_modules
self._runner = runner
self._parser = parser
self._unfinished_work = []
self._pool = multiprocessing.get_context('spawn').Pool(num_workers)
self._max_unfinished_tasks = max_unfinished_tasks
if not self._max_unfinished_tasks:
self._max_unfinished_tasks = _UNFINISHED_WORK_RATIO * num_modules
self._use_stale_results = use_stale_results
self._default_policy_size_map = collections.defaultdict(lambda: None)
self._moving_average_size_map = collections.defaultdict(lambda: None)
self._overloaded_workers_handler = overload_handler
def close_pool(self):
if self._pool:
# Stop accepting new work
self._pool.close()
self._pool.join()
self._pool = None
def inject_unfinished_work_for_test(self, work):
self._unfinished_work = work
@property
def unfinished_work(self):
return self._unfinished_work
def _schedule_jobs(self, policy_path, sampled_file_paths):
jobs = [(file_paths, policy_path, self._default_policy_size_map[file_paths],
self._moving_average_size_map[file_paths])
for file_paths in sampled_file_paths]
return [self._pool.apply_async(self._runner, job) for job in jobs]
def collect_data(
self, policy_path: str
) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, Dict[str, float]]]:
"""Collect data for a given policy.
Args:
policy_path: the path to the policy directory to collect data with.
Returns:
An iterator of batched trajectory.Trajectory that are ready to be fed to
training.
A dict of extra monitoring information, e.g., how many modules succeeded.
They will be reported using `tf.scalar.summary` by the trainer so these
information is viewable in TensorBoard.
"""
sampled_file_paths = random.sample(self._file_paths, k=self._num_modules)
results = self._schedule_jobs(policy_path, sampled_file_paths)
def wait_for_termination():
early_exit = data_collector.EarlyExitChecker(
num_modules=self._num_modules)
def get_num_finished_work():
finished_work = sum(res.ready() for res in results)
unfinished_work = len(results) - finished_work
prev_unfinished_work = sum(
not res.ready() for _, res in self._unfinished_work)
# Handle overworked workers
total_unfinished_work = unfinished_work + prev_unfinished_work
if total_unfinished_work >= self._max_unfinished_tasks:
self._overloaded_workers_handler(total_unfinished_work)
return -1
return finished_work
return early_exit.wait(get_num_finished_work)
wait_seconds = wait_for_termination()
current_work = [
(paths, res) for paths, res in zip(sampled_file_paths, results)
]
finished_work = [(paths, res) for paths, res in current_work if res.ready()]
unfinished_current_work = list(set(current_work) - set(finished_work))
successful_work = [
(paths, res) for paths, res in finished_work if res.successful()
]
stale_results = [(paths, res)
for paths, res in self._unfinished_work
if res.ready() and res.successful()]
self._unfinished_work = unfinished_current_work + [
(paths, res) for paths, res in self._unfinished_work if not res.ready()
]
logging.info(('%d of %d modules finished in %d seconds (%d failures).'
' Currently %d unfinished work'), len(finished_work),
self._num_modules, wait_seconds,
len(finished_work) - len(successful_work),
len(self._unfinished_work))
if self._use_stale_results:
logging.info('Using %d stale results and %d fresh ones.',
len(stale_results), len(successful_work))
successful_work += stale_results
successful_work = [
(paths, res) for paths, res in successful_work if res.get()[0]
]
sequence_examples = [res.get()[0] for (_, res) in successful_work]
self._default_policy_size_map.update(
{file_paths: res.get()[1] for (file_paths, res) in successful_work})
self._moving_average_size_map.update(
{file_paths: res.get()[2] for (file_paths, res) in successful_work})
monitor_dict = {}
monitor_dict['default'] = {'success_modules': len(finished_work)}
rewards = [
1 - (res.get()[3] + data_collector.DELTA) /
(res.get()[1] + data_collector.DELTA) for (_, res) in successful_work
]
monitor_dict[
'reward_distribution'] = data_collector.build_distribution_monitor(
rewards)
return self._parser(sequence_examples), monitor_dict
def on_dataset_consumed(self,
dataset_iterator: Iterator[trajectory.Trajectory]):
pass