blob: 911fccc7e56fd21440e42ddba6e77d8ce6577721 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for testing DistributionStrategy descendants."""
import functools
import os
import tempfile
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_util
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
class _TestException(Exception):
pass
# Conditionally wrap the fn in a def_function.function (so it runs in graph
# mode).
def _maybe_run_in_function(fn, run_in_function=False):
if not run_in_function or not context.executing_eagerly():
return fn
else:
return def_function.function()(fn)
# May be the argument to either distribution.extended.call_for_each_replica() or
# get_replica_context().merge_call()
def _raise_exception_fn(_=None):
raise _TestException()
# Must be the argument to a distribution.extended.call_for_each_replica() call,
# calls a get_replica_context().merge_call() that raises an exception.
def _merge_raises_fn():
distribute_lib.get_replica_context().merge_call(_raise_exception_fn)
# Must be the argument to a get_replica_context().merge_call() call, calls
# dist.extended.call_for_each_replica() with a function that raises an
# exception.
def _call_raises_fn(dist):
dist.extended.call_for_each_replica(_raise_exception_fn)
# Must be the argument to a distribution.extended.call_for_each_replica() call,
# calls a get_replica_context().merge_call() that calls a
# call_for_each_replica() that raises an exception.
def _merge_call_raises_fn():
distribute_lib.get_replica_context().merge_call(_call_raises_fn)
# Must be the argument to a get_replica_context().merge_call() call, calls
# dist.extended.call_for_each_replica() with a function that calls a
# get_replica_context().merge_call() that raises an exception.
def _call_merge_raises_fn(dist):
dist.extended.call_for_each_replica(_merge_raises_fn)
# Must be the argument to a distribution.extended.call_for_each_replica() call,
# calls a get_replica_context().merge_call() that calls a
# call_for_each_replica() that calls a get_replica_context().merge_call() that
# raises an exception.
def _merge_call_merge_raises_fn():
distribute_lib.get_replica_context().merge_call(_call_merge_raises_fn)
def _events_from_logdir(test_case, logdir):
"""Reads summary events from log directory."""
test_case.assertTrue(gfile.Exists(logdir))
files = gfile.ListDirectory(logdir)
test_case.assertLen(files, 1)
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
result = []
for r in records:
event = event_pb2.Event()
event.ParseFromString(r)
result.append(event)
return result
def create_variable_like_keras_layer(name, shape, dtype):
"""Utitlity for create variables that works like variable in keras layer."""
initializer = functools.partial(
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
return variables.Variable(
initial_value=initializer, name=name, trainable=True)
def is_optimizer_v2_instance(optimizer_obj):
# For a optimizer instance, the v2 implementation has var_list as a required
# argument.
arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize)
return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool:
return isinstance(
strategy,
(mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1))
def is_multi_worker_mirrored_strategy(
strategy: distribute_lib.Strategy) -> bool:
return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy,
mwms_lib.CollectiveAllReduceStrategyV1))
def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool:
return isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
tpu_strategy.TPUStrategyV2))
class DistributionTestBase(test.TestCase):
"""Some tests that should work with any DistributionStrategy."""
def _test_minimize_loss_eager(self, d):
with d.scope():
kernel = create_variable_like_keras_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x):
y = array_ops.reshape(
math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y
# TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
# common `implicit_grad` function and put it in DistributionStrategy.
grad_fn = backprop.implicit_grad(loss)
grad_fn = optimizer.get_filtered_grad_fn(grad_fn)
def update(v, g):
return v.assign_sub(0.2 * g)
one = array_ops.identity([[1.]])
def step():
"""Perform one optimization step."""
# Run forward & backward to get gradients, variables list.
g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
# Update the variables using the gradients and the update() function.
before_list = []
after_list = []
for g, v in g_v:
fetched = d.extended.read_var(v)
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
g = d.extended.reduce_to(
reduce_util.ReduceOp.SUM, g, destinations=v)
with ops.control_dependencies(
d.extended.update(v, update, args=(g,), group=False)):
after_list.append(d.extended.read_var(v))
return before_list, after_list
for i in range(10):
b, a = step()
if i == 0:
before, = b # pylint: disable=unbalanced-tuple-unpacking
after, = a # pylint: disable=unbalanced-tuple-unpacking
error_before = abs(before.numpy() - 1)
error_after = abs(after.numpy() - 1)
# Error should go down
self.assertLess(error_after, error_before)
def _test_minimize_loss_graph(self,
d,
soft_placement=False,
learning_rate=0.2):
config = config_pb2.ConfigProto()
config.allow_soft_placement = soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
with context.graph_mode(), \
ops.Graph().as_default(), \
self.cached_session(config=config) as sess, \
d.scope():
kernel = create_variable_like_keras_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x):
y = array_ops.reshape(
math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y
grad_fn = backprop.implicit_grad(loss)
def update(v, g):
return v.assign_sub(learning_rate * g)
one = array_ops.identity([[1.]])
def step():
"""Perform one optimization step."""
# Run forward & backward to get gradients, variables list.
g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
# Update the variables using the gradients and the update() function.
before_list = []
after_list = []
for g, v in g_v:
fetched = d.extended.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
g = d.extended.reduce_to(
reduce_util.ReduceOp.SUM, g, destinations=v)
with ops.control_dependencies(
d.extended.update(v, update, args=(g,), group=False)):
after_list.append(d.extended.read_var(v))
return before_list, after_list
before_out, after_out = step()
variables.global_variables_initializer().run()
for i in range(10):
b, a = sess.run((before_out, after_out))
if i == 0:
before, = b
after, = a
error_before = abs(before - 1)
error_after = abs(after - 1)
# Error should go down
self.assertLess(error_after, error_before)
def _test_summary_for_replica_zero_only(self, d):
logdir = tempfile.mkdtemp()
def run_fn():
"""Function executed for each replica."""
with summary_writer.as_default():
replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group
return summary_ops.write("a", replica_id)
with self.cached_session() as sess, d.scope(), \
summary_ops.always_record_summaries():
# We need global_step because summary writing op *always* has global_step
# as input, even when we always record summary or never record summary.
global_step = training_util.get_or_create_global_step()
if not context.executing_eagerly():
# When executing eagerly, variables are initialized immediately after
# creation, and its initializer will be None.
global_step.initializer.run()
summary_ops.set_step(0)
summary_writer = summary_ops.create_file_writer(logdir)
output = d.extended.call_for_each_replica(run_fn)
unwrapped = d.unwrap(output)
if not context.executing_eagerly():
sess.run(summary_writer.init())
sess.run(unwrapped)
sess.run(summary_writer.close())
events = _events_from_logdir(self, logdir)
# There will be 2 entries: 1 summary file header entry, and 1 entry
# written by replica 0.
self.assertLen(events, 2)
self.assertEqual(events[1].summary.value[0].tag, "a")
self.assertEqual(events[1].summary.value[0].simple_value, 0.0)
def _test_replica_id(self, d):
with d.scope():
expected_devices = [False] * len(d.extended.worker_devices)
def mark_devices_fn():
replica_id = self.evaluate(
distribute_lib.get_replica_context().replica_id_in_sync_group)
self.assertLess(replica_id, len(d.extended.worker_devices))
self.assertFalse(expected_devices[replica_id])
expected_devices[replica_id] = True
d.extended.call_for_each_replica(mark_devices_fn)
self.assertAllEqual(expected_devices,
[True] * len(d.extended.worker_devices))
def _test_call_and_merge_exceptions(self, dist):
with dist.scope():
with self.assertRaises(_TestException):
dist.extended.call_for_each_replica(_raise_exception_fn)
with self.assertRaises(_TestException):
dist.extended.call_for_each_replica(_merge_raises_fn)
with self.assertRaises(_TestException):
dist.extended.call_for_each_replica(_merge_call_raises_fn)
with self.assertRaises(_TestException):
dist.extended.call_for_each_replica(_merge_call_merge_raises_fn)
def _input_fn_to_test_input_context(self, dataset_or_callable_fn,
expected_num_replicas_in_sync,
expected_num_input_pipelines,
expected_input_pipeline_id):
# Use a list of one element as counter so that it can be captured by the
# `_input_fn`. This counter is incremented by 1 each time an input_fn is
# called. We use this counter to check whether the `input_pipeline_id`
# matches the counter in the in-graph replication.
worker_id_counter = [0]
def _input_fn(input_context):
"""Input fn for testing."""
self.assertIsNotNone(input_context)
self.assertEqual(expected_num_replicas_in_sync,
input_context.num_replicas_in_sync)
self.assertEqual(expected_num_input_pipelines,
input_context.num_input_pipelines)
if expected_input_pipeline_id is not None:
self.assertEqual(expected_input_pipeline_id,
input_context.input_pipeline_id)
else:
self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id)
worker_id_counter[0] += 1
return dataset_or_callable_fn()
return _input_fn
def _test_input_fn_iterable(
self, strategy, input_fn, expected_values, ignore_order=False):
assert_same = self.assertCountEqual if ignore_order else self.assertEqual
iterable = strategy.distribute_datasets_from_function(input_fn)
if context.executing_eagerly():
iterator = iter(iterable)
for expected_value in expected_values:
computed_value = self.evaluate(
list(strategy.experimental_local_results(next(iterator))))
assert_same(expected_value, computed_value)
with self.assertRaises(StopIteration):
self.evaluate(strategy.experimental_local_results(next(iterator)))
# After re-initializing the iterator, should be able to iterate again.
iterator = iter(iterable)
for expected_value in expected_values:
computed_value = self.evaluate(
list(strategy.experimental_local_results(next(iterator))))
assert_same(expected_value, computed_value)
else:
iterator = dataset_ops.make_initializable_iterator(iterable)
self._test_input_fn_iterator(iterator, strategy.extended.worker_devices,
expected_values, test_reinitialize=True,
ignore_order=ignore_order)
def _test_input_fn_iterator(self,
iterator,
devices,
expected_values,
sess=None,
test_reinitialize=True,
ignore_order=False):
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[distribute_utils.select_replica(r, next_element) for r in
range(len(devices))])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)
else:
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[distribute_utils.select_replica(r, next_element) for r in
range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
evaluate(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate([
distribute_utils.select_replica(r, next_element) for r in
range(len(devices))
])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)
else:
self.assertEqual(expected_value, computed_value)
def _test_global_step_update(self, strategy):
with strategy.scope():
global_step = variable_scope.get_variable(
"global_step",
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables.global_variables_initializer())
def model_fn():
train_op = global_step.assign_add(1)
value = global_step.read_value()
return train_op, value
train_ops, value = strategy.extended.call_for_each_replica(model_fn)
self.evaluate(strategy.group(train_ops))
global_step_tensors = strategy.experimental_local_results(value)
global_step_values = self.evaluate(global_step_tensors)
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
def _test_numpy_dataset(self, strategy, session=None, run_in_function=False):
if not isinstance(strategy, distribute_lib.StrategyV1):
self.skipTest("n/a: V1 only")
cached_session = session or self.cached_session()
with strategy.scope(), cached_session as sess:
x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]])
y = np.asarray([5, 4, 3, 2, 1, 0])
batch_size = 6
if not strategy.extended._global_batch_size: # pylint: disable=protected-access
batch_size = batch_size // strategy.num_replicas_in_sync
ds = strategy.extended.experimental_make_numpy_dataset(
(x, y), session=sess or self.cached_session())
ds = ds.repeat(2) # 2 epochs
# We need to use the drop_remainder argument to get a known static
# input shape which is required for TPUs.
drop_remainder = strategy.extended.experimental_require_static_shapes
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
i = strategy.make_dataset_iterator(ds)
self.evaluate(i.initializer)
def run_and_concatenate(strategy, i):
x, y = strategy.experimental_run(
_maybe_run_in_function(lambda z: z, run_in_function), i)
x, y = self.evaluate((strategy.experimental_local_results(x),
strategy.experimental_local_results(y)))
return np.concatenate(x), np.concatenate(y)
x_1, y_1 = run_and_concatenate(strategy, i)
self.assertAllEqual(x, x_1)
self.assertAllEqual(y, y_1)
x_2, y_2 = run_and_concatenate(strategy, i)
self.assertAllEqual(x, x_2)
self.assertAllEqual(y, y_2)
with self.assertRaises(errors.OutOfRangeError):
run_and_concatenate(strategy, i)
def _test_trainable_variable(self, strategy):
for cls in [variable_v1.VariableV1, variables.Variable]:
with strategy.scope():
v1 = cls(1.0)
self.assertEqual(True, v1.trainable)
v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ)
self.assertEqual(False, v2.trainable)
v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
trainable=True)
self.assertEqual(True, v3.trainable)
v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
trainable=False)
self.assertEqual(False, v4.trainable)
class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""
def _test_run(self, strategy):
out1 = strategy.run(lambda: array_ops.identity(4.))
self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([8.], out2_vals["a"])
self.assertAllEqual([16.], out2_vals["b"])
out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2)
self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3)))
def _test_all_reduce_sum(self, strategy):
self._test_collective_comms(
strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
def _test_all_reduce_sum_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
def _test_all_reduce_sum_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
def _test_all_reduce_mean(self, strategy):
self._test_collective_comms(
strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
def _test_all_reduce_mean_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
def _test_all_reduce_mean_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(
map(strategy.experimental_local_results,
strategy.experimental_run(comm_fn, inputs))))
self.assertAllEqual([expected[0]], outputs[0])
self.assertAllEqual([expected[1]], outputs[1])
def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
expected_grads):
if context.executing_eagerly():
self.skipTest("`tf.gradients` is not supported with eager execution.")
def step(c):
x = array_ops.identity(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(
strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
expected_grads):
def step(c):
x = array_ops.identity(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
return tape.gradient(y, x)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(
strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
def _test_device_and_input_device_are_colocated(self, strategy):
if context.executing_eagerly():
self.skipTest(
"cross-device tests are not supported with eager execution.")
workers, _ = test_util.create_local_cluster(2, 0)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.range(5))
comm_fn = lambda x: x + 1
run_op = strategy.experimental_run(comm_fn, inputs)
with session_lib.Session(target=workers[1].target) as sess:
sess.run(inputs.initialize())
sess.run(run_op)
def _test_device_and_input_device_are_colocated_with_function(self, strategy):
if context.executing_eagerly():
self.skipTest(
"cross-device tests are not supported with eager execution.")
workers, _ = test_util.create_local_cluster(2, 0)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.range(5))
comm_fn = lambda x: x + 1
experimental_run = def_function.function()(strategy.experimental_run)
with ops.device("/job:worker/replica:0/task:1/device:CPU:0"):
# The tf.function must be defined on the right device as well.
run_op = experimental_run(comm_fn, inputs)
with session_lib.Session(target=workers[1].target) as sess:
sess.run(inputs.initialize())
sess.run(run_op)
class TwoDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any two-device DistributionStrategy."""
def _test_run(self, strategy, run_in_function=False):
out1 = strategy.run(_maybe_run_in_function(
lambda: distribute_lib.get_replica_context().replica_id_in_sync_group + 1,
run_in_function))
self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.run(_maybe_run_in_function(
lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([2, 4], out2_vals["a"])
self.assertAllEqual([1, 4], out2_vals["b"])
out3 = strategy.run(_maybe_run_in_function(
lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2)
self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
def _test_all_reduce_sum(self, strategy, run_in_function=False):
self._test_collective_comms(
strategy,
_all_sum,
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
expected=(4., [42., 43.]),
run_in_function=run_in_function)
def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False):
self._test_collective_comms_gradients(
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
run_in_function=run_in_function)
def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False):
self._test_collective_comms_gradient_tape(
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
run_in_function=run_in_function)
def _test_all_reduce_mean(self, strategy, run_in_function=False):
self._test_collective_comms(
strategy,
_all_mean,
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
expected=(2., [21., 21.5]),
run_in_function=run_in_function)
def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False):
self._test_collective_comms_gradients(
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
run_in_function=run_in_function)
def _test_all_reduce_mean_gradient_tape(self, strategy,
run_in_function=False):
self._test_collective_comms_gradient_tape(
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
run_in_function=run_in_function)
def _test_collective_comms(self, strategy, comm_fn, inputs, expected,
run_in_function=False):
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(
map(strategy.experimental_local_results,
strategy.experimental_run(
_maybe_run_in_function(comm_fn, run_in_function), inputs))))
self.assertAllEqual([expected[0], expected[0]], outputs[0])
self.assertAllEqual([expected[1], expected[1]], outputs[1])
def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
expected_grads, run_in_function=False):
if context.executing_eagerly() and not run_in_function:
self.skipTest("`tf.gradients` is not supported with eager execution "
"without using tf.functions.")
def step(c):
x = array_ops.identity(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(
strategy.experimental_local_results(
strategy.experimental_run(
_maybe_run_in_function(step, run_in_function), inputs))))
def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
expected_grads,
run_in_function=False):
def step(c):
x = array_ops.identity(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
return tape.gradient(y, x)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(
strategy.experimental_local_results(
strategy.experimental_run(
_maybe_run_in_function(step, run_in_function),
inputs))))
class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
"""Tests for a Remote single worker."""
def _get_num_gpus(self):
pass
def _testNumReplicasInSync(self, distribution):
self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync)
def _testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution, learning_rate=0.05)
def _testDeviceScope(self, distribution):
with distribution.scope():
a = array_ops.identity(1.)
with ops.device("/cpu:0"):
b = array_ops.identity(1.)
if context.executing_eagerly():
device = "/job:worker/replica:0/task:0/device:CPU:0"
else:
device = "/job:worker/replica:0/task:0"
self.assertEqual(a.device, device)
self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0")
def _testMakeInputFnIteratorWithDataset(self, distribution):
dataset_fn = lambda: dataset_ops.Dataset.range(100)
num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return
num_workers = 1
expected_values = [[i+j for j in range(num_gpus)] * num_workers
for i in range(0, 100, num_gpus)]
# Dummy cached_session is used in Eager
with self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess)
def _testMakeInputFnIteratorWithCallable(self, distribution):
def fn():
dataset = dataset_ops.Dataset.range(100)
it = dataset_ops.make_one_shot_iterator(dataset)
return it.get_next
num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return
num_workers = 1
expected_values = []
for i in range(0, 100, num_gpus):
expected_values.append([i+j for j in range(num_gpus)] * num_workers)
# Dummy cached_session is used in Eager
with self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess,
test_reinitialize=False, ignore_order=True)
def _all_sum(value):
ctx = distribute_lib.get_replica_context()
return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
def _all_mean(value):
ctx = distribute_lib.get_replica_context()
return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)