| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for Cluster Resolvers.""" |
| |
| from tensorflow.python.client import session |
| from tensorflow.python.distribute.cluster_resolver import cluster_resolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import config |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import server_lib |
| |
| mock = test.mock |
| |
| |
| class MockBaseClusterResolver(cluster_resolver.ClusterResolver): |
| |
| def cluster_spec(self): |
| return None |
| |
| def master(self, task_type=None, task_id=None, rpc_layer=None): |
| return "" |
| |
| def environment(self): |
| return "" |
| |
| |
| @test_util.run_all_in_graph_and_eager_modes |
| class BaseClusterResolverTest(test.TestCase): |
| |
| @mock.patch.object(config, "list_logical_devices") |
| @mock.patch.object(session.BaseSession, "list_devices") |
| def testNumAcceleratorsSuccess(self, mock_list_devices, |
| mock_eager_list_devices): |
| devices = [ |
| context.LogicalDevice("/job:worker/task:0/device:GPU:0", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:1", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:2", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:3", "GPU"), |
| ] |
| device_list = [ |
| session._DeviceAttributes(d.name, d.device_type, 1024, 0) |
| for d in devices |
| ] |
| mock_eager_list_devices.return_value = devices |
| mock_list_devices.return_value = device_list |
| |
| resolver = MockBaseClusterResolver() |
| self.assertEqual(resolver.num_accelerators(), {"GPU": 4}) |
| |
| @mock.patch.object(config, "list_logical_devices") |
| @mock.patch.object(session.BaseSession, "list_devices") |
| def testNumAcceleratorsMultiDeviceSuccess(self, mock_list_devices, |
| mock_eager_list_devices): |
| devices = [ |
| context.LogicalDevice("/job:worker/task:0/device:TPU:0", "TPU"), |
| context.LogicalDevice("/job:worker/task:0/device:TPU:1", "TPU"), |
| context.LogicalDevice("/job:worker/task:0/device:TPU:2", "TPU"), |
| context.LogicalDevice("/job:worker/task:0/device:TPU:3", "TPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:0", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:1", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:2", "GPU"), |
| context.LogicalDevice("/job:worker/task:0/device:GPU:3", "GPU"), |
| ] |
| device_list = [ |
| session._DeviceAttributes(d.name, d.device_type, 1024, 0) |
| for d in devices |
| ] |
| mock_eager_list_devices.return_value = devices |
| mock_list_devices.return_value = device_list |
| |
| resolver = MockBaseClusterResolver() |
| self.assertEqual(resolver.num_accelerators(), {"TPU": 4, "GPU": 4}) |
| |
| @mock.patch.object(config, "list_logical_devices") |
| @mock.patch.object(session.BaseSession, "list_devices") |
| def testNumAcceleratorsFilterTasks(self, mock_list_devices, |
| mock_eager_list_devices): |
| devices = [ |
| context.LogicalDevice("/job:worker1/task:0/device:TPU:0", "TPU"), |
| context.LogicalDevice("/job:worker1/task:0/device:TPU:1", "TPU"), |
| context.LogicalDevice("/job:worker1/task:0/device:GPU:0", "GPU"), |
| context.LogicalDevice("/job:worker1/task:0/device:GPU:1", "GPU"), |
| context.LogicalDevice("/job:worker2/task:1/device:TPU:2", "TPU"), |
| context.LogicalDevice("/job:worker2/task:2/device:TPU:3", "TPU"), |
| context.LogicalDevice("/job:worker2/task:3/device:GPU:2", "GPU"), |
| context.LogicalDevice("/job:worker2/task:4/device:GPU:3", "GPU"), |
| ] |
| device_list = [ |
| session._DeviceAttributes(d.name, d.device_type, 1024, 0) |
| for d in devices |
| ] |
| mock_eager_list_devices.return_value = devices |
| mock_list_devices.return_value = device_list |
| |
| resolver = MockBaseClusterResolver() |
| self.assertEqual(resolver.num_accelerators(task_type="worker1", task_id=0), |
| {"TPU": 2, "GPU": 2}) |
| self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=3), |
| {"GPU": 1}) |
| self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=4), |
| {"GPU": 1}) |
| |
| |
| class UnionClusterResolverTest(test.TestCase): |
| # TODO(frankchn): Transform to parameterized test after it is included in the |
| # TF open source codebase. |
| |
| def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): |
| self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) |
| self.assertProtoEquals( |
| expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) |
| self.assertProtoEquals( |
| expected_proto, |
| server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) |
| self.assertProtoEquals( |
| expected_proto, |
| server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) |
| |
| def testSingleClusterResolver(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| simple_resolver = cluster_resolver.SimpleClusterResolver(base_cluster_spec) |
| union_resolver = cluster_resolver.UnionClusterResolver(simple_resolver) |
| |
| expected_proto = """ |
| job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } |
| tasks { key: 1 value: 'ps1:2222' } } |
| job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } |
| tasks { key: 1 value: 'worker1:2222' } |
| tasks { key: 2 value: 'worker2:2222' } } |
| """ |
| actual_cluster_spec = union_resolver.cluster_spec() |
| self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) |
| |
| def testInitSimpleClusterResolver(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| |
| simple_resolver = cluster_resolver.SimpleClusterResolver( |
| base_cluster_spec, task_type="ps", |
| task_id=1, environment="cloud", |
| num_accelerators={"GPU": 8}, |
| rpc_layer="grpc", |
| ) |
| |
| self.assertEqual(simple_resolver.task_type, "ps") |
| self.assertEqual(simple_resolver.task_id, 1) |
| self.assertEqual(simple_resolver.environment, "cloud") |
| self.assertEqual(simple_resolver.num_accelerators(), {"GPU": 8}) |
| self.assertEqual(simple_resolver.rpc_layer, "grpc") |
| |
| def testOverrideSimpleClusterResolver(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| |
| simple_resolver = cluster_resolver.SimpleClusterResolver( |
| base_cluster_spec, task_type="ps", |
| task_id=1, environment="cloud", |
| num_accelerators={"GPU": 8}, |
| rpc_layer="grpc", |
| ) |
| |
| simple_resolver.task_type = "worker" |
| simple_resolver.task_id = 2 |
| simple_resolver.rpc_layer = "http" |
| |
| self.assertEqual(simple_resolver.task_type, "worker") |
| self.assertEqual(simple_resolver.task_id, 2) |
| self.assertEqual(simple_resolver.rpc_layer, "http") |
| |
| def testSimpleOverrideMasterWithTaskIndexZero(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| |
| simple_resolver = cluster_resolver.SimpleClusterResolver(base_cluster_spec) |
| actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc") |
| self.assertEqual(actual_master, "grpc://worker0:2222") |
| |
| def testSimpleOverrideMasterWithRpcLayer(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| |
| simple_resolver = cluster_resolver.SimpleClusterResolver(base_cluster_spec) |
| actual_master = simple_resolver.master("worker", 2, rpc_layer="grpc") |
| self.assertEqual(actual_master, "grpc://worker2:2222") |
| |
| def testSimpleOverrideMaster(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| |
| simple_resolver = cluster_resolver.SimpleClusterResolver(base_cluster_spec) |
| actual_master = simple_resolver.master("worker", 2) |
| self.assertEqual(actual_master, "worker2:2222") |
| |
| def testUnionClusterResolverGetProperties(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "ps": ["ps0:2222", "ps1:2222"], |
| "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] |
| }) |
| resolver1 = cluster_resolver.SimpleClusterResolver( |
| cluster_spec_1, task_type="ps", |
| task_id=1, environment="cloud", |
| num_accelerators={"GPU": 8}, |
| rpc_layer="grpc", |
| ) |
| |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "ps": ["ps2:2222", "ps3:2222"], |
| "worker": ["worker3:2222", "worker4:2222", "worker5:2222"] |
| }) |
| resolver2 = cluster_resolver.SimpleClusterResolver( |
| cluster_spec_2, task_type="worker", |
| task_id=2, environment="local", |
| num_accelerators={"GPU": 16}, |
| rpc_layer="http", |
| ) |
| |
| union_resolver = cluster_resolver.UnionClusterResolver(resolver1, resolver2) |
| |
| self.assertEqual(union_resolver.task_type, "ps") |
| self.assertEqual(union_resolver.task_id, 1) |
| self.assertEqual(union_resolver.environment, "cloud") |
| self.assertEqual(union_resolver.num_accelerators(), {"GPU": 8}) |
| self.assertEqual(union_resolver.rpc_layer, "grpc") |
| |
| union_resolver.task_type = "worker" |
| union_resolver.task_id = 2 |
| union_resolver.rpc_layer = "http" |
| |
| self.assertEqual(union_resolver.task_type, "worker") |
| self.assertEqual(union_resolver.task_id, 2) |
| self.assertEqual(union_resolver.rpc_layer, "http") |
| |
| def testTwoNonOverlappingJobMergedClusterResolver(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "ps": [ |
| "ps0:2222", |
| "ps1:2222" |
| ] |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker0:2222", |
| "worker1:2222", |
| "worker2:2222" |
| ] |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| cluster_spec = union_cluster.cluster_spec() |
| |
| expected_proto = """ |
| job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } |
| tasks { key: 1 value: 'ps1:2222' } } |
| job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } |
| tasks { key: 1 value: 'worker1:2222' } |
| tasks { key: 2 value: 'worker2:2222' } } |
| """ |
| self._verifyClusterSpecEquality(cluster_spec, expected_proto) |
| |
| def testMergedClusterResolverMaster(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "ps": [ |
| "ps0:2222", |
| "ps1:2222" |
| ] |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker0:2222", |
| "worker1:2222", |
| "worker2:2222" |
| ] |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| |
| unspecified_master = union_cluster.master() |
| self.assertEqual(unspecified_master, "") |
| |
| specified_master = union_cluster.master("worker", 1) |
| self.assertEqual(specified_master, "worker1:2222") |
| |
| rpc_master = union_cluster.master("worker", 1, rpc_layer="grpc") |
| self.assertEqual(rpc_master, "grpc://worker1:2222") |
| |
| def testOverlappingJobMergedClusterResolver(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker4:2222", |
| "worker5:2222" |
| ] |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker0:2222", |
| "worker1:2222", |
| "worker2:2222" |
| ] |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| cluster_spec = union_cluster.cluster_spec() |
| |
| expected_proto = """ |
| job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } |
| tasks { key: 1 value: 'worker5:2222' } |
| tasks { key: 2 value: 'worker0:2222' } |
| tasks { key: 3 value: 'worker1:2222' } |
| tasks { key: 4 value: 'worker2:2222' } } |
| """ |
| self._verifyClusterSpecEquality(cluster_spec, expected_proto) |
| |
| def testOverlappingSparseJobMergedClusterResolverThrowError(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "worker": { |
| 7: "worker4:2222", |
| 9: "worker5:2222" |
| } |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": { |
| 3: "worker0:2222", |
| 6: "worker1:2222", |
| 7: "worker2:2222" |
| } |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| self.assertRaises(KeyError, union_cluster.cluster_spec) |
| |
| def testOverlappingDictAndListThrowError(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker4:2222", |
| "worker5:2222" |
| ] |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": { |
| 1: "worker0:2222", |
| 2: "worker1:2222", |
| 3: "worker2:2222" |
| } |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| self.assertRaises(KeyError, union_cluster.cluster_spec) |
| |
| def testOverlappingJobNonOverlappingKey(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "worker": { |
| 5: "worker4:2222", |
| 9: "worker5:2222" |
| } |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": { |
| 3: "worker0:2222", |
| 6: "worker1:2222", |
| 7: "worker2:2222" |
| } |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| cluster_spec = union_cluster.cluster_spec() |
| |
| expected_proto = """ |
| job { name: 'worker' tasks { key: 3 value: 'worker0:2222' } |
| tasks { key: 5 value: 'worker4:2222' } |
| tasks { key: 6 value: 'worker1:2222' } |
| tasks { key: 7 value: 'worker2:2222' } |
| tasks { key: 9 value: 'worker5:2222' }} |
| """ |
| self._verifyClusterSpecEquality(cluster_spec, expected_proto) |
| |
| def testMixedModeNonOverlappingKey(self): |
| cluster_spec_1 = server_lib.ClusterSpec({ |
| "worker": [ |
| "worker4:2222", |
| "worker5:2222" |
| ] |
| }) |
| cluster_spec_2 = server_lib.ClusterSpec({ |
| "worker": { |
| 3: "worker0:2222", |
| 6: "worker1:2222", |
| 7: "worker2:2222" |
| } |
| }) |
| cluster_resolver_1 = cluster_resolver.SimpleClusterResolver(cluster_spec_1) |
| cluster_resolver_2 = cluster_resolver.SimpleClusterResolver(cluster_spec_2) |
| |
| union_cluster = cluster_resolver.UnionClusterResolver( |
| cluster_resolver_1, cluster_resolver_2) |
| cluster_spec = union_cluster.cluster_spec() |
| |
| expected_proto = """ |
| job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } |
| tasks { key: 1 value: 'worker5:2222' } |
| tasks { key: 3 value: 'worker0:2222' } |
| tasks { key: 6 value: 'worker1:2222' } |
| tasks { key: 7 value: 'worker2:2222' }} |
| """ |
| self._verifyClusterSpecEquality(cluster_spec, expected_proto) |
| |
| def testRetainSparseJobWithNoMerging(self): |
| base_cluster_spec = server_lib.ClusterSpec({ |
| "worker": { |
| 1: "worker0:2222", |
| 3: "worker1:2222", |
| 5: "worker2:2222" |
| } |
| }) |
| |
| base_cluster_resolver = cluster_resolver.SimpleClusterResolver( |
| base_cluster_spec) |
| union_cluster = cluster_resolver.UnionClusterResolver(base_cluster_resolver) |
| cluster_spec = union_cluster.cluster_spec() |
| |
| expected_proto = """ |
| job { name: 'worker' tasks { key: 1 value: 'worker0:2222' } |
| tasks { key: 3 value: 'worker1:2222' } |
| tasks { key: 5 value: 'worker2:2222' } } |
| """ |
| self._verifyClusterSpecEquality(cluster_spec, expected_proto) |
| |
| |
| # TODO(saeta): Include tests for master resolution |
| |
| if __name__ == "__main__": |
| test.main() |