| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Implementation of ClusterResolvers for GCE instance groups.""" |
| |
| from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver |
| from tensorflow.python.training.server_lib import ClusterSpec |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| _GOOGLE_API_CLIENT_INSTALLED = True |
| try: |
| from googleapiclient import discovery # pylint: disable=g-import-not-at-top |
| from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top |
| except ImportError: |
| _GOOGLE_API_CLIENT_INSTALLED = False |
| |
| |
| @tf_export('distribute.cluster_resolver.GCEClusterResolver') |
| class GCEClusterResolver(ClusterResolver): |
| """ClusterResolver for Google Compute Engine. |
| |
| This is an implementation of cluster resolvers for the Google Compute Engine |
| instance group platform. By specifying a project, zone, and instance group, |
| this will retrieve the IP address of all the instances within the instance |
| group and return a ClusterResolver object suitable for use for distributed |
| TensorFlow. |
| |
| Note: this cluster resolver cannot retrieve `task_type`, `task_id` or |
| `rpc_layer`. To use it with some distribution strategies like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to |
| specify `task_type` and `task_id` in the constructor. |
| |
| Usage example with tf.distribute.Strategy: |
| |
| ```Python |
| # On worker 0 |
| cluster_resolver = GCEClusterResolver("my-project", "us-west1", |
| "my-instance-group", |
| task_type="worker", task_id=0) |
| strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( |
| cluster_resolver=cluster_resolver) |
| |
| # On worker 1 |
| cluster_resolver = GCEClusterResolver("my-project", "us-west1", |
| "my-instance-group", |
| task_type="worker", task_id=1) |
| strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( |
| cluster_resolver=cluster_resolver) |
| ``` |
| """ |
| |
| def __init__(self, |
| project, |
| zone, |
| instance_group, |
| port, |
| task_type='worker', |
| task_id=0, |
| rpc_layer='grpc', |
| credentials='default', |
| service=None): |
| """Creates a new GCEClusterResolver object. |
| |
| This takes in a few parameters and creates a GCEClusterResolver project. It |
| will then use these parameters to query the GCE API for the IP addresses of |
| each instance in the instance group. |
| |
| Args: |
| project: Name of the GCE project. |
| zone: Zone of the GCE instance group. |
| instance_group: Name of the GCE instance group. |
| port: Port of the listening TensorFlow server (default: 8470) |
| task_type: Name of the TensorFlow job this GCE instance group of VM |
| instances belong to. |
| task_id: The task index for this particular VM, within the GCE |
| instance group. In particular, every single instance should be assigned |
| a unique ordinal index within an instance group manually so that they |
| can be distinguished from each other. |
| rpc_layer: The RPC layer TensorFlow should use to communicate across |
| instances. |
| credentials: GCE Credentials. If nothing is specified, this defaults to |
| GoogleCredentials.get_application_default(). |
| service: The GCE API object returned by the googleapiclient.discovery |
| function. (Default: discovery.build('compute', 'v1')). If you specify a |
| custom service object, then the credentials parameter will be ignored. |
| |
| Raises: |
| ImportError: If the googleapiclient is not installed. |
| """ |
| self._project = project |
| self._zone = zone |
| self._instance_group = instance_group |
| self._task_type = task_type |
| self._task_id = task_id |
| self._rpc_layer = rpc_layer |
| self._port = port |
| self._credentials = credentials |
| |
| if credentials == 'default': |
| if _GOOGLE_API_CLIENT_INSTALLED: |
| self._credentials = GoogleCredentials.get_application_default() |
| |
| if service is None: |
| if not _GOOGLE_API_CLIENT_INSTALLED: |
| raise ImportError('googleapiclient must be installed before using the ' |
| 'GCE cluster resolver') |
| self._service = discovery.build( |
| 'compute', 'v1', |
| credentials=self._credentials) |
| else: |
| self._service = service |
| |
| def cluster_spec(self): |
| """Returns a ClusterSpec object based on the latest instance group info. |
| |
| This returns a ClusterSpec object for use based on information from the |
| specified instance group. We will retrieve the information from the GCE APIs |
| every time this method is called. |
| |
| Returns: |
| A ClusterSpec containing host information retrieved from GCE. |
| """ |
| request_body = {'instanceState': 'RUNNING'} |
| request = self._service.instanceGroups().listInstances( |
| project=self._project, |
| zone=self._zone, |
| instanceGroups=self._instance_group, |
| body=request_body, |
| orderBy='name') |
| |
| worker_list = [] |
| |
| while request is not None: |
| response = request.execute() |
| |
| items = response['items'] |
| for instance in items: |
| instance_name = instance['instance'].split('/')[-1] |
| |
| instance_request = self._service.instances().get( |
| project=self._project, |
| zone=self._zone, |
| instance=instance_name) |
| |
| if instance_request is not None: |
| instance_details = instance_request.execute() |
| ip_address = instance_details['networkInterfaces'][0]['networkIP'] |
| instance_url = '%s:%s' % (ip_address, self._port) |
| worker_list.append(instance_url) |
| |
| request = self._service.instanceGroups().listInstances_next( |
| previous_request=request, |
| previous_response=response) |
| |
| worker_list.sort() |
| return ClusterSpec({self._task_type: worker_list}) |
| |
| def master(self, task_type=None, task_id=None, rpc_layer=None): |
| task_type = task_type if task_type is not None else self._task_type |
| task_id = task_id if task_id is not None else self._task_id |
| |
| if task_type is not None and task_id is not None: |
| master = self.cluster_spec().task_address(task_type, task_id) |
| if rpc_layer or self._rpc_layer: |
| return '%s://%s' % (rpc_layer or self._rpc_layer, master) |
| else: |
| return master |
| |
| return '' |
| |
| @property |
| def task_type(self): |
| return self._task_type |
| |
| @property |
| def task_id(self): |
| return self._task_id |
| |
| @task_type.setter |
| def task_type(self, task_type): |
| raise RuntimeError( |
| 'You cannot reset the task_type of the GCEClusterResolver after it has ' |
| 'been created.') |
| |
| @task_id.setter |
| def task_id(self, task_id): |
| self._task_id = task_id |
| |
| @property |
| def rpc_layer(self): |
| return self._rpc_layer |
| |
| @rpc_layer.setter |
| def rpc_layer(self, rpc_layer): |
| self._rpc_layer = rpc_layer |