blob: 6a63666e365da5d0f42b656f85a2a74c95b1d522 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that sparse tensors work with GPU, such as placement of int and string.
Test using sparse tensors with distributed dataset. Since GPU does
not support strings, sparse tensors containing string should always be placed
on CPU.
"""
from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
def sparse_int64():
return sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
values=constant_op.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.int64),
dense_shape=[8, 4])
def sparse_str():
return sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
values=constant_op.constant(['1', '2', '3', '4', '5', '6', '7', '8']),
dense_shape=[8, 4])
class FactoryOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.parameters(
(sparse_int64,),
(sparse_str,),
)
@test_util.run_gpu_only
def testSparseWithDistributedDataset(self, sparse_factory):
@def_function.function
def distributed_dataset_producer(t):
strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1'])
sparse_ds = dataset_ops.Dataset.from_tensor_slices(t).batch(2)
dist_dataset = strategy.experimental_distribute_dataset(sparse_ds)
ds = iter(dist_dataset)
result = strategy.experimental_local_results(next(ds))[0]
# Reach the end of the iterator
for ignore in ds: # pylint: disable=unused-variable
pass
return result
t = sparse_factory()
result = distributed_dataset_producer(t)
self.assertAllEqual(
self.evaluate(sparse_ops.sparse_tensor_to_dense(t)[0]),
self.evaluate(sparse_ops.sparse_tensor_to_dense(result)[0]))
if __name__ == '__main__':
test.main()