Add MLGO environments. (#228)
Add MLGO environment abstractions.
This commit also contains an implementation of the inlining-for-size
environment.
diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py
new file mode 100644
index 0000000..15246bf
--- /dev/null
+++ b/compiler_opt/rl/env.py
@@ -0,0 +1,362 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Gymlike environment definition for MLGO."""
+
+from __future__ import annotations
+
+import math
+import subprocess
+import abc
+import contextlib
+import io
+import os
+import tempfile
+from typing import Any, Generator, List, Optional, Tuple, Type
+
+import numpy as np
+
+from compiler_opt.rl import corpus
+from compiler_opt.rl import log_reader
+
+OBS_T = Any
+
+OBS_KEY = 'obs'
+REWARD_KEY = 'reward'
+SCORE_POLICY_KEY = 'score_policy'
+SCORE_DEFAULT_KEY = 'score_default'
+CONTEXT_KEY = 'context'
+MODULE_NAME_KEY = 'module_name'
+OBS_ID_KEY = 'obs_id'
+STEP_TYPE_KEY = 'step_type'
+
+FIRST_STEP_STR = 'first'
+MID_STEP_STR = 'mid'
+LAST_STEP_STR = 'last'
+
+_TERMINAL_OBS = {
+ OBS_KEY: {},
+ REWARD_KEY: 0.0,
+ SCORE_POLICY_KEY: 0.0,
+ SCORE_DEFAULT_KEY: 0.0,
+ CONTEXT_KEY: '',
+ MODULE_NAME_KEY: '',
+ OBS_ID_KEY: -1,
+ STEP_TYPE_KEY: LAST_STEP_STR,
+}
+
+_INTERACTIVE_PIPE_FILE_BASE = 'interactive-pipe-base'
+
+
+class MLGOTask(metaclass=abc.ABCMeta):
+ """Abstract base class for MLGO Tasks.
+
+ A Task is an learning problem in LLVM, for example:
+ - inlining-for-size
+ - inlining-for-speed
+ - register allocation (for speed)
+
+ The Task type for a given problem defines how to build and score modules for
+ the problem, both interactively and non-interactively.
+ """
+
+ @abc.abstractmethod
+ def get_cmdline(self, clang_path: str, base_args: List[str],
+ interactive_base_path: Optional[str],
+ working_dir: str) -> List[str]:
+ """Get the cmdline for building with this task.
+
+ The resulting list[str] should be able to be passed to subprocess.run to
+ execute clang.
+
+ Args:
+ clang_path: path to the clang executable.
+ base_args: base arguments for building the module. Generally, these flags
+ should not be modified and simply added to the result.
+ interactive_base_path: the path to the interactive pipe base. if None,
+ then don't run clang interactively.
+ working_dir: directory where all artifacts from compilation should be
+ written. This will be a temp directory whose lifetime is managed outside
+ of the Task.
+
+ Returns:
+ The constructed command line.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_module_scores(self, working_dir: str) -> dict[str, float]:
+ """Get the scores for each context in the module.
+
+ This method should not be aware of whether the module was built with the
+ default heuristic or a ML policy.
+
+ Args:
+ working_dir: Directory which was passed as working_dir to get_cmdline.
+ Used to recover binaries/artifacts from the build
+
+ Returns:
+ A dictionary mapping [context name] -> [score].
+ """
+ pass
+
+
+class ClangProcess:
+ """Simple wrapper class around a clang process.
+
+ This is used wrap both the clang process and the method to return the scores
+ associated to the default-compiled binary.
+ """
+
+ def __init__(self, proc, get_scores_fn, module_name):
+ self._proc = proc
+ self._get_scores_fn = get_scores_fn
+ self._module_name = module_name
+
+ def get_scores(self, timeout: Optional[int] = None):
+ self._proc.wait(timeout=timeout)
+ return self._get_scores_fn()
+
+
+class InteractiveClang(ClangProcess):
+ """Wrapper around clang's interactive mode."""
+
+ def __init__(
+ self,
+ proc,
+ get_scores_fn,
+ module_name: str,
+ reader_pipe: io.BufferedReader,
+ writer_pipe: io.BufferedWriter,
+ ):
+ super().__init__(proc, get_scores_fn, module_name)
+ self._reader_pipe = reader_pipe
+ self._writer_pipe = writer_pipe
+ self._obs_gen = log_reader.read_log_from_file(self._reader_pipe)
+
+ self._is_first_obs = True
+
+ self._terminal_obs = _TERMINAL_OBS
+ self._terminal_obs[MODULE_NAME_KEY] = module_name
+
+ def _running(self) -> bool:
+ return self._proc.poll() is None
+
+ def get_observation(self) -> OBS_T:
+ if not self._running():
+ return self._terminal_obs
+
+ def _get_step_type():
+ step_type = FIRST_STEP_STR if self._is_first_obs else MID_STEP_STR
+ self._is_first_obs = False
+ return step_type
+
+ try:
+ obs: log_reader.ObservationRecord = next(self._obs_gen)
+
+ tv_dict = {}
+ for fv in obs.feature_values:
+ array = fv.to_numpy()
+ tv_dict[fv.spec.name] = np.reshape(array, newshape=fv.spec.shape)
+ return {
+ OBS_KEY: tv_dict,
+ REWARD_KEY: obs.score if obs.score else 0.0,
+ SCORE_POLICY_KEY: 0.0,
+ SCORE_DEFAULT_KEY: 0.0,
+ CONTEXT_KEY: obs.context,
+ MODULE_NAME_KEY: self._module_name,
+ OBS_ID_KEY: obs.observation_id,
+ STEP_TYPE_KEY: _get_step_type(),
+ }
+ except StopIteration:
+ return self._terminal_obs
+
+ def send_action(self, action: np.ndarray) -> None:
+ assert self._running()
+ data = action.tobytes()
+ bytes_sent = self._writer_pipe.write(data)
+ # Here we use the fact that for common types, the np.dtype and ctype should
+ # behave the same
+ assert bytes_sent == action.dtype.itemsize * math.prod(action.shape)
+ try:
+ self._writer_pipe.flush()
+ except BrokenPipeError:
+ # The pipe can break after we send the last action
+ pass
+
+
+_EPS = 1e-4
+
+
+def compute_relative_rewards(score_a: dict[str, float],
+ score_b: dict[str, float]) -> dict[str, float]:
+
+ def _reward_fn(a: float, b: float) -> float:
+ return 1.0 - (a + _EPS) / (b + _EPS)
+
+ assert score_a.keys() == score_b.keys()
+ return {key: _reward_fn(score_a[key], score_b[key]) for key in score_a}
+
+
+@contextlib.contextmanager
+def clang_session(
+ clang_path: str,
+ module: corpus.LoadedModuleSpec,
+ task_type: Type[MLGOTask],
+ *,
+ interactive: bool,
+):
+ """Context manager for clang session.
+
+ We need to manage the context so resources like tempfiles and pipes have
+ their lifetimes managed appropriately.
+
+ Args:
+ clang_path: The clang binary to use for the InteractiveClang session.
+ module: The module to compile with clang.
+ task_type: Type of the MLGOTask to use.
+ interactive: Whether to use an interactive or default clang instance
+
+ Yields:
+ Either the constructed InteractiveClang or DefaultClang object.
+ """
+ with tempfile.TemporaryDirectory() as td:
+ task_working_dir = os.path.join(td, '__task_working_dir__')
+ os.mkdir(task_working_dir)
+ task = task_type()
+
+ base_args = list(module.build_command_line(td))
+ interactive_base = os.path.join(
+ td, _INTERACTIVE_PIPE_FILE_BASE) if interactive else None
+ cmdline = task.get_cmdline(clang_path, base_args, interactive_base,
+ task_working_dir)
+
+ def _get_scores() -> dict[str, float]:
+ return task.get_module_scores(task_working_dir)
+
+ writer_name = os.path.join(td, _INTERACTIVE_PIPE_FILE_BASE + '.in')
+ reader_name = os.path.join(td, _INTERACTIVE_PIPE_FILE_BASE + '.out')
+ if interactive:
+ os.mkfifo(reader_name, 0o666)
+ os.mkfifo(writer_name, 0o666)
+ with subprocess.Popen(
+ cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
+ try:
+ if interactive:
+ with io.BufferedWriter(io.FileIO(writer_name, 'wb')) as writer_pipe:
+ with io.BufferedReader(io.FileIO(reader_name, 'rb')) as reader_pipe:
+ yield InteractiveClang(
+ proc,
+ _get_scores,
+ module.name,
+ reader_pipe,
+ writer_pipe,
+ )
+ else:
+ yield ClangProcess(
+ proc,
+ _get_scores,
+ module.name,
+ )
+
+ finally:
+ proc.kill()
+
+
+def _get_clang_generator(
+ clang_path: str,
+ task_type: Type[MLGOTask],
+) -> Generator[Optional[Tuple[ClangProcess, InteractiveClang]],
+ Optional[corpus.LoadedModuleSpec], None]:
+ """Returns a generator for creating InteractiveClang objects.
+
+ TODO: fix this docstring
+
+ Args:
+ clang_path: Path to the clang binary to use within InteractiveClang.
+ task_type: Type of the MLGO task to use.
+
+ Returns:
+ The generator for InteractiveClang objects.
+ """
+ while True:
+ # The following line should be type-hinted as follows:
+ # module: corpus.LoadedModuleSpec = yield
+ # However, this triggers a yapf crash. See:
+ # https://github.com/google/yapf/issues/1092
+ module = yield
+ with clang_session(
+ clang_path, module, task_type, interactive=True) as iclang:
+ with clang_session(
+ clang_path, module, task_type, interactive=False) as clang:
+ yield iclang, clang
+
+
+class MLGOEnvironmentBase:
+ """Base implementation for all MLGO environments.
+
+ Depending on the RL framework, one may want different implementations of an
+ enviroment (tf_agents: PyEnvironment, jax: dm-env, etc). This class
+ implements the core methods that are needed to then implement any of these
+ other environments as well.
+ """
+
+ def __init__(
+ self,
+ *,
+ clang_path: str,
+ task_type: Type[MLGOTask],
+ obs_spec,
+ action_spec,
+ ):
+ self._clang_generator = _get_clang_generator(clang_path, task_type)
+ self._obs_spec = obs_spec
+ self._action_spec = action_spec
+
+ self._iclang: Optional[InteractiveClang] = None
+ self._clang: Optional[ClangProcess] = None
+
+ @property
+ def obs_spec(self):
+ return self._obs_spec
+
+ @property
+ def action_spec(self):
+ return self._action_spec
+
+ def observation(self):
+ return self._last_obs
+
+ def _get_observation(self) -> OBS_T:
+ self._last_obs = self._iclang.get_observation()
+ if self._last_obs[STEP_TYPE_KEY] == 'last':
+ self._last_obs[SCORE_POLICY_KEY] = self._iclang.get_scores()
+ self._last_obs[SCORE_DEFAULT_KEY] = self._clang.get_scores()
+ self._last_obs[REWARD_KEY] = compute_relative_rewards(
+ self._last_obs[SCORE_POLICY_KEY], self._last_obs[SCORE_DEFAULT_KEY])
+ return self.observation()
+
+ def reset(self, module: corpus.LoadedModuleSpec):
+ # On the first call to reset(...), sending None starts the coroutine.
+ # On subsequent calls, this resumes execution after
+ # yielding the clang pair, which terminates the session pauses execution in
+ # the coroutine where it awaits a module
+ self._clang_generator.send(None)
+ # pytype: disable=attribute-error
+ self._iclang, self._clang = self._clang_generator.send(module)
+ return self._get_observation()
+
+ def step(self, action: np.ndarray):
+ self._iclang.send_action(action)
+ return self._get_observation()
diff --git a/compiler_opt/rl/env_test.py b/compiler_opt/rl/env_test.py
new file mode 100644
index 0000000..1af508c
--- /dev/null
+++ b/compiler_opt/rl/env_test.py
@@ -0,0 +1,193 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for compiler_opt.rl.env."""
+
+import io
+import contextlib
+import ctypes
+from unittest import mock
+import subprocess
+
+from typing import Dict, List, Optional
+
+import tensorflow as tf
+import numpy as np
+
+from compiler_opt.rl import env
+from compiler_opt.rl import corpus
+from compiler_opt.rl import log_reader_test
+
+_CLANG_PATH = '/test/clang/path'
+
+_MOCK_MODULE = corpus.LoadedModuleSpec(
+ name='module',
+ loaded_ir=b'asdf',
+ orig_options=('--opt_a', 'a', '--opt_b', 'b'),
+)
+
+_NUM_STEPS = 10
+
+
+class MockTask(env.MLGOTask):
+ """Implementation of mock task for testing."""
+
+ def get_cmdline(self, clang_path: str, base_args: List[str],
+ interactive_base_path: Optional[str],
+ working_dir: str) -> List[str]:
+ if interactive_base_path:
+ interactive_args = [
+ f'--interactive={interactive_base_path}',
+ ]
+ else:
+ interactive_args = []
+ return [clang_path] + base_args + interactive_args
+
+ def get_module_scores(self, working_dir: str) -> Dict[str, float]:
+ return {'default': 47}
+
+
+# This mocks subprocess.Popen for interactive clang sessions
+@contextlib.contextmanager
+def mock_interactive_clang(cmdline, stderr, stdout):
+ del stderr
+ del stdout
+ # do basic argument parsing
+ fname = None
+ for arg in cmdline:
+ if arg.startswith('--interactive='):
+ fname = arg[len('--interactive='):]
+ break
+
+ class MockProcess:
+
+ def wait(self, timeout):
+ pass
+
+ def kill(self):
+ pass
+
+ if not fname:
+ yield MockProcess()
+ return
+ # Create the fds for the pipes
+ # (the env doesn't create the files, it assumes they are opened by clang)
+ with io.FileIO(fname + '.out', 'wb+') as f_out:
+ with io.FileIO(fname + '.in', 'rb+') as f_in:
+ del f_in
+ # Write the header describing the features/rewards
+ f_out.write(
+ log_reader_test.json_to_bytes({
+ 'features': [{
+ 'name': 'times_called',
+ 'port': 0,
+ 'shape': [1],
+ 'type': 'int64_t',
+ },],
+ 'score': {
+ 'name': 'reward',
+ 'port': 0,
+ 'shape': [1],
+ 'type': 'float',
+ },
+ }))
+ log_reader_test.write_nl(f_out)
+
+ class MockInteractiveProcess(MockProcess):
+ """Mock clang interactive process that writes the log."""
+
+ def __init__(self):
+ self._counter = 0
+
+ # We poll the process at every call to get_observation to ensure the
+ # clang process is still alive. So here, each time poll() is called,
+ # write a new context
+ def poll(self):
+ if self._counter >= _NUM_STEPS:
+ f_out.close()
+ return None
+ log_reader_test.write_context_marker(f_out,
+ f'context_{self._counter}')
+ log_reader_test.write_observation_marker(f_out, 0)
+ log_reader_test.write_buff(f_out, [self._counter], ctypes.c_int64)
+ log_reader_test.write_nl(f_out)
+ log_reader_test.write_outcome_marker(f_out, 0)
+ log_reader_test.write_buff(f_out, [3.14], ctypes.c_float)
+ log_reader_test.write_nl(f_out)
+ self._counter += 1
+ return None
+
+ yield MockInteractiveProcess()
+
+
+class ClangSessionTest(tf.test.TestCase):
+
+ @mock.patch('subprocess.Popen')
+ def test_clang_session(self, mock_popen):
+ mock_task = MockTask()
+ with env.clang_session(
+ _CLANG_PATH, _MOCK_MODULE, MockTask,
+ interactive=False) as clang_session:
+ del clang_session
+ cmdline = mock_task.get_cmdline(_CLANG_PATH,
+ list(_MOCK_MODULE.orig_options), None,
+ '/tmp/mock/tmp/file')
+ mock_popen.assert_called_once_with(
+ cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
+
+ @mock.patch('subprocess.Popen')
+ def test_interactive_clang_session(self, mock_popen):
+ mock_popen.side_effect = mock_interactive_clang
+
+ with env.clang_session(
+ _CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
+ for idx in range(_NUM_STEPS):
+ obs = clang_session.get_observation()
+ self.assertEqual(
+ obs[env.OBS_KEY]['times_called'],
+ np.array([idx], dtype=np.int64),
+ )
+ self.assertEqual(obs[env.CONTEXT_KEY], f'context_{idx}')
+ mock_popen.assert_called_once()
+
+
+class MLGOEnvironmentTest(tf.test.TestCase):
+
+ @mock.patch('subprocess.Popen')
+ def test_env(self, mock_popen):
+ mock_popen.side_effect = mock_interactive_clang
+
+ test_env = env.MLGOEnvironmentBase(
+ clang_path=_CLANG_PATH,
+ task_type=MockTask,
+ obs_spec={},
+ action_spec={},
+ )
+
+ for env_itr in range(3):
+ del env_itr
+ step = test_env.reset(_MOCK_MODULE)
+ self.assertEqual(step[env.STEP_TYPE_KEY], env.FIRST_STEP_STR)
+
+ for step_itr in range(_NUM_STEPS - 1):
+ del step_itr
+ step = test_env.step(np.array([1], dtype=np.int64))
+ self.assertEqual(step[env.STEP_TYPE_KEY], env.MID_STEP_STR)
+
+ step = test_env.step(np.array([1], dtype=np.int64))
+ self.assertEqual(step[env.STEP_TYPE_KEY], env.LAST_STEP_STR)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/compiler_opt/rl/inlining/__init__.py b/compiler_opt/rl/inlining/__init__.py
index 2ed744f..f7be2ca 100644
--- a/compiler_opt/rl/inlining/__init__.py
+++ b/compiler_opt/rl/inlining/__init__.py
@@ -16,15 +16,20 @@
import gin
+from compiler_opt.rl import env
from compiler_opt.rl import problem_configuration
from compiler_opt.rl.inlining import config
from compiler_opt.rl.inlining import inlining_runner
+from compiler_opt.rl.inlining import env as inlining_env
@gin.register(module='configs')
class InliningConfig(problem_configuration.ProblemConfiguration):
"""Expose the regalloc eviction components."""
+ def get_env(self) -> env.MLGOEnvironmentBase:
+ return inlining_env.get_inlining_env()
+
def get_runner_type(self):
return inlining_runner.InliningRunner
diff --git a/compiler_opt/rl/inlining/env.py b/compiler_opt/rl/inlining/env.py
new file mode 100644
index 0000000..fc33a28
--- /dev/null
+++ b/compiler_opt/rl/inlining/env.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Implementation of the inlining for size environment."""
+
+import subprocess
+
+import gin
+import os
+
+from compiler_opt.rl import env
+from compiler_opt.rl.inlining import config
+
+from typing import Dict, List, Optional
+
+_COMPILED_MODULE_NAME = 'compiled_module'
+
+
+@gin.configurable
+class InliningForSizeTask(env.MLGOTask):
+ """Implementation of the inlining-for-size MLGOTask."""
+
+ def __init__(self, llvm_size_path: str):
+ super().__init__()
+ self._llvm_size_path = llvm_size_path
+
+ def get_cmdline(self, clang_path: str, base_args: List[str],
+ interactive_base_path: Optional[str],
+ working_dir: str) -> List[str]:
+ if interactive_base_path:
+ interactive_args = [
+ '-mllvm',
+ '-enable-ml-inliner=release',
+ '-mllvm',
+ f'-inliner-interactive-channel-base={interactive_base_path}',
+ #'-mllvm',
+ #'-inliner-interactive-include-default',
+ ]
+ else:
+ interactive_args = []
+ compiled_module_path = os.path.join(working_dir, _COMPILED_MODULE_NAME)
+ return [clang_path
+ ] + base_args + interactive_args + ['-o', compiled_module_path]
+
+ def get_module_scores(self, working_dir: str) -> Dict[str, float]:
+ compiled_module_path = os.path.join(working_dir, _COMPILED_MODULE_NAME)
+ cmdline = [self._llvm_size_path, compiled_module_path]
+ completed_proc = subprocess.run(cmdline, capture_output=True, check=True)
+ if not completed_proc.stdout:
+ raise RuntimeError(f'Empty llvm-size output: {" ".join(cmdline)}')
+ output = completed_proc.stdout.decode('utf-8')
+ tmp = output.split('\n')
+ if len(tmp) != 3:
+ raise RuntimeError(f'Wrong llvm-size output {output}')
+ tmp = tmp[1].split('\t')
+ native_size = int(tmp[0])
+ return {'default': native_size}
+
+
+@gin.configurable
+def get_inlining_env(clang_path: str) -> env.MLGOEnvironmentBase:
+ time_step_spec, action_spec = config.get_inlining_signature_spec()
+
+ return env.MLGOEnvironmentBase(
+ clang_path=clang_path,
+ task_type=InliningForSizeTask,
+ obs_spec=time_step_spec.observation,
+ action_spec=action_spec,
+ )
diff --git a/compiler_opt/rl/inlining/gin_configs/common.gin b/compiler_opt/rl/inlining/gin_configs/common.gin
index d2f9b96..62c18ea 100644
--- a/compiler_opt/rl/inlining/gin_configs/common.gin
+++ b/compiler_opt/rl/inlining/gin_configs/common.gin
@@ -1,3 +1,5 @@
+import compiler_opt.rl.inlining.env
+
config_registry.get_configuration.implementation=@configs.InliningConfig
launcher_path=None
@@ -8,6 +10,10 @@
runners.InliningRunner.clang_path=%clang_path
runners.InliningRunner.launcher_path=%launcher_path
+# Setup environment paths
+env.InliningForSizeTask.llvm_size_path=%llvm_size_path
+env.get_inlining_env.clang_path=%clang_path
+
problem_config.flags_to_add.add_flags=()
problem_config.flags_to_delete.delete_flags=('-split-dwarf-file','-split-dwarf-output',)
# For AFDO profile reinjection set:
diff --git a/compiler_opt/rl/problem_configuration.py b/compiler_opt/rl/problem_configuration.py
index a2240fe..203893d 100644
--- a/compiler_opt/rl/problem_configuration.py
+++ b/compiler_opt/rl/problem_configuration.py
@@ -77,6 +77,7 @@
# used for type annotation in a string (for 3.8 compat)
# pylint: disable=unused-import
from compiler_opt.rl import compilation_runner
+from compiler_opt.rl import env
types = tfa.typing.types
@@ -85,6 +86,10 @@
"""Abstraction of the APIs accessing a problem-specific configuration."""
@abc.abstractmethod
+ def get_env(self) -> env.MLGOEnvironmentBase:
+ raise NotImplementedError
+
+ @abc.abstractmethod
def get_signature_spec(
self) -> Tuple[types.NestedTensorSpec, types.NestedTensorSpec]:
raise NotImplementedError
diff --git a/compiler_opt/rl/regalloc/__init__.py b/compiler_opt/rl/regalloc/__init__.py
index ec684d1..5263a67 100644
--- a/compiler_opt/rl/regalloc/__init__.py
+++ b/compiler_opt/rl/regalloc/__init__.py
@@ -16,6 +16,7 @@
import gin
+from compiler_opt.rl import env
from compiler_opt.rl import problem_configuration
from compiler_opt.rl.regalloc import config
from compiler_opt.rl.regalloc import regalloc_runner
@@ -25,6 +26,9 @@
class RegallocEvictionConfig(problem_configuration.ProblemConfiguration):
"""Expose the regalloc eviction configuration."""
+ def get_env(self) -> env.MLGOEnvironmentBase:
+ raise NotImplementedError
+
def get_runner_type(self):
return regalloc_runner.RegAllocRunner