blob: 160e35d79a65d2105871771ce712bf707d69f994 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_agents.networks.actor_distribution_network."""
import tensorflow as tf
import tensorflow_probability as tfp
from tf_agents.specs import tensor_spec
from compiler_opt.rl import regalloc_network
from compiler_opt.rl.regalloc import config
def _observation_processing_layer(obs_spec):
"""Creates the layer to process observation given obs_spec."""
def expand_progress(obs):
if obs_spec.name == 'progress':
obs = tf.expand_dims(obs, -1)
obs = tf.tile(obs, [1, config.get_num_registers()])
return tf.expand_dims(tf.cast(obs, tf.float32), -1)
return tf.keras.layers.Lambda(expand_progress)
class RegAllocNetworkTest(tf.test.TestCase):
def setUp(self):
time_step_spec, action_spec = config.get_regalloc_signature_spec()
random_observation = tensor_spec.sample_spec_nest(
time_step_spec, outer_dims=(2, 3))
super(RegAllocNetworkTest, self).setUp()
self._time_step_spec = time_step_spec
self._action_spec = action_spec
self._random_observation = random_observation
def testBuilds(self):
layers = tf.nest.map_structure(_observation_processing_layer,
self._time_step_spec.observation)
net = regalloc_network.RegAllocNetwork(
self._time_step_spec.observation,
self._action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(10,))
action_distributions, _ = net(
self._random_observation.observation,
step_type=self._random_observation.step_type)
self.assertIsInstance(action_distributions, tfp.distributions.Categorical)
self.assertEqual([2, 3], action_distributions.mode().shape.as_list())
self.assertAllInRange(action_distributions.mode(), 0,
config.get_num_registers() - 1)
if __name__ == '__main__':
tf.test.main()