Break the interdependency between tpu/tpu_strategy_util.py,  distribute/cluster_resolver/tpu/tpu_cluster_resolver.py, and eager/remote.py.

Make exported wrappers for initialize_tpu_system and shutdown_tpu_system in tpu_cluster_resolver.py, and make implementation functions for them in tpu_strategy_util.py that they pass a reference of TPUClusterResolver to. Add a test for this.

Replace an isinstance check of TPUClusterResolver in remote.py with an isinstance check of ClusterResolver and a hasattr check.

Make all inline imports regular imports and add all missing BUILD deps. Update references to the new location of initialize_tpu_system.

PiperOrigin-RevId: 537905952
GitOrigin-RevId: 6088a22ddb8fc299d18a34596b91ccf077b86b17
Change-Id: I06951d37f6f70a34df92adf7d80e3765078f2ce0
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 6823cb3..e10303e 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -2646,7 +2646,6 @@
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/platform:flags",
-        "//tensorflow/python/tpu:tpu_lib",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/compiler/tests/giant_const_op_test.py b/tensorflow/compiler/tests/giant_const_op_test.py
index 014b9d5..9a73a95 100644
--- a/tensorflow/compiler/tests/giant_const_op_test.py
+++ b/tensorflow/compiler/tests/giant_const_op_test.py
@@ -25,7 +25,6 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.platform import flags
-from tensorflow.python.tpu import tpu_strategy_util
 
 FLAGS = flags.FLAGS
 flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -45,7 +44,7 @@
 def get_tpu_strategy():
   resolver = get_tpu_cluster_resolver()
   remote.connect_to_cluster(resolver)
-  tpu_strategy_util.initialize_tpu_system(resolver)
+  tpu_cluster_resolver.initialize_tpu_system(resolver)
   return tpu_lib.TPUStrategyV2(resolver)
 
 
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 4f164c1..8954b31 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -467,11 +467,11 @@
         "//tensorflow/python:training",
         "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
         "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/distribute/v1:input_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:device",
         "//tensorflow/python/platform:tf_logging",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/util:deprecation",
         "//tensorflow/python/util:tf_export",
     ],
@@ -814,7 +814,6 @@
         "//tensorflow/python/eager:test",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/platform:flags",
-        "//tensorflow/python/tpu:tpu_strategy_util",
     ],
 )
 
@@ -1164,7 +1163,6 @@
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:flags",
         "//tensorflow/python/tpu:device_assignment",
-        "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/training:server_lib",
         "//tensorflow/python/util:tf_export",
     ],
@@ -1783,7 +1781,6 @@
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:test",
-        "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/util:variable_utils",
         "@absl_py//absl/testing:parameterized",
     ],
@@ -2141,10 +2138,10 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/distribute/v1:input_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:config",
-        "//tensorflow/python/tpu:tpu_lib",
         "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
     ],
@@ -2189,10 +2186,10 @@
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/distribute/v1:input_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:config",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/training:server_lib",
         "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
@@ -2637,7 +2634,6 @@
         "//tensorflow/python/module",
         "//tensorflow/python/platform:flags",
         "//tensorflow/python/tpu:device_assignment",
-        "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/tpu:tpu_replication",
         "@absl_py//absl/testing:parameterized",
     ],
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
index 787bb4a..04270cf 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
+++ b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD
@@ -22,9 +22,11 @@
         "//tensorflow/core/protobuf/tpu:topology_proto_py",
         "//tensorflow/python:training_server_lib",
         "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
+        "//tensorflow/python/eager:remote",
         "//tensorflow/python/framework:config",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/platform:tf_logging",
+        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/tpu:tpu_system_metadata",
         "//tensorflow/python/tpu/client",
         "//tensorflow/python/util:compat",
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
index 6e40bac..58ec2ba 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
@@ -18,10 +18,12 @@
 import re
 
 from tensorflow.core.protobuf.tpu import topology_pb2
-from tensorflow.python.distribute.cluster_resolver import cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
+from tensorflow.python.eager import remote
 from tensorflow.python.framework import config as framework_config
 from tensorflow.python.framework import errors
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import compat
@@ -53,7 +55,44 @@
     'DeviceDetails', ['device_map', 'total_cores'])
 
 
