| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """TPU specific APIs to be used in conjunction with TPU Strategy.""" |
| |
| import gc |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.client import session as session_lib |
| from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import monitoring |
| from tensorflow.python.eager.def_function import function |
| from tensorflow.python.eager.def_function import functions_run_eagerly |
| from tensorflow.python.eager.def_function import run_functions_eagerly |
| from tensorflow.python.framework import device |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.tpu import topology |
| from tensorflow.python.tpu import tpu |
| from tensorflow.python.util import compat |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| _INITIALIZED_TPU_SYSTEMS = {} |
| _LOCAL_MASTERS = ("", "local") |
| |
| |
| _tpu_worker_address = monitoring.StringGauge( |
| "/tensorflow/tpu/worker_address", |
| "The worker address that the coordinator/client connects to.", "address") |
| |
| |
| @tf_export("tpu.experimental.initialize_tpu_system") |
| def initialize_tpu_system(cluster_resolver=None): |
| """Initialize the TPU devices. |
| |
| Args: |
| cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, |
| which provides information about the TPU cluster. |
| Returns: |
| The tf.tpu.Topology object for the topology of the TPU cluster. If called |
| inside tf.function, it returns the serialized topology object instead. |
| |
| Raises: |
| RuntimeError: If running inside a tf.function. |
| NotFoundError: If no TPU devices found in eager mode. |
| """ |
| |
| # Deallocate all TPU buffers by clearing out eager context caches and |
| # triggering garbage collection to avoid keeping invalid tpu buffer around |
| # after reinitialized tpu system. |
| logging.info("Deallocate tpu buffers before initializing tpu system.") |
| context.context()._clear_caches() # pylint: disable=protected-access |
| context.context().clear_kernel_cache() |
| gc.collect() |
| |
| job = None |
| if cluster_resolver is None: |
| # If no cluster resolver is specified, and running eagerly, execute the init |
| # ops in the current device scope. |
| if context.executing_eagerly(): |
| curr_device = device.DeviceSpec.from_string(context.context().device_name) |
| if curr_device.job is not None: |
| job = "{}/replica:0/task:0".format(curr_device.job) |
| |
| cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("") |
| assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver) |
| |
| tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access |
| if tpu_name in _INITIALIZED_TPU_SYSTEMS: |
| logging.warning( |
| "TPU system %s has already been initialized. " |
| "Reinitializing the TPU can cause previously created " |
| "variables on TPU to be lost.", tpu_name) |
| |
| logging.info("Initializing the TPU system: %s", tpu_name) |
| |
| # This function looks as it is for the following non-intuitive reasons. |
| # tpu.initialize_system creates a dummy op whose sole purpose is to trigger |
| # DistributedTPURewritePass. This pass actually adds real ops that |
| # initialize the TPU system. Thus, we can't simply run tpu.initialize_system |
| # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. |
| if tpu_name not in _LOCAL_MASTERS: |
| # Explicitly place the tpu.initialize_system in the first worker to |
| # avoid the output node match multiple devices error. |
| job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) |
| |
| if context.executing_eagerly(): |
| @function(autograph=False) |
| def _tpu_init_fn(): |
| # In TF1, we usually close chips when compilation fails to clear the data |
| # in infeed. In TF2, we don't need to do this because infeed is no longer |
| # used, so user can recover from TPU compilation failures more smoothly. |
| # Same for the cancellation of a TPU excution. |
| return tpu.initialize_system( |
| job=job, |
| compilation_failure_closes_chips=False, |
| tpu_cancellation_closes_chips=False) |
| |
| # The TPU_SYSTEM device must match the device used in tpu.initialize_system |
| # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM |
| # devices available. |
| run_eagerly = functions_run_eagerly() |
| if run_eagerly: |
| logging.warning( |
| "It looks like tf.function behavior was disabled, perhaps using" |
| " tf.config.run_functions_eagerly." |
| " tf.tpu.experimental.initialize_tpu_system requires tf.function to" |
| " work. This primitive will override the disable." |
| ) |
| run_functions_eagerly(False) |
| try: |
| with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access |
| output = _tpu_init_fn() |
| context.async_wait() |
| except errors.InvalidArgumentError as e: |
| raise errors.NotFoundError( |
| None, None, |
| "TPUs not found in the cluster. Failed in initialization: " |
| + str(e)) |
| finally: |
| if run_eagerly is not None: |
| run_functions_eagerly(run_eagerly) |
| # Clear out the eager context caches since the memory is invalid now. |
| context.context()._initialize_logical_devices() # pylint: disable=protected-access |
| |
| serialized_topology = output.numpy() |
| elif not ops.executing_eagerly_outside_functions(): |
| master = cluster_resolver.master() |
| cluster_spec = cluster_resolver.cluster_spec() |
| |
| session_config = config_pb2.ConfigProto(allow_soft_placement=True) |
| if cluster_spec: |
| session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) |
| |
| with ops.Graph().as_default(): |
| with session_lib.Session(config=session_config, target=master) as sess: |
| serialized_topology = sess.run(tpu.initialize_system()) |
| else: |
| with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access |
| serialized_topology = tpu.initialize_system( |
| job=job, compilation_failure_closes_chips=False) |
| # If initialize_tpu_system is called inside tf.function, we only return |
| # the serialized topology object as the tf.tpu.Topology object has to be |
| # constructed in eager mode. |
| return serialized_topology |
| |
| logging.info("Finished initializing TPU system.") |
| tpu_topology = topology.Topology(serialized=serialized_topology) |
| cluster_resolver.set_tpu_topology(serialized_topology) |
| _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology |
| |
| # Record the address of the TPU worker-0 that the coordinator connects to. |
| # This can be used to associate the TPU worker with the right coordinator when |
| # aggregating the metrics for the application. An example of the address: |
| # /bns/mb/borg/mb/bns/chienchunh/chienchunh_group_49640234.1.tfm_train_tpu_worker/0 |
| _tpu_worker_address.get_cell("address").set(cluster_resolver.get_master()) |
| |
| return tpu_topology |
| |
| |
| def get_initialized_tpu_systems(): |
| """Returns all currently initialized tpu systems. |
| |
| Returns: |
| A dictionary, with tpu name as the key and the tpu topology as the value. |
| """ |
| return _INITIALIZED_TPU_SYSTEMS.copy() |
| |
| |
| @tf_export("tpu.experimental.shutdown_tpu_system") |
| def shutdown_tpu_system(cluster_resolver=None): |
| """Shuts down the TPU devices. |
| |
| This will clear all caches, even those that are maintained through sequential |
| calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation |
| cache. |
| |
| Args: |
| cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, |
| which provides information about the TPU cluster. |
| |
| Raises: |
| RuntimeError: If no TPU devices found for eager execution or if run in a |
| tf.function. |
| """ |
| job = None |
| if cluster_resolver is None: |
| # If no cluster resolver is specified, and running eagerly, execute the init |
| # ops in the current device scope. |
| if context.executing_eagerly(): |
| curr_device = device.DeviceSpec.from_string(context.context().device_name) |
| if curr_device.job is not None: |
| job = "{}/replica:0/task:0".format(curr_device.job) |
| |
| cluster_resolver = tpu_cluster_resolver.TPUClusterResolver("") |
| assert isinstance(cluster_resolver, tpu_cluster_resolver.TPUClusterResolver) |
| |
| tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access |
| if tpu_name not in _INITIALIZED_TPU_SYSTEMS: |
| logging.warning("You are shutting down a TPU system %s that has not been " |
| "initialized." % tpu_name) |
| |
| logging.info("Shutting down the TPU system: %s", tpu_name) |
| |
| if context.executing_eagerly(): |
| # This function looks as it is for the following non-intuitive reasons. |
| # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger |
| # DistributedTPURewritePass. This pass actually adds real ops that |
| # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system |
| # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. |
| if tpu_name not in _LOCAL_MASTERS: |
| # Explicitly place the tpu.shutdown_system in the first worker to |
| # avoid the output node match multiple devices error. |
| job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) |
| |
| @function(autograph=False) |
| def _tpu_shutdown_fn(): |
| tpu.shutdown_system(job=job) |
| |
| # The TPU_SYSTEM device must match the device used in tpu.shutdown_system |
| # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM |
| # devices available. |
| run_eagerly = functions_run_eagerly() |
| if run_eagerly: |
| logging.warning( |
| "It looks like tf.function behavior was disabled, perhaps using" |
| " tf.config.run_functions_eagerly." |
| " tf.tpu.experimental.shutdown_tpu_system requires tf.function to" |
| " work. This primitive will override the disable." |
| ) |
| run_functions_eagerly(False) |
| try: |
| with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access |
| _tpu_shutdown_fn() |
| finally: |
| if run_eagerly is not None: |
| run_functions_eagerly(run_eagerly) |
| |
| # Clear out the eager context caches since the memory is invalid now. |
| logging.info("Clearing out eager caches") |
| context.context()._clear_caches() # pylint: disable=protected-access |
| context.context().clear_kernel_cache() |
| elif not ops.executing_eagerly_outside_functions(): |
| master = cluster_resolver.master() |
| cluster_spec = cluster_resolver.cluster_spec() |
| |
| session_config = config_pb2.ConfigProto(allow_soft_placement=True) |
| if cluster_spec: |
| session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) |
| |
| with ops.Graph().as_default(): |
| with session_lib.Session(config=session_config, target=master) as sess: |
| sess.run(tpu.shutdown_system()) |
| else: |
| raise RuntimeError( |
| "initialize_tpu_system is not supported within " |
| "tf.functions. You should call initialize_tpu_system outside of your tf.function. " |
| ) |
| |
| logging.info("Finished shutting down TPU system.") |
| if tpu_name in _INITIALIZED_TPU_SYSTEMS: |
| del _INITIALIZED_TPU_SYSTEMS[tpu_name] |