| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Implementation of Cluster Resolvers for Cloud TPUs.""" |
| |
| import collections |
| import re |
| |
| from tensorflow.core.protobuf.tpu import topology_pb2 |
| from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib |
| from tensorflow.python.eager import remote |
| from tensorflow.python.framework import config as framework_config |
| from tensorflow.python.framework import errors |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.tpu import tpu_strategy_util |
| from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.util import compat |
| |
| try: |
| from cloud_tpu_client import client # pylint: disable=g-import-not-at-top |
| except ImportError: |
| logging.debug( |
| 'Falling back to TensorFlow client; we recommended you install the Cloud ' |
| 'TPU client directly with pip install cloud-tpu-client.') |
| from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top |
| |
| |
| def is_running_in_gce(): |
| return True |
| |
| |
| class _LocalCloudTpuClient(object): |
| """Dummy local Cloud TPU client.""" |
| |
| def api_available(self): |
| return False |
| |
| |
| _TPU_DEVICE_REGEX = re.compile( |
| r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') |
| _TPU_CONN_RETRIES = 120 |
| DeviceDetails = collections.namedtuple( |
| 'DeviceDetails', ['device_map', 'total_cores']) |
| |
| |
| def initialize_tpu_system(cluster_resolver=None): |
| """Initialize the TPU devices. |
| |
| Args: |
| cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, |
| which provides information about the TPU cluster. |
| Returns: |
| The tf.tpu.Topology object for the topology of the TPU cluster. If called |
| inside tf.function, it returns the serialized topology object instead. |
| |
| Raises: |
| RuntimeError: If running inside a tf.function. |
| NotFoundError: If no TPU devices found in eager mode. |
| """ |
| return tpu_strategy_util.initialize_tpu_system_impl( |
| cluster_resolver, TPUClusterResolver) |
| |
| |
| def shutdown_tpu_system(cluster_resolver=None): |
| """Shuts down the TPU devices. |
| |
| This will clear all caches, even those that are maintained through sequential |
| calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation |
| cache. |
| |
| Args: |
| cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, |
| which provides information about the TPU cluster. |
| |
| Raises: |
| RuntimeError: If no TPU devices found for eager execution or if run in a |
| tf.function. |
| """ |
| tpu_strategy_util.shutdown_tpu_system_impl( |
| cluster_resolver, TPUClusterResolver) |
| |
| |
| class TPUClusterResolver(cluster_resolver_lib.ClusterResolver): |
| """Cluster Resolver for Google Cloud TPUs. |
| |
| This is an implementation of cluster resolvers for the Google Cloud TPU |
| service. |
| |
| TPUClusterResolver supports the following distinct environments: |
| Google Compute Engine |
| Google Kubernetes Engine |
| Google internal |
| |
| It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on |
| Cloud TPUs. |
| """ |
| |
| @staticmethod |
| def connect(tpu=None, |
| zone=None, |
| project=None): |
| """Initializes TPU and returns a TPUClusterResolver. |
| |
| This API will connect to remote TPU cluster and initialize the TPU |
| hardwares. Example usage: |
| |
| >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( |
| ... tpu='') |
| |
| It can be viewed as convenient wrapper of the following code: |
| |
| >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| >>> tf.config.experimental_connect_to_cluster(resolver) |
| >>> tf.tpu.experimental.initialize_tpu_system(resolver) |
| |
| Args: |
| tpu: A string corresponding to the TPU to use. It can be the TPU name or |
| TPU worker gRPC address. If not set, it will try automatically resolve |
| the TPU address on Cloud TPUs. |
| zone: Zone where the TPUs are located. If omitted or empty, we will assume |
| that the zone of the TPU is the same as the zone of the GCE VM, which we |
| will try to discover from the GCE metadata service. |
| project: Name of the GCP project containing Cloud TPUs. If omitted or |
| empty, we will try to discover the project name of the GCE VM from the |
| GCE metadata service. |
| |
| Returns: |
| An instance of TPUClusterResolver object. |
| |
| Raises: |
| NotFoundError: If no TPU devices found in eager mode. |
| """ |
| resolver = TPUClusterResolver(tpu, zone, project) |
| remote.connect_to_cluster(resolver) |
| tpu_strategy_util.initialize_tpu_system_impl(resolver) |
| return resolver |
| |
| @staticmethod |
| def _get_device_dict_and_cores(devices): |
| """Returns a dict of hosts to cores and total cores given devices names. |
| |
| Returns a namedtuple with two attributes: |
| device_map: A map of host_ids to a list of core_ids. |
| total_cores: The total number of cores within the TPU system. |
| |
| Args: |
| devices: A list of devices returned by session.list_devices() |
| """ |
| device_map = collections.defaultdict(list) |
| num_cores = 0 |
| for device in devices: |
| match = _TPU_DEVICE_REGEX.match(device.name) |
| if match: |
| host_id = match.group('host_id') |
| core_id = match.group('core_id') |
| device_map[host_id].append(core_id) |
| num_cores += 1 |
| return DeviceDetails(device_map, num_cores) |
| |
| @staticmethod |
| def _verify_and_return_same_core_count(device_dict): |
| """Verifies that every device in device_dict has the same # of cores.""" |
| num_cores_per_host_set = ( |
| {len(core_ids) for core_ids in device_dict.values()}) |
| if len(num_cores_per_host_set) != 1: |
| raise RuntimeError('TPU cores on each device is not the same. This ' |
| 'should never happen. Devices: {}'.format(device_dict)) |
| return num_cores_per_host_set.pop() |
| |
| def __init__(self, |
| tpu=None, |
| zone=None, |
| project=None, |
| job_name='worker', |
| coordinator_name=None, |
| coordinator_address=None, |
| credentials='default', |
| service=None, |
| discovery_url=None): |
| """Creates a new TPUClusterResolver object. |
| |
| The ClusterResolver will then use the parameters to query the Cloud TPU APIs |
| for the IP addresses and ports of each Cloud TPU listed. |
| |
| Args: |
| tpu: A string corresponding to the TPU to use. It can be the TPU name or |
| TPU worker gRPC address. If not set, it will try automatically resolve |
| the TPU address on Cloud TPUs. If set to "local", it will assume that |
| the TPU is directly connected to the VM instead of over the network. |
| zone: Zone where the TPUs are located. If omitted or empty, we will assume |
| that the zone of the TPU is the same as the zone of the GCE VM, which we |
| will try to discover from the GCE metadata service. |
| project: Name of the GCP project containing Cloud TPUs. If omitted or |
| empty, we will try to discover the project name of the GCE VM from the |
| GCE metadata service. |
| job_name: Name of the TensorFlow job the TPUs belong to. |
| coordinator_name: The name to use for the coordinator. Set to None if the |
| coordinator should not be included in the computed ClusterSpec. |
| coordinator_address: The address of the coordinator (typically an ip:port |
| pair). If set to None, a TF server will be started. If coordinator_name |
| is None, a TF server will not be started even if coordinator_address is |
| None. |
| credentials: GCE Credentials. If None, then we use default credentials |
| from the oauth2client |
| service: The GCE API object returned by the googleapiclient.discovery |
| function. If you specify a custom service object, then the credentials |
| parameter will be ignored. |
| discovery_url: A URL template that points to the location of the discovery |
| service. It should have two parameters {api} and {apiVersion} that when |
| filled in produce an absolute URL to the discovery document for that |
| service. The environment variable 'TPU_API_DISCOVERY_URL' will override |
| this. |
| |
| Raises: |
| ImportError: If the googleapiclient is not installed. |
| ValueError: If no TPUs are specified. |
| RuntimeError: If an empty TPU name is specified and this is running in a |
| Google Cloud environment. |
| """ |
| |
| if tpu != 'local': |
| # Default Cloud environment |
| self._cloud_tpu_client = client.Client( |
| tpu=tpu, |
| zone=zone, |
| project=project, |
| credentials=credentials, |
| service=service, |
| discovery_url=discovery_url) |
| self._tpu = self._cloud_tpu_client.name() |
| else: |
| # Directly connected TPU environment |
| self._cloud_tpu_client = _LocalCloudTpuClient() |
| self._tpu = 'local' |
| |
| # By default the task_type is 'worker` and the task_id is 0 (which is the |
| # first worker in the task). |
| self.task_type = job_name |
| self.task_id = 0 |
| self._coordinator_name = coordinator_name |
| if (coordinator_name and not coordinator_address): |
| self._start_local_server() |
| else: |
| self._coordinator_address = coordinator_address |
| |
| self._tpu_topology = None |
| |
| def __enter__(self): |
| self._cloud_tpu_client.enter() |
| |
| def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin |
| self._cloud_tpu_client.exit(type, value, traceback) |
| |
| def master(self, task_type=None, task_id=None, rpc_layer=None): |
| """Get the Master string to be used for the session. |
| |
| In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of |
| first instance in the ClusterSpec returned by the cluster_spec function. |
| |
| If a non-TPU name is used when constructing a TPUClusterResolver, that will |
| be returned instead (e.g. If the tpus argument's value when constructing |
| this TPUClusterResolver was 'grpc://10.240.1.2:8470', |
| 'grpc://10.240.1.2:8470' will be returned). |
| |
| Args: |
| task_type: (Optional, string) The type of the TensorFlow task of the |
| master. |
| task_id: (Optional, integer) The index of the TensorFlow task of the |
| master. |
| rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to |
| communicate with TPUs. |
| |
| Returns: |
| string, the connection string to use when creating a session. |
| |
| Raises: |
| ValueError: If none of the TPUs specified exists. |
| """ |
| |
| if self._tpu != 'local': |
| cluster_spec = self.cluster_spec() |
| if task_type is not None and task_id is not None: |
| # task_type and task_id is from the function parameter |
| master = cluster_spec.task_address(task_type, task_id) |
| elif self.task_type is not None and self.task_id is not None: |
| # task_type and task_id is from the object |
| master = cluster_spec.task_address(self.task_type, self.task_id) |
| else: |
| # by default we take the first item in the cluster with the right name |
| job_tasks = cluster_spec.job_tasks(self.task_type) |
| if not job_tasks: |
| raise ValueError('No TPUs with the specified names exist.') |
| master = job_tasks[0] |
| return cluster_resolver_lib.format_master_url(master, 'grpc') |
| else: |
| return '' |
| |
| def get_master(self): |
| return self.master() |
| |
| def get_job_name(self): |
| return self.task_type |
| |
| def get_coordination_service_leader(self): |
| """Returns the location for coordination service. |
| |
| The coordination service should be located on TPU worker0. |
| |
| Returns: |
| A string indicate the location path. |
| """ |
| return '/job:' + self.get_job_name() + '/task:0' |
| |
| def get_tpu_system_metadata(self): |
| """Returns the metadata of the TPU system. |
| |
| Users can call this method to get some facts of the TPU system, like |
| total number of cores, number of TPU workers and the devices. E.g. |
| ```python |
| |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tpu_system_metadata = resolver.get_tpu_system_metadata() |
| num_hosts = tpu_system_metadata.num_hosts |
| ``` |
| |
| Returns: |
| A `tf.tpu.experimental.TPUSystemMetadata` object. |
| """ |
| cluster_spec = self.cluster_spec() |
| cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None |
| tpu_system_metadata = ( |
| tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access |
| self.master(), |
| cluster_def=cluster_def, |
| query_topology=False)) |
| |
| return tpu_system_metadata |
| |
| def cluster_spec(self): |
| """Returns a ClusterSpec object based on the latest TPU information. |
| |
| We retrieve the information from the GCE APIs every time this method is |
| called. |
| |
| Returns: |
| A ClusterSpec containing host information returned from Cloud TPUs, |
| or None. |
| |
| Raises: |
| RuntimeError: If the provided TPU is not healthy. |
| """ |
| ############################################################################ |
| # There are 6 potential cases this code must handle: |
| # 0. [Local case.] When a TPU is connected directly to the VM. |
| # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and |
| # a. Create a ClusterSpec that includes the coordinator job |
| # b. Create a ClusterSpec without the coordinator job. |
| # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of |
| # tasks and |
| # a. Create a ClusterSpec with the coordinator |
| # b. Create a ClusterSpec without the coordinator |
| ############################################################################ |
| |
| if self._tpu != 'local': |
| network_endpoints = self._cloud_tpu_client.network_endpoints() |
| worker_list = [ |
| '%s:%s' % (endpoint['ipAddress'], endpoint['port']) |
| for endpoint in network_endpoints |
| ] |
| cluster_spec = {self.task_type: worker_list} |
| if self._coordinator_address: |
| # {1, 2}.a |
| cluster_spec[self._coordinator_name] = [self._coordinator_address] |
| return server_lib.ClusterSpec(cluster_spec) |
| else: |
| return server_lib.ClusterSpec({}) |
| |
| def num_accelerators(self, |
| task_type=None, |
| task_id=None, |
| config_proto=None): |
| """Returns the number of TPU cores per worker. |
| |
| Connects to the master and list all the devices present in the master, |
| and counts them up. Also verifies that the device counts per host in the |
| cluster is the same before returning the number of TPU cores per host. |
| |
| Args: |
| task_type: Unused. |
| task_id: Unused. |
| config_proto: Used to create a connection to a TPU master in order to |
| retrieve the system metadata. |
| |
| Raises: |
| RuntimeError: If we cannot talk to a TPU worker after retrying or if the |
| number of TPU devices per host is different. |
| """ |
| if self._tpu == 'local': |
| return { |
| 'TPU': |
| len([ |
| d for d in framework_config.list_logical_devices() |
| if d.device_type == 'TPU' |
| ]) |
| } |
| |
| retry_count = 1 |
| # TODO(b/120564445): Replace with standard library for retries. |
| while True: |
| try: |
| device_details = TPUClusterResolver._get_device_dict_and_cores( |
| cluster_resolver_lib.get_accelerator_devices( |
| self.master(), config_proto=config_proto)) |
| break |
| except errors.DeadlineExceededError: |
| error_message = ('Failed to connect to master. The TPU might not be ' |
| 'ready (e.g. still scheduling) or the master ' |
| 'address is incorrect: got (%s)' % self.master()) |
| if retry_count <= _TPU_CONN_RETRIES: |
| logging.warning(error_message) |
| logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES) |
| retry_count += 1 |
| else: |
| raise RuntimeError(error_message) |
| |
| if device_details.total_cores: |
| return { |
| 'TPU': |
| TPUClusterResolver._verify_and_return_same_core_count( |
| device_details.device_map) |
| } |
| return {'TPU': 0} |
| |
| def set_tpu_topology(self, serialized_tpu_topology): |
| """Sets the tpu topology info stored in this resolver.""" |
| self._tpu_topology = topology_pb2.TopologyProto() |
| self._tpu_topology.ParseFromString(serialized_tpu_topology) |
| |
| @property |
| def tpu_hardware_feature(self): |
| """Returns the tpu topology info stored.""" |
| if self._tpu_topology is None: |
| return self._tpu_topology |
| return self._tpu_topology.tpu_hardware_feature |
| |
| @property |
| def environment(self): |
| """Returns the current environment which TensorFlow is running in.""" |
| return self._environment |
| |
| def _start_local_server(self): |
| address = compat.as_text(self._cloud_tpu_client.get_local_ip()) |
| self._server = server_lib.Server({'local': ['0.0.0.0:0']}, |
| protocol='grpc', |
| config=None, |
| start=True) |
| # self._server.target is of the form: grpc://ipaddress:port |
| target = compat.as_bytes(self._server.target) |
| splits = target.split(compat.as_bytes(':')) |
| assert len(splits) == 3, self._server.target |
| assert splits[0] == compat.as_bytes('grpc'), self._server.target |
| self._coordinator_port = compat.as_text(splits[2]) |
| self._coordinator_address = '%s:%s' % ( |
| address, compat.as_text(self._coordinator_port)) |
| |
| def __deepcopy__(self, memo): |
| # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. |
| return self |