-class TPUClusterResolver(cluster_resolver.ClusterResolver):
+def initialize_tpu_system(cluster_resolver=None):
+  """Initialize the TPU devices.
+
+  Args:
+    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
+        which provides information about the TPU cluster.
+  Returns:
+    The tf.tpu.Topology object for the topology of the TPU cluster. If called
+    inside tf.function, it returns the serialized topology object instead.
+
+  Raises:
+    RuntimeError: If running inside a tf.function.
+    NotFoundError: If no TPU devices found in eager mode.
+  """
+  return tpu_strategy_util.initialize_tpu_system_impl(
+      cluster_resolver, TPUClusterResolver)
+
+
+def shutdown_tpu_system(cluster_resolver=None):
+  """Shuts down the TPU devices.
+
+  This will clear all caches, even those that are maintained through sequential
+  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
+  cache.
+
+  Args:
+    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
+        which provides information about the TPU cluster.
+
+  Raises:
+    RuntimeError: If no TPU devices found for eager execution or if run in a
+        tf.function.
+  """
+  tpu_strategy_util.shutdown_tpu_system_impl(
+      cluster_resolver, TPUClusterResolver)
+
+
+class TPUClusterResolver(cluster_resolver_lib.ClusterResolver):
   """Cluster Resolver for Google Cloud TPUs.
 
   This is an implementation of cluster resolvers for the Google Cloud TPU
@@ -104,10 +143,8 @@
       NotFoundError: If no TPU devices found in eager mode.
     """
     resolver = TPUClusterResolver(tpu, zone, project)
-    from tensorflow.python.eager import remote  # pylint: disable=g-import-not-at-top
     remote.connect_to_cluster(resolver)
-    from tensorflow.python.tpu import tpu_strategy_util  # pylint: disable=g-import-not-at-top
-    tpu_strategy_util.initialize_tpu_system(resolver)
+    tpu_strategy_util.initialize_tpu_system_impl(resolver)
     return resolver
 
   @staticmethod
@@ -266,7 +303,7 @@
         if not job_tasks:
           raise ValueError('No TPUs with the specified names exist.')
         master = job_tasks[0]
-      return cluster_resolver.format_master_url(master, 'grpc')
+      return cluster_resolver_lib.format_master_url(master, 'grpc')
     else:
       return ''
 
