blob: 64285626bb755ecfb52f3a521c779f9943b7287d [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for random_network_distillation."""
from absl.testing import parameterized
import tensorflow as tf
from tf_agents.networks import encoding_network
from tf_agents.networks import expand_dims_layer
from tf_agents.trajectories import time_step
from tf_agents.trajectories import trajectory
from compiler_opt.rl import random_net_distillation
def _processing_layer_creator():
expand_dims_op = expand_dims_layer.ExpandDims(-1)
def observation_processing_layer(_):
"""Creates the toy layer to process observation."""
def discard_feature(obs):
expanded_obs = expand_dims_op(obs)
return tf.ones_like(expanded_obs, dtype=tf.float32)
func = discard_feature
return tf.keras.layers.Lambda(func)
return observation_processing_layer
def _create_test_data(batch_size, sequence_length):
test_trajectory = trajectory.Trajectory(
step_type=tf.fill([batch_size, sequence_length], 1),
observation={
'edge_count':
tf.fill([batch_size, sequence_length],
tf.constant(10, dtype=tf.int64))
},
action=tf.fill([batch_size, sequence_length],
tf.constant(1, dtype=tf.int64)),
policy_info=(),
next_step_type=tf.fill([batch_size, sequence_length], 1),
reward=tf.fill([batch_size, sequence_length], 2.0),
discount=tf.fill([batch_size, sequence_length], 1.0),
)
def test_data_iterator():
while True:
yield test_trajectory
return test_data_iterator()
class RandomNetworkDistillationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(RandomNetworkDistillationTest, self).setUp()
self._update_frequency = 1
observation_spec = {
'edge_count': tf.TensorSpec(
dtype=tf.int64, shape=(), name='edge_count')
}
self._time_step_spec = time_step.time_step_spec(observation_spec)
def test_train_on_batch(self):
data_iterator = _create_test_data(batch_size=3, sequence_length=3)
# initialize the random_network_distillation instance
random_network_distillation = random_net_distillation.RandomNetworkDistillation(
time_step_spec=self._time_step_spec,
preprocessing_layer_creator=_processing_layer_creator(),
encoding_network=encoding_network.EncodingNetwork,
update_frequency=self._update_frequency)
experience = next(data_iterator)
# test the RND train function return type
for _ in range(5):
new_experience = random_network_distillation.train(experience)
self.assertIsInstance(new_experience, trajectory.Trajectory)
# the rest of experience should remain the same except reward
self.assertAllEqual(experience.step_type, new_experience.step_type)
self.assertAllEqual(experience.observation, new_experience.observation)
self.assertAllEqual(experience.action, new_experience.action)
self.assertAllEqual(experience.policy_info, new_experience.policy_info)
self.assertAllEqual(experience.next_step_type, experience.next_step_type)
self.assertAllEqual(experience.discount, new_experience.discount)
# reward should have same shape
self.assertAllEqual(experience.reward.shape, new_experience.reward.shape)
# new reward should has finite value
self.assertFalse(tf.math.is_inf(tf.reduce_sum(new_experience.reward)))
if __name__ == '__main__':
tf.test.main()