blob: 061db2de755f05450e2ea8d8e93e27fd73fd3588 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base testing class for strategies that require multiple nodes."""
import contextlib
import copy
import json
import os
import subprocess
import sys
import threading
import unittest
import six
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import remote
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
original_run_std_server = dc._run_std_server # pylint: disable=protected-access
pick_unused_port = test_util.pick_unused_port
def _create_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
protocol='grpc',
worker_config=None,
ps_config=None,
eval_config=None,
worker_name='worker',
ps_name='ps',
chief_name='chief'):
"""Creates and starts local servers and returns the cluster_spec dict."""
worker_ports = [pick_unused_port() for _ in range(num_workers)]
ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
for i in range(num_workers):
server_lib.Server(
cs,
job_name=worker_name,
protocol=protocol,
task_index=i,
config=worker_config,
start=True)
for i in range(num_ps):
server_lib.Server(
cs,
job_name=ps_name,
protocol=protocol,
task_index=i,
config=ps_config,
start=True)
if has_chief:
server_lib.Server(
cs,
job_name=chief_name,
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
if has_eval:
server_lib.Server(
cs,
job_name='evaluator',
protocol=protocol,
task_index=0,
config=eval_config,
start=True)
return cluster_dict
def create_in_process_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
rpc_layer='grpc'):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# The cluster may hang if workers don't have enough inter_op threads. See
# b/172296720 for more details.
if worker_config.inter_op_parallelism_threads < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
if has_chief:
worker_config.experimental.collective_group_leader = (
'/job:chief/replica:0/task:0')
else:
worker_config.experimental.collective_group_leader = (
'/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
eval_config = config_pb2.ConfigProto()
eval_config.experimental.collective_group_leader = ''
# Create in-process servers. Once an in-process tensorflow server is created,
# there is no way to terminate it. So we create one cluster per test process.
# We could've started the server in another process, we could then kill that
# process to terminate the server. The reasons why we don't want multiple
# processes are
# 1) it is more difficult to manage these processes;
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
cluster = None
try:
cluster = _create_cluster(
num_workers,
num_ps=num_ps,
has_chief=has_chief,
has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
eval_config=eval_config,
protocol=rpc_layer)
except errors.UnknownError as e:
if 'Could not start gRPC server' in e.message:
raise unittest.SkipTest('Cannot start std servers.')
else:
raise
return cluster
class MultiProcessCluster(object):
"""A cluster of TensorFlow servers in separate processes.
This class is not thread-safe.
"""
def __init__(self,
cluster_resolver,
stream_output=False,
collective_leader=None):
self._cluster_resolver = cluster_resolver
self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
self._rpc_layer = cluster_resolver.rpc_layer
self._stream_output = stream_output
self._start_events = {}
self._finish_events = {}
self._mpr_manager = multi_process_runner.manager()
def task_function(start_events, finish_events):
cluster_resolver = tfconfig_cluster_resolver.TFConfigClusterResolver()
cluster_spec = cluster_resolver.cluster_spec()
task_type = cluster_resolver.task_type
task_id = cluster_resolver.task_id
rpc_layer = cluster_resolver.rpc_layer
# TODO(yuefengz): support GPU clusters.
server_config = config_pb2.ConfigProto()
server_config.device_count['GPU'] = 0
if collective_leader:
server_config.experimental.collective_group_leader = collective_leader
server_config.experimental.collective_nccl = False
logging.info(
'Enabling collective ops with cluster_spec = %r, task_type = %r, '
'task_id = %r, rpc_layer = %r, collective_leader = %s',
cluster_spec, task_type, task_id, rpc_layer, collective_leader)
else:
logging.info(
'Starting server with cluster_spec = %r, task_type = %r, '
'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
rpc_layer)
server_lib.Server(
cluster_spec,
job_name=task_type,
protocol=rpc_layer,
task_index=task_id,
config=server_config,
start=True)
start_event = start_events[task_type][task_id]
start_event.set()
finish_event = finish_events[task_type][task_id]
finish_event.wait()
os._exit(0) # pylint: disable=protected-access
self._task_function = task_function
self._mpr = None
def start(self):
"""Starts one TensorFlow server for each task in the cluster_resolver.
It will wait until all the servers are up before returns.
"""
if self._mpr:
raise ValueError('The cluster has already been started.')
for task_type, task_addresses in self._cluster_spec.items():
self._start_events[task_type] = []
self._finish_events[task_type] = []
for _ in task_addresses:
self._start_events[task_type].append(self._mpr_manager.Event())
self._finish_events[task_type].append(self._mpr_manager.Event())
self._mpr = multi_process_runner.MultiProcessRunner(
self._task_function,
self._cluster_spec,
args=(self._start_events, self._finish_events),
rpc_layer=self._rpc_layer,
stream_output=self._stream_output,
return_output=False,
use_dill_for_args=False)
self._mpr.start()
for task_type, task_addresses in self._cluster_spec.items():
for i in range(len(task_addresses)):
self._start_events[task_type][i].wait()
def stop(self):
"""Stops all the servers."""
for task_type, task_addresses in self._cluster_spec.items():
for i in range(len(task_addresses)):
self._finish_events[task_type][i].set()
try:
self._mpr.join()
except multi_process_runner.UnexpectedSubprocessExitError:
# TODO(yuefengz): investigate why processes exit with 255.
pass
self._mpr = None
self._start_events = {}
self._finish_events = {}
def kill_task(self, task_type, task_id):
"""Kill a server given task_type and task_id.
Args:
task_type: the type of the task such as "worker".
task_id: the id the task such as 1.
"""
assert self._mpr
if (not self._start_events[task_type][task_id].is_set() or
self._finish_events[task_type][task_id].is_set()):
raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
self._finish_events[task_type][task_id].set()
self._mpr._processes[(task_type, task_id)].join()
def start_task(self, task_type, task_id):
"""Starts a server given task_type and task_id.
Args:
task_type: the type of the task such as "worker".
task_id: the id the task such as 1.
Raises:
ValueError: if the server already exists.
"""
assert self._mpr
if (not self._start_events[task_type][task_id].is_set() or
not self._finish_events[task_type][task_id].is_set()):
raise ValueError(
'The task %s:%d is still alive. You cannot start another one.' %
(task_type, task_id))
self._start_events[task_type][task_id] = self._mpr_manager.Event()
self._finish_events[task_type][task_id] = self._mpr_manager.Event()
self._mpr.start_single_process(task_type=task_type, task_id=task_id)
self._start_events[task_type][task_id].wait()
@property
def cluster_resolver(self):
return copy.deepcopy(self._cluster_resolver)
def create_multi_process_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
rpc_layer='grpc',
stream_output=False,
collective_leader=None):
logging.info('Now creating a MultiProcessCluster with '
f'num_workers={num_workers}, num_ps={num_ps}.')
cluster_spec = create_cluster_spec(
has_chief=has_chief,
num_workers=num_workers,
num_ps=num_ps,
has_eval=has_eval)
cluster = MultiProcessCluster(
cluster_resolver_lib.SimpleClusterResolver(
server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer),
stream_output=stream_output,
collective_leader=collective_leader)
cluster.start()
return cluster
@tf_export(
'__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[])
def create_cluster_spec(has_chief=False,
num_workers=1,
num_ps=0,
has_eval=False):
"""Create a cluster spec with tasks with unused local ports.
This utility finds available ports at localhost, and returns a dict that
represents the cluster spec that utilizes those ports, according to the
arguments. The dict representing the cluster spec contains task types, and
their instances' addresses. Note that this is usually only for testing purpose
using multiple processes in the local machine, and should not be used for real
multi-worker TensorFlow programs, where the addresses need to point to the
processes at separate machines.
This util is useful when creating the `cluster_spec` arg for
`tf.__internal__.distribute.multi_process_runner.run`.
Args:
has_chief: Whether the generated cluster spec should contain "chief" task
type.
num_workers: Number of workers to use in the cluster spec.
num_ps: Number of parameter servers to use in the cluster spec.
has_eval: Whether this cluster spec has evaluator.
Returns:
A dict that represents the cluster spec using localhost ports for the tasks.
Example:
```python
cluster_spec =
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
has_chief=True, num_workers=2, num_ps=2)
# An example of cluster_spec is
# {'chief': ['localhost:23381'],
# 'worker': ['localhost:19197', 'localhost:22903'],
# 'ps': ['localhost:16912', 'localhost:21535']}
cluster_spec =
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
has_chief=False, num_workers=0, num_ps=0, has_eval=True)
# An example of cluster_spec is
# {'evaluator': ['localhost:23381']}
```
"""
cluster_spec = {}
if has_chief:
cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
if num_workers:
cluster_spec['worker'] = [
'localhost:%s' % pick_unused_port() for _ in range(num_workers)
]
if num_ps:
cluster_spec['ps'] = [
'localhost:%s' % pick_unused_port() for _ in range(num_ps)
]
if has_eval:
cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
return cluster_spec
@contextlib.contextmanager
def skip_if_grpc_server_cant_be_started(test_obj):
try:
yield
except errors.UnknownError as e:
if 'Could not start gRPC server' in e.message:
reason = 'Cannot start std servers.'
test_obj.test_skipped_reason = reason
test_obj.skipTest(reason)
else:
raise
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
@classmethod
def setUpClass(cls, num_workers=2, num_ps=1): # pylint: disable=g-missing-super-call
"""Create a local cluster with 2 workers."""
cls._cluster_spec = create_in_process_cluster(num_workers=num_workers,
num_ps=num_ps)
cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
# different session config or master target.
self._thread_local = threading.local()
self._thread_local.cached_session = None
self._coord = coordinator.Coordinator()
@contextlib.contextmanager
def session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
Creates a test session that connects to the local testing cluster.
Args:
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
config = self._create_config(config)
if target is None:
target = self._default_target
with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
@contextlib.contextmanager
# TODO(b/117573461): Overwrite self.evaluate() to use this function.
def cached_session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
Creates a test session that connects to the local testing cluster.
The session is only created once per test and then reused.
Args:
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case. Note that the
session will live until the end of the test.
"""
config = self._create_config(config)
if target is None:
target = self._default_target
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
def _create_config(self, config):
if config is None:
config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
config = copy.deepcopy(config)
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode,
*args, **kwargs):
def wrapped_client_fn():
with self._coord.stop_on_exception():
client_fn(task_type, task_id, num_gpus, *args, **kwargs)
if eager_mode:
with context.eager_mode():
wrapped_client_fn()
else:
with context.graph_mode():
wrapped_client_fn()
def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
**kwargs):
"""Runs several clients for between-graph replication.
Args:
client_fn: a function that needs to accept `task_type`, `task_id`,
`num_gpus`.
cluster_spec: a dict specifying jobs in a cluster.
num_gpus: number of GPUs per worker.
*args: will be passed to `client_fn`.
**kwargs: will be passed to `client_fn`.
"""
threads = []
for task_type in ['chief', 'worker']:
for task_id in range(len(cluster_spec.get(task_type, []))):
t = threading.Thread(
target=self._run_client,
args=(client_fn, task_type, task_id, num_gpus,
context.executing_eagerly()) + args,
kwargs=kwargs)
t.start()
threads.append(t)
self._coord.join(threads)
class SingleWorkerTestBaseGraph(MultiWorkerTestBase):
"""Base class for testing remote single worker strategy graph and dataset."""
@classmethod
def setUpClass(cls):
super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1)
class SingleWorkerTestBaseEager(test.TestCase):
"""Base class for testing remote single worker strategy eager and dataset."""
def setUp(self):
super(SingleWorkerTestBaseEager, self).setUp()
workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
remote.connect_to_remote_host(workers[0].target)
def cached_session(self):
return DummySession()
class DummySession(object):
def __enter__(self):
return
def __exit__(self, exception_type, exception_value, traceback):
pass
class MockOsEnv(collections_abc.Mapping):
"""A class that allows per-thread TF_CONFIG."""
def __init__(self, *args):
self._dict = dict()
self._thread_local = threading.local()
super(MockOsEnv, self).__init__(*args)
def get(self, key, default=None):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.get(self._thread_local.dict, key, default)
else:
return dict.get(self._dict, key, default)
def __getitem__(self, key):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.__getitem__(self._thread_local.dict, key)
else:
return dict.__getitem__(self._dict, key)
def __setitem__(self, key, val):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
if key == 'TF_CONFIG':
return dict.__setitem__(self._thread_local.dict, key, val)
else:
return dict.__setitem__(self._dict, key, val)
def __iter__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
for x in self._thread_local.dict:
yield x
for x in self._dict:
yield x
def __len__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
return self._thread_local.dict.__len__() + self._dict.__len__()
class IndependentWorkerTestBase(test.TestCase):
"""Testing infra for independent workers."""
def _make_mock_run_std_server(self):
def _mock_run_std_server(*args, **kwargs):
"""Returns the std server once all threads have started it."""
with skip_if_grpc_server_cant_be_started(self):
ret = original_run_std_server(*args, **kwargs)
# Wait for all std servers to be brought up in order to reduce the chance
# of remote sessions taking local ports that have been assigned to std
# servers. Only call this barrier the first time this function is run for
# each thread.
if not getattr(self._thread_local, 'server_started', False):
self._barrier.wait()
self._thread_local.server_started = True
return ret
return _mock_run_std_server
def setUp(self):
self._mock_os_env = MockOsEnv()
self._mock_context = test.mock.patch.object(os, 'environ',
self._mock_os_env)
self._coord = coordinator.Coordinator()
super(IndependentWorkerTestBase, self).setUp()
self._mock_context.__enter__()
# threading local object to be shared by all threads
self._thread_local = threading.local()
def tearDown(self):
self._mock_context.__exit__(None, None, None)
super(IndependentWorkerTestBase, self).tearDown()
def _task_thread(self, task_fn, tf_config, executing_eagerly, *args,
**kwargs):
with self._coord.stop_on_exception():
os.environ['TF_CONFIG'] = json.dumps(tf_config)
# Force the new thread simulating a worker to run in the same context
# mode as the parent thread does.
if executing_eagerly:
with context.eager_mode():
task_fn(*args, **kwargs)
else:
with ops.Graph().as_default(), context.graph_mode():
task_fn(*args, **kwargs)
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
*args, **kwargs):
"""Run tasks in a thread.
If `tf_config` is provided, use it for the new thread; if not, construct one
from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new
thread to be set as `TF_CONFIG` environment.
Args:
task_fn: The function to run in the new thread.
cluster_spec: The cluster spec.
task_type: The task type.
task_id: The task id.
*args: Additional positional arguments to provide to the thread's task_fn.
**kwargs: Additional keyword arguments to provide to the thread's task_fn.
If `tf_config` is provided, that dict will be used for the TF_CONFIG for
the new thread.
Returns:
The thread that has started.
"""
tf_config = kwargs.pop('tf_config', None)
if tf_config is None:
if task_type:
tf_config = {
'cluster': cluster_spec,
'task': {
'type': task_type,
'index': task_id
}
}
else:
tf_config = {
'cluster': cluster_spec,
}
t = threading.Thread(
target=self._task_thread,
args=(task_fn, tf_config, context.executing_eagerly()) + args,
kwargs=kwargs)
t.start()
return t
def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
**kwargs):
# The task_fn should create std_server by itself.
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
*args, **kwargs)
threads[task_type].append(t)
return threads
def join_independent_workers(self, worker_threads):
with skip_if_grpc_server_cant_be_started(self):
self._coord.join(worker_threads)
class MultiWorkerMultiProcessTest(test.TestCase):
"""Testing infra for independent workers using multiple processes."""
def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id):
env = os.environ.copy()
env['TF_CONFIG'] = json.dumps({
'cluster': cluster_spec,
'task': {
'type': task_type,
'index': task_id
}
})
return subprocess.Popen(
cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
@deprecation.deprecated(
None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
'requiring multiple processes should use `multi_process_runner` for '
'better support of log printing, streaming, and more functionality.')
def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
"""Run `cmd_args` in a process for each task in `cluster_spec`."""
processes = {}
for task_type in cluster_spec.keys():
processes[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
p = self._run_task_in_process(cmd_args, cluster_spec, task_type,
task_id)
processes[task_type].append(p)
return processes
@deprecation.deprecated(
None, '`join_independent_workers` is deprecated; any new test '
'requiring multiple processes should use `multi_process_runner` for '
'better support of log printing, streaming, and more functionality.')
def join_independent_workers(self, worker_processes):
return_codes = []
for p in nest.flatten(worker_processes):
try:
# Calling p.wait() will hang if we don't consume its output.
p.communicate()
except ValueError:
# The output of the process may have been consumed, in which case
# calling `p.communicate()` will raise a ValueError.
pass
finally:
return_codes.append(p.returncode)
for return_code in return_codes:
self.assertEqual(return_code, 0)
@deprecation.deprecated(
None, '`stream_stderr` is deprecated; any new test '
'requiring multiple processes should use `multi_process_runner` for '
'better support of log printing, streaming, and more functionality.')
def stream_stderr(self, processes, print_only_first=False):
"""Consume stderr of all processes and print to stdout.
To reduce the amount of logging, caller can set print_only_first to True.
In that case, this function only prints stderr from the first process of
each type.
Args:
processes: A dictionary from process type string -> list of processes.
print_only_first: If true, only print output from first process of each
type.
"""
def _stream_stderr_single_process(process, type_string, index,
print_to_stdout):
"""Consume a single process's stderr and optionally print to stdout."""
while True:
output = process.stderr.readline()
if not output and process.poll() is not None:
break
if output and print_to_stdout:
print('{}{} {}'.format(type_string, index, output.strip()))
sys.stdout.flush()
stream_threads = []
for process_type, process_list in six.iteritems(processes):
for i in range(len(process_list)):
print_to_stdout = (not print_only_first) or (i == 0)
thread = threading.Thread(
target=_stream_stderr_single_process,
args=(process_list[i], process_type, i, print_to_stdout))
thread.start()
stream_threads.append(thread)
for thread in stream_threads:
thread.join()
def get_tf_config_task():
return json.loads(os.environ['TF_CONFIG'])['task']
def get_tf_config_cluster_spec():
return json.loads(os.environ['TF_CONFIG'])['cluster']
def get_task_type():
return get_tf_config_task()['type']
def get_task_index():
return get_tf_config_task()['index']
def is_chief():
return ('chief' not in get_tf_config_cluster_spec()
and get_task_type() == 'worker'
and get_task_index() == 0)