blob: 3c90edc47823caa224f4510cf4361f163419da77 [file] [log] [blame]
"""Module for collect data of the LR encoder."""
import os
import tempfile
from typing import Dict, Tuple
import gin
import tensorflow as tf
from google3.third_party.ml_compiler_opt.compiler_opt.rl import compilation_runner
from google3.third_party.ml_compiler_opt.compiler_opt.rl import corpus
from google3.third_party.ml_compiler_opt.compiler_opt.rl import log_reader
@gin.configurable(module='runners')
class LREncoderRunner(compilation_runner.CompilationRunner):
"""Class for collecting data for the LR encoder."""
def compile_fn(
self,
command_line: corpus.FullyQualifiedCmdLine,
tf_policy_path: str,
reward_only: bool,
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
"""Run compilation for the given IR file under the given policy.
Args:
command_line: the fully qualified command line.
tf_policy_path: path to TF policy direcoty on local disk.
reward_only: whether only return reward.
Returns:
A dict mapping from example identifier to tuple containing:
sequence_example: A tf.SequenceExample proto describing compilation
trace, None if reward_only == True.
reward: reward of register allocation.
Raises:
subprocess.CalledProcessError: if process fails.
compilation_runner.ProcessKilledError: (which it must pass through) on
cancelled work.
RuntimeError: if llvm-size produces unexpected output.
"""
assert not tf_policy_path
working_dir = tempfile.mkdtemp()
log_path = os.path.join(working_dir, 'log')
output_native_path = os.path.join(working_dir, 'native')
result = {}
try:
cmdline = []
if self._launcher_path:
cmdline.append(self._launcher_path)
cmdline.extend(
[self._clang_path]
+ list(command_line)
+ [
'-mllvm',
'-regalloc-enable-advisor=development',
'-mllvm',
'-regalloc-lr-encoder-training-log=' + log_path,
'-mllvm',
'-regalloc-training-log=/dev/null',
'-o',
output_native_path,
]
)
compilation_runner.start_cancellable_process(
cmdline, self._compilation_timeout, self._cancellation_manager
)
if not os.path.exists(log_path):
return {}
# TODO(#202)
log_result = log_reader.read_log_as_sequence_examples(log_path)
for fct_name, trajectory in log_result.items():
if not trajectory.HasField('feature_lists'):
continue
# score = (
# trajectory.feature_lists.feature_list['reward']
# .feature[-1]
# .float_list.value[0]
# )
result[fct_name] = (trajectory, 1.0)
finally:
tf.io.gfile.rmtree(working_dir)
return result