blob: ce7bacd7e694f9846da32d715e65219b568d8acd [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed variables library."""
import copy
import os
from absl.testing import parameterized
from tensorflow.python.checkpoint import checkpoint as trackable_utils
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util as ds_test_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.types import core
def _device_str(d):
return "/device:GPU:" + str(d)
def _nested_value(d):
return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.tpu_strategy_spmd,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
],
synchronization=[
variables_lib.VariableSynchronization.ON_READ,
variables_lib.VariableSynchronization.ON_WRITE,
],
aggregation=[
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
],
mode=["graph", "eager"],
use_var_policy=[True, False]))
class DistributedVariableTest(test.TestCase, parameterized.TestCase):
def testExtendsVariable(self, distribution, synchronization, aggregation):
with distribution.scope():
v = variables_lib.Variable(
1., synchronization=synchronization, aggregation=aggregation)
self.assertIsInstance(v, variables_lib.Variable)
def testCheckpointing(self, distribution, synchronization, aggregation, mode):
if (isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
and mode == "graph"):
self.skipTest("MWMS combinations tests do not work well in graph mode.")
with distribution.scope():
v = variables_lib.Variable(
constant_op.constant([1., 2., 3., 4]),
synchronization=synchronization,
aggregation=aggregation)
self.evaluate(v.initializer)
before_save = self.evaluate(v.read_value())
# Save random weights into checkpoint.
checkpoint = trackable_utils.Checkpoint(v=v)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.test_session():
save_path = checkpoint.save(prefix)
# Assign inverted value.
self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
after_assign = self.evaluate(v.read_value())
self.assertNotAllClose(before_save, after_assign)
# Restore from the checkpoint.
with self.test_session():
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
after_restore = self.evaluate(v)
self.assertAllClose(before_save, after_restore)
def testTraceback(self, distribution, synchronization, aggregation):
if context.executing_eagerly():
self.skipTest("does not apply to eager")
with distribution.scope():
variable_scope.get_variable(
name="testVar",
initializer=1.,
use_resource=True,
synchronization=synchronization,
aggregation=aggregation)
with self.assertRaisesRegex(ValueError,
"Variable testVar already exists"):
variable_scope.get_variable(
name="testVar",
initializer=1.,
use_resource=True,
synchronization=synchronization,
aggregation=aggregation)
def testSelectReplica(self, distribution, synchronization, aggregation):
with distribution.scope():
v = variables_lib.Variable(
1., synchronization=synchronization, aggregation=aggregation)
self.assertIs(v, distribute_utils.select_replica(0, v))
def testIsTensorLike(self, distribution, synchronization, aggregation):
if isinstance(distribution.extended,
tpu_strategy.TPUExtended) and context.executing_eagerly():
self.skipTest("TPU doesn't support pure eager")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
# In cross replica context.
self.assertIsInstance(v, core.Tensor)
# In replica context.
distribution.run(lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
aggregation):
if isinstance(distribution.extended, tpu_strategy.TPUExtended):
if context.executing_eagerly():
self.skipTest("TPU doesn't support pure eager")
else:
self.skipTest("b/152076846")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
def assert_is_tensor_like(v):
# We can't use Python literals because they are treated as non-distributed
# values is not allowed when aggregation is SUM. See
# `cross_device_ops.reduce_non_distributed_value`.
delta = array_ops.identity(1.)
self.assertIsInstance(v.assign(delta), core.Tensor)
self.assertIsInstance(v.assign_sub(delta), core.Tensor)
self.assertIsInstance(v.assign_add(delta), core.Tensor)
# In cross replica context we return a PerReplica which is not Tensor like
# all the time yet.
if (synchronization == variables_lib.VariableSynchronization.ON_READ and
aggregation != variables_lib.VariableAggregation.SUM):
assert_is_tensor_like(v)
# In replica context.
distribution.run(assert_is_tensor_like, args=(v,))
def testDeepCopy(self, distribution, synchronization, aggregation):
if not context.executing_eagerly():
self.skipTest("deepcopy only supported in eager mode")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
in_dist_copy = copy.deepcopy(v)
out_dist_copy = copy.deepcopy(v)
def assert_is_deep_copy(v1, v2):
self.assertIsInstance(v2, type(v1))
self.assertEqual(v1.aggregation, v2.aggregation)
self.assertEqual(v1.distribute_strategy, v2.distribute_strategy)
if isinstance(v1, ps_values.AggregatingVariable):
self.assertIsInstance(v2.get(), type(v1.get()))
self.assertNotEqual(id(v1.get()), id(v2.get()))
else:
if v1._policy:
self.assertNotEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access
else:
self.assertEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access
self.assertEqual(len(v1.values), len(v2.values))
for (v1v, v2v) in zip(v1.values, v2.values):
self.assertEqual(v1v.device, v2v.device)
self.assertNotEqual(id(v1v), id(v2v))
self.assertAllEqual(
self.evaluate(v1.values), self.evaluate(v2.values))
self.evaluate(variables_lib.global_variables_initializer())
if not isinstance(distribution.extended, tpu_strategy.TPUExtended):
distribution.run(assert_is_deep_copy, args=(v, in_dist_copy))
distribution.run(assert_is_deep_copy, args=(v, out_dist_copy))
def testAssignSignature(self, distribution, synchronization, aggregation):
# This test verifies assign*() can be called in the same way as normal
# variables.
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
def assign():
one = constant_op.constant(1.)
v.assign(one, True, "assign", False)
# TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing
# value as a keyword argument.
v.assign(one, use_locking=True, name="assign", read_value=False)
v.assign_add(one, True, "assign", False)
v.assign_add(one, use_locking=True, name="assign", read_value=False)
v.assign_sub(one, True, "assign", False)
v.assign_sub(one, use_locking=True, name="assign", read_value=False)
# Return something for graph mode to fetch.
return constant_op.constant(1)
self.evaluate(variables_lib.global_variables_initializer())
if not (synchronization == variables_lib.VariableSynchronization.ON_READ
and aggregation == variables_lib.VariableAggregation.SUM):
self.evaluate(distribution.experimental_local_results(assign()))
if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and
context.executing_eagerly()):
self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
def testStrategyExtendedUpdate(self, distribution, synchronization,
aggregation):
if len(distribution.extended.parameter_devices) != 2:
self.skipTest("n/a: needs exactly two parameter devices")
if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and
aggregation != variables_lib.VariableAggregation.NONE):
self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
value = values_lib.PerReplica([1., 2.])
assign_fn = lambda var, value: var.assign(value)
self.evaluate(distribution.extended.update(v, assign_fn, args=(value,)))
self.assertAllEqual(self.evaluate(v.values), [1., 2.])
assign_add_fn = lambda var, value: var.assign_add(value)
self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,)))
self.assertAllEqual(self.evaluate(v.values), [2., 4.])
assign_sub_fn = lambda var, value: var.assign_sub(value)
self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,)))
self.assertAllEqual(self.evaluate(v.values), [1., 2.])
read_assign_fn = lambda var, value: var.assign_add(var.value() + var.
read_value())
self.evaluate(
distribution.extended.update(v, read_assign_fn, args=(value,)))
self.assertAllEqual(self.evaluate(v.values), [3., 6.])
def testSaveNonDistributed(self, distribution, synchronization, aggregation):
# This test verifies that the DistributedVariable behave like the primary
# variable when saving a non-distributed version of the model (the default).
# The test asserts that the function traced under SaveContext has no device
# annotations and only reference the primary component of the variable. Note
# that please avoid capturing other eager tensors in this test to make the
# assertion easy.
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
self.skipTest("b/148689177: AggregatingVariable doesn't "
"conform to Variable interface well")
# tf.function requires the return value to be Tensors, which is not always
# case for properties and methods of Variable, so we simply discard the
# return values.
def _discard_return(f):
f()
return
def _test(f, v):
# This verifies that the function under SaveContext:
# - contains no device annotations.
# - only references the primary component of the variable.
g = def_function.function(lambda: _discard_return(f))
options = save_options.SaveOptions(
experimental_variable_policy=save_options.VariablePolicy.NONE)
with save_context.save_context(options):
# The graph should contain no device.
graph = g.get_concrete_function().graph
for op in graph.get_operations():
self.assertEqual(op.device, "", msg=str(op))
# The function should only capture the primary variable. Note that it
# may not have captures, e.g. v.aggregation.
captures = list(graph.captures)
self.assertLessEqual(len(captures), 1)
if graph.captures:
self.assertIs(captures[0][0], v._primary.handle)
def _assert(cond):
return control_flow_assert.Assert(cond, [cond])
with distribution.scope():
# We use four variables for convenience reasons. They have no special
# meaning.
# - v is used whenever possible.
# - w is used for scatter and gather, which require the variable to be
# non-scalar.
# - y is used when the dtype needs to be integer. Note that aggregation
# cannot be MEAN for integers.
v = variables_lib.Variable(
0.,
synchronization=synchronization,
aggregation=aggregation,
trainable=True)
w = variables_lib.Variable([0., 0., 0.],
synchronization=synchronization,
aggregation=aggregation,
trainable=True)
if aggregation != variables_lib.VariableAggregation.MEAN:
y = variables_lib.Variable(
0, synchronization=synchronization, aggregation=aggregation)
# pylint: disable=g-long-lambda
# tf.Variable properties.
_test(lambda: self.assertEqual(v.aggregation, aggregation), v)
_test(lambda: self.assertIs(v.constraint, None), v)
# TODO(crccw): should we raise an error instead?
_test(lambda: self.assertEqual(v.device, v._primary.device), v)
_test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
if not context.executing_eagerly():
_test(lambda: self.assertIs(v.graph, v._primary.graph), v)
if not context.executing_eagerly():
_test(lambda: _assert(v.initial_value == 0), v)
_test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
_test(lambda: self.assertEqual(v.name, "Variable:0"), v)
if not context.executing_eagerly():
_test(lambda: self.assertIs(v.op, v._primary.op), v)
_test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v)
_test(lambda: self.assertEqual(v.synchronization, synchronization), v)
_test(lambda: self.assertEqual(v.trainable, True), v)
# tf.Variable methods.
_test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
_test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
_test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
# TODO(b/148689177): Implement batch_scatter_update.
# count_up_to() is skipped since it's deprecated.
# eval() is skipped since it shouldn't called in a tf.function.
# experimental_ref() is skipped since it's deprecated.
# from_proto() is skipped since it shouldn't called in a tf.function.
# TODO(b/148689177): Implement gather_nd.
_test(
lambda: check_ops.assert_equal_v2(v.get_shape(),
tensor_shape.TensorShape(())), v)
# initialized_value() is skipped since it shouldn't called in a tf.function.
# load() is skipped since it shouldn't called in a tf.function.
_test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
# ref() is skipped since it shouldn't called in a tf.function.
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])),
[1., 0., 2.]), w)
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])),
[0.25, 0., 1.]), w)
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])),
[0.25, 1., 1.]), w)
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])),
[0.25, 0.5, 1.]), w)
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
[0.5, 0.25, 1.]), w)
# TODO(b/148689177): Implement scatter_nd_*
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
[-1.5, -0.25, 1.]), w)
_test(
lambda: check_ops.assert_equal_v2(
w.scatter_update(
_make_index_slices(values=[2., 0.5], indices=[0, 1])),
[2., 0.5, 1.]), w)
# set_shape() is skipped since ResourceVariable doesn't implement it.
# to_proto() is skipped since it shouldn't called in a tf.function.
_test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)
# DistributedVariable should be treated as ResourceVariable, so it needs to
# conform to ResourceVariable interface as well.
_test(lambda: self.assertIs(v.handle, v._primary.handle), v)
# Convert to tensor.
_test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v)
# Control dependency.
def _with_control_dep():
with ops.control_dependencies([v.assign(1.)]):
return array_ops.identity(1)
_test(_with_control_dep, v)
# Operator overloads.
_test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
_test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
_test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
_test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
_test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
_test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
_test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
_test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
_test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
_test(
lambda: check_ops.assert_equal_v2(
math_ops.cast(v / 2., dtypes.float32), 3.5), v)
_test(
lambda: check_ops.assert_equal_v2(
math_ops.cast(14. / v, dtypes.float32), 2.), v)
_test(lambda: _assert(v < 12.), v)
_test(lambda: _assert(v <= 12.), v)
_test(lambda: _assert(not v > 12.), v)
_test(lambda: _assert(not v >= 12.), v)
_test(lambda: _assert(not 12. < v), v)
_test(lambda: _assert(not 12. <= v), v)
_test(lambda: _assert(12. > v), v)
_test(lambda: _assert(12. >= v), v)
_test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
_test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
_test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)
# Operator overloads that only works for integers.
if aggregation != variables_lib.VariableAggregation.MEAN:
_test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
_test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
_test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
_test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
_test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
_test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
_test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
_test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
_test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
_test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
_test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
_test(lambda: check_ops.assert_equal_v2(-y, -7), y)
_test(lambda: check_ops.assert_equal_v2(~y, ~7), y)
# Index.
if isinstance(distribution.extended, tpu_strategy.TPUExtended):
# TODO(b/161572567): slice assignment doesn't work for TPU.
_test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
else:
_test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]),
w)
_test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
# pylint: enable=g-long-lambda
def testUnsaveable(self, distribution, synchronization, aggregation, mode):
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
self.skipTest("n/a: not appliable to AggregatingVariable")
if (isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
and mode == "graph"):
self.skipTest("MWMS combinations tests do not work well in graph mode.")
if not distribution.extended._use_merge_call():
self.skipTest("Unsupported combination.")
with distribution.scope():
v = variables_lib.Variable([1., 1.],
synchronization=synchronization,
aggregation=aggregation)
with self.cached_session():
self.evaluate(variables_lib.global_variables_initializer())
export_dir = self.get_temp_dir()
def _assert_unsaveable(f):
# Ignore if it cannot be traced. Certain combinations are not supported or
# yet or not allowed.
try:
f = def_function.function(f).get_concrete_function()
except (NotImplementedError, ValueError):
return
with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
save.save(v, export_dir, signatures=f)
_assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
# Reading a ON_READ variable should be unsaveable if either:
# 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
# 2) aggregation is SUM.
if (synchronization == variables_lib.VariableSynchronization.ON_READ and
(aggregation == variables_lib.VariableAggregation.SUM or
(not distribution.extended._use_merge_call()) or
(isinstance(distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended)
and aggregation == variables_lib.VariableAggregation.MEAN))):
_assert_unsaveable(v.read_value)
_assert_unsaveable(v.value)
_assert_unsaveable(lambda: ops.convert_to_tensor(v))
else:
# Otherwise reading a variable should be saveable.
@def_function.function
def f():
v.read_value()
v.value()
return ops.convert_to_tensor(v)
with self.cached_session():
save.save(v, export_dir, signatures=f.get_concrete_function())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
def testPackedVariable(self, distribution):
with distribution.scope():
v0 = variables_lib.Variable(0.)
self.assertIsNone(v0._packed_var)
distribution._enable_packed_variable_in_eager_mode = True
with distribution.scope():
v1 = variables_lib.Variable(0)
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
devices = v1._devices
for i in range(1, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
v1.assign(i)
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[0])
self.assertEqual(self.evaluate(val.read_value()), 0)
for i in range(0, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[i])
self.assertEqual(self.evaluate(val.read_value()), i)
def testIgnorePackedVariableInSaveContext(self, distribution):
distribution._enable_packed_variable_in_eager_mode = True
with distribution.scope():
v = variables_lib.Variable(0)
self.assertIsInstance(v._packed_variable,
packed.PackedDistributedVariable)
options = save_options.SaveOptions()
with save_context.save_context(options):
self.assertIsNone(v._packed_variable)
def _make_index_slices(values, indices, dense_shape=None):
if dense_shape:
dense_shape = array_ops.identity(dense_shape)
return indexed_slices.IndexedSlices(
array_ops.identity(values), array_ops.identity(indices), dense_shape)
if __name__ == "__main__":
ds_test_util.main()