add the support for dumping best trajectories during training (#153)

* add the option to dump best trajectories during training

* track->dump

* fix test

* typo
diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py
index d85da23..5e807d2 100644
--- a/compiler_opt/rl/compilation_runner.py
+++ b/compiler_opt/rl/compilation_runner.py
@@ -25,12 +25,13 @@
 
 from absl import flags
 from absl import logging
+import tensorflow as tf
+
 from compiler_opt.distributed.worker import Worker
 from compiler_opt.distributed.worker import WorkerFuture
 from compiler_opt.rl import constant
-from compiler_opt.rl import policy_saver
 from compiler_opt.rl import corpus
-import tensorflow as tf
+from compiler_opt.rl import policy_saver
 
 _COMPILATION_TIMEOUT = flags.DEFINE_integer(
     'compilation_timeout', 60,
@@ -219,11 +220,12 @@
   length: total length of all sequence examples, derived from sequence_examples.
   reward_stats: a dictionary from keys (e.g. function names) to a RewardStat.
   rewards: a list of reward values.
+  policy_rewards: a list of reward values under policy.
   keys: a list of keys.
 
   The object must observe the following invariants:
-  1) The entries of sequence_examples, rewards, and keys correspond to eachoter
-  at the same index.
+  1) The entries of sequence_examples, rewards, policy_rewards and keys
+  correspond to each other at the same index.
 
   2) The keys in reward stats are those in the keys field.
   """
@@ -232,6 +234,7 @@
   length: int = dataclasses.field(init=False)
   reward_stats: Dict[str, RewardStat]
   rewards: List[float]
+  policy_rewards: List[float]
   keys: List[str]
 
   def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
@@ -243,8 +246,8 @@
     ]
     object.__setattr__(self, 'length', sum(lengths))
 
-    assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
-            (len(self.keys)))
+    assert (len(self.serialized_sequence_examples) == len(self.rewards) == len(
+        self.policy_rewards) == len(self.keys))
     assert set(self.keys) == set(self.reward_stats.keys())
     assert not hasattr(self, 'sequence_examples')
 
@@ -356,6 +359,7 @@
 
     sequence_example_list = []
     rewards = []
+    policy_rewards = []
     keys = []
     for k, v in policy_result.items():
       sequence_example = v[0]
@@ -376,12 +380,14 @@
           policy_reward * (1 - self._moving_average_decay_rate))
       rewards.append(
           _calculate_reward(policy=policy_reward, baseline=default_reward))
+      policy_rewards.append(policy_reward)
       keys.append(k)
 
     return CompilationResult(
         sequence_examples=sequence_example_list,
         reward_stats=reward_stat,
         rewards=rewards,
+        policy_rewards=policy_rewards,
         keys=keys)
 
   def compile_fn(
diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py
index 56412fc..7bc28aa 100644
--- a/compiler_opt/rl/compilation_runner_test.py
+++ b/compiler_opt/rl/compilation_runner_test.py
@@ -134,6 +134,7 @@
                     (1 - _MOVING_AVERAGE_DECAY_RATE))
         }, data.reward_stats)
     self.assertAllClose([0.1998002], data.rewards)
+    self.assertAllClose([8], data.policy_rewards)
 
   @mock.patch(constant.BASE_MODULE_DIR +
               '.compilation_runner.CompilationRunner.compile_fn')
@@ -162,6 +163,7 @@
                     moving_average_reward=_DEFAULT_REWARD)
         }, data.reward_stats)
     self.assertAllClose([0], data.rewards)
+    self.assertAllClose([10], data.policy_rewards)
 
   @mock.patch(constant.BASE_MODULE_DIR +
               '.compilation_runner.CompilationRunner.compile_fn')
@@ -197,6 +199,7 @@
                     _POLICY_REWARD * (1 - _MOVING_AVERAGE_DECAY_RATE))
         }, data.reward_stats)
     self.assertAllClose([0.199800], data.rewards)
+    self.assertAllClose([8], data.policy_rewards)
 
   @mock.patch(constant.BASE_MODULE_DIR +
               '.compilation_runner.CompilationRunner.compile_fn')
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index eabdc72..59729d6 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -17,13 +17,14 @@
 import concurrent.futures
 import itertools
 import time
-from typing import Callable, Dict, Iterator, List, Tuple, Optional
+from typing import Callable, Dict, Iterator, List, Optional, Tuple
 
 from absl import logging
 from tf_agents.trajectories import trajectory
 
 from compiler_opt.distributed import worker
 from compiler_opt.distributed import buffered_scheduler
+from compiler_opt.rl import best_trajectory
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import corpus
 from compiler_opt.rl import data_collector
@@ -41,6 +42,7 @@
       parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
       reward_stat_map: Dict[str, Optional[Dict[str,
                                                compilation_runner.RewardStat]]],
+      best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo],
       exit_checker_ctor=data_collector.EarlyExitChecker):
     # TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
     super().__init__()
@@ -53,6 +55,7 @@
         compilation_runner
         .CompilationRunnerStub] = self._worker_pool.get_currently_active()
     self._reward_stat_map = reward_stat_map
+    self._best_trajectory_repo = best_trajectory_repo
     self._exit_checker_ctor = exit_checker_ctor
     # _reset_workers is a future that resolves when post-data collection cleanup
     # work completes, i.e. cancelling all work and re-enabling the workers.
@@ -126,7 +129,7 @@
     """Collect data for a given policy.
 
     Args:
