Restore best model test functionality

Some of the tests around the "save best model" functionality were
changed up in 6cf15b30ea80282a7a74e7c07092af2887cc1a86 as it touched
SamplingBlackboxEvaluator which is used in the tests in a way that made
them not work. This patch restores the functionality by changing up how
the tests work a bit so they test what they are intended to.

Reviewers: svkeerthy, mtrofin

Reviewed By: svkeerthy

Pull Request: https://github.com/google/ml-compiler-opt/pull/521
diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py
index 99ad925..d08dc99 100644
--- a/compiler_opt/es/blackbox_learner_test.py
+++ b/compiler_opt/es/blackbox_learner_test.py
@@ -176,13 +176,15 @@
         count=1,
         pickle_func=cloudpickle.dumps,
         worker_args=(),
-        worker_kwargs={}) as pool:
+        worker_kwargs={
+            'delta': -1.0,
+            'initial_value': 5
+        }) as pool:
       self._learner.set_baseline(pool)
       self._learner.run_step(pool)
-      self.assertEqual(len(self._saved_policies), 1)
-      self.assertIn('iteration0', self._saved_policies)
+      self.assertIn('best_policy_1.01_step_0', self._saved_policies)
       self._learner.run_step(pool)
-      self.assertIn('iteration1', self._saved_policies)
+      self.assertIn('best_policy_1.07_step_1', self._saved_policies)
 
   def test_save_best_model_only_saves_best(self):
     with local_worker_manager.LocalWorkerPoolManager(
@@ -191,9 +193,15 @@
         pickle_func=cloudpickle.dumps,
         worker_args=(),
         worker_kwargs={
-            'delta': -1.0,
+            'delta': 1.0,
             'initial_value': 5
         }) as pool:
       self._learner.set_baseline(pool)
       self._learner.run_step(pool)
-      self.assertIn('best_policy_100.0_step_0', self._saved_policies)
+      self.assertIn('best_policy_0.94_step_0', self._saved_policies)
+
+      # Check that the within the next step we only get a new iteration
+      # policy and do not save any new best.
+      current_policies_count = len(self._saved_policies)
+      self._learner.run_step(pool)
+      self.assertLen(self._saved_policies, current_policies_count + 1)
diff --git a/compiler_opt/es/blackbox_test_utils.py b/compiler_opt/es/blackbox_test_utils.py
index 714db24..99e9930 100644
--- a/compiler_opt/es/blackbox_test_utils.py
+++ b/compiler_opt/es/blackbox_test_utils.py
@@ -20,6 +20,7 @@
 from compiler_opt.distributed import worker
 from compiler_opt.rl import corpus
 from compiler_opt.rl import policy_saver
+from compiler_opt.rl import constant
 
 
 @gin.configurable
@@ -34,11 +35,14 @@
 
   def compile(self, policy: policy_saver.Policy,
               modules: list[corpus.LoadedModuleSpec]) -> float:
+    # We return the values with constant.DELTA subtracted so that we get
+    # exact values we can assert against when writing tests that only see
+    # the relative reward.
     if policy and modules:
       self.function_value += self._delta
-      return self.function_value
+      return self.function_value - constant.DELTA
     else:
-      return 0.0
+      return 100 - constant.DELTA
 
 
 class SizeReturningESWorker(worker.Worker):