Use the baseline cache for the sampling evaluator (#550)

diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py
index 8e0f783..d84e8ae 100644
--- a/compiler_opt/es/blackbox_evaluator.py
+++ b/compiler_opt/es/blackbox_evaluator.py
@@ -28,6 +28,7 @@
 from compiler_opt.es import blackbox_optimizers
 from compiler_opt.distributed import buffered_scheduler
 from compiler_opt.rl import compilation_runner
+from compiler_opt import baseline_cache
 
 
 def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]:
@@ -51,6 +52,8 @@
                estimator_type: blackbox_optimizers.EstimatorType):
     self._train_corpus = train_corpus
     self._estimator_type = estimator_type
+    self._baseline_cache = baseline_cache.BaselineCache(
+        get_key=lambda x: x.name)
 
   @abc.abstractmethod
   def get_results(
@@ -73,12 +76,17 @@
                num_ir_repeats_within_worker: int = 1,
                **kwargs):
     super().__init__(**kwargs)
-    self._samples: list[list[corpus.LoadedModuleSpec]] = []
     self._total_num_perturbations = total_num_perturbations
     self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
-    self._baselines: list[float | None] | None = None
+    self._reset()
 
-  def _load_samples(self) -> None:
+  def _reset(self):
+    # TODO: this object is currently supposed to respect a state transition
+    # and that makes it less maintainable than if not.
+    self._samples = None
+    self._baselines = None
+
+  def load_samples(self) -> None:
     """Samples and loads modules if not already done.
 
     Ensures self._samples contains the expected number of loaded samples.
@@ -89,6 +97,7 @@
     """
     if self._samples:
       raise RuntimeError('Samples have already been loaded.')
+    self._samples = []
     for _ in range(self._total_num_perturbations):
       samples = self._train_corpus.sample(self._num_ir_repeats_within_worker)
       loaded_samples = [
@@ -108,15 +117,15 @@
     if len(self._samples) != expected_count:
       raise RuntimeError('Some samples could not be loaded correctly.')
 
-  def _launch_compilation_workers(self,
-                                  pool: FixedWorkerPool,
-                                  perturbations: list[bytes] | None = None
-                                 ) -> list[concurrent.futures.Future]:
-    if self._samples is None:
-      raise RuntimeError('Loaded samples are not available.')
+  def _launch_compilation_workers(
+      self,
+      pool: FixedWorkerPool,
+      samples: list[list[corpus.LoadedModuleSpec]],
+      perturbations: list[bytes] | None = None
+  ) -> list[concurrent.futures.Future]:
     if perturbations is None:
-      perturbations = [None] * len(self._samples)
-    compile_args = zip(perturbations, self._samples)
+      perturbations = [None] * len(samples)
+    compile_args = zip(perturbations, samples)
     _, futures = buffered_scheduler.schedule_on_worker_pool(
         action=lambda w, args: w.compile(policy=args[0], modules=args[1]),
         jobs=compile_args,
@@ -130,24 +139,43 @@
           not_done, return_when=concurrent.futures.FIRST_COMPLETED)
     return futures
 
+  def ensure_baselines(self, pool):
+    if self._samples is None:
+      raise RuntimeError('Loaded samples are not available.')
+    # flatten the samples.
+    flat_samples = [item for sublist in self._samples for item in sublist]
+
+    def _get_scores(some_list):
+      futures = self._launch_compilation_workers(pool, [[x] for x in some_list])
+      return _extract_results(futures)
+
+    baselines = self._baseline_cache.get_score(flat_samples, _get_scores)
+
+    # TODO: the business of accummulating compilation results is now shared
+    # with the worker.
+    def sum_or_none(lst):
+      return sum(lst) if all(x is not None for x in lst) else None
+
+    self._baselines = [
+        sum_or_none(baselines[i:i + len(self._samples[i])])
+        for i in range(len(self._samples))
+    ]
+
   def get_results(
       self, pool: FixedWorkerPool,
       perturbations: list[bytes]) -> list[concurrent.futures.Future]:
-    # We should have _samples by now.
     if not self._samples:
-      raise RuntimeError('Loaded samples are not available.')
-    return self._launch_compilation_workers(pool, perturbations)
+      self.load_samples()
+    self.ensure_baselines(pool)
+    return self._launch_compilation_workers(pool, self._samples, perturbations)
 
   def set_baseline(self, pool: FixedWorkerPool) -> None:
-    if self._baselines is not None:
-      raise RuntimeError('The baseline has already been set.')
-    self._load_samples()
-    results_futures = self._launch_compilation_workers(pool)
-    self._baselines = _extract_results(results_futures)
+    pass
 
   def get_rewards(
       self,
       results_futures: list[concurrent.futures.Future]) -> list[float | None]:
+    # we need a pool to get the baselines, so we should have gotten them already
     if self._baselines is None:
       raise RuntimeError('The baseline has not been set.')
 
@@ -165,6 +193,7 @@
       else:
         rewards.append(
             compilation_runner.calculate_reward(policy_result, baseline))
+    self._reset()
     return rewards
 
 
diff --git a/compiler_opt/es/blackbox_evaluator_test.py b/compiler_opt/es/blackbox_evaluator_test.py
index 1d68d2c..0c71bd0 100644
--- a/compiler_opt/es/blackbox_evaluator_test.py
+++ b/compiler_opt/es/blackbox_evaluator_test.py
@@ -61,8 +61,8 @@
           train_corpus=test_corpus,
           estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
           total_num_perturbations=1)
-
-      evaluator.set_baseline(pool)
+      evaluator.load_samples()
+      evaluator.ensure_baselines(pool)
       # pylint: disable=protected-access
       self.assertAlmostEqual(evaluator._baselines, [10])
 
@@ -90,7 +90,8 @@
           estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
           total_num_perturbations=2)
 
-      evaluator.set_baseline(pool)
+      evaluator.load_samples()
+      evaluator.ensure_baselines(pool)
 
       f_policy1 = concurrent.futures.Future()
       f_policy1.set_result(1.5)