| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for the input_lib library.""" |
| |
| import collections |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.python import tf2 |
| from tensorflow.python.data.experimental.ops import data_service_ops |
| from tensorflow.python.data.experimental.service import server_lib |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import options as options_lib |
| from tensorflow.python.data.ops.options import AutoShardPolicy |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import input_lib |
| from tensorflow.python.distribute import input_ops |
| from tensorflow.python.distribute import input_util |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import strategy_combinations |
| from tensorflow.python.distribute import test_util |
| from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import test |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import extension_type |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import test_util as framework_test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import sparse_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib |
| from tensorflow.python.util import nest |
| |
| |
| class DistributedIteratorTestBase(test.TestCase): |
| |
| # The passed input_context is to create a sharded dataset in between-graph |
| # case. |
| # TODO(yuefengz): rewrite the following method to make it less DRY. |
| def _wrap_iterator(self, |
| input_type, |
| dataset_or_input_fn, |
| input_workers, |
| devices, |
| num_replicas_in_sync, |
| strategy, |
| input_context=None): |
| # The `input_context` passed in is to shard dataset for |
| # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where |
| # multiple InputContexts are needed. |
| if input_type == "input_fn": |
| self.assertIsNone( |
| input_context, |
| msg=("`The input_context` arg is only used to shard dataset in " |
| "`MultiWorkerMirroredStrategy` when the input type is dataset.")) |
| |
| input_contexts = [] |
| for i in range(input_workers.num_workers): |
| input_contexts.append( |
| distribute_lib.InputContext( |
| # Note: `input_workers.num_workers` is always 1 in between-graph |
| # case. |
| num_input_pipelines=input_workers.num_workers, |
| input_pipeline_id=i, |
| num_replicas_in_sync=len(devices))) |
| |
| iterator = input_lib_v1.InputFunctionIterator(dataset_or_input_fn, |
| input_workers, |
| input_contexts, strategy) |
| else: |
| iterator = input_lib_v1.DatasetIterator( |
| dataset_or_input_fn, |
| input_workers, |
| strategy, |
| num_replicas_in_sync=num_replicas_in_sync, |
| input_context=input_context) |
| return iterator |
| |
| def _wrap_dataset(self, |
| input_type, |
| dataset, |
| input_workers, |
| num_replicas_in_sync, |
| strategy, |
| input_context=None): |
| if input_type == "dataset": |
| if tf2.enabled(): |
| return input_lib.DistributedDataset( |
| input_workers, |
| strategy, |
| dataset, |
| num_replicas_in_sync=num_replicas_in_sync, |
| input_context=input_context) |
| else: |
| return input_lib_v1.DistributedDatasetV1( |
| dataset, |
| input_workers, |
| strategy, |
| num_replicas_in_sync=num_replicas_in_sync, |
| input_context=input_context) |
| else: |
| return strategy.distribute_datasets_from_function(dataset) |
| |
| def _assert_iterator_values(self, |
| iterator, |
| expected_values, |
| evaluate_fn, |
| devices, |
| enable_get_next_as_optional=False): |
| actual_values = [] |
| for _ in range(len(expected_values)): |
| if enable_get_next_as_optional: |
| next_element = iterator.get_next_as_optional().get_value() |
| else: |
| next_element = iterator.get_next() |
| computed_value = evaluate_fn([ |
| distribute_utils.select_replica(r, next_element) |
| for r in range(len(devices)) |
| ]) |
| actual_values.append(computed_value) |
| for expected_value, actual_value in zip(expected_values, actual_values): |
| for expected, actual in zip(expected_value, actual_value): |
| self.assertAllEqual(expected, actual) |
| |
| def _assert_dataset_values_for_loop(self, dataset, expected_values, |
| evaluate_fn, devices): |
| actual_values = [] |
| for x in dataset: |
| computed_value = self.evaluate( |
| [distribute_utils.select_replica(r, x) for r in range(len(devices))]) |
| actual_values.append(computed_value) |
| for expected_value, actual_value in zip(expected_values, actual_values): |
| for expected, actual in zip(expected_value, actual_value): |
| self.assertAllEqual(expected, actual) |
| |
| def _test_input_iteration(self, |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| strategy, |
| sess=None, |
| num_replicas_in_sync=None, |
| input_context=None): |
| if iteration_type == "for_loop" and not context.executing_eagerly(): |
| self.skipTest("unsupported test combination.") |
| |
| if api_type == "wrap_into_iterator" and iteration_type == "for_loop": |
| self.skipTest("unsupported test combination.") |
| |
| if api_type == "wrap_into_iterator" and input_type == "input_fn": |
| self.skipTest("unsupported test combination.") |
| |
| devices = nest.flatten([ds for _, ds in worker_device_pairs]) |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| |
| if api_type == "wrap_into_iterator": |
| iterator = self._wrap_iterator( |
| input_type, |
| dataset_or_input_fn, |
| input_workers, |
| devices, |
| num_replicas_in_sync, |
| strategy, |
| input_context=input_context) |
| else: |
| # wrapping into a dataset: |
| dataset = self._wrap_dataset( |
| input_type, |
| dataset_or_input_fn, |
| input_workers, |
| num_replicas_in_sync, |
| strategy, |
| input_context=input_context) |
| |
| if ops.executing_eagerly_outside_functions(): |
| iterator = iter(dataset) |
| else: |
| if isinstance(dataset, input_lib_v1.DistributedDatasetV1): |
| iterator = dataset.make_initializable_iterator() |
| else: |
| self.skipTest("unsupported test combination") |
| |
| if isinstance(iterator, composite_tensor.CompositeTensor): |
| nest.assert_same_structure( |
| iterator, iterator._type_spec, expand_composites=True) |
| |
| if iteration_type == "get_next": |
| evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) |
| if not ops.executing_eagerly_outside_functions(): |
| evaluate(control_flow_ops.group(iterator.initializer)) |
| |
| def test_get_next(iterator): |
| self._assert_iterator_values(iterator, expected_values, evaluate, |
| devices) |
| |
| with self.assertRaises(errors.OutOfRangeError): |
| self._assert_iterator_values(iterator, expected_values, evaluate, |
| devices) |
| |
| # After re-initializing the iterator, should be able to iterate again. |
| if not ops.executing_eagerly_outside_functions(): |
| evaluate(control_flow_ops.group(iterator.initializer)) |
| else: |
| if api_type == "wrap_into_iterator": |
| self.skipTest("unsupported test combination") |
| else: |
| iterator = iter(dataset) |
| |
| self._assert_iterator_values(iterator, expected_values, evaluate, |
| devices) |
| |
| def test_get_next_as_optional(iterator): |
| self._assert_iterator_values( |
| iterator, |
| expected_values, |
| evaluate, |
| devices, |
| enable_get_next_as_optional=True) |
| |
| next_element = iterator.get_next_as_optional() |
| self.assertFalse(self.evaluate(next_element.has_value())) |
| with self.assertRaises(errors.InvalidArgumentError): |
| self._assert_iterator_values( |
| iterator, [0], |
| evaluate, |
| devices, |
| enable_get_next_as_optional=True) |
| |
| test_get_next(iterator) |
| |
| # re-initializing the iterator |
| if not tf2.enabled(): |
| # TODO(yuefengz): we should split this function. |
| return |
| else: |
| if api_type == "wrap_into_iterator": |
| return |
| else: |
| iterator = iter(dataset) |
| |
| test_get_next_as_optional(iterator) |
| |
| if iteration_type == "for_loop" and context.executing_eagerly(): |
| self._assert_dataset_values_for_loop(dataset, expected_values, |
| self.evaluate, devices) |
| |
| def _create_dataset_or_input_fn(self, input_type, input_fn): |
| if input_type == "input_fn": |
| return input_fn |
| else: |
| return input_fn(distribute_lib.InputContext()) |
| |
| |
| class DistributedIteratorTest(DistributedIteratorTestBase, |
| parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu |
| ])) |
| def testMultiDeviceIterInitialize(self, distribution): |
| if tf2.enabled(): |
| self.skipTest("Only V1 is supported.") |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:CPU:0"])] |
| dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) |
| |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| |
| dist_dataset = input_util.get_distributed_dataset( |
| dataset_fn(distribute_lib.InputContext()), input_workers, distribution) |
| |
| iterator = dataset_ops.make_one_shot_iterator(dist_dataset) |
| |
| @def_function.function |
| def init_func_for_iter(): |
| self.evaluate(iterator.initializer) |
| |
| init_func_for_iter() |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, |
| enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| dataset_fn = lambda _: dataset_ops.Dataset.range(10) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [[i] for i in range(10)] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["input_fn", "dataset"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_one_gpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| ) |
| ) |
| def testAutoShardExplicit(self, input_type, distribution): |
| worker_device_pairs = [( |
| "/device:CPU:0", |
| distribution.extended.worker_devices, |
| )] |
| dataset_fn = lambda _: dataset_ops.Dataset.range(10).batch(1) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn |
| ) |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| dataset = self._wrap_dataset( |
| input_type, |
| dataset_or_input_fn, |
| input_workers, |
| num_replicas_in_sync=None, |
| strategy=distribution) |
| |
| dataset1 = input_ops.auto_shard_dataset(dataset, 2, 0) |
| iterator = iter(dataset1) |
| |
| if len(distribution.extended.worker_devices) == 2: |
| expected_values = [[0, 2], [4, 6], [8]] |
| else: |
| expected_values = [[0], [2], [4], [6], [8]] |
| for element, expected in zip(iterator, expected_values): |
| local = distribution.experimental_local_results(element) |
| local_list = array_ops.concat(local, axis=0).numpy().tolist() |
| self.assertAllEqual(local_list, expected) |
| |
| if len(distribution.extended.worker_devices) == 2: |
| expected_values = [[1, 3], [5, 7], [9]] |
| else: |
| expected_values = [[1], [3], [5], [7], [9]] |
| dataset2 = input_ops.auto_shard_dataset(dataset, 2, 1) |
| iterator = iter(dataset2) |
| for element, expected in zip(iterator, expected_values): |
| local = distribution.experimental_local_results(element) |
| local_list = array_ops.concat(local, axis=0).numpy().tolist() |
| self.assertAllEqual(local_list, expected) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu], |
| enable_get_next_as_optional=[True, False])) |
| def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type, |
| distribution, enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [[i] for i in range(10)] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, |
| distribution, enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:CPU:0"])] |
| dataset_fn = lambda _: dataset_ops.Dataset.range(10) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [[i, i + 1] for i in range(0, 10, 2)] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[strategy_combinations.tpu_strategy], |
| enable_get_next_as_optional=[True, False])) |
| def testTPU(self, input_type, api_type, iteration_type, distribution, |
| enable_get_next_as_optional): |
| worker_device_pairs = collections.OrderedDict() |
| for tpu_device in distribution.extended.worker_devices: |
| host_device = device_util.get_host_for_device(tpu_device) |
| worker_device_pairs.setdefault(host_device, []) |
| worker_device_pairs[host_device].append(tpu_device) |
| worker_device_pairs = worker_device_pairs.items() |
| dataset_fn = lambda _: dataset_ops.Dataset.range(10) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [[i, i + 1] for i in range(0, 10, 2)] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu, |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testTupleDataset(self, input_type, api_type, iteration_type, distribution, |
| enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:CPU:0"])] |
| |
| def dataset_fn(ctx): |
| del ctx |
| dataset1 = dataset_ops.Dataset.range(10) |
| dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) |
| return dataset_ops.Dataset.zip((dataset1, dataset2)) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [ |
| [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) |
| ] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type, |
| distribution, enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:GPU:1"])] |
| |
| def dataset_fn(ctx): |
| del ctx |
| dataset1 = dataset_ops.Dataset.range(10) |
| dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) |
| return dataset_ops.Dataset.zip((dataset1, dataset2)) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| expected_values = [ |
| [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) |
| ] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| |
| # Input_context is not passed in and thus no sharding. |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testIterableIterator(self, distribution): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| |
| dataset = dataset_ops.Dataset.range(10) |
| dist_dataset = input_util.get_distributed_dataset(dataset, input_workers, |
| distribution) |
| |
| iterator = iter(dist_dataset) |
| for i, element in enumerate(iterator): |
| self.assertAllEqual(distribution.experimental_local_results(element), [i]) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| ], |
| use_iterator=[False, True])) |
| def testIteratorAndDatasetEnumerateError(self, distribution, use_iterator): |
| # enumerate is not supported within tf.function for these types. |
| dataset = dataset_ops.Dataset.range(10).batch(2) |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| |
| if use_iterator: |
| iterable = iter(dist_dataset) |
| else: |
| iterable = dist_dataset |
| |
| @def_function.function |
| def enumerate_fn(iterable): |
| for _, batch in enumerate(iterable): |
| distribution.experimental_local_results(batch) |
| |
| with self.assertRaises(NotImplementedError): |
| enumerate_fn(iterable) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| ])) |
| def testIterableIteratorError(self, distribution): |
| dataset = dataset_ops.Dataset.range(10).batch(2) |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| |
| iterator = iter(dist_dataset) |
| # Raises error when next(iterator) is called without strategy scope |
| with self.assertRaises(ValueError): |
| |
| def replica_fn1(iterator): |
| return next(iterator) |
| |
| distribution.run(replica_fn1, args=(iterator,)) |
| |
| if distribution.num_replicas_in_sync == 1: |
| expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]] |
| elif distribution.num_replicas_in_sync == 2: |
| expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]], |
| [[8], [9]]] |
| |
| with distribution.scope(): |
| |
| def replica_fn2(iterator): |
| return iterator |
| |
| result = distribution.run(replica_fn2, args=(next(iterator),)) |
| self.assertAllEqual( |
| distribution.experimental_local_results(result), expected_result[0]) |
| |
| # Confirm default ReplicaContext also works |
| iterator = iter(dist_dataset) |
| for i, element in enumerate(iterator): |
| self.assertAllEqual( |
| distribution.experimental_local_results(element), expected_result[i]) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| drop_remainder=[True, False], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu |
| ])) |
| def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, |
| drop_remainder, distribution): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:CPU:0"])] |
| dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda |
| 2, drop_remainder=drop_remainder) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| # The last global batch only contains data for one replica. |
| if drop_remainder: |
| expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] |
| else: |
| expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| drop_remainder=[True, False], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ])) |
| def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type, |
| iteration_type, drop_remainder, |
| distribution): |
| # Actual devices don't matter in this test as long as the number of global |
| # repices is 2. |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| cr = distribution.cluster_resolver |
| self.assertIsNotNone(cr) |
| worker_count = multi_worker_util.worker_count(cr.cluster_spec(), |
| cr.task_type) |
| id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), |
| cr.task_type, cr.task_id) |
| |
| def dataset_fn(_): |
| dataset = dataset_ops.Dataset.range(9) |
| |
| if input_type == "input_fn": |
| # When input_fn is used, there is no automatic rebatching and sharding, |
| # so we add them here. |
| return dataset.shard(worker_count, id_in_cluster).batch(1) |
| else: |
| return dataset.batch(2, drop_remainder=drop_remainder) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| if drop_remainder and input_type == "dataset": |
| if id_in_cluster == 0: |
| expected_values = [[[0]], [[2]], [[4]], [[6]]] |
| else: |
| expected_values = [[[1]], [[3]], [[5]], [[7]]] |
| else: |
| # The last global batch only contains data for one replica. |
| if id_in_cluster == 0: |
| expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]] |
| else: |
| expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]] |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| num_replicas_in_sync=distribution.num_replicas_in_sync, |
| input_context=distribution.extended._make_input_context()) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["input_fn", "dataset"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| drop_remainder=[True, False], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call |
| ])) |
| def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type, |
| api_type, iteration_type, |
| drop_remainder, |
| distribution): |
| # Actual devices don't matter in this test as long as the number of global |
| # repices is 2. |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:GPU:1"])] |
| cr = distribution.cluster_resolver |
| self.assertIsNotNone(cr) |
| worker_count = multi_worker_util.worker_count(cr.cluster_spec(), |
| cr.task_type) |
| id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), |
| cr.task_type, cr.task_id) |
| |
| def dataset_fn(_): |
| dataset = dataset_ops.Dataset.range(15) |
| |
| if input_type == "input_fn": |
| # When input_fn is used, there is no automatic rebatching and sharding, |
| # so we add them here. |
| return dataset.shard(worker_count, id_in_cluster).batch(1) |
| else: |
| return dataset.batch(4, drop_remainder=drop_remainder) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| # The last global batch only contains data for one replica. |
| if drop_remainder and input_type == "dataset": |
| if id_in_cluster == 0: |
| expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]] |
| else: |
| expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]] |
| else: |
| if id_in_cluster == 0: |
| expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]] |
| else: |
| expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]] |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| num_replicas_in_sync=distribution.num_replicas_in_sync, |
| input_context=distribution.extended._make_input_context()) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["graph", "eager"], |
| input_type=["dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| num_replicas_in_sync=[None, 2], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testBatchSplitting(self, input_type, api_type, iteration_type, |
| num_replicas_in_sync, distribution, |
| enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:CPU:0"])] |
| batch_size = 10 |
| dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| updated_batch_size = ( |
| batch_size // |
| num_replicas_in_sync if num_replicas_in_sync else batch_size) |
| expected_values = [[ |
| range(i, i + updated_batch_size), |
| range(i + updated_batch_size, i + 2 * updated_batch_size) |
| ] for i in range(0, 100, updated_batch_size * 2)] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| sess=None, |
| num_replicas_in_sync=num_replicas_in_sync) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["dataset"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| num_replicas_in_sync=[None, 2], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call |
| ], |
| enable_get_next_as_optional=[True, False])) |
| def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, |
| num_replicas_in_sync, distribution, |
| enable_get_next_as_optional): |
| worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", |
| "/device:GPU:1"])] |
| batch_size = 10 |
| cr = distribution.cluster_resolver |
| self.assertIsNotNone(cr) |
| |
| def dataset_fn(_): |
| dataset = dataset_ops.Dataset.range(100).batch(batch_size) |
| return dataset |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| updated_batch_size = ( |
| batch_size // |
| num_replicas_in_sync if num_replicas_in_sync else batch_size) |
| expected_values = [ |
| [ # pylint: disable=g-complex-comprehension |
| range(i, i + updated_batch_size), |
| range(i + updated_batch_size, i + 2 * updated_batch_size) |
| ] for i in range(0, 100, updated_batch_size * 2) |
| ] |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| sess=None, |
| num_replicas_in_sync=num_replicas_in_sync) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ], |
| )) |
| def testCacheAcrossIteration(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| |
| dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4) |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| |
| first_epoch = list( |
| distribution.experimental_local_results(x) for x in dist_dataset) |
| second_epoch = list( |
| distribution.experimental_local_results(x) for x in dist_dataset) |
| |
| self.assertAllEqual(first_epoch, second_epoch) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ], |
| reshuffle=[True, False])) |
| def testShuffleAcrossIterations(self, distribution, reshuffle): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| |
| dataset = dataset_ops.Dataset.range(12).shuffle( |
| 12, reshuffle_each_iteration=reshuffle).batch(4) |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| |
| first_epoch = list( |
| distribution.experimental_local_results(x) for x in dist_dataset) |
| second_epoch = list( |
| distribution.experimental_local_results(x) for x in dist_dataset) |
| |
| if reshuffle: |
| self.assertNotAllEqual(first_epoch, second_epoch) |
| else: |
| self.assertAllEqual(first_epoch, second_epoch) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testGetNextOptionalShapeFinite(self, distribution): |
| batch_size = 8 |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| "feature": array_ops.ones([batch_size, 10]), |
| "label": array_ops.ones([batch_size]), |
| }) |
| dataset = dataset.batch(batch_size, drop_remainder=True) |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| |
| @def_function.function |
| def train_fn(): |
| for data in dist_dataset: |
| data = nest.map_structure(distribution.experimental_local_results, data) |
| feature = data["feature"] |
| label = data["label"] |
| |
| # Assert the shapes are still static from all replicas. |
| for replica_id in range(len(distribution.extended.worker_devices)): |
| self.assertEqual([None, 10], |
| feature[replica_id].shape.as_list()) |
| self.assertEqual([None], label[replica_id].shape.as_list()) |
| |
| train_fn() |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testGetNextOptionalShapeInfinite(self, distribution): |
| batch_size = 8 |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| "feature": array_ops.ones([batch_size, 10]), |
| "label": array_ops.ones([batch_size]), |
| }) |
| dataset = dataset.batch(batch_size, drop_remainder=True) |
| dataset = dataset.repeat() |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| per_replica_batch_size = batch_size // distribution.num_replicas_in_sync |
| |
| @def_function.function |
| def train_fn(): |
| data = iter(dist_dataset).get_next_as_optional().get_value() |
| data = nest.map_structure(distribution.experimental_local_results, data) |
| feature = data["feature"] |
| label = data["label"] |
| |
| # Assert the shapes are still static from all replicas. |
| for replica_id in range(len(distribution.extended.worker_devices)): |
| self.assertEqual([per_replica_batch_size, 10], |
| feature[replica_id].shape.as_list()) |
| self.assertEqual([per_replica_batch_size], |
| label[replica_id].shape.as_list()) |
| |
| train_fn() |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testGetNextOptionalShapeEmpty(self, distribution): |
| batch_size = 8 |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| "feature": array_ops.ones([batch_size, 10]), |
| "label": array_ops.ones([batch_size]), |
| }) |
| dataset = dataset.batch(batch_size, drop_remainder=True) |
| dataset = dataset.repeat() |
| dist_dataset = distribution.experimental_distribute_dataset(dataset) |
| per_replica_batch_size = batch_size // distribution.num_replicas_in_sync |
| |
| @def_function.function |
| def train_fn(): |
| data = iter(dist_dataset).get_next_as_optional() |
| feature_specs = data.element_spec["feature"]._component_specs |
| value_specs = data.element_spec["label"]._component_specs |
| if not isinstance(feature_specs, tuple): |
| feature_specs = (feature_specs,) |
| value_specs = (value_specs,) |
| # Assert the shapes are still static from all replicas. |
| for replica_id in range(len(distribution.extended.worker_devices)): |
| self.assertEqual([per_replica_batch_size, 10], |
| feature_specs[replica_id].shape.as_list()) |
| self.assertEqual([per_replica_batch_size], |
| value_specs[replica_id].shape.as_list()) |
| |
| train_fn() |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ], |
| input_type=["dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF])) |
| def testAutoshardingOption(self, distribution, input_type, api_type, |
| iteration_type, auto_shard_policy): |
| cr = distribution.cluster_resolver |
| self.assertIsNotNone(cr) |
| id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), |
| cr.task_type, cr.task_id) |
| ds_option = options_lib.Options() |
| ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy |
| dataset_fn = ( |
| lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| if auto_shard_policy == AutoShardPolicy.AUTO: |
| if id_in_cluster == 0: |
| expected_values = [[0], [2]] |
| else: |
| expected_values = [[1], [3]] |
| else: |
| expected_values = [[0], [1], [2], [3]] |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset_or_input_fn, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| input_context=distribution.extended._make_input_context()) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ], |
| input_type=["input_fn"], |
| api_type=["wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"])) |
| def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type, |
| iteration_type): |
| cr = distribution.cluster_resolver |
| self.assertIsNotNone(cr) |
| id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), |
| cr.task_type, cr.task_id) |
| |
| def dataset_fn(ctx): |
| if ctx.input_pipeline_id == 0: |
| return dataset_ops.Dataset.range(8).batch(2) |
| else: |
| return dataset_ops.Dataset.range(9).batch(2) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| |
| if id_in_cluster == 0: |
| expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]] |
| else: |
| expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]] |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| self._test_input_iteration(input_type, api_type, iteration_type, |
| dataset_or_input_fn, worker_device_pairs, |
| expected_values, distribution) |
| |
| @combinations.generate( |
| combinations.combine( |
| strategy=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ], |
| mode=["eager"])) |
| def testLoopOverDatasetInTFFunction(self, strategy): |
| dataset = dataset_ops.Dataset.range(10).map(lambda x: { # pylint: disable=g-long-lambda |
| "y": math_ops.cast(x, dtypes.float32) ** 2, |
| }).batch(4) |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| |
| with strategy.scope(): |
| v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM) |
| |
| @def_function.function |
| def iterator_fn(dist_dataset): |
| |
| def assign_add_fn(data): |
| v.assign_add(math_ops.reduce_sum(data["y"])) |
| |
| for data in dist_dataset: |
| strategy.run(assign_add_fn, args=(data,)) |
| |
| iterator_fn(dist_dataset) |
| self.assertEqual(v.numpy(), 285.0) |
| |
| |
| class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, |
| parameterized.TestCase): |
| """Tests for DistributedDataset with non-dense tensors.""" |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| ], |
| input_type=["dataset", "input_fn"], |
| drop_remainder=[False, True], |
| defun_type=["lambda", "tf_function"], |
| )) |
| def testRaggedSparse(self, distribution, input_type, drop_remainder, |
| defun_type): |
| """Test with `RaggedTensor`s and `SparseTensor`s.""" |
| self.skipTest("b/213596871, b/214574707") |
| |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| |
| defun = { |
| "lambda": lambda f: f, |
| "tf_function": def_function.function |
| }[defun_type] |
| distribution.extended.experimental_enable_get_next_as_optional = True |
| global_batch_size = 8 |
| |
| def dataset_fn(ctx=None): |
| ctx = ctx or distribute_lib.InputContext() |
| batch_size = ctx.get_per_replica_batch_size(global_batch_size) |
| # Use 20 which isn't divisible by 8 to test partial batch behavior. |
| row_lengths = np.mod(np.arange(20), 4).astype(np.int64) |
| ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( |
| np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| "dense": ragged_tensor.to_tensor(), |
| "ragged": ragged_tensor, |
| "sparse": ragged_tensor.to_sparse(), |
| }) |
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) |
| return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) |
| |
| dataset_or_input_fn = self._create_dataset_or_input_fn( |
| input_type, dataset_fn) |
| dataset = self._wrap_dataset(input_type, dataset_or_input_fn, |
| distribution.extended._input_workers, |
| distribution.num_replicas_in_sync, |
| distribution) |
| # Assert that the tensors are rebatched and sparsity is preserved. |
| per_replica_batch = defun(lambda x: next(iter(x)))(dataset) |
| self.assertAllEqual( |
| distribute_utils.select_replica(0, per_replica_batch["dense"]), |
| [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) |
| self.assertAllEqual( |
| distribute_utils.select_replica(1, per_replica_batch["dense"]), |
| [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) |
| # Transitively check the ragged and sparse tensors by densification. |
| for i in range(2): |
| self.assertLen( |
| distribute_utils.select_replica(i, |
| per_replica_batch["ragged"]).values, |
| 6) |
| self.assertAllEqual( |
| distribute_utils.select_replica( |
| i, per_replica_batch["ragged"]).to_tensor(), |
| distribute_utils.select_replica(i, per_replica_batch["dense"])) |
| self.assertLen( |
| distribute_utils.select_replica(i, |
| per_replica_batch["sparse"]).indices, |
| 6) |
| self.assertAllEqual( |
| sparse_ops.sparse_tensor_to_dense( |
| distribute_utils.select_replica(i, per_replica_batch["sparse"])), |
| distribute_utils.select_replica(i, per_replica_batch["dense"])) |
| # Iterate through all the batches and sum them up. |
| def sum_batch(per_replica_features): |
| """Sums the `PerReplica` values in the `per_replica_features` map.""" |
| |
| def map_fn(per_replica_values): |
| per_replica_sums = distribution.run( |
| (lambda x: math_ops.reduce_sum(x.values)) if all( |
| map(sparse_tensor.is_sparse, per_replica_values.values)) else |
| math_ops.reduce_sum, (per_replica_values,)) |
| return distribution.reduce( |
| reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) |
| |
| return nest.map_structure(map_fn, per_replica_features) |
| |
| def _reduce(state, batch): |
| sums = sum_batch(batch) |
| return {name: value + sums[name] for name, value in state.items()} |
| |
| def sum_for_loop(dataset): |
| sums = {"dense": 0., "ragged": 0., "sparse": 0.} |
| for batch in dataset: |
| sums = _reduce(sums, batch) |
| return sums |
| |
| def sum_while_loop(iterator, reduce_fn): |
| sums = {"dense": 0., "ragged": 0., "sparse": 0.} |
| while True: |
| try: |
| sums = reduce_fn(sums, iterator) |
| except (StopIteration, errors.OutOfRangeError): |
| return sums |
| |
| while_sums = sum_while_loop( |
| iter(dataset), |
| defun(lambda state, iterator: _reduce(state, next(iterator)))) |
| self.assertAllEqual( |
| nest.flatten(while_sums), |
| # When there's no partial batch, the sum is smaller. |
| [200. if drop_remainder else 310.] * 3) |
| for_sums = defun(sum_for_loop)(dataset) |
| # For loops always call get next as optional inside tf functions, so we |
| # expect 310 here when using an input function (as there are 5 batches of |
| # size 4 round robined over 2 replicas. |
| expected_for_sum = 200. |
| if (not drop_remainder or |
| (defun_type == "tf_function" and input_type == "input_fn")): |
| expected_for_sum = 310. |
| self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu, |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu |
| ], |
| input_type=["dataset", "input_fn"], |
| drop_remainder=[False, True], |
| tensor_type=["sparse", "ragged"], |
| enable_get_next_as_optional=[True, False])) |
| def testRaggedSparseGetNextAsOptional(self, distribution, input_type, |
| drop_remainder, tensor_type, |
| enable_get_next_as_optional): |
| """Test with `RaggedTensor`s and `SparseTensor`s.""" |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| |
| distribution.extended.experimental_enable_get_next_as_optional = ( |
| enable_get_next_as_optional) |
| global_batch_size = 8 |
| |
| def dataset_fn(ctx=None): |
| ctx = ctx or distribute_lib.InputContext() |
| batch_size = ctx.get_per_replica_batch_size(global_batch_size) |
| # Use 20 which isn't divisible by 8 to test partial batch behavior. |
| row_lengths = np.mod(np.arange(20), 4).astype(np.int64) |
| ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( |
| np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| tensor_type: (ragged_tensor if tensor_type == "ragged" else |
| ragged_tensor.to_sparse()), |
| }) |
| dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) |
| return dataset.batch(batch_size, drop_remainder=drop_remainder) |
| |
| if input_type == "dataset": |
| ds = distribution.experimental_distribute_dataset( |
| dataset_fn(distribute_lib.InputContext())) |
| else: |
| ds = distribution.distribute_datasets_from_function(dataset_fn) |
| iterator = iter(ds) |
| |
| self.assertEqual(iterator._enable_get_next_as_optional, |
| (not drop_remainder) and enable_get_next_as_optional) |
| |
| @combinations.generate( |
| combinations.combine( |
| tf_api_version=2, |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.central_storage_strategy_with_gpu_and_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| ], |
| input_type=["dataset", "input_fn"], |
| drop_remainder=[False, True], |
| )) |
| def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type, |
| drop_remainder): |
| """Test with `RaggedTensor`s and `SparseTensor`s.""" |
| global_batch_size = 8 |
| |
| def dataset_fn(ctx=None): |
| ctx = ctx or distribute_lib.InputContext() |
| batch_size = ctx.get_per_replica_batch_size(global_batch_size) |
| # Use 20 which isn't divisible by 8 to test partial batch behavior. |
| row_lengths = np.mod(np.arange(20), 4).astype(np.int64) |
| ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( |
| np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) |
| dataset = dataset_ops.DatasetV2.from_tensor_slices({ |
| "dense": ragged_tensor.to_tensor(), |
| "ragged": ragged_tensor, |
| "sparse": ragged_tensor.to_sparse(), |
| }) |
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) |
| return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) |
| |
| if input_type == "dataset": |
| ds = distribution.experimental_distribute_dataset( |
| dataset_fn(distribute_lib.InputContext())) |
| else: |
| ds = distribution.distribute_datasets_from_function(dataset_fn) |
| |
| # Iterate through all the batches and sum them up. |
| def sum_batch(per_replica_features): |
| """Sums the `PerReplica` values in the `per_replica_features` map.""" |
| |
| def map_fn(per_replica_values): |
| |
| def _sum(value): |
| if sparse_tensor.is_sparse(value): |
| return math_ops.reduce_sum(value.values) |
| else: |
| return math_ops.reduce_sum(value) |
| |
| per_replica_sums = distribution.run(_sum, args=(per_replica_values,)) |
| return distribution.reduce( |
| reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) |
| |
| return nest.map_structure(map_fn, per_replica_features) |
| |
| def _reduce(state, batch): |
| sums = sum_batch(batch) |
| return {name: value + sums[name] for name, value in state.items()} |
| |
| def sum_while_loop(ds): |
| iterator = iter(ds) |
| sums = {"dense": 0., "ragged": 0., "sparse": 0.} |
| try_next = constant_op.constant(True) |
| |
| while try_next: |
| opt_iterate = iterator.get_next_as_optional() |
| if opt_iterate.has_value(): |
| sums = _reduce(sums, opt_iterate.get_value()) |
| else: |
| try_next = False |
| return sums |
| |
| sums = def_function.function(sum_while_loop)(ds) |
| # For loops always call get next as optional inside tf functions, so we |
| # expect 310 here when using an input function (as there are 5 batches of |
| # size 4 round robined over 2 replicas. |
| expected_for_sum = 200. |
| if not drop_remainder or input_type == "input_fn": |
| expected_for_sum = 310. |
| self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ])) |
| def testMWMSPartialBatch(self, input_type, api_type, iteration_type, |
| distribution): |
| # Test case: 2 workers, 1 replica each. |
| # This test simulates the sharded behavior when we have two files each with |
| # 12 elements and a global batch size of 8. When we consider the dataset in |
| # aggregate (non-distributed), there are 24 elements divided into 3 batches |
| # of size 8. Hence, the correct distributed behavior is for each replica to |
| # see sub-batches of size 4, over three steps. |
| def dataset_fn(ctx): |
| del ctx |
| dataset = dataset_ops.Dataset.range(12).batch(8) |
| |
| # Set the sharding behavior to OFF for simplicity of test setup; namely, |
| # `dataset` defines the per-worker dataset and will not be further |
| # sharded. Each worker will see a dataset that is |
| # tf.data.Dataset.range(12).batch(8).rebatch(...). |
| options = options_lib.Options() |
| options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF |
| dataset = dataset.with_options(options) |
| return dataset |
| |
| dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) |
| |
| # Actual devices don't matter in this test as long as there is 1 local |
| # replica. |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| |
| # Each test runs individually on each worker, so we compare the |
| # values on each worker. Each worker should rebatch its dataset into |
| # smaller batches of size 4. |
| expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| num_replicas_in_sync=distribution.num_replicas_in_sync, |
| input_context=distribution.extended._make_input_context()) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ])) |
| def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, |
| iteration_type, distribution): |
| # Test case: 2 workers, 1 replica each. |
| # This test simulates the sharded behavior when we have two files each with |
| # 12 elements and a global batch size of 8. When we consider the dataset in |
| # aggregate (non-distributed), there are 24 elements divided into 3 batches |
| # of size 8. Hence, the correct distributed behavior is for each replica to |
| # see sub-batches of size 4, over three steps. However, when we create a |
| # DistributedDataset and cannot statically infer the intended global batch |
| # size (e.g. if the user does not use a batching dataset), each worker will |
| # rebatch based on the dynamic batch size of the data encountered, even when |
| # it encounters partial batches. The last per-worker partial batch (size 4) |
| # ends up being split into two replicas, resulting in 4 steps in total, of |
| # (global) batch sizes 8, 8, 4, 4. |
| def dataset_fn(ctx): |
| del ctx |
| # The following dataset is equivalent to |
| # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. |
| # This causes DistributedDataset to use LegacyRebatch instead. |
| batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) |
| offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) |
| dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) |
| |
| def map_fn(offset, batch_size): |
| return math_ops.range(offset, offset + batch_size) |
| |
| dataset = dataset.map(map_fn) |
| |
| # Set the sharding behavior to OFF for simplicity of test setup; namely, |
| # `dataset` defines the per-worker dataset and will not be further |
| # sharded. Each worker will see a dataset that is equivalent to |
| # tf.data.Dataset.range(12).batch(8).rebatch(...). |
| options = options_lib.Options() |
| options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF |
| dataset = dataset.with_options(options) |
| return dataset |
| |
| dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) |
| |
| # Actual devices don't matter in this test as long as the number of global |
| # replicas is 2. |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| |
| # Each test runs individually on each worker, so we compare the |
| # values on each worker. Each worker should rebatch its dataset into |
| # smaller batches of size 4. |
| expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| num_replicas_in_sync=distribution.num_replicas_in_sync, |
| input_context=distribution.extended._make_input_context()) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| input_type=["dataset"], |
| api_type=["wrap_into_iterator", "wrap_into_dataset"], |
| iteration_type=["get_next", "for_loop"], |
| distribution=[ |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| strategy_combinations.multi_worker_mirrored_2x1_gpu, |
| ], |
| auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) |
| def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, |
| distribution, auto_shard_policy): |
| # Test case: 2 workers, 1 replica each. |
| # This test simulates the sharded behavior the dataset is sharded by data |
| # and the batch size is indivisible by the number of replicas. This checks |
| # that the elements are as expected and the batch size across all workers |
| # adds up to 3. This test will only pass if the autoshard rewrite rewrites |
| # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. |
| def dataset_fn(ctx): |
| del ctx |
| dataset = dataset_ops.Dataset.range(8).batch(3) |
| |
| # Set the sharding behavior to OFF for simplicity of test setup; namely, |
| # `dataset` defines the per-worker dataset and will not be further |
| # sharded. Each worker will see a dataset that is |
| # tf.data.Dataset.range(12).batch(8).rebatch(...). |
| options = options_lib.Options() |
| options.experimental_distribute.auto_shard_policy = auto_shard_policy |
| dataset = dataset.with_options(options) |
| return dataset |
| |
| dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) |
| |
| # Actual devices don't matter in this test as long as there is 1 local |
| # replica. |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| |
| # Each test runs individually on each worker, so we compare the |
| # values on each worker. We expect each worker to see different shards of |
| # data. |
| cr = distribution.cluster_resolver |
| worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, |
| cr.task_id) |
| |
| if worker_id == 0: |
| expected_values = [[[0, 1]], [[3, 4]], [[6]]] |
| elif worker_id == 1: |
| expected_values = [[[2]], [[5]], [[7]]] |
| |
| self._test_input_iteration( |
| input_type, |
| api_type, |
| iteration_type, |
| dataset, |
| worker_device_pairs, |
| expected_values, |
| distribution, |
| num_replicas_in_sync=distribution.num_replicas_in_sync, |
| input_context=distribution.extended._make_input_context()) |
| |
| |
| @framework_test_util.with_eager_op_as_function |
| class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, |
| parameterized.TestCase): |
| """Tests for PER_WORKER and PER_REPLICA's InputOptions variants.""" |
| |
| @combinations.generate( |
| combinations.combine( |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=True, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_WORKER), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=True, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_REPLICA), |
| ], |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ])) |
| def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution, |
| input_options): |
| |
| def dataset_fn(input_context): # pylint: disable=[unused-argument] |
| return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) |
| |
| ds = distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| |
| for x in ds: |
| assert x.values[0].device == distribution.extended.worker_devices[0] |
| assert x.values[0].backing_device == distribution.extended.worker_devices[ |
| 0] |
| assert x.values[1].device == distribution.extended.worker_devices[1] |
| assert x.values[1].backing_device == distribution.extended.worker_devices[ |
| 1] |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ], |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=False, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_WORKER) |
| ], |
| mode=["eager"], |
| )) |
| def testDevicePlacementForPerWorkerValuesWithoutPrefetch( |
| self, distribution, input_options): |
| |
| def dataset_fn(input_context): |
| return dataset_ops.Dataset.from_tensor_slices( |
| np.full(4, input_context.input_pipeline_id)) |
| |
| ds = distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| |
| for x in ds: |
| x = distribution.run(lambda inputs: inputs, args=(x,)) |
| assert x.values[ |
| 0].device == "/job:localhost/replica:0/task:0/device:CPU:0" |
| assert x.values[ |
| 0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" |
| assert x.values[ |
| 1].device == "/job:localhost/replica:0/task:0/device:CPU:0" |
| assert x.values[ |
| 1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" |
| |
| @combinations.generate( |
| combinations.combine( |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=True, |
| experimental_fetch_to_device=False, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_WORKER), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=True, |
| experimental_fetch_to_device=True, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_REPLICA) |
| ], |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ])) |
| def testDevicePlacementForInvalidCombinations(self, distribution, |
| input_options): |
| |
| def dataset_fn(input_context): |
| return dataset_ops.Dataset.from_tensor_slices( |
| np.full(4, input_context.input_pipeline_id)) |
| |
| with self.assertRaises(ValueError): |
| distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| |
| @combinations.generate( |
| combinations.combine( |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=False, |
| experimental_per_replica_buffer_size=2), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=True, |
| experimental_per_replica_buffer_size=2), |
| ], |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ])) |
| def testPrefetchBufferSizeInputOptions(self, distribution, input_options): |
| |
| def dataset_fn(input_context): |
| return dataset_ops.Dataset.from_tensor_slices( |
| np.arange(1, 11).reshape( |
| (2, 5)) * (input_context.input_pipeline_id + 1)) |
| |
| ds = distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| |
| # validating the values |
| x = next(iter(ds)) |
| assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) |
| assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) |
| |
| @combinations.generate( |
| combinations.combine( |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=False, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_WORKER), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=True, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_WORKER), |
| ], |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ])) |
| def testOutputValuesForPerWorkerInputOptions(self, distribution, |
| input_options): |
| |
| def dataset_fn(input_context): |
| return dataset_ops.Dataset.from_tensor_slices( |
| np.arange(1, 11).reshape( |
| (2, 5)) * (input_context.input_pipeline_id + 1)) |
| |
| ds = distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| |
| # validating the values |
| x = next(iter(ds)) |
| assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) |
| assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) |
| |
| @combinations.generate( |
| combinations.combine( |
| input_options=[ |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=True, |
| experimental_fetch_to_device=False, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_REPLICA), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=False, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_REPLICA), |
| distribute_lib.InputOptions( |
| experimental_place_dataset_on_device=False, |
| experimental_fetch_to_device=True, |
| experimental_replication_mode=distribute_lib |
| .InputReplicationMode.PER_REPLICA), |
| ], |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.mirrored_strategy_with_two_cpus, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| ])) |
| def testOutputValuesForPerReplicaInputOptions(self, distribution, |
| input_options): |
| |
| def dataset_fn(input_context): |
| return dataset_ops.Dataset.from_tensor_slices( |
| np.arange(1, 10) * (input_context.input_pipeline_id + 1)) |
| |
| ds = distribution.experimental_distribute_datasets_from_function( |
| dataset_fn, input_options) |
| expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) |
| for i, x in enumerate(ds): |
| # validating the values |
| assert x.values[0].numpy() == expected[i] |
| assert x.values[1].numpy() == expected[i] * 2 |
| loop_num = i |
| assert loop_num == len(expected) - 1 |
| |
| |
| class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase, |
| parameterized.TestCase): |
| """Tests for distributed iterators which read from tf.data service.""" |
| |
| def setUp(self): |
| super(DistributedIteratorTfDataServiceTest, self).setUp() |
| self.num_workers = 3 |
| if combinations.in_main_process(): |
| self.dispatcher = server_lib.DispatchServer() |
| self.workers = [] |
| for _ in range(self.num_workers): |
| self.workers.append( |
| server_lib.WorkerServer( |
| server_lib.WorkerConfig( |
| dispatcher_address=self.dispatcher.target.split("://")[1], |
| heartbeat_interval_ms=100, |
| dispatcher_timeout_ms=1000))) |
| combinations.env().tf_data_service_dispatcher = self.dispatcher.target |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testTfDataService(self, distribution): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| |
| dataset = dataset_ops.Dataset.range(1, 50) |
| dataset = dataset.apply( |
| data_service_ops._distribute( |
| processing_mode=data_service_ops.ShardingPolicy.OFF, |
| service=combinations.env().tf_data_service_dispatcher, |
| job_name="foo")) |
| |
| dist_dataset = input_util.get_distributed_dataset(dataset, input_workers, |
| distribution) |
| iterator = iter(dist_dataset) |
| results = [] |
| for element in iterator: |
| local_results = distribution.experimental_local_results(element) |
| for result in local_results: |
| # input_lib.distributed_dataset may add extra '0' elements to pad |
| # per-replica results. |
| if result.numpy() != 0: |
| results.append(result.numpy()) |
| self.assertNotEmpty(results) |
| gathered = distribution.gather(constant_op.constant(results), axis=0) |
| self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) |
| |
| histogram_proto = ( |
| input_lib._distributed_dataset_initialization_time_milliseconds |
| .get_cell(distribution.__class__.__name__, "1").value()) |
| self.assertGreater(histogram_proto.num, 0.0) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testDistributeDatasetFromFunction(self, distribution): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| input_contexts = [] |
| num_workers = input_workers.num_workers |
| for i in range(num_workers): |
| input_contexts.append(distribute_lib.InputContext( |
| num_input_pipelines=num_workers, |
| input_pipeline_id=i, |
| num_replicas_in_sync=num_workers)) |
| |
| dataset = dataset_ops.Dataset.range(1, 50) |
| dataset_id = "dataset_id" |
| # The body of this test is run on both the chief and the workers, so |
| # `register_dataset` will be called multiple times. We use a pre-defined |
| # `dataset_id` so that the tf.data service will understand that the |
| # registered datasets are the same. |
| data_service_ops.register_dataset( |
| service=combinations.env().tf_data_service_dispatcher, |
| dataset=dataset, dataset_id=dataset_id) |
| |
| def dataset_fn(input_context): |
| del input_context |
| return data_service_ops.from_dataset_id( |
| processing_mode=data_service_ops.ShardingPolicy.OFF, |
| service=combinations.env().tf_data_service_dispatcher, |
| dataset_id=dataset_id, |
| element_spec=dataset.element_spec, |
| job_name="shared_job") |
| |
| dist_dataset = input_util.get_distributed_datasets_from_function( |
| dataset_fn, input_workers, input_contexts, distribution) |
| |
| iterator = iter(dist_dataset) |
| results = [] |
| for element in iterator: |
| local_results = distribution.experimental_local_results(element) |
| for result in local_results: |
| # input_lib.distributed_dataset may add extra '0' elements to pad |
| # per-replica results. |
| if result.numpy() != 0: |
| results.append(result.numpy()) |
| self.assertNotEmpty(results) |
| gathered = distribution.gather(constant_op.constant(results), axis=0) |
| self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) |
| histogram_proto = ( |
| input_lib |
| ._distributed_dataset_from_function_initialization_time_milliseconds |
| .get_cell(distribution.__class__.__name__, "1").value()) |
| self.assertGreater(histogram_proto.num, 0.0) |
| |
| @combinations.generate( |
| combinations.combine( |
| mode=["eager"], |
| distribution=[ |
| strategy_combinations.one_device_strategy, |
| strategy_combinations.mirrored_strategy_with_one_cpu, |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.mirrored_strategy_with_two_gpus, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu, |
| strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, |
| strategy_combinations.multi_worker_mirrored_2x1_cpu, |
| ])) |
| def testDistributeDatasetFromFunctionNested(self, distribution): |
| worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] |
| input_workers = input_lib.InputWorkers(worker_device_pairs) |
| input_contexts = [] |
| num_workers = input_workers.num_workers |
| for i in range(num_workers): |
| input_contexts.append( |
| distribute_lib.InputContext( |
| num_input_pipelines=num_workers, |
| input_pipeline_id=i, |
| num_replicas_in_sync=num_workers)) |
| |
| class InnerType(extension_type.ExtensionType): |
| tensor: ops.Tensor |
| |
| class OuterType(extension_type.ExtensionType): |
| inner: InnerType |
| |
| def dataset_fn(input_context): |
| del input_context |
| |
| def data_fn(batch_id) -> OuterType: |
| del batch_id |
| |
| return OuterType( |
| inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]]))) |
| |
| return dataset_ops.Dataset.range(1, 10).map(data_fn) |
| |
| dist_dataset = input_util.get_distributed_datasets_from_function( |
| dataset_fn, input_workers, input_contexts, distribution) |
| |
| iterator = iter(dist_dataset) |
| results = [] |
| for element in iterator: |
| local_results = distribution.experimental_local_results(element) |
| for result in local_results: |
| results.append(result) |
| |
| expect_component = OuterType( |
| inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]]))) |
| self.assertCountEqual( |
| num_workers * [expect_component for _ in range(1, 10)], results) |
| |
| if __name__ == "__main__": |
| test_util.main() |