blob: 1953401a5bcbfad8a4756193c502eb390adb1da6 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tests for python.tpu.feature_column."""
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import feature_column as tpu_fc
def _initialized_session():
sess = session.Session()
sess.run(variables_lib.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
return sess
class EmbeddingColumnTest(test.TestCase):
def test_defaults(self):
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
embedding_dimension = 2
embedding_column = tpu_fc.embedding_column(
categorical_column, dimension=embedding_dimension)
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('mean', embedding_column.combiner)
self.assertEqual('aaa_embedding', embedding_column.name)
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column._parse_example_spec)
def test_denylisted_column(self):
# HashedCategoricalColumn is denylisted and so will raise an exception.
categorical_column = fc_lib.categorical_column_with_hash_bucket(
key='aaa', hash_bucket_size=3)
embedding_dimension = 2
with self.assertRaises(TypeError):
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
def test_custom_column(self):
# This column is not in any allowlist but should succeed because
# it inherits from V2 CategoricalColumn.
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=10)
embedding_dimension = 2
embedding_column = tpu_fc.embedding_column(
categorical_column, dimension=embedding_dimension)
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('mean', embedding_column.combiner)
self.assertEqual('aaa_embedding', embedding_column.name)
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({'aaa': parsing_ops.VarLenFeature(dtypes.int64)},
embedding_column._parse_example_spec)
def test_all_constructor_args(self):
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
embedding_dimension = 2
embedding_column = tpu_fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
combiner='my_combiner',
initializer=lambda: 'my_initializer')
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('my_combiner', embedding_column.combiner)
self.assertEqual('aaa_embedding', embedding_column.name)
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column._parse_example_spec)
@test_util.deprecated_graph_mode_only
def test_get_dense_tensor(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
# example 2, ids []
# example 3, ids [1]
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
values=(2, 0, 1, 1),
dense_shape=(4, 5))
# Embedding variable.
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 5.), # id 1
(7., 11.) # id 2
)
def _initializer(shape, dtype, partition_info):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
# Expected lookup result, using combiner='mean'.
expected_lookups = (
# example 0, ids [2], embedding = [7, 11]
(7., 11.),
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
(2., 3.5),
# example 2, ids [], embedding = [0, 0]
(0., 0.),
# example 3, ids [1], embedding = [3, 5]
(3., 5.),
)
# Build columns.
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = tpu_fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column._get_dense_tensor(
fc._LazyBuilder({
'aaa': sparse_input
}))
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('embedding_weights:0',),
tuple([v.name for v in global_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0])
self.assertAllEqual(expected_lookups, embedding_lookup)
class SharedEmbeddingColumnTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def test_defaults(self):
categorical_column_a = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
categorical_column_b = fc_lib.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
self.assertEqual('mean', embedding_column_a.combiner)
self.assertEqual('mean', embedding_column_b.combiner)
self.assertIsNotNone(embedding_column_a.initializer)
self.assertIsNotNone(embedding_column_b.initializer)
self.assertEqual('aaa_bbb_shared_embedding',
embedding_column_a.shared_embedding_collection_name)
self.assertEqual('aaa_bbb_shared_embedding',
embedding_column_b.shared_embedding_collection_name)
self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
self.assertEqual('aaa_bbb_shared_embedding',
embedding_column_a._var_scope_name)
self.assertEqual('aaa_bbb_shared_embedding',
embedding_column_b._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape)
self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column_a._parse_example_spec)
self.assertEqual({
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column_b._parse_example_spec)
@test_util.deprecated_graph_mode_only
def test_all_constructor_args(self):
categorical_column_a = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
categorical_column_b = fc_lib.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
initializer=lambda: 'my_initializer',
shared_embedding_collection_name='var_scope_name')
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('my_combiner', embedding_column_b.combiner)
self.assertEqual('my_initializer', embedding_column_a.initializer())
self.assertEqual('my_initializer', embedding_column_b.initializer())
self.assertEqual('var_scope_name',
embedding_column_a.shared_embedding_collection_name)
self.assertEqual('var_scope_name',
embedding_column_b.shared_embedding_collection_name)
self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
self.assertEqual('var_scope_name', embedding_column_a._var_scope_name)
self.assertEqual('var_scope_name', embedding_column_b._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape)
self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column_a._parse_example_spec)
self.assertEqual({
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column_b._parse_example_spec)
@test_util.deprecated_graph_mode_only
def test_get_dense_tensor(self):
# Inputs.
vocabulary_size = 3
# -1 values are ignored.
input_a = np.array([
[2, -1, -1], # example 0, ids [2]
[0, 1, -1]
]) # example 1, ids [0, 1]
input_b = np.array([
[0, -1, -1], # example 0, ids [0]
[-1, -1, -1]
]) # example 1, ids []
input_features = {'aaa': input_a, 'bbb': input_b}
# Embedding variable.
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 5.), # id 1
(7., 11.) # id 2
)
def _initializer(shape, dtype, partition_info):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
# Expected lookup result, using combiner='mean'.
expected_lookups_a = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
)
expected_lookups_b = (
# example 0:
(1., 2.), # ids [0], embedding = [1, 2]
# example 1:
(0., 0.), # ids [], embedding = [0, 0]
)
# Build columns.
categorical_column_a = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc_lib.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a._get_dense_tensor(
fc._LazyBuilder(input_features))
embedding_lookup_b = embedding_column_b._get_dense_tensor(
fc._LazyBuilder(input_features))
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('embedding_weights:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var)
self.assertAllEqual(expected_lookups_a, embedding_lookup_a)
self.assertAllEqual(expected_lookups_b, embedding_lookup_b)
if __name__ == '__main__':
test.main()