blob: 26742b96b5e6cd5487bd3d10a2347e23e93e8f0b [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API on TPU."""
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu.tests import tpu_embedding_base_test
class TPUEmbeddingTest(tpu_embedding_base_test.TPUEmbeddingBaseTest):
def test_tables_with_same_name(self):
with self.assertRaisesRegex(
ValueError, 'Multiple tables with name table found.'):
with self._get_strategy().scope():
tpu_embedding_v2.TPUEmbedding(
(tpu_embedding_v2_utils.FeatureConfig(
table=tpu_embedding_v2_utils.TableConfig(
name='table',
vocabulary_size=4,
dim=2,
initializer=self.initializer,),
name='watched'),
tpu_embedding_v2_utils.FeatureConfig(
table=tpu_embedding_v2_utils.TableConfig(
name='table',
vocabulary_size=4,
dim=2,
initializer=self.initializer),
name='favorited')),
tpu_embedding_v2_utils.SGD(learning_rate=0.1))
def test_pass_non_tensor_to_apply_gradients(self):
self.skip_if_oss()
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
# We aren't going to actually run anything, so the batch_size here does not
# matter.
mid_level_api.build(64)
# Test pass non tensor to apply_gradients.
@def_function.function
def test_apply_1():
mid_level_api.apply_gradients((1, 2, 3))
with self.assertRaisesRegex(ValueError, 'found non-tensor type'):
strategy.run(test_apply_1)
# Test pass different structure to apply_gradients.
@def_function.function
def test_apply_2():
# This should be a tuple as feature_config is a tuple of 3 configs.
mid_level_api.apply_gradients([1, 2, 3])
with self.assertRaisesRegex(
TypeError, 'The two structures don\'t have the same nested structure.'):
strategy.run(test_apply_2)
def test_enqueue_weight_for_dense_tensor(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_dense_dataset(strategy, include_weights=True)
dense_iter = iter(
strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
@def_function.function
def test_fn():
def step():
return mid_level_api.dequeue()
features, weights = next(dense_iter)
mid_level_api.enqueue(features, weights=weights, training=False)
return strategy.run(step)
with self.assertRaisesRegex(ValueError, 'Weight specified for dense input'):
test_fn()
def test_enqueue_wrong_weight_type_for_sparse_and_ragged_tensor(self):
self.skip_if_oss()
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
sparse = self._create_sparse_dataset(strategy, include_weights=True)
ragged = self._create_ragged_dataset(strategy, include_weights=True)
sparse_iter = iter(
strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
ragged_iter = iter(
strategy.experimental_distribute_dataset(
ragged,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
@def_function.function
def test_sparse_fn():
def step():
return mid_level_api.dequeue()
features, _ = next(sparse_iter)
_, weights = next(ragged_iter)
mid_level_api.enqueue(features, weights=weights, training=False)
return strategy.run(step)
with self.assertRaisesRegex(
ValueError, 'which does not match type input which is SparseTensor.'):
test_sparse_fn()
@def_function.function
def test_ragged_fn():
def step():
return mid_level_api.dequeue()
_, weights = next(sparse_iter)
features, _ = next(ragged_iter)
mid_level_api.enqueue(features, weights=weights, training=False)
return strategy.run(step)
with self.assertRaisesRegex(
ValueError, 'which does not match type input which is RaggedTensor.'):
test_ragged_fn()
def test_enqueue_incorrect_structure_for_features_and_weights(self):
self.skip_if_oss()
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
sparse = self._create_sparse_dataset(strategy, include_weights=True)
sparse_iter = iter(
strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False)))
@def_function.function
def test_features_fn():
def step():
return mid_level_api.dequeue()
features = next(sparse_iter)
features = (features[0],)
mid_level_api.enqueue(features, training=False)
return strategy.run(step)
# The error here is raised from nest.assert_same_structure
with self.assertRaises(ValueError):
test_features_fn()
@def_function.function
def test_weights_fn():
def step():
return mid_level_api.dequeue()
features, weights = next(sparse_iter)
weights = (weights[0],)
mid_level_api.enqueue(features, weights=weights, training=False)
return strategy.run(step)
# The error here is raised from nest.assert_same_structure
with self.assertRaises(ValueError):
test_weights_fn()
def test_enqueue_cpu_tensor(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_dense_dataset(strategy)
dense_iter = iter(strategy.experimental_distribute_dataset(dataset))
@def_function.function
def test_fn():
def get_activations():
return mid_level_api.dequeue()
features = next(dense_iter)
mid_level_api.enqueue(features, training=False)
activations = strategy.run(get_activations)
return activations
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
test_fn()
@parameterized.parameters([True, False])
def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir):
if use_mlir:
config.enable_mlir_bridge()
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_dense_dataset(strategy)
dense_iter = iter(strategy.experimental_distribute_dataset(dataset))
@def_function.function
def test_fn():
def get_activations(features):
mid_level_api.enqueue(features, training=False)
return mid_level_api.dequeue()
activations = strategy.run(get_activations, args=(next(dense_iter),))
return activations
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
test_fn()
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()