| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Implementation of Cluster Resolvers for TF_CONFIG Environment Variables.""" |
| |
| |
| import json |
| import os |
| |
| from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver |
| from tensorflow.python.training.server_lib import ClusterSpec |
| from tensorflow.python.util.tf_export import tf_export |
| |
| _TF_CONFIG_ENV = 'TF_CONFIG' |
| _SESSION_MASTER_KEY = 'session_master' |
| _RPC_LAYER_KEY = 'rpc_layer' |
| _TASK_KEY = 'task' |
| |
| |
| def format_master_url(master, rpc_layer=None): |
| if rpc_layer: |
| return '%s://%s' % (rpc_layer, master) |
| else: |
| return master |
| |
| |
| def _load_tf_config(): |
| return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) |
| |
| |
| def _get_value_in_tfconfig(key, default=None): |
| tf_config = _load_tf_config() |
| return tf_config[key] if key in tf_config else default |
| |
| |
| @tf_export('distribute.cluster_resolver.TFConfigClusterResolver') |
| class TFConfigClusterResolver(ClusterResolver): |
| """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar. |
| |
| This is an implementation of cluster resolvers when using TF_CONFIG to set |
| information about the cluster. The cluster spec returned will be |
| initialized from the TF_CONFIG environment variable. |
| |
| An example to set TF_CONFIG is: |
| |
| ```Python |
| os.environ['TF_CONFIG'] = json.dumps({ |
| 'cluster': { |
| 'worker': ["localhost:12345", "localhost:23456"] |
| }, |
| 'task': {'type': 'worker', 'index': 0} |
| }) |
| ``` |
| |
| However, sometimes the container orchestration framework will set TF_CONFIG |
| for you. In this case, you can just create an instance without passing in any |
| arguments. You can find an example here to let Kuburnetes set TF_CONFIG for |
| you: https://github.com/tensorflow/ecosystem/tree/master/kubernetes. Then you |
| can use it with `tf.distribute.Strategy` as: |
| |
| ```Python |
| # `TFConfigClusterResolver` is already the default one in the following |
| # strategy. |
| strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( |
| cluster_resolver=TFConfigClusterResolver()) |
| ``` |
| """ |
| |
| def __init__(self, |
| task_type=None, |
| task_id=None, |
| rpc_layer=None, |
| environment=None): |
| """Creates a new TFConfigClusterResolver. |
| |
| Args: |
| task_type: (String, optional) Overrides the task type specified in the |
| TF_CONFIG environment variable. |
| task_id: (Integer, optional) Overrides the task index specified in the |
| TF_CONFIG environment variable. |
| rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses. |
| environment: (String, optional) Overrides the environment TensorFlow |
| operates in. |
| """ |
| self._task_type = task_type |
| self._task_id = task_id |
| self._rpc_layer = rpc_layer |
| self._environment = environment |
| |
| @property |
| def task_type(self): |
| if self._task_type is None: |
| task_info = _get_value_in_tfconfig(_TASK_KEY, {}) |
| return str(task_info['type']) if 'type' in task_info else None |
| else: |
| return str(self._task_type) |
| |
| @property |
| def task_id(self): |
| if self._task_id is None: |
| task_info = _get_value_in_tfconfig(_TASK_KEY, {}) |
| return int(task_info['index']) if 'index' in task_info else None |
| else: |
| return int(self._task_id) |
| |
| @task_type.setter |
| def task_type(self, task_type): |
| self._task_type = task_type |
| |
| @task_id.setter |
| def task_id(self, task_id): |
| self._task_id = task_id |
| |
| @property |
| def environment(self): |
| return self._environment |
| |
| @property |
| def rpc_layer(self): |
| if self._rpc_layer is None: |
| return _get_value_in_tfconfig(_RPC_LAYER_KEY) |
| else: |
| return self._rpc_layer |
| |
| @rpc_layer.setter |
| def rpc_layer(self, rpc_layer): |
| self._rpc_layer = rpc_layer |
| |
| def num_accelerators(self, |
| task_type=None, |
| task_id=None, |
| config_proto=None): |
| task_type = self.task_type if task_type is None else task_type |
| task_id = self.task_id if task_id is None else task_id |
| return super(TFConfigClusterResolver, self).num_accelerators( |
| task_type, task_id, config_proto) |
| |
| def cluster_spec(self): |
| """Returns a ClusterSpec based on the TF_CONFIG environment variable. |
| |
| Returns: |
| A ClusterSpec with information from the TF_CONFIG environment variable. |
| """ |
| tf_config = _load_tf_config() |
| if 'cluster' not in tf_config: |
| return ClusterSpec({}) |
| return ClusterSpec(tf_config['cluster']) |
| |
| def master(self, task_type=None, task_id=None, rpc_layer=None): |
| """Returns the master address to use when creating a TensorFlow session. |
| |
| Note: this is only useful for TensorFlow 1.x. |
| |
| Args: |
| task_type: (String, optional) Overrides and sets the task_type of the |
| master. |
| task_id: (Integer, optional) Overrides and sets the task id of the |
| master. |
| rpc_layer: (String, optional) Overrides and sets the protocol over which |
| TensorFlow nodes communicate with each other. |
| |
| Returns: |
| The address of the master. |
| |
| Raises: |
| RuntimeError: If the task_type or task_id is not specified and the |
| `TF_CONFIG` environment variable does not contain a task section. |
| """ |
| |
| # If `session_master` is set, just use that. |
| session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY) |
| if session_master is not None: |
| return session_master |
| |
| # Return an empty string if we are the only job in the ClusterSpec. |
| cluster_spec = self.cluster_spec() |
| if (not cluster_spec.jobs or |
| (len(cluster_spec.jobs) == 1 and |
| len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)): |
| return '' |
| |
| # We try to auto-detect the task type and id, but uses the user-supplied one |
| # where available |
| task_type = task_type if task_type is not None else self.task_type |
| task_id = task_id if task_id is not None else self.task_id |
| rpc_layer = rpc_layer if rpc_layer is not None else self.rpc_layer |
| |
| return format_master_url(cluster_spec.task_address(task_type, task_id), |
| rpc_layer) |