blob: 1f064359d77631b6de47634eeab7954b477cdf49 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base Class for TPU Embeddings Mid level APIs."""
import functools
from typing import Any, Dict, Iterable, Optional, Union, Text
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.trackable import autotrackable
from tensorflow.python.util import nest
class TPUEmbeddingBase(autotrackable.AutoTrackable):
"""The TPUEmbedding Base class.
This class only contains the basic logic to check the feature config and table
config for the tpu embedding mid level APIs.
"""
def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None): # pylint:disable=protected-access
"""Creates the TPUEmbeddingBase object."""
self._feature_config = feature_config
self._output_shapes = []
for feature in nest.flatten(feature_config):
self._output_shapes.append(feature.output_shape)
# Set table order here to the order of the first occurrence of the table in
# a feature provided by the user. The order of this struct must be fixed
# to provide the user with deterministic behavior over multiple
# instantiations.
self._table_config = []
for feature in nest.flatten(feature_config):
if feature.table not in self._table_config:
self._table_config.append(feature.table)
# Ensure tables have unique names. Also error check the optimizer as we
# specifically don't do that in the TableConfig class to allow high level
# APIs that are built on this to use strings/other classes to represent
# optimizers (before they are passed to this class).
table_names = []
for i, table in enumerate(self._table_config):
if table.optimizer is None:
# TODO(bfontain) Should we allow some sort of optimizer merging here?
table.optimizer = optimizer
if (table.optimizer is not None and
not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access
raise ValueError("{} is an unsupported optimizer class. Please pass an "
"instance of one of the optimizer classes under "
"tf.tpu.experimental.embedding.".format(
type(table.optimizer)))
if table.name is None:
table.name = "table_{}".format(i)
if table.name in table_names:
raise ValueError("Tables must have a unique name. "
f"Multiple tables with name {table.name} found.")
table_names.append(table.name)
self._built = False
@property
def embedding_tables(self):
"""Returns a dict of embedding tables, keyed by `TableConfig`."""
raise NotImplementedError
def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig,
trainable: bool) -> Dict[Text, tf_variables.Variable]:
"""Create all variables including table variables and slot variables."""
variable_shape = (table.vocabulary_size, table.dim)
def getter(name, shape, dtype, initializer, trainable):
del shape
# _add_variable_with_custom_getter clears the shape sometimes, so we
# take the global shape from outside the getter.
initial_value = functools.partial(
initializer, variable_shape, dtype=dtype)
return tf_variables.Variable(
name=name,
initial_value=initial_value,
shape=variable_shape,
dtype=dtype,
trainable=trainable)
def variable_creator(name, initializer, trainable=True):
# Use add_variable_with_custom_getter here so that we take advantage of
# the checkpoint loading to allow restore before the variables get
# created which avoids double initialization.
return self._add_variable_with_custom_getter(
name=name,
initializer=initializer,
shape=variable_shape,
dtype=dtypes.float32,
getter=getter,
trainable=trainable)
parameters = variable_creator(
table.name, table.initializer, trainable=trainable)
def slot_creator(name, initializer):
return variable_creator(table.name + "/" + name, initializer, False)
if table.optimizer is not None:
slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
else:
slot_vars = {}
slot_vars["parameters"] = parameters
return slot_vars
def _create_variables_and_slots(self):
"""Create variables and slots variables for TPU embeddings."""
raise NotImplementedError
def build(self):
"""Create variables and slots variables for TPU embeddings."""
if self._built:
return
self._variables = self._create_variables_and_slots()
self._built = True
def __call__(self, features: Any, weights: Optional[Any] = None) -> Any:
"""Call the mid level api to do embedding lookup."""
if not self._built:
self.build()
return self.embedding_lookup(features, weights)
def embedding_lookup(self,
features: Any,
weights: Optional[Any] = None) -> Any:
"""Lookup the embedding table using the input features."""
raise NotImplementedError