blob: e4558bbc932b7fcad08ccb27d61ac23f24c8db2e [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data collection module."""
import abc
import time
from typing import Dict, Iterator, Tuple, Sequence
import numpy as np
from tf_agents.trajectories import trajectory
# Deadline for data collection.
DEADLINE_IN_SECONDS = 30
# We don't wait for all data collection to finish --- it continues if either of
# the wait_termination_conditions is met.
# (0.8, 0.5) means it won't wait for more data collection to finish if 80% of
# the data collection have finished and it has waited 50% of
# _DEADLINE_IN_SECONDS time.
WAIT_TERMINATION = ((0.9, 0), (0.8, 0.5), (0, 1))
DELTA = 0.01
REWARD_QUANTILE_MONITOR = (0.1, 0.5, 1, 2, 3, 4, 5, 6, 8, 10, 20, 30, 40, 50,
60, 70, 80, 90, 95, 99, 99.5, 99.9)
def build_distribution_monitor(data: Sequence[float]) -> Dict[str, float]:
quantiles = np.percentile(
data, REWARD_QUANTILE_MONITOR, interpolation='lower')
monitor_dict = {
f'p_{x}': y for (x, y) in zip(REWARD_QUANTILE_MONITOR, quantiles)
}
monitor_dict['mean'] = np.mean(data)
return monitor_dict
class DataCollector(metaclass=abc.ABCMeta):
"""Abstract class for data collection."""
@abc.abstractmethod
def collect_data(
self, policy_path: str
) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, float]]:
"""Collect data for a given policy.
Args:
policy_path: the path to the policy directory to collect data with.
Returns:
An iterator of batched trajectory.Trajectory that are ready to be fed to
training.
A dict of extra monitoring information, e.g., how many modules succeeded.
They will be reported using `tf.scalar.summary` by the trainer so these
information is viewable in TensorBoard.
"""
@abc.abstractmethod
def on_dataset_consumed(self,
dataset_iterator: Iterator[trajectory.Trajectory]):
"""Clean up after the data has been consumed.
Args:
dataset_iterator: the dataset_iterator that has been consumed.
"""
class EarlyExitChecker:
"""Class which checks if it is ok to early-exit from data collection."""
def __init__(self,
num_modules: int,
deadline: float = DEADLINE_IN_SECONDS,
thresholds: Tuple[Tuple[float, float], ...] = WAIT_TERMINATION):
"""Initializes the early exit checker.
Args:
num_modules: How many total modules we are waiting for.
deadline: The deadline for data collection, in seconds.
thresholds: Early exit thresholds, e.g. [(0.8, 0.5)] means early exit is
allowable if 80% of data has been collected and we've waited 50% of the
maximum waiting time.
"""
self._num_modules = num_modules
self._deadline = deadline
self._thresholds = thresholds
self._start_time = time.time()
self._waited_time = 0
def _should_exit(self, collected: int):
"""Checks whether we should exit early.
If collected is negative, _should_exit will always return false.
Args:
collected: The amount data we have collected.
Returns:
True if we should exit, otherwise False.
"""
if collected < 0:
return False
self._waited_time = round(time.time() - self._start_time)
for (data_threshold, deadline_threshold) in self._thresholds:
if ((collected >= self._num_modules * data_threshold) and
(self._waited_time >= self._deadline * deadline_threshold)):
return True
return False
def wait(self, get_num_finished_work):
"""Waits until the deadline has expired or an early exit is possible.
Args:
get_num_finished_work: a callable object which returns the amount of
finished work.
Returns:
The amount of time waited.
"""
while not self._should_exit(get_num_finished_work()):
time.sleep(1)
return self.waited_time()
def waited_time(self):
return self._waited_time