| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for the distributed values library.""" |
| |
| import copy |
| import os |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python import tf2 |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import strategy_combinations |
| from tensorflow.python.distribute import test_util as ds_test_util |
| from tensorflow.python.distribute import tpu_strategy |
| from tensorflow.python.distribute import tpu_values |
| from tensorflow.python.distribute import values as values_lib |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import sparse_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables as variables_lib |
| from tensorflow.python.training import saver as saver_lib |
| |
| |
| def _device_str(d): |
| return "/device:GPU:" + str(d) |
| |
| |
| def _nested_value(d): |
| return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) |
| |
| |
| def mirrored_and_tpu_strategy_combinations(): |
| return combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| strategy_combinations.tpu_strategy_spmd, |
| ], |
| mode=["graph", "eager"]) |
| |
| |
| class DistributedValuesTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueFromTensor(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| single_value = constant_op.constant(1) |
| def value_fn(ctx): |
| del ctx |
| return single_value |
| |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| self.assertAllEqual( |
| ds_test_util.gather(distribution, distributed_values), |
| constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| array_value = np.array([1., 2., 3.]) |
| def value_fn(ctx): |
| del ctx |
| return array_value |
| |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| self.assertAllEqual( |
| ds_test_util.gather(distribution, distributed_values).numpy(), |
| [[1., 2., 3.]] * distribution.num_replicas_in_sync) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueTupleConstant(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| tuple_value = (1., 2., 3.) |
| def value_fn(ctx): |
| del ctx |
| return tuple_value |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| distributed_values = ds_test_util.gather(distribution, distributed_values) |
| |
| # Expected output for 2 replicas: |
| # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) |
| expected = tuple([v for i in range(distribution.num_replicas_in_sync)] |
| for v in tuple_value) |
| self.assertAllEqual(distributed_values, expected) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueNestedStructurePerReplica(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| tuple_value = (1., 2., 3.) |
| def value_fn(ctx): |
| per_replica = [] |
| for val in tuple_value: |
| per_replica.append(val * ctx.replica_id_in_sync_group) |
| return tuple(per_replica) |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| distributed_values = ds_test_util.gather(distribution, distributed_values) |
| |
| # Expected output for 2 replicas: |
| # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) |
| expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] |
| for v in tuple_value) |
| self.assertAllEqual(distributed_values, expected) |
| |
| # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because |
| # collective ops do not support SparseTensors. |
| @combinations.generate( |
| combinations.combine( |
| distribution=strategy_combinations.all_strategies_minus_default, |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueSpareTensor(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| def value_fn(ctx): |
| del ctx |
| return sparse_tensor.SparseTensor( |
| indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) |
| |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| local_results = distribution.experimental_local_results(distributed_values) |
| for i in range(distribution.num_replicas_in_sync): |
| self.assertAllEqual( |
| sparse_ops.sparse_tensor_to_dense(local_results[i]), |
| [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueExtractFromArray(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| multiple_values = range(distribution.num_replicas_in_sync) |
| def value_fn(ctx): |
| return multiple_values[ctx.replica_id_in_sync_group] |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| distributed_values = ds_test_util.gather(distribution, distributed_values) |
| expected = range(distribution.num_replicas_in_sync) |
| self.assertAllEqual(distributed_values, expected) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=(strategy_combinations.all_strategies_minus_default + |
| strategy_combinations.multiworker_strategies), |
| mode=["eager"] |
| )) |
| def testMakeDistributedValueAndRun(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| |
| @def_function.function |
| def run(): |
| multiple_values = range(distribution.num_replicas_in_sync) |
| def value_fn(ctx): |
| return multiple_values[ctx.replica_id_in_sync_group] |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| |
| def computation(x): |
| return math_ops.square(x) |
| |
| outputs = ds_test_util.gather( |
| distribution, |
| distribution.run(computation, args=(distributed_values,))) |
| return outputs |
| |
| results = run() |
| |
| expected = [i**2 for i in range(distribution.num_replicas_in_sync)] |
| self.assertAllEqual(results, expected) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| ] + strategy_combinations.multiworker_strategies, |
| mode=["eager"])) |
| def testMakeDistributedValueDefaultDevicePlacement(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| def value_fn(ctx): |
| del ctx |
| return constant_op.constant(1.0) |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| default_device = array_ops.identity(constant_op.constant(1.0)).device |
| for i in range(len(distribution.extended.worker_devices)): |
| self.assertAllEqual(distributed_values._values[i].device, default_device) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| ] + strategy_combinations.multiworker_strategies, |
| mode=["eager"], |
| op_type=[constant_op.constant, array_ops.identity])) |
| def testMakeDistributedValueExplicitDevicePlacement(self, distribution, |
| op_type): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| worker_devices = distribution.extended.worker_devices |
| def value_fn(ctx): |
| # In multi client setup, worker_devices is just the devices on that |
| # worker. |
| worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices) |
| with ops.device(worker_devices[worker_device_id]): |
| return op_type(1.0) |
| |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| for i in range(len(distribution.extended.worker_devices)): |
| self.assertAllEqual(distributed_values._values[i].device, |
| worker_devices[i]) |
| |
| |
| class PerReplicaTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations |
| .mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| strategy_combinations.central_storage_strategy_with_two_gpus, |
| ] + strategy_combinations.multiworker_strategies, |
| mode=["eager"])) |
| def testUsePerReplicaInvalidContextGivesError(self, distribution): |
| if not tf2.enabled(): |
| self.skipTest("Only V2 is supported.") |
| multiple_values = range(distribution.num_replicas_in_sync) |
| def value_fn(ctx): |
| return multiple_values[ctx.replica_id_in_sync_group] |
| distributed_values = ( |
| distribution.experimental_distribute_values_from_function(value_fn)) |
| with self.assertRaisesRegex(ValueError, "not inside a replica context"): |
| math_ops.cast(distributed_values, dtypes.float32) |
| |
| |
| class PerWorkerResourceTest(test.TestCase, parameterized.TestCase): |
| |
| @combinations.generate( |
| combinations.combine(dataset_fn_as_tf_function=[True, False])) |
| def testMapFnTracing(self, dataset_fn_as_tf_function): |
| # For a PerWorkerResource to correctly behave when used in dataset.map, |
| # it has to be that the map_fn is not traced only once such that |
| # PerWorkerResource.local_table can return the correct resource. This test |
| # can detect the potential breakage of this behavior on TAP. |
| self._traced_once = 0 |
| |
| def map_fn(x): |
| self._traced_once += 1 |
| return x |
| |
| def dataset_fn(): |
| dataset = dataset_ops.DatasetV2.from_tensors([0, 1, 2]).repeat().batch( |
| 2, drop_remainder=True) |
| dataset = dataset.map(map_fn) |
| return dataset |
| |
| datasets = [] |
| number_of_input_pipelines = 5 |
| |
| if dataset_fn_as_tf_function: |
| dataset_fn = def_function.function(dataset_fn) |
| expected_tracing_times = 1 |
| else: |
| expected_tracing_times = number_of_input_pipelines |
| |
| for _ in range(number_of_input_pipelines): |
| datasets.append(dataset_fn()) |
| |
| self.assertEqual(self._traced_once, expected_tracing_times) |
| |
| |
| class DistributedDelegateTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testGetAttr(self): |
| class Foo(object): |
| |
| def __init__(self, x): |
| self.x = x |
| |
| v = values_lib.DistributedDelegate((Foo(7), Foo(8))) |
| self.assertEqual(7, v.x) |
| with self.assertRaises(AttributeError): |
| _ = v.y |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testOperatorOverride(self): |
| v = values_lib.DistributedDelegate((7, 8)) |
| # v should act like int(7). |
| self.assertEqual(8, v + 1) |
| self.assertEqual(10, 3 + v) |
| self.assertEqual(14, v + v) |
| self.assertEqual(5, v - 2) |
| self.assertEqual(6, 13 - v) |
| self.assertEqual(0, v - v) |
| self.assertEqual(14, v * 2) |
| self.assertEqual(21, 3 * v) |
| self.assertEqual(49, v * v) |
| self.assertEqual(3.5, v / 2) |
| self.assertEqual(1.5, 10.5 / v) |
| self.assertEqual(3, v // 2) |
| self.assertEqual(2, 15 // v) |
| self.assertEqual(1, v % 2) |
| self.assertEqual(2, 16 % v) |
| # pylint: disable=g-generic-assert |
| self.assertTrue(v < 12) |
| self.assertTrue(v <= 12) |
| self.assertFalse(v > 12) |
| self.assertFalse(v >= 12) |
| self.assertFalse(12 < v) |
| self.assertFalse(12 <= v) |
| self.assertTrue(12 > v) |
| self.assertTrue(12 >= v) |
| # pylint: enable=g-generic-assert |
| self.assertEqual(3, v & 3) |
| self.assertEqual(3, 11 & v) |
| self.assertEqual(15, v | 8) |
| self.assertEqual(23, 16 | v) |
| self.assertEqual(4, v ^ 3) |
| self.assertEqual(12, 11 ^ v) |
| self.assertEqual(343, pow(v, 3)) |
| self.assertEqual(3, pow(v, 3, 10)) |
| self.assertEqual(128, pow(2, v)) |
| self.assertEqual(-7, -v) |
| self.assertEqual(~7, ~v) |
| self.assertEqual(7, abs(v)) |
| with self.assertRaises(TypeError): |
| _ = v[2] |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testCopy(self): |
| |
| class Foo(object): |
| |
| def __init__(self, x): |
| self.x = x |
| |
| v = values_lib.DistributedDelegate((Foo(7), Foo(8))) |
| v_shallow_copy = copy.copy(v) |
| self.assertEqual(v.x, v_shallow_copy.x) |
| v_deep_copy = copy.deepcopy(v) |
| self.assertEqual(v.x, v_deep_copy.x) |
| |
| |
| _TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1) |
| |
| |
| def _make_replica_local(method, strategy=None): |
| if strategy is None: |
| devices = ("/device:GPU:0", "/device:CPU:0") |
| else: |
| devices = strategy.extended.worker_devices |
| |
| v = [] |
| for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): |
| with ops.device(d): |
| v.append(variable_scope.get_variable( |
| name=n, initializer=init, use_resource=True)) |
| |
| if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): |
| var_cls = tpu_values.TPUSyncOnReadVariable |
| else: |
| var_cls = values_lib.SyncOnReadVariable |
| replica_local = var_cls(strategy, v, method) |
| return v, replica_local |
| |
| |
| class DistributedVariableTest(test.TestCase, parameterized.TestCase): |
| |
| def _assign_replica_local(self, v, new): |
| for var, n in zip(v, new): |
| with ops.device(var.device): |
| self.evaluate(var.assign(n)) |
| |
| def _save_return_saver(self, sess, var): |
| saver = saver_lib.Saver(var_list=[var]) |
| test_dir = self.get_temp_dir() |
| prefix = os.path.join(test_dir, "ckpt") |
| return saver.save(sess, prefix), saver |
| |
| def _save(self, sess, var): |
| save_path, _ = self._save_return_saver(sess, var) |
| return save_path |
| |
| config = config_pb2.ConfigProto() |
| config.allow_soft_placement = True |
| |
| @test_util.run_in_graph_and_eager_modes(config=config) |
| def testProperties(self): |
| if context.num_gpus() < 1 and context.executing_eagerly(): |
| self.skipTest("A GPU is not available for this test in eager mode.") |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.SUM) |
| |
| self.assertEqual(v[0].constraint, replica_local.constraint) |
| self.assertEqual(v[0].name, replica_local.name) |
| self.assertEqual(v[0].dtype, replica_local.dtype) |
| self.assertEqual(v[0].shape, replica_local.shape) |
| self.assertEqual(variable_scope.VariableAggregation.SUM, |
| replica_local.aggregation) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu |
| ], |
| mode=["eager"])) |
| def testCanPassToDefFun(self, distribution): |
| |
| @def_function.function |
| def add1(x): |
| return x + 1. |
| |
| with distribution.scope(): |
| v = variables_lib.Variable( |
| 1., |
| aggregation=variables_lib.VariableAggregation.MEAN, |
| synchronization=variables_lib.VariableSynchronization.ON_READ) |
| |
| self.assertEqual(2., self.evaluate(add1(v))) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testTensorConversion(self, distribution): |
| with context.graph_mode(): |
| _, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.SUM, distribution) |
| converted = ops.convert_to_tensor(replica_local, as_ref=False) |
| self.assertIsInstance(converted, ops.Tensor) |
| self.assertEqual(converted.dtype, replica_local.dtype) |
| |
| converted = ops.convert_to_tensor(replica_local, as_ref=True) |
| # Resources variable are converted to tensors as well when as_ref is True. |
| self.assertIsInstance(converted, ops.Tensor) |
| self.assertEqual(converted.dtype, replica_local.dtype) |
| |
| @combinations.generate(combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, |
| strategy_combinations.tpu_strategy, |
| strategy_combinations.tpu_strategy_packed_var, |
| ], mode=["eager"])) |
| def testValueInCrossReplicaContext(self, distribution): |
| value_list, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) |
| |
| self.assertIsInstance(replica_local.value(), ops.Tensor) |
| self.assertEqual(self.evaluate(replica_local.value()), |
| self.evaluate(value_list[0].value())) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
| strategy_combinations.tpu_strategy_packed_var, |
| ], |
| mode=["eager"])) |
| def testValueInDefaultReplicaContext(self, distribution): |
| with distribution.scope(): |
| v1 = variables_lib.Variable( |
| 0.0, |
| aggregation=variables_lib.VariableAggregation.SUM, |
| synchronization=variables_lib.VariableSynchronization.ON_READ) |
| v2 = variables_lib.Variable( |
| 0.0, |
| aggregation=variables_lib.VariableAggregation.SUM, |
| synchronization=variables_lib.VariableSynchronization.ON_READ) |
| |
| @def_function.function |
| def replica_fn(): |
| v1.assign_add(1.0) |
| v2.assign_add(2.0) |
| |
| distribution.run(replica_fn) |
| sum_v = v1 + v2 |
| self.assertEqual(sum_v, 6.0) |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.tpu_strategy_packed_var, |
| ], |
| mode=["eager"])) |
| def testValueInFunctionCrossReplicaContext(self, distribution): |
| with distribution.scope(): |
| v1 = variables_lib.Variable( |
| 0.0, |
| aggregation=variables_lib.VariableAggregation.NONE, |
| synchronization=variables_lib.VariableSynchronization.ON_WRITE) |
| |
| @def_function.function |
| def assign_fn(): |
| v1.assign(1.0) |
| |
| assign_fn() |
| self.assertEqual(v1, 1.0) |
| |
| # Make sure the function graph has composite variable as inputs. |
| graph_def = assign_fn.get_concrete_function().graph.as_graph_def() |
| self.assertRegex(str(graph_def), "device:COMPOSITE:0") |
| |
| @combinations.generate( |
| combinations.combine( |
| distribution=[ |
| strategy_combinations.tpu_strategy_packed_var, |
| ], |
| mode=["eager"])) |
| def testReplicatedValueNameDeterministic(self, distribution): |
| with distribution.scope(): |
| v1 = variables_lib.Variable(0.0, name="test_var_1") |
| v2 = variables_lib.Variable(0.0, name="test_var_2") |
| |
| def fn(): |
| v1.assign_add(1.0) |
| v2.assign_add(2.0) |
| return v1 + v2 |
| |
| @def_function.function |
| def dist_run_fn(): |
| a = distribution.run(fn) |
| return a |
| |
| concrete_fn = dist_run_fn.get_concrete_function() |
| inputs = concrete_fn.graph.inputs |
| self.assertLen(inputs, 2) |
| # Before cl/433948982, input name will include a non-deterministic uid, |
| # e.g. "test_var_1_139726389910864/handle/inputs_0:0" |
| self.assertEqual(inputs[0].name, "test_var_1/handle/inputs_0:0") |
| self.assertEqual(inputs[1].name, "test_var_2/handle/inputs_0:0") |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): |
| with self.cached_session() as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.SUM, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [3., 4.]) |
| |
| with distribution.scope(): |
| # Saves the current value of v[0] + v[1], 7. |
| save_path, saver = self._save_return_saver(sess, replica_local) |
| |
| # Change the values between save and restore. |
| self._assign_replica_local(v, [5., 6.]) |
| |
| # Restores the saved value of 7. which gets divided equally |
| # between the variables. |
| saver.restore(sess, save_path) |
| self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): |
| if context.num_gpus() < 1 and context.executing_eagerly(): |
| self.skipTest("A GPU is not available for this test in eager mode.") |
| |
| with self.cached_session() as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.MEAN, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [3., 4.]) |
| |
| with distribution.scope(): |
| # Saves the current value of (v[0] + v[1])/2, 3.5. |
| save_path, saver = self._save_return_saver(sess, replica_local) |
| |
| # Change the values between save and restore. |
| self._assign_replica_local(v, [5., 6.]) |
| |
| # Restores the saved value of 3.5 to both variables. |
| saver.restore(sess, save_path) |
| self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) |
| |
| def _save_replica_local_mean(self, distribution): |
| """Save variables with mirroring, returns save_path.""" |
| with self.session(graph=ops.Graph()) as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.MEAN, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [3., 4.]) |
| |
| with distribution.scope(): |
| # Saves the current value of (v[0] + v[1])/2, 3.5 |
| save_path = self._save(sess, replica_local) |
| |
| # Change the values between save and restore. |
| self._assign_replica_local(v, [5., 6.]) |
| return save_path |
| |
| def _save_replica_local_sum(self, distribution): |
| """Save variables with mirroring, returns save_path.""" |
| with self.session(graph=ops.Graph()) as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.SUM, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [1.5, 2.]) |
| |
| with distribution.scope(): |
| # Saves the current value of v[0] + v[1], 3.5 |
| save_path = self._save(sess, replica_local) |
| |
| # Change the values between save and restore. |
| self._assign_replica_local(v, [5., 6.]) |
| return save_path |
| |
| def _save_normal(self): |
| """Save variables without mirroring, returns save_path.""" |
| with self.session(graph=ops.Graph()) as sess: |
| var = variable_scope.get_variable( |
| name="v", initializer=1., use_resource=True) |
| |
| # Overwrite the initial value. |
| self.evaluate(var.assign(3.5)) |
| |
| # Saves the current value of var, 3.5. |
| save_path = self._save(sess, var) |
| |
| # Change the values between save and restore. |
| self.evaluate(var.assign(5.)) |
| return save_path |
| |
| def _restore_normal(self, save_path): |
| """Restore to variables without mirroring in a fresh graph.""" |
| with self.session(graph=ops.Graph()) as sess: |
| var = variable_scope.get_variable( |
| name="v", initializer=7., use_resource=True) |
| |
| # Overwrite the initial value. |
| self.evaluate(var.assign(8.)) |
| |
| # Restores the saved value of 3.5 to `var`. |
| saver = saver_lib.Saver(var_list=[var]) |
| saver.restore(sess, save_path) |
| self.assertEqual(3.5, self.evaluate(var)) |
| |
| def _restore_replica_local_mean(self, save_path, distribution): |
| """Restore to variables with mirroring in a fresh graph.""" |
| with self.session(graph=ops.Graph()) as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.MEAN, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [7., 8.]) |
| |
| with distribution.scope(): |
| # Restores the saved value of 3.5 to both variables. |
| saver = saver_lib.Saver(var_list=[replica_local]) |
| saver.restore(sess, save_path) |
| self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) |
| |
| def _restore_replica_local_sum(self, save_path, distribution): |
| """Restore to variables with mirroring in a fresh graph.""" |
| with self.session(graph=ops.Graph()) as sess: |
| v, replica_local = _make_replica_local( |
| variable_scope.VariableAggregation.SUM, distribution) |
| |
| # Overwrite the initial values. |
| self._assign_replica_local(v, [7., 8.]) |
| |
| with distribution.scope(): |
| # Restores the saved value of 3.5 to both variables. |
| saver = saver_lib.Saver(var_list=[replica_local]) |
| saver.restore(sess, save_path) |
| self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): |
| save_path = self._save_replica_local_mean(distribution) |
| self._restore_replica_local_mean(save_path, distribution) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): |
| save_path = self._save_replica_local_sum(distribution) |
| self._restore_replica_local_sum(save_path, distribution) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveReplicaLocalMeanRestoreNormal(self, distribution): |
| save_path = self._save_replica_local_mean(distribution) |
| self._restore_normal(save_path) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveReplicaLocalSumRestoreNormal(self, distribution): |
| save_path = self._save_replica_local_sum(distribution) |
| self._restore_normal(save_path) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveNormalRestoreReplicaLocalMean(self, distribution): |
| save_path = self._save_normal() |
| self._restore_replica_local_mean(save_path, distribution) |
| |
| @combinations.generate(mirrored_and_tpu_strategy_combinations()) |
| def testSaveNormalRestoreReplicaLocalSum(self, distribution): |
| save_path = self._save_normal() |
| self._restore_replica_local_sum(save_path, distribution) |
| |
| |
| if __name__ == "__main__": |
| ds_test_util.main() |