add the support for dumping best trajectories during training (#153)
* add the option to dump best trajectories during training
* track->dump
* fix test
* typo
diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py
index d85da23..5e807d2 100644
--- a/compiler_opt/rl/compilation_runner.py
+++ b/compiler_opt/rl/compilation_runner.py
@@ -25,12 +25,13 @@
from absl import flags
from absl import logging
+import tensorflow as tf
+
from compiler_opt.distributed.worker import Worker
from compiler_opt.distributed.worker import WorkerFuture
from compiler_opt.rl import constant
-from compiler_opt.rl import policy_saver
from compiler_opt.rl import corpus
-import tensorflow as tf
+from compiler_opt.rl import policy_saver
_COMPILATION_TIMEOUT = flags.DEFINE_integer(
'compilation_timeout', 60,
@@ -219,11 +220,12 @@
length: total length of all sequence examples, derived from sequence_examples.
reward_stats: a dictionary from keys (e.g. function names) to a RewardStat.
rewards: a list of reward values.
+ policy_rewards: a list of reward values under policy.
keys: a list of keys.
The object must observe the following invariants:
- 1) The entries of sequence_examples, rewards, and keys correspond to eachoter
- at the same index.
+ 1) The entries of sequence_examples, rewards, policy_rewards and keys
+ correspond to each other at the same index.
2) The keys in reward stats are those in the keys field.
"""
@@ -232,6 +234,7 @@
length: int = dataclasses.field(init=False)
reward_stats: Dict[str, RewardStat]
rewards: List[float]
+ policy_rewards: List[float]
keys: List[str]
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
@@ -243,8 +246,8 @@
]
object.__setattr__(self, 'length', sum(lengths))
- assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
- (len(self.keys)))
+ assert (len(self.serialized_sequence_examples) == len(self.rewards) == len(
+ self.policy_rewards) == len(self.keys))
assert set(self.keys) == set(self.reward_stats.keys())
assert not hasattr(self, 'sequence_examples')
@@ -356,6 +359,7 @@
sequence_example_list = []
rewards = []
+ policy_rewards = []
keys = []
for k, v in policy_result.items():
sequence_example = v[0]
@@ -376,12 +380,14 @@
policy_reward * (1 - self._moving_average_decay_rate))
rewards.append(
_calculate_reward(policy=policy_reward, baseline=default_reward))
+ policy_rewards.append(policy_reward)
keys.append(k)
return CompilationResult(
sequence_examples=sequence_example_list,
reward_stats=reward_stat,
rewards=rewards,
+ policy_rewards=policy_rewards,
keys=keys)
def compile_fn(
diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py
index 56412fc..7bc28aa 100644
--- a/compiler_opt/rl/compilation_runner_test.py
+++ b/compiler_opt/rl/compilation_runner_test.py
@@ -134,6 +134,7 @@
(1 - _MOVING_AVERAGE_DECAY_RATE))
}, data.reward_stats)
self.assertAllClose([0.1998002], data.rewards)
+ self.assertAllClose([8], data.policy_rewards)
@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner.compile_fn')
@@ -162,6 +163,7 @@
moving_average_reward=_DEFAULT_REWARD)
}, data.reward_stats)
self.assertAllClose([0], data.rewards)
+ self.assertAllClose([10], data.policy_rewards)
@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner.compile_fn')
@@ -197,6 +199,7 @@
_POLICY_REWARD * (1 - _MOVING_AVERAGE_DECAY_RATE))
}, data.reward_stats)
self.assertAllClose([0.199800], data.rewards)
+ self.assertAllClose([8], data.policy_rewards)
@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner.compile_fn')
diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py
index eabdc72..59729d6 100644
--- a/compiler_opt/rl/local_data_collector.py
+++ b/compiler_opt/rl/local_data_collector.py
@@ -17,13 +17,14 @@
import concurrent.futures
import itertools
import time
-from typing import Callable, Dict, Iterator, List, Tuple, Optional
+from typing import Callable, Dict, Iterator, List, Optional, Tuple
from absl import logging
from tf_agents.trajectories import trajectory
from compiler_opt.distributed import worker
from compiler_opt.distributed import buffered_scheduler
+from compiler_opt.rl import best_trajectory
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import corpus
from compiler_opt.rl import data_collector
@@ -41,6 +42,7 @@
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
reward_stat_map: Dict[str, Optional[Dict[str,
compilation_runner.RewardStat]]],
+ best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo],
exit_checker_ctor=data_collector.EarlyExitChecker):
# TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
super().__init__()
@@ -53,6 +55,7 @@
compilation_runner
.CompilationRunnerStub] = self._worker_pool.get_currently_active()
self._reward_stat_map = reward_stat_map
+ self._best_trajectory_repo = best_trajectory_repo
self._exit_checker_ctor = exit_checker_ctor
# _reset_workers is a future that resolves when post-data collection cleanup
# work completes, i.e. cancelling all work and re-enabling the workers.
@@ -126,7 +129,7 @@
"""Collect data for a given policy.
Args:
- policy_path: the path to the policy directory to collect data with.
+ policy: a policy_saver.Policy object to collect data with.
Returns:
An iterator of batched trajectory.Trajectory that are ready to be fed to
@@ -183,6 +186,15 @@
self._reward_stat_map.update(
{spec.name: res.reward_stats for (spec, res) in successful_work})
+ if self._best_trajectory_repo is not None:
+ for spec, res in successful_work:
+ module_name = spec.name
+ for (identifier, reward,
+ sequence_example) in zip(res.keys, res.policy_rewards,
+ res.serialized_sequence_examples):
+ self._best_trajectory_repo.update_if_better_trajectory(
+ module_name, identifier, reward, sequence_example)
+
monitor_dict = {}
monitor_dict['default'] = {
'success_modules': len(successful_work),
diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py
index b48fb1f..c9fceb5 100644
--- a/compiler_opt/rl/local_data_collector_test.py
+++ b/compiler_opt/rl/local_data_collector_test.py
@@ -16,13 +16,15 @@
# pylint: disable=protected-access
import collections
-
import string
import sys
+from typing import List, Tuple
import tensorflow as tf
from tf_agents.system import system_multiprocessing as multiprocessing
+# This is https://github.com/google/pytype/issues/764
+from google.protobuf import text_format # pytype: disable=pyi-error
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import corpus
@@ -30,10 +32,6 @@
from compiler_opt.rl import local_data_collector
from compiler_opt.rl import policy_saver
-# This is https://github.com/google/pytype/issues/764
-from google.protobuf import text_format # pytype: disable=pyi-error
-from typing import List, Tuple
-
_policy_str = 'policy'.encode(encoding='utf-8')
_mock_policy = policy_saver.Policy(output_spec=bytes(), policy=_policy_str)
@@ -71,6 +69,7 @@
default_reward=1, moving_average_reward=2)
},
rewards=[1.2],
+ policy_rewards=[36],
keys=['default'])
else:
return compilation_runner.CompilationResult(
@@ -81,6 +80,7 @@
default_reward=1, moving_average_reward=3)
},
rewards=[3.4],
+ policy_rewards=[18],
keys=['default'])
@@ -93,7 +93,11 @@
self._cancellation_manager)
return compilation_runner.CompilationResult(
- sequence_examples=[], reward_stats={}, rewards=[], keys=[])
+ sequence_examples=[],
+ reward_stats={},
+ rewards=[],
+ policy_rewards=[],
+ keys=[])
class MyRunner(compilation_runner.CompilationRunner):
@@ -154,7 +158,8 @@
num_modules=9,
worker_pool=lwp,
parser=create_test_iterator_fn(),
- reward_stat_map=collections.defaultdict(lambda: None))
+ reward_stat_map=collections.defaultdict(lambda: None),
+ best_trajectory_repo=None)
# reset the sampler, so the next time we collect, we collect the same
# modules. We do it before the collect_data call, because that's when
@@ -226,6 +231,7 @@
worker_pool=lwp,
parser=parser,
reward_stat_map=collections.defaultdict(lambda: None),
+ best_trajectory_repo=None,
exit_checker_ctor=QuickExiter)
collector.collect_data(policy=_mock_policy)
collector._join_pending_jobs()
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index ddf6dd9..26526e5 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -19,6 +19,7 @@
import json
import os
import time
+from typing import List
from absl import app
from absl import flags
@@ -27,10 +28,10 @@
import tensorflow as tf
from tf_agents.agents import tf_agent
from tf_agents.system import system_multiprocessing as multiprocessing
-from typing import List
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
from compiler_opt.rl import agent_creators
+from compiler_opt.rl import best_trajectory
from compiler_opt.rl import compilation_runner
from compiler_opt.rl import constant
from compiler_opt.rl import corpus
@@ -69,6 +70,7 @@
train_sequence_length=1,
deploy_policy_name='saved_policy',
use_random_network_distillation=False,
+ dump_best_trajectory=False,
moving_average_decay_rate=1):
"""Train for LLVM inliner."""
root_dir = FLAGS.root_dir
@@ -134,6 +136,15 @@
logging.info('Loaded Reward Stat Map from disk, containing %d modules',
len(reward_stat_map))
+ best_trajectory_repo = None
+ best_trajecroty_repo_path = os.path.join(root_dir,
+ 'best_trajectory_repo.json')
+ if dump_best_trajectory:
+ best_trajectory_repo = best_trajectory.BestTrajectoryRepo(
+ action_name=action_spec.name)
+ if tf.io.gfile.exists(best_trajecroty_repo_path):
+ best_trajectory_repo.load_from_json_file(best_trajecroty_repo_path)
+
with worker_manager_class(
worker_class=problem_config.get_runner_type(),
count=FLAGS.num_workers,
@@ -143,7 +154,8 @@
num_modules=num_modules,
worker_pool=worker_pool,
parser=sequence_example_iterator_fn,
- reward_stat_map=reward_stat_map)
+ reward_stat_map=reward_stat_map,
+ best_trajectory_repo=best_trajectory_repo)
# Repeat for num_policy_iterations iterations.
t1 = time.time()
@@ -155,6 +167,9 @@
with tf.io.gfile.GFile(reward_stat_map_path, 'w') as f:
json.dump(reward_stat_map, f, cls=constant.DataClassJSONEncoder)
+ if best_trajectory_repo is not None:
+ best_trajectory_repo.sink_to_json_file(best_trajecroty_repo_path)
+
policy_path = os.path.join(root_dir, 'policy',
str(llvm_trainer.global_step_numpy()))
saver.save(policy_path)
diff --git a/compiler_opt/tools/generate_default_trace_test.py b/compiler_opt/tools/generate_default_trace_test.py
index 540009f..94e825f 100644
--- a/compiler_opt/tools/generate_default_trace_test.py
+++ b/compiler_opt/tools/generate_default_trace_test.py
@@ -58,6 +58,7 @@
default_reward=1, moving_average_reward=2)
},
rewards=[1.2],
+ policy_rewards=[18],
keys=['default'])