blob: b203aec3bc7f78ee8fd46c7aacf0aa905103fba1 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for cross_device_ops."""
import copy
import threading
from typing import Callable, List, Optional, Union
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core
INSTANCE_KEY_START_NUMBER = 100
def aggregate_gradients_using_nccl(replica_grads):
"""Aggregate gradients using nccl allreduce."""
agg_all_g_and_v = []
for single_g_and_v in zip(*replica_grads):
single_grads = [g for g, _ in single_g_and_v]
agg_grads = nccl_ops.all_sum(single_grads)
agg_all_g_and_v.append(
[(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
agg_all_g_and_v = list(zip(*agg_all_g_and_v))
return agg_all_g_and_v
def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
"""Aggregate gradients using hierarchical copies.
Args:
avail_devices: available GPU devices.
replica_grads: List of lists of (gradient, variable) tuples. The outer list
is over replicas. The inner list is over individual gradients.
Returns:
The list of (aggregated_gradient, variable), where the gradient has been
summed across all replicas and the variable is chosen from the first
replica.
"""
# This only works for DGX-1 type of machine topology
# Device peer to peer matrix
# DMA: 0 1 2 3 4 5 6 7
# 0: Y Y Y Y Y N N N
# 1: Y Y Y Y N Y N N
# 2: Y Y Y Y N N Y N
# 3: Y Y Y Y N N N Y
# 4: Y N N N Y Y Y Y
# 5: N Y N N Y Y Y Y
# 6: N N Y N Y Y Y Y
# 7: N N N Y Y Y Y Y
agg_grads = []
num_devices = len(avail_devices)
# In the special case of DGX-1 machine topology, the two groups have equal
# size.
group_size = num_devices // 2
for i, single_grads in enumerate(zip(*replica_grads)):
group_0_main_device = i % num_devices
group_1_main_device = (group_0_main_device + group_size) % num_devices
if group_0_main_device < group_size:
group_0_begin = 0
group_1_begin = group_size
else:
group_0_begin = group_size
group_1_begin = 0
# Aggregate the first group.
group_0_device_grads = single_grads[group_0_begin:
group_0_begin + group_size]
with ops.device(avail_devices[group_0_main_device]):
group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
group_0_device_grads, False, False)
# Aggregate the second group.
group_1_device_grads = single_grads[group_1_begin:
group_1_begin + group_size]
with ops.device(avail_devices[group_1_main_device]):
group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
group_1_device_grads, False, False)
# Aggregate between the groups.
with ops.device(avail_devices[group_0_main_device]):
(agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
[group_0_agg_grads, group_1_agg_grads], False, False)
# Broadcast the result back into the root of each group.
with ops.device(avail_devices[group_0_main_device]):
group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
with ops.device(avail_devices[group_1_main_device]):
group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
agg_grads_bcast = []
for j in range(len(single_grads)):
with ops.device(avail_devices[j]):
# Broadcast the result back to each member in the group from the root.
if (group_0_main_device < group_size) == (j < group_size):
src_device_grad = group_0_agg_grads_bcast
else:
src_device_grad = group_1_agg_grads_bcast
agg_grads_bcast.append(array_ops.identity(src_device_grad))
agg_grads.append(
[(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
agg_grads = list(zip(*agg_grads))
return agg_grads
def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
check_inf_nan):
"""Calculate the average gradient for a shared variable across all replicas.
Note that this function provides a synchronization point across all replicas.
Args:
grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
(gradient, variable) pair within the outer list represents the gradient
of the variable calculated for a single replica, and the number of pairs
equals the number of replicas.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all replicas. The variable is chosen
from the first replica. The has_nan_or_inf indicates the grads has nan or
inf.
"""
grads = [g for g, _ in grad_and_vars]
grad = math_ops.add_n(grads)
if use_mean and len(grads) > 1:
grad = array_ops.multiply(grad, 1.0 / len(grads))
v = grad_and_vars[0][1]
if check_inf_nan:
has_nan_or_inf = array_ops.logical_not(
array_ops.reduce_all(array_ops.is_finite(grads)))
return (grad, v), has_nan_or_inf
else:
return (grad, v), None
# TODO(yuefengz): use random key starts to avoid reusing keys?
class CollectiveKeys(object):
"""Class that manages collective keys.
We need to manage three different keys for collective:
*Group key*: an integer key to identify the set of cooperative devices.
Collective ops work under the same set of devices must using the same group
key.
*Instance key*: an integer key to identify the set of same counterpart of
tensors on different devices in a device group that need to be all-reduced.
This class is thread safe.
"""
def __init__(self, group_key_start=1):
"""Initializes the object.
Args:
group_key_start: the starting integer of group key.
"""
self._group_key = group_key_start
self._instance_key_table = {}
self._lock = threading.Lock()
self._known_groups = {}
def get_group_key(self, devices):
"""Returns a group key for the list of local devices.
The same group key is returned if the list of local devices is the same.
Args:
devices: a list of local canonical device strings in a collective group.
Returns:
a group key.
"""
with self._lock:
devices_key = ','.join(devices)
if devices_key not in self._known_groups:
self._known_groups[devices_key] = self._get_new_group_key(devices)
return self._known_groups[devices_key]
def _get_new_group_key(self, devices):
"""Returns a new group key.
The caller should store and reuse the same group key for the same set of
devices. Calling this method always returns a new group key.
This method is not thread-safe.
Args:
devices: a list of canonical device strings in a collective group.
Returns:
a new group key.
"""
new_key = self._group_key
self._group_key += 1
self._instance_key_table[new_key] = {}
for device in devices:
self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
return new_key
def get_instance_key(self, group_key, device):
"""Returns a new instance key for use in defining a collective op.
You should call this once per each collective op of a collective instance.
Args:
group_key: the group key returned by get_group_key(). You should not
assign the group key yourself.
device: a canonical device string. It should be the device this collective
op is on.
Returns:
a new instance key.
Raises:
ValueError: when the group key is invalid or the device is not in the
group.
"""
with self._lock:
group = self._instance_key_table.get(group_key, None)
if group is None:
raise ValueError(f'Group {group_key} is not found.')
if device not in group:
raise ValueError(f'Device {device} is not present in group {group_key}')
v = group[device]
group[device] += 1
return v
def __deepcopy__(self, memo):
# distribute_coordinator deep-copies the strategy object, so
# CollectiveKeys needs to support deep copy as well.
copied = CollectiveKeys()
copied._group_key = self._group_key
copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
return copied
class CollectiveReplicaLauncher(object):
"""Launch collectives on one replica."""
_prefer_unique_instance_key = True
_prefer_ordering_token = True
def __init__(self, group_key: int, group_size: int,
collective_keys: CollectiveKeys, device: str,
options: collective_util.Options):
self._group_key = group_key
self._group_size = group_size
self._collective_keys = collective_keys
self._device = device
self._options = options
if self._use_ordering_token():
with ops.init_scope(), ops.device(device):
self._ordering_token = resource_variable_ops.ResourceVariable(0.)
else:
self._ordering_token = None
def _control_input(self, control_input: Union[core.TensorLike,
ops.Operation]):
if control_input is not None and not self._use_ordering_token():
return ops.control_dependencies([control_input])
return ops.NullContextmanager()
def _use_unique_instance_key(self):
if not ops.executing_eagerly_outside_functions():
return False
return CollectiveReplicaLauncher._prefer_unique_instance_key
def _use_ordering_token(self):
# We rely on auto control dep to insert control edges between NCCL calls,
# but for tf1 graph mode auto control dep is not used.
if not ops.executing_eagerly_outside_functions():
return False
return CollectiveReplicaLauncher._prefer_ordering_token
def _next_instance_key(self):
"""Returns the next instance key."""
if self._use_unique_instance_key():
# Assigning instance keys at function building time have issues since
# different workers may retrace the function at different times. With
# collective V2 we can use capture_call_time_value to use a placeholder as
# the instance key and feed it at function call time. In this way we also
# don't reuse instance keys, which allows for per-instance cancellation.
graph = ops.get_default_graph()
# Control flow ops don't work with capture_call_time_value, so we put the
# capture in the function graph of that control flow op.
while getattr(graph, 'is_control_flow_graph', False):
graph = graph.outer_graph
if not context.executing_eagerly() and graph.building_function:
with graph.as_default():
# Capture self._next_instance_key so that when building a function
# that calls another tf.function, the instance key assignment is
# further delayed until we actually call the function in eager. Note
# that capture_call_time_value doesn't automatically propagate the
# deferred capture to the outer function.
return graph.capture_call_time_value(
self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32))
else:
instance_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device('CPU:0'):
return ops.convert_to_tensor(instance_key, dtype=dtypes.int32)
else:
return self._collective_keys.get_instance_key(self._group_key,
self._device)
def _get_ordering_token(self):
if self._use_ordering_token():
return self._ordering_token.handle # pytype: disable=attribute-error
def can_order_nccl(self):
"""Whether this launcher can order NCCL operations."""
return self._use_ordering_token()
def all_reduce(
self,
input_tensor: core.TensorLike,
control_input: Optional[Union[core.TensorLike, ops.Operation]] = None,
options: Optional[collective_util.Options] = None) -> core.Tensor:
"""All-reduce a dense tensor.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
control_input: if not None, add control edges between control_input and
the all-reduce.
options: an optional tf.distribute.experimental.CommunicationOptions. If
provided, it overrides the default options.
Returns:
The reduced tensor.
"""
instance_key = self._next_instance_key()
options = self._options.merge(options)
ordering_token = self._get_ordering_token()
with ops.device(self._device), \
self._control_input(control_input):
return collective_ops.all_reduce_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=options.implementation.value,
timeout=options.timeout_seconds,
ordering_token=ordering_token)
def _all_gather(self, input_tensor: core.TensorLike,
options: Optional[collective_util.Options]) -> core.Tensor:
"""All-gather a dense tensor.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
options: an optional tf.distribute.experimental.CommunicationOptions. If
provided, it overrides the default options.
Returns:
The reduced tensor.
"""
instance_key = self._next_instance_key()
options = self._options.merge(options)
ordering_token = self._get_ordering_token()
with ops.device(self._device):
return collective_ops.all_gather_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=options.implementation.value,
timeout=options.timeout_seconds,
ordering_token=ordering_token)
def batch_all_reduce(
self,
input_tensor_packs: List[List[core.TensorLike]],
options: Optional[collective_util.Options] = None) -> core.Tensor:
"""Batch all-reduce dense tensors.
This takes a list of batches of tensors. Using multiple batches have the
benefit that it doesn't need to wait for all inputs to be ready to start the
all-reduce.
Args:
input_tensor_packs: a list of lists of dense tensors.
options: an optional tf.distribute.experimental.CommunicationOptions. If
provided, it overrides the default options.
Returns:
A flat list of reduced tensors.
"""
options = self._options.merge(options)
outputs = []
for pack in input_tensor_packs:
if context.executing_eagerly():
# We don't batch in eager as it sometimes makes the performance worse
# due the concat/split ops.
for input_tensor in pack:
outputs.append(self.all_reduce(input_tensor, None, options))
else:
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
with ops.device(self._device):
flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
shapes = [array_ops.shape(t) for t in pack]
if (options.implementation
== collective_util.CommunicationImplementation.NCCL and outputs):
control_input = outputs[-1]
else:
control_input = None
reduced = self.all_reduce(
array_ops.concat(flat_tensors, axis=0), control_input, options)
num_elements = [math_ops.reduce_prod(s) for s in shapes]
flat_outputs = array_ops.split(reduced, num_elements, axis=0)
for shape, flat_output in zip(shapes, flat_outputs):
outputs.append(array_ops.reshape(flat_output, shape))
return outputs
def all_gather(
self,
input_tensor: core.TensorLike,
axis: core.TensorLike,
options: Optional[collective_util.Options] = None) -> core.Tensor:
"""All-gather a dense tensor.
This method must be called inside a tf.function.
Args:
input_tensor: a dense tensor. It must have the same rank on all replicas,
and dimensions other than `axis` need to be the same as well.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
options: an optional tf.distribute.experimental.CommunicationOptions. If
provided, it overrides the default options.
Returns:
The gathered Tensor.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('all_gather is not supported in eager mode.')
with ops.device(self._device), \
ops.control_dependencies([array_ops.identity(input_tensor)]):
# 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
# brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
# place it back.
perm_pre = array_ops.concat(
([axis], math_ops.range(axis),
math_ops.range(axis + 1, array_ops.rank(input_tensor))),
axis=0)
input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
# 2. Pad
gathered_shape = self._all_gather(
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
options)
first_dims = gathered_shape[:, 0]
full_axis_dim = math_ops.reduce_max(first_dims)
padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)
# 3. Gather
gather_padded_out_tensor = self._all_gather(padded_input_tensor, options)
# 4. Unpad
split_tensors = []
for i in range(self._group_size):
start_pos = i * full_axis_dim
split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
first_dims[i]])
out_tensor_t = array_ops.concat(split_tensors, 0)
# 5. Transpose back
perm_after = array_ops.concat(
(math_ops.range(1, axis + 1), [0],
math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
axis=0)
return array_ops.transpose(out_tensor_t, perm=perm_after)
def all_reduce_indexed_slices(
self,
input_slices: indexed_slices.IndexedSlices,
options: Optional[collective_util.Options] = None
) -> indexed_slices.IndexedSlices:
"""All-reduce an IndexedSlices.
This method can be called outside tf.function.
Args:
input_slices: an IndexedSlices.
options: an optional tf.distribute.experimental.CommunicationOptions. If
provided, it overrides the default options.
Returns:
The reduced IndexedSlices.
"""
# Current CollectiveAllGather implementations require input IndexedSlices to
# have consistent length across the board, we handle the reduction of
# IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
# 2. If they have consistent length, apply all_gather.
# 3. Otherwise pad IndexedSlices to be the same length across all
# participants and apply_gather.
options = self._options.merge(options)
with ops.device(self._device):
def all_gather_indexed_slices(
all_gather_fn: Callable[
[core.TensorLike, Optional[collective_util.Options]], core.Tensor]
) -> indexed_slices.IndexedSlices:
"""Use all_gather_fn to aggregate `IndexedSlices`."""
all_values = all_gather_fn(input_slices.values, options)
# Add control dependency to order the all-gather.
if (options.implementation ==
collective_util.CommunicationImplementation.NCCL):
control = [all_values]
else:
control = []
with ops.control_dependencies(control):
all_indices = all_gather_fn(input_slices.indices, options)
return indexed_slices.IndexedSlices(
values=all_values,
indices=all_indices,
dense_shape=input_slices.dense_shape)
length = array_ops.shape(input_slices.indices)
all_lengths = self._all_gather(length, options)
def all_gather_with_padding(
input_tensor: core.TensorLike,
options: Optional[collective_util.Options]) -> core.Tensor:
"""all_gather tensors of different sizes using padding."""
max_length = math_ops.reduce_max(all_lengths)
padded_tensor = _pad_util(input_tensor, max_length)
all_padded_tensors = self._all_gather(padded_tensor, options)
split_tensors = []
for i in range(self._group_size):
start_pos = i * max_length
split_tensors.append(all_padded_tensors[start_pos:start_pos +
all_lengths[i]])
return array_ops.concat(split_tensors, 0)
return cond.cond(
math_ops.equal(
math_ops.reduce_max(all_lengths),
math_ops.reduce_min(all_lengths)),
lambda: all_gather_indexed_slices(self._all_gather),
lambda: all_gather_indexed_slices(all_gather_with_padding))
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
"""Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
if any(isinstance(v, indexed_slices.IndexedSlices) for v in values):
return backprop_util.AggregateIndexedSlicesGradients(values)
else:
return accumulation_fn(values)
def divide_by_n_tensors_or_indexed_slices(value, n):
if isinstance(value, indexed_slices.IndexedSlices):
value = backprop_util.FlattenNestedIndexedSlices(value)
return indexed_slices.IndexedSlices(value.values / n, value.indices,
value.dense_shape)
else:
return value / n
def copy_tensor_or_indexed_slices_to_device(value, device):
"""Copies a tensor or IndexedSlices to a device."""
with ops.device(device):
if isinstance(value, indexed_slices.IndexedSlices):
copied_values = array_ops.identity(value.values)
copied_indices = array_ops.identity(value.indices)
if value.dense_shape is not None:
copied_shape = array_ops.identity(value.dense_shape)
else:
copied_shape = None
result = indexed_slices.IndexedSlices(copied_values, copied_indices,
copied_shape)
else:
result = array_ops.identity(value)
return result
def is_indexed_slices(value):
if isinstance(value, indexed_slices.IndexedSlices):
return True
if isinstance(value, value_lib.DistributedValues):
return all(
isinstance(v, indexed_slices.IndexedSlices) for v in value.values)
return False
def split_by_sparsity(values):
"""Split values into dense and sparse values.
Args:
values: a list of tensors or `PerReplica`s.
Returns:
Four lists:
a list of dense values, a list of their indices in `values` and
a list of sparse values, a list of their indices in `values`.
"""
dense_values = []
dense_indices = []
sparse_values = []
sparse_indices = []
for i, v in enumerate(values):
if is_indexed_slices(v):
sparse_values.append(v)
sparse_indices.append(i)
else:
dense_values.append(v)
dense_indices.append(i)
return dense_values, dense_indices, sparse_values, sparse_indices
def stitch_values(values_and_indices_list):
"""Stitch values together according to their indices.
Args:
values_and_indices_list: a list of tuples of values and indices indicating
the values and positions in the returned list.
Returns:
a stitched list of values.
"""
length = 0
for values_and_indices in values_and_indices_list:
length += len(values_and_indices[0])
result = [None] * length
for values_and_indices in values_and_indices_list:
if values_and_indices and values_and_indices[0]:
for v, i in zip(*values_and_indices):
assert result[i] is None
result[i] = v
return result
def group_by_size(input_tensors, bytes_per_pack):
"""Groups `input_tensors` into chunks of `bytes_per_pack`.
The method preserves the original order of `input_tensors`. The grouping is
best effort, each pack could have more or less bytes than `bytes_per_pack`.
It only groups values with known shape.
Args:
input_tensors: a list of Tensor.
bytes_per_pack: an integer.
Returns:
A list of packs of Tensor. All values are grouped into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape.
"""
if bytes_per_pack == 0:
return [input_tensors]
packs = []
last_pack_size = 0
for value in input_tensors:
num_elements = value.shape.num_elements()
if num_elements is None:
# Can't pack values with unknown shape.
logging.warning(
'not packing values due to the unknown or inconsistent shape of %s',
value)
return [input_tensors]
size = num_elements * value.dtype.size
# Try to keep each pack as close to bytes_per_pack as possible, while each
# pack is at least bytes_per_pack large. I.E. we err on the side of having
# few but large packs.
if not packs or last_pack_size > bytes_per_pack:
packs.append([])
last_pack_size = 0
packs[-1].append(value)
last_pack_size += size
return packs
def _pad_util(input_tensor, full_axis_dim):
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
paddings = array_ops.concat([
paddings_axis,
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
],
axis=0)
padded_input_tensor = array_ops.pad(input_tensor, paddings)
return padded_input_tensor