blob: 11a458ecfc424734f7086b8c3c2e3bd13467a52c [file] [log] [blame]
"""LR encoder training config."""
import gin
import tensorflow as tf
import tf_agents
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
from google3.third_party.ml_compiler_opt.compiler_opt.rl import feature_ops
from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import config
_NUM_REGISTERS = 33
_NUM_INSTRUCTIONS = 64
_OPCODE_VOCAB_SIZE = 20000
_ENCODING_SIZE = 16
_OPCODE_KEY = 'lr_use_def_opcode'
_ENCODER_FEATURE_PREFIX = 'lr_use_def'
_ACTION_KEY = 'action'
_STATE_KEY = 'state'
_NEXT_STATE_KEY = 'next_state'
_NEXT_ACTION_KEY = 'next_action'
_MLM_KEY = 'mlm'
@gin.configurable
def get_lr_encoder_signature_spec():
observation_spec = {}
observation_spec[_OPCODE_KEY] = tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
name=_OPCODE_KEY,
minimum=0,
maximum=_OPCODE_VOCAB_SIZE,
)
observation_spec['lr_use_def_read'] = tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
name='lr_use_def_read',
minimum=0,
maximum=1,
)
observation_spec['lr_use_def_write'] = tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
name='lr_use_def_write',
minimum=0,
maximum=1,
)
observation_spec.update(
{
name: tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
name=name,
minimum=0,
maximum=1,
)
for name in [
'lr_use_def_is_use',
'lr_use_def_is_def',
'lr_use_def_is_implicit',
'lr_use_def_is_renamable',
'lr_use_def_is_ind_var_update',
'lr_use_def_is_hint',
]
}
)
observation_spec['lr_use_def_freq'] = tf.TensorSpec(
dtype=tf.float32,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
name='lr_use_def_freq',
)
encoding_spec = tf.TensorSpec(
dtype=tf.float32, shape=(33, _ENCODING_SIZE), name='lr_encoding'
)
return observation_spec, encoding_spec
def get_input_specs():
encoder_observation_spec, encoding_spec = get_lr_encoder_signature_spec()
del encoding_spec
regalloc_time_step_spec, regalloc_action_spec = (
config.get_regalloc_signature_spec()
)
# Ensure that there are no overlaps in feature names between the encoder and the RL problem
common_keys = (
encoder_observation_spec.keys()
& regalloc_time_step_spec.observation.keys()
)
assert len(common_keys) == 0
supervised_input_spec = {
**encoder_observation_spec,
**regalloc_time_step_spec.observation,
}
return supervised_input_spec
def get_output_specs(regalloc_preprocessing_layer_creator):
regalloc_time_step_spec, regalloc_action_spec = (
config.get_regalloc_signature_spec()
)
del regalloc_action_spec
random_regalloc_state = tf_agents.specs.sample_spec_nest(
regalloc_time_step_spec.observation
)
for key in random_regalloc_state:
preprocessing_layer = regalloc_preprocessing_layer_creator(
regalloc_time_step_spec.observation[key]
)
random_regalloc_state[key] = preprocessing_layer(
tf.expand_dims(random_regalloc_state[key], axis=0)
)
random_regalloc_state = tf.concat(
list(random_regalloc_state.values()), axis=-1
)
random_regalloc_state = tf.squeeze(random_regalloc_state, axis=0)
regalloc_state_shape = random_regalloc_state.shape
action_spec = tf.TensorSpec(
dtype=tf.float32, shape=(_NUM_REGISTERS, 1), name=_ACTION_KEY
)
state_spec = tf.TensorSpec(
dtype=tf.float32, shape=regalloc_state_shape, name=_STATE_KEY
)
mlm_spec = tf.TensorSpec(
dtype=tf.float32,
shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS, _OPCODE_VOCAB_SIZE),
name=_MLM_KEY,
)
return {
_ACTION_KEY: action_spec,
_STATE_KEY: state_spec,
_NEXT_ACTION_KEY: action_spec,
_NEXT_STATE_KEY: state_spec,
_MLM_KEY: mlm_spec,
}
def get_loss():
loss_fns = {
_STATE_KEY: 'mse',
_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True
),
_NEXT_STATE_KEY: 'mse',
_NEXT_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True
),
_MLM_KEY: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
}
loss_weights = {
_ACTION_KEY: 1.0,
_STATE_KEY: 1.0,
_NEXT_ACTION_KEY: 1.0,
_NEXT_STATE_KEY: 1.0,
_MLM_KEY: 10.0,
}
return loss_fns, loss_weights
def get_metrics():
return {
_ACTION_KEY: ['accuracy'],
_STATE_KEY: [],
_NEXT_ACTION_KEY: ['accuracy'],
_NEXT_STATE_KEY: [],
_MLM_KEY: ['accuracy'],
}
@gin.configurable
def get_preprocessing_layer_creator(
quantile_file_dir='/cns/oz-d/home/mlcompileropt-dev/regalloc-transformer/vocab',
with_sqrt=True,
with_z_score_normalization=True,
eps=1e-8,
):
"""Wrapper for observation_processing_layer."""
quantile_map = feature_ops.build_quantile_map(quantile_file_dir)
def preprocessing_layer_creator(obs_spec):
"""Creates the layer to process observation given obs_spec."""
if obs_spec.name == 'lr_use_def_freq':
quantile = quantile_map[obs_spec.name]
first_non_zero = 0
for x in quantile:
if x > 0:
first_non_zero = x
break
normalize_fn = feature_ops.get_normalize_fn(
quantile, with_sqrt, with_z_score_normalization, eps
)
return tf.keras.layers.Lambda(normalize_fn)
return tf.keras.layers.Lambda(feature_ops.identity_fn)
return preprocessing_layer_creator
def get_nonnormalized_features():
return [
'lr_use_def_opcode',
'lr_use_def_read',
'lr_use_def_write',
'lr_use_def_is_use',
'lr_use_def_is_def',
'lr_use_def_is_implicit',
'lr_use_def_is_renamable',
'lr_use_def_is_ind_var_update',
'lr_use_def_is_hint',
]