blob: 05579c997584edc13afc0b7b624fca53b638758e [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU Strategy."""
import atexit
import collections
import contextlib
import copy
import functools
import weakref
from absl import logging
import numpy as np
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.compiler.xla.experimental import xla_sharding
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import input_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_replicated_variable
from tensorflow.python.distribute import tpu_util
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import device_spec
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_hardware_feature
from tensorflow.python.tpu import training_loop
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
_XLA_OP_BY_OP_INPUTS_LIMIT = 200
@contextlib.contextmanager
def maybe_init_scope():
if ops.executing_eagerly_outside_functions():
yield
else:
with ops.init_scope():
yield
def validate_run_function(fn):
"""Validate the function passed into strategy.run."""
# We allow three types of functions/objects passed into TPUStrategy
# run in eager mode:
# 1. a user annotated tf.function
# 2. a ConcreteFunction, this is mostly what you get from loading a saved
# model.
# 3. a callable object and the `__call__` method itself is a tf.function.
#
# Otherwise we return an error, because we don't support eagerly running
# run in TPUStrategy.
if (context.executing_eagerly()
and not isinstance(fn, def_function.Function)
and not isinstance(fn, function.ConcreteFunction)
and not (
callable(fn) and isinstance(fn.__call__, def_function.Function))
):
raise NotImplementedError(
"TPUStrategy.run(fn, ...) does not support pure eager "
"execution. please make sure the function passed into "
"`strategy.run` is a `tf.function` or "
"`strategy.run` is called inside a `tf.function` if "
"eager behavior is enabled.")
def _maybe_partial_apply_variables(fn, args, kwargs):
"""Inspects arguments to partially apply any DistributedVariable.
This avoids an automatic cast of the current variable value to tensor.
Note that a variable may be captured implicitly with Python scope instead of
passing it to run(), but supporting run() keeps behavior consistent
with MirroredStrategy.
Since positional arguments must be applied from left to right, this function
does some tricky function inspection to move variable positional arguments
into kwargs. As a result of this, we can't support passing Variables as *args,
nor as args to functions which combine both explicit positional arguments and
*args.
Args:
fn: The function to run, as passed to run().
args: Positional arguments to fn, as passed to run().
kwargs: Keyword arguments to fn, as passed to run().
Returns:
A tuple of the function (possibly wrapped), args, kwargs (both
possibly filtered, with members of args possibly moved to kwargs).
If no variables are found, this function is a noop.
Raises:
ValueError: If the function signature makes unsupported use of *args, or if
too many arguments are passed.
"""
def is_distributed_var(x):
flat = nest.flatten(x)
return flat and isinstance(flat[0], values.DistributedVariable)
# We will split kwargs into two dicts, one of which will be applied now.
var_kwargs = {}
nonvar_kwargs = {}
if kwargs:
var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)}
if var_kwargs:
nonvar_kwargs = {
k: v for k, v in kwargs.items() if not is_distributed_var(v)
}
# Dump the argument names of `fn` to a list. This will include both positional
# and keyword arguments, but since positional arguments come first we can
# look up names of positional arguments by index.
positional_args = []
index_of_star_args = None
for i, p in enumerate(tf_inspect.signature(fn).parameters.values()):
# Class methods define "self" as first argument, but we don't pass "self".
# Note that this is a heuristic, as a method can name its first argument
# something else, and a function can define a first argument "self" as well.
# In both of these cases, using a Variable will fail with an unfortunate
# error about the number of arguments.
# inspect.is_method() seems not to work here, possibly due to the use of
# tf.function().
if i == 0 and p.name == "self":
continue
if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD:
positional_args.append(p.name)
elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL:
# We'll raise an error later if a variable is passed to *args, since we
# can neither pass it by name nor partially apply it. This case only
# happens once at most.
index_of_star_args = i
elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY:
# This is a rare Python feature, indicating a / in the arg list.
if var_kwargs or any(is_distributed_var(a) for a in args):
raise ValueError(
"Mixing Variables and positional-only parameters not supported by "
f"TPUStrategy. Received {len(var_kwargs)} DistributedVariables in "
f"**kwargs and {sum(is_distributed_var(a) for a in args)} in *args,"
" expected zero for both."
)
return fn, args, kwargs
star_args = []
have_seen_var_arg = False
for i, a in enumerate(args):
if is_distributed_var(a):
if index_of_star_args is not None and i >= index_of_star_args:
raise ValueError(
"TPUStrategy.run() cannot handle Variables passed to *args. "
"Either name the function argument, or capture the Variable "
"implicitly.")
if len(positional_args) <= i:
raise ValueError(
"Too many positional arguments passed to call to TPUStrategy.run()."
)
var_kwargs[positional_args[i]] = a
have_seen_var_arg = True
else:
if index_of_star_args is not None and i >= index_of_star_args:
if have_seen_var_arg:
raise ValueError(
"TPUStrategy.run() cannot handle both Variables and a mix of "
"positional args and *args. Either remove the *args, or capture "
"the Variable implicitly.")
else:
star_args.append(a)
continue
if len(positional_args) <= i:
raise ValueError(
"Too many positional arguments passed to call to TPUStrategy.run()."
)
nonvar_kwargs[positional_args[i]] = a
if var_kwargs:
return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs
return fn, args, kwargs
@tf_export("distribute.TPUStrategy", v1=[])
class TPUStrategyV2(distribute_lib.Strategy):
"""Synchronous training on TPUs and TPU Pods.
To construct a TPUStrategy object, you need to run the
initialization code as below:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.TPUStrategy(resolver)
While using distribution strategies, the variables created within the
strategy's scope will be replicated across all the replicas and can be kept in
sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
training loop by calling `strategy.run` directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into `strategy.run` is a `tf.function` or
`strategy.run` is called inside a `tf.function` if eager
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
`distribute_datasets_from_function` and
`experimental_distribute_dataset` APIs can be used to distribute the dataset
across the TPU workers when writing your own training loop. If you are using
`fit` and `compile` methods available in `tf.keras.Model`, then Keras will
handle the distribution for you.
An example of writing customized training loop on TPUs:
>>> with strategy.scope():
... model = tf.keras.Sequential([
... tf.keras.layers.Dense(2, input_shape=(5,)),
... ])
... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
>>> def dataset_fn(ctx):
... x = np.random.random((2, 5)).astype(np.float32)
... y = np.random.randint(2, size=(2, 1))
... dataset = tf.data.Dataset.from_tensor_slices((x, y))
... return dataset.repeat().batch(1, drop_remainder=True)
>>> dist_dataset = strategy.distribute_datasets_from_function(
... dataset_fn)
>>> iterator = iter(dist_dataset)
>>> @tf.function()
... def train_step(iterator):
...
... def step_fn(inputs):
... features, labels = inputs
... with tf.GradientTape() as tape:
... logits = model(features, training=True)
... loss = tf.keras.losses.sparse_categorical_crossentropy(
... labels, logits)
...
... grads = tape.gradient(loss, model.trainable_variables)
... optimizer.apply_gradients(zip(grads, model.trainable_variables))
...
... strategy.run(step_fn, args=(next(iterator),))
>>> train_step(iterator)
For the advanced use cases like model parallelism, you can set
`experimental_device_assignment` argument when creating TPUStrategy to specify
number of replicas and number of logical devices. Below is an example to
initialize TPU system with 2 logical devices and 1 replica.
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
>>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
... topology,
... computation_shape=[1, 1, 1, 2],
... num_replicas=1)
>>> strategy = tf.distribute.TPUStrategy(
... resolver, experimental_device_assignment=device_assignment)
Then you can run a `tf.add` operation only on logical device 0.
>>> @tf.function()
... def step_fn(inputs):
... features, _ = inputs
... output = tf.add(features, features)
...
... # Add operation will be executed on logical device 0.
... output = strategy.experimental_assign_to_logical_device(output, 0)
... return output
>>> dist_dataset = strategy.distribute_datasets_from_function(
... dataset_fn)
>>> iterator = iter(dist_dataset)
>>> strategy.run(step_fn, args=(next(iterator),))
`experimental_spmd_xla_partitioning` enables the experimental XLA SPMD feature
for model parallelism. This flag can reduce the compilation time and HBM
requirements. When running in this mode, every input tensor must either be
partitioned (via `strategy.experimental_split_to_logical_devices`) or fully
replicated (via `strategy.experimental_replicate_to_logical_devices`) to all
logical devices. And calling `strategy.experimental_assign_to_logical_device`
will result in a ValueError in this mode.
"""
def __init__(self,
tpu_cluster_resolver=None,
experimental_device_assignment=None,
experimental_spmd_xla_partitioning=False):
"""Synchronous training in TPU donuts or Pods.
Args:
tpu_cluster_resolver: A
`tf.distribute.cluster_resolver.TPUClusterResolver` instance, which
provides information about the TPU cluster. If None, it will assume
running on a local TPU worker.
experimental_device_assignment: Optional
`tf.tpu.experimental.DeviceAssignment` to specify the placement of
replicas on the TPU cluster.
experimental_spmd_xla_partitioning: If True, enable the SPMD (Single
Program Multiple Data) mode in XLA compiler. This flag only affects the
performance of XLA compilation and the HBM requirement of the compiled
TPU program. Ceveat: if this flag is True, calling
`tf.distribute.TPUStrategy.experimental_assign_to_logical_device` will
result in a ValueError.
"""
super(TPUStrategyV2, self).__init__(
TPUExtended(
self,
tpu_cluster_resolver,
device_assignment=experimental_device_assignment,
use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning,
enable_data_reorder=experimental_device_assignment is not None,
)
)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = True
def run(self, fn, args=(), kwargs=None, options=None):
"""Run the computation defined by `fn` on each TPU replica.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
`tf.distribute.DistributedValues`, such as those produced by a
`tf.distribute.DistributedDataset` from
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.distribute_datasets_from_function`,
when `fn` is executed on a particular replica, it will be executed with the
component of `tf.distribute.DistributedValues` that correspond to that
replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or
`tf.distribute.DistributedValues` containing tensors or composite tensors.
Example usage:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.TPUStrategy(resolver)
>>> @tf.function
... def run():
... def value_fn(value_context):
... return value_context.num_replicas_in_sync
... distributed_values = (
... strategy.experimental_distribute_values_from_function(value_fn))
... def replica_fn(input):
... return input * 2
... return strategy.run(replica_fn, args=(distributed_values,))
>>> result = run()
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `tf.distribute.DistributedValues`, `Tensor`
objects, or `Tensor`s (for example, if running on a single replica).
"""
validate_run_function(fn)
fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
`tf.distribute.TPUStrategy` provides the associated
`tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one
in `__init__`, that instance is returned; if the user does not, a default
`tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
"""
return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
def experimental_assign_to_logical_device(self, tensor, logical_device_id):
"""Adds annotation that `tensor` will be assigned to a logical device.
This adds an annotation to `tensor` specifying that operations on
`tensor` will be invoked on logical core device id `logical_device_id`.
When model parallelism is used, the default behavior is that all ops
are placed on zero-th logical device.
```python
# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.add(inputs, inputs)
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
strategy.run(step_fn, args=(next(iterator),))
```
Args:
tensor: Input tensor to annotate.
logical_device_id: Id of the logical core to which the tensor will be
assigned.
Raises:
ValueError: The logical device id presented is not consistent with total
number of partitions specified by the device assignment or the TPUStrategy
is constructed with `experimental_spmd_xla_partitioning=True`.
Returns:
Annotated tensor with identical value as `tensor`.
"""
if self.extended._use_spmd_for_xla_partitioning: # pylint: disable=protected-access
raise ValueError(
"Cannot assign a tensor to a logical device in SPMD mode. To disable "
"SPMD, Please construct the TPUStrategy with "
"`experimental_spmd_xla_partitioning=False`")
num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access
if (logical_device_id < 0 or
logical_device_id >= num_logical_devices_per_replica):
raise ValueError("`logical_core_id` to assign must be lower then total "
"number of logical devices per replica. Received "
"logical device id {} but there are only total of {} "
"logical devices in replica.".format(
logical_device_id, num_logical_devices_per_replica))
return xla_sharding.assign_device(
tensor, logical_device_id, use_sharding_op=True)
def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
"""Adds annotation that `tensor` will be split across logical devices.
This adds an annotation to tensor `tensor` specifying that operations on
`tensor` will be split among multiple logical devices. Tensor `tensor` will
be split across dimensions specified by `partition_dimensions`.
The dimensions of `tensor` must be divisible by corresponding value in
`partition_dimensions`.
For example, for system with 8 logical devices, if `tensor` is an image
tensor with shape (batch_size, width, height, channel) and
`partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
2 in width dimension and 4 way in height dimension and the split
tensor values will be fed into 8 logical devices.
```python
# Initializing TPU system with 8 logical devices and 1 replica.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 2, 2, 2],
num_replicas=1)
# Construct the TPUStrategy. Since we are going to split the image across
# logical devices, here we set `experimental_spmd_xla_partitioning=True`
# so that the partitioning can be compiled in SPMD mode, which usually
# results in faster compilation and smaller HBM requirement if the size of
# input and activation tensors are much bigger than that of the model
# parameters. Note that this flag is suggested but not a hard requirement
# for `experimental_split_to_logical_devices`.
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment,
experimental_spmd_xla_partitioning=True)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
inputs = strategy.experimental_split_to_logical_devices(
inputs, [1, 2, 4, 1])
# model() function will be executed on 8 logical devices with `inputs`
# split 2 * 4 ways.
output = model(inputs)
return output
strategy.run(step_fn, args=(next(iterator),))
```
Args:
tensor: Input tensor to annotate.
partition_dimensions: An unnested list of integers with the size equal to
rank of `tensor` specifying how `tensor` will be partitioned. The
product of all elements in `partition_dimensions` must be equal to the
total number of logical devices per replica.
Raises:
ValueError: 1) If the size of partition_dimensions does not equal to rank
of `tensor` or 2) if product of elements of `partition_dimensions` does
not match the number of logical devices per replica defined by the
implementing DistributionStrategy's device specification or
3) if a known size of `tensor` is not divisible by corresponding
value in `partition_dimensions`.
Returns:
Annotated tensor with identical value as `tensor`.
"""
num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access
num_partition_splits = np.prod(partition_dimensions)
input_shape = tensor.shape
tensor_rank = len(input_shape)
if tensor_rank != len(partition_dimensions):
raise ValueError("Length of `partition_dimensions` must equal to the "
"rank of `tensor.shape` ({}). Received "
"len(partition_dimensions)={}.".format(
tensor_rank, len(partition_dimensions)))
for dim_index, dim_size in enumerate(input_shape):
if dim_size is None:
continue
split_size = partition_dimensions[dim_index]
if dim_size % split_size != 0:
raise ValueError("Tensor shape at `partition_dimensions[{}]` must be "
"divisible by corresponding value specified "
"by `partition_dimensions` ({}). Received: {}.".format(
dim_index, split_size, dim_size))
if num_partition_splits != num_logical_devices_per_replica:
raise ValueError(
"The product of `partition_dimensions` should be the same as the "
"number of logical devices (={}). Received `partition_dimensions`={},"
"and their product is {}.".format(num_logical_devices_per_replica,
partition_dimensions,
num_partition_splits))
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def experimental_replicate_to_logical_devices(self, tensor):
"""Adds annotation that `tensor` will be replicated to all logical devices.
This adds an annotation to tensor `tensor` specifying that operations on
`tensor` will be invoked on all logical devices.
```python
# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
images, labels = inputs
images = strategy.experimental_split_to_logical_devices(
inputs, [1, 2, 4, 1])
# model() function will be executed on 8 logical devices with `inputs`
# split 2 * 4 ways.
output = model(inputs)
# For loss calculation, all logical devices share the same logits
# and labels.
labels = strategy.experimental_replicate_to_logical_devices(labels)
output = strategy.experimental_replicate_to_logical_devices(output)
loss = loss_fn(labels, output)
return loss
strategy.run(step_fn, args=(next(iterator),))
```
Args:
tensor: Input tensor to annotate.
Returns:
Annotated tensor with identical value as `tensor`.
"""
return xla_sharding.replicate(tensor, use_sharding_op=True)
@tf_export("distribute.experimental.TPUStrategy", v1=[])
@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
class TPUStrategy(distribute_lib.Strategy):
"""Synchronous training on TPUs and TPU Pods.
To construct a TPUStrategy object, you need to run the
initialization code as below:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
While using distribution strategies, the variables created within the
strategy's scope will be replicated across all the replicas and can be kept in
sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
training loop by calling `strategy.run` directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into `strategy.run` is a `tf.function` or
`strategy.run` is called inside a `tf.function` if eager
behavior is enabled.
"""
def __init__(self,
tpu_cluster_resolver=None,
device_assignment=None):
"""Synchronous training in TPU donuts or Pods.
Args:
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
specify the placement of replicas on the TPU cluster.
"""
logging.warning(
"`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
"the non-experimental symbol `tf.distribute.TPUStrategy` instead.")
super(TPUStrategy, self).__init__(
TPUExtended(
self,
tpu_cluster_resolver,
device_assignment=device_assignment,
enable_data_reorder=device_assignment is not None,
)
)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = True
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def run(self, fn, args=(), kwargs=None, options=None):
"""See base class."""
validate_run_function(fn)
fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
`tf.distribute.experimental.TPUStrategy` provides the
associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
provides one in `__init__`, that instance is returned; if the user does
not, a default
`tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
"""
return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
@tf_export(v1=["distribute.experimental.TPUStrategy"])
class TPUStrategyV1(distribute_lib.StrategyV1):
"""TPU distribution strategy implementation."""
def __init__(self,
tpu_cluster_resolver=None,
steps_per_run=None,
device_assignment=None):
"""Initializes the TPUStrategy object.
Args:
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
steps_per_run: Number of steps to run on device before returning to the
host. Note that this can have side-effects on performance, hooks,
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
specify the placement of replicas on the TPU cluster. Currently only
supports the usecase of using a single core within a TPU cluster.
"""
super(TPUStrategyV1, self).__init__(TPUExtended(
self, tpu_cluster_resolver, steps_per_run, device_assignment))
distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = True
@property
def steps_per_run(self):
"""DEPRECATED: use .extended.steps_per_run instead."""
return self._extended.steps_per_run
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def run(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
"per-replica" values, such as those produced by a "distributed `Dataset`",
when `fn` is executed on a particular replica, it will be executed with the
component of those "per-replica" values that correspond to that replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or
per-replica objects containing tensors or composite tensors.
Users can pass strategy specific options to `options` argument. An example
to enable bucketizing dynamic shapes in `TPUStrategy.run`
is:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
>>> options = tf.distribute.RunOptions(
... experimental_bucketizing_dynamic_shape=True)
>>> dataset = tf.data.Dataset.range(
... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
... strategy.num_replicas_in_sync, drop_remainder=True)
>>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> @tf.function()
... def step_fn(inputs):
... output = tf.reduce_sum(inputs)
... return output
>>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica).
"""
validate_run_function(fn)
fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
class TPUExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of TPUStrategy."""
def __init__(
self,
container_strategy,
tpu_cluster_resolver=None,
steps_per_run=None,
device_assignment=None,
use_spmd_for_xla_partitioning=False,
enable_data_reorder=False,
):
super(TPUExtended, self).__init__(container_strategy)
if tpu_cluster_resolver is None:
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver("")
if steps_per_run is None:
# TODO(frankchn): Warn when we are being used by DS/Keras and this is
# not specified.
steps_per_run = 1
# `self._tpu_function_cache` is a dict of `tf.function`s, thus if a
# `tf.function` is passed into `strategy.run` in eager mode, the
# `tf.function` won't get retraced.
self._tpu_function_cache = weakref.WeakKeyDictionary()
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata()
self._device_assignment = device_assignment
tpu_devices_flat = [
d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
# `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
# indexed using `[replica_id][logical_device_id]`.
if device_assignment is None:
self._tpu_devices = np.array(
[[d] for d in tpu_devices_flat], dtype=object)
else:
job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
tpu_devices = []
for replica_id in range(device_assignment.num_replicas):
replica_devices = []
for logical_core in range(device_assignment.num_cores_per_replica):
replica_devices.append(
device_util.canonicalize(
device_assignment.tpu_device(
replica=replica_id,
logical_core=logical_core,
job=job_name)))
tpu_devices.append(replica_devices)
self._tpu_devices = np.array(tpu_devices, dtype=object)
self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
# Preload the data onto the TPUs. Currently we always preload onto logical
# device 0 for each replica.
# TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
# input onto a different logical device?
self._device_input_worker_devices = collections.OrderedDict()
self._host_input_worker_devices = collections.OrderedDict()
for tpu_device in self._tpu_devices[:, 0]:
host_device = device_util.get_host_for_device(tpu_device)
self._device_input_worker_devices.setdefault(host_device, [])
self._device_input_worker_devices[host_device].append(tpu_device)
self._host_input_worker_devices.setdefault(host_device, [])
self._host_input_worker_devices[host_device].append(host_device)
# Create the replica order based on the assigned device order.
# This replica order will be used to match the IteratorGetNext ops
# with the device assigment.
self._replica_order = (
self._get_replica_order(self._tpu_devices[:, 0])
if enable_data_reorder
else None
)
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
self._require_static_shapes = True
self.experimental_enable_get_next_as_optional = True
self._logical_device_stack = [0]
if context.executing_eagerly():
# In async remote eager, we want to sync the executors before exiting the
# program.
atexit.register(context.async_wait)
# Flag to turn on VariablePolicy. Var policy is deprecated because there is
# another effort unifying DistributedVariables (see values_v2.py). SPMD XLA
# partitioning is not implemented for var policies.
# TODO(b/202048882): remove var policy from TPUStrategy.
self._use_var_policy = not use_spmd_for_xla_partitioning
# Flag to enable XLA SPMD partitioning.
self._use_spmd_for_xla_partitioning = use_spmd_for_xla_partitioning
def _get_replica_order(self, tpu_devices):
"""Get the replica order based on the tpu device order.
For example, if the tpu_devices are:
'/job:worker/replica:0/task:0/device:TPU:0',
'/job:worker/replica:0/task:0/device:TPU:2',
'/job:worker/replica:0/task:1/device:TPU:0',
'/job:worker/replica:0/task:1/device:TPU:2',
'/job:worker/replica:0/task:1/device:TPU:6',
'/job:worker/replica:0/task:1/device:TPU:4',
'/job:worker/replica:0/task:0/device:TPU:6',
'/job:worker/replica:0/task:0/device:TPU:4',
the returned replica order will be:
[0, 1, 7, 6, 2, 3, 5, 4]
This replica order will be used to reorder the data returned by the
iterators,
so that they can be placed on the same node as their computation graphs.
Args:
tpu_devices (List[str]): A list of tpu device names in the order of
replicas.
Returns:
A list containing the order ids of corresponding TPU devices.
"""
devices_with_ids = []
for i, tpu_device in enumerate(tpu_devices):
spec = tf_device.DeviceSpec.from_string(tpu_device)
devices_with_ids.append((
(
spec.job,
spec.replica,
spec.device_type,
spec.task,
spec.device_index,
),
i,
))
return [i for _, i in sorted(devices_with_ids)]
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils.validate_colocate(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
"""Make iterators for each of the TPU hosts."""
input_workers = input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
return input_lib_v1.DatasetIterator(
dataset,
input_workers,
self._container_strategy(),
num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
input_contexts = []
input_workers = input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(
distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib_v1.InputFunctionIterator(input_fn, input_workers,
input_contexts,
self._container_strategy())
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, numpy_dataset.SingleDevice(self._host_device),
session)
def _get_input_workers(self, options):
if not options or options.experimental_fetch_to_device:
return input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
else:
return input_lib.InputWorkers(
tuple(self._host_input_worker_devices.items()))
def _check_spec(self, element_spec):
if isinstance(element_spec, values.PerReplicaSpec):
element_spec = element_spec._component_specs # pylint: disable=protected-access
specs = nest.flatten_with_joined_string_paths(element_spec)
for path, spec in specs:
if isinstance(spec, (sparse_tensor.SparseTensorSpec,
ragged_tensor.RaggedTensorSpec)):
raise ValueError(
"Found tensor {} with spec {}. TPUStrategy does not support "
"distributed datasets with device prefetch when using sparse or "
"ragged tensors. If you intend to use sparse or ragged tensors, "
"please pass a tf.distribute.InputOptions object with "
"experimental_fetch_to_device set to False to your dataset "
"distribution function.".format(path, type(spec)))
def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
if options is None or options.experimental_fetch_to_device:
self._check_spec(dataset.element_spec)
return input_util.get_distributed_dataset(
dataset,
self._get_input_workers(options),
self._container_strategy(),
num_replicas_in_sync=self._num_replicas_in_sync,
options=options,
replica_order=self._replica_order,
)
def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
" `experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_workers = self._get_input_workers(options)
input_contexts = []
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
distributed_dataset = input_util.get_distributed_datasets_from_function(
dataset_fn,
input_workers,
input_contexts,
self._container_strategy(),
options=options,
replica_order=self._replica_order,
)
# We can only check after the dataset_fn is called.
if options is None or options.experimental_fetch_to_device:
self._check_spec(distributed_dataset.element_spec)
return distributed_dataset
def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = []
for replica_id in range(self._num_replicas_in_sync):
per_replica_values.append(
value_fn(distribute_lib.ValueContext(replica_id,
self._num_replicas_in_sync)))
return distribute_utils.regroup(per_replica_values, always_wrap=True)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
def _experimental_run_steps_on_iterator(
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
# Wrap `fn` for repeat.
if initial_loop_values is None:
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = input_lib.MultiStepContext()
def run_fn(inputs):
"""Single step on the TPU device."""
fn_result = fn(ctx, inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
if flat_last_step_outputs:
with ops.control_dependencies([fn_result]):
return [array_ops.identity(f) for f in flat_last_step_outputs]
else:
return fn_result
# We capture the control_flow_context at this point, before we run `fn`
# inside a while_loop and TPU replicate context. This is useful in cases
# where we might need to exit these contexts and get back to the outer
# context to do some things, for e.g. create an op which should be
# evaluated only once at the end of the loop on the host. One such usage
# is in creating metrics' value op.
self._outer_control_flow_context = (
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
def rewrite_fn(*args):
"""The rewritten step fn running on TPU."""
del args
per_replica_inputs = multi_worker_iterator.get_next()
replicate_inputs = []
for replica_id in range(self._num_replicas_in_sync):
select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda
replica_id, x) # pylint: disable=cell-var-from-loop
replicate_inputs.append((nest.map_structure(
select_replica, per_replica_inputs),))
replicate_outputs = tpu.replicate(
run_fn,
replicate_inputs,
device_assignment=self._device_assignment,
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
._use_spmd_for_xla_partitioning))
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
# tpu.replicate returns a list of no_ops, we will keep the output as it
# is.
if isinstance(replicate_outputs[0], list):
replicate_outputs = nest.flatten(replicate_outputs)
return replicate_outputs
# TODO(sourabhbajaj): The input to while loop should be based on the
# output type of the step_fn
assert isinstance(initial_loop_values, list)
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
# Put the while loop op on TPU host 0.
with ops.device(self._host_device):
if self.steps_per_run == 1:
replicate_outputs = rewrite_fn()
else:
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
initial_loop_values)
del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs)
if isinstance(replicate_outputs, list):
# Filter out any ops from the outputs, typically this would be the case
# when there were no tensor outputs.
last_step_tensor_outputs = [
x for x in replicate_outputs if not isinstance(x, ops.Operation)
]
# Outputs are currently of the structure (flattened)
# [output0_device0, output1_device0, output2_device0,
# output0_device1, output1_device1, output2_device1,
# ...]
# Convert this to the following structure instead: (grouped by output)
# [[output0_device0, output0_device1],
# [output1_device0, output1_device1],
# [output2_device0, output2_device1]]
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
last_step_tensor_outputs = [
last_step_tensor_outputs[i::output_num] for i in range(output_num)
]
else:
# no tensors returned.
last_step_tensor_outputs = []
_set_last_step_outputs(ctx, last_step_tensor_outputs)
return ctx
def _call_for_each_replica(self, fn, args, kwargs):
# TODO(jhseu): Consider making it so call_for_each_replica implies that
# we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
with _TPUReplicaContext(self._container_strategy()):
return fn(*args, **kwargs)
@contextlib.contextmanager
def experimental_logical_device(self, logical_device_id):
"""Places variables and ops on the specified logical device."""
num_logical_devices_per_replica = self._tpu_devices.shape[1]
if logical_device_id >= num_logical_devices_per_replica:
raise ValueError(
"`logical_device_id` not in range (was {}, but there are only {} "
"logical devices per replica).".format(
logical_device_id, num_logical_devices_per_replica))
self._logical_device_stack.append(logical_device_id)
try:
if tpu_util.enclosing_tpu_context() is None:
yield
else:
with ops.device(tpu.core(logical_device_id)):
yield
finally:
self._logical_device_stack.pop()
def _experimental_initialize_system(self):
"""Experimental method added to be used by Estimator.
This is a private method only to be used by Estimator. Other frameworks
should directly be calling `tf.tpu.experimental.initialize_tpu_system`
"""
tpu_cluster_resolver_lib.initialize_tpu_system(self._tpu_cluster_resolver)
def _create_variable(self, next_creator, **kwargs):
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
if kwargs.pop("skip_mirrored_creator", False):
return next_creator(**kwargs)
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
devices = self._tpu_devices[:, self._logical_device_stack[-1]]
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(**kwargs)
else:
devices = colocate_with._devices # pylint: disable=protected-access
num_replicas, num_cores_per_replica = self._tpu_devices.shape
def _create_mirrored_tpu_variables(**kwargs):
"""Returns a list of `tf.Variable`s.
The list contains `number_replicas` `tf.Variable`s and can be used to
initialize a `TPUMirroredVariable`.
Args:
**kwargs: the keyword arguments for creating a variable
"""
initial_value = None
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
if i == 0:
initial_value = kwargs["initial_value"]
# Note: some v1 code expects variable initializer creation to happen
# inside a init_scope.
with maybe_init_scope():
initial_value = initial_value() if callable(
initial_value) else initial_value
if i > 0:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
kwargs["initial_value"] = initial_value
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(**kwargs)
assert not isinstance(v, tpu_values.TPUMirroredVariable)
value_list.append(v)
return value_list
def _create_mirrored_tpu_replicated_variables(**kwargs):
"""Returns a list of `TPUReplicatedVariable`s.
The list consists of `num_replicas` `TPUReplicatedVariable`s and can be
used to initialize a `TPUMirroredVariable`. Each `TPUReplicatedVariable`
contains a list of `tf.Variable`s which are replicated to
`num_cores_per_replica` logical cores to enable XLA SPMD compilation.
Args:
**kwargs: the keyword arguments for creating a variable
"""
initial_value = kwargs["initial_value"]
# Note: some v1 code expects variable initializer creation to happen
# inside a init_scope.
with maybe_init_scope():
initial_value = initial_value() if callable(
initial_value) else initial_value
mirrored_replicated_var_list = []
for replica_id in range(num_replicas):
replicated_var_list = []
for logic_core_id in range(num_cores_per_replica):
with ops.device(self._tpu_devices[replica_id][logic_core_id]):
kwargs["initial_value"] = initial_value
v = next_creator(**kwargs)
replicated_var_list.append(v)
replica_name = "{}/r:{}".format(kwargs["name"], replica_id)
tpu_replicated_var = tpu_replicated_variable.TPUReplicatedVariable(
variables=replicated_var_list, name=replica_name)
mirrored_replicated_var_list.append(tpu_replicated_var)
return mirrored_replicated_var_list
if self._use_spmd_for_xla_partitioning and num_cores_per_replica > 1:
real_creator = _create_mirrored_tpu_replicated_variables
else:
real_creator = _create_mirrored_tpu_variables
return distribute_utils.create_mirrored_variable(
self._container_strategy(), real_creator,
distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
def _resource_creator_scope(self):
def lookup_creator(next_creator, *args, **kwargs):
host_to_table = collections.OrderedDict()
for host_device in self._device_input_worker_devices.keys():
with ops.device(host_device):
host_to_table[host_device] = next_creator(*args, **kwargs)
return values.PerWorkerResource(self._container_strategy(), host_to_table)
# TODO(b/194362531): Define creator(s) for other resources.
return ops.resource_creator_scope("StaticHashTable", lookup_creator)
def _gather_to_implementation(self, value, destinations, axis, options):
if not isinstance(value, values.DistributedValues):
return value
value_list = list(value.values)
# pylint: disable=protected-access
if isinstance(
value,
values.DistributedVariable) and value._packed_variable is not None:
value_list = list(
value._packed_variable.on_device(d)
for d in value._packed_variable.devices)
# pylint: enable=protected-access
# Currently XLA op by op mode has a limit for the number of inputs for a
# single op, thus we break one `add_n` op into a group of `add_n` ops to
# work around the constraint.
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
output = array_ops.concat(value_list, axis=axis)
else:
output = array_ops.concat(
value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis)
for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list),
_XLA_OP_BY_OP_INPUTS_LIMIT - 1):
output = array_ops.concat(
[output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1],
axis=axis)
output = self._broadcast_output(destinations, output)
return output
def _broadcast_output(self, destinations, output):
devices = cross_device_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
# If necessary, copy to requested destination.
dest_canonical = device_util.canonicalize(devices[0])
host_canonical = device_util.canonicalize(self._host_device)
if dest_canonical != host_canonical:
with ops.device(dest_canonical):
output = array_ops.identity(output)
else:
output = cross_device_ops_lib.simple_broadcast(output, destinations)
return output
def _reduce_to(self, reduce_op, value, destinations, options):
if (isinstance(value, values.DistributedValues) or
tensor_util.is_tf_type(value)
) and tpu_util.enclosing_tpu_context() is not None:
if reduce_op == reduce_util.ReduceOp.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
# scalar_mul maintains the type of value: tensor or IndexedSlices.
value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value)
elif reduce_op != reduce_util.ReduceOp.SUM:
raise NotImplementedError(
f"`reduce_op`={reduce_op} is not supported. Currently we only "
"support ReduceOp.SUM and ReduceOp.MEAN in TPUStrategy.")
return tpu_ops.cross_replica_sum(value)
if not isinstance(value, values.DistributedValues):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync)
value_list = value.values
# pylint: disable=protected-access
if isinstance(
value,
values.DistributedVariable) and value._packed_variable is not None:
value_list = tuple(
value._packed_variable.on_device(d)
for d in value._packed_variable.devices)
# pylint: enable=protected-access
# Currently XLA op by op mode has a limit for the number of inputs for a
# single op, thus we break one `add_n` op into a group of `add_n` ops to
# work around the constraint.
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
output = math_ops.add_n(value_list)
else:
output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value_list))
output = self._broadcast_output(destinations, output)
return output
def _update(self, var, fn, args, kwargs, group):
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
var, resource_variable_ops.BaseResourceVariable)
if tpu_util.enclosing_tpu_context() is not None:
if group:
return fn(var, *args, **kwargs)
else:
return (fn(var, *args, **kwargs),)
# Inside `tf.function`, we don't expand PackedVariable in python as it will
# be expanded later during function instantiation in the runtime.
packed_var = var._packed_variable # pylint: disable=protected-access
if packed_var is not None and not context.executing_eagerly():
if group:
return fn(packed_var, *args, **kwargs)
else:
return (fn(packed_var, *args, **kwargs),)
# Otherwise, we revert to MirroredStrategy behavior and update the variable
# on each replica directly.
updates = []
values_and_devices = []
if packed_var is not None:
for device in packed_var.devices:
values_and_devices.append((packed_var, device))
else:
for value in var.values:
values_and_devices.append((value, value.device))
if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
var.aggregation != variables_lib.VariableAggregation.NONE):
distribute_utils.assert_mirrored(args)
distribute_utils.assert_mirrored(kwargs)
for i, value_and_device in enumerate(values_and_devices):
value = value_and_device[0]
device = value_and_device[1]
name = "update_%d" % i
with ops.device(device), \
distribute_lib.UpdateContext(i), \
ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(
fn(value, *distribute_utils.select_replica(i, args),
**distribute_utils.select_replica(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)
def read_var(self, var):
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
var, resource_variable_ops.BaseResourceVariable)
return var.read_value()
def value_container(self, value):
return value
def _broadcast_to(self, tensor, destinations):
del destinations
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
# should be converted to. Otherwise we have trouble with:
# global_step.assign_add(1)
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
if tpu_util.enclosing_tpu_context() is not None:
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
result = tpu_ops.all_to_all(
broadcast_tensor,
concat_dimension=0,
split_dimension=0,
split_count=self._num_replicas_in_sync)
# This uses the broadcasted value from the first replica because the only
# caller of this is for ONLY_FIRST_REPLICA variables aggregation.
return result[0]
return tensor
@property
def num_hosts(self):
if self._device_assignment is None:
return self._tpu_metadata.num_hosts
return len(set([self._device_assignment.host_device(r)
for r in range(self._device_assignment.num_replicas)]))
@property
def num_replicas_per_host(self):
if self._device_assignment is None:
return self._tpu_metadata.num_of_cores_per_host
# TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
# as the computation of num_replicas_per_host is not a constant
# when using device_assignment. This is a temporary workaround to support
# StatefulRNN as everything is 1 in that case.
# This method needs to take host_id as input for correct computation.
max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
self._device_assignment.num_cores_per_replica)
return min(self._device_assignment.num_replicas, max_models_per_host)
@property
def _num_replicas_in_sync(self):
if self._device_assignment is None:
return self._tpu_metadata.num_cores
return self._device_assignment.num_replicas
@property
def experimental_between_graph(self):
return False
@property
def experimental_should_init(self):
return True
@property
def should_checkpoint(self):
return True
@property
def should_save_summary(self):
return True
@property
def worker_devices(self):
return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
@property
def parameter_devices(self):
return self.worker_devices
@property
def tpu_hardware_feature(self):
"""Return the `tf.tpu.experimental.HardwareFeature` class."""
return tpu_hardware_feature.HardwareFeature(
self._tpu_cluster_resolver.tpu_hardware_feature)
def non_slot_devices(self, var_list):
return self._host_device
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
del colocate_with
with ops.device(self._host_device), distribute_lib.UpdateContext(None):
result = fn(*args, **kwargs)
if group:
return result
else:
return nest.map_structure(self._local_results, result)
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
del cluster_spec, task_type, task_id
if session_config:
session_config.CopyFrom(self._update_config_proto(session_config))
def _update_config_proto(self, config_proto):
updated_config = copy.deepcopy(config_proto)
updated_config.isolate_session_state = True
cluster_spec = self._tpu_cluster_resolver.cluster_spec()
if cluster_spec:
updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
return updated_config
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
`make_input_fn_iterator` assumes per-replica batching.
Returns:
Boolean.
"""
return True
def tpu_run(self, fn, args, kwargs, options=None):
func = self._tpu_function_creator(fn, options)
return func(args, kwargs)
def _tpu_function_creator(self, fn, options):
if context.executing_eagerly() and fn in self._tpu_function_cache:
return self._tpu_function_cache[fn]
strategy = self._container_strategy()
def tpu_function(args, kwargs):
"""TF Function used to replicate the user computation."""
logging.vlog(1,
"`TPUStrategy.run` is called with [args: %s] [kwargs: %s]",
args, kwargs)
if kwargs is None:
kwargs = {}
# Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format.
result = [[]]
def replicated_fn(replica_id, replica_args, replica_kwargs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
result[0] = fn(*replica_args, **replica_kwargs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(strategy.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
distribute_utils.select_replica(i, args),
distribute_utils.select_replica(i, kwargs)])
# Construct and pass `maximum_shapes` so that we could support dynamic
# shapes using dynamic padder.
if options.experimental_enable_dynamic_batch_size and replicate_inputs:
maximum_shapes = []
flattened_list = nest.flatten(replicate_inputs[0])
for input_tensor in flattened_list:
if tensor_util.is_tf_type(input_tensor):
rank = input_tensor.shape.rank
else:
rank = np.ndim(input_tensor)
if rank is None:
raise ValueError(
"input tensor {} to TPUStrategy.run() has unknown rank, "
"which is not allowed".format(input_tensor))
maximum_shape = tensor_shape.TensorShape([None] * rank)
maximum_shapes.append(maximum_shape)
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
maximum_shapes)
else:
maximum_shapes = None
if options.experimental_bucketizing_dynamic_shape:
padding_spec = tpu.PaddingSpec.POWER_OF_TWO
else:
padding_spec = None
with strategy.scope():
xla_options = options.experimental_xla_options or tpu.XLAOptions(
use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning)
replicate_outputs = tpu.replicate(
replicated_fn,
replicate_inputs,
device_assignment=self._device_assignment,
maximum_shapes=maximum_shapes,
padding_spec=padding_spec,
xla_options=xla_options)
# Remove all no ops that may have been added during 'tpu.replicate()'
filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)]
if isinstance(result[0], list):
result[0] = filter_ops(result[0])
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
if result[0] is None or isinstance(result[0], ops.Operation):
replicate_outputs = [None] * len(replicate_outputs)
else:
replicate_outputs = [
nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output)))
for output in replicate_outputs
]
return distribute_utils.regroup(replicate_outputs)
if context.executing_eagerly():
tpu_function = def_function.function(tpu_function)
self._tpu_function_cache[fn] = tpu_function
return tpu_function
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
# TPUStrategy has different distributed training structure that the whole
# cluster should be treated as single worker from higher-level (e.g. Keras)
# library's point of view.
# TODO(rchao): Revisit this as we design a fault-tolerance solution for
# TPUStrategy.
return False
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _make_axis_nonnegative(axis, rank):
# Convert a potentially negative `axis` to a non-negative one.
if isinstance(axis, int):
if axis >= 0:
return axis
else:
return axis + rank
else:
return array_ops.where_v2(
math_ops.greater_equal(axis, 0),
axis,
axis + rank)
# List of Tensor dtypes supported by cross_replica_sum().
_DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM = (
dtypes.bfloat16,
dtypes.float16,
dtypes.float32,
dtypes.float64,
dtypes.int32,
dtypes.uint32,
)
class _TPUReplicaContext(distribute_lib.ReplicaContext):
"""Replication Context class for TPU Strategy."""
# TODO(sourabhbajaj): Call for each replica should be updating this.
# TODO(b/118385803): Always properly initialize replica_id.
def __init__(self, strategy, replica_id_in_sync_group=0):
distribute_lib.ReplicaContext.__init__(
self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
@property
def devices(self):
distribute_lib.require_replica_context(self)
ds = self._strategy
replica_id = tensor_util.constant_value(self.replica_id_in_sync_group)
if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`.
# TODO(cjfj): Return other devices when model parallelism is supported.
return (tpu.core(0),)
else:
return (ds.extended.worker_devices[replica_id],)
def experimental_logical_device(self, logical_device_id):
"""Places variables and ops on the specified logical device."""
return self.strategy.extended.experimental_logical_device(logical_device_id)
def _compute_all_gather_output_shape(self, value_shape, value_rank, axis):
if isinstance(value_rank, int):
output_shape = list(value_shape)
output_shape[axis] *= self.num_replicas_in_sync
else:
output_shape = array_ops.where_v2(
math_ops.equal(math_ops.range(value_rank), axis),
value_shape * context.num_replicas_in_sync,
value_shape)
return output_shape
def all_gather(self, value, axis, experimental_hints=None):
del experimental_hints
for v in nest.flatten(value):
if isinstance(v, indexed_slices.IndexedSlices):
raise NotImplementedError("all_gather does not support IndexedSlices")
def _all_gather_tensor(value, axis):
value = ops.convert_to_tensor(value)
# Compute the shape and rank and rank of the input tensor. Use static
# shapes when possible to help with shape inference in graph mode, but
# fall back on dynamic shapes when necessary.
if value.shape.rank is None:
value_rank = array_ops.rank(value)
value_shape = array_ops.shape(value)
else:
value_rank = value.shape.rank
value_shape = value.shape.as_list()
value_shape_tensor = array_ops.shape(value)
for i in range(len(value_shape)):
if value_shape[i] is None:
value_shape[i] = value_shape_tensor[i]
# In the code below, we will insert a new "replica" dimension immediately
# *before* `axis`. To ensure that it's inserted before and not after, we
# must make `axis` non-negative.
axis = _make_axis_nonnegative(axis, value_rank)
# Create a list or 1D int Tensor such as
# [1, 1, ..., 1, num_replicas_in_sync, 1, ..., 1],
# which is equal to `num_replicas_in_sync` at index `axis`
# and is equal to 1 everywhere else.
if isinstance(value_rank, int):
replica_broadcast_shape = [1] * (value_rank + 1)
replica_broadcast_shape[axis] = self.num_replicas_in_sync
else:
replica_broadcast_shape = array_ops.where_v2(
math_ops.equal(math_ops.range(value_rank+1), axis),
self.num_replicas_in_sync,
1)
output_shape = self._compute_all_gather_output_shape(
value_shape, value_rank, axis)
if value.dtype in _DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM:
# optimized all_gather implementation based on cross_replica_sum().
replica_id_mask = array_ops.one_hot(
self.replica_id_in_sync_group, self.num_replicas_in_sync)
replica_id_mask = array_ops.reshape(
replica_id_mask, replica_broadcast_shape)
replica_id_mask = math_ops.cast(replica_id_mask, value.dtype)
gathered_value = array_ops.expand_dims(value, axis) * replica_id_mask
gathered_value = self.all_reduce(
reduce_util.ReduceOp.SUM, gathered_value)
return array_ops.reshape(gathered_value, output_shape)
else:
# value.dtype isn't supported by cross_replica_sum(), so we fall back
# on a less efficient implementation based on all_to_all().
# The underlying AllToAllOp first do a split of the input value and then
# cross-replica communication and concatenation of the result. So we
# concatenate the local tensor here first.
inputs = array_ops.expand_dims(value, axis=axis)
inputs = array_ops.tile(inputs, replica_broadcast_shape)
unordered_output = tpu_ops.all_to_all(
inputs,
concat_dimension=axis,
split_dimension=axis,
split_count=self.num_replicas_in_sync)
# Re-order since xla.replica_id and ReplicaContext.replica_id mismatch.
# Start by computing a permutation -- a 1D Tensor which maps
# tensor[xla.replica_id] = ReplicaContext.replica_id
concat_replica_id = array_ops.reshape(
self.replica_id_in_sync_group, [1])
concat_replica_id = array_ops.tile(
concat_replica_id, [self.num_replicas_in_sync])
xla_to_replica_context_id = tpu_ops.all_to_all(
concat_replica_id,
concat_dimension=0,
split_dimension=0,
split_count=self.num_replicas_in_sync)
# Now invert the mapping to get
# tensor[ReplicaContext.replica_id] = xla.replica_id
replica_context_to_xla_id = math_ops.argmax(
array_ops.one_hot(xla_to_replica_context_id,
self.num_replicas_in_sync),
axis=0)
# Reorder the output elements so that they're sorted based on
# ReplicaContext.replica_id instead of xla.replica_id.
sorted_with_extra_dim = array_ops.gather(
unordered_output, replica_context_to_xla_id, axis=axis)
return array_ops.reshape(sorted_with_extra_dim, output_shape)
ys = [_all_gather_tensor(t, axis=axis) for t in nest.flatten(value)]
return nest.pack_sequence_as(value, ys)
def _set_last_step_outputs(ctx, last_step_tensor_outputs):
"""Sets the last step outputs on the given context."""
# Convert replicate_outputs to the original dict structure of
# last_step_outputs.
last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs)
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that aren't reduced, return a PerReplica of all values. Else
# take the first value from the list as each value should be the same.
if reduce_op is None:
last_step_tensor_outputs_dict[name] = values.PerReplica(output)
else:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access