| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Multi-process runner for testing purpose.""" |
| |
| import collections |
| import contextlib |
| import json |
| import os |
| import signal |
| import sys |
| import threading |
| import time |
| import unittest |
| import weakref |
| |
| from absl import logging |
| import six |
| from six.moves import queue as Queue |
| |
| from tensorflow.python import tf2 |
| from tensorflow.python.compat import v2_compat |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.distribute import multi_process_lib |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.util.tf_export import tf_export |
| |
| multiprocessing = multi_process_lib.multiprocessing |
| |
| # pylint: disable=g-import-not-at-top |
| try: |
| # `faulthandler` is not available in py2. |
| import faulthandler |
| except ImportError: |
| faulthandler = None |
| |
| # TODO(b/150264776): Remove after resolving CI issue. |
| try: |
| import dill |
| except ImportError: |
| dill = None |
| |
| # TODO(b/150264776): Remove after resolving CI issue. |
| try: |
| import tblib.pickling_support |
| # For pickling traceback objects. |
| tblib.pickling_support.install() |
| except ImportError: |
| pass |
| |
| |
| # _ProcessStatusInfo contains process status information. When is_successful |
| # attribute is True, the subprocess has ended successfully, or if False, the |
| # exception stack trace info is stored in exc_info to pass on to parent process |
| # to be re-raised. |
| _ProcessStatusInfo = collections.namedtuple( |
| '_ProcessStatusInfo', |
| ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value']) |
| |
| # Information returned from a successful MultiProcessRunner run. |
| MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', |
| ['return_value', 'stdout']) |
| |
| # visible_gpus: If not None, CUDA_VISIBLE_DEVICES is set to visible_gpus. |
| TestEnvironment = collections.namedtuple('TestEnvironment', [ |
| 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', |
| 'v2_enabled', 'executing_eagerly', 'visible_gpus' |
| ]) |
| |
| # Resources for communication between worker processes and the main process. |
| # |
| # `process_status_queue` is used by `multi_process_runner` internally for |
| # communication from subprocesses to the parent process for whether it's been |
| # successful, and if not what the error stack trace is. |
| # `parent_to_sub_queue` is used for communications from parent to subprocess. |
| # Currently this is only used to terminate subprocesses. |
| # TODO(rchao): Remove this once subprocess is terminated by SIGKILL. |
| # `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent |
| # process. |
| # `barrier` is a barrier for the party of all subprocesses. |
| Resources = collections.namedtuple('Resources', [ |
| 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' |
| ]) |
| |
| # Default time out sec is selected so that it's handled before the default |
| # "medium" timeout of the test runs. |
| _DEFAULT_TIMEOUT_SEC = 200 |
| |
| # The timeout in seconds to wait to force kill a child process. When a child |
| # process times out we first try to SIGTERM it so that it has a chance to dump |
| # stacktraces. However dumping stacktrace can take a long time. |
| _FORCE_KILL_WAIT_SEC = 30 |
| |
| |
| class MultiProcessRunner(object): |
| """A utility class to start multiple processes to simulate a cluster. |
| |
| We need to use multiple processes to simulate a cluster in TF 2.0 tests |
| because TF 2.0 has some process-global data structures that have to be |
| separated by processes. We also need child processes to test out our fault |
| tolerance because shutting down a standard TensorFlow server within its |
| process is not supported. |
| |
| Note: the main test program that uses this runner class must run main program |
| via `test_main` defined in this file. Using this runner in non-test binaries |
| is not supported yet. |
| |
| This class is not thread-safe. Child processes will inherit TF2 behavior flag. |
| """ |
| |
| def __init__(self, |
| fn, |
| cluster_spec, |
| rpc_layer=None, |
| max_run_time=None, |
| grpc_fail_fast=None, |
| stream_output=True, |
| return_output=False, |
| use_dill_for_args=True, |
| daemon=False, |
| dependence_on_chief=True, |
| auto_restart=False, |
| share_gpu=True, |
| args=None, |
| kwargs=None): |
| """Instantiation of a `MultiProcessRunner`. |
| |
| Args: |
| fn: Function to be run on child processes. This will be run on processes |
| for all task types. |
| cluster_spec: Dict for cluster spec. The utility function |
| `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` |
| can be conveniently used to create such dict. The following is an |
| example of cluster with three workers and two ps's. |
| {"worker": ["worker0.example.com:2222", |
| "worker1.example.com:2222", |
| "worker2.example.com:2222"], |
| "ps": ["ps0.example.com:2222", |
| "ps1.example.com:2222"]} |
| rpc_layer: RPC layer to use. Default value is 'grpc'. |
| max_run_time: `None` or integer. If not `None`, child processes are forced |
| to exit at approximately this many seconds after this utility is called. |
| We achieve this through `signal.alarm()` api. Note that this is best |
| effort at Python level since Python signal handler does not get executed |
| when it runs lower level C/C++ code. So it can be delayed for |
| arbitrarily long time. If any of the child process is still running when |
| `max_run_time` is up, they will be force-terminated and an |
| `UnexpectedSubprocessExitError` may be raised. If `None`, child |
| processes are not forced to exit. |
| grpc_fail_fast: Whether GRPC connection between processes should fail |
| without retrying. Defaults to None, in which case the environment |
| variable is not explicitly set. |
| stream_output: True if the output/error from the subprocesses should be |
| streamed to be printed in parent process' log. Defaults to True. |
| return_output: If True, the output/error from the subprocesses should be |
| collected to be attached to the resulting namedtuple returned from |
| `join()`. The list of output can be retrieved via `stdout` attribute. |
| Defaults to False. |
| use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill |
| can pickle more objects, but doesn't work with types in |
| `multiprocessing` library like `Mutex`. |
| daemon: Whether to start processes as daemons. |
| dependence_on_chief: Whether to terminates the cluster if the chief exits. |
| If auto_restart is True, it only terminates the cluster if the chief |
| exits with a zero exit code. |
| auto_restart: Whether to automatically restart processes that exit with |
| non-zero exit code. |
| share_gpu: Whether to share GPUs among workers. If False, each worker is |
| assigned different GPUs in a roundrobin fashion. This should be True |
| whenever possible for better test execution coverage; some situations |
| that need it to be False are tests that runs NCCL. |
| args: Positional arguments to be sent to `fn` run on subprocesses. |
| kwargs: Keyword arguments to be sent to `fn` run on subprocesses. |
| |
| Raises: |
| RuntimeError: if `multi_process_runner.test_main()` is not called. |
| ValueError: if there are more than one chief in the `cluster_spec`. |
| SkipTest: if thread sanitizer is enabled (which is incompatible with MPR). |
| """ |
| if test_util.is_tsan_enabled(): |
| raise unittest.SkipTest( |
| 'ThreadSanitizer is not compatible with MultiProcessRunner.') |
| |
| assert cluster_spec is not None |
| if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: |
| raise ValueError('If chief exists in the cluster, there must be at most ' |
| 'one chief. Current `cluster_spec` has {} chiefs.' |
| .format(len(cluster_spec['chief']))) |
| _check_initialization() |
| if not callable(fn): |
| raise ValueError('fn is not a callable') |
| |
| self._fn = fn |
| self._cluster_spec = cluster_spec |
| self._rpc_layer = rpc_layer or 'grpc' |
| self._max_run_time = max_run_time |
| self._grpc_fail_fast = grpc_fail_fast |
| self._stream_output = stream_output |
| # TODO(rchao): Revisit return_output argument to consider other solution. |
| self._return_output = return_output |
| self._dependence_on_chief = dependence_on_chief |
| self._use_dill_for_args = use_dill_for_args |
| self._daemon = daemon |
| self._auto_restart = auto_restart |
| self._args = args or () |
| self._kwargs = kwargs or {} |
| |
| self._share_gpu = share_gpu |
| self._total_gpu = len(context.context().list_physical_devices('GPU')) |
| |
| # Child processes should have the same v2 and eager behavior. |
| self._v2_enabled = tf2.enabled() |
| self._executing_eagerly = context.executing_eagerly() |
| |
| self._joined = False |
| self._process_lock = threading.Lock() |
| # Guarded by self._process_lock. |
| self._processes = {} |
| # Record which processes are terminated. Due to a bug in Python<3.7, |
| # terminated processes return 255 exit code, which should cause an exception |
| # in join(). |
| # https://bugs.python.org/issue30589 |
| # Guarded by self._process_lock. |
| self._terminated = set() |
| self._reading_threads = [] |
| |
| self._manager = manager() |
| self._process_status_queue = self._manager.Queue() |
| self._parent_to_sub_queue = self._manager.Queue() |
| parties = sum(len(addresses) for addresses in self._cluster_spec.values()) |
| self._barrier = self._manager.Barrier(parties) |
| |
| # We use a queue to collect outputs from worker processes since it's thread |
| # safe. |
| self._streaming_queue = self._manager.Queue() |
| |
| self._watchdog_thread = None |
| |
| def set_args(self, args=None, kwargs=None): |
| self._args = args or self._args |
| self._kwargs = kwargs or self._kwargs |
| |
| def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): |
| """Function to continuously read lines from subprocesses.""" |
| with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: |
| for line in reader: |
| task_string = '[{}-{}]:'.format(task_type, task_id) |
| formatted_line = '{} {}'.format(task_string.ljust(14), line) |
| if self._stream_output: |
| # TODO(rchao): Use a lock here to ensure the printed lines are not |
| # broken. |
| print(formatted_line, end='', flush=True) |
| if self._return_output: |
| self._streaming_queue.put(formatted_line) |
| |
| def _start_subprocess_and_reading_thread(self, |
| task_type, |
| task_id, |
| cluster_spec=None, |
| fn=None, |
| args=None, |
| kwargs=None): |
| """Start a subprocess and a thread the reads lines from the subprocess.""" |
| |
| if dill is None: |
| raise unittest.SkipTest( |
| 'TODO(b/150264776): Resolve dependency issue in CI') |
| |
| cluster_spec = cluster_spec or self._cluster_spec |
| visible_gpus = None |
| if not self._share_gpu and self._total_gpu > 0: |
| # Assign GPUs in a roundrobin fashion. |
| id_in_cluster = multi_worker_util.id_in_cluster(cluster_spec, task_type, |
| task_id) |
| worker_count = multi_worker_util.worker_count(cluster_spec, task_type) |
| visible_gpus = list(range(id_in_cluster, self._total_gpu, worker_count)) |
| |
| test_env = TestEnvironment( |
| task_type=task_type, |
| task_id=task_id, |
| cluster_spec=cluster_spec, |
| rpc_layer=self._rpc_layer, |
| grpc_fail_fast=self._grpc_fail_fast, |
| v2_enabled=self._v2_enabled, |
| executing_eagerly=self._executing_eagerly, |
| visible_gpus=visible_gpus, |
| ) |
| pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) |
| resources = Resources( |
| process_status_queue=self._process_status_queue, |
| parent_to_sub_queue=self._parent_to_sub_queue, |
| streaming_pipe_w=pipe_w, |
| barrier=self._barrier, |
| ) |
| if fn is None: |
| fn, args, kwargs = self._fn, self._args, self._kwargs |
| # Always use dill to pickle fn so that we support more callable |
| # types, e.g. lambda. |
| fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) |
| if self._use_dill_for_args: |
| args = dill.dumps(args, dill.HIGHEST_PROTOCOL) |
| kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) |
| |
| p = _Process( |
| test_env=test_env, |
| target=_ProcFunc(), |
| args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args), |
| daemon=self._daemon) |
| p.start() |
| self._processes[(task_type, task_id)] = p |
| self._terminated.discard((task_type, task_id)) |
| |
| # For each subprocess, we dedicate a thread continuously reading lines |
| # from them. |
| thread = threading.Thread( # pylint: disable=unexpected-keyword-arg |
| target=self._continuously_readline_from_sub, |
| args=(pipe_r, task_type, task_id)) |
| thread.start() |
| self._reading_threads.append(thread) |
| |
| if self._watchdog_thread is None or not self._watchdog_thread.is_alive(): |
| self._watchdog_thread = threading.Thread(target=self._process_watchdog) |
| self._watchdog_thread.start() |
| |
| def start(self): |
| """Starts processes, one for each task in `cluster_spec`. |
| |
| Note that this is best effort by the applicable multiprocessing library, |
| and it may take up to seconds for a subprocess to be successfully started. |
| """ |
| with self._process_lock: |
| if self._processes: |
| raise ValueError('MultiProcessRunner already started.') |
| if self._joined: |
| raise ValueError('cannot start new processes after' |
| 'MultiProcessRunner.join() is called') |
| |
| for task_type, addresses in self._cluster_spec.items(): |
| for task_id, _ in enumerate(addresses): |
| self._start_subprocess_and_reading_thread(task_type, task_id) |
| |
| # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, |
| # without this the tests become very flaky. |
| if self._max_run_time is not None: |
| |
| def handler(signum, frame): |
| del signum, frame |
| self.terminate_all() |
| |
| signal.signal(signal.SIGALRM, handler) |
| signal.alarm(self._max_run_time) |
| |
| def start_in_process_as(self, as_task_type, as_task_id): |
| """Start the processes, with the specified task run in main process. |
| |
| This is similar to `start()` except that the task with task_type |
| `as_task_type` and task_id `as_task_id` is run in the main process. |
| This method is particularly useful when debugging tool such as `pdb` is |
| needed in some specific task. Note that since this method is blocking until |
| that specific task exits, additional actions would need a thread to be |
| called: |
| |
| ```python |
| def fn(): |
| # user code to be run |
| import pdb; pdb.set_trace() |
| |
| def follow_ups(): |
| time.sleep(5) |
| mpr.start_single_process( |
| task_type='evaluator', |
| task_id=0) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1)) |
| threading.Thread(target=follow_ups).start() |
| mpr.start_in_process_as(as_task_type='chief', as_task_id=0) |
| mpr.join() |
| ``` |
| |
| Note that if `return_output=True`, the logs/stdout by task |
| run by the main process is not available in result.stdout. |
| |
| Args: |
| as_task_type: The task type to be run in the main process. |
| as_task_id: The task id to be run in the main process. |
| """ |
| if self._processes: |
| raise ValueError('MultiProcessRunner already started.') |
| with self._process_lock: |
| if self._joined: |
| raise ValueError('cannot start new processes after' |
| 'MultiProcessRunner.join() is called') |
| for task_type, addresses in self._cluster_spec.items(): |
| for task_id, _ in enumerate(addresses): |
| if not (task_type == as_task_type and task_id == as_task_id): |
| self._start_subprocess_and_reading_thread(task_type, task_id) |
| |
| _set_tf_config(as_task_type, as_task_id, self._cluster_spec, |
| self._rpc_layer) |
| self._fn(*self._args, **self._kwargs) |
| |
| def start_single_process(self, |
| task_type, |
| task_id, |
| cluster_spec=None, |
| fn=None, |
| args=None, |
| kwargs=None): |
| """Starts a single process. |
| |
| This starts a process in the cluster with the task type, task id, and the |
| process function (`fn`). If process function is `None`, the function |
| provided at `__init__` will be used. If `cluster_spec` is `None`, the |
| cluster spec provided at `__init__` will be used. |
| |
| TODO(rchao): It is meant that all subprocesses will be updated with the new |
| cluster spec, but this has yet to be implemented. At this time only the |
| newly started subprocess picks up this updated cluster spec. |
| |
| Args: |
| task_type: The task type. |
| task_id: The task id. |
| cluster_spec: The cluster spec to be used on the newly started |
| process. If `None`, the cluster spec provided at `__init__` will be |
| used. |
| fn: The process function to be run on the newly started |
| process. If specified, specify `args` and `kwargs` as well. If `None`, |
| the function provided at `__init__` will be used. |
| args: Optional positional arguments to be supplied in `fn`. |
| kwargs: Optional keyword arguments to be supplied in `fn`. |
| """ |
| with self._process_lock: |
| if self._joined: |
| raise ValueError('cannot start new processes after' |
| 'MultiProcessRunner.join() is called') |
| self._start_subprocess_and_reading_thread( |
| task_type, |
| task_id, |
| cluster_spec=cluster_spec, |
| fn=fn, |
| args=args or (), |
| kwargs=kwargs or {}) |
| |
| def _queue_to_list(self, queue_to_convert): |
| """Convert `queue.Queue` to `list`.""" |
| list_to_return = [] |
| # Calling `queue.empty()` is not reliable. |
| while True: |
| try: |
| list_to_return.append(queue_to_convert.get(block=False)) |
| except Queue.Empty: |
| break |
| return list_to_return |
| |
| def _get_process_statuses(self): |
| # One worker may have multiple statuses. We only keep the last one. |
| statuses = {} |
| for status in self._queue_to_list(self._process_status_queue): |
| statuses[(status.task_type, status.task_id)] = status |
| return statuses |
| |
| def get_process_id(self, task_type, task_id): |
| """Returns the subprocess id given the task type and task id.""" |
| with self._process_lock: |
| p = self._processes.get((task_type, task_id), None) |
| return p.pid if p else None |
| |
| def get_process_exit_code(self, task_type, task_id): |
| """Returns the subprocess exit code given the task type and task id. |
| |
| Args: |
| task_type: The task type. |
| task_id: The task id. |
| |
| Returns: |
| The subprocess exit code; `None` if the subprocess has not exited yet. |
| |
| Raises: |
| KeyError: If the corresponding subprocess is not found with `task_type` |
| and `task_id`. |
| """ |
| with self._process_lock: |
| p = self._processes[(task_type, task_id)] |
| return p.exitcode if p else None |
| |
| def process_exists(self, task_type, task_id): |
| """Returns whether the subprocess still exists given the task type and id. |
| |
| Args: |
| task_type: The task type. |
| task_id: The task id. |
| |
| Returns: |
| Boolean; whether the subprocess still exists. If the subprocess has |
| exited, this returns False. |
| """ |
| return self.get_process_exit_code(task_type, task_id) is None |
| |
| def _process_watchdog(self): |
| """Simulates a cluster management system. |
| |
| - If auto_restart is True, it restarts processes that exit with a non-zero |
| exit code. Note that when join() times out it overrides auto_restart to |
| False. |
| - If dependence_on_chief is True, it terminates all processes once the chief |
| exits. If auto_restart is also True, it only terminates all processes if |
| the chief exit with a zero exit code, otherwise it restarts the chief. |
| |
| This runs in self._watchdog_thread. |
| """ |
| while True: |
| time.sleep(1) |
| with self._process_lock: |
| chief = self._processes.get(('chief', 0), None) |
| # Terminate the cluster when _dependence_on_chief is True if either: |
| # - chief has exited with zero exit code. |
| # - chief has exited with non-zero exit code and self._auto_restart is |
| # False. |
| if chief and self._dependence_on_chief and chief.exitcode is not None: |
| if chief.exitcode == 0 or (not self._auto_restart): |
| for p in self._processes.values(): |
| # Give other processes a chance to exit on their own. |
| p.join(timeout=3) |
| self._terminate_all() |
| for p in self._processes.values(): |
| p.join() |
| return |
| |
| # Auto restart failed processes if self._auto_restart is True. |
| if self._auto_restart: |
| has_failure = False |
| for (task_type, task_id), p in self._processes.items(): |
| if p.exitcode is not None and p.exitcode != 0: |
| has_failure = True |
| logging.info('Restarting failed %s-%d', task_type, task_id) |
| self._start_subprocess_and_reading_thread(task_type, task_id) |
| if has_failure: |
| continue |
| |
| # Exit the thread if all processes have exited at this point. |
| if all(p.exitcode is not None for p in self._processes.values()): |
| return |
| |
| def _reraise_if_subprocess_error(self, process_statuses): |
| for process_status in process_statuses.values(): |
| assert isinstance(process_status, _ProcessStatusInfo) |
| if not process_status.is_successful: |
| process_status.exc_info[1].mpr_result = self._get_mpr_result( |
| process_statuses) |
| six.reraise(*process_status.exc_info) |
| |
| def join(self, timeout=_DEFAULT_TIMEOUT_SEC): |
| """Joins all the processes with timeout. |
| |
| If any of the subprocesses does not exit approximately after `timeout` |
| seconds has passed after `join` call, this raises a |
| `SubprocessTimeoutError`. |
| |
| Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to |
| log the stack traces of the subprocesses when they exit. However, this |
| results in timeout when the test runs with tsan (thread sanitizer); if tsan |
| is being run on the test targets that rely on timeout to assert information, |
| `MultiProcessRunner.terminate_all()` must be called after `join()`, before |
| the test exits, so the subprocesses are terminated with SIGKILL, and data |
| race is removed. |
| |
| Args: |
| timeout: optional integer or `None`. If provided as an integer, and not |
| all processes report status within roughly `timeout` seconds, a |
| `SubprocessTimeoutError` exception will be raised. If `None`, `join` never |
| times out. |
| |
| Returns: |
| A `MultiProcessRunnerResult` object, which has two attributes, |
| `return_value` and `stdout`. `return_value` always contains a list of |
| return values from the subprocesses, although the order is not meaningful. |
| If `return_output` argument is True at `__init__`, `stdout` is available |
| that contains a list of all messages from subprocesses' stdout and stderr. |
| |
| Raises: |
| SubprocessTimeoutError: if not all processes report status approximately |
| within `timeout` seconds. When this is raised, a |
| `MultiProcessRunnerResult` object can be retrieved by |
| `SubprocessTimeoutError`'s mpr_result attribute, which has the same |
| structure as above 'Returns' section describes. |
| UnexpectedSubprocessExitError: If any of the subprocesses did not exit |
| properly (for example, they exit on SIGTERM or SIGKILL signal). When |
| this is raised, a `MultiProcessRunnerResult` object can be retrieved by |
| `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the |
| same structure as above 'Returns' section describes. If `max_run_time` |
| is not `None`, it is expected that some subprocesses may be |
| force-killed when `max_run_time` is up, and this is raised in those |
| cases. |
| Exception: if there is an Exception propagated from any subprocess. When |
| this is raised, a `MultiProcessRunnerResult` object can be retrieved by |
| `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the |
| same structure as above 'Returns' section describes. |
| """ |
| if timeout and not isinstance(timeout, int): |
| raise ValueError('`timeout` must be an integer or `None`.') |
| with self._process_lock: |
| if self._joined: |
| raise ValueError("MultiProcessRunner can't be joined twice.") |
| self._joined = True |
| |
| self._watchdog_thread.join(timeout) |
| if self._watchdog_thread.is_alive(): |
| # Timeout. Force termination to dump worker processes stack trace. |
| with self._process_lock: |
| self._auto_restart = False |
| logging.error('Timeout when joining for child processes. Terminating...') |
| self.terminate_all(sig=signal.SIGTERM) |
| # Wait for the processes to terminate by themselves first, so they have a |
| # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them. |
| self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC) |
| if self._watchdog_thread.is_alive(): |
| logging.error('Timeout when waiting for child processes to ' |
| 'print stacktrace. Sending SIGKILL...') |
| self.terminate_all() |
| self._watchdog_thread.join() |
| process_statuses = self._get_process_statuses() |
| self._reraise_if_subprocess_error(process_statuses) |
| raise SubprocessTimeoutError( |
| 'One or more subprocesses timed out, where timeout was set to {}s. ' |
| 'Please change the `timeout` argument for ' |
| '`MultiProcessRunner.join()` or `multi_process_runner.run()` ' |
| 'if it should be adjusted.'.format(timeout), |
| self._get_mpr_result(process_statuses)) |
| |
| for (task_type, task_id), p in self._processes.items(): |
| logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) |
| |
| process_statuses = self._get_process_statuses() |
| self._reraise_if_subprocess_error(process_statuses) |
| |
| # Checking all the processes that are expected to exit properly. |
| for (task_type, task_id), p in self._processes.items(): |
| # Successfully exiting process has exit code 0. We ignore processes that |
| # are terminated. |
| assert p.exitcode is not None |
| if (p.exitcode > 0 and (task_type, task_id) not in self._terminated): |
| raise UnexpectedSubprocessExitError( |
| 'Subprocess %s-%d exited with exit code %s. See logs for details.' |
| % (task_type, task_id, p.exitcode), |
| self._get_mpr_result(process_statuses)) |
| |
| logging.info('Joining log reading threads.') |
| for thread in self._reading_threads: |
| thread.join() |
| logging.info('Joined log reading threads.') |
| |
| # Clear the alarm. |
| signal.alarm(0) |
| |
| return self._get_mpr_result(process_statuses) |
| |
| def _get_mpr_result(self, process_statuses): |
| stdout = self._queue_to_list(self._streaming_queue) |
| return_values = [] |
| for process_status in process_statuses.values(): |
| if process_status.return_value is not None: |
| return_values.append(process_status.return_value) |
| return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) |
| |
| def terminate(self, task_type, task_id): |
| """Terminates the process with `task_type` and `task_id`. |
| |
| If auto_retart=True, the terminated task will be restarted unless the chief |
| has already exited with zero exit code. |
| |
| Args: |
| task_type: the task type. |
| task_id: the task id. |
| |
| """ |
| with self._process_lock: |
| p = self._processes.get((task_type, task_id), None) |
| if p is None: |
| raise ValueError('{}-{} does not exist'.format(task_type, task_id)) |
| self._terminated.add((task_type, task_id)) |
| # TODO(crccw): change to use Process.terminate() as well. |
| self._parent_to_sub_queue.put('terminate {} {}'.format( |
| task_type, task_id)) |
| p.join() |
| |
| def _terminate_all(self, sig=None): |
| """Terminates all subprocesses. |
| |
| The caller is required to hold self._process_lock. |
| |
| Args: |
| sig: the signal used to terminate the process. The default is SIGKILL. |
| """ |
| |
| # Use SIGKILL as default. In systems where that's unavailable such as |
| # windows, use SIGTERM. |
| sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) |
| for (task_type, task_id), p in self._processes.items(): |
| if p.exitcode is not None: |
| logging.info('%s-%d has already exited. Not terminating.', task_type, |
| task_id) |
| continue |
| try: |
| os.kill(p.pid, sig) |
| self._terminated.add((task_type, task_id)) |
| logging.info('%s-%d terminated with signal %r.', task_type, task_id, |
| sig) |
| except ProcessLookupError: |
| logging.info('Attempting to kill %s-%d but it does not exist.', |
| task_type, task_id) |
| |
| def terminate_all(self, sig=None): |
| """Terminates all subprocesses.""" |
| with self._process_lock: |
| self._terminate_all(sig) |
| |
| |
| class _Process(multi_process_lib.Process): |
| """A modified `multiprocessing.Process` that can set up environment variables.""" |
| |
| # TODO(crccw): consider moving other logics in _ProcFunc to _Process. |
| |
| def __init__(self, test_env, **kwargs): |
| super(_Process, self).__init__(**kwargs) |
| self._test_env = test_env |
| self._actual_run = getattr(self, 'run') |
| self.run = self._run_with_setenv |
| |
| def _run_with_setenv(self): |
| # We need to set environment variables before doing anything because |
| # setenv() is not thread-safe. |
| test_env = self._test_env |
| if test_env.grpc_fail_fast is not None: |
| os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) |
| if test_env.visible_gpus: |
| os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( |
| [str(i) for i in test_env.visible_gpus]) |
| _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, |
| test_env.rpc_layer) |
| return self._actual_run() |
| |
| |
| class _ProcFunc(object): |
| """Represents a callable to run in a subprocess.""" |
| |
| @contextlib.contextmanager |
| def _runtime_mode(self, executing_eagerly): |
| if executing_eagerly: |
| with context.eager_mode(): |
| yield |
| else: |
| with context.graph_mode(): |
| yield |
| |
| def _message_checking_func(self, task_type, task_id): |
| """A function that regularly checks messages from parent process.""" |
| # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. |
| while True: |
| try: |
| message = self._resources.parent_to_sub_queue.get(block=False) |
| |
| # Currently the only possible message is termination. |
| if not message.startswith('terminate'): |
| raise ValueError('Unrecognized message: {}'.format(message)) |
| |
| if message == 'terminate {} {}'.format(task_type, task_id): |
| break |
| else: |
| # If the message is not targeting this process, put it back to the |
| # queue. |
| self._resources.parent_to_sub_queue.put(message) |
| time.sleep(1) |
| except Queue.Empty: |
| time.sleep(0.1) |
| self._resources.process_status_queue.put( |
| _ProcessStatusInfo( |
| task_type=task_type, |
| task_id=task_id, |
| is_successful=True, |
| exc_info=None, |
| return_value=None)) |
| # `os._exit(1)` is used to more reliably terminate a subprocess. |
| os._exit(1) # pylint: disable=protected-access |
| |
| def _close_streaming(self): |
| """Close stdout, stderr and streaming pipe. |
| |
| We need to explicitly close them since Tensorflow may take a while to exit, |
| so that the reading threads in the main process can exit more quickly. |
| """ |
| sys.stdout.flush() |
| sys.stderr.flush() |
| sys.stdout.close() |
| sys.stderr.close() |
| self._resources.streaming_pipe_w.close() |
| |
| def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args): |
| """The wrapper function that actually gets run in child process(es).""" |
| |
| global _barrier |
| |
| self._resources = resources |
| _barrier = self._resources.barrier |
| fn = dill.loads(fn) |
| if use_dill_for_args: |
| args = dill.loads(args) |
| kwargs = dill.loads(kwargs) |
| |
| if faulthandler is not None: |
| faulthandler.enable() |
| faulthandler.register(signal.SIGTERM, chain=True) |
| |
| # All logging should go to stderr to be streamed to the main process. |
| logging.set_stderrthreshold(logging.DEBUG) |
| |
| # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so |
| # print() and logging.*() write directly to `streaming_pipe_w`. |
| # Unfortunately since we cannot prepend task_type and task_id information to |
| # the streamed logs we will need a thread per subprocess to distinguish |
| # where the piece of message is from. |
| os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) |
| os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) |
| |
| pid = os.getpid() |
| logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, |
| test_env.task_type, test_env.task_id) |
| logging.info('TF_CONFIG: %r', os.environ['TF_CONFIG']) |
| |
| # The thread will be dedicated to checking messages from the parent process. |
| threading.Thread( # pylint: disable=unexpected-keyword-arg |
| target=self._message_checking_func, |
| args=(test_env.task_type, test_env.task_id), |
| daemon=True).start() |
| |
| if test_env.v2_enabled: |
| v2_compat.enable_v2_behavior() |
| |
| with self._runtime_mode(test_env.executing_eagerly): |
| info = _run_contained(test_env.task_type, test_env.task_id, fn, args, |
| kwargs) |
| self._resources.process_status_queue.put(info) |
| |
| # Re-raise the exception in addition to reporting it to the parent |
| # process, so that even if `--test_timeout` flag is set and the |
| # error doesn't make it to be shown in parent process before bazel's |
| # timeout, the log would still show what happens in this subprocess, |
| # instead of silently suppressing the error due to early bazel |
| # timeout. Raising an error in the subprocess produces stack trace in |
| # the log, but the program continues running. |
| if not info.is_successful: |
| six.reraise(*info.exc_info) |
| |
| self._close_streaming() |
| |
| # Exit with code 0 as it's considered successful exit at this point. |
| sys.exit(0) |
| |
| |
| # Active MultiProcessPoolRunner. We need to shut them down when the program |
| # exits, and this is by setting the `tearDownModule` of the module containing |
| # `__main__`. Note this it set in both the parent process and the subprocesses. |
| _active_pool_runners = weakref.WeakSet() |
| |
| |
| def _shutdown_all_pool_runners(): |
| for pool in _active_pool_runners: |
| pool.shutdown() |
| |
| |
| def is_oss(): |
| """Returns whether the test is run under OSS.""" |
| return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] |
| |
| |
| class MultiProcessPoolRunner(object): |
| """A utility class to start a process pool to simulate a cluster. |
| |
| It's similar to MultiProcessRunner, but uses a pool of processes to avoid the |
| expensive initialization cost of Tensorflow. |
| """ |
| |
| def __init__(self, cluster_spec, initializer=None, share_gpu=True): |
| """Creates a multi-process pool runner. |
| |
| Args: |
| cluster_spec: Dict for cluster spec. The following is an example of |
| cluster with three workers. |
| {"worker": ["worker0.example.com:2222", |
| "worker1.example.com:2222", |
| "worker2.example.com:2222"]} |
| initializer: a callable to called at the startup of worker processes. |
| share_gpu: Whether to share GPUs among workers. If False, each worker is |
| assigned different GPUs in a roundrobin fashion. |
| |
| Raises: |
| RuntimeError: if `multi_process_runner.test_main()` is not called. |
| ValueError: if there are more than one chief in the `cluster_spec`. |
| """ |
| _active_pool_runners.add(self) |
| self._cluster_spec = cluster_spec |
| self._initializer = initializer |
| self._share_gpu = share_gpu |
| self._conn = {} |
| self._runner = None |
| |
| def __del__(self): |
| self.shutdown() |
| |
| def shutdown(self): |
| """Shuts down the worker pool.""" |
| for conn in self._conn.values(): |
| conn.close() |
| self._conn = {} |
| if self._runner is not None: |
| try: |
| self._runner.join() |
| except Exception as e: # pylint: disable=broad-except |
| logging.error( |
| 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', |
| e) |
| self._runner = None |
| |
| def _start(self): |
| """Starts the worker pool.""" |
| # We need different arguments for different processes so we're passing a |
| # no-op fn here and use start_single_process instead. |
| |
| if dill is None: |
| raise unittest.SkipTest( |
| 'TODO(b/150264776): Resolve dependency issue in CI') |
| |
| self._runner = MultiProcessRunner( |
| fn=lambda: None, |
| cluster_spec=self._cluster_spec, |
| use_dill_for_args=False, |
| share_gpu=self._share_gpu) |
| if self._initializer: |
| initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL) |
| else: |
| initializer = None |
| for task_type, addresses in self._cluster_spec.items(): |
| for task_id, _ in enumerate(addresses): |
| conn1, conn2 = multiprocessing.Pipe(duplex=True) |
| self._conn[(task_type, task_id)] = conn1 |
| self._runner.start_single_process( |
| task_type, |
| task_id, |
| fn=_pool_runner_worker, |
| args=(task_type, task_id, initializer, conn2)) |
| |
| def run(self, fn, args=None, kwargs=None): |
| """Runs `fn` with `args` and `kwargs` on all jobs. |
| |
| Args: |
| fn: The function to be run. |
| args: Optional positional arguments to be supplied in `fn`. |
| kwargs: Optional keyword arguments to be supplied in `fn`. |
| |
| Returns: |
| A list of return values. |
| """ |
| _check_initialization() |
| # TODO(b/150264776): skip in OSS until it's implemented. |
| multi_process_lib.Process() |
| if self._runner is None: |
| self._start() |
| |
| fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) |
| for conn in self._conn.values(): |
| conn.send((fn, args or [], kwargs or {})) |
| |
| process_statuses = [] |
| for (task_type, task_id), conn in self._conn.items(): |
| logging.info('Waiting for the result from %s-%d', task_type, task_id) |
| try: |
| process_statuses.append(conn.recv()) |
| except EOFError: |
| # This shouldn't happen due to exceptions in fn. This usually |
| # means bugs in the runner. |
| self.shutdown() |
| raise RuntimeError('Unexpected EOF. Worker process may have died. ' |
| 'Please report a bug') |
| |
| return_values = [] |
| for process_status in process_statuses: |
| assert isinstance(process_status, _ProcessStatusInfo) |
| if not process_status.is_successful: |
| six.reraise(*process_status.exc_info) |
| if process_status.return_value is not None: |
| return_values.append(process_status.return_value) |
| |
| return return_values |
| |
| |
| def _pool_runner_worker(task_type, task_id, initializer, conn): |
| """Function that runs on the workers in a pool. |
| |
| It listens for callables to run and returns the result until `conn` is closed. |
| It captures the exceptions during executing the callable and return it through |
| `conn`. |
| |
| Args: |
| task_type: the task type. |
| task_id: the task index. |
| initializer: a callable to execute during startup. |
| conn: a multiprocessing.Connection object to listen for tasks and send |
| results. |
| """ |
| if initializer: |
| initializer = dill.loads(initializer) |
| initializer() |
| while True: |
| try: |
| fn, args, kwargs = conn.recv() |
| except EOFError: |
| break |
| fn = dill.loads(fn) |
| info = _run_contained(task_type, task_id, fn, args, kwargs) |
| sys.stdout.flush() |
| sys.stderr.flush() |
| conn.send(info) |
| |
| |
| def _run_contained(task_type, task_id, fn, args, kwargs): |
| """Runs `fn` with `args` and `kwargs`. |
| |
| The function returns _ProcessStatusInfo which captures the return value and |
| the exception. |
| |
| Args: |
| task_type: the task type. |
| task_id: the task index. |
| fn: the function to be run. |
| args: optional positional arguments to be supplied in `fn`. |
| kwargs: optional keyword arguments to be supplied in `fn`. |
| |
| Returns: |
| a _ProcessStatusInfo. |
| |
| """ |
| is_successful = False |
| return_value = None |
| exc_info = None |
| try: |
| return_value = fn(*args, **kwargs) |
| is_successful = True |
| return _ProcessStatusInfo( |
| task_type=task_type, |
| task_id=task_id, |
| is_successful=is_successful, |
| exc_info=exc_info, |
| return_value=return_value) |
| |
| # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not |
| # handled here. |
| except Exception: # pylint: disable=broad-except |
| exc_info = sys.exc_info() |
| return _ProcessStatusInfo( |
| task_type=task_type, |
| task_id=task_id, |
| is_successful=is_successful, |
| exc_info=exc_info, |
| return_value=return_value) |
| |
| |
| @tf_export('__internal__.distribute.multi_process_runner' |
| '.SubprocessTimeoutError', |
| v1=[]) |
| class SubprocessTimeoutError(RuntimeError): |
| """An error that indicates there is at least one subprocess timing out. |
| |
| When this is raised, a namedtuple object representing the multi-process run |
| result can be retrieved by |
| `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s |
| `mpr_result` attribute. See |
| `tf.__internal__.distribute.multi_process_runner.run` for more information. |
| """ |
| |
| def __init__(self, msg, mpr_result): |
| super(SubprocessTimeoutError, self).__init__(msg) |
| self.mpr_result = mpr_result |
| |
| |
| @tf_export('__internal__.distribute.multi_process_runner' |
| '.UnexpectedSubprocessExitError', |
| v1=[]) |
| class UnexpectedSubprocessExitError(RuntimeError): |
| """An error indicating there is at least one subprocess with unexpected exit. |
| |
| When this is raised, a namedtuple object representing the multi-process run |
| result can be retrieved by |
| `tf.__internal__.distribute.multi_process_runner |
| .UnexpectedSubprocessExitError`'s |
| `mpr_result` attribute. See |
| `tf.__internal__.distribute.multi_process_runner.run` for more information. |
| """ |
| |
| def __init__(self, msg, mpr_result): |
| super(UnexpectedSubprocessExitError, self).__init__(msg) |
| self.mpr_result = mpr_result |
| |
| |
| @tf_export( |
| '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[]) |
| class NotInitializedError(RuntimeError): |
| """An error indicating `multi_process_runner.run` is used without init. |
| |
| When this is raised, user is supposed to call |
| `tf.__internal__.distribute.multi_process_runner.test_main()` within |
| `if __name__ == '__main__':` block to properly initialize |
| `multi_process_runner.run`. |
| """ |
| pass |
| |
| |
| def _check_initialization(): |
| if not multi_process_lib.initialized(): |
| raise NotInitializedError( |
| '`multi_process_runner` is not initialized. ' |
| 'Please call `tf.__internal__.distribute.multi_process_runner.' |
| 'test_main()` within `if __name__ == \'__main__\':` block ' |
| 'in your python module to properly initialize ' |
| '`multi_process_runner`.') |
| |
| |
| def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): |
| """Set TF_CONFIG environment variable.""" |
| tf_config_dict = { |
| 'cluster': cluster_spec, |
| 'task': { |
| 'type': task_type, |
| 'index': task_id, |
| }, |
| } |
| if rpc_layer is not None: |
| tf_config_dict['rpc_layer'] = rpc_layer |
| os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) |
| |
| |
| @tf_export('__internal__.distribute.multi_process_runner.run', v1=[]) |
| def run(fn, |
| cluster_spec, |
| rpc_layer=None, |
| max_run_time=None, |
| return_output=False, |
| timeout=_DEFAULT_TIMEOUT_SEC, |
| args=None, |
| kwargs=None): |
| """Run `fn` in multiple processes according to `cluster_spec`. |
| |
| Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run` |
| launches multiple processes, each of which runs `fn`. These processes are |
| referred to as "subprocesses" or "child processes". Each of those subprocesses |
| will have their `TF_CONFIG` environment variable set, according to |
| `cluster_spec` and their task types. The stdout of the subprocesses are |
| streamed to the main process' and thus available in logs (if `stream_output` |
| is True), with [type-id] prefix. |
| |
| `tf.__internal__.distribute.multi_process_runner.run` will block until all |
| subprocesses have successfully exited, and return a namedtuple object that |
| represents the run result. This object has a `return_value` attribute, which |
| is a list that contains subprocesses `fn`'s return values, for those |
| subprocesses that successfully returned from `fn`. The order of `return_value` |
| list is not meaningful. If an optional arg `return_output` (default to False) |
| is set to True, the namedtuple object will have an additional attribute |
| `stdout`, which is a list containing the stdout of the subprocesses. If any |
| subprocess' `fn` ends up raising an error, that error will be reraised from |
| `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned |
| namedtuple object will be available through the exception's |
| `mpr_result` attribute. |
| |
| This utility is used for simulating running TensorFlow programs across |
| multiple task types, and each of the task type may contain more than one task |
| (except for "chief" where more than one task is prohibited). Test coverage of |
| multi-worker training is the main application of this utility, where code |
| written for multi-worker training can be realistically covered in unit tests. |
| |
| Any test module that uses |
| `tf.__internal__.distribute.multi_process_runner.run()` must call |
| `tf.__internal__.distribute.multi_process_runner.test_main()` instead of |
| regular `test.main()` inside `if __name__ == '__main__':` block for proper |
| initialization. |
| |
| Args: |
| fn: Function to be run on child processes. This will be run on processes for |
| all task types. |
| cluster_spec: Dict for cluster spec. The utility function |
| `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can |
| be conveniently used to create such dict. The following is an example of |
| cluster with three workers and two ps's. |
| {"worker": ["worker0.example.com:2222", |
| "worker1.example.com:2222", |
| "worker2.example.com:2222"], |
| "ps": ["ps0.example.com:2222", |
| "ps1.example.com:2222"]} |
| rpc_layer: RPC layer to use. Default value is 'grpc'. |
| max_run_time: `None` or integer. If not `None`, child processes are forced |
| to exit at approximately this many seconds after this utility is called. |
| We achieve this through `signal.alarm()` api. Note that this is best |
| effort at Python level since Python signal handler does not get executed |
| when it runs lower level C/C++ code. So it can be delayed for arbitrarily |
| long time. If any of the child process is still running when |
| `max_run_time` is up, they will be force-terminated and an |
| `tf.__internal__.distribute.multi_process_runner |
| .UnexpectedSubprocessExitError` |
| may be raised. If `None`, child processes are not forced to exit. |
| return_output: If True, the output/error from the subprocesses should be |
| collected to be attached to the resulting namedtuple returned from this |
| utility. The list of output can be retrieved via `stdout` attribute. |
| Defaults to False. |
| timeout: optional integer or `None`. If provided as an integer, and not all |
| processes report status within roughly `timeout` seconds, a |
| `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError` |
| exception will be raised. If `None`, |
| `tf.__internal__.distribute.multi_process_runner.run` never times out. |
| Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in |
| `multi_process_runner` module. |
| args: Positional arguments to be sent to `fn` run on subprocesses. |
| kwargs: Keyword arguments to be sent to `fn` run on subprocesses. |
| |
| Returns: |
| A namedtuple object, which has two attributes, |
| `return_value` and `stdout`. `return_value` always contains a list of |
| returnvalues from the subprocesses, although the order is not meaningful. |
| If `return_output` argument is True, `stdout` is available that contains a |
| list of all messages from subprocesses' stdout and stderr, and the order |
| is mostly chronological. |
| |
| Raises: |
| RuntimeError: if |
| `tf.__internal__.distribute.multi_process_runner.test_main()` is |
| not called in test's `if __name__ == '__main__':` block. |
| ValueError: if there are more than one chief in the `cluster_spec`. |
| tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if |
| not all processes report status approximately |
| within `timeout` seconds. When this is raised, a |
| namedtuple object can be retrieved by |
| `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s |
| `mpr_result` attribute, which has the same |
| structure as above 'Returns' section describes. |
| tf.__internal__.distribute.multi_process_runner |
| .UnexpectedSubprocessExitError: |
| If any of the subprocesses did not exit |
| properly (for example, they exit on SIGTERM or SIGKILL signal). When |
| this is raised, a namedtuple object can be retrieved by |
| `tf.__internal__.distribute.multi_process_runner |
| .UnexpectedSubprocessExitError`'s |
| `mpr_result` attribute, which has the |
| same structure as above 'Returns' section describes. If `max_run_time` |
| is not `None`, it is expected that some subprocesses may be |
| force-killed when `max_run_time` is up, and this is raised in those |
| cases. |
| Exception: if there is an Exception propagated from any subprocess. When |
| this is raised, a namedtuple object can be retrieved by |
| `tf.__internal__.distribute.multi_process_runner |
| .UnexpectedSubprocessExitError` |
| `mpr_result` attribute, which has the |
| same structure as above 'Returns' section describes. |
| |
| Examples: |
| |
| ```python |
| class SimpleMultiProcessTest(tf.test.TestCase): |
| |
| def test_simple_printing_and_return(self): |
| |
| def fn(): |
| resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() |
| |
| # This will print "[chief-0]: Task type: chief , task id: 0" |
| # for chief, for example. |
| logging.info('Task type: %s, task id: %d', |
| resolver.task_type, resolver.task_id) |
| |
| return resolver.task_type |
| |
| result = tf.__internal__.distribute.multi_process_runner.run( |
| fn=fn, |
| cluster_spec=( |
| tf.__internal__ |
| .distribute.multi_process_runner.create_cluster_spec( |
| has_chief=True, num_workers=2))) |
| assert sorted(result.return_value) == ['chief', 'worker', 'worker'] |
| |
| def test_error_from_fn(self): |
| |
| def fn(): |
| resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() |
| raise ValueError('Task type {}, task id {} is errors out'.format( |
| resolver.task_type, resolver.task_id)) |
| |
| with self.assertRaisesRegexp(ValueError, |
| 'Task type worker, task id 0 is errors out'): |
| cluster_spec = ( |
| tf.__internal__.distribute.multi_process_runner.create_cluster_spec( |
| num_workers=1)) |
| tf.__internal__.distribute.multi_process_runner.run( |
| fn=fn, cluster_spec=cluster_spec) |
| |
| |
| if __name__ == '__main__': |
| tf.__internal__.distribute.multi_process_runner.test_main() |
| ``` |
| """ |
| runner = MultiProcessRunner( |
| fn, |
| cluster_spec, |
| rpc_layer, |
| max_run_time=max_run_time, |
| return_output=return_output, |
| args=args, |
| kwargs=kwargs) |
| runner.start() |
| return runner.join(timeout) |
| |
| |
| # This is set by MultiProcessRunner in worker processes. |
| _barrier = None |
| |
| |
| @tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[]) |
| def get_barrier(): |
| """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`. |
| |
| `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns |
| a `multiprocessing.Barrier` object which can be used within `fn` of |
| `tf.__internal__.distribute.multi_process_runner` to wait with |
| `barrier.wait()` call until all other tasks have also reached the |
| `barrier.wait()` call, before they can proceed individually. |
| |
| Note that all tasks (subprocesses) have to reach `barrier.wait()` call to |
| proceed. Currently it is not supported to block on only a subset of tasks |
| in the cluster. |
| |
| Example: |
| ```python |
| |
| def fn(): |
| some_work_to_be_done_by_all_tasks() |
| |
| tf.__internal__.distribute.multi_process_runner.get_barrier().wait() |
| |
| # The barrier guarantees that at this point, all tasks have finished |
| # `some_work_to_be_done_by_all_tasks()` |
| some_other_work_to_be_done_by_all_tasks() |
| |
| result = tf.__internal__.distribute.multi_process_runner.run( |
| fn=fn, |
| cluster_spec=( |
| tf.__internal__ |
| .distribute.multi_process_runner.create_cluster_spec( |
| num_workers=2))) |
| ``` |
| |
| |
| Returns: |
| A `multiprocessing.Barrier` for `multi_process_runner.run`. |
| """ |
| if _barrier is None: |
| raise ValueError( |
| 'barrier is not defined. It is likely because you are calling ' |
| 'get_barrier() in the main process. get_barrier() can only be called ' |
| 'in the subprocesses.' |
| ) |
| return _barrier |
| |
| |
| _manager = None |
| _manager_lock = threading.Lock() |
| |
| |
| def manager(): |
| """Returns the multiprocessing manager object for concurrency tools. |
| |
| The manager object is useful as it controls a server process that holds |
| the python objects that can be shared across processes. This can be used |
| for parent-subprocess communication: |
| |
| ```python |
| manager = multi_process_runner.manager() |
| some_event_happening_in_subprocess = manager.Event() |
| mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec, |
| args=(some_event_happening_in_subprocess,)) |
| mpr.start() |
| some_event_happening_in_subprocess.wait() |
| # Do something that only should after some event happens in subprocess. |
| ``` |
| |
| Note that the user of multi_process_runner should not create additional |
| `multiprocessing.Manager()` objects; doing so can result in segfault in |
| some cases. |
| |
| This method should only be called after multi_process_runner.test_main() is |
| called. |
| """ |
| _check_initialization() |
| global _manager |
| with _manager_lock: |
| if _manager is None: |
| _manager = multiprocessing.Manager() |
| return _manager |
| |
| |
| @tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[]) |
| def test_main(): |
| """Main function to be called within `__main__` of a test file. |
| |
| Any test module that uses |
| `tf.__internal__.distribute.multi_process_runner.run()` |
| must call this instead of regular `test.main()` inside |
| `if __name__ == '__main__':` block, or an error will be raised when |
| `tf.__internal__.distribute.multi_process_runner.run()` is used. This method |
| takes |
| care of needed initialization for launching multiple subprocesses. |
| |
| Example: |
| ```python |
| class MyTestClass(tf.test.TestCase): |
| def testSomething(self): |
| # Testing code making use of |
| # `tf.__internal__.distribute.multi_process_runner.run()`. |
| |
| if __name__ == '__main__': |
| tf.__internal__.distribute.multi_process_runner.test_main() |
| ``` |
| """ |
| # Inject tearDownModule() to shut down all pool runners. Active pool runners |
| # will block the program from exiting. This is necessary for global pool |
| # runners. We tried atexit in the past, and it doesn't work in some |
| # deployment. |
| old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule', |
| None) |
| |
| def tear_down_module(): |
| _shutdown_all_pool_runners() |
| if old_tear_down_module is not None: |
| old_tear_down_module() |
| |
| setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module) |
| multi_process_lib.test_main() |