blob: 054eeedb81b1cc8a5b83ba6a1c2f4c3ba0e60c87 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of ClusterResolvers for GCE instance groups."""
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.util.tf_export import tf_export
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
@tf_export('distribute.cluster_resolver.GCEClusterResolver')
class GCEClusterResolver(ClusterResolver):
"""ClusterResolver for Google Compute Engine.
This is an implementation of cluster resolvers for the Google Compute Engine
instance group platform. By specifying a project, zone, and instance group,
this will retrieve the IP address of all the instances within the instance
group and return a ClusterResolver object suitable for use for distributed
TensorFlow.
Note: this cluster resolver cannot retrieve `task_type`, `task_id` or
`rpc_layer`. To use it with some distribution strategies like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
specify `task_type` and `task_id` in the constructor.
Usage example with tf.distribute.Strategy:
```Python
# On worker 0
cluster_resolver = GCEClusterResolver("my-project", "us-west1",
"my-instance-group",
task_type="worker", task_id=0)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
cluster_resolver=cluster_resolver)
# On worker 1
cluster_resolver = GCEClusterResolver("my-project", "us-west1",
"my-instance-group",
task_type="worker", task_id=1)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
cluster_resolver=cluster_resolver)
```
"""
def __init__(self,
project,
zone,
instance_group,
port,
task_type='worker',
task_id=0,
rpc_layer='grpc',
credentials='default',
service=None):
"""Creates a new GCEClusterResolver object.
This takes in a few parameters and creates a GCEClusterResolver project. It
will then use these parameters to query the GCE API for the IP addresses of
each instance in the instance group.
Args:
project: Name of the GCE project.
zone: Zone of the GCE instance group.
instance_group: Name of the GCE instance group.
port: Port of the listening TensorFlow server (default: 8470)
task_type: Name of the TensorFlow job this GCE instance group of VM
instances belong to.
task_id: The task index for this particular VM, within the GCE
instance group. In particular, every single instance should be assigned
a unique ordinal index within an instance group manually so that they
can be distinguished from each other.
rpc_layer: The RPC layer TensorFlow should use to communicate across
instances.
credentials: GCE Credentials. If nothing is specified, this defaults to
GoogleCredentials.get_application_default().
service: The GCE API object returned by the googleapiclient.discovery
function. (Default: discovery.build('compute', 'v1')). If you specify a
custom service object, then the credentials parameter will be ignored.
Raises:
ImportError: If the googleapiclient is not installed.
"""
self._project = project
self._zone = zone
self._instance_group = instance_group
self._task_type = task_type
self._task_id = task_id
self._rpc_layer = rpc_layer
self._port = port
self._credentials = credentials
if credentials == 'default':
if _GOOGLE_API_CLIENT_INSTALLED:
self._credentials = GoogleCredentials.get_application_default()
if service is None:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
'GCE cluster resolver')
self._service = discovery.build(
'compute', 'v1',
credentials=self._credentials)
else:
self._service = service
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest instance group info.
This returns a ClusterSpec object for use based on information from the
specified instance group. We will retrieve the information from the GCE APIs
every time this method is called.
Returns:
A ClusterSpec containing host information retrieved from GCE.
"""
request_body = {'instanceState': 'RUNNING'}
request = self._service.instanceGroups().listInstances(
project=self._project,
zone=self._zone,
instanceGroups=self._instance_group,
body=request_body,
orderBy='name')
worker_list = []
while request is not None:
response = request.execute()
items = response['items']
for instance in items:
instance_name = instance['instance'].split('/')[-1]
instance_request = self._service.instances().get(
project=self._project,
zone=self._zone,
instance=instance_name)
if instance_request is not None:
instance_details = instance_request.execute()
ip_address = instance_details['networkInterfaces'][0]['networkIP']
instance_url = '%s:%s' % (ip_address, self._port)
worker_list.append(instance_url)
request = self._service.instanceGroups().listInstances_next(
previous_request=request,
previous_response=response)
worker_list.sort()
return ClusterSpec({self._task_type: worker_list})
def master(self, task_type=None, task_id=None, rpc_layer=None):
task_type = task_type if task_type is not None else self._task_type
task_id = task_id if task_id is not None else self._task_id
if task_type is not None and task_id is not None:
master = self.cluster_spec().task_address(task_type, task_id)
if rpc_layer or self._rpc_layer:
return '%s://%s' % (rpc_layer or self._rpc_layer, master)
else:
return master
return ''
@property
def task_type(self):
return self._task_type
@property
def task_id(self):
return self._task_id
@task_type.setter
def task_type(self, task_type):
raise RuntimeError(
'You cannot reset the task_type of the GCEClusterResolver after it has '
'been created.')
@task_id.setter
def task_id(self, task_id):
self._task_id = task_id
@property
def rpc_layer(self):
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer