blob: b87ded9aee90a708fac40f5c2c6fe53ac89e58e1 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
import copy
import threading
import time
import weakref
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import input_util
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.trackable import base
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tsl.protobuf import coordination_config_pb2
# pylint: disable=line-too-long
@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
class CollectiveAllReduceStrategy(distribute_lib.Strategy):
"""A distribution strategy for synchronous training on multiple workers.
This strategy implements synchronous distributed training across multiple
workers, each with potentially multiple GPUs. Similar to
`tf.distribute.MirroredStrategy`, it replicates all variables and computations
to each local device. The difference is that it uses a distributed collective
implementation (e.g. all-reduce), so that multiple workers can work together.
You need to launch your program on each worker and configure
`cluster_resolver` correctly. For example, if you are using
`tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
is:
```
TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
```
Your program runs on each worker as-is. Note that collectives require each
worker to participate. All `tf.distribute` and non `tf.distribute` API may use
collectives internally, e.g. checkpointing and saving since reading a
`tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
Therefore it's recommended to run exactly the same program on each worker.
Dispatching based on `task_type` or `task_id` of the worker is error-prone.
`cluster_resolver.num_accelerators()` determines the number of GPUs the
strategy uses. If it's zero, the strategy uses the CPU. All workers need to
use the same number of devices, otherwise the behavior is undefined.
This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
instead.
After setting up TF_CONFIG, using this strategy is similar to using
`tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
```
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(2, input_shape=(5,)),
])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
x = np.random.random((2, 5)).astype(np.float32)
y = np.random.randint(2, size=(2, 1))
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
model.compile()
model.fit(dist_dataset)
```
You can also write your own training loop:
```
@tf.function
def train_step(iterator):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
strategy.run(step_fn, args=(next(iterator),))
for _ in range(NUM_STEP):
train_step(iterator)
```
See
[Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
for a detailed tutorial.
__Saving__
You need to save and checkpoint on all workers instead of just one. This is
because variables whose synchronization=ON_READ triggers aggregation during
saving. It's recommended to save to a different path on each worker to avoid
race conditions. Each worker saves the same thing. See
[Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
tutorial for examples.
__Known Issues__
* `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
correct number of accelerators. The strategy uses all available GPUs if
`cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
or `None`.
* In eager mode, the strategy needs to be created before calling any other
Tensorflow API.
"""
# pylint: enable=line-too-long
# TODO(anjalisridhar): Update our guides with examples showing how we can use
# the cluster_resolver argument.
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__(self,
cluster_resolver=None,
communication_options=None):
"""Creates the strategy.
Args:
cluster_resolver: optional
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
communication_options: optional
`tf.distribute.experimental.CommunicationOptions`. This configures the
default options for cross device communications. It can be overridden by
options provided to the communication APIs like
`tf.distribute.ReplicaContext.all_reduce`. See
`tf.distribute.experimental.CommunicationOptions` for details.
"""
if communication_options is None:
communication_options = collective_util.Options()
super(CollectiveAllReduceStrategy, self).__init__(
CollectiveAllReduceExtended(
self,
cluster_resolver=cluster_resolver,
communication_options=communication_options))
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"MultiWorkerMirroredStrategy")
# pylint: disable=protected-access
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended._num_workers)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended._num_devices_per_worker)
@classmethod
def _from_local_devices(cls, devices, communication_options=None):
"""A convenience method to create an object with a list of devices."""
obj = cls(communication_options=communication_options)
obj.extended._initialize_local( # pylint: disable=protected-access
tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices)
return obj
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy`
provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If
the user provides one in `__init__`, that instance is returned; if the user
does not, a default `TFConfigClusterResolver` is provided.
"""
return self.extended._cluster_resolver # pylint: disable=protected-access
class _CollectiveAllReduceStrategyExperimentalMeta(type):
@classmethod
def __instancecheck__(cls, instance):
# This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
# tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
# performing such check.
return isinstance(instance, CollectiveAllReduceStrategy)
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
class _CollectiveAllReduceStrategyExperimental(
CollectiveAllReduceStrategy,
metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
__doc__ = CollectiveAllReduceStrategy.__doc__
@deprecation.deprecated(
None, "use distribute.MultiWorkerMirroredStrategy instead")
def __init__(self,
communication=collective_util.CommunicationImplementation.AUTO,
cluster_resolver=None):
"""Creates the strategy.
Args:
communication: optional
`tf.distribute.experimental.CommunicationImplementation`. This is a hint
on the preferred collective communication implementation. Possible
values include `AUTO`, `RING`, and `NCCL`.
cluster_resolver: optional
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
"""
communication_options = collective_util.Options(
implementation=communication)
super(_CollectiveAllReduceStrategyExperimental,
self).__init__(cluster_resolver, communication_options)
@classmethod
def _from_local_devices(
cls,
devices,
communication=collective_util.CommunicationImplementation.AUTO):
"""A convenience method to create an object with a list of devices."""
obj = cls(communication)
obj.extended._initialize_local(tfconfig_cluster_resolver.TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
return obj
_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
__doc__ = CollectiveAllReduceStrategy.__doc__
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__(self,
communication=collective_util.CommunicationImplementation.AUTO,
cluster_resolver=None):
"""Initializes the object."""
communication_options = collective_util.Options(
implementation=communication)
super(CollectiveAllReduceStrategyV1, self).__init__(
CollectiveAllReduceExtended(
self,
cluster_resolver=cluster_resolver,
communication_options=communication_options))
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
"MultiWorkerMirroredStrategy")
# pylint: disable=protected-access
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended._num_workers)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_gpu_per_worker").set(
self.extended._num_devices_per_worker
if self.extended._local_device_type == "GPU"
else 0)
def _is_gpu_device(device):
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
"""Implementation of CollectiveAllReduceStrategy."""
# Whether to perdically check the health of the cluster. If any worker is not
# reachable, collectives are aborted and the user program should get a
# tf.errors.UnavailableError. It's required to restart in order to recover.
_enable_check_health = True
# Check health interval in seconds.
_check_health_interval = 30
# Timeout in seconds for the first check health. The first check health needs
# to wait for cluster, which may make a longer time.
_check_health_initial_timeout = 0
# Times to retry before considering the peer is down.
_check_health_retry_limit = 3
# Timeout in seconds the each check health.
_check_health_timeout = 10
def __init__(self, container_strategy, cluster_resolver,
communication_options, devices=None):
if not isinstance(communication_options, collective_util.Options):
raise ValueError("communication_options must be an instance of "
"tf.distribute.experimental.CommunicationOptions")
if cluster_resolver and devices:
raise ValueError(
"cluster_resolver and devices cannot be set at the same time")
self._cluster_resolver = cluster_resolver or tfconfig_cluster_resolver.TFConfigClusterResolver()
if not isinstance(self._cluster_resolver, cluster_resolver_lib.ClusterResolver):
raise ValueError("cluster_resolver must be an instance of "
"tf.distribute.cluster_resolver.ClusterResolver")
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
self._communication_options = communication_options
self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access
self._initialize_strategy(self._cluster_resolver, devices=devices)
self._cfer_fn_cache = weakref.WeakKeyDictionary()
self.experimental_enable_get_next_as_optional = True
assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce)
def _use_merge_call(self):
# We currently only disable merge_call when XLA is used to compile the `fn`
# passed to `strategy.run` and all devices are GPU.
return not control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph()) or not all(
[_is_gpu_device(d) for d in self._devices])
def _initialize_strategy(self, cluster_resolver, devices):
# If devices are provided or cluster_spec is not specified, initialize
# single worker. Otherwise initialize multi workers.
if devices or not cluster_resolver.cluster_spec().as_dict():
self._initialize_local(cluster_resolver, devices=devices)
else:
self._initialize_multi_worker(cluster_resolver)
def _initialize_local_devices(self, cluster_resolver, worker_device):
# TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
# some cases.
if isinstance(cluster_resolver, tfconfig_cluster_resolver.TFConfigClusterResolver):
num_gpus = context.num_gpus()
num_tpus = 0
else:
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
num_tpus = cluster_resolver.num_accelerators().get("TPU", 0)
if num_gpus:
local_device_type = "GPU"
num_local_devices = num_gpus
elif num_tpus:
local_device_type = "TPU"
num_local_devices = num_tpus
else:
local_device_type = "CPU"
num_local_devices = 1
local_devices = tuple(
f"{worker_device}/device:{local_device_type}:{i}"
for i in range(num_local_devices))
return local_devices, local_device_type
def _initialize_local(self, cluster_resolver, devices=None):
"""Initializes the object for local training."""
self._is_chief = True
self._num_workers = 1
if ops.executing_eagerly_outside_functions():
try:
context.context().configure_collective_ops(
scoped_allocator_enabled_ops=("CollectiveReduce",))
except RuntimeError:
logging.warning("Collective ops is not configured at program startup. "
"Some performance features may not be enabled.")
self._collective_ops_configured = True
if devices:
local_devices = devices
if "GPU" in devices[0]:
local_device_type = "GPU"
elif "TPU" in devices[0]:
local_device_type = "TPU"
else:
local_device_type = "CPU"
else:
local_devices, local_device_type = self._initialize_local_devices(
cluster_resolver, worker_device="")
self._worker_device = device_util.canonicalize("/device:CPU:0")
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + self._collective_key_base)
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices,
group_size=len(local_devices),
options=self._communication_options,
collective_keys=self._collective_keys)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
options=self._communication_options,
collective_keys=self._collective_keys)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
self._cluster_spec = None
self._task_type = None
self._task_id = None
self._id_in_cluster = 0
# This is a mark to tell whether we are running with standalone client or
# independent worker. Right now with standalone client, strategy object is
# created as local strategy and then turn into multi-worker strategy via
# configure call.
self._local_or_standalone_client_mode = True
# Save the num_devices_per_worker and rpc_layer for configure method.
self._num_devices_per_worker = len(local_devices)
self._local_device_type = local_device_type
self._rpc_layer = cluster_resolver.rpc_layer
self._warn_nccl_no_gpu()
logging.info(
"Single-worker MultiWorkerMirroredStrategy with local_devices "
"= %r, communication = %s", local_devices,
self._communication_options.implementation)
def _initialize_multi_worker(self, cluster_resolver):
"""Initializes the object for multi-worker training."""
cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_resolver.cluster_spec())
task_type = cluster_resolver.task_type
task_id = cluster_resolver.task_id
if task_type is None or task_id is None:
raise ValueError("When `cluster_spec` is given, you must also specify "
"`task_type` and `task_id`.")
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
self._id_in_cluster = multi_worker_util.id_in_cluster(
self._cluster_spec, self._task_type, self._task_id)
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
if not self._num_workers:
raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
"in `cluster_spec`.")
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
task_id)
self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
if (ops.executing_eagerly_outside_functions() and
not getattr(self, "_local_or_standalone_client_mode", False)):
context.context().configure_collective_ops(
collective_leader=multi_worker_util.collective_leader(
cluster_spec, task_type, task_id),
scoped_allocator_enabled_ops=("CollectiveReduce",),
device_filters=("/job:%s/task:%d" % (task_type, task_id),))
self._collective_ops_configured = True
if context.context().coordination_service is None:
coordinated_jobs = ["chief", "worker"]
if task_type in coordinated_jobs:
coordinated_job_config = []
for job in coordinated_jobs:
if job in cluster_spec.jobs:
coordinated_job_config.append(
coordination_config_pb2.CoordinatedJob(
name=job,
num_tasks=cluster_spec.num_tasks(job)))
context.context().configure_coordination_service(
service_type="standalone",
service_leader=multi_worker_util.coordination_leader(
cluster_spec),
coordinated_jobs=coordinated_job_config)
# Starting a std server in eager mode and in independent worker mode.
if (context.executing_eagerly() and
not getattr(self, "_std_server_started", False) and
not getattr(self, "_local_or_standalone_client_mode", False)):
# Checking _local_or_standalone_client_mode as well because we should not
# create the std server in standalone client mode.
config_proto = copy.deepcopy(context.context().config)
config_proto = self._update_config_proto(config_proto)
# If coordination service is enabled, use its internal heartbeat to detect
# peer failures instead of the Python-level health check.
if config_proto.experimental.coordination_config.service_type:
self._enable_check_health = False
if hasattr(cluster_resolver, "port"):
port = cluster_resolver.port
else:
port = 0
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_spec.as_cluster_def(),
default_session_config=config_proto,
job_name=task_type,
task_index=task_id,
protocol=cluster_resolver.rpc_layer or "grpc",
port=port)
context.context().enable_collective_ops(server_def)
self._std_server_started = True
# The `ensure_initialized` is needed before calling
# `context.context().devices()`.
context.context().ensure_initialized()
logging.info(
"Enabled multi-worker collective ops with available devices: %r",
context.context().devices())
# TODO(yuefengz): The `num_gpus` is only for this particular task. It
# assumes all workers have the same number of GPUs. We should remove this
# assumption by querying all tasks for their numbers of GPUs.
# TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
# some cases.
local_devices, local_device_type = self._initialize_local_devices(
cluster_resolver, self._worker_device)
if local_device_type == "TPU":
tpu_cluster_resolver.initialize_tpu_system()
self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + self._collective_key_base)
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices,
group_size=len(local_devices) * self._num_workers,
options=self._communication_options,
collective_keys=self._collective_keys)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
options=self._communication_options,
collective_keys=self._collective_keys)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
# Add a default device so that ops without specified devices will not end up
# on other workers.
self._default_device = "/job:%s/task:%d" % (task_type, task_id)
# Save the num_devices_per_worker and rpc_layer for configure method.
self._num_devices_per_worker = len(local_devices)
self._local_device_type = local_device_type
self._rpc_layer = cluster_resolver.rpc_layer
self._warn_nccl_no_gpu()
if self._enable_check_health and context.executing_eagerly():
self._start_check_health_thread()
else:
logging.info("Check health not enabled.")
logging.info(
"MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
"task_id = %r, num_workers = %r, local_devices = %r, "
"communication = %s", cluster_spec.as_dict(), task_type, task_id,
self._num_workers, local_devices,
self._communication_options.implementation)
def __del__(self):
self._stop_check_health_thread()
def _input_workers_with_options(self, options=None):
host_device = device_util.get_host_for_device(self._worker_device)
if not options or options.experimental_fetch_to_device:
return input_lib.InputWorkers([(host_device, self.worker_devices)])
else:
return input_lib.InputWorkers([(
host_device,
[device_util.get_host_for_device(worker) for worker in
self.worker_devices])])
@property
def _input_workers(self):
return self._input_workers_with_options()
def _get_variable_creator_initial_value(self,
replica_id,
device,
primary_var,
**kwargs):
if replica_id == 0: # First replica on each worker.
assert device is not None
assert primary_var is None
def initial_value_fn(): # pylint: disable=g-missing-docstring
# Only the first device participates in the broadcast of initial values.
group_key = self._collective_keys.get_group_key([device])
group_size = self._num_workers
collective_instance_key = (
self._collective_keys.get_instance_key(group_key, device))
with ops.device(device):
initial_value = kwargs["initial_value"]
if callable(initial_value):
initial_value = initial_value()
if isinstance(initial_value, base.CheckpointInitialValue):
initial_value = initial_value.wrapped_value
assert not callable(initial_value)
initial_value = ops.convert_to_tensor(
initial_value, dtype=kwargs.get("dtype", None))
if self._num_workers > 1:
if self._is_chief:
bcast_send = collective_ops.broadcast_send(
initial_value, initial_value.shape, initial_value.dtype,
group_size, group_key, collective_instance_key)
with ops.control_dependencies([bcast_send]):
return array_ops.identity(initial_value)
else:
return collective_ops.broadcast_recv(initial_value.shape,
initial_value.dtype,
group_size, group_key,
collective_instance_key)
return initial_value
return initial_value_fn
else:
return super(CollectiveAllReduceExtended,
self)._get_variable_creator_initial_value(
replica_id=replica_id,
device=device,
primary_var=primary_var,
**kwargs)
def _make_input_context(self):
input_context = distribute_lib.InputContext(
num_input_pipelines=self._num_workers,
input_pipeline_id=self._id_in_cluster,
num_replicas_in_sync=self._num_replicas_in_sync)
return input_context
def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy"
)
input_context = self._make_input_context()
return input_util.get_distributed_dataset(
dataset,
self._input_workers_with_options(options),
self._container_strategy(),
num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context,
options=options)
def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_context = self._make_input_context()
return input_util.get_distributed_datasets_from_function(
dataset_fn=dataset_fn,
input_workers=self._input_workers_with_options(options),
input_contexts=[input_context],
strategy=self._container_strategy(),
options=options)
def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = []
num_local_replicas = len(self.worker_devices)
for local_replica_id in range(num_local_replicas):
replica_id = (self._id_in_cluster * num_local_replicas +
local_replica_id)
value_context = distribute_lib.ValueContext(
replica_id, self._num_replicas_in_sync)
per_replica_values.append(value_fn(value_context))
return distribute_utils.regroup(per_replica_values, always_wrap=True)
def _make_dataset_iterator(self, dataset):
"""Distributes the dataset to each local GPU."""
input_context = self._make_input_context()
return input_lib_v1.DatasetIterator(
dataset,
self._input_workers,
self._container_strategy(),
num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
"""Distributes the input function to each local GPU."""
input_context = self._make_input_context()
return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
[input_context],
self._container_strategy())
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
"""Configures the object.
Args:
session_config: a `tf.compat.v1.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
task_id: the current task id.
Raises:
ValueError: if `task_type` is not in the `cluster_spec`.
"""
if cluster_spec:
cluster_resolver = cluster_resolver_lib.SimpleClusterResolver(
cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
task_type=task_type,
task_id=task_id,
num_accelerators={
self._local_device_type: self._num_devices_per_worker},
rpc_layer=self._rpc_layer)
self._initialize_multi_worker(cluster_resolver)
assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce)
if session_config:
session_config.CopyFrom(self._update_config_proto(session_config))
def _update_config_proto(self, config_proto):
updated_config = copy.deepcopy(config_proto)
# Enable the scoped allocator optimization for CollectiveOps. This
# optimization converts many small all-reduces into fewer larger
# all-reduces.
rewrite_options = updated_config.graph_options.rewrite_options
rewrite_options.scoped_allocator_optimization = (
rewriter_config_pb2.RewriterConfig.ON)
# We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
# ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we
# clear and then append.
del rewrite_options.scoped_allocator_opts.enable_op[:]
rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
if (not ops.executing_eagerly_outside_functions() and
self._communication_options.implementation ==
collective_util.CommunicationImplementation.NCCL):
updated_config.experimental.collective_nccl = True
if not self._cluster_spec:
return updated_config
assert self._task_type
assert self._task_id is not None
# Collective group leader is needed for collective ops to coordinate
# workers.
updated_config.experimental.collective_group_leader = (
multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
self._task_id))
# The device filters prevent communication between workers.
del updated_config.device_filters[:]
updated_config.device_filters.append(
"/job:%s/task:%d" % (self._task_type, self._task_id))
return updated_config
def _get_cross_device_ops(self, value):
# CollectiveAllReduce works on a predefined set of devices. In most cases
# they should be the compute devices, but certain use cases may reduce host
# tensors as well (e.g. early stopping). We infer the cross_device_ops to
# use based on the number of devices, since inputs don't always have device
# annotations. The compute devices one is preferred since we can potentially
# leverage NCCL.
if isinstance(value, values.DistributedValues):
num_devices = len(value._values) # pylint: disable=protected-access
else:
num_devices = 1
if num_devices == len(self.worker_devices):
return self._cross_device_ops
else:
return self._host_cross_device_ops
def _gather_to_implementation(self, value, destinations, axis, options):
return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access
value,
destinations=destinations,
axis=axis,
options=options)
def _reduce_to(self, reduce_op, value, destinations, options):
if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
assert not isinstance(value, values.Mirrored)
if (isinstance(value, values.DistributedValues) and
len(self.worker_devices) == 1):
value = value.values[0]
# When there are multiple workers, we need to reduce across workers using
# collective ops.
if (not isinstance(value, values.DistributedValues) and
self._num_workers == 1):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, len(self.worker_devices))
return self._get_cross_device_ops(value).reduce(
reduce_op,
value,
destinations=destinations,
options=self._communication_options.merge(options))
def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
"""Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
# This implementation avoids using `merge_call` and just launches collective
# ops in one replica.
if options is None:
options = collective_util.Options()
if context.executing_eagerly():
# In eager mode, falls back to the default implemenation that uses
# `merge_call`. Replica functions are running sequentially in eager mode,
# and due to the blocking nature of collective ops, execution will hang if
# collective ops are to be launched sequentially.
return super()._replica_ctx_all_reduce(reduce_op, value, options)
replica_context = distribute_lib.get_replica_context()
assert replica_context, (
"`StrategyExtended._replica_ctx_all_reduce` must be called in a "
"replica context")
return self._cross_device_ops._all_reduce( # pylint: disable=protected-access
reduce_op,
value,
replica_context._replica_id, # pylint: disable=protected-access
options)
def _check_health(self):
while True:
if self._check_health_thread_should_stop.is_set():
return
for job in self._cluster_spec.jobs:
for task_id in range(self._cluster_spec.num_tasks(job)):
peer = "/job:{}/replica:0/task:{}".format(job, task_id)
attempts = 0
while True:
attempts += 1
try:
context.context().check_collective_ops_peer_health(
peer, timeout_in_ms=self._check_health_timeout * 1000)
# If check_collective_ops_peer_health doesn't raise an Exception,
# the peer is healthy.
break
except (errors.UnavailableError, errors.FailedPreconditionError,
errors.DeadlineExceededError) as e:
# TODO(b/151232436): Always raise UnavailableError when a peer
# fails. Now there could be many kinds of errors:
# - Unavailable: when the peer is not reachable, e.g. it's down.
# - FailedPrecondition: when the peer has restarted.
if attempts < self._check_health_retry_limit:
logging.warning("%s seems down, retrying %d/%d", peer, attempts,
self._check_health_retry_limit)
continue
logging.error(
"Cluster check alive failed, %s is down, "
"aborting collectives: %s", peer, e)
context.context().abort_collective_ops(
errors.UNAVAILABLE,
"cluster check alive failed, {} is down".format(peer))
return
except Exception as e: # pylint: disable=broad-except
logging.error("Unexpected exception in check alive: %s", e)
context.context().abort_collective_ops(
errors.INTERNAL,
"unexecpted exception in check alive: %s" % e)
return
time.sleep(self._check_health_interval)
def _start_check_health_thread(self):
# Use a dummy all-reduce as a barrier to wait for all workers to be up,
# otherwise the check health may fail immediately.
# Use array_ops.identity to create the dummy tensor so that we have a new
# Tensor. If we use constant it may be a cached from on a /job:localhost
# device, which will cause some code that relies on tensor.device to error.
#
# TODO(b/151232436): change to an explicit barrier if we have it.
dummy_value = array_ops.identity([])
logging.info("Waiting for the cluster, timeout = %s",
self._check_health_initial_timeout or "inf")
try:
self._host_cross_device_ops.reduce(
reduce_util.ReduceOp.SUM,
dummy_value,
dummy_value,
options=collective_util.Options(
timeout_seconds=self._check_health_initial_timeout,
implementation=collective_util.CommunicationImplementation.RING))
if context.is_async():
context.async_wait()
except errors.DeadlineExceededError:
raise RuntimeError(
"Timeout waiting for the cluster, timeout is %d seconds" %
self._check_health_initial_timeout)
logging.info("Cluster is ready.")
self._check_health_thread_should_stop = threading.Event()
# Start the thread as daemon to avoid it blocking the program from exiting.
# We try best to shutdown the thread but __del__ is not guaranteed to be
# called when program exists.
self._check_health_thread = threading.Thread(
target=self._check_health,
daemon=True)
self._check_health_thread.start()
def _stop_check_health_thread(self):
if getattr(self, "_check_health_thread", None):
logging.info("stopping check health thread")
self._check_health_thread_should_stop.set()
self._check_health_thread.join()
self._check_health_thread = None
logging.info("check health thread stopped")
def _warn_nccl_no_gpu(self):
if ((self._communication_options.implementation ==
collective_util.CommunicationImplementation.NCCL) and
self._local_device_type != "GPU"):
logging.warning("Enabled NCCL communication but no GPUs detected/"
"specified.")
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return self._num_workers > 1
@property
def experimental_between_graph(self):
return True
@property
def experimental_should_init(self):
return True
@property
def should_checkpoint(self):
return self._is_chief
@property
def should_save_summary(self):
return self._is_chief
@property
def _num_replicas_in_sync(self):
return len(self.worker_devices) * self._num_workers
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
`make_input_fn_iterator` assumes per-replica batching.
Returns:
Boolean.
"""
return True
def _get_replica_id_in_sync_group(self, replica_id):
return self._id_in_cluster * len(self.worker_devices) + replica_id
def _get_local_replica_id(self, replica_id_in_sync_group):
return (replica_id_in_sync_group -
self._id_in_cluster * len(self.worker_devices))
def __deepcopy__(self, memo):
# We check the check health thread instead of whether we are in eager mode
# to limit the backward incompatibility.
if hasattr(self, "_check_health_thread"):
raise ValueError(
"MultiWorkerMirroredStrategy cannot be deep copied in eager mode. "
"If you're using Estimator and see this error message, call "
"tf.compat.v1.disable_eager_execution() at the beginning of your "
"program")
# Otherwise, do a regular deepcopy.
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result