| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for common methods in strategy classes.""" |
| |
| from absl.testing import parameterized |
| |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import multi_worker_test_base |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import strategy_combinations |
| from tensorflow.python.distribute import strategy_test_lib |
| from tensorflow.python.distribute import test_util |
| from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import indexed_slices |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.util import nest |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ] + strategy_combinations.all_strategies, |
| mode=['eager'])) |
| class StrategyTest(test.TestCase, parameterized.TestCase): |
| |
| def testCaptureReplicaId(self, strategy): |
| m = {} |
| |
| @def_function.function |
| def f(): |
| return distribute_lib.get_replica_context().replica_id_in_sync_group |
| |
| @def_function.function |
| def g(): |
| # Make g() a stateful function so it's traced twice. |
| if m.get('v', None) is None: |
| m['v'] = variables.Variable(0.) |
| return strategy.run(f) |
| |
| g() |
| |
| def testMergeCallInitScope(self, strategy): |
| with strategy.scope(): |
| |
| @def_function.function |
| def fn(): |
| |
| def merge_fn(unused_strat): |
| |
| y = constant_op.constant(11) |
| return y |
| |
| def replica_fn(): |
| |
| with ops.init_scope(): |
| y = distribute_lib.get_replica_context().merge_call(merge_fn) |
| z = y + 1 |
| return z |
| |
| return strategy.run(replica_fn) |
| |
| result = strategy.experimental_local_results(fn()) |
| self.assertAllClose(result, [12] * _get_num_replicas_per_client(strategy)) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_cpu_1_and_2, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.tpu_strategy |
| ], |
| mode=['graph', 'eager'])) |
| class StrategyLocalResultTest(test.TestCase): |
| |
| def testLocalResultForDictionary(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return {'a': constant_op.constant(1.), 'b': constant_op.constant(2.)} |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| got = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertEqual(got, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.})) |
| |
| def testLocalResultForList(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return [constant_op.constant(1.), constant_op.constant(2.)] |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| got = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertEqual(got, ([1., 2.], [1., 2.])) |
| |
| def testLocalResultForTuple(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return (constant_op.constant(1.), constant_op.constant(2.), |
| constant_op.constant(3.)) |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| got = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertEqual(got, ((1., 2., 3.), (1., 2., 3.))) |
| |
| def testLocalResultForNestedStruct(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return ({ |
| 'a': constant_op.constant(1.), |
| 'b': constant_op.constant(2.) |
| }, { |
| 'a': constant_op.constant(4.), |
| 'b': constant_op.constant(6.) |
| }) |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| got = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertEqual(got, (({ |
| 'a': 1., |
| 'b': 2. |
| }, { |
| 'a': 4., |
| 'b': 6. |
| }), ({ |
| 'a': 1., |
| 'b': 2. |
| }, { |
| 'a': 4., |
| 'b': 6. |
| }))) |
| |
| def testLocalResultForNestedStructWithoutTensor(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return {'a': 1., 'b': 2.} |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| v = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertIsInstance(v, tuple) |
| self.assertAllEqual(v, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.})) |
| |
| def testLocalResultForScalarValue(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return distribution.extended._get_local_replica_id( |
| distribute_lib.get_replica_context().replica_id_in_sync_group) |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| v = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertIsInstance(v, tuple) |
| self.assertEqual(v, (0, 1)) |
| |
| def testLocalResultForDictionaryDifferentReplicas(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| replica_id = distribution.extended._get_local_replica_id( |
| distribute_lib.get_replica_context().replica_id_in_sync_group) |
| return { |
| 'a': math_ops.cast(replica_id + 1, dtype=float), |
| 'b': math_ops.cast(replica_id + 2, dtype=float) |
| } |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| got = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertAllEqual(got, ({'a': 1., 'b': 2.}, {'a': 2., 'b': 3.})) |
| |
| def testLocalResultForTensor(self, distribution): |
| |
| @def_function.function |
| def model_fn(): |
| return constant_op.constant([2., 3.]) |
| |
| with distribution.scope(): |
| result = distribution.run(model_fn) |
| v = self.evaluate(distribution.experimental_local_results(result)) |
| self.assertAllEqual(v, ([2., 3.], [2., 3.])) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ] + strategy_combinations.all_strategies, |
| mode=['eager'])) |
| class ReduceTest(test.TestCase, parameterized.TestCase): |
| |
| def testBasic(self, strategy): |
| per_replica_value = strategy.experimental_distribute_values_from_function( |
| lambda _: array_ops.ones((), dtypes.float32)) |
| |
| def fn_eager(): |
| |
| return strategy.reduce( |
| reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) |
| |
| fn_graph = def_function.function(fn_eager) |
| # Run reduce under the strategy scope to explicitly enter |
| # strategy default_device scope. |
| with strategy.scope(): |
| self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) |
| self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) |
| |
| # Run reduce without a strategy scope to implicitly enter |
| # strategy default_device scope. |
| self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) |
| self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) |
| |
| def testAxis(self, strategy): |
| |
| @def_function.function |
| def fn(): |
| return constant_op.constant([1., 2.]) |
| |
| x = strategy.run(fn) |
| |
| x_m = strategy.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) |
| self.assertEqual(1.5, x_m) |
| x_s = strategy.reduce(reduce_util.ReduceOp.SUM, x, axis=0) |
| self.assertEqual(3 * strategy.num_replicas_in_sync, x_s) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.default_strategy, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| ], |
| update_fn=['assign', 'assign_add', 'assign_sub'], |
| tf_function=[True, False], |
| mode=['eager'])) |
| class ReplicaCtxUpdateTest(test.TestCase, parameterized.TestCase): |
| |
| def testDenseUpdate(self, strategy, tf_function, update_fn): |
| if strategy_test_lib.is_tpu_strategy(strategy) and (not tf_function): |
| self.skipTest('Skip TPUStrategy + eager combination.') |
| with strategy.scope(): |
| distributed_variable1 = variables.Variable(5.0) |
| |
| def replica_fn(): |
| value = array_ops.constant(2.) |
| python_literal = 1. |
| replica_context = distribute_lib.get_replica_context() |
| fn_sets = { |
| 'assign': lambda var, value: var.assign(value), |
| 'assign_add': lambda var, value: var.assign_add(value), |
| 'assign_sub': lambda var, value: var.assign_sub(value), |
| } |
| replica_context._update( |
| distributed_variable1, fn_sets[update_fn], args=(value,)) |
| replica_context._update( |
| distributed_variable1, fn_sets[update_fn], args=(python_literal,)) |
| |
| if tf_function: |
| replica_fn = def_function.function(replica_fn) |
| strategy.run(replica_fn) |
| |
| expected_result = {'assign': 1., 'assign_add': 8., 'assign_sub': 2.} |
| self.assertAllEqual( |
| strategy.experimental_local_results(distributed_variable1), |
| [expected_result[update_fn]] * _get_num_replicas_per_client(strategy)) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| ] + strategy_combinations.strategies_minus_tpu, |
| tf_function=[combinations.tf_function, combinations.no_tf_function], |
| mode=['eager'])) |
| class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase): |
| |
| def testDense(self, strategy, tf_function): |
| if (strategy_test_lib.is_tpu_strategy(strategy) and |
| tf_function is combinations.no_tf_function): |
| self.skipTest('Skip TPUStrategy + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = array_ops.identity(1.0) |
| reduced = strategy.extended._replica_ctx_all_reduce( |
| reduce_util.ReduceOp.SUM, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync) |
| |
| def testSparse(self, strategy, tf_function): |
| if tf_function is combinations.no_tf_function: |
| self.skipTest('Skip IndexedSlices + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| reduced = strategy.extended._replica_ctx_all_reduce( |
| reduce_util.ReduceOp.SUM, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| expect = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| self.assertAllEqual( |
| ops.convert_to_tensor(got), ops.convert_to_tensor(expect)) |
| |
| def testNestedInput(self, strategy, tf_function): |
| if tf_function is combinations.no_tf_function: |
| self.skipTest('Skip IndexedSlices + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = (array_ops.identity(1.0), |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])), |
| array_ops.identity(2.0), |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[2.0]]), |
| indices=array_ops.identity([1]), |
| dense_shape=array_ops.identity([5, 1]))) |
| reduced = strategy.extended._replica_ctx_all_reduce( |
| reduce_util.ReduceOp.SUM, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| expect = (1.0 * strategy.num_replicas_in_sync, |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity( |
| [[1.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])), |
| 2.0 * strategy.num_replicas_in_sync, |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity( |
| [[2.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([1]), |
| dense_shape=array_ops.identity([5, 1]))) |
| |
| self.assertAllClose( |
| nest.map_structure(ops.convert_to_tensor, got), |
| nest.map_structure(ops.convert_to_tensor, expect)) |
| |
| def testSyncOnReadVariableInput(self, strategy, tf_function): |
| if (not strategy_test_lib.is_mirrored_strategy(strategy) and |
| not strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and |
| not strategy_test_lib.is_tpu_strategy(strategy)): |
| self.skipTest('Skip strategies not using SyncOnReadVariables.') |
| if (strategy_test_lib.is_tpu_strategy(strategy) and |
| tf_function is combinations.no_tf_function): |
| self.skipTest('Skip TPUStrategy + eager combination.') |
| if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and |
| tf_function is combinations.tf_function): |
| self.skipTest('Skip MWMS + graph combination until b/228512201 is fixed.') |
| |
| with strategy.scope(): |
| var = variables.Variable( |
| 0.0, |
| synchronization=variables.VariableSynchronization.ON_READ, |
| aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) |
| |
| @tf_function |
| def replica_fn(): |
| replica_context = distribute_lib.get_replica_context() |
| replica_id = replica_context.replica_id_in_sync_group |
| var.assign(math_ops.cast(replica_id, dtype=float) * 3.0) |
| |
| return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var) |
| |
| if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy): |
| client_local_replica_num = strategy.extended._num_devices_per_worker |
| else: |
| client_local_replica_num = strategy.num_replicas_in_sync |
| |
| workers_num = strategy.num_replicas_in_sync |
| expected_sum = sum(range(workers_num)) * 3.0 |
| |
| # Expand the values on each replica if multiple devices are used; otherwise |
| # simple read the value of the Tensor. |
| result = strategy.run(replica_fn) |
| if hasattr(result, 'values'): |
| result = result.values |
| result = nest.flatten(result) |
| |
| # Iterate through all replicas and verify the reduce sum result. |
| for i in range(client_local_replica_num): |
| self.assertEqual(result[i].numpy(), expected_sum) |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| ] + strategy_combinations.strategies_minus_tpu, |
| tf_function=[combinations.tf_function, combinations.no_tf_function], |
| mode=['eager'])) |
| class AllReduceTest(test.TestCase, parameterized.TestCase): |
| |
| def testDense(self, strategy, tf_function): |
| if (strategy_test_lib.is_tpu_strategy(strategy) and |
| tf_function is combinations.no_tf_function): |
| self.skipTest('Skip TPUStrategy + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = array_ops.identity(1.0) |
| rep_ctx = distribute_lib.get_replica_context() |
| reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync) |
| |
| def testSparse(self, strategy, tf_function): |
| if tf_function is combinations.no_tf_function: |
| self.skipTest('Skip IndexedSlices + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| rep_ctx = distribute_lib.get_replica_context() |
| reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| |
| if not strategy_test_lib.is_tpu_strategy(strategy): |
| self.assertIsInstance(got, indexed_slices.IndexedSlices) |
| expect = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| self.assertAllEqual( |
| ops.convert_to_tensor(got), ops.convert_to_tensor(expect)) |
| |
| def testSparseTuple(self, strategy, tf_function): |
| if tf_function is combinations.no_tf_function: |
| self.skipTest('Skip IndexedSlices + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value1 = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| value2 = indexed_slices.IndexedSlices( |
| values=array_ops.identity([[2.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| rep_ctx = distribute_lib.get_replica_context() |
| reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, [value1, value2]) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| |
| if not strategy_test_lib.is_tpu_strategy(strategy): |
| for g in got: |
| self.assertIsInstance(g, indexed_slices.IndexedSlices) |
| expect = [ |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])), |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[2.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])) |
| ] |
| self.assertAllEqual( |
| nest.map_structure(ops.convert_to_tensor, got), |
| nest.map_structure(ops.convert_to_tensor, expect)) |
| |
| def testNestedInput(self, strategy, tf_function): |
| if tf_function is combinations.no_tf_function: |
| self.skipTest('Skip IndexedSlices + eager combination.') |
| |
| @tf_function |
| def fn(): |
| |
| def replica_fn(): |
| value = (array_ops.identity(1.0), |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[1.0]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])), |
| array_ops.identity(2.0), |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity([[2.0]]), |
| indices=array_ops.identity([1]), |
| dense_shape=array_ops.identity([5, 1]))) |
| rep_ctx = distribute_lib.get_replica_context() |
| reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value) |
| return reduced |
| |
| return strategy.experimental_local_results(strategy.run(replica_fn)) |
| |
| got = fn()[0] |
| expect = (1.0 * strategy.num_replicas_in_sync, |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity( |
| [[1.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([0]), |
| dense_shape=array_ops.identity([5, 1])), |
| 2.0 * strategy.num_replicas_in_sync, |
| indexed_slices.IndexedSlices( |
| values=array_ops.identity( |
| [[2.0 * strategy.num_replicas_in_sync]]), |
| indices=array_ops.identity([1]), |
| dense_shape=array_ops.identity([5, 1]))) |
| |
| self.assertAllClose( |
| nest.map_structure(ops.convert_to_tensor, got), |
| nest.map_structure(ops.convert_to_tensor, expect)) |
| |
| |
| def _make_indexed_slices(values, indices, dense_shape): |
| tensor = indexed_slices.IndexedSlices( |
| values=constant_op.constant(values), |
| indices=constant_op.constant(indices), |
| dense_shape=constant_op.constant(dense_shape)) |
| return tensor |
| |
| |
| def _get_num_replicas_per_client(strategy): |
| if isinstance(strategy, CollectiveAllReduceStrategy): |
| resolver = strategy.cluster_resolver |
| return max(nest.flatten(resolver.num_accelerators())[0], 1) |
| else: |
| return strategy.num_replicas_in_sync |
| |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ], |
| mode=['eager'])) |
| class DistributedCollectiveAllReduceStrategyTest( |
| strategy_test_lib.DistributionTestBase, |
| parameterized.TestCase): |
| |
| def testDatasetFromFunction(self, strategy): |
| def dataset_fn(input_context): |
| global_batch_size = 10 |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| d = dataset_ops.DatasetV2.range(100).repeat().batch(batch_size) |
| return d.shard(input_context.num_input_pipelines, |
| input_context.input_pipeline_id) |
| |
| expected_sum_on_workers = {'chief': 10, 'worker': 35} |
| input_iterator = iter( |
| strategy.distribute_datasets_from_function(dataset_fn)) |
| |
| @def_function.function |
| def run(iterator): |
| return strategy.experimental_local_results(iterator.get_next()) |
| |
| result = run(input_iterator) |
| sum_value = math_ops.reduce_sum(result) |
| self.assertEqual( |
| sum_value.numpy(), |
| expected_sum_on_workers[multi_worker_test_base.get_task_type()]) |
| |
| def testSimpleInputFromDatasetLastPartialBatch(self, strategy): |
| global_batch_size = 8 |
| dataset = dataset_ops.DatasetV2.range(14).batch( |
| global_batch_size, drop_remainder=False) |
| input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) |
| |
| @def_function.function |
| def run(input_iterator): |
| return strategy.run(lambda x: x, args=(next(input_iterator),)) |
| |
| # Let the complete batch go. |
| run(input_iterator) |
| |
| # `result` is an incomplete batch |
| result = run(input_iterator) |
| expected_data_on_workers = {'chief': [8, 9, 10], 'worker': [11, 12, 13]} |
| self.assertAllEqual( |
| expected_data_on_workers[multi_worker_test_base.get_task_type()], |
| result.numpy(), |
| ) |
| |
| def testSimpleInputFromFnLastPartialBatch(self, strategy): |
| |
| def dataset_fn(input_context): |
| global_batch_size = 8 |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| dataset = dataset_ops.DatasetV2.range(14).batch( |
| batch_size, drop_remainder=False) |
| return dataset.shard(input_context.num_input_pipelines, |
| input_context.input_pipeline_id) |
| |
| input_iterator = iter( |
| strategy.distribute_datasets_from_function(dataset_fn)) |
| |
| @def_function.function |
| def run(input_iterator): |
| return strategy.run(lambda x: x, args=(next(input_iterator),)) |
| |
| # Let the complete batch go. |
| run(input_iterator) |
| # `result` is an incomplete batch |
| result = run(input_iterator) |
| |
| expected_data_on_worker = {'chief': [8, 9, 10, 11], 'worker': [12, 13]} |
| self.assertAllEqual( |
| expected_data_on_worker[multi_worker_test_base.get_task_type()], |
| result.numpy()) |
| |
| def testReduceHostTensor(self, strategy): |
| reduced = strategy.reduce( |
| reduce_util.ReduceOp.SUM, array_ops.identity(1.), axis=None) |
| self.assertEqual(reduced.numpy(), 2.) |
| |
| def testReduceToHostTensor(self, strategy): |
| value = array_ops.identity(1.) |
| reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value, |
| value) |
| self.assertEqual(reduced.numpy(), 2.) |
| |
| def testBatchReduceToHostTensor(self, strategy): |
| value = array_ops.identity(1.) |
| reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, |
| [(value, value), |
| (value, value)]) |
| self.assertAllEqual([2., 2.], reduced) |
| |
| def testReduceDeviceTensors(self, strategy): |
| value = strategy.run(lambda: array_ops.identity(1.)) |
| reduced = strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) |
| self.assertEqual(reduced.numpy(), 2.) |
| |
| def testReduceToDeviceTensors(self, strategy): |
| value = strategy.run(lambda: array_ops.identity(1.)) |
| reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value, |
| value) |
| self.assertEqual(reduced.numpy(), 2.) |
| |
| def testBatchReduceToDeviceTensors(self, strategy): |
| value = strategy.run(lambda: array_ops.identity(1.)) |
| reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, |
| [(value, value), |
| (value, value)]) |
| self.assertAllEqual([2., 2.], reduced) |
| |
| # TODO(crccw): add a test that mixes device and host tensors after multi |
| # worker strategy combinations can run on a fixed number of GPUs. |
| |
| |
| class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[strategy_combinations.multi_worker_mirrored_2x1_cpu] + |
| strategy_combinations.all_strategies, |
| mode=['eager'])) |
| def testClusterResolverProperty(self, strategy): |
| # CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver. |
| # `None` otherwise. |
| resolver = strategy.cluster_resolver |
| if (not isinstance(strategy, CollectiveAllReduceStrategy) and |
| not strategy_test_lib.is_tpu_strategy(strategy)): |
| self.assertIsNone(resolver) |
| return |
| |
| with strategy.scope(): |
| self.assertIs(strategy.cluster_resolver, resolver) |
| |
| self.assertTrue(hasattr(resolver, 'cluster_spec')) |
| self.assertTrue(hasattr(resolver, 'master')) |
| self.assertTrue(hasattr(resolver, 'num_accelerators')) |
| self.assertTrue(hasattr(resolver, 'task_id')) |
| self.assertTrue(hasattr(resolver, 'task_type')) |
| if isinstance(strategy, CollectiveAllReduceStrategy): |
| self.assertEqual(resolver.task_id, 0) |
| self.assertAllInSet(resolver.task_type, ['chief', 'worker']) |
| |
| |
| if __name__ == '__main__': |
| test_util.main() |