Add self-supervised training code.
This is only for staging - DO NOT MERGE.
diff --git a/compiler_opt/rl/attention.py b/compiler_opt/rl/attention.py
new file mode 100644
index 0000000..30915b6
--- /dev/null
+++ b/compiler_opt/rl/attention.py
@@ -0,0 +1,183 @@
+"""attention.py
+
+Defines various building-blocks for attention in neural networks. Namely, the
+transformer encoder, which is a sequence-to-sequence model which uses
+self-attention to model relationships within the sequence.
+"""
+
+import tensorflow as tf
+import numpy as np
+
+from typing import Optional, List
+
+
+def positional_encoding(length, depth):
+ """Build a positional encoding tensor.
+
+ Taken from https://www.tensorflow.org/text/tutorials/transformer.
+
+ Args:
+ length: the number of sin/cos samples to generate.
+ depth: the depth of the embedding which the encoding will be summed with.
+
+ Returns:
+ A tensor of shape (length, depth) representing the positional encoding.
+ """
+ depth = depth / 2
+
+ positions = np.arange(length)[:, np.newaxis] # (seq, 1)
+ depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
+
+ angle_rates = 1 / (10000**depths) # (1, depth)
+ angle_rads = positions * angle_rates # (pos, depth)
+
+ pos_encoding = np.concatenate(
+ [np.sin(angle_rads), np.cos(angle_rads)], axis=-1
+ )
+
+ return tf.cast(pos_encoding, dtype=tf.float32)
+
+
+class PositionalEmbedding(tf.keras.layers.Layer):
+ """A positional embedding layer.
+
+ A "positional embedding" is a sum of a token embedding with a positional
+ encoding, which is used as the initial layer of a transformer encoder
+ network.
+
+ Taken from https://www.tensorflow.org/text/tutorials/transformer.
+ """
+
+ def __init__(self, vocab_size, d_model):
+ """Initialize the positional embedding.
+
+ Args:
+ vocab_size: the size of the vocab, which should be one more than the
+ maximum token value which will be seen during training/inference.
+ d_model: the dimension of the model (size of the embedding vector)
+ """
+ super().__init__()
+ self.d_model = d_model
+ self.embedding = tf.keras.layers.Embedding(
+ vocab_size, d_model, mask_zero=True
+ )
+ self.pos_encoding = tf.constant(
+ positional_encoding(length=2048, depth=d_model), dtype=tf.float32
+ )
+
+ def compute_mask(self, *args, **kwargs):
+ """Returns a mask for the given input."""
+ return self.embedding.compute_mask(*args, **kwargs)
+
+ def call(self, x):
+ """Perform the positional embedding."""
+ length = tf.shape(x)[-1]
+ x = self.embedding(x)
+ # This factor sets the relative scale of the embedding and positonal_encoding.
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
+ x = x + self.pos_encoding[tf.newaxis, :length, :]
+ return x
+
+
+class TransformerEncoderLayer(tf.keras.layers.Layer):
+ """Transformer Encoder.
+
+ See https://arxiv.org/abs/1706.03762 for more details.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ model_dim: int,
+ fcn_dim: int,
+ attention_axes: Optional[List] = None,
+ ):
+ """Initialize the transformer encoder.
+
+ Args:
+ num_heads: number of distinct attention heads within the layer.
+ model_dim: dimension of the model, also the dimension of the embedding.
+ fcn_dim: dimension of the fully-connected layers between attention layers.
+ attention_axes: which axes in the input tensors perform attention across.
+ """
+ super().__init__()
+ self._mha = tf.keras.layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=model_dim, attention_axes=attention_axes
+ )
+ self._mha_norm = tf.keras.layers.LayerNormalization()
+ self._mha_add = tf.keras.layers.Add()
+
+ self._fcn = tf.keras.Sequential([
+ tf.keras.layers.Dense(fcn_dim),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(model_dim),
+ tf.keras.layers.Dropout(0.1),
+ ])
+ self._fcn_norm = tf.keras.layers.LayerNormalization()
+ self._fcn_add = tf.keras.layers.Add()
+
+ def call(self, x, attention_mask=None):
+ """Call the transformer encoder."""
+ x_attended = self._mha(
+ query=x, value=x, key=x, attention_mask=attention_mask
+ )
+ x = self._mha_add([x_attended, x])
+ x = self._mha_norm(x)
+
+ x_fcn = self._fcn(x)
+ x = self._fcn_add([x_fcn, x])
+ x = self._fcn_norm(x)
+
+ return x
+
+
+class TransformerClassifier(tf.keras.layers.Layer):
+
+ def __init__(
+ self,
+ *,
+ num_tokens: int,
+ num_layers: int,
+ num_heads: int,
+ model_dim: int,
+ fcn_dim: int,
+ num_extra_features: int = 0,
+ ):
+ super().__init__()
+
+ self._model_dim = model_dim
+ self._ctx_token = tf.constant(num_tokens, dtype=tf.int64)
+ self._embed_layer = PositionalEmbedding(
+ num_tokens + 1, model_dim - num_extra_features
+ )
+ self._transformer_layers = [
+ TransformerEncoderLayer(
+ num_heads=num_heads,
+ model_dim=model_dim,
+ fcn_dim=fcn_dim,
+ attention_axes=(2,),
+ )
+ for _ in range(num_layers)
+ ]
+
+ def __call__(self, x, extra_hidden_state):
+ # [B, 33, 1 + I] --> [B, 33, 1 + I, E]
+ mask = self._embed_layer.compute_mask(x)
+ x = self._embed_layer(x)
+
+ # Append the extra hidden state
+ x = tf.concat([x, extra_hidden_state], axis=-1)
+
+ mask1 = mask[:, :, :, tf.newaxis]
+ mask2 = mask[:, :, tf.newaxis, :]
+ attn_mask = tf.cast(mask1, dtype=tf.int64) + tf.cast(mask2, dtype=tf.int64)
+ attn_mask = attn_mask > 0
+
+ for transformer_layer in self._transformer_layers:
+ x = transformer_layer(x, attention_mask=attn_mask)
+
+ mask_reduce = tf.cast(mask[:, :, :, tf.newaxis], dtype=tf.float32)
+ x_reduced = (tf.reduce_sum(mask_reduce * x, axis=-2)) / (
+ tf.reduce_sum(mask_reduce, axis=-2) + 1e-3
+ )
+ return x, x_reduced
diff --git a/compiler_opt/rl/regalloc/lr_encoder/BUILD b/compiler_opt/rl/regalloc/lr_encoder/BUILD
new file mode 100644
index 0000000..7fde6aa
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/BUILD
@@ -0,0 +1,96 @@
+# Inlining configurations.
+load("//devtools/python/blaze:strict.bzl", "py_strict_library")
+load("//devtools/python/blaze:pytype.bzl", "pytype_strict_binary", "pytype_strict_library")
+
+licenses(["notice"])
+
+package(
+ default_applicable_licenses = ["//third_party/ml_compiler_opt:license"],
+ default_visibility = [
+ "//third_party/ml_compiler_opt:default_visibility",
+ ],
+)
+
+filegroup(
+ name = "gin_files",
+ srcs = glob(["gin_configs/**"]),
+)
+
+py_strict_library(
+ name = "config",
+ srcs = ["config.py"],
+ deps = [
+ "//third_party/ml_compiler_opt/compiler_opt/rl:feature_ops",
+ "//third_party/ml_compiler_opt/compiler_opt/rl/regalloc:config",
+ "//third_party/py/gin",
+ "//third_party/py/tensorflow:tensorflow_no_contrib",
+ "//third_party/py/tf_agents",
+ "//third_party/py/tf_agents/specs:tensor_spec",
+ "//third_party/py/tf_agents/trajectories:time_step",
+ ],
+)
+
+pytype_strict_library(
+ name = "model",
+ srcs = ["model.py"],
+ deps = [
+ ":config",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:attention",
+ "//third_party/py/gin",
+ "//third_party/py/tensorflow:tensorflow_no_contrib",
+ "//third_party/py/tf_agents/utils:nest_utils",
+ ],
+)
+
+pytype_strict_library(
+ name = "dataset_ops",
+ srcs = ["dataset_ops.py"],
+ deps = [
+ ":config",
+ "//third_party/py/tensorflow:tensorflow_no_contrib",
+ "//third_party/py/tensorflow_text",
+ ],
+)
+
+pytype_strict_library(
+ name = "lr_encoder_runner",
+ srcs = ["lr_encoder_runner.py"],
+ deps = [
+ "//third_party/ml_compiler_opt/compiler_opt/rl:compilation_runner",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:corpus",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:log_reader",
+ "//third_party/py/gin",
+ "//third_party/py/tensorflow:tensorflow_no_contrib",
+ ],
+)
+
+pytype_strict_library(
+ name = "lr_encoder_problem_config",
+ srcs = ["__init__.py"],
+ deps = [
+ ":config",
+ ":lr_encoder_runner",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:problem_configuration",
+ "//third_party/py/gin",
+ ],
+)
+
+pytype_strict_binary(
+ name = "train",
+ srcs = ["train.py"],
+ data = [":gin_files"],
+ deps = [
+ ":config",
+ ":dataset_ops",
+ ":model",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:policy_saver",
+ "//third_party/ml_compiler_opt/compiler_opt/rl:registry",
+ "//third_party/ml_compiler_opt/compiler_opt/rl/regalloc:config",
+ "//third_party/py/absl:app",
+ "//third_party/py/absl/flags",
+ "//third_party/py/absl/logging",
+ "//third_party/py/gin",
+ "//third_party/py/tensorflow:tensorflow_google", # build_cleaner: keep
+ "//third_party/py/tensorflow:tensorflow_no_contrib",
+ ],
+)
diff --git a/compiler_opt/rl/regalloc/lr_encoder/__init__.py b/compiler_opt/rl/regalloc/lr_encoder/__init__.py
new file mode 100644
index 0000000..8100aa8
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/__init__.py
@@ -0,0 +1,24 @@
+"""Implementation of the 'lr_encoder' problem."""
+
+import gin
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import problem_configuration
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import lr_encoder_runner
+
+
+@gin.register(module='configs')
+class LREncoderConfig(problem_configuration.ProblemConfiguration):
+ """Expose the LR encoder configuration."""
+
+ def get_runner_type(self):
+ return lr_encoder_runner.LREncoderRunner
+
+ def get_signature_spec(self):
+ return config.get_lr_encoder_signature_spec()
+
+ def get_preprocessing_layer_creator(self):
+ return config.get_preprocessing_layer_creator()
+
+ def get_nonnormalized_features(self):
+ return config.get_nonnormalized_features()
diff --git a/compiler_opt/rl/regalloc/lr_encoder/config.py b/compiler_opt/rl/regalloc/lr_encoder/config.py
new file mode 100644
index 0000000..11a458e
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/config.py
@@ -0,0 +1,219 @@
+"""LR encoder training config."""
+
+import gin
+import tensorflow as tf
+import tf_agents
+from tf_agents.specs import tensor_spec
+from tf_agents.trajectories import time_step
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import feature_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import config
+
+_NUM_REGISTERS = 33
+_NUM_INSTRUCTIONS = 64
+_OPCODE_VOCAB_SIZE = 20000
+_ENCODING_SIZE = 16
+_OPCODE_KEY = 'lr_use_def_opcode'
+
+_ENCODER_FEATURE_PREFIX = 'lr_use_def'
+
+_ACTION_KEY = 'action'
+_STATE_KEY = 'state'
+_NEXT_STATE_KEY = 'next_state'
+_NEXT_ACTION_KEY = 'next_action'
+_MLM_KEY = 'mlm'
+
+
+@gin.configurable
+def get_lr_encoder_signature_spec():
+ observation_spec = {}
+ observation_spec[_OPCODE_KEY] = tensor_spec.BoundedTensorSpec(
+ dtype=tf.int64,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+ name=_OPCODE_KEY,
+ minimum=0,
+ maximum=_OPCODE_VOCAB_SIZE,
+ )
+ observation_spec['lr_use_def_read'] = tensor_spec.BoundedTensorSpec(
+ dtype=tf.int64,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+ name='lr_use_def_read',
+ minimum=0,
+ maximum=1,
+ )
+ observation_spec['lr_use_def_write'] = tensor_spec.BoundedTensorSpec(
+ dtype=tf.int64,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+ name='lr_use_def_write',
+ minimum=0,
+ maximum=1,
+ )
+ observation_spec.update(
+ {
+ name: tensor_spec.BoundedTensorSpec(
+ dtype=tf.int64,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+ name=name,
+ minimum=0,
+ maximum=1,
+ )
+ for name in [
+ 'lr_use_def_is_use',
+ 'lr_use_def_is_def',
+ 'lr_use_def_is_implicit',
+ 'lr_use_def_is_renamable',
+ 'lr_use_def_is_ind_var_update',
+ 'lr_use_def_is_hint',
+ ]
+ }
+ )
+ observation_spec['lr_use_def_freq'] = tf.TensorSpec(
+ dtype=tf.float32,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+ name='lr_use_def_freq',
+ )
+
+ encoding_spec = tf.TensorSpec(
+ dtype=tf.float32, shape=(33, _ENCODING_SIZE), name='lr_encoding'
+ )
+
+ return observation_spec, encoding_spec
+
+
+def get_input_specs():
+ encoder_observation_spec, encoding_spec = get_lr_encoder_signature_spec()
+ del encoding_spec
+
+ regalloc_time_step_spec, regalloc_action_spec = (
+ config.get_regalloc_signature_spec()
+ )
+
+ # Ensure that there are no overlaps in feature names between the encoder and the RL problem
+ common_keys = (
+ encoder_observation_spec.keys()
+ & regalloc_time_step_spec.observation.keys()
+ )
+ assert len(common_keys) == 0
+
+ supervised_input_spec = {
+ **encoder_observation_spec,
+ **regalloc_time_step_spec.observation,
+ }
+
+ return supervised_input_spec
+
+
+def get_output_specs(regalloc_preprocessing_layer_creator):
+ regalloc_time_step_spec, regalloc_action_spec = (
+ config.get_regalloc_signature_spec()
+ )
+ del regalloc_action_spec
+
+ random_regalloc_state = tf_agents.specs.sample_spec_nest(
+ regalloc_time_step_spec.observation
+ )
+ for key in random_regalloc_state:
+ preprocessing_layer = regalloc_preprocessing_layer_creator(
+ regalloc_time_step_spec.observation[key]
+ )
+ random_regalloc_state[key] = preprocessing_layer(
+ tf.expand_dims(random_regalloc_state[key], axis=0)
+ )
+ random_regalloc_state = tf.concat(
+ list(random_regalloc_state.values()), axis=-1
+ )
+ random_regalloc_state = tf.squeeze(random_regalloc_state, axis=0)
+ regalloc_state_shape = random_regalloc_state.shape
+
+ action_spec = tf.TensorSpec(
+ dtype=tf.float32, shape=(_NUM_REGISTERS, 1), name=_ACTION_KEY
+ )
+ state_spec = tf.TensorSpec(
+ dtype=tf.float32, shape=regalloc_state_shape, name=_STATE_KEY
+ )
+ mlm_spec = tf.TensorSpec(
+ dtype=tf.float32,
+ shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS, _OPCODE_VOCAB_SIZE),
+ name=_MLM_KEY,
+ )
+ return {
+ _ACTION_KEY: action_spec,
+ _STATE_KEY: state_spec,
+ _NEXT_ACTION_KEY: action_spec,
+ _NEXT_STATE_KEY: state_spec,
+ _MLM_KEY: mlm_spec,
+ }
+
+
+def get_loss():
+ loss_fns = {
+ _STATE_KEY: 'mse',
+ _ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ ),
+ _NEXT_STATE_KEY: 'mse',
+ _NEXT_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ ),
+ _MLM_KEY: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ }
+ loss_weights = {
+ _ACTION_KEY: 1.0,
+ _STATE_KEY: 1.0,
+ _NEXT_ACTION_KEY: 1.0,
+ _NEXT_STATE_KEY: 1.0,
+ _MLM_KEY: 10.0,
+ }
+ return loss_fns, loss_weights
+
+
+def get_metrics():
+ return {
+ _ACTION_KEY: ['accuracy'],
+ _STATE_KEY: [],
+ _NEXT_ACTION_KEY: ['accuracy'],
+ _NEXT_STATE_KEY: [],
+ _MLM_KEY: ['accuracy'],
+ }
+
+
+@gin.configurable
+def get_preprocessing_layer_creator(
+ quantile_file_dir='/cns/oz-d/home/mlcompileropt-dev/regalloc-transformer/vocab',
+ with_sqrt=True,
+ with_z_score_normalization=True,
+ eps=1e-8,
+):
+ """Wrapper for observation_processing_layer."""
+ quantile_map = feature_ops.build_quantile_map(quantile_file_dir)
+
+ def preprocessing_layer_creator(obs_spec):
+ """Creates the layer to process observation given obs_spec."""
+ if obs_spec.name == 'lr_use_def_freq':
+ quantile = quantile_map[obs_spec.name]
+ first_non_zero = 0
+ for x in quantile:
+ if x > 0:
+ first_non_zero = x
+ break
+
+ normalize_fn = feature_ops.get_normalize_fn(
+ quantile, with_sqrt, with_z_score_normalization, eps
+ )
+ return tf.keras.layers.Lambda(normalize_fn)
+ return tf.keras.layers.Lambda(feature_ops.identity_fn)
+
+ return preprocessing_layer_creator
+
+
+def get_nonnormalized_features():
+ return [
+ 'lr_use_def_opcode',
+ 'lr_use_def_read',
+ 'lr_use_def_write',
+ 'lr_use_def_is_use',
+ 'lr_use_def_is_def',
+ 'lr_use_def_is_implicit',
+ 'lr_use_def_is_renamable',
+ 'lr_use_def_is_ind_var_update',
+ 'lr_use_def_is_hint',
+ ]
diff --git a/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py b/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py
new file mode 100644
index 0000000..9967f16
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py
@@ -0,0 +1,158 @@
+import tensorflow as tf
+import tensorflow_text as text
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+
+_MAX_PREDICTIONS_PER_BATCH = 128
+
+_PAD_TOKEN = 0
+_MLM_IGNORE_TOKEN = -1
+_MLM_MASK_TOKEN = 18000 - 1
+
+
+def _get_masked_language_fn(*, selection_rate=0.33):
+ random_selector = text.RandomItemSelector(
+ max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
+ selection_rate=0.2,
+ unselectable_ids=[_PAD_TOKEN],
+ )
+ mask_values_chooser = text.MaskValuesChooser(
+ config._OPCODE_VOCAB_SIZE, _MLM_MASK_TOKEN, 0.8
+ )
+
+ def fn(opcodes):
+ masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(
+ tf.RaggedTensor.from_tensor(opcodes, padding=_PAD_TOKEN),
+ item_selector=random_selector,
+ mask_values_chooser=mask_values_chooser,
+ )
+
+ masked_pos = masked_pos.to_tensor(
+ default_value=-1,
+ shape=(config._NUM_REGISTERS, _MAX_PREDICTIONS_PER_BATCH),
+ )
+
+ ii = tf.tile(
+ tf.range(config._NUM_REGISTERS, dtype=tf.int64)[:, tf.newaxis],
+ [1, _MAX_PREDICTIONS_PER_BATCH],
+ )
+ masked_pos = tf.stack([ii, masked_pos], axis=-1)
+ scatter_values = tf.where(masked_pos[:, :, 1] < 0, 0.0, 1.0)
+ masked_pos = tf.where(
+ masked_pos < 0, tf.constant(_PAD_TOKEN, dtype=tf.int64), masked_pos
+ )
+ weights = tf.scatter_nd(
+ masked_pos,
+ scatter_values[:, :, tf.newaxis],
+ (config._NUM_REGISTERS, config._NUM_INSTRUCTIONS, 1),
+ )
+ return (
+ masked_token_ids.to_tensor(
+ default_value=0,
+ shape=(config._NUM_REGISTERS, config._NUM_INSTRUCTIONS),
+ ),
+ opcodes,
+ weights,
+ )
+
+ return fn
+
+
+def _roll_experience(seq_ex, *, shift=-1):
+ def _roll(atom):
+ return tf.roll(atom, shift=shift, axis=0)
+
+ def _cutoff(atom):
+ return atom[:shift]
+
+ seq_ex_roll = tf.nest.map_structure(_roll, seq_ex)
+ seq_ex_roll = tf.nest.map_structure(_cutoff, seq_ex_roll)
+ seq_ex = tf.nest.map_structure(_cutoff, seq_ex)
+ return seq_ex, seq_ex_roll
+
+
+def _split_sequence_example(seq_ex, seq_ex_roll):
+ action_name = 'index_to_evict'
+ obs = {k: seq_ex[k] for k in seq_ex if k != action_name}
+ action = seq_ex[action_name]
+ obs_roll = {k: seq_ex_roll[k] for k in seq_ex_roll if k != action_name}
+ action_roll = seq_ex_roll[action_name]
+ return {
+ 'obs': obs,
+ 'action': action,
+ 'obs_roll': obs_roll,
+ 'action_roll': action_roll,
+ }
+
+
+def _get_state_preprocessing_layer(
+ regalloc_input_spec, preprocessing_layer_creator
+):
+ preprocessing_layers = {}
+ for name, spec in regalloc_input_spec.items():
+ preprocessing_layers[name] = preprocessing_layer_creator(spec)
+
+ def _preprocessing_layer(seq_ex):
+ pp = []
+ for layer_name, layer in preprocessing_layers.items():
+ pp.append(layer(seq_ex[layer_name]))
+ return tf.concat(pp, axis=-1)
+
+ def _layer(obs_dict):
+ obs_dict['obs'] = obs_dict['obs']
+ obs_dict['obs_cur'] = _preprocessing_layer(obs_dict['obs'])
+ obs_dict['obs_roll'] = _preprocessing_layer(obs_dict['obs_roll'])
+ return obs_dict
+
+ return _layer
+
+
+def _get_to_inputs_and_labels_fn():
+ masked_language_fn = _get_masked_language_fn()
+
+ def fn(obs_dict):
+ inputs = {'obs': obs_dict['obs'], 'action': obs_dict['action']}
+ mlm_input, mlm_label, mlm_weight = masked_language_fn(
+ obs_dict['obs']['lr_use_def_opcode'][:, : config._NUM_INSTRUCTIONS]
+ )
+ inputs['lr_use_def_opcode'] = mlm_input
+
+ labels = {
+ config._STATE_KEY: obs_dict['obs_cur'],
+ config._ACTION_KEY: tf.expand_dims(obs_dict['action'], axis=-1),
+ config._NEXT_STATE_KEY: obs_dict['obs_roll'],
+ config._NEXT_ACTION_KEY: tf.expand_dims(
+ obs_dict['action_roll'], axis=-1
+ ),
+ config._MLM_KEY: mlm_label,
+ }
+ mask = obs_dict['obs']['mask']
+ sample_weights = {
+ config._STATE_KEY: mask,
+ config._ACTION_KEY: mask,
+ config._NEXT_STATE_KEY: None,
+ config._NEXT_ACTION_KEY: None,
+ config._MLM_KEY: mlm_weight,
+ }
+ return (inputs, labels, sample_weights)
+
+ return fn
+
+
+def process_dataset(
+ dataset, regalloc_input_spec, regalloc_preprocessing_layer_creator
+):
+ shuffle_buffer_size = 128
+ num_map_threads = 128
+ state_preprocessing_layer = _get_state_preprocessing_layer(
+ regalloc_input_spec, regalloc_preprocessing_layer_creator
+ )
+ to_inputs_and_labels_fn = _get_to_inputs_and_labels_fn()
+ return (
+ dataset.map(_roll_experience, num_parallel_calls=num_map_threads)
+ .map(_split_sequence_example, num_parallel_calls=num_map_threads)
+ .map(state_preprocessing_layer, num_parallel_calls=num_map_threads)
+ .unbatch()
+ .shuffle(shuffle_buffer_size)
+ .map(to_inputs_and_labels_fn, num_parallel_calls=num_map_threads)
+ )
diff --git a/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin b/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin
new file mode 100644
index 0000000..c0fcb35
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin
@@ -0,0 +1,14 @@
+config_registry.get_configuration.implementation=@configs.LREncoderConfig
+
+launcher_path=None
+clang_path=None
+
+runners.LREncoderRunner.clang_path=%clang_path
+runners.LREncoderRunner.launcher_path=%launcher_path
+
+problem_config.flags_to_add.add_flags=()
+problem_config.flags_to_delete.delete_flags=('-split-dwarf-file','-split-dwarf-output',)
+# For AFDO profile reinjection set:
+# problem_config.flags_to_replace.replace_flags={'-fprofile-sample-use':'/path/to/gwp.afdo','-fprofile-remapping-file':'/path/to/prof_remap.txt'}
+problem_config.flags_to_replace.replace_flags={'-fprofile-instrument-use-path': '/tmp/corpus_1_21/MergedCS.profdata'}
+#problem_config.flags_to_replace.replace_flags={'-fprofile-instrument-use-path': 'MergedCS_profdata_jacobhegna/MergedCS_profdata'}
diff --git a/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py b/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py
new file mode 100644
index 0000000..3c90edc
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py
@@ -0,0 +1,94 @@
+"""Module for collect data of the LR encoder."""
+
+import os
+import tempfile
+from typing import Dict, Tuple
+
+import gin
+import tensorflow as tf
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import compilation_runner
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import corpus
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import log_reader
+
+
+@gin.configurable(module='runners')
+class LREncoderRunner(compilation_runner.CompilationRunner):
+ """Class for collecting data for the LR encoder."""
+
+ def compile_fn(
+ self,
+ command_line: corpus.FullyQualifiedCmdLine,
+ tf_policy_path: str,
+ reward_only: bool,
+ ) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
+ """Run compilation for the given IR file under the given policy.
+
+ Args:
+ command_line: the fully qualified command line.
+ tf_policy_path: path to TF policy direcoty on local disk.
+ reward_only: whether only return reward.
+
+ Returns:
+ A dict mapping from example identifier to tuple containing:
+ sequence_example: A tf.SequenceExample proto describing compilation
+ trace, None if reward_only == True.
+ reward: reward of register allocation.
+
+ Raises:
+ subprocess.CalledProcessError: if process fails.
+ compilation_runner.ProcessKilledError: (which it must pass through) on
+ cancelled work.
+ RuntimeError: if llvm-size produces unexpected output.
+ """
+ assert not tf_policy_path
+
+ working_dir = tempfile.mkdtemp()
+
+ log_path = os.path.join(working_dir, 'log')
+ output_native_path = os.path.join(working_dir, 'native')
+
+ result = {}
+ try:
+ cmdline = []
+ if self._launcher_path:
+ cmdline.append(self._launcher_path)
+ cmdline.extend(
+ [self._clang_path]
+ + list(command_line)
+ + [
+ '-mllvm',
+ '-regalloc-enable-advisor=development',
+ '-mllvm',
+ '-regalloc-lr-encoder-training-log=' + log_path,
+ '-mllvm',
+ '-regalloc-training-log=/dev/null',
+ '-o',
+ output_native_path,
+ ]
+ )
+
+ compilation_runner.start_cancellable_process(
+ cmdline, self._compilation_timeout, self._cancellation_manager
+ )
+
+ if not os.path.exists(log_path):
+ return {}
+
+ # TODO(#202)
+ log_result = log_reader.read_log_as_sequence_examples(log_path)
+
+ for fct_name, trajectory in log_result.items():
+ if not trajectory.HasField('feature_lists'):
+ continue
+ # score = (
+ # trajectory.feature_lists.feature_list['reward']
+ # .feature[-1]
+ # .float_list.value[0]
+ # )
+ result[fct_name] = (trajectory, 1.0)
+
+ finally:
+ tf.io.gfile.rmtree(working_dir)
+
+ return result
diff --git a/compiler_opt/rl/regalloc/lr_encoder/model.py b/compiler_opt/rl/regalloc/lr_encoder/model.py
new file mode 100644
index 0000000..a953aa4
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/model.py
@@ -0,0 +1,114 @@
+import gin
+import tensorflow as tf
+from tf_agents.utils import nest_utils
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import attention
+
+
+class MultiHeadEncoderModel(tf.keras.Model):
+
+ def __init__(self, encoder, heads):
+ super().__init__(name='MultiHeadEncoderModel')
+
+ self._encoder = encoder
+ self._heads = heads
+
+ def get_encoder(self):
+ return self._encoder
+
+ def call(self, inputs, *args):
+ observation = inputs['obs']
+ action = tf.one_hot(inputs['action'], depth=config._NUM_REGISTERS)[
+ :, :, tf.newaxis
+ ]
+
+ use_def_obs = {
+ k: v
+ for k, v in observation.items()
+ if k.startswith(config._ENCODER_FEATURE_PREFIX)
+ }
+ encoded_state_per_token, encoded_state = self._encoder(use_def_obs)
+ encoded_state_with_action = tf.concat([encoded_state, action], axis=-1)
+
+ def get_input(head_name):
+ if head_name == 'mlm':
+ return encoded_state_per_token
+ if head_name.startswith('next_'):
+ return encoded_state_with_action
+ return encoded_state
+
+ outputs = {}
+ for name in self._heads:
+ head = self._heads[name]
+ head_input = get_input(name)
+ outputs[name] = head(head_input)
+
+ return outputs
+
+
+@gin.configurable
+class LiveRangeEncoder(tf.keras.Model):
+
+ def __init__(
+ self,
+ input_specs,
+ preprocessing_layer_creator,
+ *,
+ num_layers=2,
+ num_heads=2,
+ model_dim=32,
+ fcn_dim=128,
+ num_extra_features=10,
+ ):
+ super().__init__(name='LiveRangeEncoder')
+ self._encoder = attention.TransformerClassifier(
+ num_tokens=config._OPCODE_VOCAB_SIZE,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ model_dim=model_dim,
+ fcn_dim=fcn_dim,
+ num_extra_features=num_extra_features,
+ )
+ self._linear_reshape = tf.keras.layers.Dense(config._ENCODING_SIZE)
+
+ self._preprocessing_layers = {}
+ for key in input_specs:
+ self._preprocessing_layers[key] = preprocessing_layer_creator(key)
+
+ def preprocessor(observations):
+ pp_inputs = []
+ for key in input_specs:
+ pp_inputs.append(self._preprocessing_layers[key](observations[key]))
+ return tf.concat(pp_inputs, axis=-1)
+
+ self._preprocessor = preprocessor
+
+ def call(self, observations):
+ extra_hidden_state = self._preprocessor(observations)
+ tokens = observations[config._OPCODE_KEY]
+ encoded_state_per_token, encoded_state = self._encoder(
+ tokens, extra_hidden_state
+ )
+ encoded_state = self._linear_reshape(encoded_state)
+ return encoded_state_per_token, encoded_state
+
+
+def create_model(
+ input_specs: dict,
+ output_specs: dict,
+ preprocessing_layer_creator: dict,
+ *,
+ mlp_width: int = 128,
+):
+ encoder = LiveRangeEncoder(input_specs, preprocessing_layer_creator)
+
+ output_heads = {}
+ for name, spec in output_specs.items():
+ size = spec.shape[-1]
+ output_heads[name] = tf.keras.Sequential([
+ tf.keras.layers.Dense(mlp_width),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(size),
+ ])
+
+ return MultiHeadEncoderModel(encoder=encoder, heads=output_heads)
diff --git a/compiler_opt/rl/regalloc/lr_encoder/train.py b/compiler_opt/rl/regalloc/lr_encoder/train.py
new file mode 100644
index 0000000..ca0b628
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/train.py
@@ -0,0 +1,183 @@
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+
+import gin
+import tensorflow as tf
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import config as regalloc_config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import dataset_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import model
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config as encoder_config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import registry # pylint: disable=unused-import
+
+
+flags.DEFINE_multi_string(
+ 'gin_files', [], 'List of paths to gin configuration files.'
+)
+
+flags.DEFINE_string('trace', None, '')
+flags.DEFINE_string('root_dir', None, '')
+
+flags.DEFINE_string('tpu', None, 'BNS address for the TPU')
+
+FLAGS = flags.FLAGS
+
+
+def _create_parser_fn(time_step_spec, action_spec):
+ context_features = {}
+ sequence_features = {
+ tensor_spec.name: tf.io.FixedLenSequenceFeature(
+ shape=tensor_spec.shape, dtype=tensor_spec.dtype
+ )
+ for tensor_spec in time_step_spec.values()
+ }
+ sequence_features[action_spec.name] = tf.io.FixedLenSequenceFeature(
+ shape=action_spec.shape, dtype=action_spec.dtype
+ )
+
+ def _parser_fn(serialized_proto):
+ with tf.name_scope('parse'):
+ _, parsed_sequence = tf.io.parse_single_sequence_example(
+ serialized_proto,
+ context_features=context_features,
+ sequence_features=sequence_features,
+ )
+ return parsed_sequence
+
+ return _parser_fn
+
+
+def get_dataset_creator_fn(
+ batch_size,
+ input_spec,
+ action_spec,
+ regalloc_input_spec,
+ regalloc_preprocessing_layer_creator,
+):
+ files_buffer_size = 64
+ num_map_threads = 128
+ num_readers = 64
+ parser_fn = _create_parser_fn(input_spec, action_spec)
+
+ def _file_dataset_fn(data_path):
+ dataset = (
+ tf.data.Dataset.list_files(data_path)
+ .shuffle(files_buffer_size)
+ .interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=num_readers,
+ block_length=1,
+ )
+ .filter(lambda string: tf.strings.length(string) > 0)
+ .map(parser_fn, num_parallel_calls=num_map_threads)
+ )
+ dataset = dataset_ops.process_dataset(
+ dataset, regalloc_input_spec, regalloc_preprocessing_layer_creator
+ )
+ return dataset.batch(batch_size, drop_remainder=True).prefetch(
+ tf.data.experimental.AUTOTUNE
+ )
+
+ return _file_dataset_fn
+
+
+def get_strategy():
+ if FLAGS.tpu:
+ logging.info('Using TPU strategy.')
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+ tf.config.experimental_connect_to_cluster(resolver)
+ tf.tpu.experimental.initialize_tpu_system(resolver)
+ return tf.distribute.TPUStrategy(resolver)
+ logging.info('Using CPU strategy.')
+ return tf.distribute.get_strategy()
+
+
+class SaveEncoderCallback(tf.keras.callbacks.Callback):
+
+ def __init__(self, encoder, path):
+ self._encoder = encoder
+ self._path = path
+
+ def on_epoch_end(self, epoch, logs=None):
+ sm_path = os.path.join(self._path, f'epoch{epoch}')
+ tflite_path = os.path.join(sm_path, policy_saver.TFLITE_MODEL_NAME)
+ self._encoder.save(sm_path)
+ policy_saver.convert_saved_model(sm_path, tflite_path)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(
+ FLAGS.gin_files, bindings=None, skip_unknown=False
+ )
+ logging.info(gin.config_str())
+
+ lr_input_specs, lr_encoding_spec = (
+ encoder_config.get_lr_encoder_signature_spec()
+ )
+
+ regalloc_time_step_spec, regalloc_action_spec = (
+ regalloc_config.get_regalloc_signature_spec()
+ )
+ regalloc_preprocessing_creator = (
+ regalloc_config.get_observation_processing_layer_creator()
+ )
+
+ dataset_fn = get_dataset_creator_fn(
+ batch_size=64,
+ input_spec=encoder_config.get_input_specs(),
+ action_spec=regalloc_action_spec,
+ regalloc_input_spec=regalloc_time_step_spec.observation,
+ regalloc_preprocessing_layer_creator=regalloc_preprocessing_creator,
+ )
+
+ strategy = get_strategy()
+ with strategy.scope():
+ logging.info('Creating model.')
+ lr_model = model.create_model(
+ lr_input_specs,
+ encoder_config.get_output_specs(regalloc_preprocessing_creator),
+ encoder_config.get_preprocessing_layer_creator(),
+ )
+
+ logging.info('Compiling model.')
+ loss, loss_weights = encoder_config.get_loss()
+ lr_model.compile(
+ optimizer=tf.keras.optimizers.Adam(global_clipnorm=1.0),
+ loss=loss,
+ loss_weights=loss_weights,
+ metrics=encoder_config.get_metrics(),
+ )
+
+ logging.info('Creating dataset.')
+ dataset = dataset_fn(FLAGS.trace)
+ tb = tf.keras.callbacks.TensorBoard(
+ log_dir=FLAGS.root_dir,
+ histogram_freq=0,
+ write_graph=True,
+ write_steps_per_second=True,
+ update_freq='batch',
+ )
+
+ policy_dir = os.path.join(FLAGS.root_dir, 'policy')
+ checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ filepath=policy_dir, save_weights_only=False, save_freq=1024
+ )
+
+ encoder_dir = os.path.join(FLAGS.root_dir, 'encoder')
+ logging.info('Saving the encoder to %s', encoder_dir)
+ encoder_saver = SaveEncoderCallback(
+ encoder=lr_model.get_encoder(), path=encoder_dir
+ )
+
+ logging.info('Starting training.')
+ lr_model.fit(dataset, epochs=8, callbacks=[tb, checkpoint, encoder_saver])
+
+ logging.info('Training complete.')
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/compiler_opt/rl/train_semisupervised.py b/compiler_opt/rl/train_semisupervised.py
new file mode 100644
index 0000000..209f2f9
--- /dev/null
+++ b/compiler_opt/rl/train_semisupervised.py
@@ -0,0 +1,490 @@
+from absl import app
+import tensorflow as tf
+import tensorflow_text as text
+
+import numpy as np
+
+import os
+from typing import Optional, List
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import feature_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import registry
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import attention
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import regalloc_network
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+from absl import flags
+from absl import logging
+import gin
+# from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+
+
+# Have one class for the Encoder
+# Have one class for the full semisupervised model
+# only save the Encoder
+# Output heads:
+# - reward
+# - next state
+# -
+
+flags.DEFINE_multi_string(
+ 'gin_files', [], 'List of paths to gin configuration files.'
+)
+
+flags.DEFINE_string('trace', None, '')
+flags.DEFINE_string('root_dir', None, '')
+
+flags.DEFINE_string('tpu', None, 'BNS address for the TPU')
+
+
+FLAGS = flags.FLAGS
+
+_ACTION_KEY = 'action'
+_CUR_STATE_KEY = 'cur_state'
+_NEXT_STATE_KEY = 'next_state'
+_NEXT_ACTION_KEY = 'next_action'
+_MLM_KEY = 'mlm'
+
+_INSTR_COUNT = 64
+
+_MLM_IGNORE_TOKEN = -1
+_MLM_MASK_TOKEN = 18000 - 1
+
+
+class Model(tf.keras.Model):
+
+ def __init__(
+ self,
+ encoder_network,
+ state_head,
+ action_head,
+ next_state_head,
+ next_action_head,
+ mlm_head,
+ ):
+ super().__init__(name='Model')
+ self._encoder_network = encoder_network
+ self._state_head = state_head
+ self._action_head = action_head
+ self._next_state_head = next_state_head
+ self._next_action_head = next_action_head
+ self._mlm_head = mlm_head
+
+ def call(self, inputs):
+ # TODO:
+ # 1) export the non-reduced tensor because need it for the MLM outputs
+ # 2) when reducing, mask the tensors that don't matter
+ # for masking, I think I should set the relevant outputs to zero and also zero out the labels.
+ # 3) figure out the correct masks when training, might need custom training loop or custom loss functions.
+ observation = inputs['obs']
+ action = tf.one_hot(inputs['action'], depth=33)[:, :, tf.newaxis]
+
+ use_def_obs = {
+ k: v for k, v in observation.items() if k.startswith('use_def_')
+ }
+ encoded_state_per_token, encoded_state = self._encoder_network(use_def_obs)
+
+ encoded_state_with_action = tf.concat([encoded_state, action], axis=-1)
+
+ state = self._state_head(encoded_state)
+ action = self._action_head(encoded_state)
+ next_state = self._next_state_head(encoded_state_with_action)
+ next_action = self._next_action_head(encoded_state_with_action)
+ mlm = self._mlm_head(encoded_state_per_token)
+ return {
+ _CUR_STATE_KEY: state,
+ _ACTION_KEY: action,
+ _NEXT_STATE_KEY: next_state,
+ _NEXT_ACTION_KEY: next_action,
+ _MLM_KEY: mlm,
+ }
+
+
+def action_loss(label, pred):
+ # Use the current mask
+ print(label)
+ pass
+
+
+def state_loss(label, pred):
+ # Decide whether to use the intersection of the masks or just the current mask
+ # Probably use the current mask
+ print(label)
+ pass
+
+
+def masked_language_loss(label, pred):
+ mask = label != 0
+ loss_object = tf.keras.losses.SparseCategoricalCrossEntropy(
+ from_logits=True, reduction='none'
+ )
+ loss = loss_object(label, pred)
+
+ mask = tf.cast(mask, dtype=loss.dtype)
+ loss *= mask
+
+ loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
+ return loss
+
+
+def create_model(encoder_network):
+ action_head = tf.keras.Sequential([
+ tf.keras.layers.Dense(128),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(1),
+ ])
+ state_head = tf.keras.Sequential([
+ tf.keras.layers.Dense(128),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(65),
+ ])
+ next_state_head = tf.keras.Sequential([
+ tf.keras.layers.Dense(128),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(65),
+ ])
+ next_action_head = tf.keras.Sequential([
+ tf.keras.layers.Dense(128),
+ tf.keras.layers.ReLU(),
+ tf.keras.layers.Dense(1),
+ ])
+ mlm = tf.keras.Sequential([
+ tf.keras.layers.Dense(18000),
+ ])
+ return Model(
+ encoder_network=encoder_network,
+ action_head=action_head,
+ state_head=state_head,
+ next_state_head=next_state_head,
+ next_action_head=next_action_head,
+ mlm_head=mlm,
+ )
+
+
+def _create_parser_fn(time_step_spec, action_spec):
+ context_features = {}
+ sequence_features = {
+ tensor_spec.name: tf.io.FixedLenSequenceFeature(
+ shape=tensor_spec.shape, dtype=tensor_spec.dtype
+ )
+ for tensor_spec in time_step_spec.observation.values()
+ }
+ sequence_features[action_spec.name] = tf.io.FixedLenSequenceFeature(
+ shape=action_spec.shape, dtype=action_spec.dtype
+ )
+
+ def _parser_fn(serialized_proto):
+ with tf.name_scope('parse'):
+ _, parsed_sequence = tf.io.parse_single_sequence_example(
+ serialized_proto,
+ context_features=context_features,
+ sequence_features=sequence_features,
+ )
+ return parsed_sequence
+
+ return _parser_fn
+
+
+def _get_masked_input_and_labels(encoded_texts):
+ # https://keras.io/examples/nlp/masked_language_modeling/
+ # 15% BERT masking
+ encoded_texts_shape = tf.shape(encoded_texts)
+ inp_mask = tf.random.uniform(encoded_texts_shape) < 0.15
+ inp_mask[encoded_texts <= 0] = False
+ labels = _MLM_IGNORE_TOKEN * tf.ones(encoded_texts_shape, dtype=fp.int64)
+ # Set labels for masked tokens
+ labels[inp_mask] = encoded_texts[inp_mask]
+
+ # Prepare input
+ encoded_texts_masked = tf.identity(encoded_texts)
+ # Set input to [MASK] which is the last token for the 90% of tokens
+ # This means leaving 10% unchanged
+ inp_mask_2mask = inp_mask & (tf.random.uniform(encoded_texts_shape) < 0.90)
+ encoded_texts_masked[inp_mask_2mask] = _MLM_MASK_TOKEN
+
+ # Set 10% to a random token
+ inp_mask_2random = inp_mask_2mask & (
+ tf.random.uniform(encoded_texts_shape) < 1 / 9
+ )
+ encoded_texts_masked[inp_mask_2random] = tf.random.uniform(
+ encoded_texts_shape, 3, mask_token_id, dtype=tf.int64
+ )
+
+ # Prepare sample_weights to pass to .fit() method
+ sample_weights = tf.ones(encoded_texts_shape)
+ sample_weights[labels == -1] = 0
+
+ # y_labels would be same as encoded_texts i.e input tokens
+ y_labels = tf.identity(encoded_texts)
+
+ return encoded_texts_masked, y_labels, sample_weights
+
+
+_MAX_PREDICTIONS_PER_BATCH = 32
+random_selector = text.RandomItemSelector(
+ max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
+ selection_rate=0.2,
+ unselectable_ids=[0],
+)
+mask_values_chooser = text.MaskValuesChooser(18000, _MLM_MASK_TOKEN, 0.8)
+
+
+def get_masked_input_and_labels(encoded_texts):
+ masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(
+ tf.RaggedTensor.from_tensor(encoded_texts, padding=0),
+ item_selector=random_selector,
+ mask_values_chooser=mask_values_chooser,
+ )
+ # NEed to fix this
+ # Produce tensor 0 to 32, tile it and produce tensr of indices, then flatten
+ # Then can use it in the scatter like I want
+ masked_pos = masked_pos.to_tensor(
+ default_value=-1, shape=(33, _MAX_PREDICTIONS_PER_BATCH)
+ )
+
+ ii = tf.tile(
+ tf.range(33, dtype=tf.int64)[:, tf.newaxis],
+ [1, _MAX_PREDICTIONS_PER_BATCH],
+ )
+ masked_pos = tf.stack([ii, masked_pos], axis=-1)
+ scatter_values = tf.where(masked_pos[:, :, 1] < 0, 0.0, 1.0)
+ masked_pos = tf.where(
+ masked_pos < 0, tf.constant(0, dtype=tf.int64), masked_pos
+ )
+ weights = tf.scatter_nd(
+ masked_pos, scatter_values[:, :, tf.newaxis], (33, _INSTR_COUNT, 1)
+ )
+ return (
+ masked_token_ids.to_tensor(default_value=0, shape=(33, _INSTR_COUNT)),
+ encoded_texts,
+ weights,
+ )
+
+
+def create_dataset_fn(
+ time_step_spec, action_spec, batch_size, shift, preprocessing_layer_creator
+):
+ assert shift < 0
+ files_buffer_size = 100
+ num_readers = 10
+ num_map_threads = 8
+ shuffle_buffer_size = 256
+
+ parser_fn = _create_parser_fn(time_step_spec, action_spec)
+
+ def _roll_experience(seq_ex):
+ # Use tf agents nest map fn here
+ def _roll(atom):
+ return tf.roll(atom, shift=shift, axis=0)
+
+ def _cutoff(atom):
+ return atom[:shift]
+
+ seq_ex_roll = tf.nest.map_structure(_roll, seq_ex)
+ seq_ex_roll = tf.nest.map_structure(_cutoff, seq_ex_roll)
+ seq_ex = tf.nest.map_structure(_cutoff, seq_ex)
+ return seq_ex, seq_ex_roll
+
+ preprocessing_layers = tf.nest.map_structure(
+ preprocessing_layer_creator, time_step_spec.observation
+ )
+
+ def split_experience(seq_ex, seq_ex_roll):
+ obs = {k: seq_ex[k] for k in seq_ex if k != action_spec.name}
+ action = seq_ex[action_spec.name]
+ obs_roll = {k: seq_ex_roll[k] for k in seq_ex_roll if k != action_spec.name}
+ action_roll = seq_ex_roll[action_spec.name]
+ return {
+ 'obs': obs,
+ 'action': action,
+ 'obs_roll': obs_roll,
+ 'action_roll': action_roll,
+ }
+
+ def _preprocessing_layer(seq_ex):
+ for layer_name, layer in preprocessing_layers.items():
+ seq_ex[layer_name] = layer(seq_ex[layer_name])
+ return seq_ex
+
+ def preprocess_experience(obs_dict):
+ obs_dict['obs_cur'] = _preprocessing_layer(obs_dict['obs'].copy())
+ obs_dict['obs_cur'] = tf.concat(
+ [
+ v
+ for k, v in obs_dict['obs_cur'].items()
+ if not k.startswith('use_def_')
+ ],
+ axis=-1,
+ )
+ obs_dict['obs_roll'] = _preprocessing_layer(obs_dict['obs_roll'])
+ obs_dict['obs_roll'] = tf.concat(
+ [
+ v
+ for k, v in obs_dict['obs_roll'].items()
+ if not k.startswith('use_def_')
+ ],
+ axis=-1,
+ )
+ return obs_dict
+
+ def to_inputs_and_labels(obs_dict):
+ inputs = {'obs': obs_dict['obs'], 'action': obs_dict['action']}
+ mlm_input, mlm_label, mlm_weight = get_masked_input_and_labels(
+ obs_dict['obs']['use_def_opcode'][:, :_INSTR_COUNT]
+ )
+ inputs['use_def_opcode'] = mlm_input
+
+ labels = {
+ _CUR_STATE_KEY: obs_dict['obs_cur'],
+ _ACTION_KEY: tf.expand_dims(obs_dict['action'], axis=-1),
+ _NEXT_STATE_KEY: obs_dict['obs_roll'],
+ _NEXT_ACTION_KEY: tf.expand_dims(obs_dict['action_roll'], axis=-1),
+ _MLM_KEY: mlm_label,
+ }
+ mask = obs_dict['obs']['mask']
+ sample_weights = {
+ _CUR_STATE_KEY: mask,
+ _ACTION_KEY: mask,
+ _NEXT_STATE_KEY: None,
+ _NEXT_ACTION_KEY: None,
+ _MLM_KEY: mlm_weight,
+ }
+ return (inputs, labels, sample_weights)
+
+ def _file_dataset_fn(data_path):
+ return (
+ tf.data.Dataset.list_files(data_path)
+ .shuffle(files_buffer_size)
+ .interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=num_readers,
+ block_length=1,
+ )
+ .filter(lambda string: tf.strings.length(string) > 0)
+ .map(parser_fn, num_parallel_calls=num_map_threads)
+ .map(_roll_experience, num_parallel_calls=num_map_threads)
+ .map(split_experience, num_parallel_calls=num_map_threads)
+ .map(preprocess_experience, num_parallel_calls=num_map_threads)
+ .unbatch()
+ .shuffle(shuffle_buffer_size)
+ .map(to_inputs_and_labels, num_parallel_calls=num_map_threads)
+ .batch(batch_size, drop_remainder=True)
+ .prefetch(tf.data.experimental.AUTOTUNE)
+ )
+
+ return _file_dataset_fn
+
+
+def get_strategy():
+ if FLAGS.tpu:
+ logging.info('Using TPU strategy.')
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+ tf.config.experimental_connect_to_cluster(resolver)
+ tf.tpu.experimental.initialize_tpu_system(resolver)
+ return tf.distribute.TPUStrategy(resolver)
+ logging.info('Using CPU strategy.')
+ return tf.distribute.get_strategy()
+
+
+class SaveEncoderCallback(tf.keras.callbacks.Callback):
+
+ def __init__(self, encoder, path):
+ self._encoder = encoder
+ self._path = path
+
+ def on_epoch_end(self, epoch, logs=None):
+ sm_path = os.path.join(self._path, f'epoch{epoch}')
+ tflite_path = os.path.join(sm_path, policy_saver.TFLITE_MODEL_NAME)
+ self._encoder.save(sm_path)
+ policy_saver.convert_saved_model(sm_path, tflite_path)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(
+ FLAGS.gin_files, bindings=None, skip_unknown=False
+ )
+ logging.info(gin.config_str())
+
+ problem_config = registry.get_configuration()
+ time_step_spec, action_spec = problem_config.get_signature_spec()
+ preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
+ dataset_fn = create_dataset_fn(
+ time_step_spec,
+ action_spec,
+ batch_size=256,
+ shift=-1,
+ preprocessing_layer_creator=preprocessing_layer_creator,
+ )
+
+ # inputs, labels, weights = next(iter(dataset))
+ # import sys
+ # sys.exit()
+
+ strategy = get_strategy()
+ with strategy.scope():
+ logging.info('Creating model.')
+ encoder_network = regalloc_network.InstructionEncoderNetwork(
+ preprocessing_layer_creator
+ )
+ model = create_model(encoder_network)
+
+ logging.info('Compiling model.')
+
+ opt = tf.keras.optimizers.Adam(global_clipnorm=1.0)
+
+ model.compile(
+ optimizer=opt,
+ loss={
+ _CUR_STATE_KEY: 'mse',
+ _ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ ),
+ _NEXT_STATE_KEY: 'mse',
+ _NEXT_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ ),
+ _MLM_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ ),
+ },
+ loss_weights={
+ _ACTION_KEY: 1,
+ _NEXT_STATE_KEY: 1,
+ _NEXT_ACTION_KEY: 1,
+ _MLM_KEY: 10,
+ },
+ metrics={
+ _CUR_STATE_KEY: [],
+ _ACTION_KEY: ['accuracy'],
+ _NEXT_STATE_KEY: [],
+ _NEXT_ACTION_KEY: ['accuracy'],
+ _MLM_KEY: ['accuracy'],
+ },
+ )
+
+ logging.info('Creating dataset.')
+ dataset = dataset_fn(FLAGS.trace)
+ tb = tf.keras.callbacks.TensorBoard(
+ log_dir=FLAGS.root_dir,
+ histogram_freq=0,
+ write_graph=True,
+ write_steps_per_second=True,
+ update_freq='batch',
+ )
+
+ policy_dir = os.path.join(FLAGS.root_dir, 'policy')
+ checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ filepath=policy_dir, save_weights_only=False, save_freq=1024
+ )
+
+ encoder_dir = os.path.join(FLAGS.root_dir, 'encoder')
+ logging.info('Saving the encoder to %s', encoder_dir)
+ encoder_saver = SaveEncoderCallback(encoder=encoder_network, path=encoder_dir)
+
+ logging.info('Starting training.')
+ model.fit(dataset, epochs=8, callbacks=[tb, checkpoint, encoder_saver])
+
+ logging.info('Training complete.')
+
+
+if __name__ == '__main__':
+ app.run(main)