| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # =================================================================== |
| """Tests for python.tpu.feature_column.""" |
| |
| import numpy as np |
| |
| from tensorflow.python.client import session |
| from tensorflow.python.feature_column import feature_column as fc |
| from tensorflow.python.feature_column import feature_column_lib as fc_lib |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import lookup_ops |
| from tensorflow.python.ops import parsing_ops |
| from tensorflow.python.ops import variables as variables_lib |
| from tensorflow.python.platform import test |
| from tensorflow.python.tpu import feature_column as tpu_fc |
| |
| |
| def _initialized_session(): |
| sess = session.Session() |
| sess.run(variables_lib.global_variables_initializer()) |
| sess.run(lookup_ops.tables_initializer()) |
| return sess |
| |
| |
| class EmbeddingColumnTest(test.TestCase): |
| |
| def test_defaults(self): |
| categorical_column = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=3) |
| embedding_dimension = 2 |
| embedding_column = tpu_fc.embedding_column( |
| categorical_column, dimension=embedding_dimension) |
| self.assertIs(categorical_column, embedding_column.categorical_column) |
| self.assertEqual(embedding_dimension, embedding_column.dimension) |
| self.assertEqual('mean', embedding_column.combiner) |
| self.assertEqual('aaa_embedding', embedding_column.name) |
| self.assertEqual('aaa_embedding', embedding_column._var_scope_name) |
| self.assertEqual((embedding_dimension,), embedding_column._variable_shape) |
| self.assertEqual({ |
| 'aaa': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column._parse_example_spec) |
| |
| def test_denylisted_column(self): |
| # HashedCategoricalColumn is denylisted and so will raise an exception. |
| categorical_column = fc_lib.categorical_column_with_hash_bucket( |
| key='aaa', hash_bucket_size=3) |
| embedding_dimension = 2 |
| with self.assertRaises(TypeError): |
| tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension) |
| |
| def test_custom_column(self): |
| # This column is not in any allowlist but should succeed because |
| # it inherits from V2 CategoricalColumn. |
| categorical_column = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=10) |
| embedding_dimension = 2 |
| embedding_column = tpu_fc.embedding_column( |
| categorical_column, dimension=embedding_dimension) |
| self.assertIs(categorical_column, embedding_column.categorical_column) |
| self.assertEqual(embedding_dimension, embedding_column.dimension) |
| self.assertEqual('mean', embedding_column.combiner) |
| self.assertEqual('aaa_embedding', embedding_column.name) |
| self.assertEqual('aaa_embedding', embedding_column._var_scope_name) |
| self.assertEqual((embedding_dimension,), embedding_column._variable_shape) |
| self.assertEqual({'aaa': parsing_ops.VarLenFeature(dtypes.int64)}, |
| embedding_column._parse_example_spec) |
| |
| def test_all_constructor_args(self): |
| categorical_column = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=3) |
| embedding_dimension = 2 |
| embedding_column = tpu_fc.embedding_column( |
| categorical_column, |
| dimension=embedding_dimension, |
| combiner='my_combiner', |
| initializer=lambda: 'my_initializer') |
| self.assertIs(categorical_column, embedding_column.categorical_column) |
| self.assertEqual(embedding_dimension, embedding_column.dimension) |
| self.assertEqual('my_combiner', embedding_column.combiner) |
| self.assertEqual('aaa_embedding', embedding_column.name) |
| self.assertEqual('aaa_embedding', embedding_column._var_scope_name) |
| self.assertEqual((embedding_dimension,), embedding_column._variable_shape) |
| self.assertEqual({ |
| 'aaa': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column._parse_example_spec) |
| |
| @test_util.deprecated_graph_mode_only |
| def test_get_dense_tensor(self): |
| # Inputs. |
| vocabulary_size = 3 |
| sparse_input = sparse_tensor.SparseTensorValue( |
| # example 0, ids [2] |
| # example 1, ids [0, 1] |
| # example 2, ids [] |
| # example 3, ids [1] |
| indices=((0, 0), (1, 0), (1, 4), (3, 0)), |
| values=(2, 0, 1, 1), |
| dense_shape=(4, 5)) |
| |
| # Embedding variable. |
| embedding_dimension = 2 |
| embedding_values = ( |
| (1., 2.), # id 0 |
| (3., 5.), # id 1 |
| (7., 11.) # id 2 |
| ) |
| |
| def _initializer(shape, dtype, partition_info): |
| self.assertAllEqual((vocabulary_size, embedding_dimension), shape) |
| self.assertEqual(dtypes.float32, dtype) |
| self.assertIsNone(partition_info) |
| return embedding_values |
| |
| # Expected lookup result, using combiner='mean'. |
| expected_lookups = ( |
| # example 0, ids [2], embedding = [7, 11] |
| (7., 11.), |
| # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] |
| (2., 3.5), |
| # example 2, ids [], embedding = [0, 0] |
| (0., 0.), |
| # example 3, ids [1], embedding = [3, 5] |
| (3., 5.), |
| ) |
| |
| # Build columns. |
| categorical_column = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=vocabulary_size) |
| embedding_column = tpu_fc.embedding_column( |
| categorical_column, |
| dimension=embedding_dimension, |
| initializer=_initializer) |
| |
| # Provide sparse input and get dense result. |
| embedding_lookup = embedding_column._get_dense_tensor( |
| fc._LazyBuilder({ |
| 'aaa': sparse_input |
| })) |
| |
| # Assert expected embedding variable and lookups. |
| global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |
| self.assertItemsEqual(('embedding_weights:0',), |
| tuple([v.name for v in global_vars])) |
| with _initialized_session(): |
| self.assertAllEqual(embedding_values, global_vars[0]) |
| self.assertAllEqual(expected_lookups, embedding_lookup) |
| |
| |
| class SharedEmbeddingColumnTest(test.TestCase): |
| |
| @test_util.deprecated_graph_mode_only |
| def test_defaults(self): |
| categorical_column_a = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=3) |
| categorical_column_b = fc_lib.categorical_column_with_identity( |
| key='bbb', num_buckets=3) |
| embedding_dimension = 2 |
| embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns( |
| [categorical_column_b, categorical_column_a], |
| dimension=embedding_dimension) |
| self.assertIs(categorical_column_a, embedding_column_a.categorical_column) |
| self.assertIs(categorical_column_b, embedding_column_b.categorical_column) |
| self.assertEqual(embedding_dimension, embedding_column_a.dimension) |
| self.assertEqual(embedding_dimension, embedding_column_b.dimension) |
| self.assertEqual('mean', embedding_column_a.combiner) |
| self.assertEqual('mean', embedding_column_b.combiner) |
| self.assertIsNotNone(embedding_column_a.initializer) |
| self.assertIsNotNone(embedding_column_b.initializer) |
| self.assertEqual('aaa_bbb_shared_embedding', |
| embedding_column_a.shared_embedding_collection_name) |
| self.assertEqual('aaa_bbb_shared_embedding', |
| embedding_column_b.shared_embedding_collection_name) |
| self.assertEqual('aaa_shared_embedding', embedding_column_a.name) |
| self.assertEqual('bbb_shared_embedding', embedding_column_b.name) |
| self.assertEqual('aaa_bbb_shared_embedding', |
| embedding_column_a._var_scope_name) |
| self.assertEqual('aaa_bbb_shared_embedding', |
| embedding_column_b._var_scope_name) |
| self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) |
| self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) |
| self.assertEqual({ |
| 'aaa': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column_a._parse_example_spec) |
| self.assertEqual({ |
| 'bbb': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column_b._parse_example_spec) |
| |
| @test_util.deprecated_graph_mode_only |
| def test_all_constructor_args(self): |
| categorical_column_a = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=3) |
| categorical_column_b = fc_lib.categorical_column_with_identity( |
| key='bbb', num_buckets=3) |
| embedding_dimension = 2 |
| embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( |
| [categorical_column_a, categorical_column_b], |
| dimension=embedding_dimension, |
| combiner='my_combiner', |
| initializer=lambda: 'my_initializer', |
| shared_embedding_collection_name='var_scope_name') |
| self.assertIs(categorical_column_a, embedding_column_a.categorical_column) |
| self.assertIs(categorical_column_b, embedding_column_b.categorical_column) |
| self.assertEqual(embedding_dimension, embedding_column_a.dimension) |
| self.assertEqual(embedding_dimension, embedding_column_b.dimension) |
| self.assertEqual('my_combiner', embedding_column_a.combiner) |
| self.assertEqual('my_combiner', embedding_column_b.combiner) |
| self.assertEqual('my_initializer', embedding_column_a.initializer()) |
| self.assertEqual('my_initializer', embedding_column_b.initializer()) |
| self.assertEqual('var_scope_name', |
| embedding_column_a.shared_embedding_collection_name) |
| self.assertEqual('var_scope_name', |
| embedding_column_b.shared_embedding_collection_name) |
| self.assertEqual('aaa_shared_embedding', embedding_column_a.name) |
| self.assertEqual('bbb_shared_embedding', embedding_column_b.name) |
| self.assertEqual('var_scope_name', embedding_column_a._var_scope_name) |
| self.assertEqual('var_scope_name', embedding_column_b._var_scope_name) |
| self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) |
| self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) |
| self.assertEqual({ |
| 'aaa': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column_a._parse_example_spec) |
| self.assertEqual({ |
| 'bbb': parsing_ops.VarLenFeature(dtypes.int64) |
| }, embedding_column_b._parse_example_spec) |
| |
| @test_util.deprecated_graph_mode_only |
| def test_get_dense_tensor(self): |
| # Inputs. |
| vocabulary_size = 3 |
| # -1 values are ignored. |
| input_a = np.array([ |
| [2, -1, -1], # example 0, ids [2] |
| [0, 1, -1] |
| ]) # example 1, ids [0, 1] |
| input_b = np.array([ |
| [0, -1, -1], # example 0, ids [0] |
| [-1, -1, -1] |
| ]) # example 1, ids [] |
| input_features = {'aaa': input_a, 'bbb': input_b} |
| |
| # Embedding variable. |
| embedding_dimension = 2 |
| embedding_values = ( |
| (1., 2.), # id 0 |
| (3., 5.), # id 1 |
| (7., 11.) # id 2 |
| ) |
| |
| def _initializer(shape, dtype, partition_info): |
| self.assertAllEqual((vocabulary_size, embedding_dimension), shape) |
| self.assertEqual(dtypes.float32, dtype) |
| self.assertIsNone(partition_info) |
| return embedding_values |
| |
| # Expected lookup result, using combiner='mean'. |
| expected_lookups_a = ( |
| # example 0: |
| (7., 11.), # ids [2], embedding = [7, 11] |
| # example 1: |
| (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] |
| ) |
| expected_lookups_b = ( |
| # example 0: |
| (1., 2.), # ids [0], embedding = [1, 2] |
| # example 1: |
| (0., 0.), # ids [], embedding = [0, 0] |
| ) |
| |
| # Build columns. |
| categorical_column_a = fc_lib.categorical_column_with_identity( |
| key='aaa', num_buckets=vocabulary_size) |
| categorical_column_b = fc_lib.categorical_column_with_identity( |
| key='bbb', num_buckets=vocabulary_size) |
| embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( |
| [categorical_column_a, categorical_column_b], |
| dimension=embedding_dimension, |
| initializer=_initializer) |
| |
| # Provide sparse input and get dense result. |
| embedding_lookup_a = embedding_column_a._get_dense_tensor( |
| fc._LazyBuilder(input_features)) |
| embedding_lookup_b = embedding_column_b._get_dense_tensor( |
| fc._LazyBuilder(input_features)) |
| |
| # Assert expected embedding variable and lookups. |
| global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |
| self.assertItemsEqual(('embedding_weights:0',), |
| tuple([v.name for v in global_vars])) |
| embedding_var = global_vars[0] |
| with _initialized_session(): |
| self.assertAllEqual(embedding_values, embedding_var) |
| self.assertAllEqual(expected_lookups_a, embedding_lookup_a) |
| self.assertAllEqual(expected_lookups_b, embedding_lookup_b) |
| |
| |
| if __name__ == '__main__': |
| test.main() |