blob: 2313fefc522bd78c74a3ce69cb61323092f889a9 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed inputs."""
import functools
import sys
import time
import six
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.distribute_lib import InputReplicationMode
from tensorflow.python.eager import context
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond as tf_cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import while_loop
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import distribute as distribute_types
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler(
"/tensorflow/api/distribution_strategy/"
"distributed_dataset_initialization_time_milliseconds",
monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26),
"Track the time (in milliseconds) to initialize distributed datasets.",
"strategy", "workers")
_distributed_dataset_from_function_initialization_time_milliseconds = (
monitoring.Sampler(
"/tensorflow/api/distribution_strategy/"
"distributed_dataset_from_function_initialization_time_milliseconds",
monitoring.ExponentialBuckets(
scale=1, growth_factor=2, bucket_count=26),
"Track the time (in milliseconds) to initialize distributed datasets "
"from function.",
"strategy", "workers"))
def get_iterator_spec_from_dataset(strategy, dataset):
"""Returns an iterator spec from dataset function.
This function constructs type spec for iterator obtained from
iter(dataset).
Args:
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
dataset: A tf.data.Dataset instance. If using a function that returns a
tf.data.Dataset instance, pass dataset_fn.structured_outputs.
Returns:
A type_spec for iterator for dataset instance.
"""
# pylint: disable=protected-access
output_element_spec = dataset.element_spec
if isinstance(dataset._type_spec,
(DistributedDatasetSpec,
DistributedDatasetsFromFunctionSpec)):
iterator_type_spec = DistributedIteratorSpec(
strategy.extended._input_workers_with_options(),
output_element_spec,
strategy.extended._container_strategy(),
options=None,
cardinality=dataset.cardinality,
enable_get_next_as_optional=True)
else:
if strategy.extended._num_gpus_per_worker:
logging.warning(
f"{strategy.extended._num_gpus_per_worker} GPUs "
"are allocated per worker. Please use DistributedDataset by "
"calling strategy.experimental_distribute_dataset or strategy."
"distribute_datasets_from_function to make best use of GPU "
"resources"
)
iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec)
return iterator_type_spec
# pylint: enable=protected-access
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
# TODO(ishark): Remove option canonicalize_devices and make all the callers
# pass canonicalized or raw device strings as relevant from strategy.
def __init__(self,
worker_device_pairs,
canonicalize_devices=True):
"""Initialize an `InputWorkers` object.
Args:
worker_device_pairs: A sequence of pairs: `(input device, a tuple of
compute devices fed by that input device)`.
canonicalize_devices: Whether to canonicalize devices for workers fully or
partially. If False, it will partially canonicalize devices by removing
job and task.
"""
self._worker_device_pairs = worker_device_pairs
self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
self._canonicalize_devices = canonicalize_devices
if canonicalize_devices:
self._fed_devices = tuple(
tuple(device_util.canonicalize(d)
for d in f)
for _, f in self._worker_device_pairs)
else:
self._fed_devices = tuple(
tuple(device_util.canonicalize_without_job_and_task(d)
for d in f)
for _, f in self._worker_device_pairs)
@property
def num_workers(self):
return len(self._input_worker_devices)
@property
def worker_devices(self):
return self._input_worker_devices
def compute_devices_for_worker(self, worker_index):
return self._fed_devices[worker_index]
def __repr__(self):
devices = self.worker_devices
debug_repr = ",\n".join(" %d %s: %s" %
(i, devices[i], self._fed_devices[i])
for i in range(len(devices)))
return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
def serialize(self):
return (self._worker_device_pairs, self._canonicalize_devices)
def deserialize(self, serialized):
return InputWorkers(serialized)
def _calculate_replicas_with_values(strategy, input_workers, optional_list):
"""Calcualates the number of replicas that have values.
Args:
strategy: the `tf.distribute.Strategy`.
input_workers: the `InputWorkers`.
optional_list: a list of lists `tf.experimental.Optional`. The values from
each compute device grouped by the input device.
Returns:
A scalar Tensor.
"""
worker_has_values = []
for worker, optionals in zip(input_workers.worker_devices, optional_list):
with ops.device(worker):
device_has_values = [
math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
]
worker_has_values.append(
math_ops.reduce_sum(device_has_values, keepdims=True))
client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access
global_has_values = strategy.reduce(
reduce_util.ReduceOp.SUM, client_has_values, axis=None)
return array_ops.reshape(global_has_values, [])
else:
return array_ops.reshape(client_has_values, [])
def _is_statically_shaped(element_spec):
"""Test if an iterator output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension.
Args:
element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
dataset of the iterator.
Returns:
True if the shape is static, false otherwise.
"""
for spec in nest.flatten(element_spec):
if isinstance(
spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
return False
else:
for component in spec._flat_tensor_specs: # pylint: disable=protected-access
if not component.shape.is_fully_defined():
return False
return True
class DistributedIteratorBase(collections_abc.Iterator,
distribute_types.DistributedIteratorInterface):
"""Common implementation for all input iterators."""
# pylint: disable=super-init-not-called
def __init__(
self,
input_workers,
iterators,
strategy,
cardinality,
enable_get_next_as_optional,
replica_order=None,
):
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
self._iterators = iterators
self._input_workers = input_workers
self._strategy = strategy
self._cardinality = cardinality
self._enable_get_next_as_optional = enable_get_next_as_optional
self._replica_order = replica_order
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except errors.OutOfRangeError:
raise StopIteration
def __iter__(self):
return self
def get_next_as_optional(self):
# Ideally get_next_as_optional() should be consistent with get_next(), but
# we used to always do partial batch handling in get_next_as_optional(). We
# are keeping this behavior for now until we understantd the impact.
# Skip partial batch handling when the dataset is infinite or empty, as
# there won't be any partial batches in those cases. This gives the user
# more static shapes as it avoids the tf.cond. Note that for empty datasets,
# we can only skip in single client mode, as the dataset can be non-empty on
# other workers.
if self._cardinality == cardinality_lib.INFINITE:
return optional_ops.Optional.from_value(
self._get_next_no_partial_batch_handling())
if (self._cardinality == 0 and
not self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return optional_ops.Optional.empty(self._element_spec)
optional_list = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
optional_list.append(self._iterators[i].get_next_as_optional_list())
def _create_optional_with_dummy():
value_list = _get_value_or_dummy(
self._input_workers, optional_list, produce_dummy=True)
if self._replica_order is not None:
value_list = self._reorder_replicas(value_list)
per_replica = _create_per_replica(value_list, self._strategy)
return optional_ops.Optional.from_value(per_replica)
def _create_empty_optional():
return optional_ops.Optional.empty(self._element_spec)
num_replicas_with_values = _calculate_replicas_with_values(
self._strategy, self._input_workers, optional_list)
return tf_cond.cond(
num_replicas_with_values > 0,
_create_optional_with_dummy,
_create_empty_optional,
strict=True)
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
with distribute_lib.enter_or_assert_strategy(
self._strategy):
if distribute_lib.get_replica_context() is not None:
raise ValueError("next(iterator) should be called from outside of "
"replica_fn. e.g. strategy.run(replica_fn, "
"args=(next(iterator),))")
if not self._enable_get_next_as_optional:
return self._get_next_no_partial_batch_handling(name)
optional_list = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
optional_list.append(self._iterators[i].get_next_as_optional_list())
num_replicas_with_values = _calculate_replicas_with_values(
self._strategy, self._input_workers, optional_list)
def _value_or_dummy():
value_list = _get_value_or_dummy(
self._input_workers, optional_list, produce_dummy=True)
if self._replica_order is not None:
value_list = self._reorder_replicas(value_list)
return _create_per_replica(value_list, self._strategy)
def _eof():
# Optional.get_value raises InvalidArgumentError when there's no value,
# so we need to call GetNext to raise EOFError.
return self._get_next_no_partial_batch_handling()
return tf_cond.cond(
num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)
def _get_next_no_partial_batch_handling(self, name=None):
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(self._iterators[i].get_next_as_list(new_name))
if self._replica_order is not None:
replicas = self._reorder_replicas(replicas)
return _create_per_replica(replicas, self._strategy)
def _reorder_replicas(self, replicas):
assert len(self._replica_order) == len(
replicas
), "replica order size ({}) != replicas size ({})!".format(
len(self._replica_order), len(replicas)
)
return [replicas[i] for i in self._replica_order]
class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec):
"""Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction."""
__slots__ = [
"_input_workers", "_element_spec", "_strategy", "_cardinality",
"_enable_get_next_as_optional", "_options", "_canonicalize_devices"
]
def __init__(
self,
input_workers,
element_spec,
strategy,
options,
cardinality=cardinality_lib.UNKNOWN,
enable_get_next_as_optional=None,
replica_order=None,
):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
if isinstance(input_workers, tuple):
raise NotImplementedError("DistributedIteratorSpec does not have support "
"for deserialization.")
else:
self._input_workers = input_workers
self._element_spec = element_spec
self._strategy = strategy
self._cardinality = cardinality
self._enable_get_next_as_optional = enable_get_next_as_optional
self._options = options
if self._strategy:
self._canonicalize_devices = getattr(self._strategy,
"_canonicalize_devices", True)
else:
self._canonicalize_devices = True
self._replica_order = replica_order
def _serialize(self):
# We cannot serialize the strategy object so we convert it to an id that we
# can use for comparison.
return (self._input_workers.serialize(), self._element_spec,
id(self._strategy), id(self._options))
def _deserialize(self):
raise ValueError(
f"Deserialization is currently unsupported for {type(self)}.")
def sanity_check_type(self, other):
"""Returns the most specific TypeSpec compatible with `self` and `other`.
Args:
other: A `TypeSpec`.
Raises:
ValueError: If there is no TypeSpec that is compatible with both `self`
and `other`.
"""
# pylint: disable=protected-access
if type(self) is not type(other):
raise ValueError("No TypeSpec is compatible with both %s and %s" %
(self, other))
if self._input_workers.serialize() != other._input_workers.serialize():
raise ValueError("_input_workers is not compatible with both %s "
"and %s" % (self, other))
if self._strategy is not other._strategy:
raise ValueError("tf.distribute strategy is not compatible with both %s "
"and %s" % (self, other))
def is_subtype_of(self, other):
"""Returns True if `self` is subtype of `other`.
Args:
other: A `TypeSpec`.
"""
try:
self.sanity_check_type(other)
nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access
except (TypeError, ValueError):
return False
self_elements = nest.flatten(self._element_spec)
other_elements = nest.flatten(other._element_spec) # pylint: disable=protected-access
return all(
self_element.is_subtype_of(other_element)
for (self_element, other_element) in zip(self_elements, other_elements))
def most_specific_common_supertype(self, others):
"""Returns the most specific supertype of `self` and `others`.
Args:
others: A Sequence of `TypeSpec`.
Returns `None` if a supertype does not exist.
"""
try:
for other in others:
self.sanity_check_type(other)
nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access
except (TypeError, ValueError):
return None
self_elements = nest.flatten(self._element_spec)
others_elements = [nest.flatten(other._element_spec) for other in others] # pylint: disable=protected-access
common_elements = [None] * len(self_elements)
for i, self_element in enumerate(self_elements):
common_elements[i] = self_element.most_specific_common_supertype(
[other_elements[i] for other_elements in others_elements])
if common_elements[i] is None:
return None
common_element_spec = nest.pack_sequence_as(self._element_spec,
common_elements)
return type(self)(
self._input_workers,
common_element_spec,
self._strategy,
self._options,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional)
def _with_tensor_ranks_only(self):
element_spec = nest.map_structure(
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
self._element_spec)
return type(self)(
self._input_workers,
element_spec,
self._strategy,
self._options,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional)
# TODO(b/206014848): Remove once names are not used.
def _without_tensor_names(self):
element_spec = nest.map_structure(
lambda s: s._without_tensor_names(), # pylint: disable=protected-access
self._element_spec)
return type(self)(
self._input_workers,
element_spec,
self._strategy,
self._options,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional)
class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec):
"""Type specification for `DistributedIterator`."""
@property
def value_type(self):
return DistributedIterator
@property
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(
_SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
element_spec, self._options,
self._canonicalize_devices))
return specs
def _to_components(self, value):
return value._iterators # pylint: disable=protected-access
def _from_components(self, components):
return DistributedIterator(
input_workers=self._input_workers,
iterators=None,
components=components,
element_spec=self._element_spec,
strategy=self._strategy,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options,
replica_order=self._replica_order,
)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(
value._input_workers,
value._element_spec,
value._strategy,
value._options,
cardinality=value._cardinality,
enable_get_next_as_optional=value._enable_get_next_as_optional)
class DistributedIterator(DistributedIteratorBase,
composite_tensor.CompositeTensor):
"""Input Iterator for a distributed dataset."""
def __init__(
self,
input_workers=None,
iterators=None,
strategy=None,
components=None,
element_spec=None,
cardinality=cardinality_lib.UNKNOWN,
enable_get_next_as_optional=False,
options=None,
replica_order=None,
):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")
error_message = ("Either `input_workers` or "
"both `components` and `element_spec` need to be "
"provided.")
self._options = options
if iterators is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._input_workers = input_workers
self._iterators = components
self._strategy = strategy
self._cardinality = cardinality
self._enable_get_next_as_optional = enable_get_next_as_optional
self._replica_order = replica_order
else:
if (components is not None and element_spec is not None):
raise ValueError(error_message)
super(DistributedIterator, self).__init__(
input_workers,
iterators,
strategy,
cardinality,
enable_get_next_as_optional,
replica_order,
)
@property
def element_spec(self):
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@property
def _type_spec(self):
# Note that we use actual element_spec instead of the rebatched-as-dynamic
# one to create DistributedIteratorSpec, to be consistent with the
# underlying iterators' specs.
return DistributedIteratorSpec(
self._input_workers,
self._element_spec,
self._strategy,
self._options,
self._cardinality,
self._enable_get_next_as_optional,
self._replica_order,
)
class _IterableInput(collections_abc.Iterable,
distribute_types.DistributedDatasetInterface):
"""Base class for iterable inputs for distribution strategies."""
# pylint: disable=super-init-not-called
def __init__(self, input_workers):
assert isinstance(input_workers, InputWorkers)
self._input_workers = input_workers
def __iter__(self):
raise NotImplementedError("must be implemented in descendants")
def reduce(self, initial_state, reduce_fn):
"""Execute a `reduce_fn` over all the elements of the input."""
iterator = iter(self)
optional_data = iterator.get_next_as_optional()
def cond(optional_data, state):
del state # Unused.
return optional_data.has_value()
def loop_body(optional_data, state):
"""Executes `reduce_fn` in a loop till the dataset is empty."""
state = reduce_fn(state, optional_data.get_value())
optional_data = iterator.get_next_as_optional()
return optional_data, state
optional_data, final_state = while_loop.while_loop(
cond,
loop_body, [optional_data, initial_state],
parallel_iterations=1,
return_same_structure=True)
return final_state
class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec):
"""Type specification for `DistributedDataset."""
@property
def value_type(self):
return DistributedDataset
@property
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i, _ in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(dataset_ops.DatasetSpec(element_spec))
return specs
def _to_components(self, value):
return value._cloned_datasets # pylint: disable=protected-access
def _from_components(self, components):
return DistributedDataset(
input_workers=self._input_workers,
strategy=self._strategy,
components=components,
element_spec=self._element_spec,
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options,
replica_order=self._replica_order,
)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedDatasetSpec(
value._input_workers,
value._element_spec,
value._strategy,
value._options,
enable_get_next_as_optional=value._enable_get_next_as_optional)
# pylint: enable=protected-access
class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
"""Distributed dataset that supports prefetching to multiple devices."""
def __init__(
self,
input_workers,
strategy,
dataset=None,
num_replicas_in_sync=None,
input_context=None,
components=None,
element_spec=None,
enable_get_next_as_optional=None,
build=True,
options=None,
replica_order=None,
):
"""Distribute the dataset on all workers.
If `num_replicas_in_sync` is not None, we split each batch of the dataset
into `num_replicas_in_sync` smaller batches, to be distributed among that
worker's replicas, so that the batch size for a global step (across all
workers and replicas) is as expected.
Args:
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
dataset: `tf.data.Dataset` that will be used as the input source. Either
dataset or components field should be passed when constructing
DistributedDataset. Use this when contructing DistributedDataset from a
new `tf.data.Dataset`. Use components when constructing using
DistributedDatasetSpec.
num_replicas_in_sync: Optional integer. If this is not None, the value is
used to decide how to rebatch datasets into smaller batches so that the
total batch size for each step (across all workers and replicas) adds up
to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
components: datasets when DistributedDataset is constructed from
DistributedDatasetSpec. Either field dataset or components should be
passed.
element_spec: element spec for DistributedDataset when constructing from
DistributedDatasetSpec. This will be used to set the element_spec for
DistributedDataset and verified against element_spec from components.
enable_get_next_as_optional: this is required when components is passed
instead of dataset.
build: whether to build underlying datasets when this object is created.
This is only useful for `ParameterServerStrategy` now.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
replica_order: the order of the replicas, which will be used to reorder
the iterators to match the device order.
"""
super(DistributedDataset, self).__init__(input_workers=input_workers)
if input_workers is None or strategy is None:
raise ValueError("input_workers and strategy are required arguments")
if dataset is not None and components is not None:
raise ValueError("Only one of dataset or components should be present")
if dataset is None and components is None:
raise ValueError("At least one of dataset or components should be passed")
self._input_workers = input_workers
self._strategy = strategy
self._options = options
self._input_context = input_context
self._num_replicas_in_sync = num_replicas_in_sync
self._replica_order = replica_order
if dataset is not None:
self._original_dataset = dataset
self._built = False
if build:
self.build()
else:
if not build:
raise ValueError(
"When constructing DistributedDataset with components, build "
"should not be False. This is an internal error. Please file a "
"bug.")
if enable_get_next_as_optional is None:
raise ValueError(
"When constructing DistributedDataset with components, " +
"enable_get_next_as_optional should also be passed")
self._cloned_datasets = components
self._cardinality = _cardinality(self._cloned_datasets[0])
self._enable_get_next_as_optional = enable_get_next_as_optional
assert element_spec is not None
if element_spec != _create_distributed_tensor_spec(
self._strategy, self._cloned_datasets[0].element_spec):
raise ValueError("Mismatched element_spec from the passed components")
self._element_spec = element_spec
self._built = True
def build(self, dataset_to_replace=None):
assert not self._built
dataset = dataset_to_replace or self._original_dataset
self._cardinality = _cardinality(dataset)
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, dataset, self._cardinality)
distribute_start_time_ns = time.time_ns()
self._create_cloned_datasets_from_dataset(dataset, self._input_context,
self._input_workers,
self._strategy,
self._num_replicas_in_sync)
if context.executing_eagerly():
# Records the time to initialize the distributed dataset.
context.async_wait()
distribute_duration_ms = (time.time_ns() -
distribute_start_time_ns) // 1_000_000
_distributed_dataset_initialization_time_milliseconds.get_cell(
self._strategy.__class__.__name__,
str(self._input_workers.num_workers)).add(distribute_duration_ms)
self._element_spec = _create_distributed_tensor_spec(
self._strategy, self._cloned_datasets[0].element_spec)
self._built = True
def auto_shard(self, num_shards, shard_ix):
assert (
len(self._cloned_datasets) == len(self._input_workers.worker_devices)
), (
f"datasets: {len(self._cloned_datasets)}, "
f"input workers: {len(self._input_workers.worker_devices)}"
)
sharded_datasets = []
for i in range(len(self._input_workers.worker_devices)):
with ops.colocate_with(self._cloned_datasets[i]._variant_tensor): # pylint:disable=protected-access
sharded_datasets.append(
input_ops.auto_shard_dataset(
self._cloned_datasets[i], num_shards, shard_ix,
self._num_replicas_in_sync
))
return DistributedDataset(
self._input_workers,
self._strategy,
components=sharded_datasets,
element_spec=self._element_spec,
options=self._options,
enable_get_next_as_optional=self._enable_get_next_as_optional)
@property
def cardinality(self):
if not self._built:
raise ValueError(
"Cannot get the cardinality of a dataset that is not built")
return self._cardinality
def _create_cloned_datasets_from_dataset(self, dataset, input_context,
input_workers, strategy,
num_replicas_in_sync):
# We clone and shard the dataset on each worker. The current setup tries to
# shard the dataset by files if possible so that each worker sees a
# different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset.
# Additionally, we rebatch the dataset on each worker into
# `num_replicas_in_sync` smaller batches to be distributed among that
# worker's replicas, so that the batch size for a global step (across all
# workers and replicas) adds up to the original dataset's batch size.
if num_replicas_in_sync is not None and num_replicas_in_sync > 1:
num_workers = input_context.num_input_pipelines if input_context else len(
input_workers.worker_devices)
rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
num_replicas_in_sync)
else:
rebatch_fn = None
self._cloned_datasets = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
if rebatch_fn is not None:
dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
dataset = input_ops.auto_shard_dataset(dataset,
input_context.num_input_pipelines,
input_context.input_pipeline_id,
num_replicas_in_sync)
self._cloned_datasets.append(dataset)
else:
replicated_ds = distribute.replicate(dataset,
input_workers.worker_devices)
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
cloned_dataset = replicated_ds[worker]
if rebatch_fn is not None:
cloned_dataset = rebatch_fn(cloned_dataset, i)
cloned_dataset = input_ops.auto_shard_dataset(
cloned_dataset, len(input_workers.worker_devices), i,
num_replicas_in_sync)
self._cloned_datasets.append(cloned_dataset)
def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
"""Returns a callable that rebatches the input dataset.
Args:
dataset: A `tf.data.Dataset` representing the dataset to be distributed.
num_workers: An integer representing the number of workers to distribute
`dataset` among.
num_replicas_in_sync: An integer representing the number of replicas in
sync across all workers.
"""
if num_replicas_in_sync % num_workers:
raise ValueError(
"tf.distribute expects every worker to have the same number of "
"replicas. However, encountered `num_replicas_in_sync` ({}) that "
"cannot be divided by `num_workers` ({})".format(
num_replicas_in_sync, num_workers))
num_replicas_per_worker = num_replicas_in_sync // num_workers
with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
batch_size = distribute.compute_batch_size(dataset)
def rebatch_fn(dataset, worker_index):
try:
def apply_rebatch():
batch_sizes = distribute.batch_sizes_for_worker(
batch_size, num_workers, num_replicas_per_worker, worker_index)
return dataset.rebatch(batch_sizes).prefetch(num_replicas_per_worker)
# pylint: disable=protected-access
def apply_legacy_rebatch():
return distribute._LegacyRebatchDataset(
dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
with ops.colocate_with(dataset._variant_tensor):
return tf_cond.cond(
math_ops.not_equal(batch_size, -1),
true_fn=apply_rebatch,
false_fn=apply_legacy_rebatch)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please see "
"the tf.distribute.Strategy guide. {}".format(
num_replicas_in_sync, e)),
sys.exc_info()[2])
else:
raise
return rebatch_fn
def __iter__(self):
if not (context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
if not self._built:
raise ValueError("To use this dataset, you need to pass this dataset to "
"ClusterCoordinator.create_per_worker_dataset.")
canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
True)
worker_iterators = _create_iterators_per_worker(
self._cloned_datasets,
self._input_workers,
options=self._options,
canonicalize_devices=canonicalize_devices)
iterator = DistributedIterator(
self._input_workers,
worker_iterators,
self._strategy,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options,
replica_order=self._replica_order,
)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
# initialization before passing to a multi device function, add a sync point
# here to make sure all underlying iterators are initialized.
if context.executing_eagerly():
context.async_wait()
return iterator
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@property
def _type_spec(self):
return DistributedDatasetSpec(
self._input_workers,
self._element_spec,
self._strategy,
self._options,
enable_get_next_as_optional=self._enable_get_next_as_optional)
class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec):
"""Type specification for `DistributedDatasetsFromFunction."""
@property
def value_type(self):
return DistributedDatasetsFromFunction
@property
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i, _ in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(dataset_ops.DatasetSpec(element_spec))
return specs
def _to_components(self, value):
return value._datasets # pylint: disable=protected-access
def _from_components(self, components):
return DistributedDatasetsFromFunction(
input_workers=self._input_workers,
strategy=self._strategy,
components=components,
element_spec=self._element_spec,
options=self._options)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedDatasetsFromFunctionSpec(
input_workers=value._input_workers,
element_spec=value._element_spec,
strategy=value._strategy,
options=value._options)
# TODO(priyag): Add other replication modes.
class DistributedDatasetsFromFunction(_IterableInput,
composite_tensor.CompositeTensor):
"""Inputs created from dataset function."""
def __init__(
self,
input_workers,
strategy,
input_contexts=None,
dataset_fn=None,
options=None,
components=None,
element_spec=None,
build=True,
replica_order=None,
):
"""Makes an iterable from datasets created by the given function.
Args:
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `dataset_fn`. Length and order should match worker order in
`worker_device_pairs`.
dataset_fn: A function that returns a `Dataset` given an `InputContext`.
Either dataset_fn or components should be passed to construct
DistributedDatasetsFromFunction. Use this when constructing
DistributedDataset using a function. Use components when constructing
using DistributedDatasetsFromFunctionSpec.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
components: datasets when DistributedDatasetsFromFunction is constructed
from DistributedDatasetsFromFunctionSpec. Only one of dataset or
components should be passed.
element_spec: element spec for DistributedDataset when constructing from
DistributedDatasetSpec. This will be used to set the element_spec for
DistributedDatasetsFromFunctionSpec and verified against element_spec
from components.
build: whether to build underlying datasets when this object is created.
This is only useful for `ParameterServerStrategy` now.
replica_order: the order of the replicas, which will be used to reorder
the iterators to match the device order.
"""
super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers)
self._input_workers = input_workers
self._strategy = strategy
self._options = options
self._replica_order = replica_order
if dataset_fn is not None and components is not None:
raise ValueError("Only one of dataset_fn or components should be set")
if dataset_fn is None and components is None:
raise ValueError("At least one of dataset_fn or components should be set")
if dataset_fn is not None:
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
self._input_contexts = input_contexts
self._num_replicas_in_sync = self._input_contexts[0].num_replicas_in_sync
self._dataset_fn = dataset_fn
self._built = False
if build:
self.build()
else:
if element_spec is None:
raise ValueError(
"element_spec should also be passed when passing components")
if not build:
raise ValueError(
"When constructing DistributedDatasetFromFunction with components, "
"build should not be False. This is an internal error. Please file "
"a bug.")
self._element_spec = element_spec
self._datasets = components
self._num_replicas_in_sync = None
self._built = True
self._cardinality = _cardinality(self._datasets[0])
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, self._datasets[0], self._cardinality)
def build(self):
assert not self._built
distribute_start_time_ns = time.time_ns()
self._datasets, element_spec = (
_create_datasets_from_function_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn))
if context.executing_eagerly():
# Records the time to initialize the distributed dataset.
context.async_wait()
distribute_duration_ms = (time.time_ns() -
distribute_start_time_ns) // 1_000_000
_distributed_dataset_from_function_initialization_time_milliseconds.get_cell(
self._strategy.__class__.__name__,
str(self._input_workers.num_workers)).add(distribute_duration_ms)
self._element_spec = _create_distributed_tensor_spec(
self._strategy, element_spec)
self._cardinality = _cardinality(self._datasets[0])
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, self._datasets[0], self._cardinality)
self._built = True
def auto_shard(self, num_shards, shard_ix):
assert (
len(self._datasets) == len(self._input_workers.worker_devices)
), (
f"datasets: {len(self._datasets)}, "
f"input workers: {len(self._input_workers.worker_devices)}"
)
sharded_datasets = []
for i in range(len(self._input_workers.worker_devices)):
with ops.colocate_with(self._datasets[i]._variant_tensor): # pylint: disable=protected-access
sharded_datasets.append(
input_ops.auto_shard_dataset(
self._datasets[i], num_shards, shard_ix,
self._num_replicas_in_sync
)
)
return DistributedDatasetsFromFunction(self._input_workers, self._strategy,
components=sharded_datasets,
element_spec=self._element_spec,
options=self._options)
@property
def cardinality(self):
if not self._built:
raise ValueError(
"Cannot get the cardinality of a dataset that is not built")
return self._cardinality
def __iter__(self):
if not (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
if not self._built:
raise ValueError("You need to use this dataset in "
"ClusterCoordinator.create_per_worker_dataset.")
canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
True)
iterators = _create_iterators_per_worker(
self._datasets,
self._input_workers,
options=self._options,
canonicalize_devices=canonicalize_devices)
iterator = DistributedIterator(
input_workers=self._input_workers,
iterators=iterators,
strategy=self._strategy,
cardinality=self._cardinality,
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options,
replica_order=self._replica_order,
)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
# initialization before passing to a multi device function, add a sync
# point here to make sure all underlying iterators are initialized.
if context.executing_eagerly():
context.async_wait()
return iterator
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@property
def _type_spec(self):
return DistributedDatasetsFromFunctionSpec(self._input_workers,
self._element_spec,
self._strategy, self._options)
def _dummy_tensor_fn(value_structure):
"""A function to create dummy tensors from `value_structure`."""
def create_dummy_tensor(spec):
"""Create a dummy tensor with possible batch dimensions set to 0."""
if hasattr(spec, "_create_empty_value"):
# Type spec may overwrite default dummy values behavior by declaring the
# `_create_empty_value(self)` method. This method must return a value
# compatible with the type spec with batch dimensions set to 0 or fail if
# such a value does not exist. This allows a composite tensor to customize
# dummy values creation as, in general, its dummy value is not composed
# from dummy components (e.g. `row_splits` tensor of a RaggedTensor is
# never allowed to be empty). See b/183969859 for more discussions.
# TODO(b/186079336): reconsider CompositeTensor support.
return spec._create_empty_value() # pylint: disable=protected-access
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Splice out the ragged dimensions.
# pylint: disable=protected-access
feature_shape = spec._shape[:1].concatenate(
spec._shape[(1 + spec._ragged_rank):])
feature_type = spec._dtype
# pylint: enable=protected-access
else:
feature_shape = spec.shape
feature_type = spec.dtype
# Ideally we should set the batch dimension to 0, however as in
# DistributionStrategy we don't know the batch dimension, we try to
# guess it as much as possible. If the feature has unknown dimensions, we
# will set them to 0. If the feature shape is already static, we guess the
# first dimension as batch dimension and set it to 0.
dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
if feature_shape else [])
if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
feature_shape.is_fully_defined()):
dims[0] = tensor_shape.Dimension(0)
if isinstance(spec, sparse_tensor.SparseTensorSpec):
return sparse_tensor.SparseTensor(
values=array_ops.zeros(0, feature_type),
indices=array_ops.zeros((0, len(dims)), dtypes.int64),
dense_shape=dims)
# Create the dummy tensor.
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Reinsert the ragged dimensions with size 0.
# pylint: disable=protected-access
row_splits = array_ops.zeros(1, spec._row_splits_dtype)
dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
# pylint: enable=protected-access
return dummy_tensor
return nest.map_structure(create_dummy_tensor, value_structure)
def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
"""Returns the value of the optionals or dummy values.
Args:
input_workers: the `InputWorkers`.
optional_list: a list of lists `tf.experimental.Optional`. The values from
each compute device grouped by the input device.
produce_dummy: a bool. Whether to produce dummy tensors when the optional
doesn't have a value.
Returns:
A flatten list of Tensors.
"""
value_list = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
devices = input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
if produce_dummy:
# pylint: disable=cell-var-from-loop
value_list.append(
tf_cond.cond(
optional_list[i][j].has_value(),
lambda: optional_list[i][j].get_value(), # pylint: disable=unnecessary-lambda
lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
strict=True,
))
# pylint: enable=cell-var-from-loop
else:
value_list.append(optional_list[i][j].get_value())
return value_list
class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices, options=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
input to the devices on the given worker.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
options: options.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._element_spec = dataset.element_spec
self._options = options
self._make_iterator()
def _make_iterator(self):
raise NotImplementedError("must be implemented in descendants")
def _format_data_list_with_options(self, data_list):
"""Change the data in to a list type if required.
The OwnedMultiDeviceIterator returns the list data type,
while the PER_REPLICA iterator (when used with prefetch disabled)
returns without the enclosed list. This is to fix the inconsistency.
Args:
data_list: data_list
Returns:
list
"""
if (self._options and self._options.experimental_replication_mode ==
InputReplicationMode.PER_REPLICA and
not self._options.experimental_fetch_to_device):
return [data_list]
else:
return data_list
def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
if _should_use_multi_device_iterator(self._options):
return self._iterator.get_next(device)
else:
return self._iterator.get_next()
def get_next_as_list(self, name=None):
"""Get next element from the underlying iterator.
Runs the iterator get_next() within a device scope. Since this doesn't use
get_next_as_optional(), it is considerably faster than get_next_as_list(),
but it raises EOFError if any of the device doesn't get any data.
Args:
name: not used.
Returns:
A list consisting of the next data from each device.
"""
del name
with ops.device(self._worker):
return self._format_data_list_with_options(self._iterator.get_next())
def get_next_as_optional_list(self):
with ops.device(self._worker):
return self._format_data_list_with_options(
self._iterator.get_next_as_optional())
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
__slots__ = [
"_worker", "_devices", "_element_spec", "_options",
"_canonicalize_devices"
]
def __init__(self, worker, devices, element_spec, options,
canonicalize_devices=True):
self._worker = worker
if canonicalize_devices:
self._devices = tuple(device_util.canonicalize(d) for d in devices)
else:
self._devices = tuple(
device_util.canonicalize_without_job_and_task(d) for d in devices)
self._element_spec = element_spec
# `self._options` intentionally made not `None` for proper serialization.
self._options = (options if options is not None else
distribute_lib.InputOptions())
self._canonicalize_devices = canonicalize_devices
@property
def value_type(self):
return _SingleWorkerOwnedDatasetIterator
def _serialize(self):
return (self._worker, self._devices, self._element_spec, self._options,
self._canonicalize_devices)
def _get_multi_device_iterator_spec(self, specs):
device_scope = device_util.canonicalize(self._worker, device_util.current())
host_device = device_util.get_host_for_device(device_scope)
# source_device while creating iterator governs the worker device in
# iterator spec.
worker = host_device
specs.append(
multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, worker, element_spec=self._element_spec))
@property
def _component_specs(self):
specs = []
if _should_use_multi_device_iterator(self._options):
self._get_multi_device_iterator_spec(specs)
else:
specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
return specs
def _to_components(self, value):
return [value._iterator] # pylint: disable=protected-access
def _from_components(self, components):
return _SingleWorkerOwnedDatasetIterator(
dataset=None,
worker=self._worker,
devices=self._devices,
components=components,
element_spec=self._element_spec,
options=self._options,
canonicalize_devices=self._canonicalize_devices)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
value._element_spec, value._options,
value._canonicalize_devices)
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
composite_tensor.CompositeTensor):
"""Iterator for a DistributedDataset instance."""
def __init__(self,
dataset=None,
worker=None,
devices=None,
components=None,
element_spec=None,
options=None,
canonicalize_devices=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
given worker. The lifetime of this iterator is tied to the encompassing
python object. Once we go out of scope of the python object or return from
a tf.function the underlying iterator resource is deleted.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
components: Tensor components to construct the
_SingleWorkerOwnedDatasetIterator from.
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of elements of the iterator.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
canonicalize_devices: Whether to canonicalize devices for workers fully or
partially. If False, it will partially canonicalize devices by removing
job and task.
"""
if worker is None or devices is None:
raise ValueError("Both `worker` and `devices` should be provided")
error_message = ("Either `dataset` or both `components` and `element_spec` "
"need to be provided.")
self._options = options
self._canonicalize_devices = canonicalize_devices
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._worker = worker
self._devices = devices
self._iterator = components[0]
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
super(_SingleWorkerOwnedDatasetIterator,
self).__init__(dataset, worker, devices, self._options)
def _create_owned_multi_device_iterator(self):
# If the worker devices are already canonicalized, canonicalizing again
# would have no impact.
# For strategies running on remote workers such as PS Strategy, the device
# scope will be derived from current worker, if used under init_scope().
if not ops.inside_function():
device_scope = device_util.canonicalize(self._worker,
device_util.current())
host_device = device_util.get_host_for_device(device_scope)
else:
# In general, iterators should not be created within tf.functions. For
# exact visitation guarantee solutions for parameter server training,
# however, we do create iterators within the tf.functions that are
# dispatched to workers. In these cases, the traced device must match the
# runtime device. Since tracing occurs on the chief, we do not want to use
# the current device scope, which would be the chief, but rather use the
# relative worker device scope explicitly.
device_scope, host_device = self._worker, self._worker
with ops.device(device_scope):
if self._options is not None:
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset,
self._devices,
source_device=host_device,
max_buffer_size=self._options
.experimental_per_replica_buffer_size,
prefetch_buffer_size=self._options
.experimental_per_replica_buffer_size)
else:
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
if not self._worker:
raise ValueError("Worker device must be specified when creating an "
"owned iterator.")
if _should_use_multi_device_iterator(self._options):
self._create_owned_multi_device_iterator()
else:
with ops.device(self._worker):
self._iterator = iter(self._dataset)
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
self._element_spec, self._options,
self._canonicalize_devices)
@property
def output_classes(self):
"""Returns the class of each component of an element of this iterator.
The expected values are `tf.Tensor` and `tf.SparseTensor`.
Returns:
A nested structure of Python `type` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_shapes(self):
"""Returns the shape of each component of an element of this iterator.
Returns:
A nested structure of `tf.TensorShape` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_types(self):
"""Returns the type of each component of an element of this iterator.
Returns:
A nested structure of `tf.DType` objects corresponding to each component
of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._element_spec)
def _create_iterators_per_worker(worker_datasets,
input_workers,
options=None,
canonicalize_devices=False):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerOwnedDatasetIterator(
dataset=worker_datasets[i],
worker=worker,
devices=worker_devices,
options=options,
canonicalize_devices=canonicalize_devices)
iterators.append(iterator)
return iterators
def _create_datasets_from_function_with_input_context(input_contexts,
input_workers,
dataset_fn):
"""Create device datasets per worker given a dataset function."""
datasets = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
dataset = dataset_fn(ctx)
datasets.append(dataset)
return datasets, dataset.element_spec
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
"""Get the batched dataset from `d`."""
# pylint: disable=protected-access
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, (dataset_ops.PrefetchDataset,
dataset_ops._OptionsDataset)):
return _get_batched_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
def _get_batched_dataset_attributes(d):
"""Get `batch_size`, `drop_remainder` of dataset."""
# pylint: disable=protected-access
assert isinstance(d,
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
if isinstance(d, dataset_ops.BatchDataset):
batch_size = d._batch_size
drop_remainder = d._drop_remainder
elif isinstance(d, batching._MapAndBatchDataset):
batch_size = d._batch_size_t
drop_remainder = d._drop_remainder_t
# pylint: enable=protected-access
if tensor_util.is_tf_type(batch_size):
batch_size = tensor_util.constant_value(batch_size)
if tensor_util.is_tf_type(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
return batch_size, drop_remainder
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_dataset_attributes(dataset):
"""Get the underlying attributes from the dataset object."""
# pylint: disable=protected-access
# First, get batch_size and drop_remainder from the dataset. We need
# to walk back the dataset creation process and find the batched version in
# order to get the attributes.
batched_dataset = _get_batched_dataset(dataset)
batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
# Second, prefetch buffer should be get from the original dataset.
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
prefetch_buffer = dataset._buffer_size
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
prefetch_buffer = dataset._dataset._buffer_size
return batch_size, drop_remainder, prefetch_buffer
def _should_use_multi_device_iterator(options):
"""Determine whether to use multi_device_iterator_ops."""
if (options is None or
options.experimental_replication_mode == InputReplicationMode.PER_WORKER
or
(options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
and options.experimental_fetch_to_device)):
return True
return False
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
function to specify which outputs to emit at what frequency. Currently it
supports capturing output from the last step, as well as capturing non tensor
outputs. In the future it will be augmented to support other use cases such
as output each N steps.
"""
def __init__(self):
"""Initialize an output context.
Returns:
A context object.
"""
self._last_step_outputs = {}
self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {}
@property
def last_step_outputs(self):
"""A dictionary consisting of outputs to be captured on last step.
Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with a `reduce_op` for this output,
then the value is the reduced value.
Returns:
A dictionary with last step outputs.
"""
return self._last_step_outputs
def _set_last_step_outputs(self, outputs):
"""Replace the entire dictionary of last step outputs."""
if not isinstance(outputs, dict):
raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribute_lib.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribute_lib.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribute_lib.get_replica_context().merge_call(
merge_fn, args=(output,))
@property
def non_tensor_outputs(self):
"""A dictionary consisting of any non tensor outputs to be captured."""
return self._non_tensor_outputs
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribute_lib.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = (
distribution.experimental_local_results(value))
distribute_lib.get_replica_context().merge_call(
merge_fn, args=(output,))
def _create_distributed_tensor_spec(strategy, tensor_spec):
"""Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
Args:
strategy: The given `tf.distribute` strategy.
tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
shape should be None if you have partial batches.
Returns:
A `tf.TypeSpec` that matches the values produced by a given strategy. This
can be a `tf.TensorSpec` or a `PerRelicaSpec`.
"""
num_replicas = len(strategy.extended.worker_devices)
# For one device strategy that is not MultiWorkerMirroredStrategy, return the
# tensor_spec as is, since we don't wrap the output with PerReplica in this
# case.
# TODO(b/166464552): remove after we always wrap for all strategies.
if not _always_wrap(strategy):
return tensor_spec
# For other cases we assume the input to tf.function is a per replica type.
def _get_value_per_replica(tensor_spec_per_input):
value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
return values.PerReplicaSpec(*value_specs)
return nest.map_structure(_get_value_per_replica, tensor_spec)
def _replace_per_replica_spec(spec, i):
"""If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
if isinstance(spec, values.PerReplicaSpec):
return spec._value_specs[i] # pylint: disable=protected-access
else:
return spec
def _cardinality(dataset):
"""Returns the cardinality of the dataset."""
if context.executing_eagerly():
with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
return dataset.cardinality().numpy()
return cardinality_lib.UNKNOWN
def _enable_get_next_as_optional(strategy, dataset, cardinality):
"""Returns whether to enable using partial batch handling."""
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if not getattr(
strategy.extended, "enable_partial_batch_handling",
getattr(strategy.extended, "experimental_enable_get_next_as_optional",
False)):
return False
# If the dataset is infinite, we don't need to enable last partial batch
# support. Note that we can only evaluate the cardinality of the dataset in
# eager.
if cardinality == cardinality_lib.INFINITE:
return False
return not _is_statically_shaped(
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
def _create_per_replica(value_list, strategy):
"""Creates a PerReplica.
For strategies other than OneDeviceStrategy, it creates a PerReplica whose
type spec is set to the element spec of the dataset. This helps avoid
retracing for partial batches. Retracing is problematic for multi client when
different client retraces different time, since retracing changes the
collective keys in the tf.function, and causes mismatches among clients.
For single client strategies, this simply calls distribute_utils.regroup().
Args:
value_list: a list of values, one for each replica.
strategy: the `tf.distribute.Strategy`.
Returns:
a structure of PerReplica.
"""
# TODO(b/166464552): always wrap for all one device strategies as well.
always_wrap = _always_wrap(strategy)
per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
return per_replicas
def _always_wrap(strategy):
"""Returns whether to always wrap the values in a DistributedValues."""
return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access
strategy.extended.worker_devices) > 1
def _rebatch_as_dynamic(per_replica_spec):
"""Rebatch the spec to have a dynamic batch dimension."""
assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
# pylint: disable=protected-access
def _rebatch(spec):
# Rebatch if possible.
try:
return spec._unbatch()._batch(None)
except ValueError:
pass
return spec
return values.PerReplicaSpec(
*nest.map_structure(_rebatch, per_replica_spec._value_specs))
# pylint: enable=protected-access
def _ag_enumerate_not_implemented(s, unused_start):
msg = (
f"enumerate not supported with {s.__class__.__name__} types within "
"tf.functions. Use a for loop over the dataset and keep a separate "
"counter instead."
)
raise NotImplementedError(msg)
py_builtins.enumerate_registry.register(
DistributedIterator, _ag_enumerate_not_implemented
)
py_builtins.enumerate_registry.register(
DistributedDataset, _ag_enumerate_not_implemented
)