| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for MirroredStrategy.""" |
| |
| import json |
| import sys |
| |
| from absl.testing import parameterized |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python import tf2 |
| from tensorflow.python.autograph.core import converter_testing |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import collective_util |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import mirrored_strategy |
| from tensorflow.python.distribute import multi_worker_test_base |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import strategy_combinations |
| from tensorflow.python.distribute import strategy_test_lib |
| from tensorflow.python.distribute import test_util |
| from tensorflow.python.distribute import values |
| from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as tf_device |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.framework import test_util as util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gradients |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variable_v1 |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops import while_loop |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.util import traceback_utils |
| |
| |
| GPU_TEST = "test_gpu" in sys.argv[0] |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, |
| ], |
| mode=["graph", "eager"])) |
| class MirroredTwoDeviceDistributionTest( |
| strategy_test_lib.DistributionTestBase, |
| strategy_test_lib.TwoDeviceDistributionTestBase, |
| parameterized.TestCase): |
| |
| def testMinimizeLoss(self, distribution): |
| if context.executing_eagerly(): |
| self._test_minimize_loss_eager(distribution) |
| else: |
| self._test_minimize_loss_graph(distribution) |
| |
| def testReplicaId(self, distribution): |
| self._test_replica_id(distribution) |
| |
| def testNumReplicasInSync(self, distribution): |
| self.assertEqual(2, distribution.num_replicas_in_sync) |
| |
| def testCallAndMergeExceptions(self, distribution): |
| self._test_call_and_merge_exceptions(distribution) |
| |
| def testRunRegroupError(self, distribution): |
| if not distribution.extended._use_merge_call(): |
| self.skipTest("Collective all-reduce does not support int32 on GPU.") |
| def run_fn(): |
| replica_id = int(self.evaluate(_replica_id())) |
| # Generates a list with different lengths on different devices. |
| # Will fail in _regroup() (if more than one device). |
| return list(range(replica_id)) |
| |
| with distribution.scope(), self.assertRaises(AssertionError): |
| distribution.extended.call_for_each_replica(run_fn) |
| |
| def testReduceToCpu(self, distribution): |
| if not distribution.extended._use_merge_call(): |
| self.skipTest("Collective all-reduce does not support int32 on GPU.") |
| |
| with distribution.scope(): |
| result = distribution.extended.call_for_each_replica(_replica_id) |
| reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None) |
| expected = sum(range(distribution.num_replicas_in_sync)) |
| self.assertEqual(expected, self.evaluate(reduced)) |
| |
| def testReduceToCpuNested(self, distribution): |
| if not distribution.extended._use_merge_call(): |
| self.skipTest("Collective all-reduce does not support int32 on GPU.") |
| |
| with distribution.scope(): |
| def replica_fn(input_tensor): |
| return input_tensor + constant_op.constant( |
| 1.0), input_tensor - constant_op.constant(1.0) |
| |
| input_tensor = constant_op.constant(3.0) |
| run_result = distribution.run(replica_fn, args=(input_tensor,)) |
| reduced_result = distribution.reduce("SUM", run_result, axis=None) |
| expected_result = (4 * distribution.num_replicas_in_sync, |
| 2 * distribution.num_replicas_in_sync) |
| |
| self.assertEqual(expected_result, self.evaluate(reduced_result)) |
| |
| def reduce_axis_helper(self, distribution, replica_squared_fn): |
| with distribution.scope(): |
| num_replicas = distribution.num_replicas_in_sync |
| result = distribution.extended.call_for_each_replica(replica_squared_fn) |
| # sum |
| reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0) |
| expected = sum(x * (x + 1) for x in range(num_replicas)) |
| self.assertNear(expected, self.evaluate(reduced), 0.00001) |
| |
| # mean |
| reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0) |
| expected /= sum(x + 1 for x in range(num_replicas)) |
| self.assertNear(expected, self.evaluate(reduced), 0.00001) |
| |
| def testReduceAxisToCpu(self, distribution): |
| if not distribution.extended._use_merge_call(): |
| self.skipTest("Collective all-reduce does not support int32 on GPU.") |
| for dtype in (dtypes.float32, dtypes.int32): |
| def replica_squared_fn(dtype=dtype): |
| # Lists with different lengths on different replicas. |
| replica_id = _replica_id_as_int() |
| return array_ops.identity( |
| math_ops.cast([replica_id] * (replica_id + 1), dtype)) |
| |
| self.reduce_axis_helper(distribution, replica_squared_fn) |
| |
| def set_v2_tensorshape(self, v2): |
| if v2: |
| tensor_shape.enable_v2_tensorshape() |
| else: |
| tensor_shape.disable_v2_tensorshape() |
| |
| def testReduceAxisToCpuUnknownShape(self, distribution): |
| if not distribution.extended._use_merge_call(): |
| self.skipTest("Collective all-reduce does not support int32 on GPU.") |
| original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE # pylint: disable=protected-access |
| try: |
| for v2 in (False, True): |
| self.set_v2_tensorshape(v2) |
| for dtype in (dtypes.float32, dtypes.int32): |
| for shape in ((None,), None): # Test both unknown size and rank. |
| def replica_squared_fn(dtype=dtype, shape=shape): |
| # Lists with different lengths on different replicas. |
| replica_id = _replica_id_as_int() |
| tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype) |
| # Erase shape information |
| return array_ops.placeholder_with_default(tensor, shape=shape) |
| |
| self.reduce_axis_helper(distribution, replica_squared_fn) |
| finally: |
| self.set_v2_tensorshape(original_v2) |
| |
| def testReplicateDataset(self, distribution): |
| if tf2.enabled() and not context.executing_eagerly(): |
| self.skipTest("Skipping test since we do not support graph mode in TF 2") |
| |
| dataset_fn = lambda: dataset_ops.Dataset.range(10) |
| expected_values = [[i, i+1] for i in range(0, 10, 2)] |
| input_fn = self._input_fn_to_test_input_context( |
| dataset_fn, |
| expected_num_replicas_in_sync=2, |
| expected_num_input_pipelines=1, |
| expected_input_pipeline_id=0) |
| self._test_input_fn_iterable(distribution, input_fn, expected_values) |
| |
| def testMakeInputFnIteratorWithDataset(self, distribution): |
| dataset_fn = lambda: dataset_ops.Dataset.range(10) |
| expected_values = [[i, i+1] for i in range(0, 10, 2)] |
| |
| input_fn = self._input_fn_to_test_input_context( |
| dataset_fn, |
| expected_num_replicas_in_sync=2, |
| expected_num_input_pipelines=1, |
| expected_input_pipeline_id=0) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, |
| expected_values) |
| |
| def testMakeInputFnIteratorWithCallable(self, distribution): |
| def fn(): |
| dataset = dataset_ops.Dataset.range(2).interleave( |
| (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) |
| it = dataset_ops.make_one_shot_iterator(dataset) |
| return it.get_next |
| expected_values = [[i, i] for i in range(0, 10)] |
| |
| input_fn = self._input_fn_to_test_input_context( |
| fn, |
| expected_num_replicas_in_sync=2, |
| expected_num_input_pipelines=1, |
| expected_input_pipeline_id=0) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, |
| expected_values, test_reinitialize=False, |
| ignore_order=True) |
| |
| def testNumpyDataset(self, distribution): |
| self._test_numpy_dataset(distribution) |
| |
| def testGlobalStepUpdate(self, distribution): |
| self._test_global_step_update(distribution) |
| |
| def testRun(self, distribution): |
| self._test_run(distribution) |
| |
| def testAllReduceSum(self, distribution): |
| self._test_all_reduce_sum(distribution) |
| |
| def testAllReduceSumGradients(self, distribution): |
| self._test_all_reduce_sum_gradients(distribution) |
| |
| def testAllReduceSumGradientTape(self, distribution): |
| self._test_all_reduce_sum_gradient_tape(distribution) |
| |
| def testAllReduceMean(self, distribution): |
| self._test_all_reduce_mean(distribution) |
| |
| def testAllReduceMeanGradients(self, distribution): |
| self._test_all_reduce_mean_gradients(distribution) |
| |
| def testAllReduceMeanGradientTape(self, distribution): |
| self._test_all_reduce_mean_gradient_tape(distribution) |
| |
| def testSummaryForReplicaZeroOnly(self, distribution): |
| self._test_summary_for_replica_zero_only(distribution) |
| |
| def testTrainableVariables(self, distribution): |
| self._test_trainable_variable(distribution) |
| |
| def test_prefetch_to_device_dataset(self, distribution): |
| input_options = distribute_lib.InputOptions( |
| experimental_fetch_to_device=True) |
| dataset = dataset_ops.Dataset.range(100) |
| dataset = dataset.batch(distribution.num_replicas_in_sync) |
| dataset = distribution.experimental_distribute_dataset( |
| dataset, options=input_options) |
| if context.executing_eagerly(): |
| item = next(iter(dataset)) |
| else: |
| if isinstance(dataset, input_lib_v1.DistributedDatasetV1): |
| item = dataset.make_initializable_iterator().get_next() |
| else: |
| self.skipTest("unsupported test combination") |
| device_types = [ |
| tf_device.DeviceSpec.from_string(tensor.device).device_type for |
| tensor in item.values] |
| expected_device_types = [ |
| tf_device.DeviceSpec.from_string(device).device_type for |
| device in distribution.extended.worker_devices] |
| self.assertAllEqual(device_types, expected_device_types) |
| |
| def test_prefetch_to_host_dataset(self, distribution): |
| input_options = distribute_lib.InputOptions( |
| experimental_fetch_to_device=False) |
| dataset = dataset_ops.Dataset.range(100) |
| dataset = dataset.batch(distribution.num_replicas_in_sync) |
| dataset = distribution.experimental_distribute_dataset( |
| dataset, options=input_options) |
| if context.executing_eagerly(): |
| item = next(iter(dataset)) |
| else: |
| if isinstance(dataset, input_lib_v1.DistributedDatasetV1): |
| item = dataset.make_initializable_iterator().get_next() |
| else: |
| self.skipTest("unsupported test combination") |
| device_types = { |
| tf_device.DeviceSpec.from_string(tensor.device).device_type for |
| tensor in item.values} |
| self.assertAllEqual(list(device_types), ["CPU"]) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager", "graph"], required_gpus=[2])) |
| class MirroredCollectiveOpTest(strategy_test_lib.DistributionTestBase, |
| strategy_test_lib.TwoDeviceDistributionTestBase, |
| parameterized.TestCase): |
| |
| def tearDown(self): |
| super(MirroredCollectiveOpTest, self).tearDown() |
| context._reset_context() |
| |
| def testAllCpu(self): |
| @def_function.function |
| def fn(): |
| strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]) |
| if ops.executing_eagerly_outside_functions(): |
| self.assertIsInstance( |
| strategy.extended._collective_ops, |
| cross_device_ops_lib.CollectiveAllReduce) |
| self.assertEqual( |
| strategy.extended._collective_ops._options.implementation, |
| collective_util.CommunicationImplementation.RING) |
| else: |
| self.assertIsInstance(strategy.extended._collective_ops, |
| cross_device_ops_lib.ReductionToOneDevice) |
| fn() |
| |
| def testMixedDevices(self): |
| @def_function.function |
| def fn(): |
| strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "GPU:0"]) |
| self.assertIsInstance( |
| strategy.extended._collective_ops, |
| cross_device_ops_lib.ReductionToOneDevice) |
| fn() |
| |
| def testAllPhysicalGpu(self): |
| @def_function.function |
| def fn(): |
| strategy = mirrored_strategy.MirroredStrategy(["GPU:0", "GPU:1"]) |
| self.assertIsInstance( |
| strategy.extended._collective_ops, |
| cross_device_ops_lib.CollectiveAllReduce) |
| self.assertEqual( |
| strategy.extended._collective_ops._options.implementation, |
| collective_util.CommunicationImplementation.NCCL) |
| fn() |
| |
| def testVirtualGpu(self): |
| # Logical devices cannot be changed after context initialization. |
| context._reset_context() |
| |
| physical_gpus = context.context().list_physical_devices(device_type="GPU") |
| context.context().set_logical_device_configuration(physical_gpus[1], [ |
| context.LogicalDeviceConfiguration(memory_limit=1024), |
| context.LogicalDeviceConfiguration(memory_limit=1024) |
| ]) |
| @def_function.function |
| def fn(): |
| strategy = mirrored_strategy.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2"]) |
| if ops.executing_eagerly_outside_functions(): |
| self.assertIsInstance( |
| strategy.extended._collective_ops, |
| cross_device_ops_lib.CollectiveAllReduce) |
| self.assertEqual( |
| strategy.extended._collective_ops._options.implementation, |
| collective_util.CommunicationImplementation.RING) |
| else: |
| self.assertEqual(strategy.extended._collective_ops, |
| cross_device_ops_lib.ReductionToOneDevice) |
| fn() |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], required_gpus=[2], use_default=[True, False])) |
| class MirroredGetCrossDeviceOpTest( |
| strategy_test_lib.DistributionTestBase, |
| strategy_test_lib.TwoDeviceDistributionTestBase, parameterized.TestCase): |
| |
| def tearDown(self): |
| super().tearDown() |
| context._reset_context() |
| |
| def testGpusCollectiveOp(self, use_default): |
| |
| @def_function.function(jit_compile=util.is_xla_enabled()) |
| def fn(var, use_default): |
| |
| if use_default or util.is_xla_enabled(): |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.CollectiveAllReduce) |
| else: |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.NcclAllReduce) |
| |
| strategy = mirrored_strategy.MirroredStrategy( |
| ["GPU:0", "GPU:1"], |
| cross_device_ops=None |
| if use_default else cross_device_ops_lib.NcclAllReduce()) |
| with strategy.scope(): |
| var = variables.Variable(1.) |
| |
| fn(var, use_default) |
| |
| def testVirtualGpusCollectiveOp(self, use_default): |
| # Logical devices cannot be changed after context initialization. |
| context._reset_context() |
| |
| physical_gpus = context.context().list_physical_devices(device_type="GPU") |
| context.context().set_logical_device_configuration(physical_gpus[1], [ |
| context.LogicalDeviceConfiguration(memory_limit=1024), |
| context.LogicalDeviceConfiguration(memory_limit=1024) |
| ]) |
| |
| @def_function.function(jit_compile=util.is_xla_enabled()) |
| def fn(var, use_default): |
| |
| if use_default or util.is_xla_enabled(): |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.CollectiveAllReduce) |
| self.assertEqual( |
| strategy.extended._get_cross_device_ops( |
| var)._options.implementation, |
| collective_util.CommunicationImplementation.RING) |
| else: |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.NcclAllReduce) |
| |
| strategy = mirrored_strategy.MirroredStrategy( |
| ["GPU:0", "GPU:1", "GPU:2"], |
| cross_device_ops=None |
| if use_default else cross_device_ops_lib.NcclAllReduce()) |
| |
| with strategy.scope(): |
| var = variables.Variable(1.) |
| |
| fn(var, use_default) |
| |
| def testCpusCollectiveOp(self, use_default): |
| del use_default |
| if util.is_xla_enabled(): |
| self.skipTest("Only expected to run under non-XLA context.") |
| |
| @def_function.function(jit_compile=True) |
| def fn(var): |
| |
| if not ops.executing_eagerly_outside_functions(): |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.ReductionToOneDevice) |
| else: |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.CollectiveAllReduce) |
| |
| strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]) |
| with strategy.scope(): |
| var = variables.Variable(1.) |
| |
| fn(var) |
| |
| def testMixedDevicesCollectiveOp(self, use_default): |
| del use_default |
| if util.is_xla_enabled(): |
| self.skipTest("All devices should be identical in XLA context.") |
| |
| # XLA is not supported if devices are not of the same type. |
| strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "GPU:0"]) |
| with strategy.scope(): |
| var = variables.Variable(1.) |
| |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.ReductionToOneDevice) |
| |
| def testMirroredStrategyInt32VariableCollectiveOp(self, use_default): |
| if util.is_xla_enabled(): |
| self.skipTest("Only expected to run under non-XLA context.") |
| |
| strategy = mirrored_strategy.MirroredStrategy( |
| ["GPU:0", "GPU:1"], |
| cross_device_ops=None |
| if use_default else cross_device_ops_lib.NcclAllReduce()) |
| with strategy.scope(): |
| # CollevtiveOp does not support int32 on GPU. |
| var = variables.Variable(1) |
| |
| self.assertIsInstance( |
| strategy.extended._get_cross_device_ops(var), |
| cross_device_ops_lib.ReductionToOneDevice) |
| |
| |
| def one_device_combinations(): |
| return combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_one_gpu, |
| ], |
| mode=["graph", "eager"]) |
| |
| |
| @combinations.generate(one_device_combinations()) |
| class MirroredOneDeviceDistributionTest( |
| strategy_test_lib.DistributionTestBase, |
| strategy_test_lib.OneDeviceDistributionTestBase, |
| parameterized.TestCase): |
| |
| def testMinimizeLoss(self, distribution): |
| if context.executing_eagerly(): |
| self._test_minimize_loss_eager(distribution) |
| else: |
| self._test_minimize_loss_graph(distribution) |
| |
| def testReplicaId(self, distribution): |
| self._test_replica_id(distribution) |
| |
| def testCallAndMergeExceptions(self, distribution): |
| self._test_call_and_merge_exceptions(distribution) |
| |
| def testRun(self, distribution): |
| self._test_run(distribution) |
| |
| def testAllReduceSum(self, distribution): |
| self._test_all_reduce_sum(distribution) |
| |
| def testAllReduceSumGradients(self, distribution): |
| self._test_all_reduce_sum_gradients(distribution) |
| |
| def testAllReduceSumGradientTape(self, distribution): |
| self._test_all_reduce_sum_gradient_tape(distribution) |
| |
| def testAllReduceMean(self, distribution): |
| self._test_all_reduce_mean(distribution) |
| |
| def testAllReduceMeanGradients(self, distribution): |
| self._test_all_reduce_mean_gradients(distribution) |
| |
| def testAllReduceMeanGradientTape(self, distribution): |
| self._test_all_reduce_mean_gradient_tape(distribution) |
| |
| |
| class MirroredStrategyVariableCreatorStackTest( |
| test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph"])) |
| def testCreatorStacksAreThreadLocal(self, distribution): |
| def model_fn(): |
| replica_id_str = str(self.evaluate(_replica_id())) |
| |
| def thread_creator_fn(next_creator, **kwargs): |
| return next_creator(**kwargs) + ":thread_" + replica_id_str |
| |
| with variable_scope.variable_creator_scope(thread_creator_fn): |
| # Create a variable in this scope. |
| v = variable_v1.VariableV1(1.0) |
| |
| # This will pause the current thread, and execute the other thread. |
| distribute_lib.get_replica_context().merge_call(lambda _: _) |
| return v |
| |
| def main_thread_creator(next_creator, **kwargs): |
| # We are not using the underlying next_creator for test purposes. |
| del next_creator, kwargs |
| return "main_thread" |
| |
| with context.graph_mode(), \ |
| distribution.scope(), \ |
| variable_scope.variable_creator_scope(main_thread_creator): |
| result = distribution.extended.call_for_each_replica(model_fn) |
| result = distribution.experimental_local_results(result) |
| expected = ("main_thread:thread_0", "main_thread:thread_1") |
| self.assertEqual(expected, result) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph", "eager"])) |
| class MirroredStrategyCallForEachReplicaTest(test.TestCase): |
| |
| def testExecutingEagerlyOutsideFunction(self, distribution): |
| """Verify we preserve the value of executing_eagerly_outside_functions().""" |
| def model_fn(): |
| return ops.executing_eagerly_outside_functions() |
| |
| originally = ops.executing_eagerly_outside_functions() |
| with distribution.scope(): |
| in_scope = ops.executing_eagerly_outside_functions() |
| in_model_fn = distribution.extended.call_for_each_replica(model_fn) |
| unwrapped = distribution.experimental_local_results(in_model_fn) |
| self.assertEqual(in_scope, unwrapped[0]) |
| self.assertEqual(in_scope, originally) |
| |
| # Verify this all again, but this time in a FuncGraph. |
| with func_graph.FuncGraph("fg").as_default(), distribution.scope(): |
| in_scope = ops.executing_eagerly_outside_functions() |
| in_model_fn = distribution.extended.call_for_each_replica(model_fn) |
| unwrapped = distribution.experimental_local_results(in_model_fn) |
| self.assertEqual(in_scope, unwrapped[0]) |
| self.assertEqual(in_scope, originally) |
| |
| def testFunctionInCallForEachReplica(self, distribution): |
| traces = [] |
| @def_function.function |
| def model_fn(): |
| traces.append(1) |
| return distribute_lib.get_replica_context().replica_id_in_sync_group |
| |
| with distribution.scope(): |
| result = distribution.extended.call_for_each_replica(model_fn) |
| self.assertEqual( |
| (0, 1), |
| self.evaluate(distribution.experimental_local_results(result))) |
| self.assertLen(traces, distribution.num_replicas_in_sync) |
| |
| def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution): |
| traces = [] |
| @def_function.function |
| def model_fn(): |
| traces.append(1) |
| return distribute_lib.get_replica_context().replica_id_in_sync_group |
| |
| @def_function.function |
| def step(): |
| return distribution.extended.call_for_each_replica(model_fn) |
| |
| with distribution.scope(): |
| result = step() |
| self.assertEqual( |
| (0, 1), |
| self.evaluate(distribution.experimental_local_results(result))) |
| self.assertLen(traces, distribution.num_replicas_in_sync) |
| |
| def testControlFlowFunctionInCallForEachReplicaWithMergeCall( |
| self, distribution): |
| |
| def merge_fn(strategy, value): |
| return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) |
| |
| @def_function.function |
| def model_fn(): |
| |
| def body_fn(i): |
| return distribute_lib.get_replica_context().merge_call( |
| merge_fn, args=(i,)) |
| |
| return while_loop.while_loop_v2(lambda i: i < 2, body_fn, [0]) |
| |
| with distribution.scope(): |
| with self.assertRaisesRegex( |
| RuntimeError, "`merge_call` called while defining a new graph."): |
| distribution.extended.call_for_each_replica(model_fn) |
| |
| def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution): |
| |
| def merge_fn(strategy, value): |
| return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) |
| |
| def model_fn(): |
| |
| @def_function.function |
| def model_fn_nested(): |
| t = constant_op.constant(1) |
| return distribute_lib.get_replica_context().merge_call( |
| merge_fn, args=(t,)) |
| |
| return model_fn_nested() |
| |
| with distribution.scope(): |
| with self.assertRaisesRegex( |
| RuntimeError, "`merge_call` called while defining a new graph."): |
| distribution.extended.call_for_each_replica(model_fn) |
| |
| def testFunctionInCallForEachReplicaWithMergeCall(self, distribution): |
| def merge_fn(_): |
| pass |
| |
| @def_function.function |
| def model_fn(): |
| distribute_lib.get_replica_context().merge_call(merge_fn) |
| return 0. |
| |
| with distribution.scope(): |
| self.assertEqual( |
| self.evaluate(distribution.extended.call_for_each_replica(model_fn)), |
| 0.) |
| |
| def testFunctionInCallForEachReplicaCached(self, distribution): |
| traces = [] |
| |
| @def_function.function |
| def model_fn(): |
| traces.append(None) |
| |
| self.assertEmpty(traces) |
| |
| for i in range(10): |
| distribution.extended.call_for_each_replica(model_fn) |
| |
| if i == 0: |
| num_devices = len(traces) |
| self.assertGreater(num_devices, 0) |
| else: |
| # model_fn should not have been re-evaluated so the length should remain |
| # the same. |
| self.assertLen(traces, num_devices) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph"])) |
| class MirroredStrategyNameScopeTest(test.TestCase): |
| # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not |
| # testing this in eager mode. |
| |
| def testNameScope(self, distribution): |
| def model_fn(): |
| with ops.name_scope("foo"): |
| a = constant_op.constant(1.0, name="a") |
| distribute_lib.get_replica_context().merge_call(lambda _: _) |
| b = constant_op.constant(1.0, name="b") |
| return a, b |
| |
| with context.graph_mode(), distribution.scope(): |
| with ops.name_scope("main"): |
| result = distribution.extended.call_for_each_replica(model_fn) |
| self.assertEqual(2, len(result)) |
| for v, name in zip(result, ["a", "b"]): |
| self.assertIsInstance(v, values.DistributedValues) |
| v0, v1 = distribution.experimental_local_results(v) |
| self.assertEqual("main/foo/" + name + ":0", v0.name) |
| self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) |
| |
| def testWithDefaultName(self, distribution): |
| def model_fn(): |
| with ops.name_scope(None, "foo"): |
| a = constant_op.constant(1.0, name="a") |
| distribute_lib.get_replica_context().merge_call(lambda _: _) |
| b = constant_op.constant(2.0, name="b") |
| return a, b |
| |
| with context.graph_mode(), distribution.scope(): |
| result = distribution.extended.call_for_each_replica(model_fn) |
| self.assertEqual(2, len(result)) |
| for v, name in zip(result, ["a", "b"]): |
| self.assertIsInstance(v, values.DistributedValues) |
| v0, v1 = distribution.experimental_local_results(v) |
| self.assertEqual("foo/" + name + ":0", v0.name) |
| self.assertEqual("replica_1/foo/" + name + ":0", v1.name) |
| |
| # variable_v1.VariableV1() respects name scopes when creating |
| # variables. On the other hand variable_scope.get_variable() ignores name |
| # scopes but respects variable scope when creating variables. We test both |
| # methods of creating variables to make sure that we have the same |
| # variable names in both cases. |
| def testNameScopeWithVariable(self, distribution): |
| def in_cross_replica(_): |
| c = variable_v1.VariableV1(1.0, name="c") |
| return c |
| |
| def model_fn(): |
| b = variable_v1.VariableV1(1.0, name="b") |
| with ops.name_scope("foo"): |
| c = distribute_lib.get_replica_context().merge_call(in_cross_replica) |
| return b, c |
| |
| with context.graph_mode(), distribution.scope(): |
| with ops.name_scope("main"): |
| a = variable_v1.VariableV1(1.0, name="a") |
| result = distribution.extended.call_for_each_replica(model_fn) |
| result_b = result[0] |
| result_c = result[1] |
| self.assertIsInstance(result_b, values.DistributedValues) |
| self.assertIsInstance(result_c, values.DistributedValues) |
| a0, a1 = distribution.experimental_local_results(a) |
| b0, b1 = distribution.experimental_local_results(result_b) |
| c0, c1 = distribution.experimental_local_results(result_c) |
| self.assertEqual("main/a:0", a0.name) |
| self.assertEqual("main/a/replica_1:0", a1.name) |
| self.assertEqual("main/b:0", b0.name) |
| self.assertEqual("main/b/replica_1:0", b1.name) |
| self.assertEqual("main/foo/c:0", c0.name) |
| self.assertEqual("main/foo/c/replica_1:0", c1.name) |
| |
| def testNameScopeWithGetVariable(self, distribution): |
| def in_cross_replica(_): |
| c = variable_scope.get_variable("c", [1]) |
| return c |
| |
| def model_fn(): |
| b = variable_scope.get_variable("b", [1]) |
| with ops.name_scope("foo"): |
| c = distribute_lib.get_replica_context().merge_call(in_cross_replica) |
| return b, c |
| |
| with context.graph_mode(), distribution.scope(): |
| with ops.name_scope("main"): |
| a = variable_scope.get_variable("a", [1]) |
| result = distribution.extended.call_for_each_replica(model_fn) |
| result_b = result[0] |
| result_c = result[1] |
| self.assertIsInstance(result_b, values.DistributedValues) |
| self.assertIsInstance(result_c, values.DistributedValues) |
| a0, a1 = distribution.experimental_local_results(a) |
| b0, b1 = distribution.experimental_local_results(result_b) |
| c0, c1 = distribution.experimental_local_results(result_c) |
| self.assertEqual("a:0", a0.name) |
| self.assertEqual("a/replica_1:0", a1.name) |
| self.assertEqual("b:0", b0.name) |
| self.assertEqual("b/replica_1:0", b1.name) |
| self.assertEqual("c:0", c0.name) |
| self.assertEqual("c/replica_1:0", c1.name) |
| |
| def testVariableScopeWithGetVariable(self, distribution): |
| |
| def in_cross_replica(_): |
| c = variable_scope.get_variable("c", [1]) |
| return c |
| |
| def model_fn(): |
| b = variable_scope.get_variable("b", [1]) |
| with variable_scope.variable_scope("foo"): |
| c = distribute_lib.get_replica_context().merge_call(in_cross_replica) |
| return b, c |
| |
| with context.graph_mode(), distribution.scope(): |
| with variable_scope.variable_scope("main"): |
| a = variable_scope.get_variable("a", [1]) |
| result = distribution.extended.call_for_each_replica(model_fn) |
| result_b = result[0] |
| result_c = result[1] |
| self.assertIsInstance(result_b, values.DistributedValues) |
| self.assertIsInstance(result_c, values.DistributedValues) |
| a0, a1 = distribution.experimental_local_results(a) |
| b0, b1 = distribution.experimental_local_results(result_b) |
| c0, c1 = distribution.experimental_local_results(result_c) |
| self.assertEqual("main/a:0", a0.name) |
| self.assertEqual("main/a/replica_1:0", a1.name) |
| self.assertEqual("main/b:0", b0.name) |
| self.assertEqual("main/b/replica_1:0", b1.name) |
| self.assertEqual("main/foo/c:0", c0.name) |
| self.assertEqual("main/foo/c/replica_1:0", c1.name) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| combinations.NamedDistribution( |
| "Mirrored3Devices", |
| # pylint: disable=g-long-lambda |
| lambda: mirrored_strategy.MirroredStrategy( |
| ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), |
| required_gpus=2) |
| ], |
| mode=["graph", "eager"])) |
| class MirroredThreeDeviceDistributionTest( |
| strategy_test_lib.DistributionTestBase, |
| parameterized.TestCase): |
| |
| def testThreeDevices(self, distribution): |
| def model_fn(): |
| v = variable_v1.VariableV1(1.0, name="foo") |
| distribute_lib.get_replica_context().merge_call(lambda _: _) |
| return v |
| |
| with distribution.scope(): |
| result = distribution.extended.call_for_each_replica(model_fn) |
| self.assertTrue(distribute_utils.is_mirrored(result)) |
| self.assertEqual("foo:0", result.name) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph", "eager"])) |
| class MirroredVariableUpdateTest(test.TestCase): |
| # The following tests check assign, assign_add and assign_sub on Mirrored |
| # variables in replica and cross replica context. |
| |
| def testAssignMirroredVarReplicaContextWithoutAggregationType(self, |
| distribution): |
| def var_fn(): |
| v = variable_v1.VariableV1(1.0, name="foo") |
| return v |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| |
| def model_fn(): |
| return mirrored_var.assign(5.0) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(5.0, self.evaluate(mirrored_var)) |
| |
| def testAssignMirroredVarReplicaContextWithSum(self, distribution): |
| # Test that we don't reduce a non-per-replica value with the "sum" |
| # aggregation type. |
| def var_fn(): |
| v = variable_v1.VariableV1( |
| 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) |
| return v |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| |
| def model_fn(): |
| return mirrored_var.assign(5.0) |
| |
| if distribution.extended._use_merge_call(): |
| with self.assertRaisesRegex( |
| ValueError, "A non-DistributedValues value 5.0 cannot be reduced " |
| "with the given reduce op ReduceOp.SUM."): |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| else: |
| result = self.evaluate( |
| distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertAllEqual(result[0], 5.0) |
| |
| def testAssignMirroredVarCrossDeviceContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1(1.0, name="foo") |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) |
| self.assertEqual(6.0, mirrored_var_result) |
| |
| def testAssignMirroredVarReplicaContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| value = math_ops.cast( |
| distribute_lib.get_replica_context().replica_id_in_sync_group, |
| mirrored_var.dtype) |
| return mirrored_var.assign(value) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(0.5, self.evaluate(mirrored_var)) |
| |
| def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| return mirrored_var.assign(5.0) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(5.0, self.evaluate(mirrored_var)) |
| |
| def testAssignAddMirroredVarCrossDeviceContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1(1.0, name="foo") |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| |
| # read_value == True |
| mirrored_var_result = self.evaluate( |
| mirrored_var.assign_add(6.0, read_value=True)) |
| self.assertEqual(7.0, mirrored_var_result) |
| self.assertEqual( |
| 7.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[0])) |
| self.assertEqual( |
| 7.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[1])) |
| self.assertEqual( |
| distribution.extended.worker_devices[0], mirrored_var._devices[0]) |
| self.assertEqual( |
| distribution.extended.worker_devices[1], mirrored_var._devices[1]) |
| |
| # read_value == False |
| self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) |
| self.assertEqual( |
| 9.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[0])) |
| self.assertEqual( |
| 9.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[1])) |
| self.assertEqual( |
| distribution.extended.worker_devices[0], mirrored_var._devices[0]) |
| self.assertEqual( |
| distribution.extended.worker_devices[1], mirrored_var._devices[1]) |
| |
| def testAssignAddMirroredVarReplicaContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| value = math_ops.cast( |
| distribute_lib.get_replica_context().replica_id_in_sync_group, |
| mirrored_var.dtype) |
| return mirrored_var.assign_add(value) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(1.5, self.evaluate(mirrored_var)) |
| |
| def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| return mirrored_var.assign_add(5.0) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(6.0, self.evaluate(mirrored_var)) |
| |
| def testAssignSubMirroredVarCrossDeviceContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1(5.0, name="foo") |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(5.0, self.evaluate(mirrored_var)) |
| mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) |
| self.assertEqual(3.0, mirrored_var_result) |
| self.assertEqual( |
| 3.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[0])) |
| self.assertEqual( |
| 3.0, |
| self.evaluate( |
| distribution.experimental_local_results(mirrored_var)[1])) |
| self.assertEqual( |
| distribution.extended.worker_devices[0], mirrored_var._devices[0]) |
| self.assertEqual( |
| distribution.extended.worker_devices[1], mirrored_var._devices[1]) |
| |
| def testAssignSubMirroredVarReplicaContext(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(5.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| value = math_ops.cast( |
| distribute_lib.get_replica_context().replica_id_in_sync_group, |
| mirrored_var.dtype) |
| return mirrored_var.assign_sub(value) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(4.5, self.evaluate(mirrored_var)) |
| |
| def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): |
| def var_fn(): |
| return variable_v1.VariableV1( |
| 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(5.0, self.evaluate(mirrored_var)) |
| |
| def model_fn(): |
| return mirrored_var.assign_sub(1.0) |
| |
| self.evaluate(distribution.experimental_local_results( |
| distribution.extended.call_for_each_replica(model_fn))) |
| self.assertEqual(4.0, self.evaluate(mirrored_var)) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph", "eager"])) |
| class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase): |
| |
| def testAssignMirroredVarInitializer(self, distribution): |
| # This test is not eager compatible since in eager variables are initialized |
| # upon construction instead of once the initialization op is run. |
| with context.graph_mode(): |
| def var_fn(): |
| v = variable_v1.VariableV1(1.0, name="foo") |
| return v |
| |
| with distribution.scope(): |
| mirrored_var = distribution.extended.call_for_each_replica(var_fn) |
| self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) |
| self.assertFalse(self.evaluate(mirrored_var.is_initialized())) |
| self.evaluate(mirrored_var.initializer) |
| self.assertTrue(self.evaluate(mirrored_var.is_initialized())) |
| |
| def testAssignReplicaLocalVarInitializer(self, distribution): |
| # This test is not eager compatible since in eager variables are initialized |
| # upon construction instead of once the initialization op is run. |
| with context.graph_mode(): |
| def model_fn(): |
| v_sum = variable_v1.VariableV1( |
| 1.0, |
| synchronization=variable_scope.VariableSynchronization.ON_READ, |
| aggregation=variable_scope.VariableAggregation.SUM) |
| self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) |
| return v_sum |
| |
| with distribution.scope(): |
| sync_on_read_var = distribution.extended.call_for_each_replica( |
| model_fn) |
| self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) |
| self.assertFalse(self.evaluate(sync_on_read_var.is_initialized())) |
| self.evaluate(sync_on_read_var.initializer) |
| self.assertTrue(self.evaluate(sync_on_read_var.is_initialized())) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["graph", "eager"])) |
| class SyncOnReadVariableAssignTest(test.TestCase): |
| |
| def testAssignReplicaLocalVarSumAggregation(self, distribution): |
| def model_fn(): |
| v_sum = variable_v1.VariableV1( |
| 1.0, |
| synchronization=variable_scope.VariableSynchronization.ON_READ, |
| aggregation=variable_scope.VariableAggregation.SUM) |
| return v_sum |
| |
| with distribution.scope(): |
| sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) |
| self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| # Each replica has a value of 1.0 assigned to it in replica context. |
| # When we read the value using `read_var` we should see the SUM of each of |
| # values on each of the replicas. |
| self.assertEqual(2.0, self.evaluate( |
| distribution.extended.read_var(sync_on_read_var))) |
| # Assigning 6.0 in cross replica context will assign a value of |
| # 6.0/num_replicas to each replica. |
| tlv_ops = sync_on_read_var.assign(6.0) |
| self.evaluate(tlv_ops) |
| # On reading the sync on read var we should get the assigned value back. |
| # The value on all the replicas are added before being returned by |
| # `read_var`. |
| self.assertEqual(6.0, self.evaluate( |
| distribution.extended.read_var(sync_on_read_var))) |
| |
| def testAssignReplicaLocalVarMeanAggregation(self, distribution): |
| def model_fn(): |
| v_sum = variable_v1.VariableV1( |
| 1.0, |
| synchronization=variable_scope.VariableSynchronization.ON_READ, |
| aggregation=variable_scope.VariableAggregation.MEAN) |
| return v_sum |
| |
| with distribution.scope(): |
| sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) |
| self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) |
| self.evaluate(variables.global_variables_initializer()) |
| # Each replica has a value of 1.0 assigned to it in replica context. |
| # When we read the value using `read_var` we should see the MEAN of values |
| # on all replicas which is the value assigned in replica context. |
| self.assertEqual(1.0, self.evaluate( |
| distribution.extended.read_var(sync_on_read_var))) |
| tlv_ops = sync_on_read_var.assign(6.0) |
| self.evaluate(tlv_ops) |
| # On reading the sync on read var we should get the MEAN of all values |
| # which is equal to the value assigned. |
| self.assertEqual(6.0, self.evaluate( |
| distribution.extended.read_var(sync_on_read_var))) |
| |
| |
| class MockModel(object): |
| |
| def __init__(self, two_variables=False): |
| self.variables = [] |
| self.variables.append(variable_v1.VariableV1(1.25, name="dummy_var1")) |
| if two_variables: |
| self.variables.append(variable_v1.VariableV1(2.0, name="dummy_var2")) |
| |
| def __call__(self, factor=2): |
| x = factor * self.variables[0] |
| if len(self.variables) > 1: |
| x += self.variables[1] |
| return x |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| combinations.NamedDistribution( |
| "Mirrored", |
| # pylint: disable=g-long-lambda |
| lambda: mirrored_strategy.MirroredStrategy( |
| devices=mirrored_strategy.all_local_devices(), |
| cross_device_ops=cross_device_ops_lib.ReductionToOneDevice( |
| ), |
| ), |
| required_gpus=1) |
| ], |
| mode=["graph"])) |
| class MultiWorkerMirroredStrategyTest( |
| multi_worker_test_base.MultiWorkerTestBase, |
| strategy_test_lib.DistributionTestBase): |
| |
| def _configure_distribution_strategy(self, distribution): |
| cluster_spec = server_lib.ClusterSpec({ |
| "worker": ["/job:worker/task:0", "/job:worker/task:1"] |
| }) |
| distribution.configure(cluster_spec=cluster_spec) |
| |
| def test_num_replicas_in_sync(self, distribution): |
| self._configure_distribution_strategy(distribution) |
| # We calculate the total number of gpus across the workers(2) specified in |
| # the cluster spec. |
| self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) |
| |
| def testMinimizeLossGraph(self, distribution): |
| self._configure_distribution_strategy(distribution) |
| self._test_minimize_loss_graph(distribution, learning_rate=0.05) |
| |
| def testDeviceScope(self, distribution): |
| """Test the device scope of multi-worker MirroredStrategy.""" |
| self._configure_distribution_strategy(distribution) |
| with distribution.scope(): |
| a = constant_op.constant(1.) |
| with ops.device("/cpu:0"): |
| b = constant_op.constant(1.) |
| self.assertEqual(a.device, "/job:worker/task:0") |
| self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") |
| |
| def testMakeInputFnIteratorWithDataset(self, distribution): |
| self._configure_distribution_strategy(distribution) |
| dataset_fn = lambda: dataset_ops.Dataset.range(100) |
| num_gpus = context.num_gpus() |
| num_workers = 2 |
| |
| expected_values = [[i+j for j in range(num_gpus)] * num_workers |
| for i in range(0, 100, num_gpus)] |
| |
| with context.graph_mode(), self.cached_session() as sess: |
| # `expected_input_pipeline_id` is None because the input_fn will be called |
| # multiple times, each with a different input_pipeline_id. |
| input_fn = self._input_fn_to_test_input_context( |
| dataset_fn, |
| expected_num_replicas_in_sync=num_workers*num_gpus, |
| expected_num_input_pipelines=num_workers, |
| expected_input_pipeline_id=None) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator( |
| iterator, distribution.extended.worker_devices, expected_values, sess) |
| |
| def testMakeInputFnIteratorWithCallable(self, distribution): |
| self._configure_distribution_strategy(distribution) |
| def fn(): |
| dataset = dataset_ops.Dataset.range(100) |
| it = dataset_ops.make_one_shot_iterator(dataset) |
| return it.get_next |
| num_gpus = context.num_gpus() |
| num_workers = 2 |
| |
| expected_values = [] |
| for i in range(0, 100, num_gpus): |
| expected_values.append([i+j for j in range(num_gpus)] * num_workers) |
| |
| with context.graph_mode(), self.cached_session() as sess: |
| # `expected_input_pipeline_id` is None because the input_fn will be called |
| # multiple times, each with a different input_pipeline_id. |
| input_fn = self._input_fn_to_test_input_context( |
| fn, |
| expected_num_replicas_in_sync=num_workers*num_gpus, |
| expected_num_input_pipelines=num_workers, |
| expected_input_pipeline_id=None) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator( |
| iterator, distribution.extended.worker_devices, expected_values, sess, |
| test_reinitialize=False, ignore_order=True) |
| |
| def testUpdateConfigProto(self, distribution): |
| distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) |
| |
| config_proto = config_pb2.ConfigProto() |
| new_config = distribution.update_config_proto(config_proto) |
| |
| # Verify isolate_session_state |
| self.assertTrue(new_config.isolate_session_state) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| combinations.NamedDistribution( |
| "Mirrored", |
| # pylint: disable=g-long-lambda |
| lambda: mirrored_strategy.MirroredStrategy( |
| devices=["/job:worker/task:0/gpu:{}".format( |
| i) for i in range(context.num_gpus())]), |
| required_gpus=1) |
| ], |
| mode=["graph"])) |
| class RemoteSingleWorkerMirroredStrategyGraph( |
| multi_worker_test_base.SingleWorkerTestBaseGraph, |
| strategy_test_lib.RemoteSingleWorkerMirroredStrategyBase): |
| |
| def _get_num_gpus(self): |
| return context.num_gpus() |
| |
| def testNumReplicasInSync(self, distribution): |
| self._testNumReplicasInSync(distribution) |
| |
| def testMinimizeLoss(self, distribution): |
| self._testMinimizeLoss(distribution) |
| |
| def testDeviceScope(self, distribution): |
| self._testDeviceScope(distribution) |
| |
| def testMakeInputFnIteratorWithDataset(self, distribution): |
| self._testMakeInputFnIteratorWithDataset(distribution) |
| |
| def testMakeInputFnIteratorWithCallable(self, distribution): |
| self._testMakeInputFnIteratorWithCallable(distribution) |
| |
| |
| class MultiWorkerMirroredStrategyTestWithChief( |
| multi_worker_test_base.MultiWorkerTestBase, |
| strategy_test_lib.DistributionTestBase): |
| |
| @classmethod |
| def setUpClass(cls): |
| """Create a local cluster with 2 workers and 1 chief.""" |
| cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( |
| num_workers=2, num_ps=0, has_chief=True) |
| cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] |
| |
| def _make_cross_device_ops(self): |
| return cross_device_ops_lib.ReductionToOneDevice() |
| |
| def testMinimizeLossGraph(self): |
| with context.graph_mode(): |
| strategy = mirrored_strategy.MirroredStrategy( |
| cross_device_ops=self._make_cross_device_ops()) |
| strategy.configure(cluster_spec=self._cluster_spec) |
| self._test_minimize_loss_graph(strategy, learning_rate=0.05) |
| |
| def testMinimizeLossGraphMirroredStrategy(self): |
| with context.graph_mode(): |
| strategy = mirrored_strategy.MirroredStrategy( |
| mirrored_strategy.all_local_devices(), |
| cross_device_ops=self._make_cross_device_ops()) |
| strategy.configure(cluster_spec=self._cluster_spec) |
| self._test_minimize_loss_graph(strategy, learning_rate=0.05) |
| |
| def testMinimizeLossGraphMirroredStrategyWithOneNode(self): |
| with context.graph_mode(): |
| cluster_spec = {} |
| cluster_spec["chief"] = self._cluster_spec["chief"] |
| tf_config = {"cluster": cluster_spec} |
| with test.mock.patch.dict("os.environ", |
| {"TF_CONFIG": json.dumps(tf_config)}): |
| strategy = mirrored_strategy.MirroredStrategy() |
| if context.num_gpus() == 0: |
| self.assertIsInstance(strategy.extended._cross_device_ops, |
| cross_device_ops_lib.ReductionToOneDevice) |
| self.skipTest("b/130551176, run the following once fixed.") |
| self._test_minimize_loss_graph(strategy, learning_rate=0.05) |
| |
| def testInitializeFromTFConfig(self): |
| with context.graph_mode(): |
| tf_config = {"cluster": self._cluster_spec} |
| with test.mock.patch.dict("os.environ", |
| {"TF_CONFIG": json.dumps(tf_config)}): |
| strategy = mirrored_strategy.MirroredStrategy( |
| cross_device_ops=self._make_cross_device_ops()) |
| self.assertEqual( |
| max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync) |
| |
| def testSummaryForReplicaZeroOnly(self): |
| with context.graph_mode(): |
| strategy = mirrored_strategy.MirroredStrategy( |
| mirrored_strategy.all_local_devices(), |
| cross_device_ops=self._make_cross_device_ops()) |
| strategy.configure(cluster_spec=self._cluster_spec) |
| self._test_summary_for_replica_zero_only(strategy) |
| |
| |
| class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_one_gpu, |
| ], |
| mode=["graph"])) |
| def testMirroredVariableAsStopGradient(self, distribution): |
| with distribution.scope(): |
| inp = constant_op.constant(1.0) |
| x = variables.Variable(1.0) |
| y = inp*x |
| grads = gradients.gradients(x, y, stop_gradients=x) |
| self.assertIsNone(grads[0]) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| mode=["eager"])) |
| class FunctionTest(test.TestCase, parameterized.TestCase): |
| |
| def testBackwardFunctionDevicePlacement(self, distribution): |
| with distribution.scope(): |
| w = variable_v1.VariableV1([1.5], name="w") |
| b = variable_v1.VariableV1([0.5], name="b") |
| |
| @def_function.function |
| def forward(x, w, b): |
| return x * w + b |
| |
| x = array_ops.identity([1.0], name="x_useless") |
| concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) |
| |
| with distribution.scope(): |
| |
| def replica_fn(): |
| with backprop.GradientTape() as t: |
| x = array_ops.identity([1.0], name="x") |
| loss = concrete_forward(x, w._get(), b._get()) - [1.0] |
| return t.gradient(loss, [w, b]) |
| |
| def step_fn(): |
| return distribution.run(replica_fn) |
| |
| context.enable_run_metadata() |
| g1, g2 = step_fn() |
| run_metadata = context.export_run_metadata() |
| context.disable_run_metadata() |
| self.assertEqual(self.evaluate(g1._primary), 1.0) |
| self.assertEqual(self.evaluate(g2._primary), 1.0) |
| |
| # Verify that this node runs on both devices. |
| node_name = "gradients_mul_grad_mul_1_x" |
| devices_for_this_node = set() |
| for partition_graph in run_metadata.partition_graphs: |
| for node in partition_graph.node: |
| if node.name == node_name: |
| devices_for_this_node.add(node.device) |
| devices = [device_util.resolve("/device:GPU:0"), |
| device_util.resolve("/device:CPU:0")] |
| self.assertSetEqual(devices_for_this_node, set(devices)) |
| |
| def testFuctionPreservesAutoGraph(self, distribution): |
| def f(): |
| self.assertTrue(converter_testing.is_inside_generated_code()) |
| return 1 |
| |
| with distribution.scope(): |
| |
| @def_function.function |
| def replica_fn(): |
| return f() |
| |
| distribution.run(replica_fn) |
| |
| def testPreserveTracebackFiltering(self, distribution): |
| traceback_utils.disable_traceback_filtering() |
| self.assertFalse(traceback_utils.is_traceback_filtering_enabled()) |
| |
| def f(): |
| self.assertFalse(traceback_utils.is_traceback_filtering_enabled()) |
| |
| distribution.run(f) |
| |
| |
| def _replica_id(): |
| replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group |
| if not isinstance(replica_id, ops.Tensor): |
| replica_id = constant_op.constant(replica_id) |
| return array_ops.identity(replica_id) |
| |
| |
| def _replica_id_as_int(): |
| replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group |
| if isinstance(replica_id, ops.Tensor): |
| replica_id = tensor_util.constant_value(replica_id) |
| return replica_id |
| |
| |
| if __name__ == "__main__": |
| # TODO(b/172304955) |
| test_util.main(config_logical_devices=False) |