blob: 38b8c95529af1a9a46f41067e6714c0a22c75f3e [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the input_lib library."""
import collections
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import data_service_ops
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.data.ops.options import AutoShardPolicy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import input_util
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import extension_type
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util as framework_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
from tensorflow.python.util import nest
class DistributedIteratorTestBase(test.TestCase):
# The passed input_context is to create a sharded dataset in between-graph
# case.
# TODO(yuefengz): rewrite the following method to make it less DRY.
def _wrap_iterator(self,
input_type,
dataset_or_input_fn,
input_workers,
devices,
num_replicas_in_sync,
strategy,
input_context=None):
# The `input_context` passed in is to shard dataset for
# MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
# multiple InputContexts are needed.
if input_type == "input_fn":
self.assertIsNone(
input_context,
msg=("`The input_context` arg is only used to shard dataset in "
"`MultiWorkerMirroredStrategy` when the input type is dataset."))
input_contexts = []
for i in range(input_workers.num_workers):
input_contexts.append(
distribute_lib.InputContext(
# Note: `input_workers.num_workers` is always 1 in between-graph
# case.
num_input_pipelines=input_workers.num_workers,
input_pipeline_id=i,
num_replicas_in_sync=len(devices)))
iterator = input_lib_v1.InputFunctionIterator(dataset_or_input_fn,
input_workers,
input_contexts, strategy)
else:
iterator = input_lib_v1.DatasetIterator(
dataset_or_input_fn,
input_workers,
strategy,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
return iterator
def _wrap_dataset(self,
input_type,
dataset,
input_workers,
num_replicas_in_sync,
strategy,
input_context=None):
if input_type == "dataset":
if tf2.enabled():
return input_lib.DistributedDataset(
input_workers,
strategy,
dataset,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
else:
return input_lib_v1.DistributedDatasetV1(
dataset,
input_workers,
strategy,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
else:
return strategy.distribute_datasets_from_function(dataset)
def _assert_iterator_values(self,
iterator,
expected_values,
evaluate_fn,
devices,
enable_get_next_as_optional=False):
actual_values = []
for _ in range(len(expected_values)):
if enable_get_next_as_optional:
next_element = iterator.get_next_as_optional().get_value()
else:
next_element = iterator.get_next()
computed_value = evaluate_fn([
distribute_utils.select_replica(r, next_element)
for r in range(len(devices))
])
actual_values.append(computed_value)
for expected_value, actual_value in zip(expected_values, actual_values):
for expected, actual in zip(expected_value, actual_value):
self.assertAllEqual(expected, actual)
def _assert_dataset_values_for_loop(self, dataset, expected_values,
evaluate_fn, devices):
actual_values = []
for x in dataset:
computed_value = self.evaluate(
[distribute_utils.select_replica(r, x) for r in range(len(devices))])
actual_values.append(computed_value)
for expected_value, actual_value in zip(expected_values, actual_values):
for expected, actual in zip(expected_value, actual_value):
self.assertAllEqual(expected, actual)
def _test_input_iteration(self,
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
strategy,
sess=None,
num_replicas_in_sync=None,
input_context=None):
if iteration_type == "for_loop" and not context.executing_eagerly():
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_iterator" and input_type == "input_fn":
self.skipTest("unsupported test combination.")
devices = nest.flatten([ds for _, ds in worker_device_pairs])
input_workers = input_lib.InputWorkers(worker_device_pairs)
if api_type == "wrap_into_iterator":
iterator = self._wrap_iterator(
input_type,
dataset_or_input_fn,
input_workers,
devices,
num_replicas_in_sync,
strategy,
input_context=input_context)
else:
# wrapping into a dataset:
dataset = self._wrap_dataset(
input_type,
dataset_or_input_fn,
input_workers,
num_replicas_in_sync,
strategy,
input_context=input_context)
if ops.executing_eagerly_outside_functions():
iterator = iter(dataset)
else:
if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
iterator = dataset.make_initializable_iterator()
else:
self.skipTest("unsupported test combination")
if isinstance(iterator, composite_tensor.CompositeTensor):
nest.assert_same_structure(
iterator, iterator._type_spec, expand_composites=True)
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
def test_get_next(iterator):
self._assert_iterator_values(iterator, expected_values, evaluate,
devices)
with self.assertRaises(errors.OutOfRangeError):
self._assert_iterator_values(iterator, expected_values, evaluate,
devices)
# After re-initializing the iterator, should be able to iterate again.
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
else:
if api_type == "wrap_into_iterator":
self.skipTest("unsupported test combination")
else:
iterator = iter(dataset)
self._assert_iterator_values(iterator, expected_values, evaluate,
devices)
def test_get_next_as_optional(iterator):
self._assert_iterator_values(
iterator,
expected_values,
evaluate,
devices,
enable_get_next_as_optional=True)
next_element = iterator.get_next_as_optional()
self.assertFalse(self.evaluate(next_element.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
self._assert_iterator_values(
iterator, [0],
evaluate,
devices,
enable_get_next_as_optional=True)
test_get_next(iterator)
# re-initializing the iterator
if not tf2.enabled():
# TODO(yuefengz): we should split this function.
return
else:
if api_type == "wrap_into_iterator":
return
else:
iterator = iter(dataset)
test_get_next_as_optional(iterator)
if iteration_type == "for_loop" and context.executing_eagerly():
self._assert_dataset_values_for_loop(dataset, expected_values,
self.evaluate, devices)
def _create_dataset_or_input_fn(self, input_type, input_fn):
if input_type == "input_fn":
return input_fn
else:
return input_fn(distribute_lib.InputContext())
class DistributedIteratorTest(DistributedIteratorTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu
]))
def testMultiDeviceIterInitialize(self, distribution):
if tf2.enabled():
self.skipTest("Only V1 is supported.")
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
input_workers = input_lib.InputWorkers(worker_device_pairs)
dist_dataset = input_util.get_distributed_dataset(
dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
@def_function.function
def init_func_for_iter():
self.evaluate(iterator.initializer)
init_func_for_iter()
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
],
enable_get_next_as_optional=[True, False]))
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [[i] for i in range(10)]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
)
)
def testAutoShardExplicit(self, input_type, distribution):
worker_device_pairs = [(
"/device:CPU:0",
distribution.extended.worker_devices,
)]
dataset_fn = lambda _: dataset_ops.Dataset.range(10).batch(1)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn
)
input_workers = input_lib.InputWorkers(worker_device_pairs)
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = self._wrap_dataset(
input_type,
dataset_or_input_fn,
input_workers,
num_replicas_in_sync=None,
strategy=distribution)
dataset1 = input_ops.auto_shard_dataset(dataset, 2, 0)
iterator = iter(dataset1)
if len(distribution.extended.worker_devices) == 2:
expected_values = [[0, 2], [4, 6], [8]]
else:
expected_values = [[0], [2], [4], [6], [8]]
for element, expected in zip(iterator, expected_values):
local = distribution.experimental_local_results(element)
local_list = array_ops.concat(local, axis=0).numpy().tolist()
self.assertAllEqual(local_list, expected)
if len(distribution.extended.worker_devices) == 2:
expected_values = [[1, 3], [5, 7], [9]]
else:
expected_values = [[1], [3], [5], [7], [9]]
dataset2 = input_ops.auto_shard_dataset(dataset, 2, 1)
iterator = iter(dataset2)
for element, expected in zip(iterator, expected_values):
local = distribution.experimental_local_results(element)
local_list = array_ops.concat(local, axis=0).numpy().tolist()
self.assertAllEqual(local_list, expected)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu],
enable_get_next_as_optional=[True, False]))
def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type,
distribution, enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [[i] for i in range(10)]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
],
enable_get_next_as_optional=[True, False]))
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
distribution, enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [[i, i + 1] for i in range(0, 10, 2)]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[strategy_combinations.tpu_strategy],
enable_get_next_as_optional=[True, False]))
def testTPU(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = collections.OrderedDict()
for tpu_device in distribution.extended.worker_devices:
host_device = device_util.get_host_for_device(tpu_device)
worker_device_pairs.setdefault(host_device, [])
worker_device_pairs[host_device].append(tpu_device)
worker_device_pairs = worker_device_pairs.items()
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [[i, i + 1] for i in range(0, 10, 2)]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
],
enable_get_next_as_optional=[True, False]))
def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [
[(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
],
enable_get_next_as_optional=[True, False]))
def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type,
distribution, enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:GPU:1"])]
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
expected_values = [
[(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
# Input_context is not passed in and thus no sharding.
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testIterableIterator(self, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.Dataset.range(10)
dist_dataset = input_util.get_distributed_dataset(dataset, input_workers,
distribution)
iterator = iter(dist_dataset)
for i, element in enumerate(iterator):
self.assertAllEqual(distribution.experimental_local_results(element), [i])
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_two_cpus,
],
use_iterator=[False, True]))
def testIteratorAndDatasetEnumerateError(self, distribution, use_iterator):
# enumerate is not supported within tf.function for these types.
dataset = dataset_ops.Dataset.range(10).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
if use_iterator:
iterable = iter(dist_dataset)
else:
iterable = dist_dataset
@def_function.function
def enumerate_fn(iterable):
for _, batch in enumerate(iterable):
distribution.experimental_local_results(batch)
with self.assertRaises(NotImplementedError):
enumerate_fn(iterable)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_two_cpus,
]))
def testIterableIteratorError(self, distribution):
dataset = dataset_ops.Dataset.range(10).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
iterator = iter(dist_dataset)
# Raises error when next(iterator) is called without strategy scope
with self.assertRaises(ValueError):
def replica_fn1(iterator):
return next(iterator)
distribution.run(replica_fn1, args=(iterator,))
if distribution.num_replicas_in_sync == 1:
expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]]
elif distribution.num_replicas_in_sync == 2:
expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]],
[[8], [9]]]
with distribution.scope():
def replica_fn2(iterator):
return iterator
result = distribution.run(replica_fn2, args=(next(iterator),))
self.assertAllEqual(
distribution.experimental_local_results(result), expected_result[0])
# Confirm default ReplicaContext also works
iterator = iter(dist_dataset)
for i, element in enumerate(iterator):
self.assertAllEqual(
distribution.experimental_local_results(element), expected_result[i])
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
drop_remainder=[True, False],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]))
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
drop_remainder, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda
2, drop_remainder=drop_remainder)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
# The last global batch only contains data for one replica.
if drop_remainder:
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
else:
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
distribution.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
drop_remainder=[True, False],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type,
iteration_type, drop_remainder,
distribution):
# Actual devices don't matter in this test as long as the number of global
# repices is 2.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
cr = distribution.cluster_resolver
self.assertIsNotNone(cr)
worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
cr.task_type)
id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
cr.task_type, cr.task_id)
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(9)
if input_type == "input_fn":
# When input_fn is used, there is no automatic rebatching and sharding,
# so we add them here.
return dataset.shard(worker_count, id_in_cluster).batch(1)
else:
return dataset.batch(2, drop_remainder=drop_remainder)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
if drop_remainder and input_type == "dataset":
if id_in_cluster == 0:
expected_values = [[[0]], [[2]], [[4]], [[6]]]
else:
expected_values = [[[1]], [[3]], [[5]], [[7]]]
else:
# The last global batch only contains data for one replica.
if id_in_cluster == 0:
expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]]
else:
expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]]
distribution.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
drop_remainder=[True, False],
distribution=[
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
]))
def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type,
api_type, iteration_type,
drop_remainder,
distribution):
# Actual devices don't matter in this test as long as the number of global
# repices is 2.
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:GPU:1"])]
cr = distribution.cluster_resolver
self.assertIsNotNone(cr)
worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
cr.task_type)
id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
cr.task_type, cr.task_id)
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(15)
if input_type == "input_fn":
# When input_fn is used, there is no automatic rebatching and sharding,
# so we add them here.
return dataset.shard(worker_count, id_in_cluster).batch(1)
else:
return dataset.batch(4, drop_remainder=drop_remainder)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
# The last global batch only contains data for one replica.
if drop_remainder and input_type == "dataset":
if id_in_cluster == 0:
expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]]
else:
expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]]
else:
if id_in_cluster == 0:
expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]]
else:
expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]]
distribution.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
num_replicas_in_sync=[None, 2],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
],
enable_get_next_as_optional=[True, False]))
def testBatchSplitting(self, input_type, api_type, iteration_type,
num_replicas_in_sync, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
batch_size = 10
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
updated_batch_size = (
batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [[
range(i, i + updated_batch_size),
range(i + updated_batch_size, i + 2 * updated_batch_size)
] for i in range(0, 100, updated_batch_size * 2)]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
distribution,
sess=None,
num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
num_replicas_in_sync=[None, 2],
distribution=[
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
],
enable_get_next_as_optional=[True, False]))
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
num_replicas_in_sync, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:GPU:1"])]
batch_size = 10
cr = distribution.cluster_resolver
self.assertIsNotNone(cr)
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(100).batch(batch_size)
return dataset
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
updated_batch_size = (
batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [
[ # pylint: disable=g-complex-comprehension
range(i, i + updated_batch_size),
range(i + updated_batch_size, i + 2 * updated_batch_size)
] for i in range(0, 100, updated_batch_size * 2)
]
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
distribution,
sess=None,
num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
],
))
def testCacheAcrossIteration(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
first_epoch = list(
distribution.experimental_local_results(x) for x in dist_dataset)
second_epoch = list(
distribution.experimental_local_results(x) for x in dist_dataset)
self.assertAllEqual(first_epoch, second_epoch)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
],
reshuffle=[True, False]))
def testShuffleAcrossIterations(self, distribution, reshuffle):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
dataset = dataset_ops.Dataset.range(12).shuffle(
12, reshuffle_each_iteration=reshuffle).batch(4)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
first_epoch = list(
distribution.experimental_local_results(x) for x in dist_dataset)
second_epoch = list(
distribution.experimental_local_results(x) for x in dist_dataset)
if reshuffle:
self.assertNotAllEqual(first_epoch, second_epoch)
else:
self.assertAllEqual(first_epoch, second_epoch)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testGetNextOptionalShapeFinite(self, distribution):
batch_size = 8
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"feature": array_ops.ones([batch_size, 10]),
"label": array_ops.ones([batch_size]),
})
dataset = dataset.batch(batch_size, drop_remainder=True)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
@def_function.function
def train_fn():
for data in dist_dataset:
data = nest.map_structure(distribution.experimental_local_results, data)
feature = data["feature"]
label = data["label"]
# Assert the shapes are still static from all replicas.
for replica_id in range(len(distribution.extended.worker_devices)):
self.assertEqual([None, 10],
feature[replica_id].shape.as_list())
self.assertEqual([None], label[replica_id].shape.as_list())
train_fn()
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testGetNextOptionalShapeInfinite(self, distribution):
batch_size = 8
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"feature": array_ops.ones([batch_size, 10]),
"label": array_ops.ones([batch_size]),
})
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.repeat()
dist_dataset = distribution.experimental_distribute_dataset(dataset)
per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
@def_function.function
def train_fn():
data = iter(dist_dataset).get_next_as_optional().get_value()
data = nest.map_structure(distribution.experimental_local_results, data)
feature = data["feature"]
label = data["label"]
# Assert the shapes are still static from all replicas.
for replica_id in range(len(distribution.extended.worker_devices)):
self.assertEqual([per_replica_batch_size, 10],
feature[replica_id].shape.as_list())
self.assertEqual([per_replica_batch_size],
label[replica_id].shape.as_list())
train_fn()
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testGetNextOptionalShapeEmpty(self, distribution):
batch_size = 8
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"feature": array_ops.ones([batch_size, 10]),
"label": array_ops.ones([batch_size]),
})
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.repeat()
dist_dataset = distribution.experimental_distribute_dataset(dataset)
per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
@def_function.function
def train_fn():
data = iter(dist_dataset).get_next_as_optional()
feature_specs = data.element_spec["feature"]._component_specs
value_specs = data.element_spec["label"]._component_specs
if not isinstance(feature_specs, tuple):
feature_specs = (feature_specs,)
value_specs = (value_specs,)
# Assert the shapes are still static from all replicas.
for replica_id in range(len(distribution.extended.worker_devices)):
self.assertEqual([per_replica_batch_size, 10],
feature_specs[replica_id].shape.as_list())
self.assertEqual([per_replica_batch_size],
value_specs[replica_id].shape.as_list())
train_fn()
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
def testAutoshardingOption(self, distribution, input_type, api_type,
iteration_type, auto_shard_policy):
cr = distribution.cluster_resolver
self.assertIsNotNone(cr)
id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
cr.task_type, cr.task_id)
ds_option = options_lib.Options()
ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
dataset_fn = (
lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
if auto_shard_policy == AutoShardPolicy.AUTO:
if id_in_cluster == 0:
expected_values = [[0], [2]]
else:
expected_values = [[1], [3]]
else:
expected_values = [[0], [1], [2], [3]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset_or_input_fn,
worker_device_pairs,
expected_values,
distribution,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
],
input_type=["input_fn"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"]))
def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type,
iteration_type):
cr = distribution.cluster_resolver
self.assertIsNotNone(cr)
id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
cr.task_type, cr.task_id)
def dataset_fn(ctx):
if ctx.input_pipeline_id == 0:
return dataset_ops.Dataset.range(8).batch(2)
else:
return dataset_ops.Dataset.range(9).batch(2)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
if id_in_cluster == 0:
expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]]
else:
expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]]
distribution.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_or_input_fn, worker_device_pairs,
expected_values, distribution)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
mode=["eager"]))
def testLoopOverDatasetInTFFunction(self, strategy):
dataset = dataset_ops.Dataset.range(10).map(lambda x: { # pylint: disable=g-long-lambda
"y": math_ops.cast(x, dtypes.float32) ** 2,
}).batch(4)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
with strategy.scope():
v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM)
@def_function.function
def iterator_fn(dist_dataset):
def assign_add_fn(data):
v.assign_add(math_ops.reduce_sum(data["y"]))
for data in dist_dataset:
strategy.run(assign_add_fn, args=(data,))
iterator_fn(dist_dataset)
self.assertEqual(v.numpy(), 285.0)
class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for DistributedDataset with non-dense tensors."""
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
defun_type=["lambda", "tf_function"],
))
def testRaggedSparse(self, distribution, input_type, drop_remainder,
defun_type):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
self.skipTest("b/213596871, b/214574707")
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
defun = {
"lambda": lambda f: f,
"tf_function": def_function.function
}[defun_type]
distribution.extended.experimental_enable_get_next_as_optional = True
global_batch_size = 8
def dataset_fn(ctx=None):
ctx = ctx or distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"dense": ragged_tensor.to_tensor(),
"ragged": ragged_tensor,
"sparse": ragged_tensor.to_sparse(),
})
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
distribution.extended._input_workers,
distribution.num_replicas_in_sync,
distribution)
# Assert that the tensors are rebatched and sparsity is preserved.
per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
self.assertAllEqual(
distribute_utils.select_replica(0, per_replica_batch["dense"]),
[[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
self.assertAllEqual(
distribute_utils.select_replica(1, per_replica_batch["dense"]),
[[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
# Transitively check the ragged and sparse tensors by densification.
for i in range(2):
self.assertLen(
distribute_utils.select_replica(i,
per_replica_batch["ragged"]).values,
6)
self.assertAllEqual(
distribute_utils.select_replica(
i, per_replica_batch["ragged"]).to_tensor(),
distribute_utils.select_replica(i, per_replica_batch["dense"]))
self.assertLen(
distribute_utils.select_replica(i,
per_replica_batch["sparse"]).indices,
6)
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(
distribute_utils.select_replica(i, per_replica_batch["sparse"])),
distribute_utils.select_replica(i, per_replica_batch["dense"]))
# Iterate through all the batches and sum them up.
def sum_batch(per_replica_features):
"""Sums the `PerReplica` values in the `per_replica_features` map."""
def map_fn(per_replica_values):
per_replica_sums = distribution.run(
(lambda x: math_ops.reduce_sum(x.values)) if all(
map(sparse_tensor.is_sparse, per_replica_values.values)) else
math_ops.reduce_sum, (per_replica_values,))
return distribution.reduce(
reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
return nest.map_structure(map_fn, per_replica_features)
def _reduce(state, batch):
sums = sum_batch(batch)
return {name: value + sums[name] for name, value in state.items()}
def sum_for_loop(dataset):
sums = {"dense": 0., "ragged": 0., "sparse": 0.}
for batch in dataset:
sums = _reduce(sums, batch)
return sums
def sum_while_loop(iterator, reduce_fn):
sums = {"dense": 0., "ragged": 0., "sparse": 0.}
while True:
try:
sums = reduce_fn(sums, iterator)
except (StopIteration, errors.OutOfRangeError):
return sums
while_sums = sum_while_loop(
iter(dataset),
defun(lambda state, iterator: _reduce(state, next(iterator))))
self.assertAllEqual(
nest.flatten(while_sums),
# When there's no partial batch, the sum is smaller.
[200. if drop_remainder else 310.] * 3)
for_sums = defun(sum_for_loop)(dataset)
# For loops always call get next as optional inside tf functions, so we
# expect 310 here when using an input function (as there are 5 batches of
# size 4 round robined over 2 replicas.
expected_for_sum = 200.
if (not drop_remainder or
(defun_type == "tf_function" and input_type == "input_fn")):
expected_for_sum = 310.
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
tensor_type=["sparse", "ragged"],
enable_get_next_as_optional=[True, False]))
def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
drop_remainder, tensor_type,
enable_get_next_as_optional):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
global_batch_size = 8
def dataset_fn(ctx=None):
ctx = ctx or distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
tensor_type: (ragged_tensor if tensor_type == "ragged" else
ragged_tensor.to_sparse()),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
return dataset.batch(batch_size, drop_remainder=drop_remainder)
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
else:
ds = distribution.distribute_datasets_from_function(dataset_fn)
iterator = iter(ds)
self.assertEqual(iterator._enable_get_next_as_optional,
(not drop_remainder) and enable_get_next_as_optional)
@combinations.generate(
combinations.combine(
tf_api_version=2,
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
))
def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type,
drop_remainder):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
global_batch_size = 8
def dataset_fn(ctx=None):
ctx = ctx or distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"dense": ragged_tensor.to_tensor(),
"ragged": ragged_tensor,
"sparse": ragged_tensor.to_sparse(),
})
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
else:
ds = distribution.distribute_datasets_from_function(dataset_fn)
# Iterate through all the batches and sum them up.
def sum_batch(per_replica_features):
"""Sums the `PerReplica` values in the `per_replica_features` map."""
def map_fn(per_replica_values):
def _sum(value):
if sparse_tensor.is_sparse(value):
return math_ops.reduce_sum(value.values)
else:
return math_ops.reduce_sum(value)
per_replica_sums = distribution.run(_sum, args=(per_replica_values,))
return distribution.reduce(
reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
return nest.map_structure(map_fn, per_replica_features)
def _reduce(state, batch):
sums = sum_batch(batch)
return {name: value + sums[name] for name, value in state.items()}
def sum_while_loop(ds):
iterator = iter(ds)
sums = {"dense": 0., "ragged": 0., "sparse": 0.}
try_next = constant_op.constant(True)
while try_next:
opt_iterate = iterator.get_next_as_optional()
if opt_iterate.has_value():
sums = _reduce(sums, opt_iterate.get_value())
else:
try_next = False
return sums
sums = def_function.function(sum_while_loop)(ds)
# For loops always call get next as optional inside tf functions, so we
# expect 310 here when using an input function (as there are 5 batches of
# size 4 round robined over 2 replicas.
expected_for_sum = 200.
if not drop_remainder or input_type == "input_fn":
expected_for_sum = 310.
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(12).batch(8)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = options_lib.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
iteration_type, distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps. However, when we create a
# DistributedDataset and cannot statically infer the intended global batch
# size (e.g. if the user does not use a batching dataset), each worker will
# rebatch based on the dynamic batch size of the data encountered, even when
# it encounters partial batches. The last per-worker partial batch (size 4)
# ends up being split into two replicas, resulting in 4 steps in total, of
# (global) batch sizes 8, 8, 4, 4.
def dataset_fn(ctx):
del ctx
# The following dataset is equivalent to
# tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
# This causes DistributedDataset to use LegacyRebatch instead.
batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
def map_fn(offset, batch_size):
return math_ops.range(offset, offset + batch_size)
dataset = dataset.map(map_fn)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is equivalent to
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = options_lib.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as the number of global
# replicas is 2.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
distribution, auto_shard_policy):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior the dataset is sharded by data
# and the batch size is indivisible by the number of replicas. This checks
# that the elements are as expected and the batch size across all workers
# adds up to 3. This test will only pass if the autoshard rewrite rewrites
# RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(8).batch(3)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = options_lib.Options()
options.experimental_distribute.auto_shard_policy = auto_shard_policy
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. We expect each worker to see different shards of
# data.
cr = distribution.cluster_resolver
worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
cr.task_id)
if worker_id == 0:
expected_values = [[[0, 1]], [[3, 4]], [[6]]]
elif worker_id == 1:
expected_values = [[[2]], [[5]], [[7]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@framework_test_util.with_eager_op_as_function
class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
input_options):
def dataset_fn(input_context): # pylint: disable=[unused-argument]
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
assert x.values[0].device == distribution.extended.worker_devices[0]
assert x.values[0].backing_device == distribution.extended.worker_devices[
0]
assert x.values[1].device == distribution.extended.worker_devices[1]
assert x.values[1].backing_device == distribution.extended.worker_devices[
1]
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER)
],
mode=["eager"],
))
def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
self, distribution, input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
x = distribution.run(lambda inputs: inputs, args=(x,))
assert x.values[
0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_fetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_fetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA)
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForInvalidCombinations(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
with self.assertRaises(ValueError):
distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=False,
experimental_per_replica_buffer_size=2),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testPrefetchBufferSizeInputOptions(self, distribution, input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 11).reshape(
(2, 5)) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
# validating the values
x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerWorkerInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 11).reshape(
(2, 5)) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
# validating the values
x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_fetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_fetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations
.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerReplicaInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 10) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
for i, x in enumerate(ds):
# validating the values
assert x.values[0].numpy() == expected[i]
assert x.values[1].numpy() == expected[i] * 2
loop_num = i
assert loop_num == len(expected) - 1
class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for distributed iterators which read from tf.data service."""
def setUp(self):
super(DistributedIteratorTfDataServiceTest, self).setUp()
self.num_workers = 3
if combinations.in_main_process():
self.dispatcher = server_lib.DispatchServer()
self.workers = []
for _ in range(self.num_workers):
self.workers.append(
server_lib.WorkerServer(
server_lib.WorkerConfig(
dispatcher_address=self.dispatcher.target.split("://")[1],
heartbeat_interval_ms=100,
dispatcher_timeout_ms=1000)))
combinations.env().tf_data_service_dispatcher = self.dispatcher.target
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testTfDataService(self, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.Dataset.range(1, 50)
dataset = dataset.apply(
data_service_ops._distribute(
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=combinations.env().tf_data_service_dispatcher,
job_name="foo"))
dist_dataset = input_util.get_distributed_dataset(dataset, input_workers,
distribution)
iterator = iter(dist_dataset)
results = []
for element in iterator:
local_results = distribution.experimental_local_results(element)
for result in local_results:
# input_lib.distributed_dataset may add extra '0' elements to pad
# per-replica results.
if result.numpy() != 0:
results.append(result.numpy())
self.assertNotEmpty(results)
gathered = distribution.gather(constant_op.constant(results), axis=0)
self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
histogram_proto = (
input_lib._distributed_dataset_initialization_time_milliseconds
.get_cell(distribution.__class__.__name__, "1").value())
self.assertGreater(histogram_proto.num, 0.0)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testDistributeDatasetFromFunction(self, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
input_contexts = []
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=num_workers))
dataset = dataset_ops.Dataset.range(1, 50)
dataset_id = "dataset_id"
# The body of this test is run on both the chief and the workers, so
# `register_dataset` will be called multiple times. We use a pre-defined
# `dataset_id` so that the tf.data service will understand that the
# registered datasets are the same.
data_service_ops.register_dataset(
service=combinations.env().tf_data_service_dispatcher,
dataset=dataset, dataset_id=dataset_id)
def dataset_fn(input_context):
del input_context
return data_service_ops.from_dataset_id(
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=combinations.env().tf_data_service_dispatcher,
dataset_id=dataset_id,
element_spec=dataset.element_spec,
job_name="shared_job")
dist_dataset = input_util.get_distributed_datasets_from_function(
dataset_fn, input_workers, input_contexts, distribution)
iterator = iter(dist_dataset)
results = []
for element in iterator:
local_results = distribution.experimental_local_results(element)
for result in local_results:
# input_lib.distributed_dataset may add extra '0' elements to pad
# per-replica results.
if result.numpy() != 0:
results.append(result.numpy())
self.assertNotEmpty(results)
gathered = distribution.gather(constant_op.constant(results), axis=0)
self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
histogram_proto = (
input_lib
._distributed_dataset_from_function_initialization_time_milliseconds
.get_cell(distribution.__class__.__name__, "1").value())
self.assertGreater(histogram_proto.num, 0.0)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testDistributeDatasetFromFunctionNested(self, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
input_contexts = []
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(
distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=num_workers))
class InnerType(extension_type.ExtensionType):
tensor: ops.Tensor
class OuterType(extension_type.ExtensionType):
inner: InnerType
def dataset_fn(input_context):
del input_context
def data_fn(batch_id) -> OuterType:
del batch_id
return OuterType(
inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]])))
return dataset_ops.Dataset.range(1, 10).map(data_fn)
dist_dataset = input_util.get_distributed_datasets_from_function(
dataset_fn, input_workers, input_contexts, distribution)
iterator = iter(dist_dataset)
results = []
for element in iterator:
local_results = distribution.experimental_local_results(element)
for result in local_results:
results.append(result)
expect_component = OuterType(
inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]])))
self.assertCountEqual(
num_workers * [expect_component for _ in range(1, 10)], results)
if __name__ == "__main__":
test_util.main()