blob: 3a77dfefc18bd69154ac522569f661b387a8fef7 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Distribute Coordinator."""
import contextlib
import copy
import json
import os
import sys
import threading
import time
import six
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_manager
CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
NUM_WORKERS = 3
NUM_PS = 2
original_sys_exit = sys.exit
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
return maybe_bytes
else:
return str(maybe_bytes, "utf-8")
def _strip_protocol(target):
# cluster_spec expects "host:port" strings.
if "//" in target:
return target.split("//")[1]
else:
return target
class MockExtended(object):
def __init__(self,
between_graph=False,
should_init=None,
should_checkpoint=None,
should_save_summary=None):
self.experimental_between_graph = between_graph
self.experimental_should_init = should_init
self.should_checkpoint = should_checkpoint
self.should_save_summary = should_save_summary
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=None,
should_checkpoint=None,
should_save_summary=None):
self.extended = MockExtended(between_graph, should_init, should_checkpoint,
should_save_summary)
def configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
if self.extended.experimental_should_init is None:
if task_id == 0:
self.extended.experimental_should_init = True
else:
self.extended.experimental_should_init = False
if self.extended.should_checkpoint is None:
if task_id == 0:
self.extended.should_checkpoint = True
else:
self.extended.should_checkpoint = False
if self.extended.should_save_summary is None:
if task_id == 0:
self.extended.should_save_summary = True
else:
self.extended.should_save_summary = False
if session_config:
if (cluster_spec and task_type and task_id is not None and
self.extended.experimental_between_graph):
session_config.intra_op_parallelism_threads += 1
if task_type in ["chief", "worker"]:
session_config.device_filters.extend(
["/job:%s/task:%d" % (task_type, task_id), "/job:ps"])
else:
session_config.inter_op_parallelism_threads += 1
session_config.device_filters.append("/job:somejob")
class MockServer(object):
def __init__(self):
self._joined = False
self._started = False
def start(self):
self._started = True
def join(self):
assert not self._joined
self._joined = True
@property
def joined(self):
return self._joined
@property
def started(self):
return self._started
class DistributeCoordinatorTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
# We have to create a global in-process cluster because once an in-process
# tensorflow server is created, there is no way to terminate it. Please see
# multi_worker_test_base.py for more details.
# TODO(yuefengz): use the utitliy from multi_worker_test_base.
cls._workers, cls._ps = test_util.create_local_cluster(
NUM_WORKERS, num_ps=NUM_PS)
cls._cluster_spec = {
WORKER: [
_strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
],
PS: [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps]
}
def setUp(self):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
self._strategy_property = {}
self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
self._coord = coordinator.Coordinator()
@contextlib.contextmanager
def _test_session(self, target):
config = config_pb2.ConfigProto(allow_soft_placement=True)
config.graph_options.optimizer_options.opt_level = -1
with session.Session(graph=None, config=config, target=target) as sess:
yield sess
# TODO(yuefengz): use the utitliy from multi_worker_test_base.
def _create_cluster_spec(self,
has_chief=False,
num_workers=1,
num_ps=0,
has_eval=False):
cluster_spec = {}
if has_chief:
cluster_spec[CHIEF] = ["localhost:%s" % test_util.pick_unused_port()]
if num_workers:
cluster_spec[WORKER] = [
"localhost:%s" % test_util.pick_unused_port()
for _ in range(num_workers)
]
if num_ps:
cluster_spec[PS] = [
"localhost:%s" % test_util.pick_unused_port() for _ in range(num_ps)
]
if has_eval:
cluster_spec[EVALUATOR] = ["localhost:%s" % test_util.pick_unused_port()]
return cluster_spec
def _in_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
expected = 0.0
for i in range(context.num_workers):
with ops.device("/job:worker/task:%d" % i):
x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
x_add = x.assign_add(float(i))
xs.append(x_add)
expected += i + 10.0
with ops.device("/job:worker/task:0"):
result = math_ops.add_n(xs)
self.evaluate(variables.global_variables_initializer())
result_value = sess.run(result)
self.assertEqual(result_value, expected)
if result_value == expected:
self._result_correct += 1
def _wrapped_worker_fn(self, worker_fn):
def wrapped(*args, **kwargs):
with self._coord.stop_on_exception():
return worker_fn(*args, **kwargs)
return wrapped
def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator,
args=(self._wrapped_worker_fn(worker_fn), strategy),
kwargs=kwargs)
t.start()
return t
def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
cluster_spec, **kwargs):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread(
worker_fn,
strategy,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
**kwargs)
threads[task_type].append(t)
return threads
def _join_threads(self, threads):
try:
self._coord.join(threads)
except errors.UnknownError as e:
if "Could not start gRPC server" in e.message:
self.skipTest("Cannot start std servers.")
else:
raise
def _between_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable(
"x", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable(
"y", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
if context.is_chief:
self.evaluate(variables.global_variables_initializer())
# Synchronize workers after initializaton.
if context.has_barrier:
context.wait_for_other_workers()
else:
while True:
uninit_vars = sess.run(variables.report_uninitialized_variables())
# pylint: disable=g-explicit-length-test
if len(uninit_vars) == 0:
break
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
def _between_graph_with_monitored_session(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
# The monitored session will run init or ready ops.
with monitored_session.MonitoredSession() as sess:
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
Args:
strategy: a `DistributionStrategy` object.
"""
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._worker_context:
self._worker_context[task_type] = []
while len(self._worker_context[task_type]) <= task_id:
self._worker_context[task_type].append(None)
self._worker_context[task_type][task_id] = (context.master_target,
context.num_workers,
context.is_chief,
context.distributed_mode)
def _dump_strategy_property(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
self.assertEqual(context._strategy.extended.experimental_should_init,
strategy.extended.experimental_should_init)
self.assertEqual(context.should_checkpoint,
strategy.extended.should_checkpoint)
self.assertEqual(context.should_save_summary,
strategy.extended.should_save_summary)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._strategy_property:
self._strategy_property[task_type] = []
while len(self._strategy_property[task_type]) <= task_id:
self._strategy_property[task_type].append(None)
self._strategy_property[task_type][task_id] = (
context._strategy.extended.experimental_should_init,
context.should_checkpoint,
context.should_save_summary)
def _run_mock_std_server(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None,
rpc_layer=None,
environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
if task_type not in self._std_servers:
self._std_servers[task_type] = []
while len(self._std_servers[task_type]) <= task_id:
self._std_servers[task_type].append(None)
server = MockServer()
self._std_servers[task_type][task_id] = server
return server
class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
def testInGraphStandaloneMode(self):
"""Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn,
MockStrategy(between_graph=False),
cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
"""Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn,
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@test_util.run_v1_only("MonitoredSession removed from v2")
def testBetweenGraphWithMonitoredSession(self):
"""Test monitored session in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1)
self.assertTrue(WORKER in self._worker_context)
self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
self._worker_context[WORKER][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
self.assertEqual(
self._worker_context[WORKER][1],
(_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
self.assertEqual(
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def testBetweenGraphStrategyProperties(self):
# Dumps properties of the strategy objects.
distribute_coordinator.run_distribute_coordinator(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
self.assertTrue("None" in self._worker_context)
self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
self._worker_context["None"][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=None)
# There is only a "None" task.
self.assertEqual(len(self._worker_context), 1)
self.assertTrue("None" in self._worker_context)
self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[CHIEF] = ["fake_chief"]
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec=cluster_spec,
rpc_layer="grpc")
# There are one CHIEF and three workers.
self.assertEqual(len(self._worker_context), 2)
self.assertTrue(CHIEF in self._worker_context)
self.assertTrue(WORKER in self._worker_context)
self.assertEqual(len(self._worker_context[CHIEF]), 1)
self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context[CHIEF][0],
("grpc://fake_chief", 4, True, True))
self.assertEqual(
self._worker_context[WORKER][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True))
self.assertEqual(
self._worker_context[WORKER][1],
(_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True))
self.assertEqual(
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True))
def testInGraphContextWithEval(self):
# Adds a EVALUATOR job.
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[EVALUATOR] = ["fake_evaluator"]
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=cluster_spec,
rpc_layer=None)
# There are one "None" task and one EVALUATOR task.
self.assertEqual(len(self._worker_context), 2)
self.assertTrue("None" in self._worker_context)
self.assertTrue(EVALUATOR in self._worker_context)
self.assertEqual(len(self._worker_context["None"]), 1)
self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
_bytes_to_str(self._workers[0].target)), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
class DistributeCoordinatorTestIndependentWorkerMode(
DistributeCoordinatorTestBase):
def testInGraph(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn,
MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER)
self._join_threads([threads[WORKER][0]])
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
cluster_spec = self._create_cluster_spec(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER)
self._join_threads(threads[WORKER])
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@test_util.run_v1_only("MonitoredSession removed from v2")
def testBetweenGraphWithMonitoredSession(self):
cluster_spec = self._create_cluster_spec(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER)
self._join_threads(threads[WORKER])
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
self._join_threads(threads[WORKER])
# There is only one type of task and three such tasks.
self.assertEqual(len(self._worker_context), 1)
self.assertTrue(WORKER in self._worker_context)
self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
self._worker_context[WORKER][0],
(_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
self.assertEqual(
self._worker_context[WORKER][1],
(_bytes_to_str(cluster_spec[WORKER][1]), NUM_WORKERS, False, True))
self.assertEqual(
self._worker_context[WORKER][2],
(_bytes_to_str(cluster_spec[WORKER][2]), NUM_WORKERS, False, True))
# Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined)
def testBetweenGraphStrategyProperties(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps properties of the strategy objects.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
self._join_threads(threads[WORKER])
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
self._join_threads(threads[WORKER])
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
self.assertTrue("None" in self._worker_context)
self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
self._worker_context["None"][0],
(_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
# Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertTrue(self._std_servers[WORKER][1].joined)
self.assertTrue(self._std_servers[WORKER][2].joined)
def testInGraphContextWithEval(self):
# Adds a EVALUATOR job.
cluster_spec = self._create_cluster_spec(
num_workers=NUM_WORKERS, has_eval=True)
# Dumps the task contexts and std server arguments.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
self._join_threads(threads[WORKER])
self._join_threads([threads[EVALUATOR][0]])
# There are one "None" task and one EVALUATOR task.
self.assertEqual(len(self._worker_context), 2)
self.assertTrue("None" in self._worker_context)
self.assertTrue(EVALUATOR in self._worker_context)
self.assertEqual(len(self._worker_context["None"]), 1)
self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
# Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertTrue(self._std_servers[WORKER][1].joined)
self.assertTrue(self._std_servers[WORKER][2].joined)
def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
tf_config = {"cluster": cluster_spec, "environment": "google"}
joined = [False]
def _fake_sleep(_):
joined[0] = True
original_sys_exit(0)
def _thread_fn(cluster_spec):
distribute_coordinator.run_distribute_coordinator(
None,
MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
time, "sleep", _fake_sleep):
t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
t.start()
t.join()
self.assertTrue(joined[0])
def testRpcLayerEnvironmentVariable(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
rpc_layer_from_coordinator = [None]
def _run_mock_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None,
environment=None):
del cluster_spec, task_type, task_id, session_config, environment
rpc_layer_from_coordinator[0] = rpc_layer
return MockServer()
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
distribute_coordinator, "_run_std_server", _run_mock_server):
distribute_coordinator.run_distribute_coordinator(
None,
MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)
self.assertEqual(rpc_layer_from_coordinator[0], "cake")
class StrategyConfigureTest(test.TestCase):
def setUp(self):
self._device_filters = []
self._intra_op_parallelism_threads = None
self._inter_op_parallelism_threads = None
super(StrategyConfigureTest, self).setUp()
def _dump_device_filters(self, *args, **kwargs):
session_config = kwargs.get("session_config", None)
self._device_filters.extend(session_config.device_filters)
self._intra_op_parallelism_threads = (
session_config.intra_op_parallelism_threads)
self._inter_op_parallelism_threads = (
session_config.inter_op_parallelism_threads)
return MockServer()
def _worker_fn(self, strategy):
worker_context = distribute_coordinator_context.get_current_worker_context()
session_config = worker_context._session_config
self._device_filters.extend(session_config.device_filters)
self._intra_op_parallelism_threads = (
session_config.intra_op_parallelism_threads)
self._inter_op_parallelism_threads = (
session_config.inter_op_parallelism_threads)
return MockServer()
def test_session_config_in_std_server(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
tf_config = {"cluster": cluster_spec}
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
distribute_coordinator, "_run_std_server",
self._dump_device_filters):
distribute_coordinator.run_distribute_coordinator(
lambda _: None,
MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="worker",
task_id=0)
self.assertEqual(self._intra_op_parallelism_threads, 1)
self.assertEqual(self._inter_op_parallelism_threads, 0)
def test_session_config_in_session_creator(self):
cluster_spec = {"worker": ["localhost:0"]}
tf_config = {"cluster": cluster_spec}
# Reset the saved Server state.
distribute_coordinator._thread_local = threading.local() # pylint: disable=protected-access
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
distribute_coordinator.run_distribute_coordinator(
self._worker_fn,
MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="worker",
task_id=0)
self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
self.assertEqual(self._intra_op_parallelism_threads, 2)
self.assertEqual(self._inter_op_parallelism_threads, 0)
def test_eval_strategy_configure(self):
cluster_spec = {"evaluator": ["localhost:0"]}
tf_config = {"cluster": cluster_spec}
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
distribute_coordinator.run_distribute_coordinator(
lambda _: None,
MockStrategy(between_graph=False),
eval_fn=self._worker_fn,
eval_strategy=MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="evaluator",
task_id=0)
self.assertEqual(self._device_filters, ["/job:somejob"])
self.assertEqual(self._intra_op_parallelism_threads, 0)
self.assertEqual(self._inter_op_parallelism_threads, 2)
class RunStandardTensorflowServerTest(test.TestCase):
def test_std_server_arguments(self):
cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}}
def _mock_run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None):
self.assertEqual(cluster_spec.as_dict(), cs)
self.assertEqual(task_type, "ps")
self.assertEqual(task_id, 0)
self.assertEqual(session_config.experimental.collective_group_leader,
"/job:worker/replica:0/task:0")
self.assertEqual(session_config.intra_op_parallelism_threads, 1)
self.assertEqual(rpc_layer, "grpc")
return MockServer()
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
distribute_coordinator, "_run_std_server", _mock_run_std_server):
session_config = config_pb2.ConfigProto()
session_config.intra_op_parallelism_threads = 1
mock_server = distribute_coordinator.run_standard_tensorflow_server(
session_config)
self.assertTrue(mock_server.started)
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminate std server threads.
with test.mock.patch.object(sys, "exit", os._exit):
# Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly.
orig_init = session_manager.SessionManager.__init__
def new_init(*args, **kwargs):
kwargs.pop("recovery_wait_secs", None)
kwargs["recovery_wait_secs"] = 0.5
orig_init(*args, **kwargs)
session_manager.SessionManager.__init__ = new_init
test.main()