blob: ab3711e996d79ece182b69c75a2ff6231d82b9d1 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUStrategy."""
from absl import logging
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu_hardware_feature
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_cluster_resolver():
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
return resolver
def get_tpu_strategy(enable_packed_var=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
return strategy
# TPU tests which don't use TPUStrategy.
@test_util.with_eager_op_as_function
class TPUTest(test.TestCase):
# In this case, the entire computation in foo is compiled using JIT
# compilation.
def test_single_tpu_jit_compile(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
def get_a_plus_one():
return a + 1
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def foo(x):
b = x + get_a_plus_one()
b = b + get_a_plus_one()
return b + 1
with ops.device("/device:TPU:0"):
result = foo(a)
self.assertAllEqual(6, result)
# In this case, the entire computation in foo is compiled using JIT
# compilation and contains unsupported ops that should be outside compiled.
def test_single_tpu_jit_compile_with_outside_compilation(self):
context.enable_jit_compile_rewrite()
get_tpu_strategy(True)
config.set_soft_device_placement(True)
with ops.device("/device:TPU:1"):
a = variables.Variable(1)
def get_a_plus_one():
return a + 1
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def foo(x):
b = x + get_a_plus_one()
my_str = string_ops.as_string(b)
new_str = my_str + "0"
c = string_ops.string_to_number(new_str, out_type=dtypes.int32)
logging_ops.print_v2(c)
b = c + get_a_plus_one()
return b + 1
with ops.device("/device:TPU:1"):
result = foo(a)
self.assertAllEqual(33, result)
# In this case, each of the ops in the TPU device scope are compiled and run
# individually.
def test_single_tpu_on_demand(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
def get_a_plus_one():
return a + 1
x = 1
with ops.device("/device:TPU:0"):
b = x + get_a_plus_one()
b = b + get_a_plus_one()
result = b + 1
self.assertAllEqual(6, result)
# In this case, each of the ops in the tf.function and TPU device scope are
# compiled and run individually.
def test_single_tpu_on_demand_tf_function(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
def get_a_plus_one():
return a + 1
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def foo(x):
with ops.device("/device:TPU:0"):
b = x + get_a_plus_one()
b = b + get_a_plus_one()
return b + 1
result = foo(a)
self.assertAllEqual(6, result)
def test_multiple_initialize_system(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_cluster_resolver.initialize_tpu_system(resolver)
with test.mock.patch.object(logging, "warning") as mock_log:
tpu_cluster_resolver.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized")
def test_initialize_tpu_system_impl_input(self):
resolver = get_tpu_cluster_resolver()
with self.assertRaisesRegex(
TypeError,
r"tpu_cluster_resolver_cls is not"
r" tf.distribute.cluster_resolver.TPUClusterResolver."):
tpu_strategy_util.initialize_tpu_system_impl(
resolver, tpu_cluster_resolver_cls=None)
def test_shutdown_tpu_system_impl_input(self):
resolver = get_tpu_cluster_resolver()
with self.assertRaisesRegex(
TypeError,
r"tpu_cluster_resolver_cls is not"
r" tf.distribute.cluster_resolver.TPUClusterResolver."):
tpu_strategy_util.shutdown_tpu_system_impl(
resolver, tpu_cluster_resolver_cls=None)
def test_tpu_tf_function_same_device(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
@def_function.function(experimental_attributes={"_noinline": True})
def get_a_plus_one():
return a + 1
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def foo(x):
with ops.device("/device:TPU:0"):
b = x + get_a_plus_one()
return b + 1
result = foo(a)
self.assertAllEqual(4, result)
def test_tpu_return_int32(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(0)
@def_function.function
def foo():
return a + 1
@def_function.function
def bar():
with ops.device("/device:TPU:1"):
return foo()
with ops.device("/device:CPU:0"):
result = bar() + 1
self.assertAllEqual(result, 2)
def test_tpu_output_device(self):
def foo():
return 1 + 1
func1 = def_function.function(foo, jit_compile=False)
func2 = def_function.function(
foo,
jit_compile=False,
experimental_attributes={
"_OutputsOnOpDevice": True,
},
)
with ops.device("/device:TPU:0"):
ret1 = func1()
ret2 = func2()
self.assertAllEqual(ret1.backing_device,
"/job:localhost/replica:0/task:0/device:CPU:0")
self.assertAllEqual(ret2.backing_device,
"/job:localhost/replica:0/task:0/device:TPU:0")
def test_on_demand_op_with_dynamic_output(self):
with ops.device("/device:TPU:0"):
where_output = array_ops.where([True, False, True])
self.assertAllEqual(where_output, [[0], [2]])
with ops.device("/device:TPU:0"):
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
@parameterized.named_parameters([("PackedVar", True), ("", False)])
@test_util.with_eager_op_as_function
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
def test_handle_in_cross_replica_context(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(1.0)
@def_function.function
def func():
self.assertEndsWith(v.handle.device, "device:TPU:0")
return v + 1.0
ret = func()
self.assertAllEqual(ret, 2.0)
def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var):
self._dataset_fn_tracing_count = 0
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
vals = [0, 1, 2]
keys_tensor = constant_op.constant(
list(range(len(vals))), dtype=dtypes.int64)
vals_tensor = constant_op.constant(vals)
initializer = lookup_ops.KeyValueTensorInitializer(
keys_tensor, vals_tensor)
per_worker_table = lookup_ops.StaticHashTable(
initializer, default_value=-1)
@def_function.function
def dataset_fn(input_context):
tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
global_batch_size = 2
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
batch_size, drop_remainder=True)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.prefetch(2) # This prefetches 2 batches per device.
dataset = dataset.map(per_worker_table.lookup)
self._dataset_fn_tracing_count += 1
return dataset
dist_iterator = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn))
@def_function.function
def step_fn(inputs):
# inputs should be [0, 1, -1]
return math_ops.reduce_sum(inputs)
def train_steps(iterator, steps):
for _ in math_ops.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_steps(dist_iterator, steps=5)
self.assertEqual(self._dataset_fn_tracing_count, 1)
def test_function_compile_with_xla(self, enable_packed_var):
if FLAGS.tpu_use_tfrt:
self.skipTest(
"This test triggers _XlaCompile and XlaLaunch which are not "
"supported in tfrt yet. We should avoid using these kernels on TPU. "
"However, it is a workaround to support b/129842431. We need more "
"discussion about how to support it in the long term.")
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(1.0)
@def_function.function
def func():
return v.read_value() + 1.0
with ops.device("/device:TPU:0"):
self.assertAllEqual(func(), 2.0)
def test_sequential_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_switch_case.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self, enable_packed_var):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaises(errors.OpError):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(
self, enable_packed_var):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
input0, input1 = inputs
return array_ops.size(input0), math_ops.reduce_sum(input1)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaises(errors.InvalidArgumentError):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_control_output_in_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(
0.0, aggregation=variables.VariableAggregation.MEAN)
@def_function.function
def train_step():
def step_fn():
v.assign_add(1)
for _ in math_ops.range(2):
strategy.run(step_fn)
train_step()
self.assertEqual(2.0, v.numpy())
def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def shape_list(tensor):
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dynamic_shape = array_ops.shape(input=tensor)
for index in non_static_indexes:
shape[index] = dynamic_shape[index]
return shape
def step_fn(condition):
where = array_ops.where(condition)
if array_ops.shape(where)[0] > 0:
tensor_shape = shape_list(where)
d1 = tensor_shape[0]
d2 = tensor_shape[1]
where = array_ops.reshape(where, [d1, d2])
return where
return strategy.run(step_fn, args=([True, False, True],))
outputs = strategy.experimental_local_results(train_step())
self.assertAllEqual(outputs[0].numpy(), [[0], [2]])
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def step_fn(prev):
s = prev + 1
return s
def init_fn():
return array_ops.zeros(shape=())
prev = strategy.run(init_fn)
for _ in math_ops.range(10):
prev = strategy.run(step_fn, args=(prev,))
return strategy.reduce(reduce_util.ReduceOp.SUM, prev, axis=None)
sum_val = train_step().numpy().astype(float)
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
def test_two_clusters_with_same_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def foo(x):
return strategy.run(lambda x: x + 1, (x,))
@def_function.function
def bar(x):
foo(x)
return foo(x)
bar(1)
def test_tpu_variable_run_argument(self, enable_packed_var):
# TPUStrategy.run() casts inputs to Tensor, but has logic to preserve
# variables to avoid unintuitive errors.
# Here we test that a TPUDistributedVariable passed to TPUStrategy.run()
# remains a variable.
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
tpu_variable = variables.Variable(1)
def replica_step(first_arg, variable):
del first_arg # Just here to make sure we're not relying on arg position.
if variable is not None:
self.assertIsInstance(variable, tpu_values.TPUDistributedVariable)
@def_function.function
def step():
strategy.run(
replica_step, args=(
2,
tpu_variable,
))
step()
def test_tpu_run_arg_parsing(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
tpu_vars = [variables.Variable(1)]
def only_star_args(*args):
del args
def pos_and_star_args(first_arg, *args):
del first_arg
del args
def named_args(first_arg, second_arg):
del first_arg
del second_arg
def star_args_and_kw_only(*args, kw):
del args
del kw
# pylint:disable=function-redefined
@def_function.function
def step():
strategy.run(only_star_args, args=(2,))
step()
@def_function.function
def step():
strategy.run(named_args, kwargs={"first_arg": 2, "second_arg": 3})
step()
with self.assertRaisesRegex(TypeError, r"got multiple values for argument"):
@def_function.function
def step():
strategy.run(
named_args, args=(1,), kwargs={
"first_arg": 2,
"second_arg": 3
})
step()
with self.assertRaisesRegex(ValueError,
r"cannot handle Variables passed to \*args"):
@def_function.function
def step():
strategy.run(
only_star_args, args=(
2,
tpu_vars,
))
step()
@def_function.function
def step():
strategy.run(pos_and_star_args, args=(2, 3, 4))
step()
@def_function.function
def step():
strategy.run(star_args_and_kw_only, args=(2, 3), kwargs={"kw": tpu_vars})
step()
with self.assertRaisesRegex(ValueError,
r"mix of positional args and \*args"):
@def_function.function
def step():
strategy.run(pos_and_star_args, args=(tpu_vars, 3, 4))
step()
with self.assertRaisesRegex(ValueError, r"Too many positional arguments"):
@def_function.function
def step():
strategy.run(named_args, args=(2, 3, 4))
step()
class DummyClass:
@def_function.function
def method(self, arg_1):
del arg_1
def step(self):
strategy.run(self.method, args=(tpu_vars,))
DummyClass().step()
# pylint:enable=function-redefined
def test_using_external_variable_inside_tf_function(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
v = variables.Variable(2.0)
@def_function.function
def train_step(data):
def computation(inputs):
return inputs + v
return strategy.run(computation, args=(data,))
expected_result = [[x + 2.] for x in range(0, strategy.num_replicas_in_sync)
]
self.assertAllEqual(
expected_result,
strategy.experimental_local_results(train_step(next(input_iterator))))
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
strategy.num_replicas_in_sync, drop_remainder=True)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
with strategy.scope():
w = variables.Variable(
(0.,),
shape=(1,),
trainable=False,
synchronization=variables.VariableSynchronization.ON_READ,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
self.assertFalse(w._is_mirrored())
@def_function.function
def run(iterator):
def computation(x):
w.assign(x + w)
return w
def all_reduce(x):
ctx = distribute_lib.get_replica_context()
return ctx.all_reduce("SUM", w) + x
outputs = strategy.run(computation, args=(next(iterator),))
outputs2 = strategy.experimental_local_results(
strategy.run(all_reduce, args=(outputs,)))
return outputs2
data = range(0, strategy.num_replicas_in_sync)
data_sum = sum(data)
expected_result = [
[x + data_sum] for x in range(0, strategy.num_replicas_in_sync)
]
self.assertAllEqual(expected_result, run(input_iterator))
self.assertAllEqual((0.,), w.read_value())
def test_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=(2,)))
return outputs
results = train_step()
self.assertAllEqual([4., 4.], results)
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:0",
results[0].backing_device)
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
results[1].backing_device)
def test_run_passing_and_returning_nones(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def computation(x):
return x
# Note that this input None is nested.
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([1, [2, None]],)))
return outputs
results = train_step()
self.assertAllEqual(1, results[0][0])
self.assertAllEqual(2, results[0][1][0])
self.assertIsNone(results[0][1][1])
def test_run_passing_and_returning_empty_list(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def computation(x):
return x
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([],)))
return outputs
self.assertEqual([], train_step()[0])
def test_run_passing_and_returning_empty_dict(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def computation(x):
return x
outputs = strategy.experimental_local_results(
strategy.run(computation, args=({},)))
return outputs
self.assertEqual({}, train_step()[0])
def test_composite_input_output(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
with strategy.scope():
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
# Assumes dense_shape is (2, *)
looked_up = array_ops.gather(table, sparse.values)
segment_sum = math_ops.unsorted_segment_sum(
looked_up, sparse.indices[:, 0], 2)
return sparse, segment_sum
return nest.map_structure(
strategy.experimental_local_results,
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(_):
return sparse_tensor.SparseTensor(
indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64),
values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
distribute_lib.InputOptions(experimental_fetch_to_device=False)))
sparse, result = sparse_lookup(dataset)
# All replicas return identical reults.
for replica in range(strategy.num_replicas_in_sync):
self.assertIsInstance(sparse[replica], sparse_tensor.SparseTensor)
self.assertAllEqual(sparse[replica].indices, [[0, 0], [1, 0], [1, 1]])
self.assertAllEqual(sparse[replica].values, [0, 0, 1])
self.assertAllEqual(sparse[replica].dense_shape, [2, 2])
self.assertAllEqual(result[replica], [[0.0, 1.0], [3.0, 8.0]])
def test_composite_input_non_flat_output(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
with strategy.scope():
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
# Assumes dense_shape is (2, *)
looked_up = array_ops.gather(table, sparse.values)
segment_sum = math_ops.unsorted_segment_sum(
looked_up, sparse.indices[:, 0], 2)
return {"sparse": sparse, "segment_sum": segment_sum}
return nest.map_structure(
strategy.experimental_local_results,
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(_):
return sparse_tensor.SparseTensor(
indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64),
values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
distribute_lib.InputOptions(experimental_fetch_to_device=False)))
output = sparse_lookup(dataset)
# All replicas return identical reults.
for replica in range(strategy.num_replicas_in_sync):
self.assertIsInstance(output["sparse"][replica],
sparse_tensor.SparseTensor)
self.assertAllEqual(output["sparse"][replica].indices,
[[0, 0], [1, 0], [1, 1]])
self.assertAllEqual(output["sparse"][replica].values, [0, 0, 1])
self.assertAllEqual(output["sparse"][replica].dense_shape, [2, 2])
self.assertAllEqual(output["segment_sum"][replica],
[[0.0, 1.0], [3.0, 8.0]])
def test_composite_input_dynamic_shapes_outside_compilation(
self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
lookup = tpu_replication.outside_compilation(
embedding_ops.safe_embedding_lookup_sparse, table, sparse)
return math_ops.reduce_sum(lookup, axis=0)
return strategy.experimental_local_results(
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(i):
indices = array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64)[0:2 + i]
values = array_ops.constant([0, 0, 1], dtype=dtypes.int32)[0:2 + i]
shape = [
array_ops.constant([2], dtype=dtypes.int64),
array_ops.expand_dims(1 + i, axis=0)
]
dense_shape = array_ops.concat(shape, axis=0)
return sparse_tensor.SparseTensor(
indices=indices, values=values, dense_shape=dense_shape)
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
result = sparse_lookup(dataset)
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
def test_composite_input_with_non_flat_components(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
class TestCompositeTypeSpec(type_spec.TypeSpec):
def __init__(self, component_type_spec):
self._component_type_spec = component_type_spec
@property
def value_type(self):
return TestComposite
def _to_components(self, value):
return value.values
def _from_components(self, components):
return TestComposite(components[0], components[1][0], components[1][1])
@property
def _component_specs(self):
return [self._component_type_spec,
[self._component_type_spec, self._component_type_spec]]
def _serialize(self):
return (self._component_type_spec,)
class TestComposite(composite_tensor.CompositeTensor):
def __init__(self, value1, value2, value3):
self.values = [value1, [value2, value3]]
@property
def _type_spec(self):
return TestCompositeTypeSpec(
tensor_spec.TensorSpec.from_tensor(self.values[0]))
def _shape_invariant_to_type_spec(self, shape):
return [shape, [shape, shape]]
@def_function.function
def test_fn(test_composite):
def tpu_function(composite):
return (composite,
composite.values[0] + (
composite.values[1][0] + composite.values[1][1])/2)
return nest.map_structure(
strategy.experimental_local_results,
strategy.run(tpu_function, args=(test_composite,)))
a = array_ops.constant([0.1])
b = array_ops.constant([1.2])
c = array_ops.constant([-0.4])
test_composite = TestComposite(a, b, c)
composite, result = test_fn(test_composite)
# All replicas return identical reults.
for replica in range(strategy.num_replicas_in_sync):
self.assertIsInstance(composite[replica], TestComposite)
self.assertAllEqual(composite[replica].values[0], a)
self.assertAllEqual(composite[replica].values[1][0], b)
self.assertAllEqual(composite[replica].values[1][1], c)
self.assertAllEqual(result[replica], array_ops.constant([0.50000006]))
def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
# Define trace_count as a list to avoid python scoping error
trace_count = [0]
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
variable = variables.Variable(0.0)
@def_function.function
def add_one():
trace_count[0] = trace_count[0] + 1
return math_ops.add(variable, constant_op.constant(1.0))
@def_function.function
def update_variable():
for device in set(strategy.extended.worker_devices):
with ops.device(device):
add_one()
with strategy.scope():
update_variable.get_concrete_function()
self.assertLen(strategy.extended.worker_devices, trace_count[0])
def test_tpu_cancellation_does_not_close_chips(self, enable_packed_var):
if not FLAGS.tpu_use_tfrt:
self.skipTest(
"`tpu_cancellation_closes_chip only applies to TFRT TPU Runtime.")
strategy = get_tpu_strategy(enable_packed_var)
num_replicas = strategy.num_replicas_in_sync
with strategy.scope():
x = random_ops.random_normal((10240, 10240))
y = random_ops.random_normal((10240, 10240))
v = variables.Variable(array_ops.identity(x))
dist_dataset = strategy.experimental_distribute_dataset(
dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch(
num_replicas))
dist_iterator = iter(dist_dataset)
@def_function.function
def train_steps(v, iterator, steps):
def step_fn(inputs):
for val in inputs:
v.assign(math_ops.matmul(v, val))
for _ in math_ops.range(steps):
strategy.run(step_fn, args=(next(iterator),))
with self.assertRaises(errors.OutOfRangeError):
# The iterator has num_replicas/num_replicas = 1 step only.
train_steps(v, dist_iterator, 2)
# If TPU chips are not closed we can run the function on TPU again.
w = variables.Variable(array_ops.identity(x))
dist_dataset = strategy.experimental_distribute_dataset(
dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch(
num_replicas))
dist_iterator = iter(dist_dataset)
train_steps(w, dist_iterator, 1)
def test_tpu_hardware_feature(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
self.assertIsInstance(
strategy.extended.tpu_hardware_feature.embedding_feature,
tpu_hardware_feature.HardwareFeature.EmbeddingFeature)
def test_get_tpu_cluster_resolver(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
self.assertIsNotNone(strategy.cluster_resolver)
def test_replica_order_for_distribute_datasets_from_function(
self, enable_packed_var
):
def _create_dataset(strategy):
def dataset_fn(ctx):
del ctx
return dataset_ops.Dataset.range(2)
return strategy.distribute_datasets_from_function(dataset_fn)
values = self._test_replica_order(_create_dataset).values
self.assertLen(values, 2)
self.assertEqual(1, values[0].numpy())
self.assertEqual(0, values[1].numpy())
def test_replica_order_for_experimental_distribute_dataset(
self, enable_packed_var
):
def _create_dataset(strategy):
dataset = dataset_ops.Dataset.range(2).batch(2)
return strategy.experimental_distribute_dataset(dataset)
values = self._test_replica_order(_create_dataset).values
self.assertLen(values, 2)
self.assertEqual(1, values[0].numpy())
self.assertEqual(0, values[1].numpy())
def _test_replica_order(self, create_dist_dataset_fn):
tf2.enable()
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]
)
strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment
)
dist_dataset = create_dist_dataset_fn(strategy)
iterator = iter(dist_dataset)
@def_function.function
def test_iterators_order(iterator):
return next(iterator)
return test_iterators_order(iterator)
@test_util.with_eager_op_as_function
class TPUStrategyDataPrefetchTest(test.TestCase):
def test_prefetch_to_device_default(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Check default, should prefetch to TPU.
dataset_item = next(iter(strategy.experimental_distribute_dataset(dataset)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_tpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
input_options = distribute_lib.InputOptions(
experimental_fetch_to_device=True)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_cpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Should be CPU when prefetch_to_device is False.
input_options = distribute_lib.InputOptions(
experimental_fetch_to_device=False)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "CPU")
def test_prefetch_to_device_sparse_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_ragged_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_sparse_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.distribute_datasets_from_function(dataset_fn))
def test_prefetch_to_device_ragged_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.distribute_datasets_from_function(dataset_fn))
def test_create_iterator_on_device(self):
@def_function.function
def create_iter():
with ops.device("/device:TPU:0"):
return gen_dataset_ops.anonymous_iterator_v3(
output_types=[dtypes.float32], output_shapes=[[]])
create_iter()
@test_util.with_eager_op_as_function
class TPUStrategyDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase):
def test_update_config_proto(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
config_proto = config_pb2.ConfigProto()
cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]})
with test.mock.patch.object(
resolver, "cluster_spec", return_value=cluster_spec):
new_config = strategy.update_config_proto(config_proto)
# Verify cluster_def.
self.assertProtoEquals(cluster_spec.as_cluster_def(),
new_config.cluster_def)
# Verify isolate_session_state
self.assertTrue(new_config.isolate_session_state)
def test_make_input_fn_iterable(self):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
distribution = get_tpu_strategy()
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
self._test_input_fn_iterable(distribution, input_fn, expected_values)
def test_make_input_fn_iterator(self):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
distribution = get_tpu_strategy()
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator,
distribution.extended.worker_devices,
expected_values)
def test_num_replicas_in_sync(self):
strategy = get_tpu_strategy()
self.assertEqual(2, strategy.num_replicas_in_sync)
def test_call_and_merge_exceptions(self):
strategy = get_tpu_strategy()
self._test_call_and_merge_exceptions(strategy)
def test_numpy_dataset(self):
strategy = get_tpu_strategy()
self._test_numpy_dataset(strategy, run_in_function=True)
def test_global_step_update(self):
strategy = get_tpu_strategy()
self._test_global_step_update(strategy)
def test_run(self):
strategy = get_tpu_strategy()
self._test_run(strategy, run_in_function=True)
def test_summary_for_replica_zero_only(self):
strategy = get_tpu_strategy()
self._test_summary_for_replica_zero_only(strategy)
def test_all_reduce_sum(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum(strategy, run_in_function=True)
def test_all_reduce_sum_gradients(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum_gradients(strategy, run_in_function=True)
def test_all_reduce_sum_gradient_tape(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum_gradient_tape(strategy, run_in_function=True)
def test_all_reduce_mean(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean(strategy, run_in_function=True)
def test_all_reduce_mean_gradients(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean_gradients(strategy, run_in_function=True)
def test_all_reduce_mean_gradient_tape(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean_gradient_tape(strategy, run_in_function=True)
def test_reduce(self):
strategy = get_tpu_strategy()
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices([2., 3.]))
self.evaluate(inputs.initialize())
per_replica_outputs = strategy.run(
def_function.function(math_ops.square), args=(next(inputs),))
with strategy.scope():
mean = strategy.reduce(reduce_util.ReduceOp.MEAN, per_replica_outputs,
axis=None)
self.assertEqual(6.5, self.evaluate(mean))
def test_constraint(self):
strategy = get_tpu_strategy()
with strategy.scope():
variable = variables.Variable(initial_value=2.,
constraint=lambda x: 0. * x + 1.)
self.assertEqual(variable.value().numpy(), 2)
@def_function.function
def update_variable():
variable.assign_add(1)
variable.assign(variable.constraint(variable))
update_variable()
self.assertEqual(variable.value().numpy(), 1)
def test_trainable_variables(self):
strategy = get_tpu_strategy()
self._test_trainable_variable(strategy)
@test_util.with_eager_op_as_function
class DeviceAssignmentTest(test.TestCase):
def test_core_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
self.assertEqual(1, device_assignment.num_cores_per_replica)
self.assertEqual(1, device_assignment.num_replicas)
self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
def test_device_assignment_strategy_properties(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
strategy = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
self.assertEqual(strategy.extended.num_hosts, 1)
self.assertEqual(strategy.num_replicas_in_sync, 1)
self.assertEqual(strategy.extended.num_replicas_per_host, 1) # pylint: disable=protected-access
def test_device_assignment_constants(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology,
core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
self.assertEqual(1, device_assignment.num_cores_per_replica)
self.assertEqual(1, device_assignment.num_replicas)
self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
def test_variables_mismatched_device_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy0 = tpu_lib.TPUStrategyV2(resolver)
self.assertEqual(
("/job:localhost/replica:0/task:0/device:TPU:0",
"/job:localhost/replica:0/task:0/device:TPU:1"),
strategy0.extended.worker_devices)
with strategy0.scope():
v = variables.Variable(1.)
v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.evaluate(v1_assign_op)
self.assertAllEqual([1., 42.],
self.evaluate(
strategy0.experimental_local_results(v)))
# Second strategy has devices reversed relative to the first.
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]])
strategy1 = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
self.assertEqual(
("/job:localhost/replica:0/task:0/device:TPU:1",
"/job:localhost/replica:0/task:0/device:TPU:0"),
strategy1.extended.worker_devices)
v_read = strategy1.run(def_function.function(v.read_value))
with self.cached_session():
self.assertAllEqual([42., 1.],
self.evaluate(
strategy0.experimental_local_results(v_read)))
if __name__ == "__main__":
test.main()