| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Test DistributionStrategy, ReplicaContext, and supporting APIs.""" |
| |
| from absl.testing import parameterized |
| |
| from tensorflow.python.autograph.core import converter_testing |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import combinations |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import input_lib |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib |
| from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variable_v1 |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.util import nest |
| |
| |
| class _TestReplicaContext(distribute_lib.ReplicaContext): |
| |
| def merge_call(self, fn, *args, **kwargs): |
| return kwargs["test_arg"] |
| |
| |
| def _get_test_variable(name, synchronization, aggregation): |
| return { |
| "name": name, |
| "synchronization": synchronization, |
| "aggregation": aggregation |
| } |
| |
| |
| def _test_input_fn(input_context): |
| del input_context |
| return dataset_ops.DatasetV2.from_tensors(1.).repeat() |
| |
| |
| class _TestStrategy(distribute_lib.Strategy): |
| |
| def __init__(self): |
| super(_TestStrategy, self).__init__(_TestExtended(self)) |
| |
| |
| class _TestExtended(distribute_lib.StrategyExtendedV1): |
| |
| def __init__(self, distribute): |
| super(_TestExtended, self).__init__(distribute) |
| worker_device_pairs = [("", ["/device:CPU:0"])] |
| self._input_workers = input_lib.InputWorkers(worker_device_pairs) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| with _TestReplicaContext( |
| self._container_strategy(), replica_id_in_sync_group=0): |
| return fn(*args, **kwargs) |
| |
| def _create_variable(self, next_creator, **kwargs): |
| return _get_test_variable(kwargs["name"], kwargs["synchronization"], |
| kwargs["aggregation"]) |
| |
| def _make_input_fn_iterator( |
| self, |
| input_fn, |
| replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): |
| return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, |
| [distribute_lib.InputContext()], |
| self._container_strategy()) |
| |
| def _distribute_datasets_from_function(self, dataset_fn, options): |
| return dataset_fn(distribute_lib.InputContext()) |
| |
| def _local_results(self, value): |
| return (value,) |
| |
| def _reduce_to(self, reduce_op, value, destinations, options): |
| del reduce_op, destinations, options |
| return value |
| |
| def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, |
| initial_loop_values=None): |
| # TODO(tomhennigan) This is missing many things (e.g. ctx.run_op). |
| ctx = input_lib.MultiStepContext() |
| for _ in range(iterations): |
| fn(ctx, iterator.get_next()) |
| return ctx |
| |
| def _update(self, var, fn, args, kwargs, group): |
| # The implementations of _update() and _update_non_slot() are identical |
| # except _update() passes `var` as the first argument to `fn()`. |
| return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| del colocate_with |
| result = fn(*args, **kwargs) |
| if group: |
| return result |
| else: |
| return nest.map_structure(self._unwrap, result) |
| |
| def _get_local_replica_id(self, replica_id_in_sync_group): |
| return replica_id_in_sync_group |
| |
| |
| def _assert_in_default_state(t): |
| t.assertIs(distribute_lib._get_default_replica_context(), |
| distribute_lib.get_replica_context()) |
| t.assertIs(None, distribute_lib.get_cross_replica_context()) |
| t.assertFalse(distribute_lib.in_cross_replica_context()) |
| t.assertIs( |
| distribute_lib._get_default_strategy(), distribute_lib.get_strategy()) |
| t.assertFalse(distribute_lib.has_strategy()) |
| |
| |
| def _run_in_and_out_of_scope(unbound_test_method): |
| def wrapper(test_case): |
| dist = _TestStrategy() |
| # Running in the default (replica) scope should be supported. |
| _assert_in_default_state(test_case) |
| unbound_test_method(test_case, dist) |
| # As well as running in the strategy scope. |
| with dist.scope(): |
| unbound_test_method(test_case, dist) |
| _assert_in_default_state(test_case) |
| # When run under a different strategy the test method should fail. |
| another_strategy = _TestStrategy() |
| msg = "Mixing different .*Strategy objects" |
| with test_case.assertRaisesRegex(RuntimeError, msg): |
| with another_strategy.scope(): |
| unbound_test_method(test_case, dist) |
| return wrapper |
| |
| |
| class TestStrategyTest(test.TestCase): |
| |
| def testCallForEachReplica(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| |
| def run_fn(): |
| replica_context = distribute_lib.get_replica_context() |
| self.assertIsNotNone(replica_context) |
| self.assertIs(None, distribute_lib.get_cross_replica_context()) |
| self.assertFalse(distribute_lib.in_cross_replica_context()) |
| self.assertTrue(distribute_lib.has_strategy()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) |
| expected_value = _get_test_variable( |
| "bar", variable_scope.VariableSynchronization.AUTO, |
| variable_scope.VariableAggregation.NONE) |
| self.assertDictEqual(expected_value, |
| variable_v1.VariableV1(1.0, name="bar")) |
| |
| dist.extended.call_for_each_replica(run_fn) |
| with dist.scope(): |
| dist.extended.call_for_each_replica(run_fn) |
| _assert_in_default_state(self) |
| |
| def testScope(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| with dist.scope(): |
| self.assertIs(None, distribute_lib.get_replica_context()) |
| self.assertIs(dist, distribute_lib.get_cross_replica_context()) |
| self.assertTrue(distribute_lib.in_cross_replica_context()) |
| self.assertTrue(distribute_lib.has_strategy()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| expected_value = _get_test_variable( |
| "baz", variable_scope.VariableSynchronization.AUTO, |
| variable_scope.VariableAggregation.NONE) |
| self.assertDictEqual(expected_value, |
| variable_v1.VariableV1(1.0, name="baz")) |
| _assert_in_default_state(self) |
| |
| def testScopeDeviceNestingError(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| # Open a device scope with dist.scope(). |
| dist.extended._default_device = "/device:GPU:0" |
| scope = dist.scope() |
| scope.__enter__() |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| with ops.device("/device:CPU:0"): |
| with self.assertRaisesRegex(RuntimeError, "Device scope nesting error"): |
| scope.__exit__(None, None, None) |
| scope.__exit__(None, None, None) |
| _assert_in_default_state(self) |
| |
| def testScopeVarCreatorNestingError(self): |
| |
| def creator(next_creator, **kwargs): |
| return next_creator(**kwargs) |
| |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| scope = dist.scope() |
| scope.__enter__() |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| with variable_scope.variable_creator_scope(creator): |
| with self.assertRaisesRegex(RuntimeError, |
| "Variable creator scope nesting error"): |
| scope.__exit__(None, None, None) |
| scope.__exit__(None, None, None) |
| _assert_in_default_state(self) |
| |
| def testScopeVarScopeNestingError(self): |
| # We create a new graph here to simplify clean-up, since the error |
| # we are triggering happens in the middle of scope.__exit__() and |
| # leaves us in a weird state. |
| with ops.Graph().as_default(): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| scope = dist.scope() |
| scope.__enter__() |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| with variable_scope.variable_scope("AA"): |
| with self.assertRaisesRegex(RuntimeError, |
| "Variable scope nesting error"): |
| scope.__exit__(None, None, None) |
| _assert_in_default_state(self) |
| |
| def testSettingSynchronizationAndAggregation(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| with dist.scope(): |
| expected_value = _get_test_variable( |
| "baz", variable_scope.VariableSynchronization.ON_WRITE, |
| variable_scope.VariableAggregation.MEAN) |
| self.assertDictEqual( |
| expected_value, |
| variable_v1.VariableV1( |
| 1.0, |
| name="baz", |
| synchronization=variable_scope.VariableSynchronization.ON_WRITE, |
| aggregation=variable_scope.VariableAggregation.MEAN)) |
| _assert_in_default_state(self) |
| |
| def testSetStrategy(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| dist2 = _TestStrategy() |
| distribute_lib.experimental_set_strategy(dist) |
| self.assertIs(None, distribute_lib.get_replica_context()) |
| self.assertIs(dist, distribute_lib.get_cross_replica_context()) |
| self.assertTrue(distribute_lib.in_cross_replica_context()) |
| self.assertTrue(distribute_lib.has_strategy()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| expected_value = _get_test_variable( |
| "baz", variable_scope.VariableSynchronization.AUTO, |
| variable_scope.VariableAggregation.NONE) |
| self.assertDictEqual(expected_value, |
| variable_v1.VariableV1(1.0, name="baz")) |
| distribute_lib.experimental_set_strategy(dist2) |
| self.assertIs(dist2, distribute_lib.get_strategy()) |
| distribute_lib.experimental_set_strategy(None) |
| _assert_in_default_state(self) |
| |
| def testSetStrategyInScope(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| with dist.scope(): |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Must not be called inside a `tf.distribute.Strategy` scope"): |
| distribute_lib.experimental_set_strategy(_TestStrategy()) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Must not be called inside a `tf.distribute.Strategy` scope"): |
| distribute_lib.experimental_set_strategy(dist) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Must not be called inside a `tf.distribute.Strategy` scope"): |
| distribute_lib.experimental_set_strategy(None) |
| _assert_in_default_state(self) |
| |
| def testSameScopeNesting(self): |
| _assert_in_default_state(self) |
| dist = _TestStrategy() |
| scope_a = dist.scope() |
| with scope_a: |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| scope_b = dist.scope() |
| with scope_b: |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| with scope_a: |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| dist2 = _TestStrategy() |
| scope2 = dist2.scope() |
| with self.assertRaisesRegex( |
| RuntimeError, "Mixing different tf.distribute.Strategy objects"): |
| with scope2: |
| pass |
| _assert_in_default_state(self) |
| with scope_b: |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| _assert_in_default_state(self) |
| |
| @_run_in_and_out_of_scope |
| def testMakeInputFnIterator(self, dist): |
| self.assertIsNotNone(dist.make_input_fn_iterator(_test_input_fn)) |
| |
| @_run_in_and_out_of_scope |
| def testReduce(self, dist): |
| x = constant_op.constant(1.) |
| x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| |
| def testReductions_acceptStringOps(self): |
| dist = _TestStrategy() |
| for op in ("mean", "MEAN", "sum", "SUM"): |
| x = constant_op.constant(1.) |
| y = constant_op.constant(1.) |
| x_r = dist.reduce(op, x, axis=None) |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| x_r = dist.extended.reduce_to(op, x, "/CPU:0") |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| x_r, y_r = dist.extended.batch_reduce_to(op, |
| ((x, "/CPU:0"), (y, "/CPU:0"))) |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| self.assertEqual(self.evaluate(y), self.evaluate(y_r)) |
| |
| @_run_in_and_out_of_scope |
| def testReduceMeanAxis(self, dist): |
| x = constant_op.constant([[1., 2.], [3., 4.]]) |
| x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) |
| self.assertAllEqual(self.evaluate(x), self.evaluate(x_r)) |
| x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) |
| self.assertAllEqual([2., 3.], self.evaluate(x_r)) |
| x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=(0, 1)) |
| self.assertEqual(2.5, self.evaluate(x_r)) |
| |
| @_run_in_and_out_of_scope |
| def testReduceSumAxis(self, dist): |
| x = constant_op.constant([[1., 2.], [3., 4.]]) |
| x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=None) |
| self.assertAllEqual(self.evaluate(x), self.evaluate(x_r)) |
| x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=0) |
| self.assertAllEqual([4., 6.], self.evaluate(x_r)) |
| x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=(0, 1)) |
| self.assertEqual(10., self.evaluate(x_r)) |
| |
| @_run_in_and_out_of_scope |
| def testExperimentalRunStepsOnIterator(self, dist): |
| all_inputs = [] |
| dataset = dataset_ops.Dataset.from_tensors(1.).repeat() |
| dist.extended.experimental_run_steps_on_iterator( |
| lambda _, inputs: all_inputs.append(self.evaluate(inputs)), |
| dataset_ops.make_one_shot_iterator(dataset)) |
| self.assertEqual(all_inputs, [1.]) |
| |
| @_run_in_and_out_of_scope |
| def testReduceTo(self, dist): |
| x = constant_op.constant(1.) |
| x_r = dist.extended.reduce_to(reduce_util.ReduceOp.MEAN, x, "/CPU:0") |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| |
| @_run_in_and_out_of_scope |
| def testBatchReduceTo(self, dist): |
| x = constant_op.constant(1.) |
| y = constant_op.constant(1.) |
| x_r, y_r = dist.extended.batch_reduce_to(reduce_util.ReduceOp.MEAN, |
| ((x, "/CPU:0"), (y, "/CPU:0"))) |
| self.assertEqual(self.evaluate(x), self.evaluate(x_r)) |
| self.assertEqual(self.evaluate(y), self.evaluate(y_r)) |
| |
| @_run_in_and_out_of_scope |
| def testUpdate(self, dist): |
| with dist.scope(): |
| v = variables.Variable(1.) |
| t = constant_op.constant(2.) |
| |
| def assign_fn(vv, tt): |
| self.assertIs(vv, v) |
| self.assertIs(tt, t) |
| dist.extended.update(v, assign_fn, (t,)) |
| |
| @_run_in_and_out_of_scope |
| def testUpdateAutoGraph(self, dist): |
| with dist.scope(): |
| v = variables.Variable(1.) |
| t = constant_op.constant(2.) |
| |
| def assign_fn(unused_vv, unused_tt): |
| self.assertTrue(converter_testing.is_inside_generated_code()) |
| |
| @def_function.function # AutoGraph is default-on only within tf.function |
| def test_fn(): |
| dist.extended.update(v, assign_fn, (t,)) |
| |
| test_fn() |
| |
| @_run_in_and_out_of_scope |
| def testUpdateNonSlot(self, dist): |
| t = constant_op.constant(2.) |
| update_calls = [] |
| dist.extended.update_non_slot(t, lambda: update_calls.append(1)) |
| self.assertEqual(len(update_calls), 1) |
| |
| @_run_in_and_out_of_scope |
| def testUpdateNonSlotAutoGraph(self, dist): |
| t = constant_op.constant(2.) |
| |
| def update_fn(): |
| self.assertTrue(converter_testing.is_inside_generated_code()) |
| |
| @def_function.function # AutoGraph is default-on only within tf.function |
| def test_fn(): |
| dist.extended.update_non_slot(t, update_fn) |
| |
| test_fn() |
| |
| def testClusterResolverDefaultNotImplemented(self): |
| dist = _TestStrategy() |
| self.assertIsNone(dist.cluster_resolver) |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| cluster_resolver = cluster_resolver_lib.SimpleClusterResolver( |
| base_cluster_spec) |
| dist.extended._cluster_resolver = cluster_resolver |
| self.assertIs(dist.cluster_resolver, cluster_resolver) |
| |
| |
| # _TestStrategy2 is like _TestStrategy, except it doesn't change variable |
| # creation. |
| class _TestStrategy2(distribute_lib.Strategy): |
| |
| def __init__(self): |
| super(_TestStrategy2, self).__init__(_TestExtended2(self)) |
| |
| |
| class _TestExtended2(_TestExtended): |
| |
| def _create_variable(self, next_creator, **kwargs): |
| return next_creator(**kwargs) |
| |
| |
| class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase): |
| |
| def testMergeCall(self): |
| _assert_in_default_state(self) |
| |
| def merge_fn(dist, s): |
| self.assertIs(distribute_lib._get_default_strategy(), dist) |
| self.assertIs(None, distribute_lib.get_replica_context()) |
| self.assertIs(dist, distribute_lib.get_cross_replica_context()) |
| self.assertTrue(distribute_lib.in_cross_replica_context()) |
| self.assertIs(dist, distribute_lib.get_strategy()) |
| self.assertFalse(distribute_lib.has_strategy()) |
| return "foo_" + s |
| |
| replica_ctx = distribute_lib.get_replica_context() |
| self.assertIs(distribute_lib._get_default_replica_context(), replica_ctx) |
| self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",))) |
| _assert_in_default_state(self) |
| |
| def testMergeCallAutoGraph(self): |
| _assert_in_default_state(self) |
| |
| def merge_fn(_, s): |
| self.assertTrue(converter_testing.is_inside_generated_code()) |
| return s |
| |
| @def_function.function # AutoGraph is default-on only within tf.function |
| def test_fn(): |
| replica_ctx = distribute_lib.get_replica_context() |
| replica_ctx.merge_call(merge_fn, args=("bar",)) |
| |
| test_fn() |
| |
| def testScopeMostlyNoOp(self): |
| _assert_in_default_state(self) |
| |
| test_strategy = _TestStrategy2() |
| with test_strategy.scope(): |
| variable_v1.VariableV1(1.0, name="before") |
| |
| default_strategy = distribute_lib._get_default_strategy() |
| scope = default_strategy.scope() |
| with scope: |
| _assert_in_default_state(self) |
| |
| with test_strategy.scope(): |
| with self.assertRaisesRegex( |
| RuntimeError, "Mixing different tf.distribute.Strategy objects"): |
| variable_v1.VariableV1(1.0, name="error") |
| |
| with scope: |
| _assert_in_default_state(self) |
| |
| with test_strategy.scope(): |
| with self.assertRaisesRegex( |
| RuntimeError, "Mixing different tf.distribute.Strategy objects"): |
| variable_v1.VariableV1(1.0, name="also_error") |
| |
| _assert_in_default_state(self) |
| |
| _assert_in_default_state(self) |
| with test_strategy.scope(): |
| variable_v1.VariableV1(1.0, name="after") |
| |
| def testExperimentalRunV2(self): |
| default_strategy = distribute_lib._get_default_strategy() |
| dataset = dataset_ops.Dataset.range(10).batch(2) |
| iterator = default_strategy.extended._make_dataset_iterator(dataset) |
| next_val = iterator.get_next() |
| |
| def train_step(input_data): |
| return input_data |
| |
| for _ in range(2): |
| default_strategy.run(train_step, args=(next_val,)) |
| |
| @combinations.generate(combinations.combine(mode=["graph", "eager"])) |
| def testDistributedDatasets(self): |
| default_strategy = distribute_lib._get_default_strategy() |
| if context.executing_eagerly(): |
| dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) |
| dist_dataset = default_strategy.experimental_distribute_dataset( |
| dataset_fn(distribute_lib.InputContext())) |
| next_val = next(iter(dist_dataset)) |
| else: |
| dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2) |
| dist_dataset = default_strategy.experimental_distribute_dataset( |
| dataset_fn(distribute_lib.InputContext())) |
| iterator = dist_dataset.make_initializable_iterator() |
| self.evaluate(iterator.initializer) |
| next_val = iterator.get_next() |
| self.assertAllEqual([0, 1], self.evaluate(next_val)) |
| |
| @combinations.generate(combinations.combine(mode=["graph", "eager"])) |
| def testDistributedDatasetsFromFunction(self): |
| default_strategy = distribute_lib._get_default_strategy() |
| if context.executing_eagerly(): |
| dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) |
| dist_dataset_from_func = \ |
| default_strategy.distribute_datasets_from_function( |
| dataset_fn) |
| next_val = next(iter(dist_dataset_from_func)) |
| self.assertAllEqual([0, 1], self.evaluate(next_val)) |
| else: |
| dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) |
| dist_dataset_from_func = \ |
| default_strategy.distribute_datasets_from_function( |
| dataset_fn) |
| dataset_ops.make_initializable_iterator(dist_dataset_from_func) |
| |
| @combinations.generate(combinations.combine(tf_api_version=1)) |
| def testV1(self): |
| self.assertIsInstance( |
| distribute_lib.get_strategy(), distribute_lib.StrategyV1) |
| |
| @combinations.generate(combinations.combine(tf_api_version=2)) |
| def testV2(self): |
| self.assertIsInstance( |
| distribute_lib.get_strategy(), distribute_lib.Strategy) |
| |
| |
| class InputContextTest(test.TestCase): |
| |
| def testProperties(self): |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) |
| self.assertEqual(6, input_context.num_replicas_in_sync) |
| self.assertEqual(1, input_context.input_pipeline_id) |
| self.assertEqual(2, input_context.num_input_pipelines) |
| |
| def testPerReplicaBatchSize(self): |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) |
| self.assertEqual(2, input_context.get_per_replica_batch_size(12)) |
| with self.assertRaises(ValueError): |
| input_context.get_per_replica_batch_size(13) |
| |
| def testStr(self): |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42) |
| self.assertEqual( |
| "tf.distribute.InputContext(input pipeline id 0, total: 1)", |
| str(input_context)) |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42) |
| self.assertEqual( |
| "tf.distribute.InputContext(input pipeline id 1, total: 3)", |
| str(input_context)) |
| |
| |
| if __name__ == "__main__": |
| test.main() |