| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Training utilities for Estimator to use Distribute Coordinator.""" |
| |
| import copy |
| |
| import six |
| |
| from tensorflow.python.distribute import distribute_coordinator as dc |
| from tensorflow.python.distribute import distribute_coordinator_context as dc_context |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import server_lib |
| |
| # pylint: disable=protected-access |
| CHIEF = dc._TaskType.CHIEF |
| EVALUATOR = dc._TaskType.EVALUATOR |
| PS = dc._TaskType.PS |
| WORKER = dc._TaskType.WORKER |
| |
| # pylint: enable=protected-access |
| |
| |
| def _count_ps(cluster_spec): |
| """Counts the number of parameter servers in cluster_spec.""" |
| if not cluster_spec: |
| raise RuntimeError( |
| 'Internal error: `_count_ps` does not expect empty cluster_spec.') |
| |
| return len(cluster_spec.as_dict().get(PS, [])) |
| |
| |
| def _count_worker(cluster_spec, chief_task_type): |
| """Counts the number of workers (including chief) in cluster_spec.""" |
| if not cluster_spec: |
| raise RuntimeError( |
| 'Internal error: `_count_worker` does not expect empty cluster_spec.') |
| |
| return (len(cluster_spec.as_dict().get(WORKER, [])) + len( |
| cluster_spec.as_dict().get(chief_task_type, []))) |
| |
| |
| def _get_global_id(cluster_spec, task_type, task_id, chief_task_type): |
| """Returns the global id of the given task type in a cluster.""" |
| if not task_type: |
| return 0 |
| |
| # Sort task names in cluster by "chief"/"master", "evaluator", "worker" |
| # and "ps". More details can be found at the documentation of |
| # `tf.estimator.RunConfig.global_id_in_cluster`. |
| task_type_ordered_list = [] |
| if chief_task_type in cluster_spec.jobs: |
| task_type_ordered_list = [chief_task_type] |
| task_type_ordered_list.extend([ |
| t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS |
| ]) |
| if PS in cluster_spec.jobs: |
| task_type_ordered_list.append(PS) |
| |
| # Find the right global_id for current task. |
| next_global_id = 0 |
| for t in task_type_ordered_list: |
| if t == task_type: |
| return next_global_id + task_id |
| # `cluster_spec.job_tasks` returns all task addresses of type `t`. |
| next_global_id += len(cluster_spec.job_tasks(t)) |
| |
| # It is unexpected that it passes through all task_types in |
| # `task_type_ordered_list`. |
| raise RuntimeError('Internal Error: `task_type` ({}) is not in ' |
| 'cluster_spec ({}).'.format(task_type, cluster_spec)) |
| |
| |
| def _init_run_config_from_worker_context(config, worker_context): |
| """Initializes run config from distribute coordinator's worker context.""" |
| |
| # pylint: disable=protected-access |
| config._service = None |
| config._cluster_spec = worker_context.cluster_spec |
| config._task_type = worker_context.task_type |
| config._task_id = worker_context.task_id |
| config._evaluation_master = worker_context.master_target |
| config._master = worker_context.master_target |
| config._is_chief = worker_context.is_chief |
| |
| if config._cluster_spec: |
| # Distributed mode. |
| if config._task_type != EVALUATOR: |
| |
| config._num_ps_replicas = _count_ps(config._cluster_spec) |
| config._num_worker_replicas = _count_worker( |
| config._cluster_spec, chief_task_type=CHIEF) |
| config._global_id_in_cluster = _get_global_id( |
| config._cluster_spec, |
| config._task_type, |
| config._task_id, |
| chief_task_type=CHIEF) |
| else: |
| # Evaluator task should not be aware of the other tasks. |
| config._cluster_spec = server_lib.ClusterSpec({}) |
| config._num_ps_replicas = 0 |
| config._num_worker_replicas = 0 |
| config._global_id_in_cluster = None # undefined |
| else: |
| # Local mode. |
| config._global_id_in_cluster = 0 |
| config._num_ps_replicas = 0 |
| config._num_worker_replicas = 1 |
| |
| |
| def init_run_config(config, tf_config): |
| """Initializes RunConfig for distribution strategies.""" |
| # pylint: disable=protected-access |
| if (config._experimental_distribute and |
| config._experimental_distribute.train_distribute): |
| if config._train_distribute: |
| raise ValueError('Either `train_distribute` or' |
| '`experimental_distribute.train_distribute` can be set.') |
| config._train_distribute = config._experimental_distribute.train_distribute |
| |
| if (config._experimental_distribute and |
| config._experimental_distribute.eval_distribute): |
| if config._eval_distribute: |
| raise ValueError('Either `eval_distribute` or' |
| '`experimental_distribute.eval_distribute` can be set.') |
| config._eval_distribute = config._experimental_distribute.eval_distribute |
| |
| cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {})) |
| config._init_distributed_setting_from_environment_var({}) |
| |
| # Use distribute coordinator with STANDALONE_CLIENT mode if |
| # `experimental_distribute.remote_cluster` is set. |
| if (config._train_distribute and config._experimental_distribute and |
| config._experimental_distribute.remote_cluster): |
| if cluster_spec: |
| raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and ' |
| '`experimental_distribute.remote_cluster`') |
| config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT |
| config._cluster_spec = config._experimental_distribute.remote_cluster |
| logging.info('RunConfig initialized for Distribute Coordinator with ' |
| 'STANDALONE_CLIENT mode') |
| return |
| |
| # Don't use distribute coordinator if it is local training or cluster has a |
| # MASTER job or `train_distribute` is not specified. |
| if (not cluster_spec or 'master' in cluster_spec.jobs or |
| not config._train_distribute): |
| config._distribute_coordinator_mode = None |
| config._init_distributed_setting_from_environment_var(tf_config) |
| config._maybe_overwrite_session_config_for_distributed_training() |
| logging.info('Not using Distribute Coordinator.') |
| return |
| |
| # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise. |
| assert tf_config |
| |
| # Set the cluster_spec only since the distributed setting will come from |
| # distribute coordinator. |
| config._cluster_spec = cluster_spec |
| config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER |
| logging.info('RunConfig initialized for Distribute Coordinator with ' |
| 'INDEPENDENT_WORKER mode') |
| |
| |
| def should_run_distribute_coordinator(config): |
| """Checks the config to see whether to run distribute coordinator.""" |
| # pylint: disable=protected-access |
| if (not hasattr(config, '_distribute_coordinator_mode') or |
| config._distribute_coordinator_mode is None): |
| logging.info('Not using Distribute Coordinator.') |
| return False |
| if (not isinstance(config._distribute_coordinator_mode, six.string_types) or |
| config._distribute_coordinator_mode not in [ |
| dc.CoordinatorMode.STANDALONE_CLIENT, |
| dc.CoordinatorMode.INDEPENDENT_WORKER |
| ]): |
| logging.warning('Unexpected distribute_coordinator_mode: %r', |
| config._distribute_coordinator_mode) |
| return False |
| if not config.cluster_spec: |
| logging.warning('Running `train_and_evaluate` locally, ignoring ' |
| '`experimental_distribute_coordinator_mode`.') |
| return False |
| return True |
| |
| |
| def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): |
| """Run distribute coordinator for Estimator's `train_and_evaluate`. |
| |
| Args: |
| estimator: An `Estimator` instance to train and evaluate. |
| train_spec: A `TrainSpec` instance to specify the training specification. |
| eval_spec: A `EvalSpec` instance to specify the evaluation and export |
| specification. |
| executor_cls: the evaluation executor class of Estimator. |
| |
| Raises: |
| ValueError: if `distribute_coordinator_mode` is None in RunConfig. |
| """ |
| run_config = estimator.config |
| if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access |
| raise ValueError( |
| 'Distribute coordinator mode is not specified in `RunConfig`.') |
| |
| def _worker_fn(strategy): |
| """Function for worker task.""" |
| local_estimator = copy.deepcopy(estimator) |
| # pylint: disable=protected-access |
| local_estimator._config._train_distribute = strategy |
| context = dc_context.get_current_worker_context() |
| _init_run_config_from_worker_context(local_estimator._config, context) |
| logging.info('Updated config: %s', str(vars(local_estimator._config))) |
| local_estimator._train_distribution = strategy |
| # pylint: enable=protected-access |
| |
| # In the standalone client, we don't need to run hooks on all threads |
| # because logging hooks on all threads may be too much on the screen; also |
| # tensor passed to one hook can only be fetched with the graph where the |
| # tensor is defined. Other hooks such as checkpointing hooks will added by |
| # MonitoredTrainingSession. |
| # TODO(yuefengz): Is there a hook that does need to run on all threads in |
| # standalone client mode? |
| if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access |
| dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief): |
| hooks = list(train_spec.hooks) |
| else: |
| hooks = [] |
| |
| # Prevent estimator.train from calling distribute coordinator again. This |
| # function calls estimator.train which will use distribute coordinator path |
| # again if `_distribute_coordinator_mode` is set. |
| local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access |
| local_estimator.train( |
| input_fn=train_spec.input_fn, |
| max_steps=train_spec.max_steps, |
| hooks=hooks) |
| |
| def _eval_fn(strategy): |
| """Function for evaluator task.""" |
| local_estimator = copy.deepcopy(estimator) |
| # pylint: disable=protected-access |
| local_estimator._config._eval_distribute = strategy |
| _init_run_config_from_worker_context( |
| local_estimator._config, dc_context.get_current_worker_context()) |
| logging.info('Updated config: %s', str(vars(local_estimator._config))) |
| local_estimator._eval_distribution = strategy |
| |
| # Prevent estimator.evaluate from calling distribute coordinator again. This |
| # function calls estimator.evaluate which will use distribute coordinator |
| # path again if `_distribute_coordinator_mode` is set. |
| local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access |
| |
| executor = executor_cls(local_estimator, train_spec, eval_spec) |
| executor._start_continuous_evaluation() |
| # pylint: enable=protected-access |
| |
| # pylint: disable=protected-access |
| if (run_config._distribute_coordinator_mode == |
| dc.CoordinatorMode.STANDALONE_CLIENT): |
| cluster_spec = run_config.cluster_spec |
| assert cluster_spec |
| else: |
| # The cluster_spec comes from TF_CONFIG environment variable if it is |
| # INDEPENDENT_WORKER mode. |
| cluster_spec = None |
| |
| dc.run_distribute_coordinator( |
| _worker_fn, |
| run_config.train_distribute, |
| _eval_fn, |
| run_config.eval_distribute, |
| mode=run_config._distribute_coordinator_mode, |
| cluster_spec=cluster_spec, |
| session_config=run_config.session_config) |
| |
| |
| # TODO(yuefengz): maybe merge the following two functions? |
| # pylint: disable=protected-access |
| def estimator_train(estimator, train_distributed_fn, hooks): |
| """Run distribute coordinator for Estimator's `train` method.""" |
| assert estimator._config._distribute_coordinator_mode |
| run_config = estimator._config |
| assert estimator._config.cluster_spec |
| cluster_spec = multi_worker_util.normalize_cluster_spec( |
| estimator._config.cluster_spec) |
| assert estimator._config._train_distribute |
| |
| if 'evaluator' in cluster_spec.jobs: |
| raise ValueError("'evaluator' job is not supported if you don't use " |
| '`train_and_evaluate`') |
| |
| if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access |
| dc.CoordinatorMode.STANDALONE_CLIENT): |
| raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call ' |
| '`estimator.train`') |
| |
| if estimator._config._train_distribute.extended.experimental_between_graph: |
| # TODO(yuefengz): remove this limitation once we figure out how to merge |
| # return values from `_worker_fn`s. |
| raise ValueError('`Estimator.train` API is not supported for %s with ' |
| '`STANDALONE_CLIENT` mode.' % |
| estimator._config._train_distribute.__class__.__name__) |
| |
| def _worker_fn(strategy): |
| """Function for worker task.""" |
| local_estimator = copy.deepcopy(estimator) |
| local_estimator._config._train_distribute = strategy |
| context = dc_context.get_current_worker_context() |
| _init_run_config_from_worker_context(local_estimator._config, context) |
| logging.info('Updated config: %s', str(vars(local_estimator._config))) |
| local_estimator._train_distribution = strategy |
| |
| if context.is_chief: |
| chief_hooks = hooks |
| else: |
| chief_hooks = [] |
| train_distributed_fn(local_estimator, strategy, chief_hooks) |
| return local_estimator |
| |
| return dc.run_distribute_coordinator( |
| _worker_fn, |
| estimator._config.train_distribute, |
| mode=run_config._distribute_coordinator_mode, |
| cluster_spec=cluster_spec, |
| session_config=run_config.session_config) |
| |
| |
| def estimator_evaluate(estimator, evaluate_distributed_fn, hooks): |
| """Run distribute coordinator for Estimator's `evaluate` method.""" |
| assert estimator._config._distribute_coordinator_mode |
| run_config = estimator._config |
| assert estimator._config.cluster_spec |
| cluster_spec = multi_worker_util.normalize_cluster_spec( |
| estimator._config.cluster_spec) |
| assert estimator._config._eval_distribute |
| |
| if 'evaluator' in cluster_spec.jobs: |
| raise ValueError("'evaluator' job is not supported if you don't use " |
| '`train_and_evaluate`') |
| |
| if (estimator._config._distribute_coordinator_mode != |
| dc.CoordinatorMode.STANDALONE_CLIENT): |
| raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call ' |
| '`Estimator.evaluate`') |
| |
| if estimator._config._eval_distribute.extended.experimental_between_graph: |
| # TODO(yuefengz): remove this limitation once we figure out how to merge |
| # return values from `_worker_fn`s. |
| raise ValueError('`Estimator.evaluate` API is not supported for %s with ' |
| '`STANDALONE_CLIENT` mode.' % |
| estimator._config._eval_distribute.__class__.__name__) |
| |
| def _worker_fn(strategy): |
| """Function for evaluation.""" |
| local_estimator = copy.deepcopy(estimator) |
| local_estimator._config._eval_distribute = strategy |
| context = dc_context.get_current_worker_context() |
| _init_run_config_from_worker_context(local_estimator._config, context) |
| logging.info('Updated config: %s', str(vars(local_estimator._config))) |
| local_estimator._eval_distribution = strategy |
| |
| if context.is_chief: |
| chief_hooks = hooks |
| else: |
| chief_hooks = [] |
| return evaluate_distributed_fn(local_estimator, strategy, chief_hooks) |
| |
| return dc.run_distribute_coordinator( |
| _worker_fn, |
| estimator._config.eval_distribute, |
| mode=run_config._distribute_coordinator_mode, |
| cluster_spec=cluster_spec, |
| session_config=run_config.session_config) |
| |
| # pylint: enable=protected-access |