blob: bc78a21c3dab7fa19096fd8138367d00e3170ac3 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import threading
from absl.testing import parameterized
from tensorflow.python.checkpoint import checkpoint as tracking
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.util import nest
# When running collectives asynchronously, we need to give each parallel device
# execution a unique ID so the collectives don't interfere. Since the op is
# replicated with group/instance key intact, the replicated nodes will
# communicate.
# TODO(allenl): Switch to using a collective manager.
_COUNTER_LOCK = threading.Lock()
_COUNTER = 100
def _collective_reduce(inputs, operation, num_replicas):
def _reduce_tensor(tensor):
with _COUNTER_LOCK:
global _COUNTER
keys = _COUNTER
_COUNTER += 1
return collective_ops.all_reduce_v2(
t=tensor,
group_size=num_replicas,
merge_op=operation,
group_key=keys,
instance_key=keys)
return nest.map_structure(_reduce_tensor, inputs)
def _collective_sum(inputs, num_replicas):
return _collective_reduce(
inputs=inputs, operation="Add", num_replicas=num_replicas)
class _Dense(module.Module):
def __init__(self, output_size):
self.output_size = output_size
self.kernel = None
self.bias = None
def __call__(self, x):
if self.kernel is None:
self.kernel = variables.Variable(
array_ops.ones(
array_ops_stack.stack([self.output_size,
array_ops.shape(x)[-1]])))
self.bias = variables.Variable(array_ops.ones([self.output_size]))
return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias
class _VirtualDeviceTestCase(test.TestCase):
def setUp(self):
super(_VirtualDeviceTestCase, self).setUp()
ctx = context.context()
if ctx.list_physical_devices("TPU"):
self.device_type = "TPU"
tpu_cluster_resolver.initialize_tpu_system()
elif ctx.list_physical_devices("GPU"):
self.device_type = "GPU"
gpus = ctx.list_physical_devices(self.device_type)
ctx.set_logical_device_configuration(gpus[0], [
context.LogicalDeviceConfiguration(memory_limit=100),
context.LogicalDeviceConfiguration(memory_limit=100),
])
else:
self.device_type = "CPU"
cpus = ctx.list_physical_devices("CPU")
ctx.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
self.device = parallel_device.ParallelDevice(components=[
"/job:localhost/device:{}:0".format(self.device_type),
self.device_type + ":1"
])
self.assertIn(self.device_type + ":0", self.device.components[0])
self.assertIn(self.device_type + ":1", self.device.components[1])
class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):
def test_register_parallel_device(self):
with self.device:
c = constant_op.constant(1.)
d = constant_op.constant(2.)
e = c + d
outputs = self.device.unpack(e)
self.assertAllClose([3., 3.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_no_implicit_copyon(self):
a1 = constant_op.constant(1.)
a2 = constant_op.constant(2.)
with self.device:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"First pack non-parallel tensors for each device"):
a1 + a2 # pylint:disable=pointless-statement
def test_error_message_length(self):
x = array_ops.ones([3, 3, 3, 3, 3, 3])
with self.device:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"TensorHandle\((.|\n){1,150}\[...\], shape="):
array_ops.identity(x)
def test_one_replica_eager_control_flow(self):
device = parallel_device.ParallelDevice(components=[
"/job:localhost/device:{}:0".format(self.device_type),
])
x = constant_op.constant([2, 3, 4])
with device:
x = device.pack([x])
if math_ops.reduce_any(math_ops.equal(x, constant_op.constant(4))):
y = constant_op.constant(1)
else:
y = constant_op.constant(2)
self.assertAllEqual([1], device.unpack(y))
@parameterized.named_parameters(
("variable", variables.Variable),
("tensor", lambda x: x))
def test_string_representation(self, transform):
x = self.device.pack(
[constant_op.constant([5., 6.]),
constant_op.constant([6., 7.])])
with self.device:
x = transform(x)
parallel_str = str(x)
self.assertIn("5", parallel_str)
self.assertIn("7", parallel_str)
self.assertIn(self.device_type + ":0", parallel_str)
self.assertIn(self.device_type + ":1", parallel_str)
parallel_repr = repr(x)
self.assertIn("5", parallel_repr)
self.assertIn("7", parallel_repr)
self.assertIn(self.device_type + ":0", parallel_repr)
self.assertIn(self.device_type + ":1", parallel_repr)
def test_device_id(self):
device_ids = self.device.unpack(self.device.device_ids)
self.assertAllClose([0, 1], device_ids)
# TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
# Currently backing_device is CPU.
self.assertIn(self.device.components[0], device_ids[0].device)
self.assertIn(self.device.components[1], device_ids[1].device)
def test_zeros(self):
with self.device:
x = array_ops.zeros([array_ops.identity(constant_op.constant(10))])
for component in self.device.unpack(x):
self.assertAllClose([0.] * 10, component)
def test_generator(self):
with self.device:
g_same = stateful_random_ops.Generator.from_seed(0)
g_different = stateful_random_ops.Generator.from_seed(
self.device.device_ids)
same = g_same.normal([10])
different = g_different.normal([10])
same_unpacked = self.device.unpack(same)
different_unpacked = self.device.unpack(different)
for same_component, different_component in zip(same_unpacked[1:],
different_unpacked[1:]):
self.assertAllClose(same_component, same_unpacked[0])
self.assertNotAllClose(different_component, different_unpacked[0])
def test_collective_reduce(self):
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
with self.device:
reduced = _collective_sum(x, num_replicas=2)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_in_function(self):
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
with self.device:
@def_function.function
def reduce(t):
return _collective_sum(t, num_replicas=2)
reduced = reduce(x)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async_scope(self):
# Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks.
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
with context.async_scope(), self.device:
reduced = _collective_sum(x, num_replicas=2)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async_context(self):
previous = config.get_synchronous_execution()
try:
context._reset_context()
config.set_synchronous_execution(False)
self.setUp()
# Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks.
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
with self.device:
reduced = _collective_sum(x, num_replicas=2)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
finally:
context._reset_context()
config.set_synchronous_execution(previous)
def test_collective_broadcast_in_function(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice broadcast collectives on TPUs need work")
@def_function.function
def broadcast_send_recv(device_id):
c = constant_op.constant([2])
@def_function.function
def send():
s0 = collective_ops.broadcast_send(
c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
with ops.control_dependencies([s0.op]):
return array_ops.identity(c)
@def_function.function
def recv():
r0 = collective_ops.broadcast_recv(
c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
return r0
return control_flow_switch_case.switch_case(
device_id, branch_fns={
0: send,
1: recv
})
with self.device:
result = broadcast_send_recv(self.device.device_ids)
self.assertAllClose([[2], [6]], self.device.unpack(result))
def test_use_in_graph_error_is_informative(self):
@def_function.function
def uses_parallel():
with self.device:
return self.device.unpack(array_ops.ones([]))
with self.assertRaisesRegex(NotImplementedError, "inside `tf.function`"):
uses_parallel()
def test_checkpointing(self):
self.skipTest("b/216201668: revisit parallel device and checkpointing.")
prefix = os.path.join(self.get_temp_dir(), "ckpt")
different_values = self.device.pack(
[constant_op.constant(-1.),
constant_op.constant(3.)])
with self.device:
v = variables.Variable(different_values)
checkpoint = tracking.Checkpoint(v=v)
save_path = checkpoint.save(prefix)
with self.device:
v.assign(constant_op.constant(0.))
checkpoint.restore(save_path).assert_consumed()
with self.device:
outputs = self.device.unpack(v)
self.assertAllClose([-1., 3.], outputs)
with self.device:
restore_on_create = tracking.Checkpoint()
restore_on_create.restore(save_path)
restore_on_create.v = variables.Variable(0.)
outputs = self.device.unpack(restore_on_create.v)
self.assertAllClose([-1., 3.], outputs)
# Changing the number of devices / restoring into a single-device copy is OK
single_device = tracking.Checkpoint(v=variables.Variable(0.))
status = single_device.restore(save_path)
status.assert_existing_objects_matched()
self.assertAllClose(-1., single_device.v)
with self.assertRaisesRegex(AssertionError, "parallel_component_1"):
# There are parts of the variable that aren't restored into a
# single-device copy.
status.assert_consumed()
def test_pack_composite(self):
if self.device_type != "CPU":
self.skipTest("Iterator GetNext doesn't work on accelerators.")
datasets = [
dataset_ops.Dataset.from_tensor_slices(
[i + 1, (i + 1) * 2, (i + 1) * 3])
for i in range(len(self.device.components))]
parallel_dataset = self.device.pack(datasets)
with self.device:
iterator = iter(parallel_dataset)
parallel_sample = next(iterator)
component_iterators = self.device.unpack(iterator)
self.assertEqual(2, next(component_iterators[0]).numpy())
self.assertEqual(1, self.device.unpack(parallel_sample)[0].numpy())
self.assertEqual(4, next(component_iterators[1]).numpy())
self.assertEqual(2, self.device.unpack(parallel_sample)[1].numpy())
def test_pack_structure(self):
x_parts = [{"a": constant_op.constant(float(i))}
for i in range(len(self.device.components))]
x = self.device.pack(x_parts)
self.assertAllClose([{"a": 0.}, {"a": 1.}], self.device.unpack(x))
def test_pack_variable_value(self):
x_parts = [variables.Variable(i)
for i in range(len(self.device.components))]
x = self.device.pack(x_parts)
with self.device:
x1 = self.device.pack(x_parts)
for v in x_parts:
v.assign(-10) # Mutating the variable does not affect previous reads.
self.assertAllClose([0, 1], self.device.unpack(x))
self.assertAllClose([0, 1], self.device.unpack(x1))
def test_unpack_variable_value(self):
x_parts = [constant_op.constant(i)
for i in range(len(self.device.components))]
x = self.device.pack(x_parts)
with self.device:
v = variables.Variable(x)
v_unpacked = self.device.unpack(v)
v.assign(-10) # Mutating the variable does not affect previous reads.
self.assertAllClose([0, 1], v_unpacked)
def test_saved_model(self):
self.skipTest("b/216201668: revisit parallel device and saved model")
different_values = self.device.pack(
[constant_op.constant(-1.),
constant_op.constant(3.)])
with self.device:
m = module.Module()
m.v = variables.Variable(different_values)
m.f = def_function.function(lambda: m.v * 2.)
self.assertAllClose([-2., 6.], self.device.unpack(m.f()))
saved_model_path = os.path.join(self.get_temp_dir(), "saved_model")
save.save(m, saved_model_path)
context._reset_context()
self.setUp()
single_device_loaded = load.load(saved_model_path)
self.assertAllClose(-2., single_device_loaded.f())
assign_value = self.device.pack(
[constant_op.constant(.1), constant_op.constant(.2)])
with self.device:
parallel_loaded = load.load(saved_model_path)
self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f()))
self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v))
parallel_loaded.v.assign(assign_value)
self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))
def _assert_close_to_non_parallel(self, computation):
"""Asserts that replication of `computation` works and is equivalent."""
with self.device:
parallel_result = computation()
non_parallel_result = computation()
# The computations should have the same number and structure of Tensor
# objects, even though the tensors themselves will be on different devices
# and represent different numbers of values.
nest.assert_same_structure(parallel_result, non_parallel_result)
non_parallel_flat = nest.flatten(non_parallel_result)
parallel_flat = nest.flatten(parallel_result)
self.assertGreater(len(parallel_flat), 0)
for non_parallel, parallel in zip(non_parallel_flat, parallel_flat):
self.assertEqual(self.device._name, parallel.device)
self.assertNotEqual(self.device._name, non_parallel.device)
for parallel_component in self.device.unpack(parallel):
self.assertAllClose(non_parallel, parallel_component)
def test_capturing(self):
with self.device:
x = constant_op.constant([1., 2.])
x = array_ops.identity(x)
@def_function.function
def f(y):
return x + y
y = array_ops.ones([2])
parallel_result = f(y)
self.assertAllClose([[2., 3.]] * 2, self.device.unpack(parallel_result))
def test_euclidean_norm(self):
def _test_fn():
with backprop.GradientTape() as tape:
x = array_ops.ones([5, 5])
tape.watch(x)
y = math_ops.reduce_euclidean_norm(x, axis=constant_op.constant(1))
return y, tape.gradient(y, x)
self._assert_close_to_non_parallel(_test_fn)
def test_reduce_sum(self):
def _test_fn():
with backprop.GradientTape() as tape:
x = array_ops.ones([5, 5])
tape.watch(x)
y = math_ops.reduce_sum(x, axis=constant_op.constant(1))
return y, tape.gradient(y, x)
self._assert_close_to_non_parallel(_test_fn)
def test_variable_created_in_function(self):
captured_value = constant_op.constant(2.)
class M(module.Module):
def __init__(self):
self.v = None
self.w = None
self.x = None
self.z = None
@def_function.function(autograph=False)
def __call__(self, x):
if self.v is None:
with ops.init_scope():
initial_value = constant_op.constant(2.)
self.z = variables.Variable(initial_value)
self.x = variables.Variable(captured_value)
self.w = variables.Variable(lambda: constant_op.constant(2.))
self.v = variables.Variable(constant_op.constant(2.))
return x * self.v * self.w * self.x * self.z
with self.device:
m = M()
packed_outputs = m(array_ops.ones([]))
outputs = self.device.unpack(packed_outputs)
self.assertAllClose([16., 16.], outputs)
def test_different_shapes(self):
x = self.device.pack(
[constant_op.constant([1., 2.]),
constant_op.constant([5.])])
with self.device:
y = x * 2.
self.assertEqual([None], y.shape.as_list())
self.assertAllClose([[2., 4.], [10.]], self.device.unpack(y))
different_axes = self.device.pack(
[constant_op.constant([1., 2.]),
constant_op.constant([[5.]])])
with self.assertRaisesRegex(Exception,
"components do not all have the same rank"):
different_axes.shape # pylint: disable=pointless-statement
class LayerTests(_VirtualDeviceTestCase):
def test_layer_forward(self):
with self.device:
layer = _Dense(5)
x = constant_op.constant([[2.]])
y = layer(x)
outputs = self.device.unpack(y)
self.assertAllClose([[3.] * 5], outputs[0])
self.assertAllClose([[3.] * 5], outputs[1])
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
# With different Layer inputs we get different outputs
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
with self.device:
y = layer(x)
outputs = self.device.unpack(y)
self.assertGreater(
math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_layer_sync_training(self):
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
with self.device:
layer = _Dense(5)
with backprop.GradientTape() as tape:
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
for grad, param in zip(reduced_gradients, parameters):
param.assign_sub(0.01 * grad)
final_kernels = self.device.unpack(layer.kernel)
self.assertAllClose(final_kernels[0], final_kernels[1])
final_bias = self.device.unpack(layer.bias)
expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
0.01 * 2. * (1. - .5 - math_ops.range(5.)))
self.assertAllClose(expected_bias, final_bias[0], rtol=1e-4, atol=1e-4)
self.assertAllClose(expected_bias, final_bias[1], rtol=1e-4, atol=1e-4)
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
def test_layer_divergent_buffer_training(self):
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
with self.device:
layer = _Dense(5)
with backprop.GradientTape() as tape:
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
for grad, param in zip(unreduced_gradients, parameters):
param.assign_sub(0.01 * grad)
final_kernels = self.device.unpack(layer.kernel)
self.assertNotAllClose(final_kernels[0], final_kernels[1])
final_bias = self.device.unpack(layer.bias)
self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
final_bias[0])
self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
final_bias[1])
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
def test_training_loop(self):
self.skipTest("b/216201668: revisit parallel device and checkpointing")
for _ in range(5):
layer = _Dense(5)
checkpoint = tracking.Checkpoint(layer=layer)
manager = checkpoint_management.CheckpointManager(
checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
manager.restore_or_initialize()
for _ in range(10):
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
with self.device:
with backprop.GradientTape() as tape:
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
reduced_gradients = _collective_sum(
unreduced_gradients, num_replicas=len(self.device.components))
for grad, param in zip(reduced_gradients, parameters):
param.assign_sub(0.01 * grad)
manager.save()
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()