blob: df9ca41f84398fffdd0ab3d501a36b25dc2d14c2 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLVM Policy Trainer."""
import time
from absl import logging
import gin
import tensorflow as tf
from tf_agents.policies import policy_loader
from tf_agents.utils import common as common_utils
_INLINING_DEFAULT_KEY = 'inlining_default'
@gin.configurable
class Trainer(object):
"""Object that trains LLVM policy.
After initialization, the function 'train' can be called multiple times to
train on different datasets. An example usage:
```python
trainer = Trainer(root_dir, agent)
trainer.train(data_iter_1, num_iterations_1)
trainer.train(data_iter_2, num_iterations_2)
```
"""
def __init__(
self,
root_dir,
agent,
random_network_distillation=None,
warmstart_policy_dir=None,
# Params for summaries and logging
checkpoint_interval=10000,
log_interval=100,
summary_interval=1000,
summaries_flush_secs=10):
"""Initialize the Trainer object.
Args:
root_dir: str, the root directory to host all required sub-directories.
agent: a tf_agents.agents.TFAgent object.
random_network_distillation: a random_net_distillation.RND object.
warmstart_policy_dir: the directory to warmstart the policy if given.
checkpoint_interval: int, the training step interval for saving
checkpoint.
log_interval: int, the training step interval for logging.
summary_interval: int, the training step interval for exporting to
tensorboard.
summaries_flush_secs: int, the seconds for flushing to tensorboard.
"""
self._root_dir = root_dir
self._agent = agent
self._random_network_distillation = random_network_distillation
self._checkpoint_interval = checkpoint_interval
self._log_interval = log_interval
self._summary_interval = summary_interval
self._summary_writer = tf.summary.create_file_writer(
self._root_dir, flush_millis=summaries_flush_secs * 1000)
self._summary_writer.set_as_default()
self._global_step = tf.compat.v1.train.get_or_create_global_step()
# Initialize agent and trajectory replay.
# Wrap training and trajectory replay in a tf.function to make it much
# faster.
self._agent.initialize()
self._agent.train = common_utils.function(self._agent.train)
if self._random_network_distillation:
self._random_network_distillation.train = common_utils.function(
self._random_network_distillation.train)
self._initialize_metrics()
# Load warmstart policy before restoring from checkpoint.
if warmstart_policy_dir:
warmstart_policy = policy_loader.load(warmstart_policy_dir)
self._agent.policy.update(
policy=warmstart_policy,
tau=1.0,
tau_non_trainable=None,
sort_variables_by_name=False)
self._checkpointer = common_utils.Checkpointer(
ckpt_dir=self._root_dir,
agent=self._agent,
global_step=self._global_step)
self._checkpointer.initialize_or_restore()
self._start_time = time.time()
self._last_checkpoint_step = 0
self._last_log_step = 0
def _initialize_metrics(self):
"""Initializes metrics."""
self._data_action_mean = tf.keras.metrics.Mean()
self._data_reward_mean = tf.keras.metrics.Mean()
self._num_trajectories = tf.keras.metrics.Sum()
def _update_metrics(self, experience, monitor_dict):
"""Updates metrics and exports to Tensorboard."""
is_action = ~experience.is_boundary()
self._data_action_mean.update_state(
experience.action, sample_weight=is_action)
self._data_reward_mean.update_state(
experience.reward, sample_weight=is_action)
self._num_trajectories.update_state(experience.is_first())
with tf.name_scope('default/'):
tf.summary.scalar(
name='data_action_mean',
data=self._data_action_mean.result(),
step=self._global_step)
tf.summary.scalar(
name='data_reward_mean',
data=self._data_reward_mean.result(),
step=self._global_step)
tf.summary.scalar(
name='num_trajectories',
data=self._num_trajectories.result(),
step=self._global_step)
for name_scope, d in monitor_dict.items():
with tf.name_scope(name_scope + '/'):
for key, value in d.items():
tf.summary.scalar(name=key, data=value, step=self._global_step)
tf.summary.histogram(
name='reward', data=experience.reward, step=self._global_step)
def _reset_metrics(self):
"""Reset num_trajectories."""
self._num_trajectories.reset_states()
def _log_experiment(self, loss):
"""Log training info."""
global_step_val = self._global_step.numpy()
if global_step_val - self._last_log_step >= self._log_interval:
logging.info('step = %d, loss = %g', global_step_val, loss)
time_acc = time.time() - self._start_time
steps_per_sec = (global_step_val - self._last_log_step) / time_acc
logging.info('%.3f steps/sec', steps_per_sec)
self._last_log_step = global_step_val
self._start_time = time.time()
def _save_checkpoint(self):
if (self._global_step.numpy() - self._last_checkpoint_step >=
self._checkpoint_interval):
self._checkpointer.save(global_step=self._global_step)
self._last_checkpoint_step = self._global_step.numpy()
def global_step_numpy(self):
return self._global_step.numpy()
def train(self, dataset_iter, monitor_dict, num_iterations):
"""Trains policy with data from dataset_iter for num_iterations steps."""
self._reset_metrics()
with tf.summary.record_if(
lambda: tf.math.equal(self._global_step % self._summary_interval, 0)):
for _ in range(num_iterations):
experience = next(dataset_iter)
# random network distillation for intrinsic reward generation
if self._random_network_distillation:
experience = self._random_network_distillation.train(experience)
loss = self._agent.train(experience)
self._update_metrics(experience, monitor_dict)
self._log_experiment(loss.loss)
self._save_checkpoint()