blob: f78ce91d0fdae162cafc23e2e5e4501e466438f4 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility to create MLGO policy learner."""
from typing import Callable, List, Optional, Text, Tuple
from absl import logging
import tensorflow as tf
from tf_agents.agents.ppo import ppo_agent
from tf_agents.train import interval_trigger
from tf_agents.train import learner
from tf_agents.typing import types
# A function which processes a tuple of a nested tensor representing a TF-Agent
# Trajectory and Reverb SampleInfo.
_SequenceParamsType = Tuple[types.NestedTensor, types.ReverbSampleInfo]
_SequenceFnType = Callable[[_SequenceParamsType], _SequenceParamsType]
class MLGOPPOLearner(object):
"""Manages all the learning details needed.
These include:
* Using distribution strategies correctly
* Summaries
* Checkpoints
* Minimizing entering/exiting TF context:
Especially in the case of TPUs scheduling a single TPU program to
perform multiple train steps is critical for performance.
* Generalizes the train call to be done correctly across CPU, GPU, or TPU
executions managed by DistributionStrategies. This uses `strategy.run` and
then makes sure to do a reduce operation over the `LossInfo` returned by
the agent.
"""
def __init__(self,
root_dir: Text,
train_step: tf.Variable,
model_id: tf.Variable,
agent: ppo_agent.PPOAgent,
experience_dataset_fn: Callable[[], tf.data.Dataset],
sequence_length: int,
num_episodes_per_iteration: int,
minibatch_size: int,
shuffle_buffer_size: int,
num_epochs: int = 1,
triggers: Optional[List[
interval_trigger.IntervalTrigger]] = None,
checkpoint_interval: int = 100000,
summary_interval: int = 1000,
strategy: Optional[tf.distribute.Strategy] = None,
per_sequence_fn: Optional[_SequenceFnType] = None,
allow_variable_length_episodes: bool = False) -> None:
"""Initializes a MLGOPPOLearner instance.
Args:
root_dir: Main directory path where checkpoints, saved_models, and
summaries will be written to.
train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
number of train steps. This is used for artifacts created like
summaries, or outputs in the root_dir.
model_id: a scalar tf.int64 `tf.Variable` which will keep track of the
number of learner iterations / policy updates.
agent: `ppo_agent.PPOAgent` instance to train with. Note that
update_normalizers_in_train should be set to `False`, otherwise a
ValueError will be raised. We do not update normalizers in the agent
again because we already update it in the learner. When mini batching is
enabled, compute_value_and_advantage_in_train should be set to False,
and preprocessing should be done as part of the data pipeline as part of
`replay_buffer.as_dataset`.
experience_dataset_fn: a function that will create an instance of a
tf.data.Dataset used to sample experience for training. Each element in
the dataset is a (Trajectory, SampleInfo) pair.
sequence_length: Fixed sequence length for elements in the dataset. Used
for calculating how many iterations of minibatches to use for training.
num_episodes_per_iteration: The number of episodes to sample for training.
If fewer than this amount of episodes exists in the dataset, the learner
will wait for more data to be added, or until the reverb timeout is
reached.
minibatch_size: The minibatch size. The dataset used for training is
shaped `[minibatch_size, 1, ...]`. If None, full sequences will be fed
into the agent. Please set this parameter to None for RNN networks which
requires full sequences.
shuffle_buffer_size: The buffer size for shuffling the trajectories before
splitting them into mini batches. Only required when mini batch learning
is enabled (minibatch_size is set). Otherwise it is ignored. Commonly
set to a number 1-3x the episode length of your environment.
num_epochs: The number of iterations to go through the same sequences.
triggers: List of callables of the form `trigger(train_step)`. After every
`run` call every trigger is called with the current `train_step` value
as an np scalar.
checkpoint_interval: Number of train steps in between checkpoints. Note
these are placed into triggers and so a check to generate a checkpoint
only occurs after every `run` call. Set to -1 to disable (this is not
recommended, because it means that if the pipeline gets preempted, all
previous progress is lost). This only takes care of the checkpointing
the training process. Policies must be explicitly exported through
triggers.
summary_interval: Number of train steps in between summaries. Note these
are placed into triggers and so a check to generate a checkpoint only
occurs after every `run` call.
strategy: (Optional) `tf.distribute.Strategy` to use during training.
per_sequence_fn: (Optional): sequence-wise preprecessing, pass in agent.
preprocess for advantage calculation. This operation happens after
take() and before rebatching.
allow_variable_length_episodes: Whether to support variable length
episodes for training.
Raises:
ValueError: agent._compute_value_and_advantage_in_train is set to `True`.
preprocessing must be done as part of the data pipeline when mini
batching is enabled.
"""
self._strategy = strategy or tf.distribute.get_strategy()
self._agent = agent
self._minibatch_size = minibatch_size
self._shuffle_buffer_size = shuffle_buffer_size
self._num_epochs = num_epochs
self._experience_dataset_fn = experience_dataset_fn
self._num_episodes_per_iteration = num_episodes_per_iteration
# Tracks the number of times learner.run() has been called.
# This is used for filtering out data generated by older models to ensure
# the on policyness of the algorithm.
self._model_id = model_id
self._sequence_length = sequence_length
self._per_sequence_fn = per_sequence_fn
self._generic_learner = learner.Learner(
root_dir,
train_step,
agent,
after_train_strategy_step_fn=None,
triggers=triggers,
checkpoint_interval=checkpoint_interval,
summary_interval=summary_interval,
use_kwargs_in_agent_train=False,
strategy=self._strategy)
self.num_replicas = self._strategy.num_replicas_in_sync
self._allow_variable_length_episodes = allow_variable_length_episodes
self._num_samples = self._num_episodes_per_iteration * self._sequence_length
self._create_datasets()
self._steps_per_iter = self._get_train_steps_per_iteration()
logging.info('[Learner] Replicas for training: %d', self.num_replicas)
logging.info('[Learner] Samples per iteration: %d', self._num_samples)
logging.info('[Learner] Train steps per iteration: %d',
self._steps_per_iter)
def _create_datasets(self):
"""Create the training dataset and iterator."""
def _filter_invalid_episodes(sample):
sample_info = sample.info
data_model_id = tf.cast(
tf.reduce_min(sample_info.priority), dtype=tf.int64)
if self._allow_variable_length_episodes:
# Filter off policy samples.
return tf.math.equal(self._model_id, data_model_id)
else:
# Filter infeasible placements with shorter episode lengths than
# expected along with off policy samples.
data = sample.data
return tf.math.logical_and(
tf.math.equal(tf.size(data.discount), self._sequence_length),
tf.math.equal(self._model_id, data_model_id))
def make_dataset(_):
# `experience_dataset_fn` returns a tf.Dataset. Each item is a (Trajectory
# , SampleInfo) tuple, and the Trajectory represents one single episode
# of a fixed sequence length. The Trajectory dimensions are [1, T, ...].
train_dataset = self._experience_dataset_fn()
train_dataset = train_dataset.filter(_filter_invalid_episodes)
if self._per_sequence_fn:
train_dataset = train_dataset.map(
self._per_sequence_fn,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
# We unbatch the dataset shaped [B, T, ...] to a new dataset that
# contains individual elements.
# Note that we unbatch across the time dimension, which could result
# in mini batches that contain subsets from more than one sequences.
# PPO agent can handle mini batches across episode boundaries.
train_dataset = train_dataset.unbatch()
train_dataset = train_dataset.batch(1, drop_remainder=True)
train_dataset = train_dataset.shuffle(self._shuffle_buffer_size)
#train_dataset = train_dataset.take(self._num_samples)
#train_dataset = train_dataset.cache()
train_dataset = train_dataset.repeat(self._num_epochs)
train_dataset = train_dataset.batch(
self._minibatch_size, drop_remainder=True)
#train_dataset = train_dataset.repeat()
options = tf.data.Options()
options.deterministic = False
options.experimental_optimization.parallel_batch = True
train_dataset = train_dataset.with_options(options)
return train_dataset
with self._strategy.scope():
if self._strategy.num_replicas_in_sync > 1:
self._train_dataset = (
self._strategy.distribute_datasets_from_function(make_dataset))
else:
self._train_dataset = make_dataset(0)
self._train_iterator = iter(self._train_dataset)
def _get_train_steps_per_iteration(self):
"""Number of train steps each time learner.run() is called."""
# We exhaust all num_episodes_per_iteration taken from Reverb in this setup.
# Here we assume that there's only 1 episode per batch, and each episode is
# of the fixed sequence length.
num_mini_batches = int(self._num_samples * self._num_epochs /
self._minibatch_size)
train_steps = int(num_mini_batches / self.num_replicas)
return train_steps
def wait_for_data(self):
"""Blocking call on dataset."""
traj, sample_info = next(self._train_iterator)
del traj
del sample_info
def run(self):
"""Train `num_episodes_per_iteration` repeating for `num_epochs`.
Returns:
The total loss computed before running the final step.
"""
loss_info = self._generic_learner.run(self._steps_per_iter,
self._train_iterator)
self._model_id.assign_add(1)
self._create_datasets()
return loss_info
@property
def train_step_numpy(self):
"""The current train_step.
Returns:
The current `train_step`. Note this will return a scalar numpy array which
holds the `train_step` value when this was called.
"""
return self._generic_learner.train_step_numpy