| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for remote eager execution.""" |
| |
| import os |
| import threading |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_op |
| from tensorflow.core.protobuf import cluster_pb2 |
| from tensorflow.core.protobuf import tensorflow_server_pb2 |
| from tensorflow.python import pywrap_tfe |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import executor |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import coordinator |
| from tensorflow.python.training import server_lib |
| |
| JOB_NAME = "remote_device" |
| |
| |
| def get_server_def(job_name, local_server_port, remote_server_addresses, |
| task_index): |
| """Returns a server def with a single job + multiple tasks.""" |
| cluster_def = cluster_pb2.ClusterDef() |
| job_def = cluster_def.job.add() |
| job_def.name = job_name |
| job_def.tasks[0] = "localhost:%d" % local_server_port |
| |
| for i, remote_server_address in enumerate(remote_server_addresses, start=1): |
| job_def.tasks[i] = remote_server_address |
| |
| server_def = tensorflow_server_pb2.ServerDef( |
| cluster=cluster_def, |
| job_name=job_name, |
| task_index=task_index, |
| protocol="grpc") |
| server_def.default_session_config.experimental.coordination_config.service_type = "standalone" |
| |
| return server_def |
| |
| |
| class DynamicClusterTest(test.TestCase, parameterized.TestCase): |
| |
| def __init__(self, methodName="runTest"): # pylint: disable=invalid-name |
| super(DynamicClusterTest, self).__init__(methodName) |
| self._cached_server1 = server_lib.Server.create_local_server() |
| self._cached_server2 = server_lib.Server.create_local_server() |
| self._cached_server3 = server_lib.Server.create_local_server() |
| self._cached_server4 = server_lib.Server.create_local_server() |
| |
| self._cached_server1_target = self._cached_server1.target[len("grpc://"):] |
| self._cached_server2_target = self._cached_server2.target[len("grpc://"):] |
| self._cached_server3_target = self._cached_server3.target[len("grpc://"):] |
| self._cached_server4_target = self._cached_server4.target[len("grpc://"):] |
| |
| self.server_def_s1 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[self._cached_server1_target], |
| task_index=0) |
| self.server_def_s1_s2 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[ |
| self._cached_server1_target, self._cached_server2_target |
| ], |
| task_index=0) |
| self.server_def_s1_s3 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[ |
| self._cached_server1_target, self._cached_server3_target |
| ], |
| task_index=0) |
| self.server_def_s4_s3 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[ |
| self._cached_server4_target, self._cached_server3_target |
| ], |
| task_index=0) |
| self.server_def_s1_s2_s3 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[ |
| self._cached_server1_target, self._cached_server2_target, |
| self._cached_server3_target |
| ], |
| task_index=0) |
| self.server_def_s1_s2_s3_s4 = get_server_def( |
| JOB_NAME, |
| local_server_port=0, |
| remote_server_addresses=[ |
| self._cached_server1_target, self._cached_server2_target, |
| self._cached_server3_target, self._cached_server4_target |
| ], |
| task_index=0) |
| |
| self.device_local = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME |
| self.device_t1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME |
| self.device_t2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME |
| self.device_t3 = "/job:%s/replica:0/task:3/device:CPU:0" % JOB_NAME |
| self.device_t4 = "/job:%s/replica:0/task:4/device:CPU:0" % JOB_NAME |
| |
| def setUp(self): |
| super(DynamicClusterTest, self).setUp() |
| os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = str(False) |
| local_port = pywrap_tfe.TF_PickUnusedPortOrDie() |
| context.set_server_def( |
| server_def=get_server_def( |
| JOB_NAME, |
| local_server_port=local_port, |
| remote_server_addresses=[ |
| self._cached_server1_target, self._cached_server2_target |
| ], |
| task_index=0)) |
| |
| def tearDown(self): |
| super(DynamicClusterTest, self).tearDown() |
| ops.device(None).__enter__() |
| context._reset_context() |
| |
| def testCheckPreemption(self): |
| preemption_key = "TF_DEFAULT_PREEMPTION_NOTICE_KEY" |
| preemption_task = "/job:worker/task:0" |
| |
| with ops.device(self.device_t1): |
| gen_check_preemption_op.check_preemption(preemption_key=preemption_key) |
| |
| # Simulate a preemption notifier callback invocation. |
| context.context().set_config_key_value(preemption_key, preemption_task) |
| with self.assertRaises(errors.AbortedError) as cm: |
| with ops.device(self.device_t2): |
| gen_check_preemption_op.check_preemption(preemption_key=preemption_key) |
| self.assertEqual( |
| cm.exception.experimental_payloads.get( |
| b"type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption" |
| ), preemption_task.encode()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testServerAdded(self): |
| """Add a server to cluster, and run remote ops on it.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2_s3) |
| with ops.device(self.device_t3): |
| x2 = array_ops.ones([2, 2]) |
| |
| # Test new server accessing resources on old server |
| with ops.device(self.device_t3): |
| y = math_ops.matmul(x1, x2) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| # Test old server accessing resources on new server |
| with ops.device(self.device_t2): |
| y = math_ops.matmul(x1, x2) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testServerRemoved(self): |
| """Remove a server from cluster, and run ops on cluster.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| with ops.device(self.device_t2): |
| x2 = array_ops.ones([2, 2]) |
| |
| with ops.device(self.device_t1): |
| y = math_ops.matmul(x1, x2) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| context.update_server_def(server_def=self.server_def_s1) |
| with ops.device(self.device_t1): |
| y = math_ops.matmul(x1, x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| # Running ops on removed server s2 throws an exception |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| with ops.device(self.device_t2): |
| y = math_ops.matmul(x1, x2) |
| self.assertIn("unknown device", cm.exception.message) |
| |
| # TODO(haoyuzhang): raise and catch exception when accessing tensors on |
| # the removed servers. |
| |
| @test_util.run_in_async_and_sync_mode |
| def testServerReplaced(self): |
| """Replace remote host_port for a task, and run ops on cluster.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| context.update_server_def(server_def=self.server_def_s1_s3) |
| with ops.device(self.device_t2): |
| y = math_ops.matmul(x1, x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testFunctionServerAdded(self): |
| """Add a server to cluster, and run remote function on it.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2_s3) |
| with ops.device(self.device_t3): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| with ops.device(self.device_t3): |
| x2 = array_ops.ones([2, 2]) |
| with ops.device(self.device_t1): |
| y = worker_fn(x2) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testFunctionServerRemoved(self): |
| """Remove a server from cluster, and run ops on cluster.""" |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1) |
| |
| with ops.device(self.device_t1): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| # Running functions on removed server s2 throws an exception |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| with ops.device(self.device_t2): |
| y = worker_fn(x1) |
| self.assertIn(" unknown device", cm.exception.message) |
| |
| # TODO(haoyuzhang): raise and catch exception when accessing tensors on |
| # the removed servers. |
| |
| @test_util.run_in_async_and_sync_mode |
| def testFunctionServerRemovedAddedBack(self): |
| """Add and remove a server, and run functions on cluster.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2_s3) |
| with ops.device(self.device_t3): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2) |
| with ops.device(self.device_t2): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2_s3) |
| with ops.device(self.device_t3): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testFunctionServerReplaced(self): |
| """Replace remote host_port for a task, and run functions on cluster.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1_s3) |
| with ops.device(self.device_t2): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testFunctionRegisteredAndRemoved(self): |
| """Update cluster when other function are registered and removed.""" |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| num_calls = 30 |
| self._coord = coordinator.Coordinator() |
| |
| def update_server_def_fn(): |
| with self._coord.stop_on_exception(): |
| for i in range(num_calls): |
| context.update_server_def( |
| server_def=(self.server_def_s1_s2 if i % |
| 2 == 0 else self.server_def_s1_s3)) |
| |
| t = threading.Thread(target=update_server_def_fn) |
| t.start() |
| |
| for _ in range(num_calls): |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| concrete_fn = worker_fn.get_concrete_function(x1) |
| del concrete_fn |
| del worker_fn |
| |
| # No exception should be thrown from the thread |
| self._coord.join([t]) |
| |
| def testPendingNodesServerReplaced(self): |
| """Update cluster when nodes are still pending on remote workers.""" |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| # Add enough ops so they are pending when changing the cluster |
| num_nodes = 10 |
| ret = [None] * num_nodes |
| for i in range(num_nodes): |
| with ops.device(self.device_t1): |
| ret[i] = worker_fn(x1) |
| # While nodes are still pending on worker s1, replace worker s2 with s3. |
| context.update_server_def(server_def=self.server_def_s1_s3) |
| with ops.device(self.device_t2): |
| y = worker_fn(x1) |
| for i in range(num_nodes): |
| np.testing.assert_array_equal([[2, 2], [2, 2]], ret[i].numpy()) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testMultiThreadPendingNodesServerReplaced(self): |
| """Update cluster when other remote function calls are being launched.""" |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| num_calls = 10 |
| lock = threading.Lock() |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| def thread_fn(device, results): |
| for i in range(num_calls): |
| lock.acquire() |
| with ops.device(device): |
| y = worker_fn(x1) |
| results[i] = y.numpy() |
| lock.release() |
| |
| def update_server_def_fn(): |
| for i in range(num_calls): |
| lock.acquire() |
| context.update_server_def( |
| server_def=(self.server_def_s1_s2 if i % |
| 2 == 0 else self.server_def_s1_s3)) |
| lock.release() |
| |
| t1_results = [None] * num_calls |
| t2_results = [None] * num_calls |
| threads = [] |
| threads.append( |
| threading.Thread(target=thread_fn, args=(self.device_t1, t1_results))) |
| threads.append( |
| threading.Thread(target=thread_fn, args=(self.device_t2, t2_results))) |
| threads.append(threading.Thread(target=update_server_def_fn)) |
| for t in threads: |
| t.start() |
| for t in threads: |
| t.join() |
| for result in t1_results + t2_results: |
| np.testing.assert_array_equal([[2, 2], [2, 2]], result) |
| |
| def testMultiThreadPendingNodesLockFree(self): |
| """Update cluster when other remote function calls are being launched.""" |
| |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| num_calls = 10 |
| self._coord = coordinator.Coordinator() |
| |
| @def_function.function |
| def worker_fn(i): |
| return math_ops.matmul(i, i) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| def thread_fn(device, results): |
| for i in range(num_calls): |
| with self._coord.stop_on_exception(): |
| with ops.device(device): |
| results[i] = worker_fn(x1).numpy() |
| |
| def update_server_def_fn(): |
| for _ in range(30): |
| with self._coord.stop_on_exception(): |
| context.update_server_def(self.server_def_s1_s2) |
| |
| t1_results = [None] * num_calls |
| t2_results = [None] * num_calls |
| threads = [] |
| threads.append( |
| threading.Thread(target=thread_fn, args=(self.device_t1, t1_results))) |
| threads.append( |
| threading.Thread(target=thread_fn, args=(self.device_t2, t2_results))) |
| threads.append(threading.Thread(target=update_server_def_fn)) |
| for t in threads: |
| t.start() |
| self._coord.join(threads) |
| for result in t1_results + t2_results: |
| np.testing.assert_array_equal([[2, 2], [2, 2]], result) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testDistributedFunctionServerAdded(self): |
| """Add a server to cluster, and run distributed function on it.""" |
| with ops.device(self.device_t1): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| with ops.device(self.device_t2): |
| mul = math_ops.matmul(i, i) |
| return mul - array_ops.zeros_like(mul) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2_s3) |
| with ops.device(self.device_t3): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testDistributedFunctionServerRemovedAddedBack(self): |
| """Add then remove a server, and run distributed function on cluster.""" |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| with ops.device(self.device_t1): |
| mul = math_ops.matmul(i, i) |
| return mul - array_ops.zeros_like(mul) |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| context.update_server_def(server_def=self.server_def_s1) |
| with ops.device(self.device_t1): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| context.update_server_def(server_def=self.server_def_s1_s2) |
| with ops.device(self.device_t2): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| @test_util.run_in_async_and_sync_mode |
| def testDistributedFunctionBothServersReplaced(self): |
| """Tests that replacing servers works correctly. |
| |
| We create two servers, t1 and t2. We first replace t2, then we replace t1. |
| |
| Among other things, this ensures that both already existing, and |
| restarted workers have the context view IDs correctly updated. |
| """ |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| with ops.device(self.device_t1): |
| mul = math_ops.matmul(i, i) |
| with ops.device(self.device_t2): |
| add = mul + i |
| return add - i |
| |
| # Forces function tracing and registration |
| worker_fn.get_concrete_function(x1) |
| |
| # Replace task2 |
| context.update_server_def(server_def=self.server_def_s1_s3) |
| for device in (self.device_t1, self.device_t2): |
| with ops.device(device): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| # Then replace task1 |
| context.update_server_def(server_def=self.server_def_s4_s3) |
| for device in (self.device_t1, self.device_t2): |
| with ops.device(device): |
| y = worker_fn(x1) |
| np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) |
| |
| def testDistributedFunctionPendingNodesServerReplaced(self): |
| with ops.device(self.device_local): |
| x1 = array_ops.ones([2, 2]) |
| |
| @def_function.function |
| def worker_fn(i): |
| with ops.device(self.device_t1): |
| mul = math_ops.matmul(i, i) |
| with ops.device(self.device_t2): |
| add = mul + i |
| return add - i |
| |
| worker_fn.get_concrete_function(x1) |
| |
| num_calls = 10 |
| self._coord = coordinator.Coordinator() |
| |
| def thread_fn(device, results): |
| with self._coord.stop_on_exception(): |
| for i in range(num_calls): |
| with ops.device(device): |
| y = worker_fn(x1) |
| results[i] = y.numpy() |
| |
| def update_server_def_fn(): |
| with self._coord.stop_on_exception(): |
| for i in range(num_calls): |
| context.update_server_def( |
| server_def=(self.server_def_s1_s2_s3 if i % |
| 2 == 0 else self.server_def_s1_s2)) |
| |
| results = [None] * num_calls |
| threads = [] |
| threads.append( |
| threading.Thread(target=thread_fn, args=(self.device_t1, results))) |
| threads.append(threading.Thread(target=update_server_def_fn)) |
| for t in threads: |
| t.start() |
| self._coord.join(threads) |
| for result in results: |
| np.testing.assert_array_equal([[2, 2], [2, 2]], result) |
| |
| def testParameterServerMultiExecutors(self): |
| context.update_server_def(server_def=self.server_def_s1_s2_s3_s4) |
| |
| with ops.device(self.device_t1): |
| v1 = variables.Variable(initial_value=0.) |
| with ops.device(self.device_t2): |
| v2 = variables.Variable(initial_value=10.) |
| |
| @def_function.function |
| def worker_fn(): |
| x1 = v1.read_value() |
| x2 = v2.read_value() |
| grad = (x1 + x2) * 0.1 |
| v1.assign_add(grad) |
| v2.assign_sub(grad) |
| return v1 + v2 |
| |
| worker_fn.get_concrete_function() |
| |
| executor_t3 = executor.new_executor(enable_async=False) |
| executor_t4 = executor.new_executor(enable_async=False) |
| |
| num_calls = 10 |
| self._coord = coordinator.Coordinator() |
| |
| def thread_fn(executor_obj, device, results): |
| with self._coord.stop_on_exception(): |
| for i in range(num_calls): |
| with context.executor_scope(executor_obj): |
| with ops.device(device): |
| results[i] = worker_fn() |
| |
| def update_server_def_fn(): |
| with self._coord.stop_on_exception(): |
| for _ in range(30): |
| context.update_server_def(self.server_def_s1_s2_s3_s4) |
| |
| t3_results = [None] * num_calls |
| t4_results = [None] * num_calls |
| threads = [] |
| threads.append( |
| threading.Thread( |
| target=thread_fn, args=(executor_t3, self.device_t3, t3_results))) |
| threads.append( |
| threading.Thread( |
| target=thread_fn, args=(executor_t4, self.device_t4, t4_results))) |
| threads.append(threading.Thread(target=update_server_def_fn)) |
| for t in threads: |
| t.start() |
| self._coord.join(threads) |
| |
| # Cannot assert individual values since the results are non-deterministic. |
| # By summing up the value we ensure that there are all reasonable and valid |
| # numbers (not `None` or `NaN`). |
| total = np.sum(t3_results + t4_results) |
| self.assertGreater(total, 0) |
| |
| def testCheckAlive(self): |
| with self.assertRaisesRegex(ValueError, "Context is not initialized."): |
| context.check_alive("/job:remote_device/task:0") |
| context.context().ensure_initialized() |
| |
| self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) |
| self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) |
| |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Unable to find worker interface"): |
| context.check_alive("/job:remote_device/replica:0/task:10") |
| |
| |
| if __name__ == "__main__": |
| ops.enable_eager_execution() |
| test.main() |