blob: 5dd2c5a3b1ae4afd0f99d4615a80e2baa012fd1a [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed values library."""
import itertools
import uuid
from absl.testing import parameterized
from tensorflow.python.checkpoint import checkpoint as trackable_utils
from tensorflow.python.checkpoint import checkpoint_management as ckpt_manager
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.util import variable_utils
def strategy_and_run_tf_function_combinations():
# Test the combination of different strategies and whether a tf.function
# is passed into strategy.run."""
# TODO(b/197981388): re-enable MWMS test
# return combinations.combine(
# distribution=[
# strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
# ],
# mode=["graph", "eager"],
# experimental_run_tf_function=[True, False],
# use_var_policy=[True, False]) +
return combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True],
use_var_policy=[True, False])
def strategy_with_var_policy():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
# TODO(b/197981388): re-enable MWMS test
# strategy_combinations.multi_worker_mirrored_2x1_cpu,
# strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"],
use_var_policy=[True, False])
class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# assign in replica context with SUM does not make sense cause you can
# just do value * num replicas error is 1. is not a distributed value and
# is unsupported for aggregation SUM
if (not cross_replica and aggregation ==
variables_lib.VariableAggregation.SUM):
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
with distribution.scope():
v_to_assign = variable_v1.VariableV1(
2., aggregation=variables_lib.VariableAggregation.MEAN)
v_to_assign_sub = variable_v1.VariableV1(
-2., aggregation=variables_lib.VariableAggregation.MEAN)
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
("assign_sub", v_to_assign_sub)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# assign in replica context with SUM does not make sense cause you can
# just do value * num replicas error is 1. is not a distributed value and
# is unsupported for aggregation SUM
if aggregation == variables_lib.VariableAggregation.SUM:
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(2.0, self.evaluate(component.read_value()))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
if strategy_test_lib.is_tpu_strategy(distribution):
self.skipTest("Assigning PerReplica values is not supported. See"
" sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
with distribution.scope():
per_replica_value = values.PerReplica(
[constant_op.constant(2.0),
constant_op.constant(2.0)])
per_replica_sub_value = values.PerReplica(
[constant_op.constant(-2.0),
constant_op.constant(-2.0)])
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", per_replica_value), ("assign_add", per_replica_value),
("assign_sub", per_replica_sub_value)]
# We don't support assigning PerReplica valus to vars in replica context
# with aggregation=NONE.
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# assign in replica context with SUM does not make sense cause you can
# just do value * num replicas error is 1. is not a distributed value and
# is unsupported for aggregation SUM
if cross_replica:
# We don't support assigning PerReplica values to MirroredVariables in
# cross replica context
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
if aggregation == variables_lib.VariableAggregation.SUM:
expected = 4.0
else:
expected = 2.0
for component in v._values:
self.assertAllEqual(expected, self.evaluate(component.read_value()))
@combinations.generate(strategy_with_var_policy())
def testValueInReplicaContext(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
1., aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def f():
with ops.control_dependencies([v.assign_add(1.)]):
return v.value()
results = self.evaluate(
test_util.gather(distribution, distribution.run(f)))
for value in results:
self.assertEqual(2., value)
@combinations.generate(strategy_with_var_policy())
def testValueInReplicaContextAssignDirectValue(self, distribution,
use_var_policy):
with distribution.scope():
v = variables_lib.Variable(
1., aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def f():
with ops.control_dependencies([v.assign_add(1.)]):
return v.value()
results = self.evaluate(
test_util.gather(distribution, distribution.run(f)))
for value in results:
self.assertEqual(2., value)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_v1.VariableV1(
0.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
if experimental_run_tf_function:
read_var_fn = def_function.function(v.read_value)
else:
read_var_fn = v.read_value
results = self.evaluate(
test_util.gather(distribution, distribution.run(read_var_fn)))
for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_v1.VariableV1(
2.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
if experimental_run_tf_function:
read_var_fn = def_function.function(v.read_value)
else:
read_var_fn = v.read_value
results = read_var_fn()
for component in v._values:
self.assertEqual(self.evaluate(component.read_value()),
self.evaluate(results))
@combinations.generate(strategy_with_var_policy())
def testAssignOutOfScope(self, distribution):
with distribution.scope():
mirrored = variables_lib.Variable(1.)
self.evaluate(mirrored.assign(3.))
self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
for component in mirrored.values:
self.assertEqual(self.evaluate(component.read_value()), 3.)
@combinations.generate(strategy_with_var_policy())
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only test")
if isinstance(distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended):
self.skipTest("Test for more than 1 device per worker only.")
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(random_ops.random_normal([]))
distribution.run(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
@combinations.generate(strategy_with_var_policy())
def testAggregationOnlyFirstReplica(self, distribution):
if isinstance(distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended):
self.skipTest("b/212945803")
with distribution.scope():
v = variable_v1.VariableV1(
15.,
synchronization=variables_lib.VariableSynchronization.ON_WRITE,
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def assign():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(assign)))
# The per-replica values should always match the first replicas value.
self.assertAllEqual(
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
per_replica_results)
@combinations.generate(strategy_with_var_policy())
def testInitScope(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")
class C(object):
pass
obj = C()
obj.w = None
obj.v = None
@def_function.function
def assign():
with ops.init_scope():
if obj.w is None:
obj.w = variables_lib.Variable(
0., aggregation=variables_lib.VariableAggregation.MEAN)
obj.v = variables_lib.Variable(
obj.w.read_value(),
aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
return obj.v.assign_add(2.)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(assign)))
self.assertAllEqual([2., 2.], per_replica_results)
@combinations.generate(strategy_with_var_policy())
def testOperatorOverride(self, distribution):
if not context.executing_eagerly() and isinstance(
distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended):
self.skipTest("b/212954197")
with distribution.scope():
v = variable_v1.VariableV1(
1, aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
self.assertEqual(2, self.evaluate(v + 1))
@def_function.function
def add():
return v + 1
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(add)))
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
mode=["eager"],
use_var_policy=[True, False]))
def testSaveAndRestoreOnWrite(self, strategy):
aggregation = [
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA,
variable_scope.VariableAggregation.SUM,
variable_scope.VariableAggregation.MEAN
]
for agg in aggregation:
v_normal_restore = variables_lib.Variable(1.0)
v_normal_save = variables_lib.Variable(3.0)
with strategy.scope():
v_on_write = variables_lib.Variable(2.0, aggregation=agg)
# Save ONWRITE Restore ONWRITE
# Save
ckpt = trackable_utils.Checkpoint(var=v_on_write)
manager = ckpt_manager.CheckpointManager(
ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
manager.save()
# Restore
ckpt.restore(manager.latest_checkpoint)
self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
self.assertEqual(2.0, self.evaluate(v_on_write.read_value()))
# Save Mirrored Restore Normal
# We've already saved Mirrored, so we only need to restore normal
ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
ckpt_normal.restore(manager.latest_checkpoint)
self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
self.assertEqual(2.0, self.evaluate(v_normal_restore.read_value()))
# Save Normal Restore Mirrored
# Save
ckpt = trackable_utils.Checkpoint(var=v_normal_save)
manager_2 = ckpt_manager.CheckpointManager(
ckpt, "/tmp/ckptckpt_" + str(uuid.uuid4()), max_to_keep=None)
manager_2.save()
# Restore
ckpt_on_write = trackable_utils.Checkpoint(var=v_on_write)
ckpt_on_write.restore(manager_2.latest_checkpoint)
self.assertEqual(3.0, self.evaluate(v_on_write._values[0]))
self.assertEqual(3.0, self.evaluate(v_on_write.read_value()))
ms_combination = combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_gpu_and_cpu],
mode=["graph", "eager"])
tpu_combination = combinations.combine(
distribution=[strategy_combinations.tpu_strategy_packed_var],
mode=["graph", "eager"])
class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
@combinations.generate(ms_combination)
def testScatterSub(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(v.initializer)
@def_function.function
def scatter_sub():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops_stack.stack([
math_ops.cast(replica_id, dtypes.float32),
math_ops.cast(replica_id + 1, dtypes.float32)
]),
indices=array_ops_stack.stack([replica_id, replica_id + 1]),
dense_shape=(3,))
return v.scatter_sub(value)
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_sub)))
self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterAdd(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(v.initializer)
@def_function.function
def scatter_add():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops_stack.stack([replica_id, replica_id + 1]),
indices=array_ops_stack.stack([replica_id, replica_id + 1]),
dense_shape=(3,))
return v.scatter_add(value)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(scatter_add)))
self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterDiv(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(v.initializer)
@def_function.function
def scatter_div():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.reshape(replica_id + 2, [1]),
indices=array_ops.reshape(replica_id, [1]),
dense_shape=(3,))
return v.scatter_div(value)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(scatter_div)))
self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterMul(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(v.initializer)
@def_function.function
def scatter_mul():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.reshape(
math_ops.cast(replica_id + 2, dtypes.float32), [1]),
indices=array_ops.reshape(replica_id, [1]),
dense_shape=(3,))
return v.scatter_mul(value)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(scatter_mul)))
self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterMin(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 2, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_min(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([1]),
indices=array_ops.identity([1]),
dense_shape=(3,))
return v.scatter_min(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_min, args=(v1,))))
per_replica_results = self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_min, args=(v2,))))
self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterMax(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 0, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_max(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([1]),
indices=array_ops.identity([0]),
dense_shape=(3,))
return v.scatter_max(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_max, args=(v1,))))
per_replica_results = self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_max, args=(v2,))))
self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
@combinations.generate(ms_combination)
def testScatterUpdate(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 0, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_update(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([3]),
indices=array_ops.identity([1]),
dense_shape=(3,))
return v.scatter_update(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_update, args=(v1,))))
per_replica_results = self.evaluate(
test_util.gather(distribution,
distribution.run(scatter_update, args=(v2,))))
self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
@combinations.generate(ms_combination + tpu_combination)
def testScatterOpsWithNoneAggregation(self, distribution):
def assert_close(v, op, delta, expect):
scatter_op = getattr(v, op)
@def_function.function
def scatter_xxx():
return scatter_op(delta)
per_replica_results = self.evaluate(
variable_utils.convert_variables_to_tensors(
distribution.experimental_local_results(
distribution.run(scatter_xxx))))
self.assertAllClose([expect, expect], per_replica_results)
with distribution.scope():
v = variables_lib.Variable(
[4.], aggregation=variables_lib.VariableAggregation.NONE)
self.evaluate(variables_lib.global_variables_initializer())
delta = indexed_slices.IndexedSlices(
values=array_ops.identity([2.]),
indices=array_ops.identity([0]),
dense_shape=(1,))
assert_close(v, "scatter_sub", delta, [2.])
assert_close(v, "scatter_add", delta, [4.])
assert_close(v, "scatter_max", delta, [4.])
assert_close(v, "scatter_min", delta, [2.])
assert_close(v, "scatter_mul", delta, [4.])
assert_close(v, "scatter_div", delta, [2.])
assert_close(v, "scatter_update", delta, [2.])
@combinations.generate(ms_combination + tpu_combination)
def testScatterOpsInCrossReplicaContext(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable([1, 1, 1])
self.evaluate(variables_lib.global_variables_initializer())
value = indexed_slices.IndexedSlices(
values=array_ops.identity([2]),
indices=array_ops.identity([0]),
dense_shape=(3,))
with distribution.scope():
self.evaluate(v1.scatter_add(value))
self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
self.evaluate(v2.scatter_min(value))
self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# VariableAggregation.SUM in cross-replica mode is tested below,
# VariableAggregation.NONE in cross-replica mode is not supported.
if cross_replica and aggregation in [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.NONE,
]:
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignOnReadVar(self, distribution, experimental_run_tf_function):
with distribution.scope():
v_to_assign = variable_v1.VariableV1(
2., aggregation=variables_lib.VariableAggregation.MEAN)
v_to_assign_sub = variable_v1.VariableV1(
-2., aggregation=variables_lib.VariableAggregation.MEAN)
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
("assign_sub", v_to_assign_sub)]
expected_cross_replica = {
variables_lib.VariableAggregation.SUM: 1.0,
variables_lib.VariableAggregation.MEAN: 2.0,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
}
expected_replica = {
variables_lib.VariableAggregation.SUM: 2.0,
variables_lib.VariableAggregation.MEAN: 2.0,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
}
# aggregation=NONE is not supported for OnReadVariables.
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# assign in replica context with SUM does not make sense cause you can
# just do value * num replicas error is 1. is not a distributed value and
# is unsupported for aggregation SUM
if aggregation == variables_lib.VariableAggregation.SUM:
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
if cross_replica:
for component in v._values:
self.assertAllEqual(expected_cross_replica.get(aggregation),
self.evaluate(component.read_value()))
else:
for component in v._values:
self.assertAllEqual(expected_replica.get(aggregation),
self.evaluate(component.read_value()))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
if strategy_test_lib.is_tpu_strategy(distribution):
self.skipTest("Assigning PerReplica values is not supported. See"
" sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
self.skipTest("We don't support assiging PerReplica values in cross "
"replica context or replica context. see error in "
"sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.")
with distribution.scope():
per_replica_value = values.PerReplica(
[constant_op.constant(2.0),
constant_op.constant(2.0)])
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", per_replica_value)]
# We don't support assigning PerReplica valus to vars in replica context
# with aggregation=NONE.
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# assign in replica context with SUM does not make sense cause you can
# just do value * num replicas error is 1. is not a distributed value and
# is unsupported for aggregation SUM
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
# with self.assertRaisesRegex(ValueError, "Attempt to convert a value "):
self.evaluate(assign(fn, v, update_value, cross_replica))
if aggregation == variables_lib.VariableAggregation.SUM:
expected = 4.0
else:
expected = 2.0
for component in v._values:
self.assertAllEqual(expected, self.evaluate(component.read_value()))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignDtypeConversion(self, distribution,
experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return test_util.gather(distribution, distribution.run(update_fn))
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# VariableAggregation.SUM in cross-replica mode is tested below,
# VariableAggregation.NONE in cross-replica mode is not supported.
if cross_replica and aggregation in [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.NONE,
]:
continue
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_with_var_policy())
def testAssignWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_with_var_policy())
def testAssignAddSubWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_add(1.))
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_sub(1.))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
if experimental_run_tf_function:
read_var_fn = def_function.function(v.read_value)
else:
read_var_fn = v.read_value
results = self.evaluate(
test_util.gather(distribution, distribution.run(read_var_fn)))
for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
if strategy_test_lib.is_tpu_strategy(distribution):
resolver = tpu_cluster_resolver.TPUClusterResolver("")
tpu_cluster_resolver.initialize_tpu_system(resolver)
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(v=v):
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
if experimental_run_tf_function:
assign = def_function.function(assign)
self.evaluate(test_util.gather(distribution, distribution.run(assign)))
num_replicas = distribution.num_replicas_in_sync
sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
if aggregation == variables_lib.VariableAggregation.SUM:
expected = sum_of_replica_values
elif aggregation == variables_lib.VariableAggregation.MEAN:
expected = sum_of_replica_values / num_replicas
else:
expected = 0
self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
self.assertEqual(expected, self.evaluate(v.value()), aggregation)
self.assertEqual(expected, self.evaluate(v), aggregation)
self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
aggregation)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAllReduce(self, distribution, experimental_run_tf_function):
with distribution.scope():
v = variable_v1.VariableV1(
2.,
synchronization=variables_lib.VariableSynchronization.ON_WRITE,
aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
def all_reduce():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
dtypes.float32)
if experimental_run_tf_function:
all_reduce = def_function.function(all_reduce)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(all_reduce)))
expected_result = []
for i in range(distribution.num_replicas_in_sync):
expected_result.append(2.0 * distribution.num_replicas_in_sync +
1.0 * i)
self.assertAllEqual(per_replica_results, tuple(expected_result))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaBeforeRead(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(var=v):
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return var.assign(math_ops.cast(replica_id, dtypes.float32))
if experimental_run_tf_function:
assign = def_function.function(assign)
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(assign)))
expected_result = []
for i in range(distribution.num_replicas_in_sync):
expected_result.append(1.0 * i)
self.assertAllEqual(per_replica_results, tuple(expected_result))
@combinations.generate(strategy_with_var_policy())
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
with distribution.scope():
v = variable_v1.VariableV1(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.NONE)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
self.evaluate(v.read_value())
@combinations.generate(strategy_with_var_policy())
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")
if isinstance(distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended):
self.skipTest("Test for more than 1 device per worker only.")
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(
random_ops.random_normal([]),
synchronization=variables_lib.VariableSynchronization.ON_READ)
distribution.run(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
@combinations.generate(strategy_with_var_policy())
def testOperatorOverride(self, distribution):
with distribution.scope():
v = variable_v1.VariableV1(
0.0,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def assign():
ctx = distribute_lib.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
# Assign different replicas with different values.
self.evaluate(test_util.gather(distribution, distribution.run(assign)))
self.assertEqual(1.5, self.evaluate(v + 1))
@def_function.function
def add():
return v + 1
per_replica_results = self.evaluate(
test_util.gather(distribution, distribution.run(add)))
self.assertAllEqual([1, 2], per_replica_results)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
mode=["eager"],
use_var_policy=[True, False]))
def testSaveAndRestoreOnRead(self, strategy):
aggregation = [variable_scope.VariableAggregation.SUM,
variable_scope.VariableAggregation.MEAN]
for agg in aggregation:
v_normal_restore = variables_lib.Variable(1.0)
v_normal_save = variables_lib.Variable(2.0)
with strategy.scope():
v_on_read = variables_lib.Variable(
1.0, synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=agg)
@def_function.function
def assign_fn():
cluster_resolver = strategy.cluster_resolver
replica_ctx = distribute_lib.get_replica_context()
if ((cluster_resolver and cluster_resolver.task_type == "worker") or
math_ops.equal(replica_ctx.replica_id_in_sync_group,
constant_op.constant(1))):
v_on_read.assign(3.) # pylint:disable=cell-var-from-loop
else:
v_on_read.assign(4.) # pylint:disable=cell-var-from-loop
strategy.run(assign_fn)
# Save ONREAD, restore ONREAD
# Saves v[0] + v[1] = 7 for SUM and 3.5 for MEAN.
ckpt = trackable_utils.Checkpoint(var=v_on_read)
manager = ckpt_manager.CheckpointManager(
ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
manager.save()
# Restores a value of 7/2 = 3.5 for SUM and 3.5 for MEAN.
ckpt.restore(manager.latest_checkpoint)
self.assertEqual(3.5, self.evaluate(v_on_read._values[0]))
# Save ONREAD, restore normal
ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
ckpt_normal.restore(manager.latest_checkpoint)
if agg == variable_scope.VariableAggregation.SUM:
self.assertEqual(7.0, self.evaluate(v_normal_restore.read_value()))
else:
self.assertEqual(3.5, self.evaluate(v_normal_restore.read_value()))
# Save normal, restore ONREAD
ckpt = trackable_utils.Checkpoint(var=v_normal_save)
manager = ckpt_manager.CheckpointManager(
ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
manager.save()
# Restores a value of 2/2 = 1.0 for SUM and 2.0 for MEAN.
ckpt_on_read = trackable_utils.Checkpoint(var=v_on_read)
ckpt_on_read.restore(manager.latest_checkpoint)
if agg == variable_scope.VariableAggregation.SUM:
self.assertEqual(1.0, self.evaluate(v_on_read._values[0]))
else:
self.assertEqual(2.0, self.evaluate(v_on_read._values[0]))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
aggregation=[
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
],
mode=["graph", "eager"],
use_var_policy=[True, False]))
class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
def testScatterSub(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[1., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
def testScatterAdd(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[1., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
def testScatterDiv(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[2., 6., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
def testScatterMul(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[2., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
def testScatterMin(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[3., 4., 5.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
def testScatterMax(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[3., 4., 5.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
def testScatterUpdate(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[0., 0., 0.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
if __name__ == "__main__":
test_util.main()