-      policy_path: the path to the policy directory to collect data with.
+      policy: a policy_saver.Policy object to collect data with.
 
     Returns:
       An iterator of batched trajectory.Trajectory that are ready to be fed to
@@ -183,6 +186,15 @@
     self._reward_stat_map.update(
         {spec.name: res.reward_stats for (spec, res) in successful_work})
 
+    if self._best_trajectory_repo is not None:
+      for spec, res in successful_work:
+        module_name = spec.name
+        for (identifier, reward,
+             sequence_example) in zip(res.keys, res.policy_rewards,
+                                      res.serialized_sequence_examples):
+          self._best_trajectory_repo.update_if_better_trajectory(
+              module_name, identifier, reward, sequence_example)
+
     monitor_dict = {}
     monitor_dict['default'] = {
         'success_modules': len(successful_work),
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index b48fb1f..c9fceb5 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -16,13 +16,15 @@
 
 # pylint: disable=protected-access
 import collections
-
 import string
 import sys
+from typing import List, Tuple
 
 import tensorflow as tf
 from tf_agents.system import system_multiprocessing as multiprocessing
 
+# This is https://github.com/google/pytype/issues/764
+from google.protobuf import text_format  # pytype: disable=pyi-error
 from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import corpus
@@ -30,10 +32,6 @@
 from compiler_opt.rl import local_data_collector
 from compiler_opt.rl import policy_saver
 
-# This is https://github.com/google/pytype/issues/764
-from google.protobuf import text_format  # pytype: disable=pyi-error
-from typing import List, Tuple
-
 _policy_str = 'policy'.encode(encoding='utf-8')
 
 _mock_policy = policy_saver.Policy(output_spec=bytes(), policy=_policy_str)
@@ -71,6 +69,7 @@
                     default_reward=1, moving_average_reward=2)
         },
         rewards=[1.2],
+        policy_rewards=[36],
         keys=['default'])
   else:
     return compilation_runner.CompilationResult(
@@ -81,6 +80,7 @@
                     default_reward=1, moving_average_reward=3)
         },
         rewards=[3.4],
+        policy_rewards=[18],
         keys=['default'])
 
 
@@ -93,7 +93,11 @@
                                                  self._cancellation_manager)
 
     return compilation_runner.CompilationResult(
-        sequence_examples=[], reward_stats={}, rewards=[], keys=[])
+        sequence_examples=[],
+        reward_stats={},
+        rewards=[],
+        policy_rewards=[],
+        keys=[])
 
 
 class MyRunner(compilation_runner.CompilationRunner):
