Add MLGO environments. (#228)

Add MLGO environment abstractions.

This commit also contains an implementation of the inlining-for-size
environment.
diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py
new file mode 100644
index 0000000..15246bf
--- /dev/null
+++ b/compiler_opt/rl/env.py
@@ -0,0 +1,362 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Gymlike environment definition for MLGO."""
+
+from __future__ import annotations
+
+import math
+import subprocess
+import abc
+import contextlib
+import io
+import os
+import tempfile
+from typing import Any, Generator, List, Optional, Tuple, Type
+
+import numpy as np
+
+from compiler_opt.rl import corpus
+from compiler_opt.rl import log_reader
+
+OBS_T = Any
+
+OBS_KEY = 'obs'
+REWARD_KEY = 'reward'
+SCORE_POLICY_KEY = 'score_policy'
+SCORE_DEFAULT_KEY = 'score_default'
+CONTEXT_KEY = 'context'
+MODULE_NAME_KEY = 'module_name'
+OBS_ID_KEY = 'obs_id'
+STEP_TYPE_KEY = 'step_type'
+
+FIRST_STEP_STR = 'first'
+MID_STEP_STR = 'mid'
+LAST_STEP_STR = 'last'
+
+_TERMINAL_OBS = {
+    OBS_KEY: {},
+    REWARD_KEY: 0.0,
+    SCORE_POLICY_KEY: 0.0,
+    SCORE_DEFAULT_KEY: 0.0,
+    CONTEXT_KEY: '',
+    MODULE_NAME_KEY: '',
+    OBS_ID_KEY: -1,
+    STEP_TYPE_KEY: LAST_STEP_STR,
+}
+
+_INTERACTIVE_PIPE_FILE_BASE = 'interactive-pipe-base'
+
+
+class MLGOTask(metaclass=abc.ABCMeta):
+  """Abstract base class for MLGO Tasks.
+
+  A Task is an learning problem in LLVM, for example:
+   - inlining-for-size
+   - inlining-for-speed
+   - register allocation (for speed)
+
+  The Task type for a given problem defines how to build and score modules for
+  the problem, both interactively and non-interactively.
+  """
+
+  @abc.abstractmethod
+  def get_cmdline(self, clang_path: str, base_args: List[str],
+                  interactive_base_path: Optional[str],
+                  working_dir: str) -> List[str]:
+    """Get the cmdline for building with this task.
+
+    The resulting list[str] should be able to be passed to subprocess.run to
+    execute clang.
+
+    Args:
+      clang_path: path to the clang executable.
+      base_args: base arguments for building the module. Generally, these flags
+        should not be modified and simply added to the result.
+      interactive_base_path: the path to the interactive pipe base. if None,
+        then don't run clang interactively.
+      working_dir: directory where all artifacts from compilation should be
+        written. This will be a temp directory whose lifetime is managed outside
+        of the Task.
+
+    Returns:
+      The constructed command line.
+    """
+    pass
+
+  @abc.abstractmethod
+  def get_module_scores(self, working_dir: str) -> dict[str, float]:
+    """Get the scores for each context in the module.
+
+    This method should not be aware of whether the module was built with the
+    default heuristic or a ML policy.
+
+    Args:
+      working_dir: Directory which was passed as working_dir to get_cmdline.
+        Used to recover binaries/artifacts from the build
+
+    Returns:
+      A dictionary mapping [context name] -> [score].
+    """
+    pass
+
+
+class ClangProcess:
+  """Simple wrapper class around a clang process.
+
+  This is used wrap both the clang process and the method to return the scores
+  associated to the default-compiled binary.
+  """
+
+  def __init__(self, proc, get_scores_fn, module_name):
+    self._proc = proc
+    self._get_scores_fn = get_scores_fn
+    self._module_name = module_name
+
+  def get_scores(self, timeout: Optional[int] = None):
+    self._proc.wait(timeout=timeout)
+    return self._get_scores_fn()
+
+
+class InteractiveClang(ClangProcess):
+  """Wrapper around clang's interactive mode."""
+
+  def __init__(
+      self,
+      proc,
+      get_scores_fn,
+      module_name: str,
+      reader_pipe: io.BufferedReader,
+      writer_pipe: io.BufferedWriter,
+  ):
+    super().__init__(proc, get_scores_fn, module_name)
+    self._reader_pipe = reader_pipe
+    self._writer_pipe = writer_pipe
+    self._obs_gen = log_reader.read_log_from_file(self._reader_pipe)
+
+    self._is_first_obs = True
+
+    self._terminal_obs = _TERMINAL_OBS
+    self._terminal_obs[MODULE_NAME_KEY] = module_name
+
+  def _running(self) -> bool:
+    return self._proc.poll() is None
+
+  def get_observation(self) -> OBS_T:
+    if not self._running():
+      return self._terminal_obs
+
+    def _get_step_type():
+      step_type = FIRST_STEP_STR if self._is_first_obs else MID_STEP_STR
+      self._is_first_obs = False
+      return step_type
+
+    try:
+      obs: log_reader.ObservationRecord = next(self._obs_gen)
+
+      tv_dict = {}
+      for fv in obs.feature_values:
+        array = fv.to_numpy()
+        tv_dict[fv.spec.name] = np.reshape(array, newshape=fv.spec.shape)
+      return {
+          OBS_KEY: tv_dict,
+          REWARD_KEY: obs.score if obs.score else 0.0,
+          SCORE_POLICY_KEY: 0.0,
+          SCORE_DEFAULT_KEY: 0.0,
+          CONTEXT_KEY: obs.context,
+          MODULE_NAME_KEY: self._module_name,
+          OBS_ID_KEY: obs.observation_id,
+          STEP_TYPE_KEY: _get_step_type(),
+      }
+    except StopIteration:
+      return self._terminal_obs
+
+  def send_action(self, action: np.ndarray) -> None:
+    assert self._running()
+    data = action.tobytes()
+    bytes_sent = self._writer_pipe.write(data)
+    # Here we use the fact that for common types, the np.dtype and ctype should
+    # behave the same
+    assert bytes_sent == action.dtype.itemsize * math.prod(action.shape)
+    try:
+      self._writer_pipe.flush()
+    except BrokenPipeError:
+      # The pipe can break after we send the last action
+      pass
+
+
+_EPS = 1e-4
+
+
+def compute_relative_rewards(score_a: dict[str, float],
+                             score_b: dict[str, float]) -> dict[str, float]:
+
+  def _reward_fn(a: float, b: float) -> float:
+    return 1.0 - (a + _EPS) / (b + _EPS)
+
+  assert score_a.keys() == score_b.keys()
+  return {key: _reward_fn(score_a[key], score_b[key]) for key in score_a}
+
+
+@contextlib.contextmanager
+def clang_session(
+    clang_path: str,
+    module: corpus.LoadedModuleSpec,
+    task_type: Type[MLGOTask],
+    *,
+    interactive: bool,
+):
+  """Context manager for clang session.
+
+  We need to manage the context so resources like tempfiles and pipes have
+  their lifetimes managed appropriately.
+
+  Args:
+    clang_path: The clang binary to use for the InteractiveClang session.
+    module: The module to compile with clang.
+    task_type: Type of the MLGOTask to use.
+    interactive: Whether to use an interactive or default clang instance
+
+  Yields:
+    Either the constructed InteractiveClang or DefaultClang object.
+  """
+  with tempfile.TemporaryDirectory() as td:
+    task_working_dir = os.path.join(td, '__task_working_dir__')
+    os.mkdir(task_working_dir)
+    task = task_type()
+
+    base_args = list(module.build_command_line(td))
+    interactive_base = os.path.join(
+        td, _INTERACTIVE_PIPE_FILE_BASE) if interactive else None
+    cmdline = task.get_cmdline(clang_path, base_args, interactive_base,
+                               task_working_dir)
+
+    def _get_scores() -> dict[str, float]:
+      return task.get_module_scores(task_working_dir)
+
+    writer_name = os.path.join(td, _INTERACTIVE_PIPE_FILE_BASE + '.in')
+    reader_name = os.path.join(td, _INTERACTIVE_PIPE_FILE_BASE + '.out')
+    if interactive:
+      os.mkfifo(reader_name, 0o666)
+      os.mkfifo(writer_name, 0o666)
+    with subprocess.Popen(
+        cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
+      try:
+        if interactive:
+          with io.BufferedWriter(io.FileIO(writer_name, 'wb')) as writer_pipe:
+            with io.BufferedReader(io.FileIO(reader_name, 'rb')) as reader_pipe:
+              yield InteractiveClang(
+                  proc,
+                  _get_scores,
+                  module.name,
+                  reader_pipe,
+                  writer_pipe,
+              )
+        else:
+          yield ClangProcess(
+              proc,
+              _get_scores,
+              module.name,
+          )
+
+      finally:
+        proc.kill()
+
+
+def _get_clang_generator(
+    clang_path: str,
+    task_type: Type[MLGOTask],
+) -> Generator[Optional[Tuple[ClangProcess, InteractiveClang]],
+               Optional[corpus.LoadedModuleSpec], None]:
+  """Returns a generator for creating InteractiveClang objects.
+
+  TODO: fix this docstring
+
+  Args:
+    clang_path: Path to the clang binary to use within InteractiveClang.
+    task_type: Type of the MLGO task to use.
+
+  Returns:
+    The generator for InteractiveClang objects.
+  """
+  while True:
+    # The following line should be type-hinted as follows:
+    #   module: corpus.LoadedModuleSpec = yield
+    # However, this triggers a yapf crash. See:
+    #   https://github.com/google/yapf/issues/1092
+    module = yield
+    with clang_session(
+        clang_path, module, task_type, interactive=True) as iclang:
+      with clang_session(
+          clang_path, module, task_type, interactive=False) as clang:
+        yield iclang, clang
+
+
+class MLGOEnvironmentBase:
+  """Base implementation for all MLGO environments.
+
+  Depending on the RL framework, one may want different implementations of an
+  enviroment (tf_agents: PyEnvironment, jax: dm-env, etc). This class
+  implements the core methods that are needed to then implement any of these
+  other environments as well.
+  """
+
+  def __init__(
+      self,
+      *,
+      clang_path: str,
+      task_type: Type[MLGOTask],
+      obs_spec,
+      action_spec,
+  ):
+    self._clang_generator = _get_clang_generator(clang_path, task_type)
+    self._obs_spec = obs_spec
+    self._action_spec = action_spec
+
+    self._iclang: Optional[InteractiveClang] = None
+    self._clang: Optional[ClangProcess] = None
+
+  @property
+  def obs_spec(self):
+    return self._obs_spec
+
+  @property
+  def action_spec(self):
+    return self._action_spec
+
+  def observation(self):
+    return self._last_obs
+
+  def _get_observation(self) -> OBS_T:
+    self._last_obs = self._iclang.get_observation()
+    if self._last_obs[STEP_TYPE_KEY] == 'last':
+      self._last_obs[SCORE_POLICY_KEY] = self._iclang.get_scores()
+      self._last_obs[SCORE_DEFAULT_KEY] = self._clang.get_scores()
+      self._last_obs[REWARD_KEY] = compute_relative_rewards(
+          self._last_obs[SCORE_POLICY_KEY], self._last_obs[SCORE_DEFAULT_KEY])
+    return self.observation()
+
+  def reset(self, module: corpus.LoadedModuleSpec):
+    # On the first call to reset(...), sending None starts the coroutine.
+    # On subsequent calls, this resumes execution after
+    # yielding the clang pair, which terminates the session pauses execution in
+    # the coroutine where it awaits a module
+    self._clang_generator.send(None)
+    # pytype: disable=attribute-error
+    self._iclang, self._clang = self._clang_generator.send(module)
+    return self._get_observation()
+
+  def step(self, action: np.ndarray):
+    self._iclang.send_action(action)
+    return self._get_observation()
diff --git a/compiler_opt/rl/env_test.py b/compiler_opt/rl/env_test.py
new file mode 100644
index 0000000..1af508c
--- /dev/null
+++ b/compiler_opt/rl/env_test.py
@@ -0,0 +1,193 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for compiler_opt.rl.env."""
+
+import io
+import contextlib
+import ctypes
+from unittest import mock
+import subprocess
+
+from typing import Dict, List, Optional
+
+import tensorflow as tf
+import numpy as np
+
+from compiler_opt.rl import env
+from compiler_opt.rl import corpus
+from compiler_opt.rl import log_reader_test
+
+_CLANG_PATH = '/test/clang/path'
+
+_MOCK_MODULE = corpus.LoadedModuleSpec(
+    name='module',
+    loaded_ir=b'asdf',
+    orig_options=('--opt_a', 'a', '--opt_b', 'b'),
+)
+
+_NUM_STEPS = 10
+
+
+class MockTask(env.MLGOTask):
+  """Implementation of mock task for testing."""
+
+  def get_cmdline(self, clang_path: str, base_args: List[str],
+                  interactive_base_path: Optional[str],
+                  working_dir: str) -> List[str]:
+    if interactive_base_path:
+      interactive_args = [
+          f'--interactive={interactive_base_path}',
+      ]
+    else:
+      interactive_args = []
+    return [clang_path] + base_args + interactive_args
+
+  def get_module_scores(self, working_dir: str) -> Dict[str, float]:
+    return {'default': 47}
+
+
+# This mocks subprocess.Popen for interactive clang sessions
+@contextlib.contextmanager
+def mock_interactive_clang(cmdline, stderr, stdout):
+  del stderr
+  del stdout
+  # do basic argument parsing
+  fname = None
+  for arg in cmdline:
+    if arg.startswith('--interactive='):
+      fname = arg[len('--interactive='):]
+      break
+
+  class MockProcess:
+
+    def wait(self, timeout):
+      pass
+
+    def kill(self):
+      pass
+
+  if not fname:
+    yield MockProcess()
+    return
+  # Create the fds for the pipes
+  # (the env doesn't create the files, it assumes they are opened by clang)
+  with io.FileIO(fname + '.out', 'wb+') as f_out:
+    with io.FileIO(fname + '.in', 'rb+') as f_in:
+      del f_in
+      # Write the header describing the features/rewards
+      f_out.write(
+          log_reader_test.json_to_bytes({
+              'features': [{
+                  'name': 'times_called',
+                  'port': 0,
+                  'shape': [1],
+                  'type': 'int64_t',
+              },],
+              'score': {
+                  'name': 'reward',
+                  'port': 0,
+                  'shape': [1],
+                  'type': 'float',
+              },
+          }))
+      log_reader_test.write_nl(f_out)
+
+      class MockInteractiveProcess(MockProcess):
+        """Mock clang interactive process that writes the log."""
+
+        def __init__(self):
+          self._counter = 0
+
+        # We poll the process at every call to get_observation to ensure the
+        # clang process is still alive. So here, each time poll() is called,
+        # write a new context
+        def poll(self):
+          if self._counter >= _NUM_STEPS:
+            f_out.close()
+            return None
+          log_reader_test.write_context_marker(f_out,
+                                               f'context_{self._counter}')
+          log_reader_test.write_observation_marker(f_out, 0)
+          log_reader_test.write_buff(f_out, [self._counter], ctypes.c_int64)
+          log_reader_test.write_nl(f_out)
+          log_reader_test.write_outcome_marker(f_out, 0)
+          log_reader_test.write_buff(f_out, [3.14], ctypes.c_float)
+          log_reader_test.write_nl(f_out)
+          self._counter += 1
+          return None
+
+      yield MockInteractiveProcess()
+
+
+class ClangSessionTest(tf.test.TestCase):
+
+  @mock.patch('subprocess.Popen')
+  def test_clang_session(self, mock_popen):
+    mock_task = MockTask()
+    with env.clang_session(
+        _CLANG_PATH, _MOCK_MODULE, MockTask,
+        interactive=False) as clang_session:
+      del clang_session
+      cmdline = mock_task.get_cmdline(_CLANG_PATH,
+                                      list(_MOCK_MODULE.orig_options), None,
+                                      '/tmp/mock/tmp/file')
+      mock_popen.assert_called_once_with(
+          cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
+
+  @mock.patch('subprocess.Popen')
+  def test_interactive_clang_session(self, mock_popen):
+    mock_popen.side_effect = mock_interactive_clang
+
+    with env.clang_session(
+        _CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
+      for idx in range(_NUM_STEPS):
+        obs = clang_session.get_observation()
+        self.assertEqual(
+            obs[env.OBS_KEY]['times_called'],
+            np.array([idx], dtype=np.int64),
+        )
+        self.assertEqual(obs[env.CONTEXT_KEY], f'context_{idx}')
+      mock_popen.assert_called_once()
+
+
+class MLGOEnvironmentTest(tf.test.TestCase):
+
+  @mock.patch('subprocess.Popen')
+  def test_env(self, mock_popen):
+    mock_popen.side_effect = mock_interactive_clang
+
+    test_env = env.MLGOEnvironmentBase(
+        clang_path=_CLANG_PATH,
+        task_type=MockTask,
+        obs_spec={},
+        action_spec={},
+    )
+
+    for env_itr in range(3):
+      del env_itr
+      step = test_env.reset(_MOCK_MODULE)
+      self.assertEqual(step[env.STEP_TYPE_KEY], env.FIRST_STEP_STR)
+
+      for step_itr in range(_NUM_STEPS - 1):
+        del step_itr
+        step = test_env.step(np.array([1], dtype=np.int64))
+        self.assertEqual(step[env.STEP_TYPE_KEY], env.MID_STEP_STR)
+
+      step = test_env.step(np.array([1], dtype=np.int64))
+      self.assertEqual(step[env.STEP_TYPE_KEY], env.LAST_STEP_STR)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/compiler_opt/rl/inlining/__init__.py b/compiler_opt/rl/inlining/__init__.py
index 2ed744f..f7be2ca 100644
--- a/compiler_opt/rl/inlining/__init__.py
+++ b/compiler_opt/rl/inlining/__init__.py
@@ -16,15 +16,20 @@
 
 import gin
 
+from compiler_opt.rl import env
 from compiler_opt.rl import problem_configuration
 from compiler_opt.rl.inlining import config
 from compiler_opt.rl.inlining import inlining_runner
+from compiler_opt.rl.inlining import env as inlining_env
 
 
 @gin.register(module='configs')
 class InliningConfig(problem_configuration.ProblemConfiguration):
   """Expose the regalloc eviction components."""
 
+  def get_env(self) -> env.MLGOEnvironmentBase:
+    return inlining_env.get_inlining_env()
+
   def get_runner_type(self):
     return inlining_runner.InliningRunner
 
diff --git a/compiler_opt/rl/inlining/env.py b/compiler_opt/rl/inlining/env.py
new file mode 100644
index 0000000..fc33a28
--- /dev/null
+++ b/compiler_opt/rl/inlining/env.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Implementation of the inlining for size environment."""
+
+import subprocess
+
+import gin
+import os
+
+from compiler_opt.rl import env
+from compiler_opt.rl.inlining import config
+
+from typing import Dict, List, Optional
+
+_COMPILED_MODULE_NAME = 'compiled_module'
+
+
+@gin.configurable
+class InliningForSizeTask(env.MLGOTask):
+  """Implementation of the inlining-for-size MLGOTask."""
+
+  def __init__(self, llvm_size_path: str):
+    super().__init__()
+    self._llvm_size_path = llvm_size_path
+
+  def get_cmdline(self, clang_path: str, base_args: List[str],
+                  interactive_base_path: Optional[str],
+                  working_dir: str) -> List[str]:
+    if interactive_base_path:
+      interactive_args = [
+          '-mllvm',
+          '-enable-ml-inliner=release',
+          '-mllvm',
+          f'-inliner-interactive-channel-base={interactive_base_path}',
+          #'-mllvm',
+          #'-inliner-interactive-include-default',
+      ]
+    else:
+      interactive_args = []
+    compiled_module_path = os.path.join(working_dir, _COMPILED_MODULE_NAME)
+    return [clang_path
+           ] + base_args + interactive_args + ['-o', compiled_module_path]
+
+  def get_module_scores(self, working_dir: str) -> Dict[str, float]:
+    compiled_module_path = os.path.join(working_dir, _COMPILED_MODULE_NAME)
+    cmdline = [self._llvm_size_path, compiled_module_path]
+    completed_proc = subprocess.run(cmdline, capture_output=True, check=True)
+    if not completed_proc.stdout:
+      raise RuntimeError(f'Empty llvm-size output: {" ".join(cmdline)}')
+    output = completed_proc.stdout.decode('utf-8')
+    tmp = output.split('\n')
+    if len(tmp) != 3:
+      raise RuntimeError(f'Wrong llvm-size output {output}')
+    tmp = tmp[1].split('\t')
+    native_size = int(tmp[0])
+    return {'default': native_size}
+
+
+@gin.configurable
+def get_inlining_env(clang_path: str) -> env.MLGOEnvironmentBase:
+  time_step_spec, action_spec = config.get_inlining_signature_spec()
+
+  return env.MLGOEnvironmentBase(
+      clang_path=clang_path,
+      task_type=InliningForSizeTask,
+      obs_spec=time_step_spec.observation,
+      action_spec=action_spec,
+  )
diff --git a/compiler_opt/rl/inlining/gin_configs/common.gin b/compiler_opt/rl/inlining/gin_configs/common.gin
index d2f9b96..62c18ea 100644
--- a/compiler_opt/rl/inlining/gin_configs/common.gin
+++ b/compiler_opt/rl/inlining/gin_configs/common.gin
@@ -1,3 +1,5 @@
+import compiler_opt.rl.inlining.env
+
 config_registry.get_configuration.implementation=@configs.InliningConfig
 
 launcher_path=None
@@ -8,6 +10,10 @@
 runners.InliningRunner.clang_path=%clang_path
 runners.InliningRunner.launcher_path=%launcher_path
 
+# Setup environment paths
+env.InliningForSizeTask.llvm_size_path=%llvm_size_path
+env.get_inlining_env.clang_path=%clang_path
+
 problem_config.flags_to_add.add_flags=()
 problem_config.flags_to_delete.delete_flags=('-split-dwarf-file','-split-dwarf-output',)
 # For AFDO profile reinjection set:
diff --git a/compiler_opt/rl/problem_configuration.py b/compiler_opt/rl/problem_configuration.py
index a2240fe..203893d 100644
--- a/compiler_opt/rl/problem_configuration.py
+++ b/compiler_opt/rl/problem_configuration.py
@@ -77,6 +77,7 @@
 # used for type annotation in a string (for 3.8 compat)
 # pylint: disable=unused-import
 from compiler_opt.rl import compilation_runner
+from compiler_opt.rl import env
 
 types = tfa.typing.types
 
@@ -85,6 +86,10 @@
   """Abstraction of the APIs accessing a problem-specific configuration."""
 
   @abc.abstractmethod
+  def get_env(self) -> env.MLGOEnvironmentBase:
+    raise NotImplementedError
+
+  @abc.abstractmethod
   def get_signature_spec(
       self) -> Tuple[types.NestedTensorSpec, types.NestedTensorSpec]:
     raise NotImplementedError
diff --git a/compiler_opt/rl/regalloc/__init__.py b/compiler_opt/rl/regalloc/__init__.py
index ec684d1..5263a67 100644
--- a/compiler_opt/rl/regalloc/__init__.py
+++ b/compiler_opt/rl/regalloc/__init__.py
@@ -16,6 +16,7 @@
 
 import gin
 
+from compiler_opt.rl import env
 from compiler_opt.rl import problem_configuration
 from compiler_opt.rl.regalloc import config
 from compiler_opt.rl.regalloc import regalloc_runner
@@ -25,6 +26,9 @@
 class RegallocEvictionConfig(problem_configuration.ProblemConfiguration):
   """Expose the regalloc eviction configuration."""
 
+  def get_env(self) -> env.MLGOEnvironmentBase:
+    raise NotImplementedError
+
   def get_runner_type(self):
     return regalloc_runner.RegAllocRunner