| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for `multi_process_runner`.""" |
| |
| import ctypes |
| import json |
| import os |
| import sys |
| import threading |
| import time |
| import unittest |
| |
| from absl import logging |
| from absl.testing import parameterized |
| |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import multi_process_runner |
| from tensorflow.python.distribute import multi_worker_test_base |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import test |
| |
| |
| try: |
| import dill # pylint:disable=g-import-not-at-top |
| |
| _REGISTER_DECORATOR = dill.register |
| except ImportError: |
| _REGISTER_DECORATOR = lambda fn, *_: fn |
| |
| |
| def fn_that_adds_task_type_in_return_data(): |
| return multi_worker_test_base.get_task_type() |
| |
| |
| def fn_that_errors(): |
| raise ValueError('This is an error.') |
| |
| |
| def fn_that_does_nothing(): |
| pass |
| |
| |
| def fn_that_adds_simple_return_data(): |
| return 'dummy_data' |
| |
| |
| def fn_that_returns_args_and_kwargs(*args, **kwargs): |
| return list(args) + list(kwargs.items()) |
| |
| |
| def fn_with_barrier(): |
| return multi_process_runner.get_barrier() |
| |
| |
| def fn_that_returns_pid(): |
| return os.getpid() |
| |
| |
| V = None |
| |
| |
| def fn_that_sets_global(val): |
| global V |
| old_val = V |
| V = val |
| return old_val |
| |
| |
| @combinations.generate(combinations.combine(required_gpus=0)) |
| class MultiProcessRunnerTest(test.TestCase, parameterized.TestCase): |
| |
| def _worker_idx(self): |
| config_task = json.loads(os.environ['TF_CONFIG'])['task'] |
| return config_task['index'] |
| |
| def test_multi_process_runner(self): |
| mpr_result = multi_process_runner.run( |
| fn_that_adds_task_type_in_return_data, |
| multi_worker_test_base.create_cluster_spec( |
| num_workers=2, num_ps=3, has_chief=True)) |
| |
| job_count_dict = {'worker': 2, 'ps': 3, 'chief': 1} |
| for data in mpr_result.return_value: |
| job_count_dict[data] -= 1 |
| |
| self.assertEqual(job_count_dict['worker'], 0) |
| self.assertEqual(job_count_dict['ps'], 0) |
| self.assertEqual(job_count_dict['chief'], 0) |
| |
| def test_multi_process_runner_error_propagates_from_subprocesses(self): |
| runner = multi_process_runner.MultiProcessRunner( |
| fn_that_errors, |
| multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1), |
| max_run_time=20) |
| runner.start() |
| with self.assertRaisesRegex(ValueError, 'This is an error.'): |
| runner.join() |
| |
| def test_multi_process_runner_queue_emptied_between_runs(self): |
| cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) |
| return_value = multi_process_runner.run(fn_that_adds_simple_return_data, |
| cluster_spec).return_value |
| self.assertTrue(return_value) |
| self.assertEqual(return_value[0], 'dummy_data') |
| self.assertEqual(return_value[1], 'dummy_data') |
| return_value = multi_process_runner.run(fn_that_does_nothing, |
| cluster_spec).return_value |
| self.assertFalse(return_value) |
| |
| def test_multi_process_runner_args_passed_correctly(self): |
| return_value = multi_process_runner.run( |
| fn_that_returns_args_and_kwargs, |
| multi_worker_test_base.create_cluster_spec(num_workers=1), |
| args=('a', 'b'), |
| kwargs={ |
| 'c_k': 'c_v' |
| }).return_value |
| self.assertEqual(return_value[0][0], 'a') |
| self.assertEqual(return_value[0][1], 'b') |
| self.assertEqual(return_value[0][2], ('c_k', 'c_v')) |
| |
| def test_stdout_captured(self): |
| |
| def simple_print_func(): |
| print('This is something printed.', flush=True) |
| return 'This is returned data.' |
| |
| mpr_result = multi_process_runner.run( |
| simple_print_func, |
| multi_worker_test_base.create_cluster_spec(num_workers=2), |
| return_output=True) |
| std_stream_results = mpr_result.stdout |
| return_value = mpr_result.return_value |
| self.assertIn('[worker-0]: This is something printed.\n', |
| std_stream_results) |
| self.assertIn('[worker-1]: This is something printed.\n', |
| std_stream_results) |
| self.assertIn('This is returned data.', return_value) |
| |
| def test_termination(self): |
| |
| def fn(): |
| for i in range(0, 10): |
| print( |
| 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) |
| time.sleep(5) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec(num_workers=2), |
| return_output=True) |
| mpr.start() |
| time.sleep(5) |
| mpr.terminate('worker', 0) |
| |
| std_stream_results = mpr.join().stdout |
| |
| # Worker 0 is terminated in the middle, so it should not have iteration 9 |
| # printed. |
| self.assertIn('[worker-0]: index 0, iteration 0\n', std_stream_results) |
| self.assertNotIn('[worker-0]: index 0, iteration 9\n', |
| std_stream_results) |
| self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results) |
| self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results) |
| |
| def test_termination_and_start_single_process(self): |
| |
| def fn(): |
| for i in range(0, 10): |
| print( |
| 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) |
| time.sleep(1) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec(num_workers=2), |
| return_output=True) |
| mpr.start() |
| time.sleep(3) |
| mpr.terminate('worker', 0) |
| mpr.start_single_process('worker', 0) |
| std_stream_results = mpr.join().stdout |
| |
| # Worker 0 is terminated in the middle, but a new worker 0 is added, so it |
| # should still have iteration 9 printed. Moreover, iteration 0 of worker 0 |
| # should happen twice. |
| self.assertLen( |
| [s for s in std_stream_results if 'index 0, iteration 0' in s], 2) |
| self.assertIn('[worker-0]: index 0, iteration 9\n', std_stream_results) |
| self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results) |
| self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results) |
| |
| def test_streaming(self): |
| |
| def fn(): |
| for i in range(5): |
| logging.info('(logging) %s-%d, i: %d', |
| multi_worker_test_base.get_task_type(), self._worker_idx(), |
| i) |
| print( |
| '(print) {}-{}, i: {}'.format( |
| multi_worker_test_base.get_task_type(), self._worker_idx(), i), |
| flush=True) |
| time.sleep(1) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=2, num_ps=2), |
| return_output=True) |
| mpr._dependence_on_chief = False |
| |
| mpr.start() |
| mpr.start_single_process('worker', 2) |
| mpr.start_single_process('ps', 2) |
| mpr_result = mpr.join() |
| |
| list_to_assert = mpr_result.stdout |
| |
| for job in ['chief']: |
| for iteration in range(5): |
| self.assertTrue( |
| any('(logging) {}-0, i: {}'.format(job, iteration) in line |
| for line in list_to_assert)) |
| self.assertTrue( |
| any('(print) {}-0, i: {}'.format(job, iteration) in line |
| for line in list_to_assert)) |
| |
| for job in ['worker', 'ps']: |
| for iteration in range(5): |
| for task in range(3): |
| self.assertTrue( |
| any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line |
| for line in list_to_assert)) |
| self.assertTrue( |
| any('(print) {}-{}, i: {}'.format(job, task, iteration) in line |
| for line in list_to_assert)) |
| task = 3 |
| self.assertFalse( |
| any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line |
| for line in list_to_assert)) |
| self.assertFalse( |
| any('(print) {}-{}, i: {}'.format(job, task, iteration) in line |
| for line in list_to_assert)) |
| |
| def test_start_in_process_as(self): |
| |
| def fn(): |
| for i in range(5): |
| logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(), |
| self._worker_idx(), i) |
| time.sleep(1) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1), |
| return_output=True) |
| |
| def eval_func(): |
| time.sleep(1) |
| mpr.start_single_process(task_type='evaluator', task_id=0) |
| |
| eval_thread = threading.Thread(target=eval_func) |
| eval_thread.start() |
| mpr.start_in_process_as(as_task_type='chief', as_task_id=0) |
| eval_thread.join() |
| list_to_assert = mpr.join().stdout |
| for job in ['worker', 'evaluator']: |
| for iteration in range(5): |
| self.assertTrue( |
| any('{}-0, i: {}'.format(job, iteration) in line |
| for line in list_to_assert)) |
| |
| def test_terminate_all_does_not_ignore_error(self): |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn_that_errors, |
| multi_worker_test_base.create_cluster_spec(num_workers=2), |
| return_output=True) |
| mpr.start() |
| time.sleep(60) |
| mpr.terminate_all() |
| with self.assertRaisesRegex(ValueError, 'This is an error.'): |
| mpr.join() |
| |
| def test_barrier(self): |
| multi_process_runner.run( |
| fn_with_barrier, |
| cluster_spec=multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1), |
| ) |
| |
| def test_barrier_called_in_main_process(self): |
| with self.assertRaises(ValueError): |
| multi_process_runner.get_barrier() |
| |
| def test_stdout_available_when_timeout(self): |
| |
| def fn(): |
| logging.info('something printed') |
| time.sleep(10000) # Intentionally make the test timeout. |
| |
| with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm: |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec(num_workers=1), |
| return_output=True) |
| mpr.start() |
| mpr.join(timeout=60) |
| mpr.terminate_all() |
| |
| list_to_assert = cm.exception.mpr_result.stdout |
| self.assertTrue( |
| any('something printed' in line for line in list_to_assert)) |
| |
| def test_seg_fault_raises_error(self): |
| |
| if multi_process_runner.is_oss() or sys.version_info >= (3, 7): |
| self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+') |
| |
| def fn_expected_to_seg_fault(): |
| ctypes.string_at(0) # Intentionally made seg fault. |
| |
| with self.assertRaises( |
| multi_process_runner.UnexpectedSubprocessExitError) as cm: |
| multi_process_runner.run( |
| fn_expected_to_seg_fault, |
| multi_worker_test_base.create_cluster_spec(num_workers=1), |
| return_output=True) |
| self.assertIn('Subprocess worker-0 exited with exit code', |
| str(cm.exception)) |
| list_to_assert = cm.exception.mpr_result.stdout |
| self.assertTrue( |
| any('Segmentation fault' in line for line in list_to_assert)) |
| |
| def test_seg_fault_in_chief_raises_error(self): |
| |
| if multi_process_runner.is_oss() or sys.version_info >= (3, 7): |
| self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+') |
| |
| def fn_expected_to_seg_fault(): |
| if multi_worker_test_base.get_task_type() == 'worker': |
| time.sleep(10000) |
| ctypes.string_at(0) # Intentionally made seg fault. |
| |
| with self.assertRaises( |
| multi_process_runner.UnexpectedSubprocessExitError) as cm: |
| multi_process_runner.run( |
| fn_expected_to_seg_fault, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1), |
| return_output=True) |
| self.assertIn('Subprocess chief-0 exited with exit code', |
| str(cm.exception)) |
| list_to_assert = cm.exception.mpr_result.stdout |
| self.assertTrue( |
| any('Segmentation fault' in line for line in list_to_assert)) |
| |
| def test_exit_code_is_reported_by_chief_subprocess(self): |
| |
| def fn_expected_to_exit_with_20(): |
| if multi_worker_test_base.get_task_type() == 'worker': |
| time.sleep(10000) |
| sys.exit(20) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn_expected_to_exit_with_20, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1)) |
| mpr.start() |
| |
| with self.assertRaisesRegex( |
| multi_process_runner.UnexpectedSubprocessExitError, |
| 'Subprocess chief-0 exited with exit code 20'): |
| mpr.join() |
| |
| def test_exit_code_is_reported_by_subprocess(self): |
| |
| def fn_expected_to_exit_with_10(): |
| sys.exit(10) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn_expected_to_exit_with_10, |
| multi_worker_test_base.create_cluster_spec(num_workers=1)) |
| mpr.start() |
| |
| with self.assertRaisesRegex( |
| multi_process_runner.UnexpectedSubprocessExitError, |
| 'Subprocess worker-0 exited with exit code 10'): |
| mpr.join() |
| |
| def test_auto_restart(self): |
| |
| def fn(counter): |
| counter.value += 1 |
| if counter.value == 1: |
| raise ValueError |
| |
| manager = multi_process_runner.manager() |
| counter = manager.Value(int, 0) |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec(num_workers=1), |
| args=(counter,), |
| auto_restart=True) |
| mpr.start() |
| mpr.join() |
| self.assertEqual(counter.value, 2) |
| |
| def test_auto_restart_and_timeout(self): |
| |
| def fn(): |
| logging.info('Running') |
| time.sleep(1) |
| raise ValueError |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec(num_workers=1), |
| auto_restart=True, |
| return_output=True) |
| mpr.start() |
| with self.assertRaises(ValueError) as cm: |
| mpr.join(timeout=10) |
| self.assertGreater( |
| sum(['Running' in msg for msg in cm.exception.mpr_result.stdout]), 1) |
| |
| def test_auto_restart_and_chief(self): |
| # If the chief has exited with zero exit code, auto restart should stop |
| # restarting other tasks even if they fail. |
| |
| def fn(): |
| time.sleep(1) |
| if multi_worker_test_base.get_task_type() != 'chief': |
| raise ValueError |
| |
| manager = multi_process_runner.manager() |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1), |
| auto_restart=True) |
| mpr.start() |
| with self.assertRaises(ValueError): |
| mpr.join(timeout=10) |
| |
| def test_auto_restart_failure_immediate_after_restart(self): |
| # Test the case when worker-0 fails immediately after worker-1 restarts. |
| |
| def fn(): |
| time.sleep(5) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=False, num_workers=2), |
| auto_restart=True) |
| mpr.start() |
| pid = mpr.get_process_id('worker', 1) |
| mpr.terminate('worker', 1) |
| while mpr.get_process_id('worker', 1) == pid: |
| time.sleep(0.1) |
| mpr.terminate('worker', 0) |
| mpr.join(timeout=20) |
| |
| def test_auto_restart_terminate(self): |
| # Tasks terminated by the user should also be restarted. |
| |
| def fn(counter): |
| counter.value += 1 |
| if counter.value == 1: |
| time.sleep(100) |
| |
| manager = multi_process_runner.manager() |
| counter = manager.Value(int, 0) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, |
| multi_worker_test_base.create_cluster_spec( |
| has_chief=False, num_workers=1), |
| args=(counter,), |
| auto_restart=True) |
| mpr.start() |
| time.sleep(3) |
| mpr.terminate('worker', 0) |
| mpr.join(timeout=20) |
| self.assertEqual(counter.value, 2) |
| |
| def test_error_reporting_overrides_timeout_reporting(self): |
| |
| def fn(): |
| if self._worker_idx() == 1: |
| time.sleep(10000) |
| raise ValueError('Worker 0 errored') |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, multi_worker_test_base.create_cluster_spec(num_workers=2)) |
| mpr.start() |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| 'Worker 0 errored'): |
| mpr.join(timeout=20) |
| |
| def test_process_exists(self): |
| |
| def fn(): |
| time.sleep(100000) |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) |
| mpr.start() |
| self.assertTrue(mpr.process_exists('worker', 0)) |
| mpr.terminate('worker', 0) |
| # Worker 0 should exit at some point, or else the test would time out. |
| while mpr.process_exists('worker', 0): |
| time.sleep(1) |
| |
| def test_timeout_none(self): |
| |
| if multi_process_runner.is_oss(): |
| self.skipTest('Intentionally skipping longer test in OSS.') |
| |
| def fn(): |
| time.sleep(250) |
| raise ValueError('Worker 0 errored') |
| |
| mpr = multi_process_runner.MultiProcessRunner( |
| fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) |
| |
| mpr.start() |
| with self.assertRaisesRegex(ValueError, 'Worker 0 errored'): |
| mpr.join(timeout=None) |
| |
| |
| _global_pool = multi_process_runner.MultiProcessPoolRunner( |
| multi_worker_test_base.create_cluster_spec(num_workers=2)) |
| |
| |
| class MultiProcessPoolRunnerTest(test.TestCase): |
| |
| def test_same_process_across_runs(self): |
| cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) |
| runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) |
| pid = runner.run(fn_that_returns_pid) |
| for _ in range(3): |
| self.assertAllEqual(runner.run(fn_that_returns_pid), pid) |
| |
| def test_exceptions_in_sub_process(self): |
| cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) |
| runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) |
| pid = runner.run(fn_that_returns_pid) |
| with self.assertRaisesRegex(ValueError, 'This is an error.'): |
| runner.run(fn_that_errors) |
| self.assertAllEqual(runner.run(fn_that_returns_pid), pid) |
| |
| def test_tf_config(self): |
| cluster_spec = multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=2) |
| runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) |
| result = runner.run(fn_that_adds_task_type_in_return_data) |
| |
| job_count_dict = {'worker': 2, 'chief': 1} |
| for data in result: |
| job_count_dict[data] -= 1 |
| |
| self.assertEqual(job_count_dict['worker'], 0) |
| self.assertEqual(job_count_dict['chief'], 0) |
| |
| @unittest.expectedFailure |
| def test_exception_in_main_process(self): |
| # When there's an exception in the main process, __del__() is not called. |
| # This test is to verify MultiProcessPoolRunner can cope with __del__() not |
| # being called. |
| cluster_spec = multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=2) |
| runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) |
| runner.run(fn_that_returns_pid) |
| raise ValueError('failure') |
| |
| def test_initializer(self): |
| cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) |
| runner = multi_process_runner.MultiProcessPoolRunner( |
| cluster_spec, initializer=lambda: fn_that_sets_global(1)) |
| result = runner.run(fn_that_sets_global, args=(2,)) |
| self.assertAllEqual(result, [1, 1]) |
| |
| def test_global_pool(self): |
| _global_pool.run(fn_that_does_nothing) |
| |
| def test_nested_pool(self): |
| |
| def fn(): |
| # This runs in sub processes, so they are each using their own |
| # MultiProcessPoolRunner. |
| _global_pool.run(fn_that_does_nothing) |
| |
| _global_pool.run(fn) |
| |
| |
| @combinations.generate(combinations.combine(required_physical_gpus=2)) |
| class MultiProcessRunnerMultiGPUTest(test.TestCase, parameterized.TestCase): |
| |
| def test_not_share_gpu(self): |
| num_gpus = len(context.context().list_physical_devices('GPU')) |
| if num_gpus != 2 and num_gpus != 4: |
| self.skipTest('requires 2 or 4 GPUs') |
| cluster_spec = multi_worker_test_base.create_cluster_spec( |
| has_chief=True, num_workers=1) |
| |
| # Verify that CUDA_VISIBLE_DEVICES are different on each worker. |
| |
| def cuda_visible_devices_fn(): |
| return os.getenv('CUDA_VISIBLE_DEVICES') |
| |
| runner = multi_process_runner.MultiProcessRunner( |
| cuda_visible_devices_fn, cluster_spec, share_gpu=False) |
| runner.start() |
| result = runner.join() |
| if num_gpus == 2: |
| self.assertAllEqual(sorted(result.return_value), ['0', '1']) |
| else: |
| self.assertAllEqual(sorted(result.return_value), ['0,2', '1,3']) |
| |
| # Verify that CUDA_VISIBLE_DEVICES works. |
| |
| def num_gpus_fn(): |
| return len(context.context().list_physical_devices('GPU')) |
| |
| runner = multi_process_runner.MultiProcessRunner( |
| num_gpus_fn, cluster_spec, share_gpu=False) |
| runner.start() |
| result = runner.join() |
| if num_gpus == 2: |
| self.assertAllEqual(result.return_value, [1, 1]) |
| else: |
| self.assertAllEqual(result.return_value, [2, 2]) |
| |
| |
| @_REGISTER_DECORATOR(MultiProcessRunnerTest) |
| def _save_multi_process_runner_test(pickler, obj): |
| def reconstruct(*args, **kwargs): |
| del args, kwargs |
| return MultiProcessRunnerTest() |
| |
| return pickler.save_reduce(reconstruct, (), obj=obj) |
| |
| |
| @_REGISTER_DECORATOR(MultiProcessPoolRunnerTest) |
| def _save_multi_process_pool_runner_test(pickler, obj): |
| def reconstruct(*args, **kwargs): |
| del args, kwargs |
| return MultiProcessPoolRunnerTest() |
| |
| return pickler.save_reduce(reconstruct, (), obj=obj) |
| |
| |
| @_REGISTER_DECORATOR(MultiProcessRunnerMultiGPUTest) |
| def _save_multi_process_runner_multi_gpu_test(pickler, obj): |
| def reconstruct(*args, **kwargs): |
| del args, kwargs |
| return MultiProcessRunnerMultiGPUTest() |
| |
| return pickler.save_reduce(reconstruct, (), obj=obj) |
| |
| |
| if __name__ == '__main__': |
| multi_process_runner.test_main() |