Type annotations and typing fixes in `policy_saver.py` (#305)
diff --git a/compiler_opt/rl/policy_saver.py b/compiler_opt/rl/policy_saver.py
index 9ba4742..c285dd4 100644
--- a/compiler_opt/rl/policy_saver.py
+++ b/compiler_opt/rl/policy_saver.py
@@ -23,6 +23,7 @@
import tensorflow as tf
from tf_agents.policies import tf_policy
from tf_agents.policies import policy_saver
+from tf_agents.typing import types as tf_agents_types
from typing import Dict, Tuple
@@ -34,11 +35,11 @@
}
-def _split_tensor_name(name):
+def _split_tensor_name(name: str) -> Tuple[str, int]:
"""Return tuple (op, port) with the op and int port for the tensor name."""
op_port = name.split(':', 2)
if len(op_port) == 1:
- return op_port, 0
+ return op_port[0], 0
else:
return op_port[0], int(op_port[1])
@@ -166,13 +167,9 @@
for policy_name, policy in policy_dict.items()
}
- def _save_policy(self, saver, path):
- """Writes policy, model weights and model_binding.txt to path/."""
- saver.save(path)
-
- def _write_output_signature(self, saver, path):
+ def _write_output_signature(
+ self, action_signature: tf_agents_types.NestedTensorSpec, path: str):
"""Writes the output_signature json file into the SavedModel directory."""
- action_signature = saver.policy_step_spec
# We'll load the actual SavedModel to be able to map signature names to
# actual tensor names.
@@ -234,8 +231,8 @@
"""Writes policy and model_binding.txt to root_dir/policy_name/."""
for policy_name, (saver, _) in self._policy_saver_dict.items():
saved_model_dir = os.path.join(root_dir, policy_name)
- self._save_policy(saver, saved_model_dir)
- self._write_output_signature(saver, saved_model_dir)
+ saver.save(saved_model_dir)
+ self._write_output_signature(saver.policy_step_spec, saved_model_dir)
# This is not quite the most efficient way to do this - we save the model
# just to load it again and save it as tflite - but it's the minimum,
# temporary step so we can validate more thoroughly our use of tflite.