| # coding=utf-8 |
| # Copyright 2020 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Module for running compilation and collect training data.""" |
| |
| import abc |
| import dataclasses |
| import os |
| import signal |
| import subprocess |
| import tempfile |
| import threading |
| from typing import Dict, List, Optional, Tuple |
| |
| from absl import flags |
| from absl import logging |
| import tensorflow as tf |
| |
| from compiler_opt.distributed.worker import Worker |
| from compiler_opt.distributed.worker import WorkerFuture |
| from compiler_opt.rl import constant |
| from compiler_opt.rl import corpus |
| from compiler_opt.rl import policy_saver |
| |
| _COMPILATION_TIMEOUT = flags.DEFINE_integer( |
| 'compilation_timeout', 60, |
| 'Max duration (in seconds) after which we cancel any compilation job.') |
| _QUIET = flags.DEFINE_bool( |
| 'quiet', True, 'Whether or not to compile quietly (hiding info logging)') |
| |
| |
| def _calculate_reward(policy: float, baseline: float) -> float: |
| # This assumption allows us to imply baseline + constant.DELTA > 0. |
| assert baseline >= 0 |
| return 1 - (policy + constant.DELTA) / (baseline + constant.DELTA) |
| |
| |
| @dataclasses.dataclass |
| class RewardStat: |
| default_reward: float |
| moving_average_reward: float |
| |
| |
| def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample, |
| reward: float) -> tf.train.SequenceExample: |
| """Overwrite the reward in the trace (sequence_example) with the given one. |
| |
| |
| Args: |
| sequence_example: A tf.SequenceExample proto describing compilation trace. |
| reward: The reward to overwrite with. |
| |
| Returns: |
| The tf.SequenceExample proto after post-processing. |
| """ |
| sequence_length = len( |
| next(iter(sequence_example.feature_lists.feature_list.values())).feature) |
| |
| reward_list = sequence_example.feature_lists.feature_list['reward'] |
| for _ in range(sequence_length): |
| added_feature = reward_list.feature.add() |
| added_feature.float_list.value.append(reward) |
| |
| return sequence_example |
| |
| |
| class ProcessKilledError(Exception): |
| |
| def __init__(self): |
| Exception.__init__(self) |
| |
| |
| def kill_process_ignore_exceptions(p: 'subprocess.Popen[bytes]'): |
| # kill the process and ignore exceptions. Exceptions would be thrown if the |
| # process has already been killed/finished (which is inherently in a race |
| # condition with us killing it) |
| try: |
| p.kill() |
| p.wait() |
| finally: |
| return # pylint: disable=lost-exception |
| |
| |
| class WorkerCancellationManager: |
| """A thread-safe object that can be used to signal cancellation. |
| |
| This allows killing long-running processes promptly, and thus efficiently |
| managing resources. |
| """ |
| |
| def __init__(self): |
| # the queue is filled only by workers, and drained only by the single |
| # consumer. we use _done to manage access to the queue. We can then assume |
| # empty() is accurate and get() never blocks. |
| self._processes = set() |
| self._done = False |
| self._paused = False |
| self._lock = threading.Lock() |
| |
| def enable(self): |
| with self._lock: |
| self._done = False |
| |
| def register_process(self, p: 'subprocess.Popen[bytes]'): |
| """Register a process for potential cancellation.""" |
| with self._lock: |
| if not self._done: |
| self._processes.add(p) |
| return |
| kill_process_ignore_exceptions(p) |
| |
| def kill_all_processes(self): |
| """Cancel any pending work.""" |
| with self._lock: |
| self._done = True |
| for p in self._processes: |
| kill_process_ignore_exceptions(p) |
| |
| def pause_all_processes(self): |
| with self._lock: |
| if self._paused: |
| return |
| self._paused = True |
| |
| for p in self._processes: |
| # used to send the STOP signal; does not actually kill the process |
| os.kill(p.pid, signal.SIGSTOP) |
| |
| def resume_all_processes(self): |
| with self._lock: |
| if not self._paused: |
| return |
| self._paused = False |
| |
| for p in self._processes: |
| # used to send the CONTINUE signal; does not actually kill the process |
| os.kill(p.pid, signal.SIGCONT) |
| |
| def unregister_process(self, p: 'subprocess.Popen[bytes]'): |
| with self._lock: |
| if p in self._processes: |
| self._processes.remove(p) |
| |
| def __del__(self): |
| if len(self._processes) > 0: |
| raise RuntimeError('Cancellation manager deleted while containing items.') |
| |
| |
| def start_cancellable_process( |
| cmdline: List[str], |
| timeout: float, |
| cancellation_manager: Optional[WorkerCancellationManager], |
| want_output: bool = False) -> Optional[bytes]: |
| """Start a cancellable process. |
| |
| Args: |
| cmdline: the process executable and command line |
| timeout: process execution timeout |
| cancellation_manager: kill any running process if signaled to do so |
| want_output: if True, return a buffer containing stdout |
| |
| Returns: |
| stdout |
| Raises: |
| CalledProcessError: if the process encounters an error. |
| TimeoutExpired: if the process times out. |
| ProcessKilledError: if the process was killed via the cancellation token. |
| """ |
| command_env = os.environ.copy() |
| # Disable tensorflow info messages during data collection |
| if _QUIET.value: |
| command_env['TF_CPP_MIN_LOG_LEVEL'] = '1' |
| else: |
| logging.info(cmdline) |
| with subprocess.Popen( |
| cmdline, |
| env=command_env, |
| stdout=(subprocess.PIPE if want_output else None)) as p: |
| if cancellation_manager: |
| cancellation_manager.register_process(p) |
| |
| try: |
| retcode = p.wait(timeout=timeout) |
| except subprocess.TimeoutExpired as e: |
| kill_process_ignore_exceptions(p) |
| raise e |
| finally: |
| if cancellation_manager: |
| cancellation_manager.unregister_process(p) |
| |
| if retcode != 0: |
| raise ProcessKilledError( |
| ) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline) |
| else: |
| if want_output: |
| ret: bytes = p.stdout.read() |
| p.stdout.close() |
| return ret |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class CompilationResult: |
| """Result of a call to CompilationRunner.collect_data. |
| |
| sequence_examples: a list of tf.train.SequenceExample protos, init-only |
| variables. |
| serialized_sequence_examples: a list of tf.train.SequenceExample serialized |
| protos, derived from sequence_examples. |
| length: total length of all sequence examples, derived from sequence_examples. |
| reward_stats: a dictionary from keys (e.g. function names) to a RewardStat. |
| rewards: a list of reward values. |
| policy_rewards: a list of reward values under policy. |
| keys: a list of keys. |
| |
| The object must observe the following invariants: |
| 1) The entries of sequence_examples, rewards, policy_rewards and keys |
| correspond to each other at the same index. |
| |
| 2) The keys in reward stats are those in the keys field. |
| """ |
| sequence_examples: dataclasses.InitVar[List[tf.train.SequenceExample]] |
| serialized_sequence_examples: List[str] = dataclasses.field(init=False) |
| length: int = dataclasses.field(init=False) |
| reward_stats: Dict[str, RewardStat] |
| rewards: List[float] |
| policy_rewards: List[float] |
| keys: List[str] |
| |
| def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]): |
| object.__setattr__(self, 'serialized_sequence_examples', |
| [x.SerializeToString() for x in sequence_examples]) |
| lengths = [ |
| len(next(iter(x.feature_lists.feature_list.values())).feature) |
| for x in sequence_examples |
| ] |
| object.__setattr__(self, 'length', sum(lengths)) |
| |
| assert (len(self.serialized_sequence_examples) == len(self.rewards) == len( |
| self.policy_rewards) == len(self.keys)) |
| assert set(self.keys) == set(self.reward_stats.keys()) |
| assert not hasattr(self, 'sequence_examples') |
| |
| |
| class CompilationRunnerStub(metaclass=abc.ABCMeta): |
| """The interface of a stub to CompilationRunner, for type checkers.""" |
| |
| @abc.abstractmethod |
| def collect_data( |
| self, |
| loaded_module_spec: corpus.LoadedModuleSpec, |
| policy: Optional[policy_saver.Policy] = None, |
| reward_stat: Optional[Dict[str, RewardStat]] = None |
| ) -> WorkerFuture[CompilationResult]: |
| raise NotImplementedError() |
| |
| @abc.abstractmethod |
| def cancel_all_work(self) -> WorkerFuture: |
| raise NotImplementedError() |
| |
| @abc.abstractmethod |
| def enable(self) -> WorkerFuture: |
| raise NotImplementedError() |
| |
| |
| class CompilationRunner(Worker): |
| """Base class for collecting compilation data.""" |
| |
| @classmethod |
| def is_priority_method(cls, method_name: str) -> bool: |
| return method_name in { |
| 'cancel_all_work', 'pause_all_work', 'resume_all_work' |
| } |
| |
| def __init__(self, |
| clang_path: Optional[str] = None, |
| launcher_path: Optional[str] = None, |
| moving_average_decay_rate: float = 1, |
| compilation_timeout=None): |
| """Initialization of CompilationRunner class. |
| |
| Args: |
| clang_path: path to the clang binary. |
| launcher_path: path to the launcher binary. |
| moving_average_decay_rate: moving average decay rate during training. |
| """ |
| self._clang_path = clang_path |
| self._launcher_path = launcher_path |
| self._moving_average_decay_rate = moving_average_decay_rate |
| # Avoid reading the flag during the first interpretation of this module. |
| self._compilation_timeout = ( |
| compilation_timeout or _COMPILATION_TIMEOUT.value) |
| self._cancellation_manager = WorkerCancellationManager() |
| |
| # re-allow the cancellation manager accept work. |
| def enable(self): |
| self._cancellation_manager.enable() |
| |
| def cancel_all_work(self): |
| self._cancellation_manager.kill_all_processes() |
| |
| def pause_all_work(self): |
| self._cancellation_manager.pause_all_processes() |
| |
| def resume_all_work(self): |
| self._cancellation_manager.resume_all_processes() |
| |
| def collect_data( |
| self, |
| loaded_module_spec: corpus.LoadedModuleSpec, |
| policy: Optional[policy_saver.Policy] = None, |
| reward_stat: Optional[Dict[str, RewardStat]] = None) -> CompilationResult: |
| """Collect data for the given IR file and policy. |
| |
| Args: |
| loaded_module_spec: a LoadedModuleSpec. |
| policy: serialized policy. |
| reward_stat: reward stat of this module, None if unknown. |
| |
| Returns: |
| A CompilationResult. In particular: |
| reward_stat is the updated reward stat of this module; |
| rewards is rewards under the current ml policy. |
| |
| Raises: |
| subprocess.CalledProcessError if process fails. |
| compilation_runner.ProcessKilledException is passed through. |
| ValueError if example under default policy and ml policy does not match. |
| """ |
| with tempfile.TemporaryDirectory() as tempdir: |
| final_cmd_line = loaded_module_spec.build_command_line(tempdir) |
| tf_policy_path = '' |
| if policy is not None: |
| tf_policy_path = os.path.join(tempdir, 'policy') |
| policy.to_filesystem(tf_policy_path) |
| |
| if reward_stat is None: |
| default_result = self.compile_fn( |
| final_cmd_line, tf_policy_path='', reward_only=bool(tf_policy_path)) |
| reward_stat = { |
| k: RewardStat(v[1], v[1]) for (k, v) in default_result.items() |
| } |
| |
| if tf_policy_path: |
| policy_result = self.compile_fn( |
| final_cmd_line, tf_policy_path, reward_only=False) |
| else: |
| policy_result = default_result |
| |
| sequence_example_list = [] |
| rewards = [] |
| policy_rewards = [] |
| keys = [] |
| for k, v in policy_result.items(): |
| sequence_example = v[0] |
| policy_reward = v[1] |
| if k not in reward_stat: |
| raise ValueError( |
| (f'Example {k} does not exist under default policy for ' |
| f'cmd line: {final_cmd_line}')) |
| default_reward = reward_stat[k].default_reward |
| moving_average_reward = reward_stat[k].moving_average_reward |
| sequence_example = _overwrite_trajectory_reward( |
| sequence_example=sequence_example, |
| reward=_calculate_reward( |
| policy=policy_reward, baseline=moving_average_reward)) |
| sequence_example_list.append(sequence_example) |
| reward_stat[k].moving_average_reward = ( |
| moving_average_reward * self._moving_average_decay_rate + |
| policy_reward * (1 - self._moving_average_decay_rate)) |
| rewards.append( |
| _calculate_reward(policy=policy_reward, baseline=default_reward)) |
| policy_rewards.append(policy_reward) |
| keys.append(k) |
| |
| return CompilationResult( |
| sequence_examples=sequence_example_list, |
| reward_stats=reward_stat, |
| rewards=rewards, |
| policy_rewards=policy_rewards, |
| keys=keys) |
| |
| def compile_fn( |
| self, command_line: corpus.FullyQualifiedCmdLine, tf_policy_path: str, |
| reward_only: bool) -> Dict[str, Tuple[tf.train.SequenceExample, float]]: |
| """Compiles for the given IR file under the given policy. |
| |
| Args: |
| command_line: the fully qualified command line. |
| tf_policy_path: path to TF policy directory on local disk. |
| reward_only: whether only return reward. |
| |
| Returns: |
| A dict mapping from example identifier to tuple containing: |
| sequence_example: A tf.SequenceExample proto describing compilation |
| trace, None if reward_only == True. |
| reward: reward under the policy. |
| |
| Raises: |
| subprocess.CalledProcessError if process fails. |
| ProcessKilledError if the process was killed |
| """ |
| raise NotImplementedError('Not implemented compile fn.') |