blob: abf0bf1ffcbd2b745cc58274fb10cc7b50e72685 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""util function to create training datasets."""
from typing import Callable, List
import tensorflow as tf
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.typing import types
from compiler_opt.rl import constant
def _get_policy_info_parsing_dict(agent_name, action_spec):
"""Function to get parsing dict for policy info."""
if agent_name == constant.AgentName.PPO:
if tensor_spec.is_discrete(action_spec):
return {
'CategoricalProjectionNetwork_logits':
tf.io.FixedLenSequenceFeature(
shape=(action_spec.maximum - action_spec.minimum + 1),
dtype=tf.float32)
}
else:
return {
'NormalProjectionNetwork_scale':
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32),
'NormalProjectionNetwork_loc':
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32)
}
return {}
def _process_parsed_sequence_and_get_policy_info(parsed_sequence, agent_name,
action_spec):
"""Function to process parsed_sequence and to return policy_info.
Args:
parsed_sequence: A dict from feature_name to feature_value parsed from TF
SequenceExample.
agent_name: AgentName, enum type of the agent.
action_spec: action spec of the optimization problem.
Returns:
policy_info: A nested policy_info for given agent.
"""
if agent_name == constant.AgentName.PPO:
if tensor_spec.is_discrete(action_spec):
policy_info = {
'dist_params': {
'logits': parsed_sequence['CategoricalProjectionNetwork_logits']
}
}
del parsed_sequence['CategoricalProjectionNetwork_logits']
else:
policy_info = {
'dist_params': {
'scale': parsed_sequence['NormalProjectionNetwork_scale'],
'loc': parsed_sequence['NormalProjectionNetwork_loc']
}
}
del parsed_sequence['NormalProjectionNetwork_scale']
del parsed_sequence['NormalProjectionNetwork_loc']
return policy_info
else:
return ()
def create_parser_fn(
agent_name: constant.AgentName, time_step_spec: types.NestedSpec,
action_spec: types.NestedSpec) -> Callable[[str], trajectory.Trajectory]:
"""Create a parser function for reading from a serialized tf.SequenceExample.
Args:
agent_name: AgentName, enum type of the agent.
time_step_spec: time step spec of the optimization problem.
action_spec: action spec of the optimization problem.
Returns:
A callable that takes scalar serialized proto Tensors and emits
`Trajectory` objects containing parsed tensors.
"""
def _parser_fn(serialized_proto):
"""Helper function that is returned by create_`parser_fn`."""
# We copy through all context features at each frame, so even though we know
# they don't change from frame to frame, they are still sequence features
# and stored in the feature list.
context_features = {}
# pylint: disable=g-complex-comprehension
sequence_features = dict(
(tensor_spec.name,
tf.io.FixedLenSequenceFeature(
shape=tensor_spec.shape, dtype=tensor_spec.dtype))
for tensor_spec in time_step_spec.observation.values())
sequence_features[action_spec.name] = tf.io.FixedLenSequenceFeature(
shape=action_spec.shape, dtype=action_spec.dtype)
sequence_features[
time_step_spec.reward.name] = tf.io.FixedLenSequenceFeature(
shape=time_step_spec.reward.shape,
dtype=time_step_spec.reward.dtype)
sequence_features.update(
_get_policy_info_parsing_dict(agent_name, action_spec))
# pylint: enable=g-complex-comprehension
with tf.name_scope('parse'):
_, parsed_sequence = tf.io.parse_single_sequence_example(
serialized_proto,
context_features=context_features,
sequence_features=sequence_features)
# TODO(yundi): make the transformed reward configurable.
action = parsed_sequence[action_spec.name]
reward = tf.cast(parsed_sequence[time_step_spec.reward.name], tf.float32)
policy_info = _process_parsed_sequence_and_get_policy_info(
parsed_sequence, agent_name, action_spec)
del parsed_sequence[time_step_spec.reward.name]
del parsed_sequence[action_spec.name]
full_trajectory = trajectory.from_episode(
observation=parsed_sequence,
action=action,
policy_info=policy_info,
reward=reward)
return full_trajectory
return _parser_fn
def create_sequence_example_dataset_fn(
agent_name: constant.AgentName, time_step_spec: types.NestedSpec,
action_spec: types.NestedSpec, batch_size: int, train_sequence_length: int
) -> Callable[[List[str]], tf.data.Dataset]:
"""Get a function that creates a dataset from serialized sequence examples.
Args:
agent_name: AgentName, enum type of the agent.
time_step_spec: time step spec of the optimization problem.
action_spec: action spec of the optimization problem.
batch_size: int, batch_size B.
train_sequence_length: int, trajectory sequence length T.
Returns:
A callable that takes a list of serialized sequence examples and returns
a `tf.data.Dataset`. Treating this dataset as an iterator yields batched
`trajectory.Trajectory` instances with shape `[B, T, ...]`.
"""
trajectory_shuffle_buffer_size = 1024
parser_fn = create_parser_fn(agent_name, time_step_spec, action_spec)
def _sequence_example_dataset_fn(sequence_examples):
# Data collector returns empty strings for corner cases, filter them out
# here.
dataset = tf.data.Dataset.from_tensor_slices(sequence_examples).filter(
lambda string: tf.strings.length(string) > 0).map(parser_fn).filter(
lambda traj: tf.size(traj.reward) > 2)
dataset = (
dataset.unbatch().batch(
train_sequence_length,
drop_remainder=True).shuffle(trajectory_shuffle_buffer_size).batch(
batch_size,
drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
return dataset
return _sequence_example_dataset_fn
# TODO(yundi): PyType check of input_dataset as Type[tf.data.Dataset] is not
# working.
def create_file_dataset_fn(
agent_name: constant.AgentName, time_step_spec: types.NestedSpec,
action_spec: types.NestedSpec, batch_size: int, train_sequence_length: int,
input_dataset) -> Callable[[List[str]], tf.data.Dataset]:
"""Get a function that creates an dataset from files.
Args:
agent_name: AgentName, enum type of the agent.
time_step_spec: time step spec of the optimization problem.
action_spec: action spec of the optimization problem.
batch_size: int, batch_size B.
train_sequence_length: int, trajectory sequence length T.
input_dataset: A tf.data.Dataset subclass object.
Returns:
A callable that takes file path(s) and returns a `tf.data.Dataset`.
Iterating over this dataset yields `trajectory.Trajectory` instances with
shape `[B, T, ...]`.
"""
files_buffer_size = 100
num_readers = 10
num_map_threads = 8
shuffle_buffer_size = 1024
trajectory_shuffle_buffer_size = 1024
parser_fn = create_parser_fn(agent_name, time_step_spec, action_spec)
def _file_dataset_fn(data_path):
dataset = (
tf.data.Dataset.list_files(data_path).shuffle(
files_buffer_size).interleave(
input_dataset, cycle_length=num_readers, block_length=1)
# Due to a bug in collection, we sometimes get empty rows.
.filter(lambda string: tf.strings.length(string) > 0).apply(
tf.data.experimental.shuffle_and_repeat(shuffle_buffer_size)).map(
parser_fn, num_parallel_calls=num_map_threads)
# Only keep sequences of length 2 or more.
.filter(lambda traj: tf.size(traj.reward) > 2))
# TODO(yundi): window and subsample data.
# TODO(yundi): verify the shuffling is correct.
dataset = (
dataset.unbatch().batch(
train_sequence_length,
drop_remainder=True).shuffle(trajectory_shuffle_buffer_size).batch(
batch_size,
drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
return dataset
return _file_dataset_fn
def create_tfrecord_dataset_fn(
agent_name: constant.AgentName, time_step_spec: types.NestedSpec,
action_spec: types.NestedSpec, batch_size: int, train_sequence_length: int
) -> Callable[[List[str]], tf.data.Dataset]:
"""Get a function that creates an dataset from tfrecord.
Args:
agent_name: AgentName, enum type of the agent.
time_step_spec: time step spec of the optimization problem.
action_spec: action spec of the optimization problem.
batch_size: int, batch_size B.
train_sequence_length: int, trajectory sequence length T.
Returns:
A callable that takes tfrecord path(s) and returns a `tf.data.Dataset`.
Iterating over this dataset yields `trajectory.Trajectory` instances with
shape `[B, T, ...]`.
"""
return create_file_dataset_fn(
agent_name,
time_step_spec,
action_spec,
batch_size,
train_sequence_length,
input_dataset=tf.data.TFRecordDataset)