Save models during ES training in a separate thread

This patch makes it so that blackbox_learner saves models during ES
training in a separate thread. There is a bit of additionaly complexity
so that we save the models concurrently while waiting for compilation
results (when the process will be blocked waiting on IO). If we just
spawn jobs whenever we run into performance problems due to the GIL and
are back where we started. We cannot use subprocesses as we cannot
pickle the policy saver function.

This extra complexity is justified through significant performance gains on
training jobs with relatively short step times (up to 20%, still single digit
percentages with step times around 5 minutes).

Reviewers: svkeerthy, mtrofin

Reviewed By: mtrofin

Pull Request: https://github.com/google/ml-compiler-opt/pull/524
diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py
index 656e371..ecc9b5f 100644
--- a/compiler_opt/es/blackbox_learner.py
+++ b/compiler_opt/es/blackbox_learner.py
@@ -17,6 +17,7 @@
 import dataclasses
 import gin
 import math
+import multiprocessing
 import numpy as np
 import numpy.typing as npt
 import tensorflow as tf
@@ -162,6 +163,15 @@
     self._evaluator = self._config.evaluator(self._train_corpus,
                                              self._config.estimator_type)
 
+    self._thread_pool = multiprocessing.pool.ThreadPool(processes=1)
+    self._models_to_save = []
+    self._models_to_flush = []
+
+  def __del__(self):
+    self._start_model_saving()
+    self._flush_models()
+    self._thread_pool.close()
+
   def _get_perturbations(self) -> list[npt.NDArray[np.float32]]:
     """Get perturbations for the model weights."""
     rng = np.random.default_rng(seed=self._seed)
@@ -225,11 +235,29 @@
           len(train_wins) / len(rewards),
           step=self._step)
 
-  def _save_model(self) -> None:
+  def _save_model(self, parameters: npt.NDArray[np.float32],
+                  policy_name: str) -> None:
     """Save the model."""
     logging.info('Saving the model.')
-    self._policy_saver_fn(
-        parameters=self._model_weights, policy_name=f'iteration{self._step}')
+    self._models_to_save.append((parameters, policy_name))
+
+  def _start_model_saving(self):
+    for model_parameters, model_name in self._models_to_save:
+      self._models_to_flush.append(
+          self._thread_pool.apply_async(self._policy_saver_fn,
+                                        (model_parameters, model_name)))
+    self._models_to_save.clear()
+
+  def _flush_models(self):
+    for model_to_flush in self._models_to_flush:
+      model_to_flush.wait()
+      if not model_to_flush.successful():
+        model_to_flush.get()
+    self._models_to_flush.clear()
+
+  def flush_models(self):
+    self._start_model_saving()
+    self._flush_models()
 
   def get_model_weights(self) -> npt.NDArray[np.float32]:
     return self._model_weights
@@ -257,8 +285,10 @@
         for perturbation in initial_perturbations
     ]
 
+    self._start_model_saving()
     results = self._evaluator.get_results(pool, perturbations_as_bytes)
     rewards = self._evaluator.get_rewards(results)
+    self._flush_models()
 
     num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)
     logging.info('Pruned [%d]', num_pruned)
@@ -281,12 +311,13 @@
                    '%d, saving.', self._global_max_reward, self._step)
       max_index = np.argmax(rewards)
       perturbation = initial_perturbations[max_index]
-      self._policy_saver_fn(
+      self._save_model(
           parameters=self._model_weights + perturbation,
           policy_name=f'best_policy_{self._global_max_reward}_step'
           f'_{self._step}',
       )
 
-    self._save_model()
+    self._save_model(
+        parameters=self._model_weights, policy_name=f'iteration{self._step}')
 
     self._step += 1
diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py
index 232d62f..03d67aa 100644
--- a/compiler_opt/es/blackbox_learner_test.py
+++ b/compiler_opt/es/blackbox_learner_test.py
@@ -121,7 +121,7 @@
     # The directory should be unique per test and thus should not exist
     # before we create it. Raise an error otherwise.
     if os.path.exists(self._iteration_policies_path):
-      raise ValueError("Test directory already exists.")
+      raise ValueError('Test directory already exists.')
     os.mkdir(self._iteration_policies_path)
 
     def _policy_saver_fn(parameters: npt.NDArray[np.float32],
@@ -178,6 +178,12 @@
       for value in self._learner.get_model_weights()[:5]:
         self.assertNotAlmostEqual(value, 0.0)
 
+      # Normally the models would be saved asynchronously while
+      # blackbox_learner waits for compilation results. Flush them explicitly
+      # here so we can see the model.
+      self._learner.flush_models()
+      self.assertIn('iteration0', os.listdir(self._iteration_policies_path))
+
   def test_save_best_model(self):
     with local_worker_manager.LocalWorkerPoolManager(
         blackbox_test_utils.ESWorker,
@@ -190,9 +196,12 @@
         }) as pool:
       self._learner.set_baseline(pool)
       self._learner.run_step(pool)
+      self._learner.run_step(pool)
+      # Check the policy from step zero since it will be flushed in step one.
       self.assertIn('best_policy_1.01_step_0',
                     os.listdir(self._iteration_policies_path))
-      self._learner.run_step(pool)
+      # Manually flush the model since we are not going to run another step.
+      self._learner.flush_models()
       self.assertIn('best_policy_1.07_step_1',
                     os.listdir(self._iteration_policies_path))
 
@@ -208,12 +217,15 @@
         }) as pool:
       self._learner.set_baseline(pool)
       self._learner.run_step(pool)
+
+      self._learner.run_step(pool)
+      # CHeck the policy from step zero since it will be flushed in step one.
       self.assertIn('best_policy_0.94_step_0',
                     os.listdir(self._iteration_policies_path))
-
       # Check that the within the next step we only get a new iteration
       # policy and do not save any new best.
       current_policies_count = len(os.listdir(self._iteration_policies_path))
-      self._learner.run_step(pool)
+      # Flush the policies since we are not going to run another step.
+      self._learner.flush_models()
       self.assertLen(
           os.listdir(self._iteration_policies_path), current_policies_count + 1)
diff --git a/compiler_opt/es/es_trainer_lib.py b/compiler_opt/es/es_trainer_lib.py
index 714665b..db8ee58 100644
--- a/compiler_opt/es/es_trainer_lib.py
+++ b/compiler_opt/es/es_trainer_lib.py
@@ -224,4 +224,6 @@
     for _ in range(learner_config.total_steps):
       learner.run_step(pool)
 
+  learner.flush_models()
+
   return learner.get_model_weights()