blob: 239e9f42f2c82a35326a00b921ccafb9b87c6025 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API on TPU."""
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.platform import test
from tensorflow.python.tpu.tests import tpu_embedding_base_test
class TPUEmbeddingTest(tpu_embedding_base_test.TPUEmbeddingBaseTest):
@parameterized.parameters([True, False])
def test_sequence_feature(self, is_sparse):
seq_length = 3
# Set the max_seq_length in feature config
for feature in self.feature_config:
feature.max_sequence_length = seq_length
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
if is_sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
feature_iter = iter(
strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
@def_function.function
def test_fn():
def step():
return mid_level_api.dequeue()
mid_level_api.enqueue(next(feature_iter), training=False)
return strategy.run(step)
output = test_fn()
self.assertEqual(
self._get_replica_numpy(output[0], strategy, 0).shape, (2, 3, 4))
self.assertEqual(
self._get_replica_numpy(output[1], strategy, 0).shape, (2, 3, 4))
self.assertEqual(
self._get_replica_numpy(output[2], strategy, 0).shape, (2, 3, 2))
@parameterized.parameters([True, False])
def test_sequence_feature_with_build(self, is_updated_shape):
seq_length = 3
# Set the max_seq_length in feature config
for feature in self.feature_config:
feature.max_sequence_length = seq_length
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
feature_iter = iter(
strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
if is_updated_shape:
mid_level_api.build([
TensorShape([self.batch_size, seq_length, 2]),
TensorShape([self.batch_size, seq_length, 2]),
TensorShape([self.batch_size, seq_length, 3])
])
else:
mid_level_api.build([
TensorShape([self.batch_size, 2]),
TensorShape([self.batch_size, 2]),
TensorShape([self.batch_size, 3])
])
@def_function.function
def test_fn():
def step():
return mid_level_api.dequeue()
mid_level_api.enqueue(next(feature_iter), training=False)
return strategy.run(step)
output = test_fn()
self.assertEqual(
self._get_replica_numpy(output[0], strategy, 0).shape, (2, 3, 4))
self.assertEqual(
self._get_replica_numpy(output[1], strategy, 0).shape, (2, 3, 4))
self.assertEqual(
self._get_replica_numpy(output[2], strategy, 0).shape, (2, 3, 2))
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()