blob: 6bedfeb7270eafe1d1f8da8b636e78202ede063f [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""util function to save the policy and model config file."""
import json
import os
import tensorflow as tf
from tf_agents.policies import policy_saver
OUTPUT_SIGNATURE = "output_spec.json"
_TYPE_CONVERSION_DICT = {
tf.float32: "float",
tf.float64: "double",
tf.int8: "int8_t",
tf.uint8: "uint8_t",
tf.int16: "int16_t",
tf.uint16: "uint16_t",
tf.int32: "int32_t",
tf.uint32: "uint32_t",
tf.int64: "int64_t",
tf.uint64: "uint64_t",
}
def _split_tensor_name(name):
"""Return tuple (op, port) with the op and int port for the tensor name."""
op_port = name.split(":", 2)
if len(op_port) == 1:
return op_port, 0
else:
return op_port[0], int(op_port[1])
# TODO(b/156295309): more sophisticated way of finding tensor names.
def _get_non_identity_op(tensor):
"""Get the true output op aliased by Identity `tensor`.
Output signature tensors are in a Function that refrences the true call
in the base SavedModel metagraph. Traverse the function upwards until
we find this true output op and tensor and return that.
Args:
tensor: A tensor from the unstructured output list of a signature.
Returns:
The true associated output tensor of the original function in the main
SavedModel graph.
"""
while tensor.op.name.startswith("Identity"):
tensor = tensor.op.inputs[0]
return tensor
class PolicySaver(object):
"""Object that saves policy and model config file required by inference.
```python
policy_saver = PolicySaver(policy_dict, config)
policy_saver.save(root_dir)
```
"""
def __init__(self, policy_dict):
"""Initialize the PolicySaver object.
Args:
policy_dict: A dict mapping from policy name to policy.
"""
self._policy_saver_dict = {
policy_name: (policy_saver.PolicySaver(
policy, batch_size=1, use_nest_path_signatures=False), policy)
for policy_name, policy in policy_dict.items()
}
def _save_policy(self, saver, path):
"""Writes policy, model weights and model_binding.txt to path/."""
saver.save(path)
def _write_output_signature(self, saver, path):
"""Writes the output_signature json file into the SavedModel directory."""
action_signature = saver.policy_step_spec
# We'll load the actual SavedModel to be able to map signature names to
# actual tensor names.
saved_model = tf.saved_model.load(path)
# Dict mapping spec name to spec in flattened action signature.
sm_action_signature = (
tf.nest.flatten(saved_model.signatures["action"].structured_outputs))
# Map spec name to index in flattened outputs.
sm_action_indices = dict(
(k.name, i) for i, k in enumerate(sm_action_signature))
# List mapping flattened structured outputs to tensors.
sm_action_tensors = saved_model.signatures["action"].outputs
# First entry in output list is the decision (action)
decision_spec = tf.nest.flatten(action_signature.action)
if len(decision_spec) != 1:
raise ValueError(
"Expected action decision to have 1 tensor, but saw: {}".format(
action_signature.action))
# Find the decision's tensor in the flattened output tensor list.
sm_action_decision = (
sm_action_tensors[sm_action_indices[decision_spec[0].name]])
sm_action_decision = _get_non_identity_op(sm_action_decision)
# The first entry in the output_signature file corresponds to the decision.
(tensor_op, tensor_port) = _split_tensor_name(sm_action_decision.name)
output_list = [{
"logging_name": decision_spec[0].name, # used in SequenceExample.
"tensor_spec": {
"name": tensor_op,
"port": tensor_port,
"type": _TYPE_CONVERSION_DICT[sm_action_decision.dtype],
"shape": sm_action_decision.shape.as_list(),
}
}]
for info_spec in tf.nest.flatten(action_signature.info):
sm_action_info = sm_action_tensors[sm_action_indices[info_spec.name]]
sm_action_info = _get_non_identity_op(sm_action_info)
(tensor_op, tensor_port) = _split_tensor_name(sm_action_info.name)
output_list.append({
"logging_name": info_spec.name, # used in SequenceExample.
"tensor_spec": {
"name": tensor_op,
"port": tensor_port,
"type": _TYPE_CONVERSION_DICT[sm_action_info.dtype],
"shape": sm_action_info.shape.as_list(),
}
})
with tf.io.gfile.GFile(os.path.join(path, OUTPUT_SIGNATURE), "w") as f:
f.write(json.dumps(output_list))
def save(self, root_dir):
"""Writes policy and model_binding.txt to root_dir/policy_name/."""
for policy_name, (saver, _) in self._policy_saver_dict.items():
self._save_policy(saver, os.path.join(root_dir, policy_name))
self._write_output_signature(saver, os.path.join(root_dir, policy_name))