blob: 5240229585824f4b912833063fa672eddf96715b [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API on TPU."""
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.eager import def_function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu.tests import tpu_embedding_v2_correctness_base_test
from tensorflow.python.util import nest
class TPUEmbeddingCorrectnessTest(
tpu_embedding_v2_correctness_base_test.TPUEmbeddingCorrectnessBaseTest):
@parameterized.parameters([True, False])
def test_sequence_embeddings(self, sparse):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched',
max_sequence_length=2),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='favorited',
max_sequence_length=2),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_user, name='friends',
max_sequence_length=3))
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
strategy = self._get_strategy()
num_replicas = strategy.num_replicas_in_sync
with strategy.scope():
mid_level = tpu_embedding_v2.TPUEmbedding(
feature_config=feature_config,
optimizer=optimizer)
# Call build here. We call 'next' outside of the tf.function and this
# results in data where the shape of the sparse tensor is a tensor which we
# can't tell the shape of at tracing time.
mid_level.build(self.batch_size)
if sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
data = next(
iter(
strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_fetch_to_device=False))))
@def_function.function
def embedding_and_set_gradients(data):
def tpu_fn():
activations = mid_level.dequeue()
mid_level.apply_gradients(nest.map_structure(array_ops.ones_like,
activations))
return activations
mid_level.enqueue(data)
return strategy.run(tpu_fn)
@def_function.function
def embedding_only(data):
def tpu_fn():
return mid_level.dequeue()
mid_level.enqueue(data)
return strategy.run(tpu_fn)
# Only check core 0.
before_update = self._get_replica_numpy(
embedding_and_set_gradients(data), strategy, 0)
after_update = self._get_replica_numpy(embedding_only(data), strategy, 0)
# For videos table, row 0 and row 1 are looked up 3*num_replicas times as
# they occur 3 times per replica (considering the features 0 and 1 which are
# both looked up in the videos table).
# Feature 0 has ids [0, 0, 1], [0, 1, 1], ... repeated over num_replicas
# Feature 1 has ids [0, 1, 1], [0, 0, 1], ... repeated over num_replicas
# This means that both rows 0 and 1 get a -0.1*3*num_replicas update
# For users table, each row is looked up twice:
# Feature 2 has ids [3, 0, 1, 2], .. repeated over num_replicas
# This means that we get a -0.1*num_replicas update to the third feature.
# In general this means that after the update, if we lookup feature 0 and 1
# the values will be 0.3*num_replicas lower per entry and for feature 2 they
# will be 0.1*num_replicas lower.
# The one issue is that these lookups contain padding values.
# For core 0, we get the first 2 elements of the 4 element batch.
# For feature 0, the indices are [[0, 0], [1, 0], [1, 1]] with max sequence
# length of 2, which means that [0, 1] will be 0s.
# For feature 1, the indices are [[0, 0], [0, 1], [1, 0]] with max sequence
# length of 2, which means that [1, 1] will be 0s.
# For feature 2, the indices are [[0, 0], [1, 0], [1, 1], [1, 2]] with max
# sequence length of 3, which means that [0, 1], [0, 2] will be 0s.
# The following masks represent that so that we only apply the above updates
# to the non-padding rows:
masks = (
np.array([[[1], [0]], [[1], [1]]]),
np.array([[[1], [1]], [[1], [0]]]),
np.array([[[1], [0], [0]], [[1], [1], [1]]]))
per_row_update = (0.3 * num_replicas,
0.3 * num_replicas,
0.1 * num_replicas)
golden = tuple([before - update * mask for before, update, mask in
zip(before_update, per_row_update, masks)])
self.assertAllClose(golden, after_update)
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()