[env] Strong-typing for the observation (#299)
diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py
index 15246bf..1bd96af 100644
--- a/compiler_opt/rl/env.py
+++ b/compiler_opt/rl/env.py
@@ -15,6 +15,8 @@
"""Gymlike environment definition for MLGO."""
from __future__ import annotations
+import dataclasses
+from enum import Enum
import math
import subprocess
@@ -23,38 +25,31 @@
import io
import os
import tempfile
-from typing import Any, Generator, List, Optional, Tuple, Type
+from typing import Callable, Generator, List, Optional, Tuple, Type
import numpy as np
from compiler_opt.rl import corpus
from compiler_opt.rl import log_reader
-OBS_T = Any
-OBS_KEY = 'obs'
-REWARD_KEY = 'reward'
-SCORE_POLICY_KEY = 'score_policy'
-SCORE_DEFAULT_KEY = 'score_default'
-CONTEXT_KEY = 'context'
-MODULE_NAME_KEY = 'module_name'
-OBS_ID_KEY = 'obs_id'
-STEP_TYPE_KEY = 'step_type'
+class StepType(Enum):
+ FIRST = 1
+ MID = 2
+ LAST = 3
-FIRST_STEP_STR = 'first'
-MID_STEP_STR = 'mid'
-LAST_STEP_STR = 'last'
-_TERMINAL_OBS = {
- OBS_KEY: {},
- REWARD_KEY: 0.0,
- SCORE_POLICY_KEY: 0.0,
- SCORE_DEFAULT_KEY: 0.0,
- CONTEXT_KEY: '',
- MODULE_NAME_KEY: '',
- OBS_ID_KEY: -1,
- STEP_TYPE_KEY: LAST_STEP_STR,
-}
+@dataclasses.dataclass
+class TimeStep:
+ obs: Optional[dict[str, np.NDArray]]
+ reward: Optional[dict[str, float]]
+ score_policy: Optional[dict[str, float]]
+ score_default: Optional[dict[str, float]]
+ context: Optional[str]
+ module_name: str
+ obs_id: Optional[int]
+ step_type: StepType
+
_INTERACTIVE_PIPE_FILE_BASE = 'interactive-pipe-base'
@@ -119,7 +114,8 @@
associated to the default-compiled binary.
"""
- def __init__(self, proc, get_scores_fn, module_name):
+ def __init__(self, proc: subprocess.Popen,
+ get_scores_fn: Callable[[], dict[str, float]], module_name):
self._proc = proc
self._get_scores_fn = get_scores_fn
self._module_name = module_name
@@ -134,8 +130,8 @@
def __init__(
self,
- proc,
- get_scores_fn,
+ proc: subprocess.Popen,
+ get_scores_fn: Callable[[], dict[str, float]],
module_name: str,
reader_pipe: io.BufferedReader,
writer_pipe: io.BufferedWriter,
@@ -147,18 +143,26 @@
self._is_first_obs = True
- self._terminal_obs = _TERMINAL_OBS
- self._terminal_obs[MODULE_NAME_KEY] = module_name
+ self._terminal_obs = TimeStep(
+ obs=None,
+ reward=None,
+ score_policy=None,
+ score_default=None,
+ context=None,
+ module_name=module_name,
+ obs_id=None,
+ step_type=StepType.LAST,
+ )
def _running(self) -> bool:
return self._proc.poll() is None
- def get_observation(self) -> OBS_T:
+ def get_observation(self) -> TimeStep:
if not self._running():
return self._terminal_obs
- def _get_step_type():
- step_type = FIRST_STEP_STR if self._is_first_obs else MID_STEP_STR
+ def _get_step_type() -> StepType:
+ step_type = StepType.FIRST if self._is_first_obs else StepType.MID
self._is_first_obs = False
return step_type
@@ -169,16 +173,16 @@
for fv in obs.feature_values:
array = fv.to_numpy()
tv_dict[fv.spec.name] = np.reshape(array, newshape=fv.spec.shape)
- return {
- OBS_KEY: tv_dict,
- REWARD_KEY: obs.score if obs.score else 0.0,
- SCORE_POLICY_KEY: 0.0,
- SCORE_DEFAULT_KEY: 0.0,
- CONTEXT_KEY: obs.context,
- MODULE_NAME_KEY: self._module_name,
- OBS_ID_KEY: obs.observation_id,
- STEP_TYPE_KEY: _get_step_type(),
- }
+ return TimeStep(
+ obs=tv_dict,
+ reward={obs.context: obs.score} if obs.score else None,
+ score_policy=None,
+ score_default=None,
+ context=obs.context,
+ module_name=self._module_name,
+ obs_id=obs.observation_id,
+ step_type=_get_step_type(),
+ )
except StopIteration:
return self._terminal_obs
@@ -338,14 +342,14 @@
def observation(self):
return self._last_obs
- def _get_observation(self) -> OBS_T:
+ def _get_observation(self) -> TimeStep:
self._last_obs = self._iclang.get_observation()
- if self._last_obs[STEP_TYPE_KEY] == 'last':
- self._last_obs[SCORE_POLICY_KEY] = self._iclang.get_scores()
- self._last_obs[SCORE_DEFAULT_KEY] = self._clang.get_scores()
- self._last_obs[REWARD_KEY] = compute_relative_rewards(
- self._last_obs[SCORE_POLICY_KEY], self._last_obs[SCORE_DEFAULT_KEY])
- return self.observation()
+ if self._last_obs.step_type == StepType.LAST:
+ self._last_obs.score_policy = self._iclang.get_scores()
+ self._last_obs.score_default = self._clang.get_scores()
+ self._last_obs.reward = compute_relative_rewards(
+ self._last_obs.score_policy, self._last_obs.score_default)
+ return self._last_obs
def reset(self, module: corpus.LoadedModuleSpec):
# On the first call to reset(...), sending None starts the coroutine.
diff --git a/compiler_opt/rl/env_test.py b/compiler_opt/rl/env_test.py
index 1af508c..87577b3 100644
--- a/compiler_opt/rl/env_test.py
+++ b/compiler_opt/rl/env_test.py
@@ -155,10 +155,10 @@
for idx in range(_NUM_STEPS):
obs = clang_session.get_observation()
self.assertEqual(
- obs[env.OBS_KEY]['times_called'],
+ obs.obs['times_called'],
np.array([idx], dtype=np.int64),
)
- self.assertEqual(obs[env.CONTEXT_KEY], f'context_{idx}')
+ self.assertEqual(obs.context, f'context_{idx}')
mock_popen.assert_called_once()
@@ -178,15 +178,15 @@
for env_itr in range(3):
del env_itr
step = test_env.reset(_MOCK_MODULE)
- self.assertEqual(step[env.STEP_TYPE_KEY], env.FIRST_STEP_STR)
+ self.assertEqual(step.step_type, env.StepType.FIRST)
for step_itr in range(_NUM_STEPS - 1):
del step_itr
step = test_env.step(np.array([1], dtype=np.int64))
- self.assertEqual(step[env.STEP_TYPE_KEY], env.MID_STEP_STR)
+ self.assertEqual(step.step_type, env.StepType.MID)
step = test_env.step(np.array([1], dtype=np.int64))
- self.assertEqual(step[env.STEP_TYPE_KEY], env.LAST_STEP_STR)
+ self.assertEqual(step.step_type, env.StepType.LAST)
if __name__ == '__main__':