blob: 4722d1cf156c5474f2f875e7f2ee31e53577f4e2 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for running compilation and collect training data."""
from abc import abstractclassmethod, abstractmethod
import concurrent.futures
import dataclasses
import json
import subprocess
import threading
from typing import Dict, List, Optional, Tuple
from absl import flags
import tensorflow as tf
from compiler_opt.rl import constant
_COMPILATION_TIMEOUT = flags.DEFINE_integer(
'compilation_timeout', 60,
'Max duration (in seconds) after which we cancel any compilation job.')
def _calculate_reward(policy: float, baseline: float) -> float:
# This assumption allows us to imply baseline + constant.DELTA > 0.
assert baseline >= 0
return 1 - (policy + constant.DELTA) / (baseline + constant.DELTA)
@dataclasses.dataclass
class RewardStat:
default_reward: float
moving_average_reward: float
class DataClassJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample,
reward: float) -> tf.train.SequenceExample:
"""Overwrite the reward in the trace (sequence_example) with the given one.
Args:
sequence_example: A tf.SequenceExample proto describing compilation trace.
reward: The reward to overwrite with.
Returns:
The tf.SequenceExample proto after post-processing.
"""
sequence_length = len(
next(iter(sequence_example.feature_lists.feature_list.values())).feature)
reward_list = sequence_example.feature_lists.feature_list['reward']
for _ in range(sequence_length):
added_feature = reward_list.feature.add()
added_feature.float_list.value.append(reward)
return sequence_example
def get_command_line_for_bundle(cmd_file: str,
ir_file: str,
thinlto: Optional[str] = None) -> List[str]:
"""Cleans up base command line.
Remove certain unnecessary flags, and add the .bc file to compile and, if
given, the thinlto index.
Args:
cmd_file: Path to a .cmd file (from corpus).
ir_file: The path to the ir file to compile.
thinlto: The path to the thinlto index, or None.
Returns:
The argument list to pass to the compiler process.
"""
cmdline = []
flags_to_remove = [
'-split-dwarf-file', '-split-dwarf-output', '-fthinlto-index',
'-fprofile-sample-use'
]
with open(cmd_file, encoding='utf-8') as f:
option_iterator = iter(f.read().split('\0'))
option = next(option_iterator, None)
while option:
if any(option.startswith(flag) for flag in flags_to_remove):
if '=' not in option:
next(option_iterator, None)
else:
cmdline.append(option)
option = next(option_iterator, None)
cmdline.extend(['-x', 'ir', ir_file])
if thinlto:
cmdline.append('-fthinlto-index=' + thinlto)
return cmdline
class ProcessKilledError(Exception):
def __init__(self):
Exception.__init__(self)
def kill_process_ignore_exceptions(p: 'subprocess.Popen[bytes]'):
# kill the process and ignore exceptions. Exceptions would be thrown if the
# process has already been killed/finished (which is inherently in a race
# condition with us killing it)
try:
p.kill()
p.wait()
finally:
return # pylint: disable=lost-exception
class WorkerCancellationManager:
"""A thread-safe object that can be used to signal cancellation.
This allows killing long-running processes promptly, and thus efficiently
managing resources.
"""
def __init__(self):
# the queue is filled only by workers, and drained only by the single
# consumer. we use _done to manage access to the queue. We can then assume
# empty() is accurate and get() never blocks.
self._processes = set()
self._done = False
self._lock = threading.Lock()
def enable(self):
self._done = False
def register_process(self, p: 'subprocess.Popen[bytes]'):
"""Register a process for potential cancellation."""
with self._lock:
if not self._done:
self._processes.add(p)
return
kill_process_ignore_exceptions(p)
def kill_all_processes(self):
"""Cancel any pending work."""
with self._lock:
self._done = True
for p in self._processes:
kill_process_ignore_exceptions(p)
return len(self._processes)
def unregister_process(self, p: 'subprocess.Popen[bytes]'):
with self._lock:
if not self._done:
self._processes.remove(p)
def start_cancellable_process(
cmdline: List[str],
timeout: float,
cancellation_manager: Optional[WorkerCancellationManager],
want_output: bool = False) -> Optional[bytes]:
"""Start a cancellable process.
Args:
cmdline: the process executable and command line
timeout: process execution timeout
cancellation_manager: kill any running process if signaled to do so
want_output: if True, return a buffer containing stdout
Returns:
stdout
Raises:
CalledProcessError: if the process encounters an error.
TimeoutExpired: if the process times out.
ProcessKilledError: if the process was killed via the cancellation token.
"""
with subprocess.Popen(
cmdline, stdout=(subprocess.PIPE if want_output else None)) as p:
if cancellation_manager:
cancellation_manager.register_process(p)
try:
retcode = p.wait(timeout=timeout)
except subprocess.TimeoutExpired as e:
kill_process_ignore_exceptions(p)
raise e
finally:
if cancellation_manager:
cancellation_manager.unregister_process(p)
if retcode != 0:
raise ProcessKilledError(
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
else:
if want_output:
ret: bytes = p.stdout.read()
p.stdout.close()
return ret
@dataclasses.dataclass(frozen=True)
class CompilationResult:
"""Result of a call to CompilationRunner.collect_data.
sequence_examples: a list of tf.train.SequenceExample protos, init-only
variables.
serialized_sequence_examples: a list of tf.train.SequenceExample serialized
protos, derived from sequence_examples.
length: total length of all sequence examples, derived from sequence_examples.
reward_stats: a dictionary from keys (e.g. function names) to a RewardStat.
rewards: a list of reward values.
keys: a list of keys.
The object must observe the following invariants:
1) The entries of sequence_examples, rewards, and keys correspond to eachoter
at the same index.
2) The keys in reward stats are those in the keys field.
"""
sequence_examples: dataclasses.InitVar[List[tf.train.SequenceExample]]
serialized_sequence_examples: List[str] = dataclasses.field(init=False)
length: int = dataclasses.field(init=False)
reward_stats: Dict[str, RewardStat]
rewards: List[float]
keys: List[str]
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
object.__setattr__(self, 'serialized_sequence_examples',
[x.SerializeToString() for x in sequence_examples])
lengths = [
len(next(iter(x.feature_lists.feature_list.values())).feature)
for x in sequence_examples
]
object.__setattr__(self, 'length', sum(lengths))
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
(len(self.keys)))
assert set(self.keys) == set(self.reward_stats.keys())
assert not hasattr(self, 'sequence_examples')
class CompilationRunnerStub:
"""The interface of a stub to CompilationRunner."""
@abstractmethod
def collect_data(
self, file_paths: Tuple[str, ...], tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]
) -> concurrent.futures.Future[CompilationResult]:
...
@abstractmethod
def cancel_all_work(self) -> concurrent.futures.Future:
...
class CompilationRunner:
"""Base class for collecting compilation data."""
def __init__(self,
clang_path: Optional[str] = None,
launcher_path: Optional[str] = None,
moving_average_decay_rate: float = 1):
"""Initialization of CompilationRunner class.
Args:
clang_path: path to the clang binary.
launcher_path: path to the launcher binary.
moving_average_decay_rate: moving average decay rate during training.
"""
self._clang_path = clang_path
self._launcher_path = launcher_path
self._moving_average_decay_rate = moving_average_decay_rate
self._compilation_timeout = _COMPILATION_TIMEOUT.value
self._cancellation_manager = WorkerCancellationManager()
def enable(self):
self._cancellation_manager.enable()
def cancel_all_work(self):
return self._cancellation_manager.kill_all_processes()
def collect_data(
self, file_paths: Tuple[str, ...], tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
"""Collect data for the given IR file and policy.
Args:
file_paths: path to files needed for inlining, Tuple of (.bc, .cmd).
tf_policy_path: path to the tensorflow policy.
reward_stat: reward stat of this module, None if unknown.
cancellation_token: a CancellationToken through which workers may be
signaled early termination
Returns:
A CompilationResult. In particular:
reward_stat is the updated reward stat of this module;
rewards is rewards under the current ml policy.
Raises:
subprocess.CalledProcessError if process fails.
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
if reward_stat is None:
default_result = self._compile_fn(
file_paths,
tf_policy_path='',
reward_only=bool(tf_policy_path),
cancellation_manager=self._cancellation_manager)
reward_stat = {
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
}
if tf_policy_path:
policy_result = self._compile_fn(
file_paths,
tf_policy_path,
reward_only=False,
cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result
sequence_example_list = []
rewards = []
keys = []
for k, v in policy_result.items():
sequence_example = v[0]
policy_reward = v[1]
if k not in reward_stat:
raise ValueError(
(f'Example {k} does not exist under default policy for '
'module {file_paths[0]}'))
default_reward = reward_stat[k].default_reward
moving_average_reward = reward_stat[k].moving_average_reward
sequence_example = _overwrite_trajectory_reward(
sequence_example=sequence_example,
reward=_calculate_reward(
policy=policy_reward, baseline=moving_average_reward))
sequence_example_list.append(sequence_example)
reward_stat[k].moving_average_reward = (
moving_average_reward * self._moving_average_decay_rate +
policy_reward * (1 - self._moving_average_decay_rate))
rewards.append(
_calculate_reward(policy=policy_reward, baseline=default_reward))
keys.append(k)
return CompilationResult(
sequence_examples=sequence_example_list,
reward_stats=reward_stat,
rewards=rewards,
keys=keys)
def _compile_fn(
self, file_paths: Tuple[str, ...], tf_policy_path: str, reward_only: bool,
cancellation_manager: Optional[WorkerCancellationManager]
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
"""Compiles for the given IR file under the given policy.
Args:
file_paths: path to files needed for compilation.
tf_policy_path: path to TF policy directory on local disk.
reward_only: whether only return reward.
cancellation_manager: a WorkerCancellationManager to handle early
termination
Returns:
A dict mapping from example identifier to tuple containing:
sequence_example: A tf.SequenceExample proto describing compilation
trace, None if reward_only == True.
reward: reward under the policy.
Raises:
subprocess.CalledProcessError if process fails.
ProcessKilledError if the process was killed
"""
raise NotImplementedError('Not implemented compile fn.')