blob: a2f5310f33c32ba68fe1e32afc0ac322d2530f24 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for distributed_table."""
import copy
import os
from absl.testing import parameterized
# The following import helps load the keras injection function we use in
# parameter_server_strategy_v2 -- keras_deps.get_load_context_function.
from tensorflow import keras # pylint: disable=unused-import
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
from tensorflow.python.distribute.coordinator import coordinator_context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import save as tf_save
source_combination = combinations.combine(source=["textfile", "keyvaluetensor"])
class DistributedTableTest(test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(DistributedTableTest, cls).setUpClass()
cls.cluster = multi_worker_test_base.create_multi_process_cluster(
num_workers=2, num_ps=3, rpc_layer="grpc")
cls.cluster_resolver = cls.cluster.cluster_resolver
@classmethod
def tearDownClass(cls):
super(DistributedTableTest, cls).tearDownClass()
cls.cluster.stop()
def make_initializer(self, init_source, vals):
if init_source == "textfile":
file = os.path.join(self.get_temp_dir(), "text_file_initializer")
with open(file, "w") as f:
f.write("\n".join(str(v) for v in vals) + "\n")
return lookup_ops.TextFileInitializer(
filename=file,
key_dtype=dtypes.int64,
key_index=lookup_ops.TextFileIndex.LINE_NUMBER,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.WHOLE_LINE)
elif init_source == "keyvaluetensor":
keys_tensor = constant_op.constant(
list(range(len(vals))), dtype=dtypes.int64)
vals_tensor = constant_op.constant(vals, dtype=dtypes.int64)
return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
else:
raise ValueError("Unrecognized init_source: " + init_source)
def createStaticHashTable(self,
init_source=None,
vals=None,
default_value=None,
initializer=None):
if not initializer:
initializer = self.make_initializer(init_source, vals)
return lookup_ops.StaticHashTable(
initializer=initializer, default_value=default_value)
def makeDatasetFromTensorWithoutUsingResource(self, input_context, tensor):
"""Returns a dataset made from `tensor`. To be called in a dataset_fn."""
global_batch_size = 24
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = dataset_ops.DatasetV2.from_tensors(tensor).repeat().batch(
batch_size, drop_remainder=True)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.prefetch(2) # This prefetches 2 batches per device.
return dataset
@combinations.generate(source_combination)
def testCreateDistributedTableInScope(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookuptable = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
self.assertIsInstance(lookuptable, ps_values.DistributedTable)
self.assertEqual(self.evaluate(lookuptable.size()), 3)
# Lookup on the coordinator.
output = lookuptable.lookup(
constant_op.constant([0, 1, -1], dtype=dtypes.int64))
self.assertAllEqual([0, 1, -2], output)
self.assertEqual(lookuptable.size(), 3)
@combinations.generate(source_combination)
def testCopyDistributedTable(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookuptable = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
new_table = copy.copy(lookuptable)
# No new coordinator instance or distributed tables are created.
self.assertDictEqual(lookuptable.__dict__, new_table.__dict__)
@combinations.generate(source_combination)
def testCreateLookupInDatasetFnUnderScope(self, source):
# TODO(wxinyi): Warn the user of the inefficiency of this workflow (i.e.
# creating `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to
# be distributed with `distribute_datasets_from_function` and
# `create_per_worker_dataset`.
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
def dataset_fn(input_context):
some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64)
lookuptable = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
self.assertNotIsInstance(lookuptable, ps_values.DistributedTable)
generation_tensor = lookuptable.lookup(some_out_of_range_tensor)
dataset = self.makeDatasetFromTensorWithoutUsingResource(
input_context, generation_tensor)
return dataset
@def_function.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
@def_function.function
def worker_fn(iterator):
return math_ops.reduce_sum(next(iterator))
result = []
for _ in range(10):
result.append(
coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
for r in result:
returned_input = r.fetch()
self.assertAllClose(-48, returned_input)
@combinations.generate(source_combination)
def testAccessingResourceHandleInDatasetFnWithoutMap(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookuptable = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
def dataset_fn(input_context):
some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64)
self.assertIsInstance(lookuptable, ps_values.DistributedTable)
generation_tensor = lookuptable.lookup(some_out_of_range_tensor)
dataset = self.makeDatasetFromTensorWithoutUsingResource(
input_context, generation_tensor)
return dataset
@def_function.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
@def_function.function
def worker_fn(iterator):
return math_ops.reduce_sum(next(iterator))
result = []
for _ in range(10):
result.append(
coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
for r in result:
returned_input = r.fetch()
self.assertAllClose(-48, returned_input)
@combinations.generate(
combinations.combine(
source=["textfile", "keyvaluetensor"],
create_datasets_under_scope=[True, False],
using_dataset_instance_not_function=[True, False],
create_per_worker_dataset_takes_instance=[True, False]))
def testCreateTableUnderScopeCombo(self, source,
create_datasets_under_scope,
using_dataset_instance_not_function,
create_per_worker_dataset_takes_instance):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookup_table = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
if using_dataset_instance_not_function:
def per_worker_dataset_fn():
dataset = dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64))
dataset = dataset.repeat().batch(24, drop_remainder=True).prefetch(2)
dataset = dataset.map(lookup_table.lookup)
return strategy.experimental_distribute_dataset(dataset)
else:
def per_worker_dataset_fn():
def dataset_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(24)
dataset = dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64))
dataset = dataset.repeat().batch(batch_size, drop_remainder=True)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.prefetch(2) # This prefetches 2 batches per device.
dataset = dataset.map(lookup_table.lookup)
return dataset
return strategy.distribute_datasets_from_function(dataset_fn)
if create_datasets_under_scope:
with strategy.scope():
if create_per_worker_dataset_takes_instance:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn())
else:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
else:
if create_per_worker_dataset_takes_instance:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn())
else:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
@def_function.function
def worker_fn(iterator):
return math_ops.reduce_sum(next(iterator))
result = []
for _ in range(10):
result.append(
coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
for r in result:
returned_input = r.fetch()
self.assertAllClose(-24, returned_input)
@combinations.generate(
combinations.combine(
source=["textfile", "keyvaluetensor"],
create_datasets_under_scope=[True, False],
using_dataset_instance_not_function=[True, False],
create_per_worker_dataset_takes_instance=[True, False]))
def testCreateTableInDatasetCombo(self, source, create_datasets_under_scope,
using_dataset_instance_not_function,
create_per_worker_dataset_takes_instance):
if using_dataset_instance_not_function and (
not create_per_worker_dataset_takes_instance):
# This is the case that uses the `experimental_distribute_dataset` API to
# distribute dataset (instead of the `distribute_datasets_from_function`
# API), and passes `create_per_worker_dataset` a function that returns
# the distributed dataset (instead of passing it the distributed dataset
# directly).
# TODO(b/201775366): evaluate whether we need to handle this case
self.skipTest("Failed to serialize the input pipeline graph")
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
if using_dataset_instance_not_function:
def per_worker_dataset_fn():
# If this line is being called under strategy.scope(), it becomes a
# DistributedTable. Interestingly, after
# `experimental_distribute_dataset` serializes the dataset on chief and
# deserializes it on workers, `lookup_table` becomes a
# RestoredDistributedTable instead of a DistributedTable. And when it’s
# `resource_handle` is being accessed on the worker, it does not detect
# a DispatchContext, so it returns the restored resource handle,
# which is also the one on the local worker. The LookupTableFindV2 ops
# is on the local worker, too.
lookup_table = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
if create_datasets_under_scope:
self.assertIsInstance(lookup_table, ps_values.DistributedTable)
dataset = dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64))
dataset = dataset.repeat().batch(24, drop_remainder=True).prefetch(2)
dataset = dataset.map(lookup_table.lookup)
return strategy.experimental_distribute_dataset(dataset)
else:
def per_worker_dataset_fn():
def dataset_fn(input_context):
# When we're wrapping the initialization of a StaticHashTable inside a
# `dataset_fn` to be distributed with
# `distribute_datasets_from_function`, no matter it's called under
# strategy.scope() or not, this call creates a StaticHashTable on
# chief instead of a DistributedTable on chief and workers.
# And correspondingly, LookupTableFindV2 ops is on chief and there are
# send-recv communication for the lookup.
lookup_table = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
if create_datasets_under_scope:
self.assertIsInstance(lookup_table, lookup_ops.StaticHashTable)
self.assertNotIsInstance(lookup_table, ps_values.DistributedTable)
batch_size = input_context.get_per_replica_batch_size(24)
dataset = dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64))
dataset = dataset.repeat().batch(batch_size, drop_remainder=True)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.prefetch(2) # This prefetches 2 batches per device.
dataset = dataset.map(lookup_table.lookup)
return dataset
return strategy.distribute_datasets_from_function(dataset_fn)
if create_datasets_under_scope:
with strategy.scope():
if create_per_worker_dataset_takes_instance:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn())
else:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
else:
if create_per_worker_dataset_takes_instance:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn())
else:
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
@def_function.function
def worker_fn(iterator):
return math_ops.reduce_sum(next(iterator))
result = []
for _ in range(10):
result.append(
coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
for r in result:
returned_input = r.fetch()
self.assertAllClose(-24, returned_input)
@combinations.generate(source_combination)
def testAccessingTableInStepFunction(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookup_table = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
dataset = (
dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64)).repeat().batch(
24, drop_remainder=True).prefetch(2))
dataset = dataset.map(lookup_table.lookup)
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
distributed_dataset = coordinator.create_per_worker_dataset(
distributed_dataset)
@def_function.function
def worker_fn(iterator):
def replica_fn(inputs):
return math_ops.reduce_sum(lookup_table.lookup(inputs))
all_results = strategy.run(replica_fn, args=(next(iterator),))
return all_results
steps_per_epoch = 10
distributed_iterator = iter(distributed_dataset)
result = []
for _ in range(steps_per_epoch):
result.append(
coordinator.schedule(worker_fn, args=(distributed_iterator,)))
coordinator.join()
for r in result:
returned_input = r.fetch()
self.assertAllClose(-24, returned_input)
@combinations.generate(source_combination)
def testAccessingResourceHandleInDatasetFnWithMapFnDefinedOutside(
self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
with strategy.scope():
lookuptable = self.createStaticHashTable(
init_source=source, vals=[0, 1, 2], default_value=-2)
def map_fn(vals):
return lookuptable.lookup(vals)
def dataset_fn(input_context):
generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
dataset = self.makeDatasetFromTensorWithoutUsingResource(
input_context, generation_tensor)
dataset = dataset.map(map_fn)
return dataset
@def_function.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
@def_function.function
def worker_fn(iterator):
return math_ops.reduce_sum(next(iterator))
result = []
for _ in range(10):
# batch_size == 24 and each input is [0, 1, -2]
result.append(
coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
for r in result:
returned_input = r.fetch()
self.assertAllClose(-24, returned_input)
class Model(module.Module):
def __init__(self, init_source, filepath):
vals = [0, 1, 2]
if init_source == "textfile":
with open(filepath, "w") as f:
f.write("\n".join(str(v) for v in vals) + "\n")
self.initializer = lookup_ops.TextFileInitializer(
filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
else:
keys_tensor = constant_op.constant(
list(range(len(vals))), dtype=dtypes.int64)
vals_tensor = constant_op.constant(vals, dtype=dtypes.int64)
self.initializer = lookup_ops.KeyValueTensorInitializer(
keys_tensor, vals_tensor)
self.table = lookup_ops.StaticHashTable(
self.initializer, default_value=-2)
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)])
def use_table(self, x):
return self.table.lookup(x)
def verifyWorkerLocalInstance(self, coordinator, model):
# assert capturing a worker-local resource on each worker
for worker in coordinator._cluster.workers:
with coordinator_context.with_dispatch_context(worker):
captures = model.use_table.get_concrete_function().captured_inputs
resource_capture = [t for t in captures if t.dtype == dtypes.resource]
self.assertNotEmpty(resource_capture)
for capture in resource_capture:
self.assertEqual(
capture.device,
device_util.canonicalize("/CPU:0", default=worker.device_name))
@combinations.generate(source_combination)
def testInModelAndCapture(self, source):
file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
model = self.Model(source, file_path)
func_captures = model.use_table.get_concrete_function(
).graph.external_captures
self.assertLen(func_captures, 2)
self.assertTrue(
any(model.table.resource_handle is t for t in func_captures))
deferred_captures = model.use_table.get_concrete_function(
).graph.deferred_external_captures
self.assertEmpty(deferred_captures)
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy)
with strategy.scope():
distributed_model = self.Model("value", file_path)
func_captures = distributed_model.use_table.get_concrete_function(
).graph.external_captures
# One less external_capture, since the table handle becomes a closure in the
# deferred_external_capture
self.assertLen(func_captures, 1)
self.assertFalse(
any(model.table.resource_handle is t for t in func_captures))
deferred_captures = distributed_model.use_table.get_concrete_function(
).graph.deferred_external_captures
self.assertNotEmpty(deferred_captures)
self.verifyWorkerLocalInstance(coordinator, distributed_model)
@combinations.generate(source_combination)
def testLookupInNestedTFWhileLoop(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
with strategy.scope():
model = self.Model(source, file_path)
@def_function.function
def replica_fn(batch_data):
replica_result = array_ops.zeros(shape=(), dtype=dtypes.int64)
for _ in math_ops.range(10):
replica_result += math_ops.reduce_sum(model.use_table(batch_data))
return replica_result
@def_function.function
def step_fn(iterator):
step_result = array_ops.zeros(shape=(), dtype=dtypes.int64)
for _ in math_ops.range(10):
step_result += strategy.run(replica_fn, args=(next(iterator),))
return step_result
dataset = (
dataset_ops.DatasetV2.from_tensors(
constant_op.constant([0, 1, 3], dtype=dtypes.int64)).repeat().batch(
24, drop_remainder=True).prefetch(2))
distributed_dataset = coordinator.create_per_worker_dataset(
strategy.experimental_distribute_dataset(dataset))
results = []
for _ in range(10):
results.append(
coordinator.schedule(step_fn, args=(iter(distributed_dataset),)))
coordinator.join()
for r in results:
self.assertAllClose(-2400, r.fetch())
@combinations.generate(source_combination)
def testDistributeTableSaveAndServe(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
with strategy.scope():
model = self.Model(source, file_path)
model_dir = self.get_temp_dir()
tf_save.save(model, model_dir)
loaded_without_strategy = tf_load.load(model_dir)
loaded_func_captures_without_strategy = (
loaded_without_strategy.use_table.get_concrete_function().graph
.external_captures)
loaded_func_deferred_captures_without_strategy = (
loaded_without_strategy.use_table.get_concrete_function().graph
.deferred_external_captures)
self.assertLen(loaded_func_captures_without_strategy, 2)
self.assertEmpty(loaded_func_deferred_captures_without_strategy)
self.assertAllEqual(
loaded_without_strategy.use_table(
constant_op.constant([0, 1, 3], dtype=dtypes.int64)), [0, 1, -2])
@combinations.generate(source_combination)
def testDistributeTableSaveAndLoadUnderStrategy(self, source):
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver)
coordinator = coordinator_lib.ClusterCoordinator(strategy)
file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
with strategy.scope():
model = self.Model(source, file_path)
model_dir = self.get_temp_dir()
tf_save.save(model, model_dir)
with strategy.scope():
loaded = tf_load.load(model_dir)
loaded_func_captures = (
loaded.use_table.get_concrete_function().graph.external_captures)
loaded_func_deferred_captures = (
loaded.use_table.get_concrete_function().graph
.deferred_external_captures)
# Compared with loading without strategy, there is one less
# external_capture, since the captured table handle has been swapped to a
# closure in the deferred_external_capture
self.assertLen(loaded_func_captures, 1)
self.assertNotEmpty(loaded_func_deferred_captures)
self.assertIsInstance(loaded.table, ps_values.DistributedTable)
self.assertLen([
t for t in loaded.use_table.get_concrete_function().captured_inputs
if t.dtype == dtypes.resource
], 1)
self.verifyWorkerLocalInstance(coordinator, loaded)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
multi_process_runner.test_main()