blob: a953aa4e280ec2e4e87636afaa1e4b248415d684 [file] [log] [blame]
import gin
import tensorflow as tf
from tf_agents.utils import nest_utils
from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
from google3.third_party.ml_compiler_opt.compiler_opt.rl import attention
class MultiHeadEncoderModel(tf.keras.Model):
def __init__(self, encoder, heads):
super().__init__(name='MultiHeadEncoderModel')
self._encoder = encoder
self._heads = heads
def get_encoder(self):
return self._encoder
def call(self, inputs, *args):
observation = inputs['obs']
action = tf.one_hot(inputs['action'], depth=config._NUM_REGISTERS)[
:, :, tf.newaxis
]
use_def_obs = {
k: v
for k, v in observation.items()
if k.startswith(config._ENCODER_FEATURE_PREFIX)
}
encoded_state_per_token, encoded_state = self._encoder(use_def_obs)
encoded_state_with_action = tf.concat([encoded_state, action], axis=-1)
def get_input(head_name):
if head_name == 'mlm':
return encoded_state_per_token
if head_name.startswith('next_'):
return encoded_state_with_action
return encoded_state
outputs = {}
for name in self._heads:
head = self._heads[name]
head_input = get_input(name)
outputs[name] = head(head_input)
return outputs
@gin.configurable
class LiveRangeEncoder(tf.keras.Model):
def __init__(
self,
input_specs,
preprocessing_layer_creator,
*,
num_layers=2,
num_heads=2,
model_dim=32,
fcn_dim=128,
num_extra_features=10,
):
super().__init__(name='LiveRangeEncoder')
self._encoder = attention.TransformerClassifier(
num_tokens=config._OPCODE_VOCAB_SIZE,
num_layers=num_layers,
num_heads=num_heads,
model_dim=model_dim,
fcn_dim=fcn_dim,
num_extra_features=num_extra_features,
)
self._linear_reshape = tf.keras.layers.Dense(config._ENCODING_SIZE)
self._preprocessing_layers = {}
for key in input_specs:
self._preprocessing_layers[key] = preprocessing_layer_creator(key)
def preprocessor(observations):
pp_inputs = []
for key in input_specs:
pp_inputs.append(self._preprocessing_layers[key](observations[key]))
return tf.concat(pp_inputs, axis=-1)
self._preprocessor = preprocessor
def call(self, observations):
extra_hidden_state = self._preprocessor(observations)
tokens = observations[config._OPCODE_KEY]
encoded_state_per_token, encoded_state = self._encoder(
tokens, extra_hidden_state
)
encoded_state = self._linear_reshape(encoded_state)
return encoded_state_per_token, encoded_state
def create_model(
input_specs: dict,
output_specs: dict,
preprocessing_layer_creator: dict,
*,
mlp_width: int = 128,
):
encoder = LiveRangeEncoder(input_specs, preprocessing_layer_creator)
output_heads = {}
for name, spec in output_specs.items():
size = spec.shape[-1]
output_heads[name] = tf.keras.Sequential([
tf.keras.layers.Dense(mlp_width),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(size),
])
return MultiHeadEncoderModel(encoder=encoder, heads=output_heads)