@@ -384,7 +421,7 @@
     while True:
       try:
         device_details = TPUClusterResolver._get_device_dict_and_cores(
-            cluster_resolver.get_accelerator_devices(
+            cluster_resolver_lib.get_accelerator_devices(
                 self.master(), config_proto=config_proto))
         break
       except errors.DeadlineExceededError:
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
index 0fdaf79..4aebbf2 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
@@ -15,8 +15,12 @@
 """Shim so that direct imports of tpu_cluster_resolver get correct symbols.
 """
 
+from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import initialize_tpu_system
 from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import is_running_in_gce  # pylint: disable=unused-import
+from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import shutdown_tpu_system
 from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import TPUClusterResolver
 from tensorflow.python.util.tf_export import tf_export
 
 tf_export('distribute.cluster_resolver.TPUClusterResolver')(TPUClusterResolver)
+tf_export('tpu.experimental.initialize_tpu_system')(initialize_tpu_system)
+tf_export('tpu.experimental.shutdown_tpu_system')(shutdown_tpu_system)
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 99c796d..b87ded9 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -36,6 +36,7 @@
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
 from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
 from tensorflow.python.eager import context
 from tensorflow.python.framework import device as tf_device
@@ -45,7 +46,6 @@
 from tensorflow.python.ops import collective_ops
 from tensorflow.python.ops import control_flow_util
 from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.trackable import base
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
@@ -540,7 +540,7 @@
     local_devices, local_device_type = self._initialize_local_devices(
         cluster_resolver, self._worker_device)
     if local_device_type == "TPU":
-      tpu_strategy_util.initialize_tpu_system()
+      tpu_cluster_resolver.initialize_tpu_system()
 
     self._collective_keys = cross_device_utils.CollectiveKeys(
         group_key_start=1 + self._collective_key_base)
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
index 4ad7a8e..9721a53 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.distribute import strategy_test_lib
 from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
 from tensorflow.python.eager import context
 from tensorflow.python.framework import config as tf_config
@@ -52,7 +53,6 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.training import server_lib
 
 
@@ -76,7 +76,7 @@
   if num_tpus is None:
     num_tpus = context.context().list_physical_devices('TPU')
   if num_tpus:
-    tpu_strategy_util.initialize_tpu_system()
+    tpu_cluster_resolver.initialize_tpu_system()
 
   if cluster_spec and task_type and task_id is not None:
     cluster_resolver = cluster_resolver_lib.SimpleClusterResolver(
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
index fc148be..e1afbf9 100644
--- a/tensorflow/python/distribute/parallel_device/BUILD
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -60,6 +60,7 @@
         "//tensorflow/python/checkpoint",
         "//tensorflow/python/checkpoint:checkpoint_management",
         "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
@@ -69,7 +70,6 @@
         "//tensorflow/python/module",
         "//tensorflow/python/saved_model:load",
         "//tensorflow/python/saved_model:save",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/util:nest",
         "@absl_py//absl/testing:parameterized",
     ],
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
index d1aee83..bc78a21 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -20,6 +20,7 @@
 from tensorflow.python.checkpoint import checkpoint as tracking
 from tensorflow.python.checkpoint import checkpoint_management
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.distribute.parallel_device import parallel_device
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
@@ -39,7 +40,6 @@
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import load
 from tensorflow.python.saved_model import save
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.util import nest
 
 # When running collectives asynchronously, we need to give each parallel device
@@ -97,7 +97,7 @@
     ctx = context.context()
     if ctx.list_physical_devices("TPU"):
       self.device_type = "TPU"
-      tpu_strategy_util.initialize_tpu_system()
+      tpu_cluster_resolver.initialize_tpu_system()
     elif ctx.list_physical_devices("GPU"):
       self.device_type = "GPU"
       gpus = ctx.list_physical_devices(self.device_type)
diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py
index 4a2bd81..c73d180 100644
--- a/tensorflow/python/distribute/strategy_combinations.py
+++ b/tensorflow/python/distribute/strategy_combinations.py
@@ -39,7 +39,6 @@
 from tensorflow.python.framework import test_util as framework_test_util
 from tensorflow.python.platform import flags
 from tensorflow.python.tpu import device_assignment as device_assignment_lib
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.training import server_lib
 from tensorflow.python.util.tf_export import tf_export
 
@@ -105,7 +104,7 @@
       if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
         remote.connect_to_cluster(resolver)
         _did_connect_to_cluster = True
-      _topology = tpu_strategy_util.initialize_tpu_system(resolver)
+      _topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
 
     device_assignment = None
     if use_single_core:
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index ff09791..05579c9 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -62,7 +62,6 @@
 from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
 from tensorflow.python.tpu import tpu
 from tensorflow.python.tpu import tpu_hardware_feature
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.tpu import training_loop
 from tensorflow.python.tpu.ops import tpu_ops
 from tensorflow.python.util import deprecation
@@ -1239,7 +1238,7 @@
     This is a private method only to be used by Estimator. Other frameworks
     should directly be calling `tf.tpu.experimental.initialize_tpu_system`
     """
-    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
+    tpu_cluster_resolver_lib.initialize_tpu_system(self._tpu_cluster_resolver)
 
   def _create_variable(self, next_creator, **kwargs):
     """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
diff --git a/tensorflow/python/distribute/tpu_strategy_compilation_test.py b/tensorflow/python/distribute/tpu_strategy_compilation_test.py
index 63d54f1..defc9c2 100644
--- a/tensorflow/python/distribute/tpu_strategy_compilation_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_compilation_test.py
@@ -21,7 +21,6 @@
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import flags
-from tensorflow.python.tpu import tpu_strategy_util
 
 FLAGS = flags.FLAGS
 flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -41,7 +40,7 @@
 def get_tpu_strategy():
   resolver = get_tpu_cluster_resolver()
   remote.connect_to_cluster(resolver)
-  tpu_strategy_util.initialize_tpu_system(resolver)
+  tpu_cluster_resolver.initialize_tpu_system(resolver)
   strategy = tpu_lib.TPUStrategyV2(resolver)
   return strategy
 
diff --git a/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py b/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
index 13c9e85..b6801e5c 100644
--- a/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_model_parallelism_test.py
@@ -47,7 +47,6 @@
 from tensorflow.python.platform import flags
 from tensorflow.python.tpu import device_assignment as device_assignment_lib
 from tensorflow.python.tpu import tpu_replication
-from tensorflow.python.tpu import tpu_strategy_util
 
 FLAGS = flags.FLAGS
 flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
@@ -67,7 +66,7 @@
 def get_tpu_strategy(enable_spmd=False):
   resolver = get_tpu_cluster_resolver()
   remote.connect_to_cluster(resolver)
-  topology = tpu_strategy_util.initialize_tpu_system(resolver)
+  topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
   num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
   device_assignment = device_assignment_lib.DeviceAssignment.build(
       topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index 115ae1e..ab3711e 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -79,7 +79,7 @@
 def get_tpu_strategy(enable_packed_var=False):
   resolver = get_tpu_cluster_resolver()
   remote.connect_to_cluster(resolver)
-  tpu_strategy_util.initialize_tpu_system(resolver)
+  tpu_cluster_resolver.initialize_tpu_system(resolver)
   strategy = tpu_lib.TPUStrategyV2(resolver)
   strategy._enable_packed_variable_in_eager_mode = enable_packed_var
   return strategy
@@ -176,12 +176,30 @@
   def test_multiple_initialize_system(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    tpu_strategy_util.initialize_tpu_system(resolver)
+    tpu_cluster_resolver.initialize_tpu_system(resolver)
 
     with test.mock.patch.object(logging, "warning") as mock_log:
-      tpu_strategy_util.initialize_tpu_system(resolver)
+      tpu_cluster_resolver.initialize_tpu_system(resolver)
       self.assertRegex(str(mock_log.call_args), "already been initialized")
 
+  def test_initialize_tpu_system_impl_input(self):
+    resolver = get_tpu_cluster_resolver()
+    with self.assertRaisesRegex(
+        TypeError,
+        r"tpu_cluster_resolver_cls is not"
+        r" tf.distribute.cluster_resolver.TPUClusterResolver."):
+      tpu_strategy_util.initialize_tpu_system_impl(
+          resolver, tpu_cluster_resolver_cls=None)
+
+  def test_shutdown_tpu_system_impl_input(self):
+    resolver = get_tpu_cluster_resolver()
+    with self.assertRaisesRegex(
+        TypeError,
+        r"tpu_cluster_resolver_cls is not"
+        r" tf.distribute.cluster_resolver.TPUClusterResolver."):
+      tpu_strategy_util.shutdown_tpu_system_impl(
+          resolver, tpu_cluster_resolver_cls=None)
+
   def test_tpu_tf_function_same_device(self):
     with ops.device("/device:TPU:0"):
       a = variables.Variable(1)
@@ -332,7 +350,7 @@
   def test_sequential_runs(self, enable_packed_var):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     # Computation replicated to all cores.
     device_assignment = device_assignment_lib.DeviceAssignment.build(
         topology, num_replicas=2)
@@ -443,7 +461,7 @@
   def test_computation_on_subset_cores(self, enable_packed_var):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
     all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
 
@@ -484,7 +502,7 @@
   def test_worker_devices_on_subset_cores(self, enable_packed_var):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
 
     # Strategy for the 1st core.
     device_assignment = device_assignment_lib.DeviceAssignment.build(
@@ -1180,7 +1198,7 @@
 
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     device_assignment = device_assignment_lib.DeviceAssignment(
         topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]
     )
@@ -1317,7 +1335,7 @@
   def test_update_config_proto(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    tpu_strategy_util.initialize_tpu_system(resolver)
+    tpu_cluster_resolver.initialize_tpu_system(resolver)
     strategy = tpu_lib.TPUStrategyV2(resolver)
 
     config_proto = config_pb2.ConfigProto()
@@ -1449,7 +1467,7 @@
   def test_core_assignment(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     device_assignment = device_assignment_lib.DeviceAssignment(
         topology, core_assignment=[[[0, 0, 0, 0]]])
     self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
@@ -1461,7 +1479,7 @@
   def test_device_assignment_strategy_properties(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     device_assignment = device_assignment_lib.DeviceAssignment(
         topology, core_assignment=[[[0, 0, 0, 0]]])
     strategy = tpu_lib.TPUStrategyV2(
@@ -1474,7 +1492,7 @@
   def test_device_assignment_constants(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
     device_assignment = device_assignment_lib.DeviceAssignment(
         topology,
         core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
@@ -1487,7 +1505,7 @@
   def test_variables_mismatched_device_assignment(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
-    topology = tpu_strategy_util.initialize_tpu_system(resolver)
+    topology = tpu_cluster_resolver.initialize_tpu_system(resolver)
 
     strategy0 = tpu_lib.TPUStrategyV2(resolver)
     self.assertEqual(
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
index e1472fe..5dd2c5a 100644
--- a/tensorflow/python/distribute/vars_test.py
+++ b/tensorflow/python/distribute/vars_test.py
@@ -42,7 +42,6 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variable_v1
 from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.util import variable_utils
 
 
@@ -971,7 +970,7 @@
     for aggregation in aggregations:
       if strategy_test_lib.is_tpu_strategy(distribution):
         resolver = tpu_cluster_resolver.TPUClusterResolver("")
-        tpu_strategy_util.initialize_tpu_system(resolver)
+        tpu_cluster_resolver.initialize_tpu_system(resolver)
       with distribution.scope():
         v = variable_v1.VariableV1(
             0.,
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index ce25754..2499f60 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1110,7 +1110,6 @@
         "//tensorflow/python:pywrap_tfe",
         "//tensorflow/python/distribute:device_util",
         "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
-        "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/platform:remote_utils",
         "//tensorflow/python/training:server_lib",
@@ -1234,7 +1233,6 @@
         ":remote",
         "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/framework:config",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "@absl_py//absl/flags",
         "@absl_py//absl/testing:absltest",
     ],
diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py
index 3f2b093..2f70bfc 100644
--- a/tensorflow/python/eager/remote.py
+++ b/tensorflow/python/eager/remote.py
@@ -22,7 +22,6 @@
 from tensorflow.python import pywrap_tfe
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute.cluster_resolver import cluster_resolver
-from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver
 from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import remote_utils
@@ -184,8 +183,10 @@
     service_leader = ""
     # Maybe enable coordination service for the communication protocol
     # TODO(b/243839559): Fix UPTC + Coordination service crashing
-    if isinstance(cluster_spec_or_resolver,
-                  tpu_cluster_resolver.TPUClusterResolver):
+    # Check if cluster_spec_or_resolver is an instance of
+    #    tpu_cluster_resolver.TPUClusterResolver
+    if (isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver)
+        and hasattr(cluster_spec_or_resolver, "tpu_hardware_feature")):
       is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master()
       service_type = remote_utils.coordination_service_type(
           protocol, is_uptc_sess)
diff --git a/tensorflow/python/eager/remote_cloud_tpu_test.py b/tensorflow/python/eager/remote_cloud_tpu_test.py
index efad4ad..d1b2371 100644
--- a/tensorflow/python/eager/remote_cloud_tpu_test.py
+++ b/tensorflow/python/eager/remote_cloud_tpu_test.py
@@ -20,7 +20,6 @@
 from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import remote
 from tensorflow.python.framework import config
-from tensorflow.python.tpu import tpu_strategy_util
 
 FLAGS = flags.FLAGS
 flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
@@ -75,7 +74,7 @@
         expected_devices,
         [device.name for device in config.list_logical_devices()])
 
-    tpu_strategy_util.initialize_tpu_system(resolver)
+    tpu_cluster_resolver.initialize_tpu_system(resolver)
 
 if __name__ == '__main__':
   absltest.main()
diff --git a/tensorflow/python/ops/memory_tests/BUILD b/tensorflow/python/ops/memory_tests/BUILD
index 5a7404c..88b24f4 100644
--- a/tensorflow/python/ops/memory_tests/BUILD
+++ b/tensorflow/python/ops/memory_tests/BUILD
@@ -19,6 +19,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:config",
@@ -27,7 +28,6 @@
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:test",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "@absl_py//absl/testing:parameterized",
     ] + tf_additional_xla_deps_py(),
 )
@@ -46,6 +46,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:config",
@@ -54,7 +55,6 @@
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:test",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "@absl_py//absl/testing:parameterized",
     ] + tf_additional_xla_deps_py(),
 )
diff --git a/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py b/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
index f5650c4..56b7204 100644
--- a/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
+++ b/tensorflow/python/ops/memory_tests/custom_gradient_memory_test.py
@@ -18,6 +18,7 @@
 
 from absl.testing import parameterized
 from tensorflow.compiler.xla.service import hlo_pb2
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import config
@@ -29,7 +30,6 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
 
 
 class RecomputeGradMemoryTest(test.TestCase, parameterized.TestCase):
@@ -116,7 +116,7 @@
     device_name = f"{device_type}:0"
     # Necessary for TFRT tests.
     if device_type == "TPU":
-      tpu_strategy_util.initialize_tpu_system()
+      tpu_cluster_resolver.initialize_tpu_system()
 
     n = 500
     with ops.device(device_name):
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index 368424d..3bba5e5 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -333,7 +333,7 @@
         "//tensorflow/python:while_loop",
         "//tensorflow/python/client:session",
         "//tensorflow/python/compiler/xla",
-        "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
+        "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:monitoring",
@@ -613,7 +613,7 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python/client:session",
-        "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py",
+        "//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:monitoring",
@@ -621,7 +621,6 @@
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/util:compat",
-        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -906,7 +905,6 @@
     tags = ["no_oss"],
     deps = [
         ":functional",
-        ":tpu_lib",
         ":tpu_py",
         ":tpu_replication",
         "//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/tpu/tests/BUILD b/tensorflow/python/tpu/tests/BUILD
index 7e8752a..33b966b 100644
--- a/tensorflow/python/tpu/tests/BUILD
+++ b/tensorflow/python/tpu/tests/BUILD
@@ -28,7 +28,6 @@
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/tpu:tpu_embedding_v2",
         "//tensorflow/python/tpu:tpu_embedding_v2_utils",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/util:nest",
         "//third_party/py/numpy",
         "@absl_py//absl/flags",
@@ -50,6 +49,7 @@
         "//tensorflow/python:init_ops_v2",
         "//tensorflow/python/checkpoint",
         "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
@@ -62,7 +62,6 @@
         "//tensorflow/python/tpu:tpu_embedding_for_serving",
         "//tensorflow/python/tpu:tpu_embedding_v2",
         "//tensorflow/python/tpu:tpu_embedding_v2_utils",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "//tensorflow/python/training:checkpoint_utils",
         "//third_party/py/numpy",
         "@absl_py//absl/testing:parameterized",
@@ -525,7 +524,6 @@
         "//tensorflow/python/compat:v2_compat",
         "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
         "//tensorflow/python/platform:client_testlib",
-        "//tensorflow/python/tpu:tpu_strategy_util",
         "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/python/tpu/tests/tpu_embedding_base_test.py b/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
index 7c751e3..cbe3b5b 100644
--- a/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
+++ b/tensorflow/python/tpu/tests/tpu_embedding_base_test.py
@@ -35,7 +35,6 @@
 from tensorflow.python.platform import test
 from tensorflow.python.tpu import tpu_embedding_v2
 from tensorflow.python.tpu import tpu_embedding_v2_utils
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.util import nest
 
 FLAGS = flags.FLAGS
@@ -154,7 +153,7 @@
       self.resolver._cloud_tpu_client.configure_tpu_version(
           version='nightly', restart_type='always')
     remote.connect_to_cluster(self.resolver)
-    tpu_strategy_util.initialize_tpu_system(self.resolver)
+    tpu_cluster_resolver.initialize_tpu_system(self.resolver)
     return tpu_strategy.TPUStrategy(self.resolver)
 
   def _create_mid_level(self, optimizer=None):
diff --git a/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py b/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
index ca83819..a725a2b 100644
--- a/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
+++ b/tensorflow/python/tpu/tests/tpu_embedding_v2_checkpoint_test.py
@@ -18,6 +18,7 @@
 import numpy as np
 from tensorflow.python.checkpoint import checkpoint as util
 from tensorflow.python.compat import v2_compat
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -31,7 +32,6 @@
 from tensorflow.python.tpu import tpu_embedding_for_serving
 from tensorflow.python.tpu import tpu_embedding_v2
 from tensorflow.python.tpu import tpu_embedding_v2_utils
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.tpu.tests import tpu_embedding_base_test
 from tensorflow.python.training import checkpoint_utils
 
@@ -86,7 +86,7 @@
         msg='Checkpoint should contain values from the first api object.')
 
     # Reinitialize the tpu.
-    tpu_strategy_util.initialize_tpu_system(self.resolver)
+    tpu_cluster_resolver.initialize_tpu_system(self.resolver)
 
     with strategy.scope():
       second_mid_level_contents = np.ones((num_rows, 4)) * 2
@@ -148,7 +148,7 @@
     first_checkpoint = util.Checkpoint(model=first_mid_level)
     first_checkpoint.save(self._get_tmpdir('restore', 'save'))
 
-    tpu_strategy_util.initialize_tpu_system(self.resolver)
+    tpu_cluster_resolver.initialize_tpu_system(self.resolver)
 
     with strategy.scope():
       second_mid_level_contents = np.ones((num_rows, 4)) * 2
diff --git a/tensorflow/python/tpu/tests/tpu_initialization_test.py b/tensorflow/python/tpu/tests/tpu_initialization_test.py
index 252f6c8..a398b2f 100644
--- a/tensorflow/python/tpu/tests/tpu_initialization_test.py
+++ b/tensorflow/python/tpu/tests/tpu_initialization_test.py
@@ -19,14 +19,13 @@
 from tensorflow.python.compat import v2_compat
 from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.platform import test
-from tensorflow.python.tpu import tpu_strategy_util
 
 
 class TPUInitializationTest(parameterized.TestCase, test.TestCase):
 
   def test_tpu_initialization(self):
     resolver = tpu_cluster_resolver.TPUClusterResolver('')
-    tpu_strategy_util.initialize_tpu_system(resolver)
+    tpu_cluster_resolver.initialize_tpu_system(resolver)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py
index 7f4fcb0..eb0ce82 100644
--- a/tensorflow/python/tpu/tpu_outside_compilation_test.py
+++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py
@@ -51,7 +51,6 @@
 from tensorflow.python.tpu import functional as tpu_functional
 from tensorflow.python.tpu import tpu
 from tensorflow.python.tpu import tpu_replication
-from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.tpu.ops import tpu_ops
 
 FLAGS = flags.FLAGS
@@ -72,7 +71,7 @@
 def get_tpu_strategy():
   resolver = get_tpu_cluster_resolver()
   remote.connect_to_cluster(resolver)
-  tpu_strategy_util.initialize_tpu_system(resolver)
+  tpu_cluster_resolver.initialize_tpu_system(resolver)
   return tpu_lib.TPUStrategyV2(resolver)
 
 
diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py
index 371a418..e122d48 100644
--- a/tensorflow/python/tpu/tpu_strategy_util.py
+++ b/tensorflow/python/tpu/tpu_strategy_util.py
@@ -18,12 +18,10 @@
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session as session_lib
-from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver
+from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import monitoring
-from tensorflow.python.eager.def_function import function
-from tensorflow.python.eager.def_function import functions_run_eagerly
-from tensorflow.python.eager.def_function import run_functions_eagerly
 from tensorflow.python.framework import device
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -31,7 +29,6 @@
 from tensorflow.python.tpu import topology
 from tensorflow.python.tpu import tpu
 from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
 
 
 _INITIALIZED_TPU_SYSTEMS = {}
@@ -43,13 +40,19 @@
     "The worker address that the coordinator/client connects to.", "address")
 
 
-@tf_export("tpu.experimental.initialize_tpu_system")
-def initialize_tpu_system(cluster_resolver=None):
-  """Initialize the TPU devices.
+def initialize_tpu_system_impl(cluster_resolver, tpu_cluster_resolver_cls):
+  """Implementation for tpu.experimental.initialize_tpu_system.
+
+  Kept separate to avoid tpu_oss code duplication.
+
+  Initialize the TPU devices.
 
   Args:
     cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
         which provides information about the TPU cluster.
+    tpu_cluster_resolver_cls: a reference to
+        tf.distribute.cluster_resolver.TPUClusterResolver so that an instance
+        of it can be initialized if cluster_resolver is None.
   Returns:
     The tf.tpu.Topology object for the topology of the TPU cluster. If called
     inside tf.function, it returns the serialized topology object instead.
@@ -57,8 +60,17 @@
   Raises:
     RuntimeError: If running inside a tf.function.
     NotFoundError: If no TPU devices found in eager mode.
+    TypeError: If tpu_cluster_resolver_cls is
+        not tf.distribute.cluster_resolver.TPUClusterResolver.
   """
-
+  # check that tpu_cluster_resolver_cls is a
+  # tf.distribute.cluster_resolver.TPUClusterResolver
+  if tpu_cluster_resolver_cls is None or not issubclass(
+      tpu_cluster_resolver_cls, cluster_resolver_lib.ClusterResolver
+  ) or not hasattr(tpu_cluster_resolver_cls, "tpu_hardware_feature"):
+    raise TypeError(
+        "tpu_cluster_resolver_cls is not"
+        " tf.distribute.cluster_resolver.TPUClusterResolver.")
   # Deallocate all TPU buffers by clearing out eager context caches and
   # triggering garbage collection to avoid keeping invalid tpu buffer around
   # after reinitialized tpu system.
@@ -76,8 +88,8 @@
       if curr_device.job is not None:
         job = "{}/replica:0/task:0".format(curr_device.job)
 
-    cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("")
-  assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver)
+    cluster_resolver = tpu_cluster_resolver_cls("")
+  assert isinstance(cluster_resolver, tpu_cluster_resolver_cls)
 
   tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
   if tpu_name in _INITIALIZED_TPU_SYSTEMS:
@@ -99,7 +111,7 @@
     job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
 
   if context.executing_eagerly():
-    @function(autograph=False)
+    @def_function.function(autograph=False)
     def _tpu_init_fn():
       # In TF1, we usually close chips when compilation fails to clear the data
       # in infeed. In TF2, we don't need to do this because infeed is no longer
@@ -113,7 +125,7 @@
     # The TPU_SYSTEM device must match the device used in tpu.initialize_system
     # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
     # devices available.
-    run_eagerly = functions_run_eagerly()
+    run_eagerly = def_function.functions_run_eagerly()
     if run_eagerly:
       logging.warning(
           "It looks like tf.function behavior was disabled, perhaps using"
@@ -121,7 +133,7 @@
           " tf.tpu.experimental.initialize_tpu_system requires tf.function to"
           " work. This primitive will override the disable."
       )
-      run_functions_eagerly(False)
+      def_function.run_functions_eagerly(False)
     try:
       with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
         output = _tpu_init_fn()
@@ -133,7 +145,7 @@
           + str(e))
     finally:
       if run_eagerly is not None:
-        run_functions_eagerly(run_eagerly)
+        def_function.run_functions_eagerly(run_eagerly)
     # Clear out the eager context caches since the memory is invalid now.
     context.context()._initialize_logical_devices()  # pylint: disable=protected-access
 
@@ -181,9 +193,12 @@
   return _INITIALIZED_TPU_SYSTEMS.copy()
 
 
-@tf_export("tpu.experimental.shutdown_tpu_system")
-def shutdown_tpu_system(cluster_resolver=None):
-  """Shuts down the TPU devices.
+def shutdown_tpu_system_impl(cluster_resolver, tpu_cluster_resolver_cls):
+  """Implementation for tpu.experimental.shutdown_tpu_system.
+
+  Kept separate to avoid tpu_oss code duplication.
+
+  Shuts down the TPU devices.
 
   This will clear all caches, even those that are maintained through sequential
   calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
@@ -192,11 +207,25 @@
   Args:
     cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
         which provides information about the TPU cluster.
+    tpu_cluster_resolver_cls: a reference to
+        tf.distribute.cluster_resolver.TPUClusterResolver so that an instance
+        of it can be initialized if cluster_resolver is None.
 
   Raises:
     RuntimeError: If no TPU devices found for eager execution or if run in a
         tf.function.
+    TypeError: If tpu_cluster_resolver_cls is
+        not tf.distribute.cluster_resolver.TPUClusterResolver.
   """
+  # check that tpu_cluster_resolver_cls is a
+  # tf.distribute.cluster_resolver.TPUClusterResolver
+  if tpu_cluster_resolver_cls is None or not issubclass(
+      tpu_cluster_resolver_cls, cluster_resolver_lib.ClusterResolver
+  ) or not hasattr(tpu_cluster_resolver_cls, "tpu_hardware_feature"):
+    raise TypeError(
+        "tpu_cluster_resolver_cls is not"
+        " tf.distribute.cluster_resolver.TPUClusterResolver.")
+
   job = None
   if cluster_resolver is None:
     # If no cluster resolver is specified, and running eagerly, execute the init
@@ -206,8 +235,8 @@
       if curr_device.job is not None:
         job = "{}/replica:0/task:0".format(curr_device.job)
 
-    cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("")
-  assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver)
+    cluster_resolver = tpu_cluster_resolver_cls("")
+  assert isinstance(cluster_resolver, tpu_cluster_resolver_cls)
 
   tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
   if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
@@ -227,14 +256,14 @@
       # avoid the output node match multiple devices error.
       job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
 
-    @function(autograph=False)
+    @def_function.function(autograph=False)
     def _tpu_shutdown_fn():
       tpu.shutdown_system(job=job)
 
     # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
     # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
     # devices available.
-    run_eagerly = functions_run_eagerly()
+    run_eagerly = def_function.functions_run_eagerly()
     if run_eagerly:
       logging.warning(
           "It looks like tf.function behavior was disabled, perhaps using"
@@ -242,13 +271,13 @@
           " tf.tpu.experimental.shutdown_tpu_system requires tf.function to"
           " work. This primitive will override the disable."
       )
-      run_functions_eagerly(False)
+      def_function.run_functions_eagerly(False)
     try:
       with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
         _tpu_shutdown_fn()
     finally:
       if run_eagerly is not None:
-        run_functions_eagerly(run_eagerly)
+        def_function.run_functions_eagerly(run_eagerly)
 
     # Clear out the eager context caches since the memory is invalid now.
     logging.info("Clearing out eager caches")