blob: 37d24e08d2bafb92ef0bc40332ff4bacf5c491c7 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Variable class that is replicated to logical cores for model parallelism."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import abc
import contextlib
from tensorflow.python.compiler.xla.experimental import xla_sharding
from tensorflow.python.distribute import tpu_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_tpu_partition_ops as tpu_partition_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.trackable import base as trackable
def _on_device_update(update_fn, var, value, **kwargs):
with ops.device(var.device):
return update_fn(var, value, **kwargs)
class TPUReplicatedVariable(variables_lib.Variable):
"""Container for replicated `Variables` that are treated as a single variable.
This class maintains a list of replicated variables that are stored on
separate logic TPU devices. TF2XLA bridge accesses these variables as
if they were a single variable.
"""
def __init__(self, variables, name='TPUReplicatedVariable'):
"""Treats `variables` as a replicated list of `tf.Variable`s.
Example:
```
variables = [
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
]
replicated_variable = TPUReplicatedVariable(variables)
assert replicated_variable.shape.as_list() == [10, 100]
```
Args:
variables: A list of `ResourceVariable`s that comprise this replicated
variable. Variables should not be shared between different
`TPUReplicatedVariable` objects.
name: String. Name of this container. Defaults to "TPUReplicatedVariable".
"""
if not isinstance(variables, abc.Sequence) or not variables or any(
not isinstance(v, variables_lib.Variable) for v in variables):
raise TypeError('Argument `variables` should be a non-empty list of '
f'`variables.Variable`s. Received {variables}')
if any(v.dtype != variables[0].dtype for v in variables):
raise ValueError(
'All elements in argument `variables` must have the same dtype. '
f'Received dtypes: {[v.dtype for v in variables]}')
if any(v.shape != variables[0].shape for v in variables):
raise ValueError(
'All elements in argument `variables` must have the same shape. '
f'Received shapes: {[v.shape for v in variables]}')
self._vars = variables
self._name = name
self._common_name = self._name.split(':')[0]
self._cached_value = None
def __iter__(self):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._vars)
@property
def name(self):
"""The name of this object. Used for checkpointing."""
return self._name
@property
def dtype(self):
"""The dtype of all `Variable`s in this object."""
return self._vars[0].dtype
@property
def is_initialized(self):
return self._vars[0].is_initialized
@property
def trainable(self):
return self._vars[0].trainable
@property
def device(self):
"""The device this variable is on."""
return self._vars[0].device
@contextlib.contextmanager
def _handle_graph(self):
with self.handle.graph.as_default():
yield
@contextlib.contextmanager
def _assign_dependencies(self):
if self._cached_value is not None:
with ops.control_dependencies([self._cached_value]):
yield
else:
yield
@property
def constraint(self):
return self._vars[0].constraint
@property
def _in_graph_mode(self):
return self._vars[0]._in_graph_mode # pylint: disable=protected-access
@property
def _unique_id(self):
return self._vars[0]._unique_id # pylint: disable=protected-access
@property
def graph(self):
return self._vars[0].graph
@property
def _shared_name(self):
return self._common_name
@property
def synchronization(self):
return variable_scope.VariableSynchronization.NONE
@property
def aggregation(self):
return variable_scope.VariableAggregation.NONE
@property
def variables(self):
"""The list of `Variables`."""
if save_context.in_save_context():
return [self._vars[0]]
return self._vars
def _export_to_saved_model_graph(self, object_map, tensor_map,
options, **kwargs):
"""For implementing `Trackable`."""
first_var = self._vars[0]
resource_list = first_var._export_to_saved_model_graph( # pylint:disable=protected-access
object_map, tensor_map, options, **kwargs)
for v in self._vars[1:]:
object_map[v] = object_map[first_var]
tensor_map[v.handle] = tensor_map[first_var.handle]
resource_list.append(v.handle)
object_map[self] = object_map[first_var]
tensor_map[self] = tensor_map[first_var.handle]
resource_list.append(self)
return resource_list
def _gather_saveables_for_saved_model(self):
return {trackable.VARIABLE_VALUE_KEY: self._vars[0]}
@property
def shape(self):
return self._vars[0].shape
@property
def handle(self):
if save_context.in_save_context() or context.executing_eagerly():
return self._vars[0].handle
if tpu_util.enclosing_tpu_context() is None:
raise NotImplementedError('TPUReplicatedVariable.handle is not available '
'outside tpu context or save context')
else:
with tpu_util.outside_or_skip_tpu_context():
packed_var = getattr(self, '_packed_var', None)
# TODO(b/202047549): Enable packed variables with soft device placement
if packed_var is None or config.get_soft_device_placement():
tensor = tpu_partition_ops.tpu_partitioned_input_v2(
[v.handle for v in self._vars],
partition_dims=[], is_packed=False)
else:
tensor = tpu_partition_ops.tpu_partitioned_input_v2(
[packed_var.packed_handle], partition_dims=[], is_packed=True)
return xla_sharding.replicate(tensor)
def _read_variable_op(self):
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# pylint: disable=protected-access
if tpu_util.enclosing_tpu_context() is None:
return self.read_value()
else:
return self._read_variable_op()
def read_value(self):
return self._vars[0].read_value()
def _update(self, update_fn, value, **kwargs):
"""Converts the value to tensor and updates the variable list."""
input_tensor = ops.convert_to_tensor(
value, name='value_in_tensor', dtype=self.dtype)
return control_flow_ops.group(
*tuple(
_on_device_update(update_fn, v, input_tensor, **kwargs)
for v in self.variables))
def assign(self, value, use_locking=False, name=None, read_value=True):
if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
assign_fn = lambda var, *a, **ka: var.assign(*a, **ka)
return self._update(
assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
assign_sub_fn = lambda var, *a, **ka: var.assign_sub(*a, **ka)
return self._update(
assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka)
return self._update(
assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def __str__(self):
debug_str = ',\n'.join(
' %d: %s' % (i, v) for i, v in enumerate(self._vars))
return '%s:{\n%s\n}' % (self.__class__.__name__, debug_str)
def __repr__(self):
debug_repr = ',\n'.join(
' %d: %r' % (i, v) for i, v in enumerate(self._vars))
return '%s:{\n%s\n}' % (self.__class__.__name__, debug_repr)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion_tpu_replicated_var(var,
dtype=None,
name=None,
as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
tensor_conversion_registry.register_tensor_conversion_function(
TPUReplicatedVariable, _tensor_conversion_tpu_replicated_var)