blob: 714665bfd7c94d706980a66203232a3465d62734 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Local ES trainer."""
from absl import flags, logging
import enum
import functools
import gin
import tensorflow as tf
import os
# Pytype cannot pick up the pyi file for tensorflow.summary. Disable the error
# here as these errors are false positives.
# pytype: disable=pyi-error
from compiler_opt.distributed import worker_manager
from compiler_opt.distributed.local import local_worker_manager
from compiler_opt.es import blackbox_optimizers
from compiler_opt.es import gradient_ascent_optimization_algorithms
from compiler_opt.es import blackbox_learner
from compiler_opt.es import policy_utils
from compiler_opt.rl import corpus
FLAGS = flags.FLAGS
_GRAD_REG_ALPHA = flags.DEFINE_float(
"grad_reg_alpha", 0.01,
"Weight of regularization term in regression gradient.")
_GRAD_REG_TYPE = flags.DEFINE_string(
"grad_reg_type", "ridge",
"Regularization method to use with regression gradient.")
_OUTPUT_PATH = flags.DEFINE_string("output_path", "",
"Path to write all output")
_PRETRAINED_POLICY_PATH = flags.DEFINE_string(
"pretrained_policy_path", None,
"The path of the pretrained policy. If not provided, it will \
construct a new policy with randomly initialized weights.")
_REQUEST_DEADLINE = flags.DEFINE_float(
"request_deadline", 30.0, "Deadline in seconds for requests \
to the data collection requests.")
_TRAIN_CORPORA = flags.DEFINE_string("train_corpora", "",
"List of paths to training corpora")
_NUM_WORKERS = flags.DEFINE_integer("num_workers", 100,
"The number of workers to create.")
@gin.constants_from_enum(module="es_trainer_lib")
class GradientAscentOptimizerType(enum.Enum):
INVALID = 0
MOMENTUM = enum.auto()
ADAM = enum.auto()
@gin.configurable
def train(additional_compilation_flags=(),
delete_compilation_flags=(),
replace_compilation_flags=(),
worker_class=None,
beta1=0.9,
beta2=0.999,
momentum=0.0,
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM,
worker_manager_class: type[
worker_manager.WorkerManager] = local_worker_manager
.LocalWorkerPoolManager):
"""Train with ES."""
if not _TRAIN_CORPORA.value:
raise ValueError("Need to supply nonempty train corpora.")
# Create directories
if not tf.io.gfile.isdir(_OUTPUT_PATH.value):
tf.io.gfile.makedirs(_OUTPUT_PATH.value)
# Construct the policy and upload it
policy = policy_utils.create_actor_policy()
# Get initial parameter
if not _PRETRAINED_POLICY_PATH.value:
# Use randomly initialized parameters
logging.info("Use random parameters")
initial_parameters = policy_utils.get_vectorized_parameters_from_policy(
policy)
logging.info("Parameter dimension: %s", initial_parameters.shape)
logging.info("Initial parameters: %s", initial_parameters)
else:
# Read the parameters from the pretrained policy
logging.info("Reading policy parameters from %s",
_PRETRAINED_POLICY_PATH.value)
# Load the policy
pretrained_policy = tf.saved_model.load(_PRETRAINED_POLICY_PATH.value)
initial_parameters = policy_utils.get_vectorized_parameters_from_policy(
pretrained_policy)
policy_parameter_dimension = (
policy_utils.get_vectorized_parameters_from_policy(policy).shape[0])
if policy_parameter_dimension != initial_parameters.shape[0]:
raise ValueError("Pretrained policy dimension is incorrect")
logging.info("Parameter dimension: %s", initial_parameters.shape)
logging.info("Initial parameters: %s", initial_parameters)
cps = corpus.Corpus(
data_path=_TRAIN_CORPORA.value,
additional_flags=additional_compilation_flags,
delete_flags=delete_compilation_flags,
replace_flags=replace_compilation_flags)
# Construct policy saver
saved_policy = policy_utils.create_actor_policy()
policy_saver_function = functools.partial(
policy_utils.save_policy,
policy=saved_policy,
save_folder=os.path.join(_OUTPUT_PATH.value, "saved_policies"))
# Get learner config
learner_config = blackbox_learner.BlackboxLearnerConfig()
# the following are from Blackbox Library.
init_current_input = initial_parameters
init_iteration = 0
metaparams = [] # Ignore meta params for state normalization for now
# TODO(linzinan): delete all unused parameters.
# ------------------ GRADIENT ASCENT OPTIMIZERS ------------------------------
if gradient_ascent_optimizer_type == GradientAscentOptimizerType.MOMENTUM:
logging.info("Running momentum gradient ascent optimizer")
# You can obtain a vanilla gradient ascent optimizer by setting momentum=0.0
# and setting step_size to the desired learning rate.
gradient_ascent_optimizer = (
gradient_ascent_optimization_algorithms.MomentumOptimizer(
learner_config.step_size, momentum))
elif gradient_ascent_optimizer_type == GradientAscentOptimizerType.ADAM:
logging.info("Running Adam gradient ascent optimizer")
gradient_ascent_optimizer = (
gradient_ascent_optimization_algorithms.AdamOptimizer(
learner_config.step_size, beta1, beta2))
else:
logging.info("No gradient ascent optimizer selected. Stopping.")
return
# ----------------------------------------------------------------------------
# ------------------ OPTIMIZERS ----------------------------------------------
if learner_config.blackbox_optimizer == (
blackbox_optimizers.Algorithm.MONTE_CARLO):
logging.info("Running ES/ARS. Filtering: %s directions",
str(learner_config.num_top_directions))
blackbox_optimizer = blackbox_optimizers.MonteCarloBlackboxOptimizer(
learner_config.precision_parameter, learner_config.estimator_type,
learner_config.fvalues_normalization,
learner_config.hyperparameters_update_method, metaparams, None,
learner_config.num_top_directions, gradient_ascent_optimizer)
elif learner_config.blackbox_optimizer == (
blackbox_optimizers.Algorithm.TRUST_REGION):
logging.info("Running trust region")
tr_params = {
"init_radius": FLAGS.tr_init_radius,
"grow_threshold": FLAGS.tr_grow_threshold,
"grow_factor": FLAGS.tr_grow_factor,
"shrink_neg_threshold": FLAGS.tr_shrink_neg_threshold,
"shrink_factor": FLAGS.tr_shrink_factor,
"reject_threshold": FLAGS.tr_reject_threshold,
"reject_factor": FLAGS.tr_reject_factor,
"dense_hessian": FLAGS.tr_dense_hessian,
"sub_termination": FLAGS.tr_sub_termination,
"subproblem_maxiter": FLAGS.tr_subproblem_maxiter,
"minimum_radius": FLAGS.tr_minimum_radius,
"grad_type": FLAGS.grad_type,
"grad_reg_type": _GRAD_REG_TYPE.value,
"grad_reg_alpha": _GRAD_REG_ALPHA.value
}
for param, value in tr_params.items():
logging.info("%s: %s", param, value)
blackbox_optimizer = blackbox_optimizers.TrustRegionOptimizer(
learner_config.precision_parameter, learner_config.estimator_type,
learner_config.fvalues_normalization,
learner_config.hyperparameters_update_method, metaparams, tr_params)
elif learner_config.blackbox_optimizer == (
blackbox_optimizers.Algorithm.SKLEARN_REGRESSION):
logging.info("Running Regression Based Optimizer")
blackbox_optimizer = blackbox_optimizers.SklearnRegressionBlackboxOptimizer(
_GRAD_REG_TYPE.value, _GRAD_REG_ALPHA.value,
learner_config.estimator_type, learner_config.fvalues_normalization,
learner_config.hyperparameters_update_method, metaparams, None,
gradient_ascent_optimizer)
else:
raise ValueError(
f"Unknown optimizer: '{learner_config.blackbox_optimizer}'")
logging.info("Initializing blackbox learner.")
learner = blackbox_learner.BlackboxLearner(
blackbox_opt=blackbox_optimizer,
train_corpus=cps,
output_dir=_OUTPUT_PATH.value,
policy_saver_fn=policy_saver_function,
model_weights=init_current_input,
config=learner_config,
initial_step=init_iteration,
deadline=_REQUEST_DEADLINE.value)
if not worker_class:
logging.info("No Worker class selected. Stopping.")
return
logging.info("Ready to train: running for %d steps.",
learner_config.total_steps)
with worker_manager_class(
worker_class,
count=_NUM_WORKERS.value,
worker_kwargs=dict(gin_config=gin.operative_config_str())) as pool:
learner.set_baseline(pool)
for _ in range(learner_config.total_steps):
learner.run_step(pool)
return learner.get_model_weights()