blob: fee9a49af4cb2c53d963f79fcb8bc40457f942f1 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.policy_saver."""
import json
import os
import tensorflow as tf
from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent
from tf_agents.networks import q_rnn_network
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
from compiler_opt.rl import policy_saver
class PolicySaverTest(tf.test.TestCase):
def setUp(self):
super(PolicySaverTest, self).setUp()
observation_spec = tf.TensorSpec(
dtype=tf.int64, shape=(), name='callee_users')
self._time_step_spec = time_step.time_step_spec(observation_spec)
self._action_spec = tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(),
minimum=0,
maximum=1,
name='inlining_decision')
self._network = q_rnn_network.QRnnNetwork(
input_tensor_spec=self._time_step_spec.observation,
action_spec=self._action_spec,
lstm_size=(40,))
def test_save_policy(self):
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
self._time_step_spec, self._action_spec, self._network,
tf.compat.v1.train.AdamOptimizer())
policy_dict = {
'saved_policy': test_agent.policy,
'saved_collect_policy': test_agent.collect_policy
}
test_policy_saver = policy_saver.PolicySaver(policy_dict=policy_dict)
root_dir = self.get_temp_dir()
test_policy_saver.save(root_dir)
sub_dirs = tf.io.gfile.listdir(root_dir)
self.assertCountEqual(['saved_policy', 'saved_collect_policy'], sub_dirs)
for sub_dir in ['saved_policy', 'saved_collect_policy']:
self.assertTrue(
tf.io.gfile.exists(os.path.join(root_dir, sub_dir, 'saved_model.pb')))
self.assertTrue(
tf.io.gfile.exists(
os.path.join(root_dir, sub_dir,
'variables/variables.data-00000-of-00001')))
output_signature_fn = os.path.join(root_dir, sub_dir, 'output_spec.json')
self.assertTrue(tf.io.gfile.exists(output_signature_fn))
self.assertEqual([{
'logging_name': 'inlining_decision',
'tensor_spec': {
'name': 'StatefulPartitionedCall',
'port': 0,
'type': 'int64_t',
'shape': [1],
}
}], json.loads(tf.io.gfile.GFile(output_signature_fn).read()))
if __name__ == '__main__':
tf.test.main()