blob: e7dcd958fe32405772467dc5624a8e338422411a [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed values library."""
from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values_v2
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
class _VariableInterfaceTestBase(test.TestCase, parameterized.TestCase):
# This test verifies that DistributedVariable/AutoSyncVariable conforms to
# Variable and ResourceVariable interface, i.e. the methods and properties are
# all defined. It verifies methods and properties that have the same code path
# under different replicas/devices as well. It is not intended to verify
# methods and properties that behave differently under different
# replicas/devices; those should be covered separate tests.
def create_variable(self, initial_value=1., **kwargs):
raise NotImplementedError
@property
def devices(self):
return ["CPU:0", "CPU:1"]
# ==== Begin Variable interface ===
# Please follow the same order as methods and properties defined in
# tf.Variable.
def testStringify(self):
v = self.create_variable()
self.assertIsInstance(v.__str__(), str)
self.assertIsInstance(v.__repr__(), str)
def testDenseRead(self):
v = self.create_variable(1.)
self.assertEqual(v.value(), 1.)
self.assertEqual(v.read_value(), 1.)
def testShape(self):
v = self.create_variable([1.])
self.assertEqual(v.shape, (1,))
self.assertEqual(v.get_shape(), (1,))
v.set_shape((1,))
with self.assertRaisesRegex(ValueError, "not compatible"):
v.set_shape((1, 1))
@combinations.generate(combinations.combine(trainable=[True, False]))
def testTrainable(self, trainable):
v = self.create_variable(trainable=trainable)
self.assertEqual(v.trainable, trainable)
@combinations.generate(
combinations.combine(synchronization=[
variables_lib.VariableSynchronization.ON_READ,
variables_lib.VariableSynchronization.ON_WRITE,
variables_lib.VariableSynchronization.AUTO,
variables_lib.VariableSynchronization.NONE,
]))
def testSynchronization(self, synchronization):
v = self.create_variable(synchronization=synchronization)
self.assertEqual(v.synchronization, synchronization)
@combinations.generate(
combinations.combine(aggregation=[
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
variables_lib.VariableAggregation.NONE,
]))
def testAggregation(self, aggregation):
v = self.create_variable(aggregation=aggregation)
self.assertEqual(v.aggregation, aggregation)
@combinations.generate(combinations.combine(mode="graph"))
def testEval(self):
v = self.create_variable(1.)
with self.cached_session():
self.evaluate(variables_lib.global_variables_initializer())
self.assertEqual(v.eval(), 1.)
def testInitialValueEager(self):
v = self.create_variable(1.)
with self.assertRaises(RuntimeError):
v.initial_value # pylint: disable=pointless-statement
@combinations.generate(combinations.combine(mode="graph"))
def testInitialValueGraph(self):
v = self.create_variable(1.)
self.assertEqual(self.evaluate(v.initial_value), 1.)
def testConstraint(self):
v = self.create_variable(constraint=lambda x: x + 1.)
self.assertEqual(v.constraint(1.), 2.)
def testDenseUpdate(self):
v = self.create_variable(1.)
self.assertEqual(
v.assign(2., use_locking=True, name="assign", read_value=True), 2.)
self.assertIsNone(v.assign(3., read_value=False))
self.assertEqual(v, 3.)
self.assertEqual(
v.assign_add(1., use_locking=True, name="assign_add", read_value=True),
4.)
self.assertIsNone(v.assign_add(1., read_value=False))
self.assertEqual(v, 5.)
self.assertEqual(
v.assign_sub(1., use_locking=True, name="assign_sub", read_value=True),
4.)
self.assertIsNone(v.assign_sub(1., read_value=False))
self.assertEqual(v, 3.)
@def_function.function
def f():
self.assertIsInstance(v.assign(1., read_value=False), ops.Operation)
self.assertIsInstance(v.assign_add(1., read_value=False), ops.Operation)
self.assertIsInstance(v.assign_sub(1., read_value=False), ops.Operation)
f()
def testSparseUpdate(self):
v = self.create_variable([0., 0., 0.])
self.assertAllEqual(
v.scatter_add(
_make_index_slices(values=[1., 2.], indices=[0, 2]),
use_locking=True,
name="add"), [1., 0., 2.])
self.assertAllEqual(
v.scatter_div(
_make_index_slices(values=[4., 2.], indices=[0, 2]),
use_locking=True,
name="div"), [0.25, 0., 1.])
self.assertAllEqual(
v.scatter_max(
_make_index_slices(values=[1., 0.5], indices=[1, 2]),
use_locking=True,
name="max"), [0.25, 1., 1.])
self.assertAllEqual(
v.scatter_min(
_make_index_slices(values=[1., 0.5], indices=[0, 1]),
use_locking=True,
name="min"), [0.25, 0.5, 1.])
self.assertAllEqual(
v.scatter_mul(
_make_index_slices(values=[2., 0.5], indices=[0, 1]),
use_locking=True,
name="mul"), [0.5, 0.25, 1.])
self.assertAllEqual(
v.scatter_sub(
_make_index_slices(values=[2., 0.5], indices=[0, 1]),
use_locking=True,
name="sub"), [-1.5, -0.25, 1.])
self.assertAllEqual(
v.scatter_update(
_make_index_slices(values=[2., 0.5], indices=[0, 1]),
use_locking=True,
name="update"), [2., 0.5, 1.])
self.assertAllEqual(
v.batch_scatter_update(
_make_index_slices(values=[1., 1.5], indices=[0, 1]),
use_locking=True,
name="update"), [1., 1.5, 1.])
def testSparseNdUpdate(self):
v = self.create_variable([0., 0., 0., 0.])
self.assertAllEqual(
v.scatter_nd_sub([[3], [1]], [1., 2.], name="sub"), [0., -2., 0., -1.])
self.assertAllEqual(
v.scatter_nd_add([[2], [0]], [1., 2.], name="add"), [2., -2., 1., -1.])
self.assertAllEqual(
v.scatter_nd_update([[1], [3]], [3., 3.], name="update"),
[2., 3., 1., 3.])
def testSparseRead(self):
v = self.create_variable([[1., 2.], [3., 4.]])
self.assertAllEqual(
v.sparse_read([1, 0], name="read"), [[3., 4.], [1., 2.]])
self.assertAllEqual(
v.gather_nd([[1, 0], [0, 1]], name="gather_nd"), [3., 2.])
def testTensorConversion(self):
v = self.create_variable([1.])
self.assertEqual(ops.convert_to_tensor(v), [1.])
def testHash(self):
v = self.create_variable()
w = self.create_variable()
d = {}
with self.assertRaises(TypeError):
d[v] = 1
d[v.ref()] = 1
self.assertEqual(d[v.ref()], 1)
self.assertNotIn(w.ref(), d)
@combinations.generate(combinations.combine(mode="graph"))
def testHashGraph(self):
v = self.create_variable()
w = self.create_variable()
d = {v: 1}
self.assertEqual(d[v], 1)
self.assertNotIn(w, d)
def testEquality(self):
v = self.create_variable(1.)
w = self.create_variable(2.)
x = self.create_variable(1.)
self.assertEqual(v, x)
self.assertNotEqual(v, w)
@combinations.generate(combinations.combine(mode="graph"))
def testEqualityGraph(self):
# In legacy graph mode, tensor equality is object equality
v = self.create_variable(1.)
w = self.create_variable(1.)
self.assertNotEqual(v, w)
self.assertEqual(v, v)
def testIteration(self):
v = self.create_variable([1.])
self.assertEqual([1.], list(iter(v)))
def testProperties(self):
v = self.create_variable()
self.assertIsInstance(v.name, str)
# _shared_name is also part of the interface. E.g. it's used in optimizer to
# determine slot variable key.
self.assertIsInstance(v._shared_name, str)
self.assertIsNone(v.initializer)
self.assertIsInstance(v.device, str)
self.assertEqual(v.dtype, dtypes.float32)
with self.assertRaises(AttributeError):
v.op # pylint: disable=pointless-statement
with self.assertRaises(AttributeError):
v.graph # pylint: disable=pointless-statement
@combinations.generate(combinations.combine(mode="graph"))
def testPropertiesGraph(self):
v = self.create_variable()
self.assertIsInstance(v.initializer, ops.Operation)
self.assertIsInstance(v.op, ops.Operation)
self.assertIsInstance(v.graph, ops.Graph)
def testProtoConversion(self):
# to_proto and from_proto are not supported.
v = self.create_variable([1, 2])
with self.assertRaises(TypeError):
v.to_proto()
with self.assertRaises(TypeError):
v.from_proto(variable_def=None)
def testSaveSliceInfo(self):
v = self.create_variable()
slice_info = variables_lib.Variable.SaveSliceInfo()
v._set_save_slice_info(slice_info)
self.assertIs(v._get_save_slice_info(), slice_info)
# Some code accesses _save_slice_info directly without using the getter.
self.assertIs(v._save_slice_info, slice_info)
def testOperatorOverride(self):
v = self.create_variable(7)
self.assertEqual(v + 1, 8)
self.assertEqual(3 + v, 10)
self.assertEqual(v + v, 14)
self.assertEqual(v - 2, 5)
self.assertEqual(13 - v, 6)
self.assertEqual(v - v, 0)
self.assertEqual(v * 2, 14)
self.assertEqual(3 * v, 21)
self.assertEqual(v * v, 49)
self.assertEqual(v / 2, 3.5)
self.assertEqual(14 / v, 2.)
self.assertEqual(v // 2, 3)
self.assertEqual(15 // v, 2)
self.assertEqual(v % 2, 1)
self.assertEqual(16 % v, 2)
# pylint: disable=g-generic-assert
self.assertTrue(v < 12)
self.assertTrue(v <= 12)
self.assertFalse(v > 12)
self.assertFalse(v >= 12)
self.assertFalse(12 < v)
self.assertFalse(12 <= v)
self.assertTrue(12 > v)
self.assertTrue(12 >= v)
# pylint: enable=g-generic-assert
self.assertEqual(v & 3, 3)
self.assertEqual(11 & v, 3)
self.assertEqual(v | 8, 15)
self.assertEqual(16 | v, 23)
self.assertEqual(v ^ 3, 4)
self.assertEqual(11 ^ v, 12)
self.assertEqual(pow(v, 3), 343)
# TODO(b/178748613): pow(v, 3, 10) fails.
self.assertEqual(pow(2, v), 128)
self.assertEqual(-v, -7)
self.assertEqual(~v, ~7)
self.assertEqual(abs(v), 7)
def testSlice(self):
v = self.create_variable([1., 2., 3.])
self.assertEqual(v[1], 2.)
v[2].assign(4.)
self.assertAllEqual(v, [1., 2., 4.])
# ==== End Variable interface ===
# ==== Begin ResourceVariable interface ===
def testHandle(self):
v = self.create_variable()
self.assertIsInstance(v.handle, ops.Tensor)
self.assertEqual(v.handle.dtype, dtypes.resource)
def testInGraphMode(self):
# This is protected but used in a lot of places internally.
v = self.create_variable()
self.assertFalse(v._in_graph_mode)
def testUniqueId(self):
# This is used in optimizer as part of slot variable key.
v = self.create_variable()
w = self.create_variable()
self.assertNotEqual(v._unique_id, w._unique_id)
def testIsResourceVariable(self):
v = self.create_variable()
self.assertTrue(resource_variable_ops.is_resource_variable(v))
# ==== End ResourceVariable interface ===
@combinations.generate(combinations.combine(mode="graph"))
def testAsGraphElement(self):
g = ops.Graph()
with g.as_default():
v = self.create_variable(1.)
g.finalize()
self.evaluate(v.initializer)
# _as_graph_element shouldn't create new operations.
self.assertEqual(self.evaluate(v._as_graph_element()), 1.)
class DistributedVariableInterfaceTest(_VariableInterfaceTestBase):
def create_variable(self, initial_value=1., **kwargs):
variables = []
for device in self.devices:
with ops.device(device):
variables.append(
variables_lib.Variable(initial_value, **kwargs))
return values_v2.DistributedVariable(variables)
# Prevent the base class from running.
del _VariableInterfaceTestBase
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.tpu_strategy,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
enable_packed_handle=[True, False],
tf_function=[combinations.tf_function, combinations.no_tf_function]))
class DistributedVariableTest(test.TestCase, parameterized.TestCase):
def create_variable(self, strategy, initial_value, enable_packed_handle,
**kwargs):
variables = []
for device in strategy.extended.parameter_devices:
with ops.device(device):
variables.append(variables_lib.Variable(initial_value, **kwargs))
return values_v2.DistributedVariable(
variables, enable_packed_handle=enable_packed_handle)
def assertReplica(self, distributed_var, values):
for var, value in zip(distributed_var._variables, values):
self.assertAllEqual(var, value)
def testRead(self, strategy, enable_packed_handle, tf_function):
v = self.create_variable(strategy, 0., enable_packed_handle)
with ops.device(strategy.extended.parameter_devices[0]):
v.assign(1.)
with ops.device(strategy.extended.parameter_devices[1]):
v.assign(2.)
@tf_function
def read_device0():
with ops.device(strategy.extended.parameter_devices[0]):
return v.read_value(), v.value()
@tf_function
def read_device1():
with ops.device(strategy.extended.parameter_devices[1]):
return v.read_value(), v.value()
@tf_function
def read_other_device():
with ops.device("CPU:0"):
return v.read_value(), v.value()
self.assertAllEqual(read_device0(), [1., 1.])
self.assertAllEqual(read_device1(), [2., 2.])
self.assertAllEqual(read_other_device(), [1., 1.])
def testAssign(self, strategy, enable_packed_handle, tf_function):
v = self.create_variable(strategy, 0., enable_packed_handle)
@tf_function
def update_device0():
with ops.device(strategy.extended.parameter_devices[0]):
v.assign(1.)
@tf_function
def update_device1():
with ops.device(strategy.extended.parameter_devices[1]):
v.assign(2.)
update_device0()
update_device1()
self.assertReplica(v, [1., 2.])
with ops.device("CPU:0"):
# Update the primary replica.
v.assign(3.)
self.assertReplica(v, [3., 2.])
def testStrategyRun(self, strategy, enable_packed_handle, tf_function):
if (test_util.is_tpu_strategy(strategy) and
tf_function is combinations.no_tf_function):
self.skipTest("tpu doesn't support eager")
v = self.create_variable(strategy, 0., enable_packed_handle)
@tf_function
def update(per_replica):
v.assign(per_replica)
@tf_function
def read():
return v.read_value()
strategy.run(
update, args=(test_util.create_per_replica(strategy, [1., 2.]),))
self.assertReplica(v, [1., 2.])
self.assertAllEqual(
test_util.gather(strategy, strategy.run(read)), [1., 2.])
def _make_index_slices(values, indices, dense_shape=None):
if dense_shape:
dense_shape = array_ops.identity(dense_shape)
return indexed_slices.IndexedSlices(
array_ops.identity(values), array_ops.identity(indices), dense_shape)
if __name__ == "__main__":
test_util.main()