blob: 1c332c7155028d0875340c5d85c6dfa2ec85528d [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed values."""
import copy
from typing import Optional
import weakref
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import record
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.trackable import base as trackable
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.types import core
from tensorflow.python.types import distribute as ds_types
from tensorflow.python.types import trace
def _on_write_update_replica(var, update_fn, value, **kwargs):
"""Updates variables with ON_WRITE synchronization in replica context."""
if var.aggregation == vs.VariableAggregation.NONE:
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
if not distribute_lib.get_strategy().extended._use_merge_call(): # pylint: disable=protected-access
# Don't allow MEAN with non float dtype, since it may cause unexpected
# precision loss. Python3 and NumPy automatically upcast integers to
# float in division, but we should always preserve the type.
if var.aggregation == vs.VariableAggregation.MEAN and (
not var.dtype.is_floating) and tensor_util.is_tf_type(value):
raise ValueError(
"Cannot update non-float variables with "
"tf.VariableAggregation.MEAN aggregation in replica context. "
"Either change the variable dtype to float or update it in "
"cross-replica context.")
aggregated_value = apply_aggregation_replica_context(
value, var.aggregation, var)
values_util.mark_as_unsaveable()
return distribute_lib.get_replica_context()._update( # pylint: disable=protected-access
var,
update_fn,
args=(aggregated_value,),
kwargs=kwargs,
group=True)
else:
def merge_fn(strategy, value, **kwargs):
"""Aggregate values and update all variables in cross replica context."""
# Don't allow MEAN with non float dtype, since it may cause unexpected
# precision loss. Python3 and NumPy automatically upcast integers to
# float in division, but we should always preserve the type.
#
# Note that to be backward compatible we allow the case when the value
# is *always* the same on each replica. I.E. value is not a
# PerReplica. Refer to regroup() to see how values are grouped.
if var.aggregation == vs.VariableAggregation.MEAN and (
not var.dtype.is_floating) and isinstance(value, PerReplica):
raise ValueError(
"Cannot update non-float variables with "
"tf.VariableAggregation.MEAN aggregation in replica context. "
"Either change the variable dtype to float or update it in "
"cross-replica context.")
assert strategy == var.distribute_strategy
v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access
return distribute_lib.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
def apply_aggregation_replica_context(value, aggregation, destinations):
"""Aggregate `value` to `destinations` as specified by `aggregation`."""
# if it is a python literal, return without aggregation
if isinstance(value, DistributedValues):
raise TypeError(
"Cannot use DistributedValues to update variables in replica context.")
if not tensor_util.is_tf_type(value):
return value
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
# Switch to cross-replica context to broadcast
def merge_fn(strategy, value):
return strategy.extended.broadcast_to(
strategy.experimental_local_results(value)[0],
destinations=destinations)
return distribute_lib.get_replica_context().merge_call(
merge_fn, args=(value,))
else:
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
aggregated_value = distribute_lib.get_strategy( # pylint: disable=protected-access
).extended._replica_ctx_all_reduce(reduce_op, value)
return aggregated_value
class DistributedValues(ds_types.DistributedValues):
"""Base class for representing distributed values."""
def __init__(self, values):
"""Should only be called by subclass __init__."""
self._values = tuple(values)
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica()
else:
return self._values[replica_id]
def _get_cross_replica(self):
raise NotImplementedError(
"DistributedValues._get_cross_replica should be implemented by "
"sub-classes which support cross-replica accesses.")
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
# Try to find a value on the current device.
current_device = device_util.canonicalize(device_util.current())
for value in self._values:
if device_util.canonicalize(value.device) == current_device:
return value
return self._primary
else:
return self._values[replica_id]
@property
def _primary(self):
"""Returns a representative component."""
return self._values[0]
@property
def _devices(self):
return tuple(v.device for v in self._values)
def __str__(self):
debug_str = ",\n".join(
" %d: %s" % (i, v) for i, v in enumerate(self._values))
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
def __repr__(self):
debug_repr = ",\n".join(
" %d: %r" % (i, v) for i, v in enumerate(self._values))
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
# initialized with and use that to generate the overloaded operators here.
# Unfortunately, Python's rules for special methods don't allow this, see
# https://docs.python.org/3/reference/datamodel.html#special-method-names
# "if a class defines a method named __getitem__(), and x is an instance of
# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
# In particular, these special methods don't go through __getattr__, and
# it will only use those methods if they are defined in the class, not the
# object.
class DistributedDelegate(DistributedValues):
"""A map from device to values; acts as the same type as the values."""
def __getattr__(self, name):
# The '_use_resource_variables' and the attrs starts with '_self' are used
# for restoring the saved_model proto, and '_attribute_sentinel' is used for
# Layer tracking. At the point these attrs are queried, the variable has not
# been initialized. Thus it should not query those of the underlying
# components.
if name.startswith("_self_") or name in ("_use_resource_variables",
"_attribute_sentinel",
"_distributed_container"):
return super(DistributedDelegate, self).__getattr__(name)
# This allows copy.copy(DistributedDelegate). When copying an object,
# copy.copy doesn't invoke its __init__ method, instead it makes a new
# empty object, then copies the attributes over. copy.copy looks for
# attributes like "__getstate__" in case the object implements its custom
# copying. Since DistributedDelegate doesn't have those attributes defined,
# __getattr__ will be invoked, which tries to access "_values" attributes,
# but that doesn't exist either because this is an empty object, and again
# __getattr__ is invoked, leading to an infinite recursion.
if name == "_values":
raise AttributeError()
# TODO(priyag): This needs to be made robust against pitfalls from mix use
# __getattr__ and @property. See b/120402273.
return getattr(self._get(), name)
@property
def values(self):
"""Returns the per replica values."""
return self._values
def _get_as_operand(self):
"""Returns the value for operations for the current device.
Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
value type within a replica context. They can, however, return a value that
can be used by the operations below.
"""
return self._get()
# pylint: disable=multiple-statements
def __add__(self, o):
return self._get_as_operand() + o
def __radd__(self, o):
return o + self._get_as_operand()
def __sub__(self, o):
return self._get_as_operand() - o
def __rsub__(self, o):
return o - self._get_as_operand()
def __mul__(self, o):
return self._get_as_operand() * o
def __rmul__(self, o):
return o * self._get_as_operand()
def __truediv__(self, o):
return self._get_as_operand() / o
def __rtruediv__(self, o):
return o / self._get_as_operand()
def __floordiv__(self, o):
return self._get_as_operand() // o
def __rfloordiv__(self, o):
return o // self._get_as_operand()
def __mod__(self, o):
return self._get_as_operand() % o
def __rmod__(self, o):
return o % self._get_as_operand()
def __lt__(self, o):
return self._get_as_operand() < o
def __le__(self, o):
return self._get_as_operand() <= o
def __gt__(self, o):
return self._get_as_operand() > o
def __ge__(self, o):
return self._get_as_operand() >= o
def __and__(self, o):
return self._get_as_operand() & o
def __rand__(self, o):
return o & self._get_as_operand()
def __or__(self, o):
return self._get_as_operand() | o
def __ror__(self, o):
return o | self._get_as_operand()
def __xor__(self, o):
return self._get_as_operand() ^ o
def __rxor__(self, o):
return o ^ self._get_as_operand()
def __getitem__(self, o):
return self._get_as_operand()[o]
def __pow__(self, o, modulo=None):
return pow(self._get_as_operand(), o, modulo)
def __rpow__(self, o):
return pow(o, self._get_as_operand())
def __invert__(self):
return ~self._get_as_operand()
def __neg__(self):
return -self._get_as_operand()
def __abs__(self):
return abs(self._get_as_operand())
def __div__(self, o):
try:
return self._get_as_operand().__div__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rdiv__(self, o):
try:
return self._get_as_operand().__rdiv__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __matmul__(self, o):
try:
return self._get_as_operand().__matmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rmatmul__(self, o):
try:
return self._get_as_operand().__rmatmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
# TODO(josh11b): Even more operator overloads.
class PerReplica(DistributedValues, composite_tensor.CompositeTensor,
ds_types.PerReplica):
"""Holds a map from replica to unsynchronized values."""
@property
def _type_spec(self):
return PerReplicaSpec(
*(type_spec.type_spec_from_value(v) for v in self._values))
@property
def values(self):
"""Returns the per replica values."""
return self._values
def _per_replica_to_tensor(var, dtype=None, name=None, as_ref=False):
"""Converts a `PerReplica` to a `Tensor`."""
del name
if dtype is not None and not dtype.is_compatible_with(var.dtype):
raise ValueError(
"Incompatible type conversion requested to type {!r} for variable "
"of type {!r}".format(dtype.name, var.dtype.name))
if as_ref:
raise NotImplementedError(
"PerReplica doesn't support being used as a reference.")
if (distribute_lib.in_cross_replica_context() or
not distribute_lib.has_strategy()):
raise ValueError("It looks like you are using a PerReplica object while "
"not inside a replica context, which is not supported. "
"Try running your op or function inside a replica context "
"by using `strategy.run`")
else:
replica_id = values_util.get_current_replica_id_as_int()
return var.values[replica_id]
# Register a conversion function to provide a useful error message when users
# try to use PerReplica values in the wrong contexts
tensor_conversion_registry.register_tensor_conversion_function(
PerReplica, _per_replica_to_tensor)
class PerReplicaSpec(type_spec.TypeSpec):
"""Type specification for a `PerReplica`."""
__slots__ = ["_value_specs"]
value_type = property(lambda self: PerReplica)
def __init__(self, *value_specs):
self._value_specs = tuple(value_specs)
def _serialize(self):
return self._value_specs
@property
def _component_specs(self):
return self._value_specs
def _to_components(self, value):
replica_context = distribute_lib.get_replica_context()
if replica_context is not None and replica_context.num_replicas_in_sync > 1:
raise ValueError(
"Flattening a PerReplica to components is not supported in replica "
"context.")
return value._values # pylint: disable=protected-access
def _from_components(self, tensor_list):
return PerReplica(tensor_list)
nested_structure_coder.register_codec(
nested_structure_coder.BuiltInTypeSpecCodec(
PerReplicaSpec, struct_pb2.TypeSpecProto.PER_REPLICA_SPEC
)
)
# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
# TODO(tomhennigan) Should this extend CompositeTensor?
class Mirrored(DistributedDelegate, ds_types.Mirrored):
"""Holds a map from replica to values which are kept in sync."""
def _get_cross_replica(self):
return self._get_on_device_or_primary()
def _as_graph_element(self):
obj = self._get()
conv_fn = getattr(obj, "_as_graph_element", None)
if conv_fn and callable(conv_fn):
return conv_fn()
return obj
def _is_mirrored(self):
return True
class DistributedVarOp(object):
"""A class that looks like `tf.Operation`."""
def __init__(self, name, graph, traceback, typ):
self.name = name
self.graph = graph
self.traceback = traceback
self.type = typ
def __eq__(self, o):
if not isinstance(o, self.__class__):
raise NotImplementedError
return (self.name == o.name and self.graph == o.graph and
self.traceback == o.traceback and self.type == o.type)
def __hash__(self):
return hash((self.name, self.graph, tuple(self.traceback), self.type))
# TODO(b/209081027): Remove this once Variable is a CompositeTensor.
class DistributedVariableTraceType(trace.TraceType):
"""TraceType of DistributedVariable objects."""
def __init__(self, distributed_variable):
self.distributed_variable = distributed_variable
self.components = (tuple(distributed_variable.shape.as_list()),
distributed_variable.dtype)
def is_subtype_of(self, other):
return self == other
def most_specific_common_supertype(self, others):
return self if all(self == other for other in others) else None
def placeholder_value(self, placeholder_context=None):
return self.distributed_variable
def _to_tensors(self, value):
return []
def __hash__(self) -> int:
return hash(self.components)
def __eq__(self, other) -> bool:
if not isinstance(other, DistributedVariableTraceType):
return False
return self.components == other.components
class DistributedVariable(DistributedDelegate, variables_lib.Variable,
core.Tensor):
"""Holds a map from replica to variables."""
def __init__(self, strategy, values, aggregation, var_policy=None):
if (aggregation == variables_lib.VariableAggregation.MEAN and
not values[0].dtype.is_floating):
raise ValueError(
"creating distributed tf.Variable with aggregation=MEAN and a "
"non-floating dtype is not supported, please use a different "
"aggregation or dtype")
self._distribute_strategy = strategy
self._aggregation = aggregation
super(DistributedVariable, self).__init__(values)
self._common_name = self._primary.name.split(":")[0]
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in values:
# ResourceVariable is a CompositeTensor. Attributes added to
# CompositeTensors will get lost through tf.nest packing and unpacking.
if isinstance(v, composite_tensor.CompositeTensor) and hasattr(
v, "handle"):
v.handle._distributed_container = weakref.ref(self) # pylint: disable=protected-access
else:
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
if ops.executing_eagerly_outside_functions() and getattr(
strategy, "_enable_packed_variable_in_eager_mode", False):
name = "%s/packed/" % self._common_name
if hasattr(values[0], "_vars"):
# Handle when the resource variables are "nested" underneath another
# layer of values, e.g., TPUReplicatedVariable, by packing all them
# together and pushing the packed var down a level
# pylint: disable=protected-access
packed_var = packed.PackedDistributedVariable(
sum((value._vars for value in values), []), name=name)
for value in values:
value._packed_var = packed_var
self._packed_var = None
# pylint: enable=protected-access
else:
self._packed_var = packed.PackedDistributedVariable(values, name=name)
else:
self._packed_var = None
# tf.keras keeps track of variables initialized using this attribute. When
# tf.keras gets the default session, it initializes all uninitialized vars.
# We need to make _keras_initialized a member of DistributedVariable because
# without this it will use `__getattr__` which will delegate to a component
# variable.
self._keras_initialized = False
# Typically, a `DistributedVariable`'s initializer is composed of the
# initializers of the components variables. However, in some cases, such as
# when restoring from a checkpoint, we may set the _initializer_op
# property on the entire `DistributedVariable`.
self._initializer_op = None
# Set a VariablePolicy which decides how we replicate/aggregate the given
# variable.
self._policy = var_policy
def __deepcopy__(self, memo):
"""Perform a deepcopy of the `DistributedVariable`.
Unlike the deepcopy of a regular tf.Variable, this keeps the original
strategy and devices of the `DistributedVariable`. To avoid confusion
with the behavior of deepcopy on a regular `Variable` (which does
copy into new devices), we only allow a deepcopy of a `DistributedVariable`
within its originating strategy scope.
Args:
memo: The memoization object for `deepcopy`.
Returns:
A deep copy of the current `DistributedVariable`.
Raises:
RuntimeError: If trying to deepcopy into a different strategy.
"""
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
new_values = []
for value in self._values:
with ops.device(value.device):
new_values.append(copy.deepcopy(value, memo))
copied_variable = type(self)(
strategy=self._distribute_strategy,
values=new_values,
aggregation=self._aggregation,
var_policy=copy.deepcopy(self._policy, memo))
memo[id(self)] = copied_variable
return copied_variable
def _use_packed_variable(self):
# Don't use packed variable when under a SaveContext to avoid explicit
# device placement on variable consuming ops.
return self._packed_var is not None and (
not values_util.is_saving_non_distributed())
def is_initialized(self, name=None):
"""Identifies if all the component variables are initialized.
Args:
name: Name of the final `logical_and` op.
Returns:
The op that evaluates to True or False depending on if all the
component variables are initialized.
"""
if values_util.is_saving_non_distributed():
return self._primary.is_initialized()
if self._use_packed_variable():
return self._packed_var.is_initialized()
result = self._primary.is_initialized()
# We iterate through the list of values except the last one to allow us to
# name the final `logical_and` op the same name that is passed by the user
# to the `is_initialized` op. For distributed variables, the
# `is_initialized` op is a `logical_and` op.
for v in self._values[1:-1]:
result = math_ops.logical_and(result, v.is_initialized())
result = math_ops.logical_and(
result, self._values[-1].is_initialized(), name=name)
return result
@property
def initializer(self):
if values_util.is_saving_non_distributed():
return self._primary.initializer
if self._initializer_op:
init_op = self._initializer_op
else:
# return grouped ops of all the var initializations of component values of
# the mirrored variable
init_op = control_flow_ops.group(
tuple(v.initializer for v in self._values))
return init_op
def initialized_value(self):
return self._get_on_device_or_primary().initialized_value()
def _is_mirrored(self):
return (self._policy is not None) and (self._policy._is_mirrored()) # pylint: disable=protected-access
@property
def initial_value(self):
return self._get_on_device_or_primary().initial_value
@property
def constraint(self):
return self._primary.constraint
@property
def graph(self):
return self._primary.graph
@property
def _shared_name(self):
return self._common_name
@property
def _unique_id(self):
return self._primary._unique_id # pylint: disable=protected-access
@property
def _graph_key(self):
"""Lets Optimizers know which graph this variable is from."""
return self._primary._graph_key # pylint: disable=protected-access
@property
def name(self):
return self._primary.name
@property
def dtype(self):
return self._primary.dtype
@property
def shape(self):
return self._primary.shape
@property
def synchronization(self):
return self._primary.synchronization
@property
def aggregation(self):
return self._aggregation
@property
def _packed_variable(self):
if self._use_packed_variable():
return self._packed_var
return None
@property
def handle(self):
if values_util.is_saving_non_distributed():
return self._primary.handle
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
raise ValueError(
"DistributedVariable.handle is not available outside the replica "
"context or a `tf.distribute.Strategy.update()` call.")
else:
if self._use_packed_variable():
return self._packed_var.handle
return self._values[replica_id].handle
def eval(self, session=None):
return self._get_on_device_or_primary().eval(session)
@property
def _save_slice_info(self):
return self._primary._save_slice_info # pylint: disable=protected-access
def _get_save_slice_info(self):
return self._primary._get_save_slice_info() # pylint: disable=protected-access
def _set_save_slice_info(self, save_slice_info):
for v in self._values:
v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access
@property
def device(self):
return self._get_on_device_or_primary().device
@property
def trainable(self):
return self._primary.trainable
@property
def distribute_strategy(self):
return self._distribute_strategy
def get_shape(self):
return self._primary.get_shape()
def to_proto(self, export_scope=None):
return self._primary.to_proto(export_scope=export_scope)
@property
def op(self):
if values_util.is_saving_non_distributed():
return self._primary.op
# We want cross-replica code that does some var.op.X calls
# to work (even if the current device isn't in self._devices), but
# other uses of var.op in a cross-replica context to fail.
if distribute_lib.in_cross_replica_context():
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
self._primary.op.traceback, self._primary.op.type)
return self._get().op
@property
def _in_graph_mode(self):
return self._primary._in_graph_mode # pylint: disable=protected-access
def _get_replica(self, replica_id):
"""Returns the value on a device with the given replica_id."""
value = self._values[replica_id]
if self._use_packed_variable():
return self._packed_var.on_device(value.device)
else:
return value
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
if values_util.is_saving_non_distributed():
return self._primary
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica()
else:
return self._get_replica(replica_id)
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
if values_util.is_saving_non_distributed():
return self._primary
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
# Try to find a value on the current device.
current_device = device_util.canonicalize(device_util.current())
for i, value in enumerate(self._values):
if device_util.canonicalize(value.device) == current_device:
return self._get_replica(i)
return self._get_replica(0)
else:
return self._get_replica(replica_id)
def read_value(self):
if values_util.is_saving_non_distributed():
return self._primary.read_value()
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
return array_ops.identity(self._get())
def value(self):
if values_util.is_saving_non_distributed():
return self._primary.value()
if self._policy:
return self._policy.value(self)
return self._get_on_device_or_primary().value()
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
else:
raise NotImplementedError("DistributedVariable.numpy() is only available "
"when eager execution is enabled.")
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_sub(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign_sub(
self,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
return values_util.on_write_assign_sub(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_add(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign_add(
self,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
return values_util.on_write_assign_add(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign(value, use_locking, name, read_value)
if self._policy:
return self._policy.assign(
self,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
return values_util.on_write_assign(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_sub(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_add(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_mul(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_div(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
def __tf_tracing_type__(self, _):
return DistributedVariableTraceType(self)
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
DistributedVariables.
Returns:
A dictionary mapping attribute names to `SaveableObject` factories.
"""
def _saveable_factory(name=self._common_name):
return _DistributedVariableSaveable(self, self._primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _as_graph_element(self):
if values_util.is_saving_non_distributed():
return self._primary._as_graph_element() # pylint: disable=protected-access
if self._policy:
return self._policy._as_graph_element(self) # pylint: disable=protected-access
raise NotImplementedError(
"DistributedVariable._as_graph_element requires a valid "
"VariablePolicy. Please set the policy via the `var_policy` argument "
"in the constructor, or override this method in sub-classes which "
"support cross-replica accesses.")
def _get_cross_replica(self):
if values_util.is_saving_non_distributed():
return self._primary
if self._policy:
return self._policy._get_cross_replica(self) # pylint: disable=protected-access
raise NotImplementedError(
"DistributedVariable._get_cross_replica requires a valid "
"VariablePolicy. Please set the policy via the `var_policy` argument "
"in the constructor, or override this method in sub-classes which "
"support cross-replica accesses.")
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
Args:
update_fn: A callable to pass to `strategy.extended.update` to update the
variable. It should has the same signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: remaining arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
values_util.mark_as_unsaveable()
return self.distribute_strategy.extended.update(
self, update_fn, args=(value,), kwargs=kwargs, group=True)
def _update_replica(self, update_fn, value, **kwargs):
"""Applies updates in one replica.
Args:
update_fn: A callable to update the variable. It should has the same
signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: remaining arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
if self._policy:
return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access
raise NotImplementedError(
"DistributedVariable._update_replica requires a valid VariablePolicy. "
"Please set the policy via the `var_policy` argument in the "
"constructor, or override this method in sub-classes which support "
"cross-replica accesses.")
def _update(self, update_fn, value, **kwargs):
"""Applies updates depending on the context.
The method calls `_update_replica` in replica context,
`_update_cross_replica` in cross replica context, and `update_fn` in update
context.
If `read_value` is True, the method returns the updated Variable. If
`read_value` is False, the method returns the update `tf.Operation`.
Args:
update_fn: A callable to pass to `strategy.extended.update` to update the
variable. It should have the same signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: keyword arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
if values_util.is_saving_non_distributed():
return update_fn(self._primary, value, **kwargs)
with distribute_lib.enter_or_assert_strategy(self.distribute_strategy):
if distribute_lib.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
replica_value = self._get_replica(update_replica_id)
return update_fn(replica_value, value, **kwargs)
return self._update_cross_replica(update_fn, value, **kwargs)
else:
values_util.assert_replica_context(self.distribute_strategy)
return self._update_replica(update_fn, value, **kwargs)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
if values_util.is_saving_non_distributed():
return ops.convert_to_tensor(
self._primary, dtype=dtype, name=name, as_ref=as_ref)
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
def __tf_tensor__(self,
dtype: Optional[dtypes.DType] = None,
name: Optional[str] = None) -> ops.Tensor:
return self._dense_var_to_tensor(dtype, name)
def _export_to_saved_model_graph(self,
object_map=None,
tensor_map=None,
options=None,
**kwargs):
# Initialize for self._primary first, so that obj_map[self._primary] and
# resource_map[self._primary.handle] contain mapped values.
resource_list = self._primary._export_to_saved_model_graph( # pylint:disable=protected-access
object_map=object_map,
tensor_map=tensor_map,
options=options,
**kwargs)
for v in [v for v in self._values if v != self._primary]:
if (options.experimental_variable_policy # pylint:disable=protected-access
._expand_distributed_variables()):
resource_list.extend(
v._export_to_saved_model_graph( # pylint:disable=protected-access
object_map=object_map,
tensor_map=tensor_map,
options=options,
**kwargs)) # pylint:disable=protected-access
else:
object_map[v] = object_map[self._primary]
tensor_map[v.handle] = tensor_map[self._primary.handle]
resource_list.append(v.handle)
object_map[self] = object_map[self._primary]
tensor_map[self] = tensor_map[self._primary.handle]
resource_list.append(self)
if self._packed_var is not None:
tensor_map[self._packed_var.packed_handle] = tensor_map[
self._primary.handle]
resource_list.append(self._packed_var.packed_handle)
return resource_list
def _write_object_proto(self, proto, options):
"""Update a SavedObject proto for the caller.
If a DistributedVariable object supports this method, it will be called when
saving with a pre-built `SavedObject` proto representing the object, plus an
instance of `SaveOptions`. This method is then free to modify that proto
instance.
`DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
write out information about their components to the
`experimental_distributed_variable_components` field of a
`SavedVariable` (depending on the `SaveOptions` variable policy).
Args:
proto: A pre-built `SavedObject` proto for this object. It is assumed this
will be a `SavedVariable` instance.
options: A `SaveOptions` instance.
"""
resource_variable_ops.write_object_proto_for_resource_variable(
self, proto, options)
if self._is_mirrored():
values_util.write_object_proto(self, proto, options)
@property
def is_distributed_variable(self):
return True
def __tf_experimental_restore_capture__(
self, concrete_function, internal_capture):
graph = concrete_function.graph
# Add given distributed variable to captures with given placeholder.
graph.replace_capture(self, internal_capture)
record.record_operation(
"captured_value", [internal_capture], [self],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
return self
# We extend from `saveable_object.SaveableObject` instead of
# `saveable_object_util.ResourceVariableSaveable` since we need to read the
# value of ONREAD variables when saving. `SaveableObject` provides a way to
# specify the function to run to get the value of the variable or tensor at
# saving time. We can use this for both ON_READ and ON_WRITE variables.
# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
# if possible.
class _DistributedVariableSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a DistributedVariable."""
def __init__(self, distributed_variable, primary_variable, name):
self._distributed_variable = distributed_variable
if not self._distributed_variable._policy:
raise ValueError(
"The VariablePolicy of the argument `distributed_variable` must be "
"set to create a _DistributedVariableSaveable. Please set it via "
"the `var_policy` argument in the constructor of DistributedVariable."
)
tensor, spec = distributed_variable._policy.get_saveable(
distributed_variable, primary_variable, name)
super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access
self._distributed_variable, tensor)
class _MirroredSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a MirroredVariable."""
def __init__(self, mirrored_variable, primary_variable, name):
self._mirrored_variable = mirrored_variable
tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
primary_variable, name)
super(_MirroredSaveable, self).__init__(tensor, spec, name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync."""
def _is_mirrored(self):
return Mirrored._is_mirrored(self) # Use correct parent class.
def _update_replica(self, update_fn, value, **kwargs):
return _on_write_update_replica(self, update_fn, value, **kwargs)
def scatter_min(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
def scatter_max(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_max", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
def scatter_update(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(*args, **kwargs)
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_update", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
def _get_cross_replica(self):
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(Mirrored._get_cross_replica(self))
def _as_graph_element(self):
return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
MirroredVariables.
Returns:
A dictionary mapping attribute names to `SaveableObject` factories.
"""
def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self._primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
# and ON_WRITE.
# Try to avoid assignments to and other mutations of MirroredVariable
# state except through a DistributionStrategy.extended.update() or any of
# the `assign*` and `scatter*` calls.
if as_ref:
# A TF 1.x case where the variable is a boolean variable and used like:
# tf.cond(v, true_fn, false_fn).
raise ValueError(
"You may be using variable created under distribute strategy in TF "
"1.x control flows. Try explicitly converting the variable to Tensor "
"using variable.read_value(), or switch to TF 2.x.")
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
class _SyncOnReadSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a SyncOnReadVariable."""
def __init__(self, sync_on_read_variable, name):
self._sync_on_read_variable = sync_on_read_variable
tensor, spec = values_util.get_on_read_saveable(
sync_on_read_variable, sync_on_read_variable._primary, name)
super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return values_util.get_on_read_restore_ops(
self._sync_on_read_variable, tensor,
self._sync_on_read_variable.aggregation)
class SyncOnReadVariable(DistributedVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def _update_replica(self, update_fn, value, **kwargs):
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
def _get(self):
"""Returns the value of SyncOnReadVariable based on surrounding context.
If called under a non-default replica-context, returns the corresponding
variable on that replica.
If called under default replica-context or cross-replica context, returns
the synced value.
"""
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
return super(SyncOnReadVariable, self)._get()
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
# with MirroredVariable.
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_sub(value, use_locking, name, read_value)
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_sub_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable,
self).assign_sub(value, use_locking, name, read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_add(value, use_locking, name, read_value)
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_add_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable,
self).assign_add(value, use_locking, name, read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign(value, use_locking, name, read_value)
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable, self).assign(value, use_locking, name,
read_value)
def _scatter_not_implemented(self, method):
raise NotImplementedError(
f"Variables with `synchronization=ON_READ` doesn't support `{method}`")
def scatter_sub(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_sub(*args, **kwargs)
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_add(*args, **kwargs)
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_mul(*args, **kwargs)
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_div(*args, **kwargs)
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(*args, **kwargs)
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(*args, **kwargs)
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(*args, **kwargs)
self._scatter_not_implemented("scatter_update")
def value(self):
if distribute_lib.in_variable_sync_on_read_context():
raise NotImplementedError(
"call `variable.value()` inside variable_sync_on_read_context is not "
"supported")
if values_util.is_saving_non_distributed():
return self._primary.value()
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._get_replica(0).value()
return self._get_cross_replica()
else:
# _get_on_device_or_primary() returns a Variable.
return self._get_on_device_or_primary().value()
def read_value(self):
if distribute_lib.in_variable_sync_on_read_context():
raise NotImplementedError(
"call `variable.read_value()` inside variable_sync_on_read_context is"
" not supported")
return super().read_value()
def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
# Consider returning a tensor value here to make the return value of
# _get_cross_replica consistent.
return self._get_replica(0)
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable()
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
self,
axis=None)
def _as_graph_element(self):
if values_util.is_saving_non_distributed():
return self._primary._as_graph_element() # pylint: disable=protected-access
# pylint: disable=protected-access
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
if distribute_lib.in_cross_replica_context():
return ops.convert_to_tensor(self._get_cross_replica())
return self._get()._as_graph_element()
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
`SyncOnReadVariable`s.
Returns:
A dictionary mapping attribute names to `SaveableObject` factories.
"""
def _saveable_factory(name=self._common_name):
return _SyncOnReadSaveable(self, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a SyncOnReadVariable to a tensor."""
if values_util.is_saving_non_distributed():
return ops.convert_to_tensor(
self._primary, dtype=dtype, name=name, as_ref=as_ref)
with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
replica_context = distribute_lib.get_replica_context()
if (replica_context is not None and
distribute_lib.in_variable_sync_on_read_context()):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return ops.convert_to_tensor(
self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref)
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable()
# pylint: disable=protected-access
reduced = (
replica_context.strategy.extended._replica_ctx_all_reduce(
reduce_util.ReduceOp.from_variable_aggregation(
self._aggregation),
self._get().read_value()))
return ops.convert_to_tensor(
reduced, dtype=dtype, name=name, as_ref=as_ref)
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
# Register a conversion functions which reads the value of the variable,
# allowing instances of the class to be used as tensors.
# DistributedVariable
def _tensor_conversion_distributed_var(var,
dtype=None,
name=None,
as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
tensor_conversion_registry.register_tensor_conversion_function(
DistributedVariable, _tensor_conversion_distributed_var)
# MirroredVariables
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
tensor_conversion_registry.register_tensor_conversion_function(
MirroredVariable, _tensor_conversion_mirrored)
# Mirrored Values
def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
return ops.convert_to_tensor(
value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
tensor_conversion_registry.register_tensor_conversion_function(
Mirrored, _tensor_conversion_mirrored_val)
# SyncOnReadVariables
def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
tensor_conversion_registry.register_tensor_conversion_function(
SyncOnReadVariable, _tensor_conversion_sync_on_read)
class VariablePolicy(object):
"""Policy defining synchronization and aggregation of a distributed variable.
Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
during variable creation within `tf.distribute` scope, `tf.distribute` creates
an appropriate policy object and assigns it to the distributed variable. All
variable operations are delegated to the respective policy object.
"""
def __init__(self, aggregation):
self._aggregation = aggregation
def value(self):
raise NotImplementedError(
"VariablePolicy.value should be overriden by sub-classes.")
def _is_mirrored(self):
raise NotImplementedError(
"VariablePolicy._is_mirrored should be overriden by sub-classes.")
def _as_graph_element(self, _):
raise NotImplementedError(
"VariablePolicy._as_graph_element should be overriden by sub-classes.")
def _get_cross_replica(self, var):
raise NotImplementedError(
"VariablePolicy._get_cross_replica should be overriden by sub-classes.")
def _update_replica(self, var, update_fn, value, **kwargs):
raise NotImplementedError(
"VariablePolicy._update_replica should be overriden by sub-classes.")
class OnReadPolicy(VariablePolicy):
"""Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
This policy is created when `synchronization` is set to
`tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
`MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
scope.
"""
def _is_mirrored(self):
return False
def value(self, var):
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._get_replica(0).value() # pylint: disable=protected-access
return var._get_cross_replica() # pylint: disable=protected-access
else:
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
def _as_graph_element(self, var):
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
if distribute_lib.in_cross_replica_context():
return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access
return var._get()._as_graph_element() # pylint: disable=protected-access
def _get_cross_replica(self, var):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._get_replica(0) # pylint: disable=protected-access
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable()
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
return var.distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
var,
axis=None)
def _update_replica(self, var, update_fn, value, **kwargs):
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
def _scatter_not_implemented(self, method):
raise NotImplementedError(f"ON_READ variables doesn't support `{method}` "
"in cross replica context")
def assign_sub(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
"""Subtracts a value from this variable."""
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_sub_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign_sub(
var,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
"""Adds a value to this variable."""
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_add_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign_add(
var,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
if (distribute_lib.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign(
var,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_update")
def get_saveable(self, var, primary_var, name):
"""Create a saveable object for the given variable."""
return values_util.get_on_read_saveable(var, primary_var, name)
def get_restore_ops(self, var, tensor):
"""Restore the same value into all variables."""
return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
class OnWritePolicy(VariablePolicy):
"""Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
This policy is created when the following `synchronization` and `aggregation`
parameters are specified when creating a `tf.Variable` in `tf.distribute`
scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE`
or `tf.VariableSynchronization.AUTO`.
"""
def _is_mirrored(self):
return True
def value(self, var):
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
def _as_graph_element(self, var):
return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
def _get_cross_replica(self, var):
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access
def _update_replica(self, var, update_fn, value, **kwargs):
if var.aggregation == variables_lib.VariableAggregation.NONE:
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
return _on_write_update_replica(var, update_fn, value, **kwargs)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
return values_util.on_write_assign(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
return values_util.on_write_assign_add(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_sub(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
return values_util.on_write_assign_sub(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_sub(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_add(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_mul(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_div(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return values_util.scatter_min(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_max", aggregation=self._aggregation))
return values_util.scatter_max(
var, sparse_delta, use_locking=use_locking, name=name)
def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(
values_util.scatter_error_msg.format(
op_name="scatter_update", aggregation=self._aggregation))
return values_util.scatter_update(
var, sparse_delta, use_locking=use_locking, name=name)
def get_saveable(self, var, primary_var, name):
"""Saveable ops for AUTO variables."""
return values_util.get_on_write_saveable(var, primary_var, name)
def get_restore_ops(self, var, tensor):
return values_util.get_on_write_restore_ops(var, tensor)
class PerWorkerResource():
"""A per-worker CapturableResource class for non-ParameterServer strategy.
Resources that populate `host_to_resources` should be instances of classes
subclassing CapturableResource, although currently it's only used and tested
for StaticHashTable with TPUStrategy.
"""
def __init__(self, strategy, host_to_resources):
distribute_lib.distribution_strategy_input_api_counter.get_cell(
"PerWorkerResource", "TPUDistributedLookupTable").increase_by(1)
self._strategy = strategy
self._host_to_resources = host_to_resources
def __getattribute__(self, name):
if name not in ("__init__", "__getattribute__", "_host_to_resources",
"_strategy", "local_resource"):
return getattr(self.local_resource(), name)
return super(PerWorkerResource, self).__getattribute__(name)
def __setattr__(self, name, value):
if name not in ("_strategy", "_host_to_resources"):
return setattr(self.local_resource(), name, value)
return super(PerWorkerResource, self).__setattr__(name, value)
def local_resource(self):
"""Returns the resource on the local worker."""
current_device = device_util.canonicalize(device_util.current())
host_device = device_util.canonicalize(
device_util.get_host_for_device(current_device))
return self._host_to_resources.get(
host_device,
self._host_to_resources[next(iter(self._host_to_resources))])