Add self-supervised training code.

This is only for staging - DO NOT MERGE.
diff --git a/compiler_opt/rl/attention.py b/compiler_opt/rl/attention.py
new file mode 100644
index 0000000..30915b6
--- /dev/null
+++ b/compiler_opt/rl/attention.py
@@ -0,0 +1,183 @@
+"""attention.py
+
+Defines various building-blocks for attention in neural networks. Namely, the
+transformer encoder, which is a sequence-to-sequence model which uses
+self-attention to model relationships within the sequence.
+"""
+
+import tensorflow as tf
+import numpy as np
+
+from typing import Optional, List
+
+
+def positional_encoding(length, depth):
+  """Build a positional encoding tensor.
+
+  Taken from https://www.tensorflow.org/text/tutorials/transformer.
+
+  Args:
+    length: the number of sin/cos samples to generate.
+    depth: the depth of the embedding which the encoding will be summed with.
+
+  Returns:
+    A tensor of shape (length, depth) representing the positional encoding.
+  """
+  depth = depth / 2
+
+  positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
+  depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)
+
+  angle_rates = 1 / (10000**depths)  # (1, depth)
+  angle_rads = positions * angle_rates  # (pos, depth)
+
+  pos_encoding = np.concatenate(
+      [np.sin(angle_rads), np.cos(angle_rads)], axis=-1
+  )
+
+  return tf.cast(pos_encoding, dtype=tf.float32)
+
+
+class PositionalEmbedding(tf.keras.layers.Layer):
+  """A positional embedding layer.
+
+  A "positional embedding" is a sum of a token embedding with a positional
+  encoding, which is used as the initial layer of a transformer encoder
+  network.
+
+  Taken from https://www.tensorflow.org/text/tutorials/transformer.
+  """
+
+  def __init__(self, vocab_size, d_model):
+    """Initialize the positional embedding.
+
+    Args:
+      vocab_size: the size of the vocab, which should be one more than the
+        maximum token value which will be seen during training/inference.
+      d_model: the dimension of the model (size of the embedding vector)
+    """
+    super().__init__()
+    self.d_model = d_model
+    self.embedding = tf.keras.layers.Embedding(
+        vocab_size, d_model, mask_zero=True
+    )
+    self.pos_encoding = tf.constant(
+        positional_encoding(length=2048, depth=d_model), dtype=tf.float32
+    )
+
+  def compute_mask(self, *args, **kwargs):
+    """Returns a mask for the given input."""
+    return self.embedding.compute_mask(*args, **kwargs)
+
+  def call(self, x):
+    """Perform the positional embedding."""
+    length = tf.shape(x)[-1]
+    x = self.embedding(x)
+    # This factor sets the relative scale of the embedding and positonal_encoding.
+    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
+    x = x + self.pos_encoding[tf.newaxis, :length, :]
+    return x
+
+
+class TransformerEncoderLayer(tf.keras.layers.Layer):
+  """Transformer Encoder.
+
+  See https://arxiv.org/abs/1706.03762 for more details.
+  """
+
+  def __init__(
+      self,
+      num_heads: int,
+      model_dim: int,
+      fcn_dim: int,
+      attention_axes: Optional[List] = None,
+  ):
+    """Initialize the transformer encoder.
+
+    Args:
+      num_heads: number of distinct attention heads within the layer.
+      model_dim: dimension of the model, also the dimension of the embedding.
+      fcn_dim: dimension of the fully-connected layers between attention layers.
+      attention_axes: which axes in the input tensors perform attention across.
+    """
+    super().__init__()
+    self._mha = tf.keras.layers.MultiHeadAttention(
+        num_heads=num_heads, key_dim=model_dim, attention_axes=attention_axes
+    )
+    self._mha_norm = tf.keras.layers.LayerNormalization()
+    self._mha_add = tf.keras.layers.Add()
+
+    self._fcn = tf.keras.Sequential([
+        tf.keras.layers.Dense(fcn_dim),
+        tf.keras.layers.ReLU(),
+        tf.keras.layers.Dense(model_dim),
+        tf.keras.layers.Dropout(0.1),
+    ])
+    self._fcn_norm = tf.keras.layers.LayerNormalization()
+    self._fcn_add = tf.keras.layers.Add()
+
+  def call(self, x, attention_mask=None):
+    """Call the transformer encoder."""
+    x_attended = self._mha(
+        query=x, value=x, key=x, attention_mask=attention_mask
+    )
+    x = self._mha_add([x_attended, x])
+    x = self._mha_norm(x)
+
+    x_fcn = self._fcn(x)
+    x = self._fcn_add([x_fcn, x])
+    x = self._fcn_norm(x)
+
+    return x
+
+
+class TransformerClassifier(tf.keras.layers.Layer):
+
+  def __init__(
+      self,
+      *,
+      num_tokens: int,
+      num_layers: int,
+      num_heads: int,
+      model_dim: int,
+      fcn_dim: int,
+      num_extra_features: int = 0,
+  ):
+    super().__init__()
+
+    self._model_dim = model_dim
+    self._ctx_token = tf.constant(num_tokens, dtype=tf.int64)
+    self._embed_layer = PositionalEmbedding(
+        num_tokens + 1, model_dim - num_extra_features
+    )
+    self._transformer_layers = [
+        TransformerEncoderLayer(
+            num_heads=num_heads,
+            model_dim=model_dim,
+            fcn_dim=fcn_dim,
+            attention_axes=(2,),
+        )
+        for _ in range(num_layers)
+    ]
+
+  def __call__(self, x, extra_hidden_state):
+    # [B, 33, 1 + I] --> [B, 33, 1 + I, E]
+    mask = self._embed_layer.compute_mask(x)
+    x = self._embed_layer(x)
+
+    # Append the extra hidden state
+    x = tf.concat([x, extra_hidden_state], axis=-1)
+
+    mask1 = mask[:, :, :, tf.newaxis]
+    mask2 = mask[:, :, tf.newaxis, :]
+    attn_mask = tf.cast(mask1, dtype=tf.int64) + tf.cast(mask2, dtype=tf.int64)
+    attn_mask = attn_mask > 0
+
+    for transformer_layer in self._transformer_layers:
+      x = transformer_layer(x, attention_mask=attn_mask)
+
+    mask_reduce = tf.cast(mask[:, :, :, tf.newaxis], dtype=tf.float32)
+    x_reduced = (tf.reduce_sum(mask_reduce * x, axis=-2)) / (
+        tf.reduce_sum(mask_reduce, axis=-2) + 1e-3
+    )
+    return x, x_reduced
diff --git a/compiler_opt/rl/regalloc/lr_encoder/BUILD b/compiler_opt/rl/regalloc/lr_encoder/BUILD
new file mode 100644
index 0000000..7fde6aa
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/BUILD
@@ -0,0 +1,96 @@
+# Inlining configurations.
+load("//devtools/python/blaze:strict.bzl", "py_strict_library")
+load("//devtools/python/blaze:pytype.bzl", "pytype_strict_binary", "pytype_strict_library")
+
+licenses(["notice"])
+
+package(
+    default_applicable_licenses = ["//third_party/ml_compiler_opt:license"],
+    default_visibility = [
+        "//third_party/ml_compiler_opt:default_visibility",
+    ],
+)
+
+filegroup(
+    name = "gin_files",
+    srcs = glob(["gin_configs/**"]),
+)
+
+py_strict_library(
+    name = "config",
+    srcs = ["config.py"],
+    deps = [
+        "//third_party/ml_compiler_opt/compiler_opt/rl:feature_ops",
+        "//third_party/ml_compiler_opt/compiler_opt/rl/regalloc:config",
+        "//third_party/py/gin",
+        "//third_party/py/tensorflow:tensorflow_no_contrib",
+        "//third_party/py/tf_agents",
+        "//third_party/py/tf_agents/specs:tensor_spec",
+        "//third_party/py/tf_agents/trajectories:time_step",
+    ],
+)
+
+pytype_strict_library(
+    name = "model",
+    srcs = ["model.py"],
+    deps = [
+        ":config",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:attention",
+        "//third_party/py/gin",
+        "//third_party/py/tensorflow:tensorflow_no_contrib",
+        "//third_party/py/tf_agents/utils:nest_utils",
+    ],
+)
+
+pytype_strict_library(
+    name = "dataset_ops",
+    srcs = ["dataset_ops.py"],
+    deps = [
+        ":config",
+        "//third_party/py/tensorflow:tensorflow_no_contrib",
+        "//third_party/py/tensorflow_text",
+    ],
+)
+
+pytype_strict_library(
+    name = "lr_encoder_runner",
+    srcs = ["lr_encoder_runner.py"],
+    deps = [
+        "//third_party/ml_compiler_opt/compiler_opt/rl:compilation_runner",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:corpus",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:log_reader",
+        "//third_party/py/gin",
+        "//third_party/py/tensorflow:tensorflow_no_contrib",
+    ],
+)
+
+pytype_strict_library(
+    name = "lr_encoder_problem_config",
+    srcs = ["__init__.py"],
+    deps = [
+        ":config",
+        ":lr_encoder_runner",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:problem_configuration",
+        "//third_party/py/gin",
+    ],
+)
+
+pytype_strict_binary(
+    name = "train",
+    srcs = ["train.py"],
+    data = [":gin_files"],
+    deps = [
+        ":config",
+        ":dataset_ops",
+        ":model",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:policy_saver",
+        "//third_party/ml_compiler_opt/compiler_opt/rl:registry",
+        "//third_party/ml_compiler_opt/compiler_opt/rl/regalloc:config",
+        "//third_party/py/absl:app",
+        "//third_party/py/absl/flags",
+        "//third_party/py/absl/logging",
+        "//third_party/py/gin",
+        "//third_party/py/tensorflow:tensorflow_google",  # build_cleaner: keep
+        "//third_party/py/tensorflow:tensorflow_no_contrib",
+    ],
+)
diff --git a/compiler_opt/rl/regalloc/lr_encoder/__init__.py b/compiler_opt/rl/regalloc/lr_encoder/__init__.py
new file mode 100644
index 0000000..8100aa8
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/__init__.py
@@ -0,0 +1,24 @@
+"""Implementation of the 'lr_encoder' problem."""
+
+import gin
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import problem_configuration
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import lr_encoder_runner
+
+
+@gin.register(module='configs')
+class LREncoderConfig(problem_configuration.ProblemConfiguration):
+  """Expose the LR encoder configuration."""
+
+  def get_runner_type(self):
+    return lr_encoder_runner.LREncoderRunner
+
+  def get_signature_spec(self):
+    return config.get_lr_encoder_signature_spec()
+
+  def get_preprocessing_layer_creator(self):
+    return config.get_preprocessing_layer_creator()
+
+  def get_nonnormalized_features(self):
+    return config.get_nonnormalized_features()
diff --git a/compiler_opt/rl/regalloc/lr_encoder/config.py b/compiler_opt/rl/regalloc/lr_encoder/config.py
new file mode 100644
index 0000000..11a458e
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/config.py
@@ -0,0 +1,219 @@
+"""LR encoder training config."""
+
+import gin
+import tensorflow as tf
+import tf_agents
+from tf_agents.specs import tensor_spec
+from tf_agents.trajectories import time_step
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import feature_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import config
+
+_NUM_REGISTERS = 33
+_NUM_INSTRUCTIONS = 64
+_OPCODE_VOCAB_SIZE = 20000
+_ENCODING_SIZE = 16
+_OPCODE_KEY = 'lr_use_def_opcode'
+
+_ENCODER_FEATURE_PREFIX = 'lr_use_def'
+
+_ACTION_KEY = 'action'
+_STATE_KEY = 'state'
+_NEXT_STATE_KEY = 'next_state'
+_NEXT_ACTION_KEY = 'next_action'
+_MLM_KEY = 'mlm'
+
+
+@gin.configurable
+def get_lr_encoder_signature_spec():
+  observation_spec = {}
+  observation_spec[_OPCODE_KEY] = tensor_spec.BoundedTensorSpec(
+      dtype=tf.int64,
+      shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+      name=_OPCODE_KEY,
+      minimum=0,
+      maximum=_OPCODE_VOCAB_SIZE,
+  )
+  observation_spec['lr_use_def_read'] = tensor_spec.BoundedTensorSpec(
+      dtype=tf.int64,
+      shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+      name='lr_use_def_read',
+      minimum=0,
+      maximum=1,
+  )
+  observation_spec['lr_use_def_write'] = tensor_spec.BoundedTensorSpec(
+      dtype=tf.int64,
+      shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+      name='lr_use_def_write',
+      minimum=0,
+      maximum=1,
+  )
+  observation_spec.update(
+      {
+          name: tensor_spec.BoundedTensorSpec(
+              dtype=tf.int64,
+              shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+              name=name,
+              minimum=0,
+              maximum=1,
+          )
+          for name in [
+              'lr_use_def_is_use',
+              'lr_use_def_is_def',
+              'lr_use_def_is_implicit',
+              'lr_use_def_is_renamable',
+              'lr_use_def_is_ind_var_update',
+              'lr_use_def_is_hint',
+          ]
+      }
+  )
+  observation_spec['lr_use_def_freq'] = tf.TensorSpec(
+      dtype=tf.float32,
+      shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS),
+      name='lr_use_def_freq',
+  )
+
+  encoding_spec = tf.TensorSpec(
+      dtype=tf.float32, shape=(33, _ENCODING_SIZE), name='lr_encoding'
+  )
+
+  return observation_spec, encoding_spec
+
+
+def get_input_specs():
+  encoder_observation_spec, encoding_spec = get_lr_encoder_signature_spec()
+  del encoding_spec
+
+  regalloc_time_step_spec, regalloc_action_spec = (
+      config.get_regalloc_signature_spec()
+  )
+
+  # Ensure that there are no overlaps in feature names between the encoder and the RL problem
+  common_keys = (
+      encoder_observation_spec.keys()
+      & regalloc_time_step_spec.observation.keys()
+  )
+  assert len(common_keys) == 0
+
+  supervised_input_spec = {
+      **encoder_observation_spec,
+      **regalloc_time_step_spec.observation,
+  }
+
+  return supervised_input_spec
+
+
+def get_output_specs(regalloc_preprocessing_layer_creator):
+  regalloc_time_step_spec, regalloc_action_spec = (
+      config.get_regalloc_signature_spec()
+  )
+  del regalloc_action_spec
+
+  random_regalloc_state = tf_agents.specs.sample_spec_nest(
+      regalloc_time_step_spec.observation
+  )
+  for key in random_regalloc_state:
+    preprocessing_layer = regalloc_preprocessing_layer_creator(
+        regalloc_time_step_spec.observation[key]
+    )
+    random_regalloc_state[key] = preprocessing_layer(
+        tf.expand_dims(random_regalloc_state[key], axis=0)
+    )
+  random_regalloc_state = tf.concat(
+      list(random_regalloc_state.values()), axis=-1
+  )
+  random_regalloc_state = tf.squeeze(random_regalloc_state, axis=0)
+  regalloc_state_shape = random_regalloc_state.shape
+
+  action_spec = tf.TensorSpec(
+      dtype=tf.float32, shape=(_NUM_REGISTERS, 1), name=_ACTION_KEY
+  )
+  state_spec = tf.TensorSpec(
+      dtype=tf.float32, shape=regalloc_state_shape, name=_STATE_KEY
+  )
+  mlm_spec = tf.TensorSpec(
+      dtype=tf.float32,
+      shape=(_NUM_REGISTERS, _NUM_INSTRUCTIONS, _OPCODE_VOCAB_SIZE),
+      name=_MLM_KEY,
+  )
+  return {
+      _ACTION_KEY: action_spec,
+      _STATE_KEY: state_spec,
+      _NEXT_ACTION_KEY: action_spec,
+      _NEXT_STATE_KEY: state_spec,
+      _MLM_KEY: mlm_spec,
+  }
+
+
+def get_loss():
+  loss_fns = {
+      _STATE_KEY: 'mse',
+      _ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+          from_logits=True
+      ),
+      _NEXT_STATE_KEY: 'mse',
+      _NEXT_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+          from_logits=True
+      ),
+      _MLM_KEY: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+  }
+  loss_weights = {
+      _ACTION_KEY: 1.0,
+      _STATE_KEY: 1.0,
+      _NEXT_ACTION_KEY: 1.0,
+      _NEXT_STATE_KEY: 1.0,
+      _MLM_KEY: 10.0,
+  }
+  return loss_fns, loss_weights
+
+
+def get_metrics():
+  return {
+      _ACTION_KEY: ['accuracy'],
+      _STATE_KEY: [],
+      _NEXT_ACTION_KEY: ['accuracy'],
+      _NEXT_STATE_KEY: [],
+      _MLM_KEY: ['accuracy'],
+  }
+
+
+@gin.configurable
+def get_preprocessing_layer_creator(
+    quantile_file_dir='/cns/oz-d/home/mlcompileropt-dev/regalloc-transformer/vocab',
+    with_sqrt=True,
+    with_z_score_normalization=True,
+    eps=1e-8,
+):
+  """Wrapper for observation_processing_layer."""
+  quantile_map = feature_ops.build_quantile_map(quantile_file_dir)
+
+  def preprocessing_layer_creator(obs_spec):
+    """Creates the layer to process observation given obs_spec."""
+    if obs_spec.name == 'lr_use_def_freq':
+      quantile = quantile_map[obs_spec.name]
+      first_non_zero = 0
+      for x in quantile:
+        if x > 0:
+          first_non_zero = x
+          break
+
+      normalize_fn = feature_ops.get_normalize_fn(
+          quantile, with_sqrt, with_z_score_normalization, eps
+      )
+      return tf.keras.layers.Lambda(normalize_fn)
+    return tf.keras.layers.Lambda(feature_ops.identity_fn)
+
+  return preprocessing_layer_creator
+
+
+def get_nonnormalized_features():
+  return [
+      'lr_use_def_opcode',
+      'lr_use_def_read',
+      'lr_use_def_write',
+      'lr_use_def_is_use',
+      'lr_use_def_is_def',
+      'lr_use_def_is_implicit',
+      'lr_use_def_is_renamable',
+      'lr_use_def_is_ind_var_update',
+      'lr_use_def_is_hint',
+  ]
diff --git a/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py b/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py
new file mode 100644
index 0000000..9967f16
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/dataset_ops.py
@@ -0,0 +1,158 @@
+import tensorflow as tf
+import tensorflow_text as text
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+
+_MAX_PREDICTIONS_PER_BATCH = 128
+
+_PAD_TOKEN = 0
+_MLM_IGNORE_TOKEN = -1
+_MLM_MASK_TOKEN = 18000 - 1
+
+
+def _get_masked_language_fn(*, selection_rate=0.33):
+  random_selector = text.RandomItemSelector(
+      max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
+      selection_rate=0.2,
+      unselectable_ids=[_PAD_TOKEN],
+  )
+  mask_values_chooser = text.MaskValuesChooser(
+      config._OPCODE_VOCAB_SIZE, _MLM_MASK_TOKEN, 0.8
+  )
+
+  def fn(opcodes):
+    masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(
+        tf.RaggedTensor.from_tensor(opcodes, padding=_PAD_TOKEN),
+        item_selector=random_selector,
+        mask_values_chooser=mask_values_chooser,
+    )
+
+    masked_pos = masked_pos.to_tensor(
+        default_value=-1,
+        shape=(config._NUM_REGISTERS, _MAX_PREDICTIONS_PER_BATCH),
+    )
+
+    ii = tf.tile(
+        tf.range(config._NUM_REGISTERS, dtype=tf.int64)[:, tf.newaxis],
+        [1, _MAX_PREDICTIONS_PER_BATCH],
+    )
+    masked_pos = tf.stack([ii, masked_pos], axis=-1)
+    scatter_values = tf.where(masked_pos[:, :, 1] < 0, 0.0, 1.0)
+    masked_pos = tf.where(
+        masked_pos < 0, tf.constant(_PAD_TOKEN, dtype=tf.int64), masked_pos
+    )
+    weights = tf.scatter_nd(
+        masked_pos,
+        scatter_values[:, :, tf.newaxis],
+        (config._NUM_REGISTERS, config._NUM_INSTRUCTIONS, 1),
+    )
+    return (
+        masked_token_ids.to_tensor(
+            default_value=0,
+            shape=(config._NUM_REGISTERS, config._NUM_INSTRUCTIONS),
+        ),
+        opcodes,
+        weights,
+    )
+
+  return fn
+
+
+def _roll_experience(seq_ex, *, shift=-1):
+  def _roll(atom):
+    return tf.roll(atom, shift=shift, axis=0)
+
+  def _cutoff(atom):
+    return atom[:shift]
+
+  seq_ex_roll = tf.nest.map_structure(_roll, seq_ex)
+  seq_ex_roll = tf.nest.map_structure(_cutoff, seq_ex_roll)
+  seq_ex = tf.nest.map_structure(_cutoff, seq_ex)
+  return seq_ex, seq_ex_roll
+
+
+def _split_sequence_example(seq_ex, seq_ex_roll):
+  action_name = 'index_to_evict'
+  obs = {k: seq_ex[k] for k in seq_ex if k != action_name}
+  action = seq_ex[action_name]
+  obs_roll = {k: seq_ex_roll[k] for k in seq_ex_roll if k != action_name}
+  action_roll = seq_ex_roll[action_name]
+  return {
+      'obs': obs,
+      'action': action,
+      'obs_roll': obs_roll,
+      'action_roll': action_roll,
+  }
+
+
+def _get_state_preprocessing_layer(
+    regalloc_input_spec, preprocessing_layer_creator
+):
+  preprocessing_layers = {}
+  for name, spec in regalloc_input_spec.items():
+    preprocessing_layers[name] = preprocessing_layer_creator(spec)
+
+  def _preprocessing_layer(seq_ex):
+    pp = []
+    for layer_name, layer in preprocessing_layers.items():
+      pp.append(layer(seq_ex[layer_name]))
+    return tf.concat(pp, axis=-1)
+
+  def _layer(obs_dict):
+    obs_dict['obs'] = obs_dict['obs']
+    obs_dict['obs_cur'] = _preprocessing_layer(obs_dict['obs'])
+    obs_dict['obs_roll'] = _preprocessing_layer(obs_dict['obs_roll'])
+    return obs_dict
+
+  return _layer
+
+
+def _get_to_inputs_and_labels_fn():
+  masked_language_fn = _get_masked_language_fn()
+
+  def fn(obs_dict):
+    inputs = {'obs': obs_dict['obs'], 'action': obs_dict['action']}
+    mlm_input, mlm_label, mlm_weight = masked_language_fn(
+        obs_dict['obs']['lr_use_def_opcode'][:, : config._NUM_INSTRUCTIONS]
+    )
+    inputs['lr_use_def_opcode'] = mlm_input
+
+    labels = {
+        config._STATE_KEY: obs_dict['obs_cur'],
+        config._ACTION_KEY: tf.expand_dims(obs_dict['action'], axis=-1),
+        config._NEXT_STATE_KEY: obs_dict['obs_roll'],
+        config._NEXT_ACTION_KEY: tf.expand_dims(
+            obs_dict['action_roll'], axis=-1
+        ),
+        config._MLM_KEY: mlm_label,
+    }
+    mask = obs_dict['obs']['mask']
+    sample_weights = {
+        config._STATE_KEY: mask,
+        config._ACTION_KEY: mask,
+        config._NEXT_STATE_KEY: None,
+        config._NEXT_ACTION_KEY: None,
+        config._MLM_KEY: mlm_weight,
+    }
+    return (inputs, labels, sample_weights)
+
+  return fn
+
+
+def process_dataset(
+    dataset, regalloc_input_spec, regalloc_preprocessing_layer_creator
+):
+  shuffle_buffer_size = 128
+  num_map_threads = 128
+  state_preprocessing_layer = _get_state_preprocessing_layer(
+      regalloc_input_spec, regalloc_preprocessing_layer_creator
+  )
+  to_inputs_and_labels_fn = _get_to_inputs_and_labels_fn()
+  return (
+      dataset.map(_roll_experience, num_parallel_calls=num_map_threads)
+      .map(_split_sequence_example, num_parallel_calls=num_map_threads)
+      .map(state_preprocessing_layer, num_parallel_calls=num_map_threads)
+      .unbatch()
+      .shuffle(shuffle_buffer_size)
+      .map(to_inputs_and_labels_fn, num_parallel_calls=num_map_threads)
+  )
diff --git a/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin b/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin
new file mode 100644
index 0000000..c0fcb35
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/gin_configs/common.gin
@@ -0,0 +1,14 @@
+config_registry.get_configuration.implementation=@configs.LREncoderConfig
+
+launcher_path=None
+clang_path=None
+
+runners.LREncoderRunner.clang_path=%clang_path
+runners.LREncoderRunner.launcher_path=%launcher_path
+
+problem_config.flags_to_add.add_flags=()
+problem_config.flags_to_delete.delete_flags=('-split-dwarf-file','-split-dwarf-output',)
+# For AFDO profile reinjection set:
+# problem_config.flags_to_replace.replace_flags={'-fprofile-sample-use':'/path/to/gwp.afdo','-fprofile-remapping-file':'/path/to/prof_remap.txt'}
+problem_config.flags_to_replace.replace_flags={'-fprofile-instrument-use-path': '/tmp/corpus_1_21/MergedCS.profdata'}
+#problem_config.flags_to_replace.replace_flags={'-fprofile-instrument-use-path': 'MergedCS_profdata_jacobhegna/MergedCS_profdata'}
diff --git a/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py b/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py
new file mode 100644
index 0000000..3c90edc
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/lr_encoder_runner.py
@@ -0,0 +1,94 @@
+"""Module for collect data of the LR encoder."""
+
+import os
+import tempfile
+from typing import Dict, Tuple
+
+import gin
+import tensorflow as tf
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import compilation_runner
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import corpus
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import log_reader
+
+
+@gin.configurable(module='runners')
+class LREncoderRunner(compilation_runner.CompilationRunner):
+  """Class for collecting data for the LR encoder."""
+
+  def compile_fn(
+      self,
+      command_line: corpus.FullyQualifiedCmdLine,
+      tf_policy_path: str,
+      reward_only: bool,
+  ) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
+    """Run compilation for the given IR file under the given policy.
+
+    Args:
+      command_line: the fully qualified command line.
+      tf_policy_path: path to TF policy direcoty on local disk.
+      reward_only: whether only return reward.
+
+    Returns:
+      A dict mapping from example identifier to tuple containing:
+        sequence_example: A tf.SequenceExample proto describing compilation
+          trace, None if reward_only == True.
+        reward: reward of register allocation.
+
+    Raises:
+      subprocess.CalledProcessError: if process fails.
+      compilation_runner.ProcessKilledError: (which it must pass through) on
+        cancelled work.
+      RuntimeError: if llvm-size produces unexpected output.
+    """
+    assert not tf_policy_path
+
+    working_dir = tempfile.mkdtemp()
+
+    log_path = os.path.join(working_dir, 'log')
+    output_native_path = os.path.join(working_dir, 'native')
+
+    result = {}
+    try:
+      cmdline = []
+      if self._launcher_path:
+        cmdline.append(self._launcher_path)
+      cmdline.extend(
+          [self._clang_path]
+          + list(command_line)
+          + [
+              '-mllvm',
+              '-regalloc-enable-advisor=development',
+              '-mllvm',
+              '-regalloc-lr-encoder-training-log=' + log_path,
+              '-mllvm',
+              '-regalloc-training-log=/dev/null',
+              '-o',
+              output_native_path,
+          ]
+      )
+
+      compilation_runner.start_cancellable_process(
+          cmdline, self._compilation_timeout, self._cancellation_manager
+      )
+
+      if not os.path.exists(log_path):
+        return {}
+
+      # TODO(#202)
+      log_result = log_reader.read_log_as_sequence_examples(log_path)
+
+      for fct_name, trajectory in log_result.items():
+        if not trajectory.HasField('feature_lists'):
+          continue
+        # score = (
+        #    trajectory.feature_lists.feature_list['reward']
+        #    .feature[-1]
+        #    .float_list.value[0]
+        # )
+        result[fct_name] = (trajectory, 1.0)
+
+    finally:
+      tf.io.gfile.rmtree(working_dir)
+
+    return result
diff --git a/compiler_opt/rl/regalloc/lr_encoder/model.py b/compiler_opt/rl/regalloc/lr_encoder/model.py
new file mode 100644
index 0000000..a953aa4
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/model.py
@@ -0,0 +1,114 @@
+import gin
+import tensorflow as tf
+from tf_agents.utils import nest_utils
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import attention
+
+
+class MultiHeadEncoderModel(tf.keras.Model):
+
+  def __init__(self, encoder, heads):
+    super().__init__(name='MultiHeadEncoderModel')
+
+    self._encoder = encoder
+    self._heads = heads
+
+  def get_encoder(self):
+    return self._encoder
+
+  def call(self, inputs, *args):
+    observation = inputs['obs']
+    action = tf.one_hot(inputs['action'], depth=config._NUM_REGISTERS)[
+        :, :, tf.newaxis
+    ]
+
+    use_def_obs = {
+        k: v
+        for k, v in observation.items()
+        if k.startswith(config._ENCODER_FEATURE_PREFIX)
+    }
+    encoded_state_per_token, encoded_state = self._encoder(use_def_obs)
+    encoded_state_with_action = tf.concat([encoded_state, action], axis=-1)
+
+    def get_input(head_name):
+      if head_name == 'mlm':
+        return encoded_state_per_token
+      if head_name.startswith('next_'):
+        return encoded_state_with_action
+      return encoded_state
+
+    outputs = {}
+    for name in self._heads:
+      head = self._heads[name]
+      head_input = get_input(name)
+      outputs[name] = head(head_input)
+
+    return outputs
+
+
+@gin.configurable
+class LiveRangeEncoder(tf.keras.Model):
+
+  def __init__(
+      self,
+      input_specs,
+      preprocessing_layer_creator,
+      *,
+      num_layers=2,
+      num_heads=2,
+      model_dim=32,
+      fcn_dim=128,
+      num_extra_features=10,
+  ):
+    super().__init__(name='LiveRangeEncoder')
+    self._encoder = attention.TransformerClassifier(
+        num_tokens=config._OPCODE_VOCAB_SIZE,
+        num_layers=num_layers,
+        num_heads=num_heads,
+        model_dim=model_dim,
+        fcn_dim=fcn_dim,
+        num_extra_features=num_extra_features,
+    )
+    self._linear_reshape = tf.keras.layers.Dense(config._ENCODING_SIZE)
+
+    self._preprocessing_layers = {}
+    for key in input_specs:
+      self._preprocessing_layers[key] = preprocessing_layer_creator(key)
+
+    def preprocessor(observations):
+      pp_inputs = []
+      for key in input_specs:
+        pp_inputs.append(self._preprocessing_layers[key](observations[key]))
+      return tf.concat(pp_inputs, axis=-1)
+
+    self._preprocessor = preprocessor
+
+  def call(self, observations):
+    extra_hidden_state = self._preprocessor(observations)
+    tokens = observations[config._OPCODE_KEY]
+    encoded_state_per_token, encoded_state = self._encoder(
+        tokens, extra_hidden_state
+    )
+    encoded_state = self._linear_reshape(encoded_state)
+    return encoded_state_per_token, encoded_state
+
+
+def create_model(
+    input_specs: dict,
+    output_specs: dict,
+    preprocessing_layer_creator: dict,
+    *,
+    mlp_width: int = 128,
+):
+  encoder = LiveRangeEncoder(input_specs, preprocessing_layer_creator)
+
+  output_heads = {}
+  for name, spec in output_specs.items():
+    size = spec.shape[-1]
+    output_heads[name] = tf.keras.Sequential([
+        tf.keras.layers.Dense(mlp_width),
+        tf.keras.layers.ReLU(),
+        tf.keras.layers.Dense(size),
+    ])
+
+  return MultiHeadEncoderModel(encoder=encoder, heads=output_heads)
diff --git a/compiler_opt/rl/regalloc/lr_encoder/train.py b/compiler_opt/rl/regalloc/lr_encoder/train.py
new file mode 100644
index 0000000..ca0b628
--- /dev/null
+++ b/compiler_opt/rl/regalloc/lr_encoder/train.py
@@ -0,0 +1,183 @@
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+
+import gin
+import tensorflow as tf
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import config as regalloc_config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import dataset_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import model
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc.lr_encoder import config as encoder_config
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import registry  # pylint: disable=unused-import
+
+
+flags.DEFINE_multi_string(
+    'gin_files', [], 'List of paths to gin configuration files.'
+)
+
+flags.DEFINE_string('trace', None, '')
+flags.DEFINE_string('root_dir', None, '')
+
+flags.DEFINE_string('tpu', None, 'BNS address for the TPU')
+
+FLAGS = flags.FLAGS
+
+
+def _create_parser_fn(time_step_spec, action_spec):
+  context_features = {}
+  sequence_features = {
+      tensor_spec.name: tf.io.FixedLenSequenceFeature(
+          shape=tensor_spec.shape, dtype=tensor_spec.dtype
+      )
+      for tensor_spec in time_step_spec.values()
+  }
+  sequence_features[action_spec.name] = tf.io.FixedLenSequenceFeature(
+      shape=action_spec.shape, dtype=action_spec.dtype
+  )
+
+  def _parser_fn(serialized_proto):
+    with tf.name_scope('parse'):
+      _, parsed_sequence = tf.io.parse_single_sequence_example(
+          serialized_proto,
+          context_features=context_features,
+          sequence_features=sequence_features,
+      )
+      return parsed_sequence
+
+  return _parser_fn
+
+
+def get_dataset_creator_fn(
+    batch_size,
+    input_spec,
+    action_spec,
+    regalloc_input_spec,
+    regalloc_preprocessing_layer_creator,
+):
+  files_buffer_size = 64
+  num_map_threads = 128
+  num_readers = 64
+  parser_fn = _create_parser_fn(input_spec, action_spec)
+
+  def _file_dataset_fn(data_path):
+    dataset = (
+        tf.data.Dataset.list_files(data_path)
+        .shuffle(files_buffer_size)
+        .interleave(
+            tf.data.TFRecordDataset,
+            cycle_length=num_readers,
+            block_length=1,
+        )
+        .filter(lambda string: tf.strings.length(string) > 0)
+        .map(parser_fn, num_parallel_calls=num_map_threads)
+    )
+    dataset = dataset_ops.process_dataset(
+        dataset, regalloc_input_spec, regalloc_preprocessing_layer_creator
+    )
+    return dataset.batch(batch_size, drop_remainder=True).prefetch(
+        tf.data.experimental.AUTOTUNE
+    )
+
+  return _file_dataset_fn
+
+
+def get_strategy():
+  if FLAGS.tpu:
+    logging.info('Using TPU strategy.')
+    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+    tf.config.experimental_connect_to_cluster(resolver)
+    tf.tpu.experimental.initialize_tpu_system(resolver)
+    return tf.distribute.TPUStrategy(resolver)
+  logging.info('Using CPU strategy.')
+  return tf.distribute.get_strategy()
+
+
+class SaveEncoderCallback(tf.keras.callbacks.Callback):
+
+  def __init__(self, encoder, path):
+    self._encoder = encoder
+    self._path = path
+
+  def on_epoch_end(self, epoch, logs=None):
+    sm_path = os.path.join(self._path, f'epoch{epoch}')
+    tflite_path = os.path.join(sm_path, policy_saver.TFLITE_MODEL_NAME)
+    self._encoder.save(sm_path)
+    policy_saver.convert_saved_model(sm_path, tflite_path)
+
+
+def main(_):
+  gin.parse_config_files_and_bindings(
+      FLAGS.gin_files, bindings=None, skip_unknown=False
+  )
+  logging.info(gin.config_str())
+
+  lr_input_specs, lr_encoding_spec = (
+      encoder_config.get_lr_encoder_signature_spec()
+  )
+
+  regalloc_time_step_spec, regalloc_action_spec = (
+      regalloc_config.get_regalloc_signature_spec()
+  )
+  regalloc_preprocessing_creator = (
+      regalloc_config.get_observation_processing_layer_creator()
+  )
+
+  dataset_fn = get_dataset_creator_fn(
+      batch_size=64,
+      input_spec=encoder_config.get_input_specs(),
+      action_spec=regalloc_action_spec,
+      regalloc_input_spec=regalloc_time_step_spec.observation,
+      regalloc_preprocessing_layer_creator=regalloc_preprocessing_creator,
+  )
+
+  strategy = get_strategy()
+  with strategy.scope():
+    logging.info('Creating model.')
+    lr_model = model.create_model(
+        lr_input_specs,
+        encoder_config.get_output_specs(regalloc_preprocessing_creator),
+        encoder_config.get_preprocessing_layer_creator(),
+    )
+
+    logging.info('Compiling model.')
+    loss, loss_weights = encoder_config.get_loss()
+    lr_model.compile(
+        optimizer=tf.keras.optimizers.Adam(global_clipnorm=1.0),
+        loss=loss,
+        loss_weights=loss_weights,
+        metrics=encoder_config.get_metrics(),
+    )
+
+  logging.info('Creating dataset.')
+  dataset = dataset_fn(FLAGS.trace)
+  tb = tf.keras.callbacks.TensorBoard(
+      log_dir=FLAGS.root_dir,
+      histogram_freq=0,
+      write_graph=True,
+      write_steps_per_second=True,
+      update_freq='batch',
+  )
+
+  policy_dir = os.path.join(FLAGS.root_dir, 'policy')
+  checkpoint = tf.keras.callbacks.ModelCheckpoint(
+      filepath=policy_dir, save_weights_only=False, save_freq=1024
+  )
+
+  encoder_dir = os.path.join(FLAGS.root_dir, 'encoder')
+  logging.info('Saving the encoder to %s', encoder_dir)
+  encoder_saver = SaveEncoderCallback(
+      encoder=lr_model.get_encoder(), path=encoder_dir
+  )
+
+  logging.info('Starting training.')
+  lr_model.fit(dataset, epochs=8, callbacks=[tb, checkpoint, encoder_saver])
+
+  logging.info('Training complete.')
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/compiler_opt/rl/train_semisupervised.py b/compiler_opt/rl/train_semisupervised.py
new file mode 100644
index 0000000..209f2f9
--- /dev/null
+++ b/compiler_opt/rl/train_semisupervised.py
@@ -0,0 +1,490 @@
+from absl import app
+import tensorflow as tf
+import tensorflow_text as text
+
+import numpy as np
+
+import os
+from typing import Optional, List
+
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import feature_ops
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import registry
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import attention
+from google3.third_party.ml_compiler_opt.compiler_opt.rl.regalloc import regalloc_network
+from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+from absl import flags
+from absl import logging
+import gin
+# from google3.third_party.ml_compiler_opt.compiler_opt.rl import policy_saver
+
+
+# Have one class for the Encoder
+# Have one class for the full semisupervised model
+# only save the Encoder
+# Output heads:
+#   - reward
+#   - next state
+#   -
+
+flags.DEFINE_multi_string(
+    'gin_files', [], 'List of paths to gin configuration files.'
+)
+
+flags.DEFINE_string('trace', None, '')
+flags.DEFINE_string('root_dir', None, '')
+
+flags.DEFINE_string('tpu', None, 'BNS address for the TPU')
+
+
+FLAGS = flags.FLAGS
+
+_ACTION_KEY = 'action'
+_CUR_STATE_KEY = 'cur_state'
+_NEXT_STATE_KEY = 'next_state'
+_NEXT_ACTION_KEY = 'next_action'
+_MLM_KEY = 'mlm'
+
+_INSTR_COUNT = 64
+
+_MLM_IGNORE_TOKEN = -1
+_MLM_MASK_TOKEN = 18000 - 1
+
+
+class Model(tf.keras.Model):
+
+  def __init__(
+      self,
+      encoder_network,
+      state_head,
+      action_head,
+      next_state_head,
+      next_action_head,
+      mlm_head,
+  ):
+    super().__init__(name='Model')
+    self._encoder_network = encoder_network
+    self._state_head = state_head
+    self._action_head = action_head
+    self._next_state_head = next_state_head
+    self._next_action_head = next_action_head
+    self._mlm_head = mlm_head
+
+  def call(self, inputs):
+    # TODO:
+    # 1) export the non-reduced tensor because need it for the MLM outputs
+    # 2) when reducing, mask the tensors that don't matter
+    #       for masking, I think I should set the relevant outputs to zero and also zero out the labels.
+    # 3) figure out the correct masks when training, might need custom training loop or custom loss functions.
+    observation = inputs['obs']
+    action = tf.one_hot(inputs['action'], depth=33)[:, :, tf.newaxis]
+
+    use_def_obs = {
+        k: v for k, v in observation.items() if k.startswith('use_def_')
+    }
+    encoded_state_per_token, encoded_state = self._encoder_network(use_def_obs)
+
+    encoded_state_with_action = tf.concat([encoded_state, action], axis=-1)
+
+    state = self._state_head(encoded_state)
+    action = self._action_head(encoded_state)
+    next_state = self._next_state_head(encoded_state_with_action)
+    next_action = self._next_action_head(encoded_state_with_action)
+    mlm = self._mlm_head(encoded_state_per_token)
+    return {
+        _CUR_STATE_KEY: state,
+        _ACTION_KEY: action,
+        _NEXT_STATE_KEY: next_state,
+        _NEXT_ACTION_KEY: next_action,
+        _MLM_KEY: mlm,
+    }
+
+
+def action_loss(label, pred):
+  # Use the current mask
+  print(label)
+  pass
+
+
+def state_loss(label, pred):
+  # Decide whether to use the intersection of the masks or just the current mask
+  # Probably use the current mask
+  print(label)
+  pass
+
+
+def masked_language_loss(label, pred):
+  mask = label != 0
+  loss_object = tf.keras.losses.SparseCategoricalCrossEntropy(
+      from_logits=True, reduction='none'
+  )
+  loss = loss_object(label, pred)
+
+  mask = tf.cast(mask, dtype=loss.dtype)
+  loss *= mask
+
+  loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
+  return loss
+
+
+def create_model(encoder_network):
+  action_head = tf.keras.Sequential([
+      tf.keras.layers.Dense(128),
+      tf.keras.layers.ReLU(),
+      tf.keras.layers.Dense(1),
+  ])
+  state_head = tf.keras.Sequential([
+      tf.keras.layers.Dense(128),
+      tf.keras.layers.ReLU(),
+      tf.keras.layers.Dense(65),
+  ])
+  next_state_head = tf.keras.Sequential([
+      tf.keras.layers.Dense(128),
+      tf.keras.layers.ReLU(),
+      tf.keras.layers.Dense(65),
+  ])
+  next_action_head = tf.keras.Sequential([
+      tf.keras.layers.Dense(128),
+      tf.keras.layers.ReLU(),
+      tf.keras.layers.Dense(1),
+  ])
+  mlm = tf.keras.Sequential([
+      tf.keras.layers.Dense(18000),
+  ])
+  return Model(
+      encoder_network=encoder_network,
+      action_head=action_head,
+      state_head=state_head,
+      next_state_head=next_state_head,
+      next_action_head=next_action_head,
+      mlm_head=mlm,
+  )
+
+
+def _create_parser_fn(time_step_spec, action_spec):
+  context_features = {}
+  sequence_features = {
+      tensor_spec.name: tf.io.FixedLenSequenceFeature(
+          shape=tensor_spec.shape, dtype=tensor_spec.dtype
+      )
+      for tensor_spec in time_step_spec.observation.values()
+  }
+  sequence_features[action_spec.name] = tf.io.FixedLenSequenceFeature(
+      shape=action_spec.shape, dtype=action_spec.dtype
+  )
+
+  def _parser_fn(serialized_proto):
+    with tf.name_scope('parse'):
+      _, parsed_sequence = tf.io.parse_single_sequence_example(
+          serialized_proto,
+          context_features=context_features,
+          sequence_features=sequence_features,
+      )
+      return parsed_sequence
+
+  return _parser_fn
+
+
+def _get_masked_input_and_labels(encoded_texts):
+  # https://keras.io/examples/nlp/masked_language_modeling/
+  # 15% BERT masking
+  encoded_texts_shape = tf.shape(encoded_texts)
+  inp_mask = tf.random.uniform(encoded_texts_shape) < 0.15
+  inp_mask[encoded_texts <= 0] = False
+  labels = _MLM_IGNORE_TOKEN * tf.ones(encoded_texts_shape, dtype=fp.int64)
+  # Set labels for masked tokens
+  labels[inp_mask] = encoded_texts[inp_mask]
+
+  # Prepare input
+  encoded_texts_masked = tf.identity(encoded_texts)
+  # Set input to [MASK] which is the last token for the 90% of tokens
+  # This means leaving 10% unchanged
+  inp_mask_2mask = inp_mask & (tf.random.uniform(encoded_texts_shape) < 0.90)
+  encoded_texts_masked[inp_mask_2mask] = _MLM_MASK_TOKEN
+
+  # Set 10% to a random token
+  inp_mask_2random = inp_mask_2mask & (
+      tf.random.uniform(encoded_texts_shape) < 1 / 9
+  )
+  encoded_texts_masked[inp_mask_2random] = tf.random.uniform(
+      encoded_texts_shape, 3, mask_token_id, dtype=tf.int64
+  )
+
+  # Prepare sample_weights to pass to .fit() method
+  sample_weights = tf.ones(encoded_texts_shape)
+  sample_weights[labels == -1] = 0
+
+  # y_labels would be same as encoded_texts i.e input tokens
+  y_labels = tf.identity(encoded_texts)
+
+  return encoded_texts_masked, y_labels, sample_weights
+
+
+_MAX_PREDICTIONS_PER_BATCH = 32
+random_selector = text.RandomItemSelector(
+    max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
+    selection_rate=0.2,
+    unselectable_ids=[0],
+)
+mask_values_chooser = text.MaskValuesChooser(18000, _MLM_MASK_TOKEN, 0.8)
+
+
+def get_masked_input_and_labels(encoded_texts):
+  masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(
+      tf.RaggedTensor.from_tensor(encoded_texts, padding=0),
+      item_selector=random_selector,
+      mask_values_chooser=mask_values_chooser,
+  )
+  # NEed to fix this
+  # Produce tensor 0 to 32, tile it and produce tensr of indices, then flatten
+  # Then can use it in the scatter like I want
+  masked_pos = masked_pos.to_tensor(
+      default_value=-1, shape=(33, _MAX_PREDICTIONS_PER_BATCH)
+  )
+
+  ii = tf.tile(
+      tf.range(33, dtype=tf.int64)[:, tf.newaxis],
+      [1, _MAX_PREDICTIONS_PER_BATCH],
+  )
+  masked_pos = tf.stack([ii, masked_pos], axis=-1)
+  scatter_values = tf.where(masked_pos[:, :, 1] < 0, 0.0, 1.0)
+  masked_pos = tf.where(
+      masked_pos < 0, tf.constant(0, dtype=tf.int64), masked_pos
+  )
+  weights = tf.scatter_nd(
+      masked_pos, scatter_values[:, :, tf.newaxis], (33, _INSTR_COUNT, 1)
+  )
+  return (
+      masked_token_ids.to_tensor(default_value=0, shape=(33, _INSTR_COUNT)),
+      encoded_texts,
+      weights,
+  )
+
+
+def create_dataset_fn(
+    time_step_spec, action_spec, batch_size, shift, preprocessing_layer_creator
+):
+  assert shift < 0
+  files_buffer_size = 100
+  num_readers = 10
+  num_map_threads = 8
+  shuffle_buffer_size = 256
+
+  parser_fn = _create_parser_fn(time_step_spec, action_spec)
+
+  def _roll_experience(seq_ex):
+    # Use tf agents nest map fn here
+    def _roll(atom):
+      return tf.roll(atom, shift=shift, axis=0)
+
+    def _cutoff(atom):
+      return atom[:shift]
+
+    seq_ex_roll = tf.nest.map_structure(_roll, seq_ex)
+    seq_ex_roll = tf.nest.map_structure(_cutoff, seq_ex_roll)
+    seq_ex = tf.nest.map_structure(_cutoff, seq_ex)
+    return seq_ex, seq_ex_roll
+
+  preprocessing_layers = tf.nest.map_structure(
+      preprocessing_layer_creator, time_step_spec.observation
+  )
+
+  def split_experience(seq_ex, seq_ex_roll):
+    obs = {k: seq_ex[k] for k in seq_ex if k != action_spec.name}
+    action = seq_ex[action_spec.name]
+    obs_roll = {k: seq_ex_roll[k] for k in seq_ex_roll if k != action_spec.name}
+    action_roll = seq_ex_roll[action_spec.name]
+    return {
+        'obs': obs,
+        'action': action,
+        'obs_roll': obs_roll,
+        'action_roll': action_roll,
+    }
+
+  def _preprocessing_layer(seq_ex):
+    for layer_name, layer in preprocessing_layers.items():
+      seq_ex[layer_name] = layer(seq_ex[layer_name])
+    return seq_ex
+
+  def preprocess_experience(obs_dict):
+    obs_dict['obs_cur'] = _preprocessing_layer(obs_dict['obs'].copy())
+    obs_dict['obs_cur'] = tf.concat(
+        [
+            v
+            for k, v in obs_dict['obs_cur'].items()
+            if not k.startswith('use_def_')
+        ],
+        axis=-1,
+    )
+    obs_dict['obs_roll'] = _preprocessing_layer(obs_dict['obs_roll'])
+    obs_dict['obs_roll'] = tf.concat(
+        [
+            v
+            for k, v in obs_dict['obs_roll'].items()
+            if not k.startswith('use_def_')
+        ],
+        axis=-1,
+    )
+    return obs_dict
+
+  def to_inputs_and_labels(obs_dict):
+    inputs = {'obs': obs_dict['obs'], 'action': obs_dict['action']}
+    mlm_input, mlm_label, mlm_weight = get_masked_input_and_labels(
+        obs_dict['obs']['use_def_opcode'][:, :_INSTR_COUNT]
+    )
+    inputs['use_def_opcode'] = mlm_input
+
+    labels = {
+        _CUR_STATE_KEY: obs_dict['obs_cur'],
+        _ACTION_KEY: tf.expand_dims(obs_dict['action'], axis=-1),
+        _NEXT_STATE_KEY: obs_dict['obs_roll'],
+        _NEXT_ACTION_KEY: tf.expand_dims(obs_dict['action_roll'], axis=-1),
+        _MLM_KEY: mlm_label,
+    }
+    mask = obs_dict['obs']['mask']
+    sample_weights = {
+        _CUR_STATE_KEY: mask,
+        _ACTION_KEY: mask,
+        _NEXT_STATE_KEY: None,
+        _NEXT_ACTION_KEY: None,
+        _MLM_KEY: mlm_weight,
+    }
+    return (inputs, labels, sample_weights)
+
+  def _file_dataset_fn(data_path):
+    return (
+        tf.data.Dataset.list_files(data_path)
+        .shuffle(files_buffer_size)
+        .interleave(
+            tf.data.TFRecordDataset,
+            cycle_length=num_readers,
+            block_length=1,
+        )
+        .filter(lambda string: tf.strings.length(string) > 0)
+        .map(parser_fn, num_parallel_calls=num_map_threads)
+        .map(_roll_experience, num_parallel_calls=num_map_threads)
+        .map(split_experience, num_parallel_calls=num_map_threads)
+        .map(preprocess_experience, num_parallel_calls=num_map_threads)
+        .unbatch()
+        .shuffle(shuffle_buffer_size)
+        .map(to_inputs_and_labels, num_parallel_calls=num_map_threads)
+        .batch(batch_size, drop_remainder=True)
+        .prefetch(tf.data.experimental.AUTOTUNE)
+    )
+
+  return _file_dataset_fn
+
+
+def get_strategy():
+  if FLAGS.tpu:
+    logging.info('Using TPU strategy.')
+    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+    tf.config.experimental_connect_to_cluster(resolver)
+    tf.tpu.experimental.initialize_tpu_system(resolver)
+    return tf.distribute.TPUStrategy(resolver)
+  logging.info('Using CPU strategy.')
+  return tf.distribute.get_strategy()
+
+
+class SaveEncoderCallback(tf.keras.callbacks.Callback):
+
+  def __init__(self, encoder, path):
+    self._encoder = encoder
+    self._path = path
+
+  def on_epoch_end(self, epoch, logs=None):
+    sm_path = os.path.join(self._path, f'epoch{epoch}')
+    tflite_path = os.path.join(sm_path, policy_saver.TFLITE_MODEL_NAME)
+    self._encoder.save(sm_path)
+    policy_saver.convert_saved_model(sm_path, tflite_path)
+
+
+def main(_):
+  gin.parse_config_files_and_bindings(
+      FLAGS.gin_files, bindings=None, skip_unknown=False
+  )
+  logging.info(gin.config_str())
+
+  problem_config = registry.get_configuration()
+  time_step_spec, action_spec = problem_config.get_signature_spec()
+  preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
+  dataset_fn = create_dataset_fn(
+      time_step_spec,
+      action_spec,
+      batch_size=256,
+      shift=-1,
+      preprocessing_layer_creator=preprocessing_layer_creator,
+  )
+
+  # inputs, labels, weights = next(iter(dataset))
+  # import sys
+  # sys.exit()
+
+  strategy = get_strategy()
+  with strategy.scope():
+    logging.info('Creating model.')
+    encoder_network = regalloc_network.InstructionEncoderNetwork(
+        preprocessing_layer_creator
+    )
+    model = create_model(encoder_network)
+
+    logging.info('Compiling model.')
+
+    opt = tf.keras.optimizers.Adam(global_clipnorm=1.0)
+
+    model.compile(
+        optimizer=opt,
+        loss={
+            _CUR_STATE_KEY: 'mse',
+            _ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+                from_logits=True
+            ),
+            _NEXT_STATE_KEY: 'mse',
+            _NEXT_ACTION_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+                from_logits=True
+            ),
+            _MLM_KEY: tf.keras.losses.SparseCategoricalCrossentropy(
+                from_logits=True
+            ),
+        },
+        loss_weights={
+            _ACTION_KEY: 1,
+            _NEXT_STATE_KEY: 1,
+            _NEXT_ACTION_KEY: 1,
+            _MLM_KEY: 10,
+        },
+        metrics={
+            _CUR_STATE_KEY: [],
+            _ACTION_KEY: ['accuracy'],
+            _NEXT_STATE_KEY: [],
+            _NEXT_ACTION_KEY: ['accuracy'],
+            _MLM_KEY: ['accuracy'],
+        },
+    )
+
+  logging.info('Creating dataset.')
+  dataset = dataset_fn(FLAGS.trace)
+  tb = tf.keras.callbacks.TensorBoard(
+      log_dir=FLAGS.root_dir,
+      histogram_freq=0,
+      write_graph=True,
+      write_steps_per_second=True,
+      update_freq='batch',
+  )
+
+  policy_dir = os.path.join(FLAGS.root_dir, 'policy')
+  checkpoint = tf.keras.callbacks.ModelCheckpoint(
+      filepath=policy_dir, save_weights_only=False, save_freq=1024
+  )
+
+  encoder_dir = os.path.join(FLAGS.root_dir, 'encoder')
+  logging.info('Saving the encoder to %s', encoder_dir)
+  encoder_saver = SaveEncoderCallback(encoder=encoder_network, path=encoder_dir)
+
+  logging.info('Starting training.')
+  model.fit(dataset, epochs=8, callbacks=[tb, checkpoint, encoder_saver])
+
+  logging.info('Training complete.')
+
+
+if __name__ == '__main__':
+  app.run(main)