blob: 9967f16a2bb2aa9e2c865752e080d9541ff6e554 [file] [log] [blame]
import tensorflow as tf
import tensorflow_text as text
from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
_MAX_PREDICTIONS_PER_BATCH = 128
_PAD_TOKEN = 0
_MLM_IGNORE_TOKEN = -1
_MLM_MASK_TOKEN = 18000 - 1
def _get_masked_language_fn(*, selection_rate=0.33):
random_selector = text.RandomItemSelector(
max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
selection_rate=0.2,
unselectable_ids=[_PAD_TOKEN],
)
mask_values_chooser = text.MaskValuesChooser(
config._OPCODE_VOCAB_SIZE, _MLM_MASK_TOKEN, 0.8
)
def fn(opcodes):
masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(
tf.RaggedTensor.from_tensor(opcodes, padding=_PAD_TOKEN),
item_selector=random_selector,
mask_values_chooser=mask_values_chooser,
)
masked_pos = masked_pos.to_tensor(
default_value=-1,
shape=(config._NUM_REGISTERS, _MAX_PREDICTIONS_PER_BATCH),
)
ii = tf.tile(
tf.range(config._NUM_REGISTERS, dtype=tf.int64)[:, tf.newaxis],
[1, _MAX_PREDICTIONS_PER_BATCH],
)
masked_pos = tf.stack([ii, masked_pos], axis=-1)
scatter_values = tf.where(masked_pos[:, :, 1] < 0, 0.0, 1.0)
masked_pos = tf.where(
masked_pos < 0, tf.constant(_PAD_TOKEN, dtype=tf.int64), masked_pos
)
weights = tf.scatter_nd(
masked_pos,
scatter_values[:, :, tf.newaxis],
(config._NUM_REGISTERS, config._NUM_INSTRUCTIONS, 1),
)
return (
masked_token_ids.to_tensor(
default_value=0,
shape=(config._NUM_REGISTERS, config._NUM_INSTRUCTIONS),
),
opcodes,
weights,
)
return fn
def _roll_experience(seq_ex, *, shift=-1):
def _roll(atom):
return tf.roll(atom, shift=shift, axis=0)
def _cutoff(atom):
return atom[:shift]
seq_ex_roll = tf.nest.map_structure(_roll, seq_ex)
seq_ex_roll = tf.nest.map_structure(_cutoff, seq_ex_roll)
seq_ex = tf.nest.map_structure(_cutoff, seq_ex)
return seq_ex, seq_ex_roll
def _split_sequence_example(seq_ex, seq_ex_roll):
action_name = 'index_to_evict'
obs = {k: seq_ex[k] for k in seq_ex if k != action_name}
action = seq_ex[action_name]
obs_roll = {k: seq_ex_roll[k] for k in seq_ex_roll if k != action_name}
action_roll = seq_ex_roll[action_name]
return {
'obs': obs,
'action': action,
'obs_roll': obs_roll,
'action_roll': action_roll,
}
def _get_state_preprocessing_layer(
regalloc_input_spec, preprocessing_layer_creator
):
preprocessing_layers = {}
for name, spec in regalloc_input_spec.items():
preprocessing_layers[name] = preprocessing_layer_creator(spec)
def _preprocessing_layer(seq_ex):
pp = []
for layer_name, layer in preprocessing_layers.items():
pp.append(layer(seq_ex[layer_name]))
return tf.concat(pp, axis=-1)
def _layer(obs_dict):
obs_dict['obs'] = obs_dict['obs']
obs_dict['obs_cur'] = _preprocessing_layer(obs_dict['obs'])
obs_dict['obs_roll'] = _preprocessing_layer(obs_dict['obs_roll'])
return obs_dict
return _layer
def _get_to_inputs_and_labels_fn():
masked_language_fn = _get_masked_language_fn()
def fn(obs_dict):
inputs = {'obs': obs_dict['obs'], 'action': obs_dict['action']}
mlm_input, mlm_label, mlm_weight = masked_language_fn(
obs_dict['obs']['lr_use_def_opcode'][:, : config._NUM_INSTRUCTIONS]
)
inputs['lr_use_def_opcode'] = mlm_input
labels = {
config._STATE_KEY: obs_dict['obs_cur'],
config._ACTION_KEY: tf.expand_dims(obs_dict['action'], axis=-1),
config._NEXT_STATE_KEY: obs_dict['obs_roll'],
config._NEXT_ACTION_KEY: tf.expand_dims(
obs_dict['action_roll'], axis=-1
),
config._MLM_KEY: mlm_label,
}
mask = obs_dict['obs']['mask']
sample_weights = {
config._STATE_KEY: mask,
config._ACTION_KEY: mask,
config._NEXT_STATE_KEY: None,
config._NEXT_ACTION_KEY: None,
config._MLM_KEY: mlm_weight,
}
return (inputs, labels, sample_weights)
return fn
def process_dataset(
dataset, regalloc_input_spec, regalloc_preprocessing_layer_creator
):
shuffle_buffer_size = 128
num_map_threads = 128
state_preprocessing_layer = _get_state_preprocessing_layer(
regalloc_input_spec, regalloc_preprocessing_layer_creator
)
to_inputs_and_labels_fn = _get_to_inputs_and_labels_fn()
return (
dataset.map(_roll_experience, num_parallel_calls=num_map_threads)
.map(_split_sequence_example, num_parallel_calls=num_map_threads)
.map(state_preprocessing_layer, num_parallel_calls=num_map_threads)
.unbatch()
.shuffle(shuffle_buffer_size)
.map(to_inputs_and_labels_fn, num_parallel_calls=num_map_threads)
)