blob: c87f8f83f0ea6baa3b72d4dae6b5b324aa46982b [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API on TPU."""
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.platform import test
from tensorflow.python.tpu.tests import tpu_embedding_base_test
class TPUEmbeddingCorrectnessBaseTest(
tpu_embedding_base_test.TPUEmbeddingBaseTest):
def _test_embedding(self, optimizer_name, training, sparse,
is_high_dimensional):
strategy, mid_level_api, optimizer = (
self._create_strategy_and_mid_level(optimizer_name))
if sparse:
if is_high_dimensional:
dataset = self._create_high_dimensional_sparse_dataset(strategy)
else:
dataset = self._create_sparse_dataset(strategy)
else:
if is_high_dimensional:
dataset = self._create_high_dimensional_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
if is_high_dimensional:
if sparse:
mid_level_api.build([
TensorShape([self.batch_size, self.data_batch_size, 2]),
TensorShape([self.batch_size, self.data_batch_size, 2]),
TensorShape([self.batch_size, self.data_batch_size, 3]),
])
else:
mid_level_api.build([
TensorShape([self.batch_size, self.data_batch_size, None]),
TensorShape([self.batch_size, self.data_batch_size, None]),
TensorShape([self.batch_size, self.data_batch_size, None]),
])
dist = strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(experimental_fetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
def test_fn():
def step():
"""Create and run computation that returns the embedding activations."""
if not training:
activations = mid_level_api.dequeue()
total_loss = self._get_total_loss_tensor(activations)
ret_val = [total_loss] + list(activations)
return ret_val
else:
with backprop.GradientTape() as tape:
activations = mid_level_api.dequeue()
tape.watch(activations)
total_loss = self._get_total_loss_tensor(activations)
loss_per_replica = total_loss / strategy.num_replicas_in_sync
gradients = tape.gradient(loss_per_replica, activations)
mid_level_api.apply_gradients(gradients)
ret_val = [total_loss] + list(activations)
return ret_val
mid_level_api.enqueue(next(dist_iter), training=training)
result = strategy.run(step)
return result
# Run model.
shard_out_val = test_fn()
# Retrieve TPU weights to CPU.
mid_level_api._retrieve_variables()
# Compute sparse tensors for global batch.
if is_high_dimensional:
input_data = next(
iter(self._create_high_dimensional_sparse_dataset(strategy)))
else:
input_data = next(iter(self._create_sparse_dataset(strategy)))
# Check results.
self._check_results(strategy, shard_out_val, training, input_data,
mid_level_api._variables, optimizer,
is_high_dimensional)
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()