[env] Strong-typing for the observation (#299)

diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py
index 15246bf..1bd96af 100644
--- a/compiler_opt/rl/env.py
+++ b/compiler_opt/rl/env.py
@@ -15,6 +15,8 @@
 """Gymlike environment definition for MLGO."""
 
 from __future__ import annotations
+import dataclasses
+from enum import Enum
 
 import math
 import subprocess
@@ -23,38 +25,31 @@
 import io
 import os
 import tempfile
-from typing import Any, Generator, List, Optional, Tuple, Type
+from typing import Callable, Generator, List, Optional, Tuple, Type
 
 import numpy as np
 
 from compiler_opt.rl import corpus
 from compiler_opt.rl import log_reader
 
-OBS_T = Any
 
-OBS_KEY = 'obs'
-REWARD_KEY = 'reward'
-SCORE_POLICY_KEY = 'score_policy'
-SCORE_DEFAULT_KEY = 'score_default'
-CONTEXT_KEY = 'context'
-MODULE_NAME_KEY = 'module_name'
-OBS_ID_KEY = 'obs_id'
-STEP_TYPE_KEY = 'step_type'
+class StepType(Enum):
+  FIRST = 1
+  MID = 2
+  LAST = 3
 
-FIRST_STEP_STR = 'first'
-MID_STEP_STR = 'mid'
-LAST_STEP_STR = 'last'
 
-_TERMINAL_OBS = {
-    OBS_KEY: {},
-    REWARD_KEY: 0.0,
-    SCORE_POLICY_KEY: 0.0,
-    SCORE_DEFAULT_KEY: 0.0,
-    CONTEXT_KEY: '',
-    MODULE_NAME_KEY: '',
-    OBS_ID_KEY: -1,
-    STEP_TYPE_KEY: LAST_STEP_STR,
-}
+@dataclasses.dataclass
+class TimeStep:
+  obs: Optional[dict[str, np.NDArray]]
+  reward: Optional[dict[str, float]]
+  score_policy: Optional[dict[str, float]]
+  score_default: Optional[dict[str, float]]
+  context: Optional[str]
+  module_name: str
+  obs_id: Optional[int]
+  step_type: StepType
+
 
 _INTERACTIVE_PIPE_FILE_BASE = 'interactive-pipe-base'
 
@@ -119,7 +114,8 @@
   associated to the default-compiled binary.
   """
 
-  def __init__(self, proc, get_scores_fn, module_name):
+  def __init__(self, proc: subprocess.Popen,
+               get_scores_fn: Callable[[], dict[str, float]], module_name):
     self._proc = proc
     self._get_scores_fn = get_scores_fn
     self._module_name = module_name
@@ -134,8 +130,8 @@
 
   def __init__(
       self,
-      proc,
-      get_scores_fn,
+      proc: subprocess.Popen,
+      get_scores_fn: Callable[[], dict[str, float]],
       module_name: str,
       reader_pipe: io.BufferedReader,
       writer_pipe: io.BufferedWriter,
@@ -147,18 +143,26 @@
 
     self._is_first_obs = True
 
-    self._terminal_obs = _TERMINAL_OBS
-    self._terminal_obs[MODULE_NAME_KEY] = module_name
+    self._terminal_obs = TimeStep(
+        obs=None,
+        reward=None,
+        score_policy=None,
+        score_default=None,
+        context=None,
+        module_name=module_name,
+        obs_id=None,
+        step_type=StepType.LAST,
+    )
 
   def _running(self) -> bool:
     return self._proc.poll() is None
 
-  def get_observation(self) -> OBS_T:
+  def get_observation(self) -> TimeStep:
     if not self._running():
       return self._terminal_obs
 
-    def _get_step_type():
-      step_type = FIRST_STEP_STR if self._is_first_obs else MID_STEP_STR
+    def _get_step_type() -> StepType:
+      step_type = StepType.FIRST if self._is_first_obs else StepType.MID
       self._is_first_obs = False
       return step_type
 
@@ -169,16 +173,16 @@
       for fv in obs.feature_values:
         array = fv.to_numpy()
         tv_dict[fv.spec.name] = np.reshape(array, newshape=fv.spec.shape)
-      return {
-          OBS_KEY: tv_dict,
-          REWARD_KEY: obs.score if obs.score else 0.0,
-          SCORE_POLICY_KEY: 0.0,
-          SCORE_DEFAULT_KEY: 0.0,
-          CONTEXT_KEY: obs.context,
-          MODULE_NAME_KEY: self._module_name,
-          OBS_ID_KEY: obs.observation_id,
-          STEP_TYPE_KEY: _get_step_type(),
-      }
+      return TimeStep(
+          obs=tv_dict,
+          reward={obs.context: obs.score} if obs.score else None,
+          score_policy=None,
+          score_default=None,
+          context=obs.context,
+          module_name=self._module_name,
+          obs_id=obs.observation_id,
+          step_type=_get_step_type(),
+      )
     except StopIteration:
       return self._terminal_obs
 
@@ -338,14 +342,14 @@
   def observation(self):
     return self._last_obs
 
-  def _get_observation(self) -> OBS_T:
+  def _get_observation(self) -> TimeStep:
     self._last_obs = self._iclang.get_observation()
-    if self._last_obs[STEP_TYPE_KEY] == 'last':
-      self._last_obs[SCORE_POLICY_KEY] = self._iclang.get_scores()
-      self._last_obs[SCORE_DEFAULT_KEY] = self._clang.get_scores()
-      self._last_obs[REWARD_KEY] = compute_relative_rewards(
-          self._last_obs[SCORE_POLICY_KEY], self._last_obs[SCORE_DEFAULT_KEY])
-    return self.observation()
+    if self._last_obs.step_type == StepType.LAST:
+      self._last_obs.score_policy = self._iclang.get_scores()
+      self._last_obs.score_default = self._clang.get_scores()
+      self._last_obs.reward = compute_relative_rewards(
+          self._last_obs.score_policy, self._last_obs.score_default)
+    return self._last_obs
 
   def reset(self, module: corpus.LoadedModuleSpec):
     # On the first call to reset(...), sending None starts the coroutine.
diff --git a/compiler_opt/rl/env_test.py b/compiler_opt/rl/env_test.py
index 1af508c..87577b3 100644
--- a/compiler_opt/rl/env_test.py
+++ b/compiler_opt/rl/env_test.py
@@ -155,10 +155,10 @@
       for idx in range(_NUM_STEPS):
         obs = clang_session.get_observation()
         self.assertEqual(
-            obs[env.OBS_KEY]['times_called'],
+            obs.obs['times_called'],
             np.array([idx], dtype=np.int64),
         )
-        self.assertEqual(obs[env.CONTEXT_KEY], f'context_{idx}')
+        self.assertEqual(obs.context, f'context_{idx}')
       mock_popen.assert_called_once()
 
 
@@ -178,15 +178,15 @@
     for env_itr in range(3):
       del env_itr
       step = test_env.reset(_MOCK_MODULE)
-      self.assertEqual(step[env.STEP_TYPE_KEY], env.FIRST_STEP_STR)
+      self.assertEqual(step.step_type, env.StepType.FIRST)
 
       for step_itr in range(_NUM_STEPS - 1):
         del step_itr
         step = test_env.step(np.array([1], dtype=np.int64))
-        self.assertEqual(step[env.STEP_TYPE_KEY], env.MID_STEP_STR)
+        self.assertEqual(step.step_type, env.StepType.MID)
 
       step = test_env.step(np.array([1], dtype=np.int64))
-      self.assertEqual(step[env.STEP_TYPE_KEY], env.LAST_STEP_STR)
+      self.assertEqual(step.step_type, env.StepType.LAST)
 
 
 if __name__ == '__main__':