@@ -154,7 +158,8 @@
           num_modules=9,
           worker_pool=lwp,
           parser=create_test_iterator_fn(),
-          reward_stat_map=collections.defaultdict(lambda: None))
+          reward_stat_map=collections.defaultdict(lambda: None),
+          best_trajectory_repo=None)
 
       # reset the sampler, so the next time we collect, we collect the same
       # modules. We do it before the collect_data call, because that's when
@@ -226,6 +231,7 @@
           worker_pool=lwp,
           parser=parser,
           reward_stat_map=collections.defaultdict(lambda: None),
+          best_trajectory_repo=None,
           exit_checker_ctor=QuickExiter)
       collector.collect_data(policy=_mock_policy)
       collector._join_pending_jobs()
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index ddf6dd9..26526e5 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -19,6 +19,7 @@
 import json
 import os
 import time
+from typing import List
 
 from absl import app
 from absl import flags
@@ -27,10 +28,10 @@
 import tensorflow as tf
 from tf_agents.agents import tf_agent
 from tf_agents.system import system_multiprocessing as multiprocessing
-from typing import List
 
 from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
 from compiler_opt.rl import agent_creators
+from compiler_opt.rl import best_trajectory
 from compiler_opt.rl import compilation_runner
 from compiler_opt.rl import constant
 from compiler_opt.rl import corpus
@@ -69,6 +70,7 @@
                train_sequence_length=1,
                deploy_policy_name='saved_policy',
                use_random_network_distillation=False,
+               dump_best_trajectory=False,
                moving_average_decay_rate=1):
   """Train for LLVM inliner."""
   root_dir = FLAGS.root_dir
@@ -134,6 +136,15 @@
     logging.info('Loaded Reward Stat Map from disk, containing %d modules',
                  len(reward_stat_map))
 
+  best_trajectory_repo = None
+  best_trajecroty_repo_path = os.path.join(root_dir,
+                                           'best_trajectory_repo.json')
+  if dump_best_trajectory:
+    best_trajectory_repo = best_trajectory.BestTrajectoryRepo(
+        action_name=action_spec.name)
+    if tf.io.gfile.exists(best_trajecroty_repo_path):
+      best_trajectory_repo.load_from_json_file(best_trajecroty_repo_path)
+
   with worker_manager_class(
       worker_class=problem_config.get_runner_type(),
       count=FLAGS.num_workers,
@@ -143,7 +154,8 @@
         num_modules=num_modules,
         worker_pool=worker_pool,
         parser=sequence_example_iterator_fn,
-        reward_stat_map=reward_stat_map)
+        reward_stat_map=reward_stat_map,
+        best_trajectory_repo=best_trajectory_repo)
 
     # Repeat for num_policy_iterations iterations.
     t1 = time.time()
@@ -155,6 +167,9 @@
       with tf.io.gfile.GFile(reward_stat_map_path, 'w') as f:
         json.dump(reward_stat_map, f, cls=constant.DataClassJSONEncoder)
 
+      if best_trajectory_repo is not None:
+        best_trajectory_repo.sink_to_json_file(best_trajecroty_repo_path)
+
       policy_path = os.path.join(root_dir, 'policy',
                                  str(llvm_trainer.global_step_numpy()))
       saver.save(policy_path)
diff --git a/compiler_opt/tools/generate_default_trace_test.py b/compiler_opt/tools/generate_default_trace_test.py
index 540009f..94e825f 100644
--- a/compiler_opt/tools/generate_default_trace_test.py
+++ b/compiler_opt/tools/generate_default_trace_test.py
@@ -58,6 +58,7 @@
                     default_reward=1, moving_average_reward=2)
         },
         rewards=[1.2],
+        policy_rewards=[18],
         keys=['default'])