blob: c5367e0e7c0de699e5d48afbde82bd78c581b74a [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training utilities for Estimator to use Distribute Coordinator."""
import copy
import six
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
# pylint: disable=protected-access
CHIEF = dc._TaskType.CHIEF
EVALUATOR = dc._TaskType.EVALUATOR
PS = dc._TaskType.PS
WORKER = dc._TaskType.WORKER
# pylint: enable=protected-access
def _count_ps(cluster_spec):
"""Counts the number of parameter servers in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_ps` does not expect empty cluster_spec.')
return len(cluster_spec.as_dict().get(PS, []))
def _count_worker(cluster_spec, chief_task_type):
"""Counts the number of workers (including chief) in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_worker` does not expect empty cluster_spec.')
return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
cluster_spec.as_dict().get(chief_task_type, [])))
def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
"""Returns the global id of the given task type in a cluster."""
if not task_type:
return 0
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
# `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
task_type_ordered_list.extend([
t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
])
if PS in cluster_spec.jobs:
task_type_ordered_list.append(PS)
# Find the right global_id for current task.
next_global_id = 0
for t in task_type_ordered_list:
if t == task_type:
return next_global_id + task_id
# `cluster_spec.job_tasks` returns all task addresses of type `t`.
next_global_id += len(cluster_spec.job_tasks(t))
# It is unexpected that it passes through all task_types in
# `task_type_ordered_list`.
raise RuntimeError('Internal Error: `task_type` ({}) is not in '
'cluster_spec ({}).'.format(task_type, cluster_spec))
def _init_run_config_from_worker_context(config, worker_context):
"""Initializes run config from distribute coordinator's worker context."""
# pylint: disable=protected-access
config._service = None
config._cluster_spec = worker_context.cluster_spec
config._task_type = worker_context.task_type
config._task_id = worker_context.task_id
config._evaluation_master = worker_context.master_target
config._master = worker_context.master_target
config._is_chief = worker_context.is_chief
if config._cluster_spec:
# Distributed mode.
if config._task_type != EVALUATOR:
config._num_ps_replicas = _count_ps(config._cluster_spec)
config._num_worker_replicas = _count_worker(
config._cluster_spec, chief_task_type=CHIEF)
config._global_id_in_cluster = _get_global_id(
config._cluster_spec,
config._task_type,
config._task_id,
chief_task_type=CHIEF)
else:
# Evaluator task should not be aware of the other tasks.
config._cluster_spec = server_lib.ClusterSpec({})
config._num_ps_replicas = 0
config._num_worker_replicas = 0
config._global_id_in_cluster = None # undefined
else:
# Local mode.
config._global_id_in_cluster = 0
config._num_ps_replicas = 0
config._num_worker_replicas = 1
def init_run_config(config, tf_config):
"""Initializes RunConfig for distribution strategies."""
# pylint: disable=protected-access
if (config._experimental_distribute and
config._experimental_distribute.train_distribute):
if config._train_distribute:
raise ValueError('Either `train_distribute` or'
'`experimental_distribute.train_distribute` can be set.')
config._train_distribute = config._experimental_distribute.train_distribute
if (config._experimental_distribute and
config._experimental_distribute.eval_distribute):
if config._eval_distribute:
raise ValueError('Either `eval_distribute` or'
'`experimental_distribute.eval_distribute` can be set.')
config._eval_distribute = config._experimental_distribute.eval_distribute
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
config._init_distributed_setting_from_environment_var({})
# Use distribute coordinator with STANDALONE_CLIENT mode if
# `experimental_distribute.remote_cluster` is set.
if (config._train_distribute and config._experimental_distribute and
config._experimental_distribute.remote_cluster):
if cluster_spec:
raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and '
'`experimental_distribute.remote_cluster`')
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
config._cluster_spec = config._experimental_distribute.remote_cluster
logging.info('RunConfig initialized for Distribute Coordinator with '
'STANDALONE_CLIENT mode')
return
# Don't use distribute coordinator if it is local training or cluster has a
# MASTER job or `train_distribute` is not specified.
if (not cluster_spec or 'master' in cluster_spec.jobs or
not config._train_distribute):
config._distribute_coordinator_mode = None
config._init_distributed_setting_from_environment_var(tf_config)
config._maybe_overwrite_session_config_for_distributed_training()
logging.info('Not using Distribute Coordinator.')
return
# Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
assert tf_config
# Set the cluster_spec only since the distributed setting will come from
# distribute coordinator.
config._cluster_spec = cluster_spec
config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
logging.info('RunConfig initialized for Distribute Coordinator with '
'INDEPENDENT_WORKER mode')
def should_run_distribute_coordinator(config):
"""Checks the config to see whether to run distribute coordinator."""
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
logging.info('Not using Distribute Coordinator.')
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
dc.CoordinatorMode.STANDALONE_CLIENT,
dc.CoordinatorMode.INDEPENDENT_WORKER
]):
logging.warning('Unexpected distribute_coordinator_mode: %r',
config._distribute_coordinator_mode)
return False
if not config.cluster_spec:
logging.warning('Running `train_and_evaluate` locally, ignoring '
'`experimental_distribute_coordinator_mode`.')
return False
return True
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
"""Run distribute coordinator for Estimator's `train_and_evaluate`.
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
eval_spec: A `EvalSpec` instance to specify the evaluation and export
specification.
executor_cls: the evaluation executor class of Estimator.
Raises:
ValueError: if `distribute_coordinator_mode` is None in RunConfig.
"""
run_config = estimator.config
if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
raise ValueError(
'Distribute coordinator mode is not specified in `RunConfig`.')
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
context = dc_context.get_current_worker_context()
_init_run_config_from_worker_context(local_estimator._config, context)
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
# In the standalone client, we don't need to run hooks on all threads
# because logging hooks on all threads may be too much on the screen; also
# tensor passed to one hook can only be fetched with the graph where the
# tensor is defined. Other hooks such as checkpointing hooks will added by
# MonitoredTrainingSession.
# TODO(yuefengz): Is there a hook that does need to run on all threads in
# standalone client mode?
if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
hooks = list(train_spec.hooks)
else:
hooks = []
# Prevent estimator.train from calling distribute coordinator again. This
# function calls estimator.train which will use distribute coordinator path
# again if `_distribute_coordinator_mode` is set.
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=hooks)
def _eval_fn(strategy):
"""Function for evaluator task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
# Prevent estimator.evaluate from calling distribute coordinator again. This
# function calls estimator.evaluate which will use distribute coordinator
# path again if `_distribute_coordinator_mode` is set.
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
executor = executor_cls(local_estimator, train_spec, eval_spec)
executor._start_continuous_evaluation()
# pylint: enable=protected-access
# pylint: disable=protected-access
if (run_config._distribute_coordinator_mode ==
dc.CoordinatorMode.STANDALONE_CLIENT):
cluster_spec = run_config.cluster_spec
assert cluster_spec
else:
# The cluster_spec comes from TF_CONFIG environment variable if it is
# INDEPENDENT_WORKER mode.
cluster_spec = None
dc.run_distribute_coordinator(
_worker_fn,
run_config.train_distribute,
_eval_fn,
run_config.eval_distribute,
mode=run_config._distribute_coordinator_mode,
cluster_spec=cluster_spec,
session_config=run_config.session_config)
# TODO(yuefengz): maybe merge the following two functions?
# pylint: disable=protected-access
def estimator_train(estimator, train_distributed_fn, hooks):
"""Run distribute coordinator for Estimator's `train` method."""
assert estimator._config._distribute_coordinator_mode
run_config = estimator._config
assert estimator._config.cluster_spec
cluster_spec = multi_worker_util.normalize_cluster_spec(
estimator._config.cluster_spec)
assert estimator._config._train_distribute
if 'evaluator' in cluster_spec.jobs:
raise ValueError("'evaluator' job is not supported if you don't use "
'`train_and_evaluate`')
if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access
dc.CoordinatorMode.STANDALONE_CLIENT):
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
'`estimator.train`')
if estimator._config._train_distribute.extended.experimental_between_graph:
# TODO(yuefengz): remove this limitation once we figure out how to merge
# return values from `_worker_fn`s.
raise ValueError('`Estimator.train` API is not supported for %s with '
'`STANDALONE_CLIENT` mode.' %
estimator._config._train_distribute.__class__.__name__)
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
local_estimator._config._train_distribute = strategy
context = dc_context.get_current_worker_context()
_init_run_config_from_worker_context(local_estimator._config, context)
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
if context.is_chief:
chief_hooks = hooks
else:
chief_hooks = []
train_distributed_fn(local_estimator, strategy, chief_hooks)
return local_estimator
return dc.run_distribute_coordinator(
_worker_fn,
estimator._config.train_distribute,
mode=run_config._distribute_coordinator_mode,
cluster_spec=cluster_spec,
session_config=run_config.session_config)
def estimator_evaluate(estimator, evaluate_distributed_fn, hooks):
"""Run distribute coordinator for Estimator's `evaluate` method."""
assert estimator._config._distribute_coordinator_mode
run_config = estimator._config
assert estimator._config.cluster_spec
cluster_spec = multi_worker_util.normalize_cluster_spec(
estimator._config.cluster_spec)
assert estimator._config._eval_distribute
if 'evaluator' in cluster_spec.jobs:
raise ValueError("'evaluator' job is not supported if you don't use "
'`train_and_evaluate`')
if (estimator._config._distribute_coordinator_mode !=
dc.CoordinatorMode.STANDALONE_CLIENT):
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
'`Estimator.evaluate`')
if estimator._config._eval_distribute.extended.experimental_between_graph:
# TODO(yuefengz): remove this limitation once we figure out how to merge
# return values from `_worker_fn`s.
raise ValueError('`Estimator.evaluate` API is not supported for %s with '
'`STANDALONE_CLIENT` mode.' %
estimator._config._eval_distribute.__class__.__name__)
def _worker_fn(strategy):
"""Function for evaluation."""
local_estimator = copy.deepcopy(estimator)
local_estimator._config._eval_distribute = strategy
context = dc_context.get_current_worker_context()
_init_run_config_from_worker_context(local_estimator._config, context)
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
if context.is_chief:
chief_hooks = hooks
else:
chief_hooks = []
return evaluate_distributed_fn(local_estimator, strategy, chief_hooks)
return dc.run_distribute_coordinator(
_worker_fn,
estimator._config.eval_distribute,
mode=run_config._distribute_coordinator_mode,
cluster_spec=cluster_spec,
session_config=run_config.session_config)
# pylint: enable=protected-access