blob: 9914e084bb1f379ce9fbf5def259c43b14cf0a71 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mid level API for Serving TPU Embeddings."""
from typing import Any, Iterable, Optional, Text, Union, Dict
from absl import logging
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import tpu_embedding_base
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.types import core
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export("tpu.experimental.embedding.TPUEmbeddingForServing")
class TPUEmbeddingForServing(tpu_embedding_base.TPUEmbeddingBase):
"""The TPUEmbedding mid level API running on CPU for serving.
Note: This class is intended to be used for embedding tables that are trained
on TPU and to be served on CPU. Therefore the class should be only initialized
under non-TPU strategy. Otherwise an error will be raised.
You can first train your model using the TPUEmbedding class and save the
checkpoint. Then use this class to restore the checkpoint to do serving.
First train a model and save the checkpoint.
```python
model = model_fn(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
# Your custom training code.
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.save(...)
```
Then restore the checkpoint and do serving.
```python
# Restore the model on CPU.
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbeddingForServing(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
result = embedding(...)
table = embedding.embedding_table
```
NOTE: This class can also be used to do embedding training on CPU. But it
requires the conversion between keras optimizer and embedding optimizers so
that the slot variables can stay consistent between them.
"""
def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
"""Creates the TPUEmbeddingForServing mid level API object.
```python
embedding = tf.tpu.experimental.embedding.TPUEmbeddingForServing(
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
table=tf.tpu.experimental.embedding.TableConfig(
dim=...,
vocabulary_size=...)))
```
Args:
feature_config: A nested structure of
`tf.tpu.experimental.embedding.FeatureConfig` configs.
optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`,
`tf.tpu.experimental.embedding.Adagrad` or
`tf.tpu.experimental.embedding.Adam`. When not created under TPUStrategy
may be set to None to avoid the creation of the optimizer slot
variables, useful for optimizing memory consumption when exporting the
model for serving where slot variables aren't needed.
Raises:
RuntimeError: If created under TPUStrategy.
"""
super(TPUEmbeddingForServing, self).__init__(feature_config, optimizer)
self._strategy = distribute_lib.get_strategy()
if isinstance(self._strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)):
raise RuntimeError("Serving on TPU is not yet supported.")
@property
def embedding_tables(
self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
"""Returns a dict of embedding tables, keyed by `TableConfig`."""
self._maybe_build()
# Only return the tables and not the slot variables.
return {
table: self._variables[table.name]["parameters"]
for table in self._table_config
}
def _maybe_build(self):
if not self._built:
# This can be called while tracing a function, so we wrap the
# initialization code with init_scope so it runs eagerly, this means that
# it will not be included the function graph generated by tracing so that
# we can be sure that we only initialize the TPU for embeddings exactly
# once.
with ops.init_scope():
self.build()
def _create_variables_and_slots(
self) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
"""Create variables for TPU embeddings.
Returns:
A dict of dicts. The outer dict is keyed by the table names and the inner
dicts are keyed by 'parameters' and the slot variable names.
"""
variables = {}
for table in self._table_config:
variables[table.name] = self._create_variables(table, trainable=True)
return variables
def embedding_lookup(self,
features: Any,
weights: Optional[Any] = None) -> Any:
"""Apply standard lookup ops on CPU.
Args:
features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or
`tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs
will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor`
or `tf.RaggedTensor` is supported per call.
weights: If not `None`, a nested structure of `tf.Tensor`s,
`tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except
that the tensors should be of float type (and they will be downcast to
`tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the
same for the parallel entries from `features` and similarly for
`tf.RaggedTensor`s we assume the row_splits are the same.
Returns:
A nested structure of Tensors with the same structure as input features.
"""
return cpu_embedding_lookup(features, weights, self.embedding_tables,
self._feature_config)
def _ragged_embedding_lookup_with_reduce(table: tf_variables.Variable,
ragged: ragged_tensor.RaggedTensor,
weights: ragged_tensor.RaggedTensor,
combiner: Text) -> core.Tensor:
"""Compute a ragged lookup followed by a reduce on axis 1.
Args:
table: The embedding table.
ragged: A RaggedTensor of ids to look up.
weights: A RaggedTensor of weights (or None).
combiner: One of "mean", "sum", "sqrtn".
Returns:
A Tensor.
"""
if weights is None:
weights = array_ops.ones_like(ragged, dtype=table.dtype)
weights = array_ops.expand_dims(weights, axis=2)
ragged_result = embedding_ops.embedding_lookup(table, ragged)
ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1)
if combiner == "mean":
ragged_result = math_ops.div_no_nan(ragged_result,
math_ops.reduce_sum(weights, axis=1))
elif combiner == "sqrtn":
ragged_result = math_ops.div_no_nan(
ragged_result,
math_ops.sqrt(math_ops.reduce_sum(weights * weights, axis=1)))
return ragged_result
@tf_export("tpu.experimental.embedding.serving_embedding_lookup")
def cpu_embedding_lookup(
inputs: Any,
weights: Optional[Any],
tables: Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable],
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable] # pylint:disable=g-bare-generic
) -> Any:
"""Apply standard lookup ops with `tf.tpu.experimental.embedding` configs.
This function is a utility which allows using the
`tf.tpu.experimental.embedding` config objects with standard lookup functions.
This can be used when exporting a model which uses
`tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular
`tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and
should not be part of your serving graph.
Note that TPU specific options (such as `max_sequence_length`) in the
configuration objects will be ignored.
In the following example we take a trained model (see the documentation for
`tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a
saved model with a serving function that will perform the embedding lookup and
pass the results to your model:
```python
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=1024,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
@tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
'feature_two': tf.TensorSpec(...),
'feature_three': tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
embedding_features, None, embedding.embedding_tables,
feature_config)
return model(embedded_features)
model.embedding_api = embedding
tf.saved_model.save(model,
export_dir=...,
signatures={'serving_default': serve_tensors})
```
NOTE: It's important to assign the embedding API object to a member of your
model as `tf.saved_model.save` only supports saving variables as one
`Trackable` object. Since the model's weights are in `model` and the
embedding table are managed by `embedding`, we assign `embedding` to an
attribute of `model` so that tf.saved_model.save can find the embedding
variables.
NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will
work directly from training.
Args:
inputs: a nested structure of Tensors, SparseTensors or RaggedTensors.
weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
None for no weights. If not None, structure must match that of inputs, but
entries are allowed to be None.
tables: a dict of mapping TableConfig objects to Variables.
feature_config: a nested structure of FeatureConfig objects with the same
structure as inputs.
Returns:
A nested structure of Tensors with the same structure as inputs.
"""
nest.assert_same_structure(inputs, feature_config)
flat_inputs = nest.flatten(inputs)
flat_weights = [None] * len(flat_inputs)
if weights is not None:
nest.assert_same_structure(inputs, weights)
flat_weights = nest.flatten(weights)
flat_features = nest.flatten_with_joined_string_paths(feature_config)
outputs = []
for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
flat_features):
table = tables[feature.table]
if weight is not None:
if isinstance(inp, ops.Tensor):
raise ValueError(
"Weight specified for {}, but input is dense.".format(path))
elif type(weight) is not type(inp):
raise ValueError(
"Weight for {} is of type {} but it does not match type of the "
"input which is {}.".format(path, type(weight), type(inp)))
elif feature.max_sequence_length > 0:
raise ValueError("Weight specified for {}, but this is a sequence "
"feature.".format(path))
if isinstance(inp, ops.Tensor):
if feature.max_sequence_length > 0:
raise ValueError("Feature {} is a sequence feature but a dense tensor "
"was passed.".format(path))
outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
elif isinstance(inp, sparse_tensor.SparseTensor):
outputs.append(
_embedding_lookup_for_sparse_tensor(inp, weight, table, feature))
elif isinstance(inp, ragged_tensor.RaggedTensor):
outputs.append(
_embedding_lookup_for_ragged_tensor(inp, weight, table, feature))
else:
raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
"RaggedTensor expected.".format(path, type(inp)))
return nest.pack_sequence_as(feature_config, outputs)
def _embedding_lookup_for_sparse_tensor(
inp: sparse_tensor.SparseTensor,
weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable,
feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
"""Embedding lookup for sparse tensor based on its feature config.
Args:
inp: a single SparseTensor input.
weight: None or SparseTensor which has the same shape of the input.
table: a table variable.
feature: a feature config.
Returns:
Embedding lookup result.
"""
inp_rank = inp.shape.rank
# The input rank can be None for sequence input tensor.
if (
not feature.output_shape
and feature.max_sequence_length > 0
and (inp_rank is None or inp_rank == 2)
):
batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
sparse_shape = array_ops_stack.stack(
[batch_size, feature.max_sequence_length], axis=0
)
# TPU Embedding truncates sequences to max_sequence_length, and if we
# don't truncate, scatter_nd will error out if the index was out of
# bounds.
truncated_inp = sparse_ops.sparse_slice(
inp, start=[0, 0], size=sparse_shape)
dense_output_shape = array_ops_stack.stack(
[batch_size, feature.max_sequence_length, feature.table.dim], axis=0)
return array_ops.scatter_nd(
truncated_inp.indices,
array_ops.gather(table.read_value(), truncated_inp.values),
dense_output_shape)
else:
if feature.max_sequence_length > 0:
logging.warning(
(
"max_sequence_length setting will be ignored because the rank of"
" the input tensor is %d which is not 2."
),
inp_rank,
)
if (not feature.validate_weights_and_indices and inp_rank is not None and
inp_rank <= 2):
return embedding_ops.embedding_lookup_sparse_v2(
table, inp, sp_weights=weight, combiner=feature.table.combiner)
else:
return embedding_ops.safe_embedding_lookup_sparse_v2(
table, inp, sparse_weights=weight, combiner=feature.table.combiner)
def _embedding_lookup_for_ragged_tensor(
inp: ragged_tensor.RaggedTensor,
weight: Optional[ragged_tensor.RaggedTensor], table: tf_variables.Variable,
feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
"""Embedding lookup for ragged tensor based on its feature config.
Args:
inp: a single rank 2 RaggedTensor input.
weight: None or RaggedTensor which has the same shape of the input.
table: a table variable.
feature: a feature config.
Returns:
Embedding lookup result.
Raises:
ValueError: if input ragged tensor is not rank 2 or output shape set in the
feature config doesn't match with the first dim size of the input.
"""
if inp.shape.rank != 2:
raise ValueError(
"Only rank 2 ragged tensor is supported, but got rank {}".format(
inp.shape.rank))
batch_size = inp.shape[0]
if feature.output_shape:
output_batch_size = math_ops.reduce_prod(feature.output_shape)
# If the output batch size matches the data batch size, treat it as
# normal ragged input.
if output_batch_size == batch_size:
ragged_output = _ragged_embedding_lookup_with_reduce(
table, inp, weight, feature.table.combiner)
ragged_output = array_ops.reshape(
ragged_output, shape=feature.output_shape + [feature.table.dim])
# If the data batch size is a factor of the output batch size, the
# divide result will be the sequence length. Ignore the weights and
# combiner.
elif output_batch_size > batch_size and output_batch_size % batch_size == 0:
ragged_output = embedding_ops.embedding_lookup_v2(table, inp)
# Pad or truncate in the sequence dimension
ragged_output = ragged_output.to_tensor(shape=[
batch_size, output_batch_size // batch_size, feature.table.dim
])
# Reshape to desire output shape.
ragged_output = array_ops.reshape(
ragged_output, feature.output_shape + [feature.table.dim])
else:
raise ValueError(
"Output shape set in the FeatureConfig should be the factor of "
"the input data batch size. But instead got output shape {}, "
"input data batch size {}".format(feature.output_shape, batch_size))
else:
if feature.max_sequence_length > 0:
output_shape = [
batch_size, feature.max_sequence_length, feature.table.dim
]
ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp)
# Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given
# shape.
ragged_output = ragged_lookup.to_tensor(shape=output_shape)
else:
ragged_output = _ragged_embedding_lookup_with_reduce(
table, inp, weight, feature.table.combiner)
return ragged_output