Break the interdependency between tpu/tpu_strategy_util.py, distribute/cluster_resolver/tpu/tpu_cluster_resolver.py, and eager/remote.py.
Make exported wrappers for initialize_tpu_system and shutdown_tpu_system in tpu_cluster_resolver.py, and make implementation functions for them in tpu_strategy_util.py that they pass a reference of TPUClusterResolver to. Add a test for this.
Replace an isinstance check of TPUClusterResolver in remote.py with an isinstance check of ClusterResolver and a hasattr check.
Make all inline imports regular imports and add all missing BUILD deps. Update references to the new location of initialize_tpu_system.
PiperOrigin-RevId: 537905952
GitOrigin-RevId: 6088a22ddb8fc299d18a34596b91ccf077b86b17
Change-Id: I06951d37f6f70a34df92adf7d80e3765078f2ce0
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 6823cb3..e10303e 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -2646,7 +2646,6 @@
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/platform:flags",
- "//tensorflow/python/tpu:tpu_lib",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/compiler/tests/giant_const_op_test.py b/tensorflow/compiler/tests/giant_const_op_test.py
index 014b9d5..9a73a95 100644
--- a/tensorflow/compiler/tests/giant_const_op_test.py
+++ b/tensorflow/compiler/tests/giant_const_op_test.py
@@ -25,7 +25,6 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import flags
-from tensorflow.python.tpu import tpu_strategy_util
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -45,7 +44,7 @@
def get_tpu_strategy():
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategyV2(resolver)
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 4f164c1..8954b31 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -467,11 +467,11 @@
"//tensorflow/python:training",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
"//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/distribute/v1:input_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:device",
"//tensorflow/python/platform:tf_logging",
- "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/util:deprecation",
"//tensorflow/python/util:tf_export",
],
@@ -814,7 +814,6 @@
"//tensorflow/python/eager:test",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/platform:flags",
- "//tensorflow/python/tpu:tpu_strategy_util",
],
)
@@ -1164,7 +1163,6 @@
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/platform:flags",
"//tensorflow/python/tpu:device_assignment",
- "//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/training:server_lib",
"//tensorflow/python/util:tf_export",
],
@@ -1783,7 +1781,6 @@
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
- "//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/util:variable_utils",
"@absl_py//absl/testing:parameterized",
],
@@ -2141,10 +2138,10 @@
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/distribute/v1:input_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:config",
- "//tensorflow/python/tpu:tpu_lib",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@@ -2189,10 +2186,10 @@
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/distribute/v1:input_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:config",
- "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/training:server_lib",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -2637,7 +2634,6 @@
"//tensorflow/python/module",
"//tensorflow/python/platform:flags",
"//tensorflow/python/tpu:device_assignment",
- "//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/tpu:tpu_replication",
"@absl_py//absl/testing:parameterized",
],
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
index 787bb4a..04270cf 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
+++ b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
@@ -22,9 +22,11 @@
"//tensorflow/core/protobuf/tpu:topology_proto_py",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
+ "//tensorflow/python/eager:remote",
"//tensorflow/python/framework:config",
"//tensorflow/python/framework:errors",
"//tensorflow/python/platform:tf_logging",
+ "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/tpu:tpu_system_metadata",
"//tensorflow/python/tpu/client",
"//tensorflow/python/util:compat",
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
index 6e40bac..58ec2ba 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
@@ -18,10 +18,12 @@
import re
from tensorflow.core.protobuf.tpu import topology_pb2
-from tensorflow.python.distribute.cluster_resolver import cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
+from tensorflow.python.eager import remote
from tensorflow.python.framework import config as framework_config
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
@@ -53,7 +55,44 @@
'DeviceDetails', ['device_map', 'total_cores'])
-class TPUClusterResolver(cluster_resolver.ClusterResolver):
+def initialize_tpu_system(cluster_resolver=None):
+ """Initialize the TPU devices.
+
+ Args:
+ cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
+ which provides information about the TPU cluster.
+ Returns:
+ The tf.tpu.Topology object for the topology of the TPU cluster. If called
+ inside tf.function, it returns the serialized topology object instead.
+
+ Raises:
+ RuntimeError: If running inside a tf.function.
+ NotFoundError: If no TPU devices found in eager mode.
+ """
+ return tpu_strategy_util.initialize_tpu_system_impl(
+ cluster_resolver, TPUClusterResolver)
+
+
+def shutdown_tpu_system(cluster_resolver=None):
+ """Shuts down the TPU devices.
+
+ This will clear all caches, even those that are maintained through sequential
+ calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
+ cache.
+
+ Args:
+ cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
+ which provides information about the TPU cluster.
+
+ Raises:
+ RuntimeError: If no TPU devices found for eager execution or if run in a
+ tf.function.
+ """
+ tpu_strategy_util.shutdown_tpu_system_impl(
+ cluster_resolver, TPUClusterResolver)
+
+
+class TPUClusterResolver(cluster_resolver_lib.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
@@ -104,10 +143,8 @@
NotFoundError: If no TPU devices found in eager mode.
"""
resolver = TPUClusterResolver(tpu, zone, project)
- from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top
remote.connect_to_cluster(resolver)
- from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_strategy_util.initialize_tpu_system_impl(resolver)
return resolver
@staticmethod
@@ -266,7 +303,7 @@
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
- return cluster_resolver.format_master_url(master, 'grpc')
+ return cluster_resolver_lib.format_master_url(master, 'grpc')
else:
return ''
@@ -384,7 +421,7 @@
while True:
try:
device_details = TPUClusterResolver._get_device_dict_and_cores(
- cluster_resolver.get_accelerator_devices(
+ cluster_resolver_lib.get_accelerator_devices(
self.master(), config_proto=config_proto))
break
except errors.DeadlineExceededError:
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
index 0fdaf79..4aebbf2 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
@@ -15,8 +15,12 @@
"""Shim so that direct imports of tpu_cluster_resolver get correct symbols.
"""
+from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import initialize_tpu_system
from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import is_running_in_gce # pylint: disable=unused-import
+from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import shutdown_tpu_system
from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.util.tf_export import tf_export
tf_export('distribute.cluster_resolver.TPUClusterResolver')(TPUClusterResolver)
+tf_export('tpu.experimental.initialize_tpu_system')(initialize_tpu_system)
+tf_export('tpu.experimental.shutdown_tpu_system')(shutdown_tpu_system)
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 99c796d..b87ded9 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -36,6 +36,7 @@
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
@@ -45,7 +46,6 @@
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.trackable import base
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -540,7 +540,7 @@
local_devices, local_device_type = self._initialize_local_devices(
cluster_resolver, self._worker_device)
if local_device_type == "TPU":
- tpu_strategy_util.initialize_tpu_system()
+ tpu_cluster_resolver.initialize_tpu_system()
self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + self._collective_key_base)
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
index 4ad7a8e..9721a53 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
@@ -35,6 +35,7 @@
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
from tensorflow.python.eager import context
from tensorflow.python.framework import config as tf_config
@@ -52,7 +53,6 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import server_lib
@@ -76,7 +76,7 @@
if num_tpus is None:
num_tpus = context.context().list_physical_devices('TPU')
if num_tpus:
- tpu_strategy_util.initialize_tpu_system()
+ tpu_cluster_resolver.initialize_tpu_system()
if cluster_spec and task_type and task_id is not None:
cluster_resolver = cluster_resolver_lib.SimpleClusterResolver(
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
index fc148be..e1afbf9 100644
--- a/tensorflow/python/distribute/parallel_device/BUILD
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -60,6 +60,7 @@
"//tensorflow/python/checkpoint",
"//tensorflow/python/checkpoint:checkpoint_management",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
@@ -69,7 +70,6 @@
"//tensorflow/python/module",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
- "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/util:nest",
"@absl_py//absl/testing:parameterized",
],
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
index d1aee83..bc78a21 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -20,6 +20,7 @@
from tensorflow.python.checkpoint import checkpoint as tracking
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -39,7 +40,6 @@
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.util import nest
# When running collectives asynchronously, we need to give each parallel device
@@ -97,7 +97,7 @@
ctx = context.context()
if ctx.list_physical_devices("TPU"):
self.device_type = "TPU"
- tpu_strategy_util.initialize_tpu_system()
+ tpu_cluster_resolver.initialize_tpu_system()
elif ctx.list_physical_devices("GPU"):
self.device_type = "GPU"
gpus = ctx.list_physical_devices(self.device_type)
diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py
index 4a2bd81..c73d180 100644
--- a/tensorflow/python/distribute/strategy_combinations.py
+++ b/tensorflow/python/distribute/strategy_combinations.py
@@ -39,7 +39,6 @@
from tensorflow.python.framework import test_util as framework_test_util
from tensorflow.python.platform import flags
from tensorflow.python.tpu import device_assignment as device_assignment_lib
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import server_lib
from tensorflow.python.util.tf_export import tf_export
@@ -105,7 +104,7 @@
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
remote.connect_to_cluster(resolver)
_did_connect_to_cluster = True
- _topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ _topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = None
if use_single_core:
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index ff09791..05579c9 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -62,7 +62,6 @@
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_hardware_feature
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu import training_loop
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util import deprecation
@@ -1239,7 +1238,7 @@
This is a private method only to be used by Estimator. Other frameworks
should directly be calling `tf.tpu.experimental.initialize_tpu_system`
"""
- tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
+ tpu_cluster_resolver_lib.initialize_tpu_system(self._tpu_cluster_resolver)
def _create_variable(self, next_creator, **kwargs):
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
diff --git a/tensorflow/python/distribute/tpu_strategy_compilation_test.py b/tensorflow/python/distribute/tpu_strategy_compilation_test.py
index 63d54f1..defc9c2 100644
--- a/tensorflow/python/distribute/tpu_strategy_compilation_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_compilation_test.py
@@ -21,7 +21,6 @@
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import flags
-from tensorflow.python.tpu import tpu_strategy_util
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -41,7 +40,7 @@
def get_tpu_strategy():
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
return strategy
diff --git a/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py b/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
index 13c9e85..b6801e5c 100644
--- a/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
@@ -47,7 +47,6 @@
from tensorflow.python.platform import flags
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu_replication
-from tensorflow.python.tpu import tpu_strategy_util
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -67,7 +66,7 @@
def get_tpu_strategy(enable_spmd=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index 115ae1e..ab3711e 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -79,7 +79,7 @@
def get_tpu_strategy(enable_packed_var=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
return strategy
@@ -176,12 +176,30 @@
def test_multiple_initialize_system(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
with test.mock.patch.object(logging, "warning") as mock_log:
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized")
+ def test_initialize_tpu_system_impl_input(self):
+ resolver = get_tpu_cluster_resolver()
+ with self.assertRaisesRegex(
+ TypeError,
+ r"tpu_cluster_resolver_cls is not"
+ r" tf.distribute.cluster_resolver.TPUClusterResolver."):
+ tpu_strategy_util.initialize_tpu_system_impl(
+ resolver, tpu_cluster_resolver_cls=None)
+
+ def test_shutdown_tpu_system_impl_input(self):
+ resolver = get_tpu_cluster_resolver()
+ with self.assertRaisesRegex(
+ TypeError,
+ r"tpu_cluster_resolver_cls is not"
+ r" tf.distribute.cluster_resolver.TPUClusterResolver."):
+ tpu_strategy_util.shutdown_tpu_system_impl(
+ resolver, tpu_cluster_resolver_cls=None)
+
def test_tpu_tf_function_same_device(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
@@ -332,7 +350,7 @@
def test_sequential_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
@@ -443,7 +461,7 @@
def test_computation_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
@@ -484,7 +502,7 @@
def test_worker_devices_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
@@ -1180,7 +1198,7 @@
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]
)
@@ -1317,7 +1335,7 @@
def test_update_config_proto(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
config_proto = config_pb2.ConfigProto()
@@ -1449,7 +1467,7 @@
def test_core_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
@@ -1461,7 +1479,7 @@
def test_device_assignment_strategy_properties(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
strategy = tpu_lib.TPUStrategyV2(
@@ -1474,7 +1492,7 @@
def test_device_assignment_constants(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology,
core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
@@ -1487,7 +1505,7 @@
def test_variables_mismatched_device_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- topology = tpu_strategy_util.initialize_tpu_system(resolver)
+ topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy0 = tpu_lib.TPUStrategyV2(resolver)
self.assertEqual(
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
index e1472fe..5dd2c5a 100644
--- a/tensorflow/python/distribute/vars_test.py
+++ b/tensorflow/python/distribute/vars_test.py
@@ -42,7 +42,6 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.util import variable_utils
@@ -971,7 +970,7 @@
for aggregation in aggregations:
if strategy_test_lib.is_tpu_strategy(distribution):
resolver = tpu_cluster_resolver.TPUClusterResolver("")
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
with distribution.scope():
v = variable_v1.VariableV1(
0.,
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index ce25754..2499f60 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1110,7 +1110,6 @@
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python/distribute:device_util",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
- "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
"//tensorflow/python/framework:ops",
"//tensorflow/python/platform:remote_utils",
"//tensorflow/python/training:server_lib",
@@ -1234,7 +1233,6 @@
":remote",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/framework:config",
- "//tensorflow/python/tpu:tpu_strategy_util",
"@absl_py//absl/flags",
"@absl_py//absl/testing:absltest",
],
diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py
index 3f2b093..2f70bfc 100644
--- a/tensorflow/python/eager/remote.py
+++ b/tensorflow/python/eager/remote.py
@@ -22,7 +22,6 @@
from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
-from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.platform import remote_utils
@@ -184,8 +183,10 @@
service_leader = ""
# Maybe enable coordination service for the communication protocol
# TODO(b/243839559): Fix UPTC + Coordination service crashing
- if isinstance(cluster_spec_or_resolver,
- tpu_cluster_resolver.TPUClusterResolver):
+ # Check if cluster_spec_or_resolver is an instance of
+ # tpu_cluster_resolver.TPUClusterResolver
+ if (isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver)
+ and hasattr(cluster_spec_or_resolver, "tpu_hardware_feature")):
is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master()
service_type = remote_utils.coordination_service_type(
protocol, is_uptc_sess)
diff --git a/tensorflow/python/eager/remote_cloud_tpu_test.py b/tensorflow/python/eager/remote_cloud_tpu_test.py
index efad4ad..d1b2371 100644
--- a/tensorflow/python/eager/remote_cloud_tpu_test.py
+++ b/tensorflow/python/eager/remote_cloud_tpu_test.py
@@ -20,7 +20,6 @@
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import remote
from tensorflow.python.framework import config
-from tensorflow.python.tpu import tpu_strategy_util
FLAGS = flags.FLAGS
flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
@@ -75,7 +74,7 @@
expected_devices,
[device.name for device in config.list_logical_devices()])
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
if __name__ == '__main__':
absltest.main()
diff --git a/tensorflow/python/ops/memory_tests/BUILD b/tensorflow/python/ops/memory_tests/BUILD
index 5a7404c..88b24f4 100644
--- a/tensorflow/python/ops/memory_tests/BUILD
+++ b/tensorflow/python/ops/memory_tests/BUILD
@@ -19,6 +19,7 @@
"//tensorflow/python:array_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:config",
@@ -27,7 +28,6 @@
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
- "//tensorflow/python/tpu:tpu_strategy_util",
"@absl_py//absl/testing:parameterized",
] + tf_additional_xla_deps_py(),
)
@@ -46,6 +46,7 @@
"//tensorflow/python:array_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:config",
@@ -54,7 +55,6 @@
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
- "//tensorflow/python/tpu:tpu_strategy_util",
"@absl_py//absl/testing:parameterized",
] + tf_additional_xla_deps_py(),
)
diff --git a/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py b/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
index f5650c4..56b7204 100644
--- a/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
+++ b/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
@@ -18,6 +18,7 @@
from absl.testing import parameterized
from tensorflow.compiler.xla.service import hlo_pb2
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
@@ -29,7 +30,6 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
class RecomputeGradMemoryTest(test.TestCase, parameterized.TestCase):
@@ -116,7 +116,7 @@
device_name = f"{device_type}:0"
# Necessary for TFRT tests.
if device_type == "TPU":
- tpu_strategy_util.initialize_tpu_system()
+ tpu_cluster_resolver.initialize_tpu_system()
n = 500
with ops.device(device_name):
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index 368424d..3bba5e5 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -333,7 +333,7 @@
"//tensorflow/python:while_loop",
"//tensorflow/python/client:session",
"//tensorflow/python/compiler/xla",
- "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
+ "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:monitoring",
@@ -613,7 +613,7 @@
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python/client:session",
- "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
+ "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:monitoring",
@@ -621,7 +621,6 @@
"//tensorflow/python/framework:errors",
"//tensorflow/python/platform:tf_logging",
"//tensorflow/python/util:compat",
- "//tensorflow/python/util:tf_export",
],
)
@@ -906,7 +905,6 @@
tags = ["no_oss"],
deps = [
":functional",
- ":tpu_lib",
":tpu_py",
":tpu_replication",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/tpu/tests/BUILD b/tensorflow/python/tpu/tests/BUILD
index 7e8752a..33b966b 100644
--- a/tensorflow/python/tpu/tests/BUILD
+++ b/tensorflow/python/tpu/tests/BUILD
@@ -28,7 +28,6 @@
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/tpu:tpu_embedding_v2",
"//tensorflow/python/tpu:tpu_embedding_v2_utils",
- "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/util:nest",
"//third_party/py/numpy",
"@absl_py//absl/flags",
@@ -50,6 +49,7 @@
"//tensorflow/python:init_ops_v2",
"//tensorflow/python/checkpoint",
"//tensorflow/python/compat:v2_compat",
+ "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
@@ -62,7 +62,6 @@
"//tensorflow/python/tpu:tpu_embedding_for_serving",
"//tensorflow/python/tpu:tpu_embedding_v2",
"//tensorflow/python/tpu:tpu_embedding_v2_utils",
- "//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/training:checkpoint_utils",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -525,7 +524,6 @@
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/platform:client_testlib",
- "//tensorflow/python/tpu:tpu_strategy_util",
"@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/tpu/tests/tpu_embedding_base_test.py b/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
index 7c751e3..cbe3b5b 100644
--- a/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
+++ b/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
@@ -35,7 +35,6 @@
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.util import nest
FLAGS = flags.FLAGS
@@ -154,7 +153,7 @@
self.resolver._cloud_tpu_client.configure_tpu_version(
version='nightly', restart_type='always')
remote.connect_to_cluster(self.resolver)
- tpu_strategy_util.initialize_tpu_system(self.resolver)
+ tpu_cluster_resolver.initialize_tpu_system(self.resolver)
return tpu_strategy.TPUStrategy(self.resolver)
def _create_mid_level(self, optimizer=None):
diff --git a/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py b/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
index ca83819..a725a2b 100644
--- a/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
+++ b/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
@@ -18,6 +18,7 @@
import numpy as np
from tensorflow.python.checkpoint import checkpoint as util
from tensorflow.python.compat import v2_compat
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,7 +32,6 @@
from tensorflow.python.tpu import tpu_embedding_for_serving
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu.tests import tpu_embedding_base_test
from tensorflow.python.training import checkpoint_utils
@@ -86,7 +86,7 @@
msg='Checkpoint should contain values from the first api object.')
# Reinitialize the tpu.
- tpu_strategy_util.initialize_tpu_system(self.resolver)
+ tpu_cluster_resolver.initialize_tpu_system(self.resolver)
with strategy.scope():
second_mid_level_contents = np.ones((num_rows, 4)) * 2
@@ -148,7 +148,7 @@
first_checkpoint = util.Checkpoint(model=first_mid_level)
first_checkpoint.save(self._get_tmpdir('restore', 'save'))
- tpu_strategy_util.initialize_tpu_system(self.resolver)
+ tpu_cluster_resolver.initialize_tpu_system(self.resolver)
with strategy.scope():
second_mid_level_contents = np.ones((num_rows, 4)) * 2
diff --git a/tensorflow/python/tpu/tests/tpu_initialization_test.py b/tensorflow/python/tpu/tests/tpu_initialization_test.py
index 252f6c8..a398b2f 100644
--- a/tensorflow/python/tpu/tests/tpu_initialization_test.py
+++ b/tensorflow/python/tpu/tests/tpu_initialization_test.py
@@ -19,14 +19,13 @@
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
class TPUInitializationTest(parameterized.TestCase, test.TestCase):
def test_tpu_initialization(self):
resolver = tpu_cluster_resolver.TPUClusterResolver('')
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
if __name__ == '__main__':
diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py
index 7f4fcb0..eb0ce82 100644
--- a/tensorflow/python/tpu/tpu_outside_compilation_test.py
+++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py
@@ -51,7 +51,6 @@
from tensorflow.python.tpu import functional as tpu_functional
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_replication
-from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu.ops import tpu_ops
FLAGS = flags.FLAGS
@@ -72,7 +71,7 @@
def get_tpu_strategy():
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
- tpu_strategy_util.initialize_tpu_system(resolver)
+ tpu_cluster_resolver.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategyV2(resolver)
diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py
index 371a418..e122d48 100644
--- a/tensorflow/python/tpu/tpu_strategy_util.py
+++ b/tensorflow/python/tpu/tpu_strategy_util.py
@@ -18,12 +18,10 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
-from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
-from tensorflow.python.eager.def_function import function
-from tensorflow.python.eager.def_function import functions_run_eagerly
-from tensorflow.python.eager.def_function import run_functions_eagerly
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -31,7 +29,6 @@
from tensorflow.python.tpu import topology
from tensorflow.python.tpu import tpu
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
_INITIALIZED_TPU_SYSTEMS = {}
@@ -43,13 +40,19 @@
"The worker address that the coordinator/client connects to.", "address")
-@tf_export("tpu.experimental.initialize_tpu_system")
-def initialize_tpu_system(cluster_resolver=None):
- """Initialize the TPU devices.
+def initialize_tpu_system_impl(cluster_resolver, tpu_cluster_resolver_cls):
+ """Implementation for tpu.experimental.initialize_tpu_system.
+
+ Kept separate to avoid tpu_oss code duplication.
+
+ Initialize the TPU devices.
Args:
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
+ tpu_cluster_resolver_cls: a reference to
+ tf.distribute.cluster_resolver.TPUClusterResolver so that an instance
+ of it can be initialized if cluster_resolver is None.
Returns:
The tf.tpu.Topology object for the topology of the TPU cluster. If called
inside tf.function, it returns the serialized topology object instead.
@@ -57,8 +60,17 @@
Raises:
RuntimeError: If running inside a tf.function.
NotFoundError: If no TPU devices found in eager mode.
+ TypeError: If tpu_cluster_resolver_cls is
+ not tf.distribute.cluster_resolver.TPUClusterResolver.
"""
-
+ # check that tpu_cluster_resolver_cls is a
+ # tf.distribute.cluster_resolver.TPUClusterResolver
+ if tpu_cluster_resolver_cls is None or not issubclass(
+ tpu_cluster_resolver_cls, cluster_resolver_lib.ClusterResolver
+ ) or not hasattr(tpu_cluster_resolver_cls, "tpu_hardware_feature"):
+ raise TypeError(
+ "tpu_cluster_resolver_cls is not"
+ " tf.distribute.cluster_resolver.TPUClusterResolver.")
# Deallocate all TPU buffers by clearing out eager context caches and
# triggering garbage collection to avoid keeping invalid tpu buffer around
# after reinitialized tpu system.
@@ -76,8 +88,8 @@
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
- cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("")
- assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver)
+ cluster_resolver = tpu_cluster_resolver_cls("")
+ assert isinstance(cluster_resolver, tpu_cluster_resolver_cls)
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
if tpu_name in _INITIALIZED_TPU_SYSTEMS:
@@ -99,7 +111,7 @@
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
if context.executing_eagerly():
- @function(autograph=False)
+ @def_function.function(autograph=False)
def _tpu_init_fn():
# In TF1, we usually close chips when compilation fails to clear the data
# in infeed. In TF2, we don't need to do this because infeed is no longer
@@ -113,7 +125,7 @@
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
# devices available.
- run_eagerly = functions_run_eagerly()
+ run_eagerly = def_function.functions_run_eagerly()
if run_eagerly:
logging.warning(
"It looks like tf.function behavior was disabled, perhaps using"
@@ -121,7 +133,7 @@
" tf.tpu.experimental.initialize_tpu_system requires tf.function to"
" work. This primitive will override the disable."
)
- run_functions_eagerly(False)
+ def_function.run_functions_eagerly(False)
try:
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
output = _tpu_init_fn()
@@ -133,7 +145,7 @@
+ str(e))
finally:
if run_eagerly is not None:
- run_functions_eagerly(run_eagerly)
+ def_function.run_functions_eagerly(run_eagerly)
# Clear out the eager context caches since the memory is invalid now.
context.context()._initialize_logical_devices() # pylint: disable=protected-access
@@ -181,9 +193,12 @@
return _INITIALIZED_TPU_SYSTEMS.copy()
-@tf_export("tpu.experimental.shutdown_tpu_system")
-def shutdown_tpu_system(cluster_resolver=None):
- """Shuts down the TPU devices.
+def shutdown_tpu_system_impl(cluster_resolver, tpu_cluster_resolver_cls):
+ """Implementation for tpu.experimental.shutdown_tpu_system.
+
+ Kept separate to avoid tpu_oss code duplication.
+
+ Shuts down the TPU devices.
This will clear all caches, even those that are maintained through sequential
calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
@@ -192,11 +207,25 @@
Args:
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
+ tpu_cluster_resolver_cls: a reference to
+ tf.distribute.cluster_resolver.TPUClusterResolver so that an instance
+ of it can be initialized if cluster_resolver is None.
Raises:
RuntimeError: If no TPU devices found for eager execution or if run in a
tf.function.
+ TypeError: If tpu_cluster_resolver_cls is
+ not tf.distribute.cluster_resolver.TPUClusterResolver.
"""
+ # check that tpu_cluster_resolver_cls is a
+ # tf.distribute.cluster_resolver.TPUClusterResolver
+ if tpu_cluster_resolver_cls is None or not issubclass(
+ tpu_cluster_resolver_cls, cluster_resolver_lib.ClusterResolver
+ ) or not hasattr(tpu_cluster_resolver_cls, "tpu_hardware_feature"):
+ raise TypeError(
+ "tpu_cluster_resolver_cls is not"
+ " tf.distribute.cluster_resolver.TPUClusterResolver.")
+
job = None
if cluster_resolver is None:
# If no cluster resolver is specified, and running eagerly, execute the init
@@ -206,8 +235,8 @@
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
- cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("")
- assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver)
+ cluster_resolver = tpu_cluster_resolver_cls("")
+ assert isinstance(cluster_resolver, tpu_cluster_resolver_cls)
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
@@ -227,14 +256,14 @@
# avoid the output node match multiple devices error.
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
- @function(autograph=False)
+ @def_function.function(autograph=False)
def _tpu_shutdown_fn():
tpu.shutdown_system(job=job)
# The TPU_SYSTEM device must match the device used in tpu.shutdown_system
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
# devices available.
- run_eagerly = functions_run_eagerly()
+ run_eagerly = def_function.functions_run_eagerly()
if run_eagerly:
logging.warning(
"It looks like tf.function behavior was disabled, perhaps using"
@@ -242,13 +271,13 @@
" tf.tpu.experimental.shutdown_tpu_system requires tf.function to"
" work. This primitive will override the disable."
)
- run_functions_eagerly(False)
+ def_function.run_functions_eagerly(False)
try:
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
_tpu_shutdown_fn()
finally:
if run_eagerly is not None:
- run_functions_eagerly(run_eagerly)
+ def_function.run_functions_eagerly(run_eagerly)
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")