blob: 42302d32d66ff105266b80906628c5cb9f587576 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing TPU distributed values.
Note that the tests are in values_test.py .
"""
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import tpu_replicated_variable
from tensorflow.python.distribute import tpu_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
_scatter_error_msg = ("{op_name} is only supported for distributed "
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE "
" aggregation, got: {aggregation}.")
class TPUVariableMixin(object):
"""Mixin for TPU variables."""
def __init__(self, *args, **kwargs):
super(TPUVariableMixin, self).__init__(*args, **kwargs)
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
# correctly since in eager mode different variables can have the same name.
if ops.executing_eagerly_outside_functions():
self._handle_id = self._common_name + "_" + str(id(self._primary))
else:
self._handle_id = self._common_name
def __getattr__(self, name):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).__getattr__(name)
else:
raise AttributeError(
f"`TPUVariableMixin.{name}` not accessible within a TPU context.")
def get(self):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).get()
else:
raise NotImplementedError(
"`TPUVariableMixin.get()` is not supported within a TPU context.")
def _get_as_operand(self):
return self.read_value()
@property
def handle(self):
"""The handle by which this variable can be accessed."""
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = tpu_util.enclosing_tpu_context()
if tpu_context is None or context.executing_eagerly():
var = self._get_on_device_or_primary()
if isinstance(var, packed.PackedVarAndDevice):
return var.on_device_handle()
else:
return var.handle
else:
is_packed = self._packed_var is not None
val = self._values
if is_packed:
val = [self._packed_var]
return tpu_context.get_replicated_var_handle(self._common_name,
self._handle_id, val,
self._is_mirrored(),
is_packed)
@property
def device(self):
return self.handle.device
def _read_variable_op(self):
"""Reads the value of this variable."""
if self.trainable:
tape.variable_accessed(self)
handle = self.handle
if getattr(handle, "is_packed", False):
# Add a device scope for a packed variable handle.
with ops.device(self._get_on_device_or_primary().device):
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
else:
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
def read_value(self):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).read_value()
else:
return self._read_variable_op()
def value(self):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).value()
else:
return self._read_variable_op()
def _as_graph_element(self):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
else:
return None
@property
def op(self):
if values_util.is_saving_non_distributed():
return self._primary.op
return values.DistributedVarOp(self._primary.op.name,
self._primary.op.graph,
self._primary.op.traceback,
self._primary.op.type)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# pylint: disable=protected-access
if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._dense_var_to_tensor(
dtype=dtype, name=name, as_ref=as_ref)
# pylint: enable=protected-access
elif dtype is not None and dtype != self.dtype:
return math_ops.cast(self.read_value(), dtype)
else:
return self.handle if as_ref else self.read_value()
class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
"""DistributedVariable subclass for TPUStrategy."""
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_sub(value, use_locking, name, read_value)
return self._policy.assign_sub(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign_add(value, use_locking, name, read_value)
return self._policy.assign_add(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
if values_util.is_saving_non_distributed():
return self._primary.assign(value, use_locking, name, read_value)
return self._policy.assign(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_sub(sparse_delta, use_locking, name)
return self._policy.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_add(sparse_delta, use_locking, name)
return self._policy.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_mul(sparse_delta, use_locking, name)
return self._policy.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_div(sparse_delta, use_locking, name)
return self._policy.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(sparse_delta, use_locking, name)
return self._policy.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(sparse_delta, use_locking, name)
return self._policy.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(sparse_delta, use_locking, name)
return self._policy.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync."""
def _is_replicated_or_sharded_to_logical_cores(self):
"""Returns whether each of the underlying variables is replicated or sharded to logical cores.
If True, the handles of the underlying variables are not available outside a
TPU context.
"""
return isinstance(self._primary,
tpu_replicated_variable.TPUReplicatedVariable)
@property
def device(self):
if (self._is_replicated_or_sharded_to_logical_cores() and
tpu_util.enclosing_tpu_context() is None):
return self._primary.device
return super(TPUMirroredVariable, self).device
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
tpu_context = tpu_util.enclosing_tpu_context()
if (self._is_replicated_or_sharded_to_logical_cores() and
tpu_context is None):
assign_sub_fn = lambda v, *a, **ka: v.assign_sub(*a, **ka)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if (tpu_context and
self.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign_sub(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
tpu_context = tpu_util.enclosing_tpu_context()
if (self._is_replicated_or_sharded_to_logical_cores() and
tpu_context is None):
assign_add_fn = lambda v, *a, **ka: v.assign_add(*a, **ka)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if (tpu_context and
self.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign_add(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
tpu_context = tpu_util.enclosing_tpu_context()
if (self._is_replicated_or_sharded_to_logical_cores() and
tpu_context is None):
assign_fn = lambda v, *a, **ka: v.assign(*a, **ka)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if (tpu_util.enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign(
self, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_sub(*args, **kwargs)
raise NotImplementedError
def scatter_add(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_add(*args, **kwargs)
raise NotImplementedError
def scatter_max(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(*args, **kwargs)
raise NotImplementedError
def scatter_min(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(*args, **kwargs)
raise NotImplementedError
def scatter_mul(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_mul(*args, **kwargs)
raise NotImplementedError
def scatter_div(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_div(*args, **kwargs)
raise NotImplementedError
def scatter_update(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(*args, **kwargs)
raise NotImplementedError
class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
**kwargs)
def assign_add(self, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
**kwargs)
def assign(self, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
# Common method between OnWrite and Mirrored variables.
def assign_sub(var, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)
return var._update( # pylint: disable=protected-access
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(var, value, use_locking=False, name=None, read_value=True):
assign_add_fn = tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)
return var._update( # pylint: disable=protected-access
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(var, value, use_locking=False, name=None, read_value=True):
assign_fn = tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)
return var._update( # pylint: disable=protected-access
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
class TPUOnWritePolicy(values.OnWritePolicy):
"""Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
This policy is created when `synchronization` is set to
`tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`.
"""
def assign_sub(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
if (tpu_util.enclosing_tpu_context() and
var.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(
var,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign_sub(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
if (tpu_util.enclosing_tpu_context() and
var.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(
var,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign_add(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
if (tpu_util.enclosing_tpu_context() and
var.aggregation == variable_scope.VariableAggregation.NONE):
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(
var,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
return assign(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def _scatter_xxx(self,
raw_scater_xxx_fn,
op_name,
var,
sparse_delta,
use_locking=False,
name=None):
scater_xxx_fn = tpu_util.make_raw_scatter_xxx_fn(raw_scater_xxx_fn)
if tpu_util.enclosing_tpu_context():
if self._aggregation != variable_scope.VariableAggregation.NONE:
raise NotImplementedError(
_scatter_error_msg.format(
op_name=op_name, aggregation=self._aggregation))
return scater_xxx_fn(
var, sparse_delta=sparse_delta, use_locking=use_locking, name=name)
else:
return var._update( # pylint: disable=protected-access
update_fn=scater_xxx_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub,
"scatter_sub", var, sparse_delta, use_locking,
name)
def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add,
"scatter_add", var, sparse_delta, use_locking,
name)
def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max,
"scatter_max", var, sparse_delta, use_locking,
name)
def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min,
"scatter_min", var, sparse_delta, use_locking,
name)
def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul,
"scatter_mul", var, sparse_delta, use_locking,
name)
def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div,
"scatter_div", var, sparse_delta, use_locking,
name)
def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update,
"scatter_update", var, sparse_delta, use_locking,
name)
class TPUOnReadPolicy(values.OnReadPolicy):
"""Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
This policy is created when `synchronization` is set to
`tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
`MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
scope.
"""
def assign_sub(self, var, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
**kwargs)
def assign_add(self, var, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(var, *args,
**kwargs)
def assign(self, var, *args, **kwargs):
if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
else:
return tpu_util.make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs)
def scatter_sub(self, *args, **kwargs):
raise NotImplementedError
def scatter_add(self, *args, **kwargs):
raise NotImplementedError
def scatter_max(self, *args, **kwargs):
raise NotImplementedError
def scatter_min(self, *args, **kwargs):
raise NotImplementedError
def scatter_mul(self, *args, **kwargs):
raise NotImplementedError
def scatter_div(self, *args, **kwargs):
raise NotImplementedError
def scatter_update(self, *args, **kwargs):
raise NotImplementedError