| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Library for testing DistributionStrategy descendants.""" |
| |
| import functools |
| import os |
| import tempfile |
| |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.util import event_pb2 |
| from tensorflow.python.client import session as session_lib |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import mirrored_strategy as mirrored_lib |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import tpu_strategy |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import test |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.lib.io import tf_record |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import init_ops_v2 |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import summary_ops_v2 as summary_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variable_v1 |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import gfile |
| from tensorflow.python.training import optimizer |
| from tensorflow.python.training import training_util |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_inspect |
| |
| |
| class _TestException(Exception): |
| pass |
| |
| |
| # Conditionally wrap the fn in a def_function.function (so it runs in graph |
| # mode). |
| def _maybe_run_in_function(fn, run_in_function=False): |
| if not run_in_function or not context.executing_eagerly(): |
| return fn |
| else: |
| return def_function.function()(fn) |
| |
| |
| # May be the argument to either distribution.extended.call_for_each_replica() or |
| # get_replica_context().merge_call() |
| def _raise_exception_fn(_=None): |
| raise _TestException() |
| |
| |
| # Must be the argument to a distribution.extended.call_for_each_replica() call, |
| # calls a get_replica_context().merge_call() that raises an exception. |
| def _merge_raises_fn(): |
| distribute_lib.get_replica_context().merge_call(_raise_exception_fn) |
| |
| |
| # Must be the argument to a get_replica_context().merge_call() call, calls |
| # dist.extended.call_for_each_replica() with a function that raises an |
| # exception. |
| def _call_raises_fn(dist): |
| dist.extended.call_for_each_replica(_raise_exception_fn) |
| |
| |
| # Must be the argument to a distribution.extended.call_for_each_replica() call, |
| # calls a get_replica_context().merge_call() that calls a |
| # call_for_each_replica() that raises an exception. |
| def _merge_call_raises_fn(): |
| distribute_lib.get_replica_context().merge_call(_call_raises_fn) |
| |
| |
| # Must be the argument to a get_replica_context().merge_call() call, calls |
| # dist.extended.call_for_each_replica() with a function that calls a |
| # get_replica_context().merge_call() that raises an exception. |
| def _call_merge_raises_fn(dist): |
| dist.extended.call_for_each_replica(_merge_raises_fn) |
| |
| |
| # Must be the argument to a distribution.extended.call_for_each_replica() call, |
| # calls a get_replica_context().merge_call() that calls a |
| # call_for_each_replica() that calls a get_replica_context().merge_call() that |
| # raises an exception. |
| def _merge_call_merge_raises_fn(): |
| distribute_lib.get_replica_context().merge_call(_call_merge_raises_fn) |
| |
| |
| def _events_from_logdir(test_case, logdir): |
| """Reads summary events from log directory.""" |
| test_case.assertTrue(gfile.Exists(logdir)) |
| files = gfile.ListDirectory(logdir) |
| test_case.assertLen(files, 1) |
| records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) |
| result = [] |
| for r in records: |
| event = event_pb2.Event() |
| event.ParseFromString(r) |
| result.append(event) |
| return result |
| |
| |
| def create_variable_like_keras_layer(name, shape, dtype): |
| """Utitlity for create variables that works like variable in keras layer.""" |
| initializer = functools.partial( |
| init_ops_v2.GlorotUniform(), shape, dtype=dtype) |
| return variables.Variable( |
| initial_value=initializer, name=name, trainable=True) |
| |
| |
| def is_optimizer_v2_instance(optimizer_obj): |
| # For a optimizer instance, the v2 implementation has var_list as a required |
| # argument. |
| arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize) |
| return "var_list" in arg_spec.args[:-len(arg_spec.defaults)] |
| |
| |
| def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool: |
| return isinstance( |
| strategy, |
| (mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1)) |
| |
| |
| def is_multi_worker_mirrored_strategy( |
| strategy: distribute_lib.Strategy) -> bool: |
| return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy, |
| mwms_lib.CollectiveAllReduceStrategyV1)) |
| |
| |
| def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool: |
| return isinstance(strategy, |
| (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, |
| tpu_strategy.TPUStrategyV2)) |
| |
| |
| class DistributionTestBase(test.TestCase): |
| """Some tests that should work with any DistributionStrategy.""" |
| |
| def _test_minimize_loss_eager(self, d): |
| with d.scope(): |
| kernel = create_variable_like_keras_layer( |
| name="kernel", shape=(1, 1), dtype=dtypes.float32) |
| def loss(x): |
| y = array_ops.reshape( |
| math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) |
| return y * y |
| # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a |
| # common `implicit_grad` function and put it in DistributionStrategy. |
| grad_fn = backprop.implicit_grad(loss) |
| grad_fn = optimizer.get_filtered_grad_fn(grad_fn) |
| |
| def update(v, g): |
| return v.assign_sub(0.2 * g) |
| |
| one = array_ops.identity([[1.]]) |
| |
| def step(): |
| """Perform one optimization step.""" |
| # Run forward & backward to get gradients, variables list. |
| g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) |
| |
| # Update the variables using the gradients and the update() function. |
| before_list = [] |
| after_list = [] |
| for g, v in g_v: |
| fetched = d.extended.read_var(v) |
| before_list.append(fetched) |
| # control_dependencies irrelevant but harmless in eager execution |
| with ops.control_dependencies([fetched]): |
| g = d.extended.reduce_to( |
| reduce_util.ReduceOp.SUM, g, destinations=v) |
| with ops.control_dependencies( |
| d.extended.update(v, update, args=(g,), group=False)): |
| after_list.append(d.extended.read_var(v)) |
| return before_list, after_list |
| |
| for i in range(10): |
| b, a = step() |
| if i == 0: |
| before, = b # pylint: disable=unbalanced-tuple-unpacking |
| after, = a # pylint: disable=unbalanced-tuple-unpacking |
| |
| error_before = abs(before.numpy() - 1) |
| error_after = abs(after.numpy() - 1) |
| # Error should go down |
| self.assertLess(error_after, error_before) |
| |
| def _test_minimize_loss_graph(self, |
| d, |
| soft_placement=False, |
| learning_rate=0.2): |
| config = config_pb2.ConfigProto() |
| config.allow_soft_placement = soft_placement |
| config.gpu_options.per_process_gpu_memory_fraction = 0.3 |
| with context.graph_mode(), \ |
| ops.Graph().as_default(), \ |
| self.cached_session(config=config) as sess, \ |
| d.scope(): |
| kernel = create_variable_like_keras_layer( |
| name="kernel", shape=(1, 1), dtype=dtypes.float32) |
| |
| def loss(x): |
| y = array_ops.reshape( |
| math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) |
| return y * y |
| |
| grad_fn = backprop.implicit_grad(loss) |
| |
| def update(v, g): |
| return v.assign_sub(learning_rate * g) |
| |
| one = array_ops.identity([[1.]]) |
| |
| def step(): |
| """Perform one optimization step.""" |
| # Run forward & backward to get gradients, variables list. |
| g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) |
| |
| # Update the variables using the gradients and the update() function. |
| before_list = [] |
| after_list = [] |
| for g, v in g_v: |
| fetched = d.extended.read_var(v) |
| before_list.append(fetched) |
| with ops.control_dependencies([fetched]): |
| g = d.extended.reduce_to( |
| reduce_util.ReduceOp.SUM, g, destinations=v) |
| with ops.control_dependencies( |
| d.extended.update(v, update, args=(g,), group=False)): |
| after_list.append(d.extended.read_var(v)) |
| return before_list, after_list |
| |
| before_out, after_out = step() |
| variables.global_variables_initializer().run() |
| for i in range(10): |
| b, a = sess.run((before_out, after_out)) |
| if i == 0: |
| before, = b |
| after, = a |
| |
| error_before = abs(before - 1) |
| error_after = abs(after - 1) |
| # Error should go down |
| self.assertLess(error_after, error_before) |
| |
| def _test_summary_for_replica_zero_only(self, d): |
| logdir = tempfile.mkdtemp() |
| |
| def run_fn(): |
| """Function executed for each replica.""" |
| with summary_writer.as_default(): |
| replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group |
| return summary_ops.write("a", replica_id) |
| |
| with self.cached_session() as sess, d.scope(), \ |
| summary_ops.always_record_summaries(): |
| # We need global_step because summary writing op *always* has global_step |
| # as input, even when we always record summary or never record summary. |
| global_step = training_util.get_or_create_global_step() |
| if not context.executing_eagerly(): |
| # When executing eagerly, variables are initialized immediately after |
| # creation, and its initializer will be None. |
| global_step.initializer.run() |
| summary_ops.set_step(0) |
| summary_writer = summary_ops.create_file_writer(logdir) |
| output = d.extended.call_for_each_replica(run_fn) |
| unwrapped = d.unwrap(output) |
| if not context.executing_eagerly(): |
| sess.run(summary_writer.init()) |
| sess.run(unwrapped) |
| sess.run(summary_writer.close()) |
| |
| events = _events_from_logdir(self, logdir) |
| # There will be 2 entries: 1 summary file header entry, and 1 entry |
| # written by replica 0. |
| self.assertLen(events, 2) |
| self.assertEqual(events[1].summary.value[0].tag, "a") |
| self.assertEqual(events[1].summary.value[0].simple_value, 0.0) |
| |
| def _test_replica_id(self, d): |
| with d.scope(): |
| expected_devices = [False] * len(d.extended.worker_devices) |
| |
| def mark_devices_fn(): |
| replica_id = self.evaluate( |
| distribute_lib.get_replica_context().replica_id_in_sync_group) |
| self.assertLess(replica_id, len(d.extended.worker_devices)) |
| self.assertFalse(expected_devices[replica_id]) |
| expected_devices[replica_id] = True |
| |
| d.extended.call_for_each_replica(mark_devices_fn) |
| self.assertAllEqual(expected_devices, |
| [True] * len(d.extended.worker_devices)) |
| |
| def _test_call_and_merge_exceptions(self, dist): |
| with dist.scope(): |
| with self.assertRaises(_TestException): |
| dist.extended.call_for_each_replica(_raise_exception_fn) |
| with self.assertRaises(_TestException): |
| dist.extended.call_for_each_replica(_merge_raises_fn) |
| with self.assertRaises(_TestException): |
| dist.extended.call_for_each_replica(_merge_call_raises_fn) |
| with self.assertRaises(_TestException): |
| dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) |
| |
| def _input_fn_to_test_input_context(self, dataset_or_callable_fn, |
| expected_num_replicas_in_sync, |
| expected_num_input_pipelines, |
| expected_input_pipeline_id): |
| # Use a list of one element as counter so that it can be captured by the |
| # `_input_fn`. This counter is incremented by 1 each time an input_fn is |
| # called. We use this counter to check whether the `input_pipeline_id` |
| # matches the counter in the in-graph replication. |
| worker_id_counter = [0] |
| |
| def _input_fn(input_context): |
| """Input fn for testing.""" |
| self.assertIsNotNone(input_context) |
| self.assertEqual(expected_num_replicas_in_sync, |
| input_context.num_replicas_in_sync) |
| self.assertEqual(expected_num_input_pipelines, |
| input_context.num_input_pipelines) |
| if expected_input_pipeline_id is not None: |
| self.assertEqual(expected_input_pipeline_id, |
| input_context.input_pipeline_id) |
| else: |
| self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) |
| worker_id_counter[0] += 1 |
| |
| return dataset_or_callable_fn() |
| |
| return _input_fn |
| |
| def _test_input_fn_iterable( |
| self, strategy, input_fn, expected_values, ignore_order=False): |
| assert_same = self.assertCountEqual if ignore_order else self.assertEqual |
| |
| iterable = strategy.distribute_datasets_from_function(input_fn) |
| if context.executing_eagerly(): |
| iterator = iter(iterable) |
| |
| for expected_value in expected_values: |
| computed_value = self.evaluate( |
| list(strategy.experimental_local_results(next(iterator)))) |
| assert_same(expected_value, computed_value) |
| |
| with self.assertRaises(StopIteration): |
| self.evaluate(strategy.experimental_local_results(next(iterator))) |
| |
| # After re-initializing the iterator, should be able to iterate again. |
| iterator = iter(iterable) |
| |
| for expected_value in expected_values: |
| computed_value = self.evaluate( |
| list(strategy.experimental_local_results(next(iterator)))) |
| assert_same(expected_value, computed_value) |
| else: |
| iterator = dataset_ops.make_initializable_iterator(iterable) |
| self._test_input_fn_iterator(iterator, strategy.extended.worker_devices, |
| expected_values, test_reinitialize=True, |
| ignore_order=ignore_order) |
| |
| def _test_input_fn_iterator(self, |
| iterator, |
| devices, |
| expected_values, |
| sess=None, |
| test_reinitialize=True, |
| ignore_order=False): |
| evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) |
| evaluate(iterator.initializer) |
| |
| for expected_value in expected_values: |
| next_element = iterator.get_next() |
| computed_value = evaluate( |
| [distribute_utils.select_replica(r, next_element) for r in |
| range(len(devices))]) |
| if ignore_order: |
| self.assertCountEqual(expected_value, computed_value) |
| else: |
| self.assertEqual(expected_value, computed_value) |
| |
| with self.assertRaises(errors.OutOfRangeError): |
| next_element = iterator.get_next() |
| evaluate( |
| [distribute_utils.select_replica(r, next_element) for r in |
| range(len(devices))]) |
| |
| # After re-initializing the iterator, should be able to iterate again. |
| if test_reinitialize: |
| evaluate(iterator.initializer) |
| |
| for expected_value in expected_values: |
| next_element = iterator.get_next() |
| computed_value = evaluate([ |
| distribute_utils.select_replica(r, next_element) for r in |
| range(len(devices)) |
| ]) |
| if ignore_order: |
| self.assertCountEqual(expected_value, computed_value) |
| else: |
| self.assertEqual(expected_value, computed_value) |
| |
| def _test_global_step_update(self, strategy): |
| with strategy.scope(): |
| global_step = variable_scope.get_variable( |
| "global_step", |
| shape=[], |
| dtype=dtypes.int64, |
| initializer=init_ops.zeros_initializer(), |
| trainable=False, |
| aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) |
| self.evaluate(variables.global_variables_initializer()) |
| |
| def model_fn(): |
| train_op = global_step.assign_add(1) |
| value = global_step.read_value() |
| return train_op, value |
| |
| train_ops, value = strategy.extended.call_for_each_replica(model_fn) |
| self.evaluate(strategy.group(train_ops)) |
| global_step_tensors = strategy.experimental_local_results(value) |
| global_step_values = self.evaluate(global_step_tensors) |
| self.assertEqual((1,) * len(global_step_tensors), global_step_values) |
| |
| def _test_numpy_dataset(self, strategy, session=None, run_in_function=False): |
| if not isinstance(strategy, distribute_lib.StrategyV1): |
| self.skipTest("n/a: V1 only") |
| cached_session = session or self.cached_session() |
| with strategy.scope(), cached_session as sess: |
| x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]]) |
| y = np.asarray([5, 4, 3, 2, 1, 0]) |
| batch_size = 6 |
| if not strategy.extended._global_batch_size: # pylint: disable=protected-access |
| batch_size = batch_size // strategy.num_replicas_in_sync |
| |
| ds = strategy.extended.experimental_make_numpy_dataset( |
| (x, y), session=sess or self.cached_session()) |
| ds = ds.repeat(2) # 2 epochs |
| # We need to use the drop_remainder argument to get a known static |
| # input shape which is required for TPUs. |
| drop_remainder = strategy.extended.experimental_require_static_shapes |
| ds = ds.batch(batch_size, drop_remainder=drop_remainder) |
| i = strategy.make_dataset_iterator(ds) |
| |
| self.evaluate(i.initializer) |
| |
| def run_and_concatenate(strategy, i): |
| x, y = strategy.experimental_run( |
| _maybe_run_in_function(lambda z: z, run_in_function), i) |
| x, y = self.evaluate((strategy.experimental_local_results(x), |
| strategy.experimental_local_results(y))) |
| return np.concatenate(x), np.concatenate(y) |
| |
| x_1, y_1 = run_and_concatenate(strategy, i) |
| self.assertAllEqual(x, x_1) |
| self.assertAllEqual(y, y_1) |
| x_2, y_2 = run_and_concatenate(strategy, i) |
| self.assertAllEqual(x, x_2) |
| self.assertAllEqual(y, y_2) |
| with self.assertRaises(errors.OutOfRangeError): |
| run_and_concatenate(strategy, i) |
| |
| def _test_trainable_variable(self, strategy): |
| for cls in [variable_v1.VariableV1, variables.Variable]: |
| with strategy.scope(): |
| v1 = cls(1.0) |
| self.assertEqual(True, v1.trainable) |
| |
| v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ) |
| self.assertEqual(False, v2.trainable) |
| |
| v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, |
| trainable=True) |
| self.assertEqual(True, v3.trainable) |
| |
| v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, |
| trainable=False) |
| self.assertEqual(False, v4.trainable) |
| |
| |
| class OneDeviceDistributionTestBase(test.TestCase): |
| """Some tests that should work with any one-device DistributionStrategy.""" |
| |
| def _test_run(self, strategy): |
| out1 = strategy.run(lambda: array_ops.identity(4.)) |
| self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) |
| |
| out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) |
| out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) |
| self.assertAllEqual([8.], out2_vals["a"]) |
| self.assertAllEqual([16.], out2_vals["b"]) |
| |
| out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2) |
| self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3))) |
| |
| def _test_all_reduce_sum(self, strategy): |
| self._test_collective_comms( |
| strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) |
| |
| def _test_all_reduce_sum_gradients(self, strategy): |
| self._test_collective_comms_gradients( |
| strategy, _all_sum, inputs=[4.], expected_grads=[4.]) |
| |
| def _test_all_reduce_sum_gradient_tape(self, strategy): |
| self._test_collective_comms_gradient_tape( |
| strategy, _all_sum, inputs=[4.], expected_grads=[4.]) |
| |
| def _test_all_reduce_mean(self, strategy): |
| self._test_collective_comms( |
| strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) |
| |
| def _test_all_reduce_mean_gradients(self, strategy): |
| self._test_collective_comms_gradients( |
| strategy, _all_mean, inputs=[5.], expected_grads=[5.]) |
| |
| def _test_all_reduce_mean_gradient_tape(self, strategy): |
| self._test_collective_comms_gradient_tape( |
| strategy, _all_mean, inputs=[5.], expected_grads=[5.]) |
| |
| def _test_collective_comms(self, strategy, comm_fn, inputs, expected): |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensors(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| outputs = self.evaluate( |
| list( |
| map(strategy.experimental_local_results, |
| strategy.experimental_run(comm_fn, inputs)))) |
| self.assertAllEqual([expected[0]], outputs[0]) |
| self.assertAllEqual([expected[1]], outputs[1]) |
| |
| def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, |
| expected_grads): |
| if context.executing_eagerly(): |
| self.skipTest("`tf.gradients` is not supported with eager execution.") |
| |
| def step(c): |
| x = array_ops.identity(42.) |
| y = comm_fn(x) * c |
| return gradients_impl.gradients(y, [x])[0] |
| |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensors(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| self.assertAllEqual( |
| expected_grads, |
| self.evaluate( |
| strategy.experimental_local_results( |
| strategy.experimental_run(step, inputs)))) |
| |
| def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, |
| expected_grads): |
| |
| def step(c): |
| x = array_ops.identity(42.) |
| with backprop.GradientTape() as tape: |
| tape.watch(x) |
| y = comm_fn(x) * c |
| return tape.gradient(y, x) |
| |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensors(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| self.assertAllEqual( |
| expected_grads, |
| self.evaluate( |
| strategy.experimental_local_results( |
| strategy.experimental_run(step, inputs)))) |
| |
| def _test_device_and_input_device_are_colocated(self, strategy): |
| if context.executing_eagerly(): |
| self.skipTest( |
| "cross-device tests are not supported with eager execution.") |
| workers, _ = test_util.create_local_cluster(2, 0) |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.range(5)) |
| comm_fn = lambda x: x + 1 |
| run_op = strategy.experimental_run(comm_fn, inputs) |
| with session_lib.Session(target=workers[1].target) as sess: |
| sess.run(inputs.initialize()) |
| sess.run(run_op) |
| |
| def _test_device_and_input_device_are_colocated_with_function(self, strategy): |
| if context.executing_eagerly(): |
| self.skipTest( |
| "cross-device tests are not supported with eager execution.") |
| workers, _ = test_util.create_local_cluster(2, 0) |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.range(5)) |
| comm_fn = lambda x: x + 1 |
| experimental_run = def_function.function()(strategy.experimental_run) |
| with ops.device("/job:worker/replica:0/task:1/device:CPU:0"): |
| # The tf.function must be defined on the right device as well. |
| run_op = experimental_run(comm_fn, inputs) |
| with session_lib.Session(target=workers[1].target) as sess: |
| sess.run(inputs.initialize()) |
| sess.run(run_op) |
| |
| |
| class TwoDeviceDistributionTestBase(test.TestCase): |
| """Some tests that should work with any two-device DistributionStrategy.""" |
| |
| def _test_run(self, strategy, run_in_function=False): |
| out1 = strategy.run(_maybe_run_in_function( |
| lambda: distribute_lib.get_replica_context().replica_id_in_sync_group + 1, |
| run_in_function)) |
| self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1))) |
| |
| out2 = strategy.run(_maybe_run_in_function( |
| lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,)) |
| out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) |
| self.assertAllEqual([2, 4], out2_vals["a"]) |
| self.assertAllEqual([1, 4], out2_vals["b"]) |
| |
| out3 = strategy.run(_maybe_run_in_function( |
| lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2) |
| self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3))) |
| |
| def _test_all_reduce_sum(self, strategy, run_in_function=False): |
| self._test_collective_comms( |
| strategy, |
| _all_sum, |
| inputs=([1., 3.], [[39., 2.], [3., 41.]]), |
| expected=(4., [42., 43.]), |
| run_in_function=run_in_function) |
| |
| def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False): |
| self._test_collective_comms_gradients( |
| strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], |
| run_in_function=run_in_function) |
| |
| def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False): |
| self._test_collective_comms_gradient_tape( |
| strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], |
| run_in_function=run_in_function) |
| |
| def _test_all_reduce_mean(self, strategy, run_in_function=False): |
| self._test_collective_comms( |
| strategy, |
| _all_mean, |
| inputs=([1., 3.], [[39., 2.], [3., 41.]]), |
| expected=(2., [21., 21.5]), |
| run_in_function=run_in_function) |
| |
| def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False): |
| self._test_collective_comms_gradients( |
| strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], |
| run_in_function=run_in_function) |
| |
| def _test_all_reduce_mean_gradient_tape(self, strategy, |
| run_in_function=False): |
| self._test_collective_comms_gradient_tape( |
| strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], |
| run_in_function=run_in_function) |
| |
| def _test_collective_comms(self, strategy, comm_fn, inputs, expected, |
| run_in_function=False): |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| outputs = self.evaluate( |
| list( |
| map(strategy.experimental_local_results, |
| strategy.experimental_run( |
| _maybe_run_in_function(comm_fn, run_in_function), inputs)))) |
| self.assertAllEqual([expected[0], expected[0]], outputs[0]) |
| self.assertAllEqual([expected[1], expected[1]], outputs[1]) |
| |
| def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, |
| expected_grads, run_in_function=False): |
| if context.executing_eagerly() and not run_in_function: |
| self.skipTest("`tf.gradients` is not supported with eager execution " |
| "without using tf.functions.") |
| |
| def step(c): |
| x = array_ops.identity(42.) |
| y = comm_fn(x) * c |
| return gradients_impl.gradients(y, [x])[0] |
| |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| self.assertAllEqual( |
| expected_grads, |
| self.evaluate( |
| strategy.experimental_local_results( |
| strategy.experimental_run( |
| _maybe_run_in_function(step, run_in_function), inputs)))) |
| |
| def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, |
| expected_grads, |
| run_in_function=False): |
| |
| def step(c): |
| x = array_ops.identity(42.) |
| with backprop.GradientTape() as tape: |
| tape.watch(x) |
| y = comm_fn(x) * c |
| return tape.gradient(y, x) |
| |
| inputs = strategy.make_input_fn_iterator( |
| lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) |
| |
| self.evaluate(inputs.initialize()) |
| self.assertAllEqual( |
| expected_grads, |
| self.evaluate( |
| strategy.experimental_local_results( |
| strategy.experimental_run( |
| _maybe_run_in_function(step, run_in_function), |
| inputs)))) |
| |
| |
| class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase): |
| """Tests for a Remote single worker.""" |
| |
| def _get_num_gpus(self): |
| pass |
| |
| def _testNumReplicasInSync(self, distribution): |
| self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync) |
| |
| def _testMinimizeLoss(self, distribution): |
| if context.executing_eagerly(): |
| self._test_minimize_loss_eager(distribution) |
| else: |
| self._test_minimize_loss_graph(distribution, learning_rate=0.05) |
| |
| def _testDeviceScope(self, distribution): |
| with distribution.scope(): |
| a = array_ops.identity(1.) |
| with ops.device("/cpu:0"): |
| b = array_ops.identity(1.) |
| if context.executing_eagerly(): |
| device = "/job:worker/replica:0/task:0/device:CPU:0" |
| else: |
| device = "/job:worker/replica:0/task:0" |
| self.assertEqual(a.device, device) |
| self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0") |
| |
| def _testMakeInputFnIteratorWithDataset(self, distribution): |
| dataset_fn = lambda: dataset_ops.Dataset.range(100) |
| num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return |
| num_workers = 1 |
| |
| expected_values = [[i+j for j in range(num_gpus)] * num_workers |
| for i in range(0, 100, num_gpus)] |
| |
| # Dummy cached_session is used in Eager |
| with self.cached_session() as sess: |
| # `expected_input_pipeline_id` is None because the input_fn will be called |
| # multiple times, each with a different input_pipeline_id. |
| input_fn = self._input_fn_to_test_input_context( |
| dataset_fn, |
| expected_num_replicas_in_sync=num_workers*num_gpus, |
| expected_num_input_pipelines=num_workers, |
| expected_input_pipeline_id=None) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator( |
| iterator, distribution.extended.worker_devices, expected_values, sess) |
| |
| def _testMakeInputFnIteratorWithCallable(self, distribution): |
| def fn(): |
| dataset = dataset_ops.Dataset.range(100) |
| it = dataset_ops.make_one_shot_iterator(dataset) |
| return it.get_next |
| |
| num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return |
| num_workers = 1 |
| |
| expected_values = [] |
| for i in range(0, 100, num_gpus): |
| expected_values.append([i+j for j in range(num_gpus)] * num_workers) |
| |
| # Dummy cached_session is used in Eager |
| with self.cached_session() as sess: |
| # `expected_input_pipeline_id` is None because the input_fn will be called |
| # multiple times, each with a different input_pipeline_id. |
| input_fn = self._input_fn_to_test_input_context( |
| fn, |
| expected_num_replicas_in_sync=num_workers*num_gpus, |
| expected_num_input_pipelines=num_workers, |
| expected_input_pipeline_id=None) |
| iterator = distribution.make_input_fn_iterator(input_fn) |
| self._test_input_fn_iterator( |
| iterator, distribution.extended.worker_devices, expected_values, sess, |
| test_reinitialize=False, ignore_order=True) |
| |
| |
| def _all_sum(value): |
| ctx = distribute_lib.get_replica_context() |
| return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) |
| |
| |
| def _all_mean(value): |
| ctx = distribute_lib.get_replica_context() |
| return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) |