blob: 9e18c8c3a0cf8882e4bfd9e36f80ad062fd6a3b3 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API utils on TPU."""
from absl.testing import parameterized
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.compat import v2_compat
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2_utils
class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase):
@parameterized.parameters(tpu_embedding_v2_utils.Adagrad,
tpu_embedding_v2_utils.Adam,
tpu_embedding_v2_utils.FTRL)
def test_grad_clip_with_accumulation_off(self, optimizer):
with self.assertRaisesRegex(ValueError, 'accumulation'):
optimizer(use_gradient_accumulation=False, clipvalue=0.)
with self.assertRaisesRegex(ValueError, 'accumulation'):
optimizer(use_gradient_accumulation=False, clipvalue=(None, 1.))
@parameterized.parameters(tpu_embedding_v2_utils.SGD,
tpu_embedding_v2_utils.Adagrad,
tpu_embedding_v2_utils.Adam,
tpu_embedding_v2_utils.FTRL)
def test_grad_clip_with_tuple(self, optimizer):
opt = optimizer(clipvalue=(-1., 1.))
self.assertEqual(-1., opt.clip_gradient_min)
self.assertEqual(1., opt.clip_gradient_max)
@parameterized.parameters(tpu_embedding_v2_utils.SGD,
tpu_embedding_v2_utils.Adagrad,
tpu_embedding_v2_utils.Adam,
tpu_embedding_v2_utils.FTRL)
def test_grad_clip_with_single_value(self, optimizer):
opt = optimizer(clipvalue=1.)
self.assertEqual(-1., opt.clip_gradient_min)
self.assertEqual(1., opt.clip_gradient_max)
@parameterized.parameters(tpu_embedding_v2_utils.SGD,
tpu_embedding_v2_utils.Adagrad,
tpu_embedding_v2_utils.Adam,
tpu_embedding_v2_utils.FTRL)
def test_grad_clip_with_tuple_and_none(self, optimizer):
opt = optimizer(clipvalue=(None, 1))
self.assertIsNone(opt.clip_gradient_min)
self.assertEqual(1., opt.clip_gradient_max)
@parameterized.parameters(tpu_embedding_v2_utils.SGD,
tpu_embedding_v2_utils.Adagrad,
tpu_embedding_v2_utils.Adam,
tpu_embedding_v2_utils.FTRL)
def test_equal_and_hash_function(self, optimizer):
opt1 = optimizer(0.1)
opt2 = optimizer(0.1)
opt3 = optimizer(0.2)
self.assertEqual(opt1, opt2)
self.assertEqual(hash(opt1), hash(opt2))
self.assertNotEqual(opt1, opt3)
self.assertNotEqual(hash(opt1), hash(opt3))
class ConfigTest(test.TestCase):
def test_table_config_repr(self):
table = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=2, dim=4,
combiner='sum', name='table')
self.assertEqual(
repr(table),
'TableConfig(vocabulary_size=2, dim=4, initializer=None, '
'optimizer=None, combiner=\'sum\', name=\'table\', '
'quantization_config=None)')
def test_feature_config_repr(self):
table = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=2, dim=4, initializer=None,
combiner='sum', name='table')
feature_config = tpu_embedding_v2_utils.FeatureConfig(
table=table, output_shape=[16, 4], name='feature')
self.assertEqual(
repr(feature_config),
'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, '
'initializer=None, optimizer=None, combiner=\'sum\', '
'name=\'table\', quantization_config=None), max_sequence_length=0, '
'validate_weights_and_indices=True, output_shape=TensorShape([16, 4]), '
'name=\'feature\')')
def test_quantization_config_num_buckets(self):
with self.assertRaisesRegex(ValueError, 'num_buckets'):
tpu_embedding_v2_utils.QuantizationConfig(0, -1, 1)
def test_quantization_config_repr(self):
quantization_config = tpu_embedding_v2_utils.QuantizationConfig(
num_buckets=10, lower=-1.0, upper=1.0)
self.assertEqual(
repr(quantization_config),
'QuantizationConfig(num_buckets=10, lower=-1.0, upper=1.0)')
class TPUEmbeddingConfigurationTest(test.TestCase):
def test_no_truncate(self):
truncate_length = 14937 # Experimentally maximum string length loggable.
config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
for i in range(500):
td = config.table_descriptor.add()
td.name = 'table_{}'.format(i)
td.vocabulary_size = i
config.num_hosts = 2
config.num_tensor_cores = 4
config.batch_size_per_tensor_core = 128
self.assertGreater(
len(str(config)), truncate_length,
'Test sanity check: generated config should be of truncating length.')
with self.assertLogs() as logs:
tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
self.assertIn('table_499', ''.join(logs.output))
for line in logs.output:
self.assertLess(
len(line), truncate_length,
'Logging function lines should not be of truncating length.')
class TPUEmbeddingUtilityFunctionTest(test.TestCase):
def test_sort_device_spec_strings(self):
device_spec_strings = []
for task in [2, 3, 0, 1]: # Intentionally permuted
for device in range(8):
device_spec_strings.append(
f'/job:trainer/replica:0/task:{task}/device:TPU:{device}'
)
sorted_specs = tpu_embedding_v2_utils._sort_device_spec_strings(
device_spec_strings
)
self.assertEqual(sorted_specs, sorted(device_spec_strings))
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()