blob: 011d28b9d44b2a0bfbbb183b8168479bd0e7c613 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.data_reader."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
from tf_agents.trajectories import trajectory
from compiler_opt.rl import constant
from compiler_opt.rl import data_reader
def _define_sequence_example(agent_name, is_action_discrete):
example = tf.train.SequenceExample()
for _ in range(10):
example.feature_lists.feature_list['feature_key'].feature.add(
).int64_list.value.append(1)
if is_action_discrete:
example.feature_lists.feature_list['inlining_decision'].feature.add(
).int64_list.value.append(0)
else:
example.feature_lists.feature_list['live_interval_weight'].feature.add(
).float_list.value.append(1.23)
example.feature_lists.feature_list['reward'].feature.add(
).float_list.value.append(2.3)
if agent_name == constant.AgentName.PPO:
if is_action_discrete:
example.feature_lists.feature_list[
'CategoricalProjectionNetwork_logits'].feature.add(
).float_list.value.extend([1.2, 3.4])
else:
example.feature_lists.feature_list[
'NormalProjectionNetwork_scale'].feature.add(
).float_list.value.extend([1.2])
example.feature_lists.feature_list[
'NormalProjectionNetwork_loc'].feature.add(
).float_list.value.extend([3.4])
return example
def _write_tmp_tfrecord(file_path, example, num_examples):
with tf.io.TFRecordWriter(file_path) as file_writer:
for _ in range(num_examples):
file_writer.write(example.SerializeToString())
class DataReaderTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
self._agent_name = constant.AgentName.DQN
observation_spec = {
'feature_key':
tf.TensorSpec(dtype=tf.int64, shape=(), name='feature_key')
}
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
self._time_step_spec = time_step.time_step_spec(observation_spec,
reward_spec)
self._discrete_action_spec = tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(),
minimum=0,
maximum=1,
name='inlining_decision')
self._continuous_action_spec = tensor_spec.BoundedTensorSpec(
dtype=tf.float32,
shape=(),
name='live_interval_weight',
minimum=-100,
maximum=20)
super(DataReaderTest, self).setUp()
@parameterized.named_parameters(
('SequenceExampleDatasetFn',
data_reader.create_sequence_example_dataset_fn),
('TFRecordDatasetFn', data_reader.create_tfrecord_dataset_fn))
def test_create_dataset_fn(self, test_fn):
example = _define_sequence_example(
self._agent_name, is_action_discrete=True)
data_source = None
if test_fn == data_reader.create_sequence_example_dataset_fn:
data_source = [example.SerializeToString() for _ in range(100)]
elif test_fn == data_reader.create_tfrecord_dataset_fn:
data_source = os.path.join(self.get_temp_dir(), 'data_tfrecord')
_write_tmp_tfrecord(data_source, example, 100)
dataset_fn = test_fn(
agent_name=self._agent_name,
time_step_spec=self._time_step_spec,
action_spec=self._discrete_action_spec,
batch_size=2,
train_sequence_length=3)
data_iterator = iter(dataset_fn(data_source).repeat())
experience = next(data_iterator)
self.assertIsInstance(experience, trajectory.Trajectory)
self.assertAllEqual([2, 3], experience.step_type.shape)
self.assertCountEqual(['feature_key'],
experience.observation.keys())
self.assertAllEqual([[1, 1, 1], [1, 1, 1]],
experience.observation['feature_key'])
self.assertAllEqual([[0, 0, 0], [0, 0, 0]], experience.action)
self.assertEmpty(experience.policy_info)
self.assertAllEqual([2, 3], experience.next_step_type.shape)
self.assertAllClose([[2.3, 2.3, 2.3], [2.3, 2.3, 2.3]], experience.reward)
self.assertAllEqual([[1, 1, 1], [1, 1, 1]], experience.discount)
@parameterized.named_parameters(
('SequenceExampleDatasetFn',
data_reader.create_sequence_example_dataset_fn),
('TFRecordDatasetFn', data_reader.create_tfrecord_dataset_fn))
def test_ppo_policy_info_discrete(self, test_fn):
self._agent_name = constant.AgentName.PPO
example = _define_sequence_example(
self._agent_name, is_action_discrete=True)
data_source = None
if test_fn == data_reader.create_sequence_example_dataset_fn:
data_source = [example.SerializeToString() for _ in range(100)]
elif test_fn == data_reader.create_tfrecord_dataset_fn:
data_source = os.path.join(self.get_temp_dir(), 'data_tfrecord')
_write_tmp_tfrecord(data_source, example, 100)
dataset_fn = test_fn(
agent_name=self._agent_name,
time_step_spec=self._time_step_spec,
action_spec=self._discrete_action_spec,
batch_size=2,
train_sequence_length=3)
data_iterator = iter(dataset_fn(data_source).repeat())
experience = next(data_iterator)
self.assertAllEqual(['feature_key'], list(experience.observation.keys()))
self.assertAllClose([[[1.2, 3.4], [1.2, 3.4], [1.2, 3.4]],
[[1.2, 3.4], [1.2, 3.4], [1.2, 3.4]]],
experience.policy_info['dist_params']['logits'])
@parameterized.named_parameters(
('SequenceExampleDatasetFn',
data_reader.create_sequence_example_dataset_fn),
('TFRecordDatasetFn', data_reader.create_tfrecord_dataset_fn))
def test_ppo_policy_info_continuous(self, test_fn):
self._agent_name = constant.AgentName.PPO
example = _define_sequence_example(
self._agent_name, is_action_discrete=False)
data_source = None
if test_fn == data_reader.create_sequence_example_dataset_fn:
data_source = [example.SerializeToString() for _ in range(100)]
elif test_fn == data_reader.create_tfrecord_dataset_fn:
data_source = os.path.join(self.get_temp_dir(), 'data_tfrecord')
_write_tmp_tfrecord(data_source, example, 100)
dataset_fn = test_fn(
agent_name=self._agent_name,
time_step_spec=self._time_step_spec,
action_spec=self._continuous_action_spec,
batch_size=2,
train_sequence_length=3)
data_iterator = iter(dataset_fn(data_source).repeat())
experience = next(data_iterator)
self.assertAllEqual(['feature_key'], list(experience.observation.keys()))
self.assertAllClose([[1.23, 1.23, 1.23], [1.23, 1.23, 1.23]],
experience.action)
self.assertAllClose([[1.2, 1.2, 1.2], [1.2, 1.2, 1.2]],
experience.policy_info['dist_params']['scale'])
self.assertAllClose([[3.4, 3.4, 3.4], [3.4, 3.4, 3.4]],
experience.policy_info['dist_params']['loc'])
if __name__ == '__main__':
tf.test.main()