blob: 8b350528ab34bc53d97c226a77ee20e147acee72 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util functions to create and edit a tf_agent policy."""
from typing import Protocol, Sequence, Type
import gin
import numpy as np
import numpy.typing as npt
import tensorflow as tf
from tf_agents.networks import network
from tf_agents.policies import actor_policy
from tf_agents.policies import greedy_policy
from tf_agents.policies import tf_policy
from compiler_opt.rl import policy_saver
from compiler_opt.rl import registry
class HasModelVariables(Protocol):
model_variables: Sequence[tf.Variable]
# TODO(abenalaast): Issue #280
@gin.configurable(module='policy_utils')
def create_actor_policy(
actor_network_ctor: Type[network.DistributionNetwork],
greedy: bool = False,
) -> tf_policy.TFPolicy:
"""Creates an actor policy."""
problem_config = registry.get_configuration()
time_step_spec, action_spec = problem_config.get_signature_spec()
layers = tf.nest.map_structure(
problem_config.get_preprocessing_layer_creator(),
time_step_spec.observation)
actor_network = actor_network_ctor(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers)
policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)
if greedy:
policy = greedy_policy.GreedyPolicy(policy)
return policy
def get_vectorized_parameters_from_policy(
policy: 'tf_policy.TFPolicy | HasModelVariables'
) -> npt.NDArray[np.float32]:
"""Returns a policy's variable values as a single np array."""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
f'Passed policy: {policy}')
parameters = [var.numpy().flatten() for var in variables]
parameters = np.concatenate(parameters, axis=0)
return parameters
def set_vectorized_parameters_for_policy(
policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32]) -> None:
"""Separates values in parameters into the policy's shapes
and sets the policy variables to those values"""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
f'Passed policy: {policy}')
param_pos = 0
for variable in variables:
shape = tf.shape(variable).numpy()
num_elems = np.prod(shape)
param = np.reshape(parameters[param_pos:param_pos + num_elems], shape)
variable.assign(param)
param_pos += num_elems
if param_pos != len(parameters):
raise ValueError(
f'Parameter dimensions are not matched! Expected {len(parameters)} '
f'but only found {param_pos}.')
def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32], save_folder: str,
policy_name: str) -> None:
"""Assigns a policy the name policy_name
and saves it to the directory of save_folder
with the values in parameters."""
set_vectorized_parameters_for_policy(policy, parameters)
saver = policy_saver.PolicySaver({policy_name: policy})
saver.save(save_folder)