[nfc] alias tf_agents.policies.policy_saver as tf_policy_saver (#528)
diff --git a/compiler_opt/rl/policy_saver.py b/compiler_opt/rl/policy_saver.py
index fc88e75..abf8aae 100644
--- a/compiler_opt/rl/policy_saver.py
+++ b/compiler_opt/rl/policy_saver.py
@@ -21,7 +21,7 @@
import tensorflow as tf
from tf_agents.policies import tf_policy
-from tf_agents.policies import policy_saver
+from tf_agents.policies import policy_saver as tf_policy_saver
from tf_agents.typing import types as tf_agents_types
OUTPUT_SIGNATURE = 'output_spec.json'
@@ -158,8 +158,8 @@
policy_dict: A dict mapping from policy name to policy.
"""
self._policy_saver_dict: dict[str, tuple[
- policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
- policy_name: (policy_saver.PolicySaver(
+ tf_policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
+ policy_name: (tf_policy_saver.PolicySaver(
policy, batch_size=1, use_nest_path_signatures=False), policy
) for policy_name, policy in policy_dict.items()
}