blob: 7b19500025bbc13f5daed1f39706bcf8b47827c7 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mid level API for TPU Embeddings without Embedding Accelerator."""
from typing import Any, Dict, Iterable, Optional, Text, Union
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import tpu_embedding_base
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export("tpu.experimental.embedding.TPUEmbeddingV0")
class TPUEmbeddingV0(tpu_embedding_base.TPUEmbeddingBase):
"""The TPUEmbedding mid level API running on TPU without Embedding accelerator.
NOTE: This mid level API is not intended for large embedding table lookup.
Embedding tables will be replicated across devices rather than sharding
across them. To do large embedding table lookup, please use the
`tpu.experimental.embedding.TPUEmbedding` class. This class is an alternative
way to do embedding lookups when the TPU doesn't support any version of
embedding feature. See
`tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
explanation.
This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError
will be raised.
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the lookup
operation a special input option must be specified:
```python
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```
Below is an example of a training and evaluation step:
```python
optimizer = tf.keras.optimizers.SGD(0.1)
@tf.function
def training_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
with tf.GradientTape() as tape:
tape.watch(embedding.embedding_table.values())
activation = embedding(embedding_features)
model_output = model(activations)
loss = ... # some function of labels and model_output
embedding_gradients = tape.gradient(loss,
embedding.embedding_table.values())
optimizer.apply_gradients(list(zip(gradients,
mid_level_api.embedding_tables.values())))
# Insert your model gradient and optimizer application here
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
@tf.function
def evalution_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
activations = embedding(embedding_features)
model_output = model(activations)
# Insert your evaluation code here.
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
```
NOTE: The optimizer used here is a Keras optimizer. In order to make the slot
variable creation stay consistent between Keras optimizers and
embedding optimizers, the `slot_variable_creation_fn` argument of the
embedding optimizers has to be passed with the Keras `add_slot` function. Also
note that the slot names might be slightly different between them.
```python
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
def slot_variable_creation_fn(table, slot_names, slot_initializers):
slots = {}
for slot, initializer in zip(slot_names, slot_initializers):
slots[slot] = optimizer.add_slot(table, slot, initializer)
return slots
embedding_optimizer = tf.experimental.embedding.Adagrad(
learning_rate=0.1,
slot_variable_creation_fn=slot_variable_creation_fn)
# Use the embedding optimizer to create mid level api and keras optimizer to
# apply gradients.
```
"""
def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
super(TPUEmbeddingV0, self).__init__(feature_config, optimizer)
self._strategy = distribute_lib.get_strategy()
if not isinstance(self._strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)):
raise RuntimeError(
"TPUEmbeddingV0 should be created under TPUStrategy but found {}."
.format(self._strategy))
self._built = False
@property
def embedding_tables(
self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
"""Returns a dict of embedding tables, keyed by `TableConfig`."""
self._maybe_build()
# Only return the tables and not the slot variables.
return {
table: self._variables[table.name]["parameters"]
for table in self._table_config
}
def _create_variables_and_slots(
self) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
"""Create variables for TPU embeddings.
Note that this will always ensure that the variable is created under the
TPUStrategy.
Returns:
A dict of dicts. The outer dict is keyed by the table names and the inner
dicts are keyed by 'parameters' and the slot variable names.
"""
variables = {}
for table in self._table_config:
# created TPUDistributedVariable.
variables[table.name] = self._create_variables(table, trainable=True)
return variables
def _maybe_build(self):
if not self._built:
# This can be called while tracing a function, so we wrap the
# initialization code with init_scope so it runs eagerly, this means that
# it will not be included in the function graph generated by tracing so
# that we can be sure that we only initialize the TPU for embeddings
# exactly once.
with ops.init_scope():
self.build()
def _apply_combiner_to_embeddings(
self,
embeddings: ops.Tensor,
weight: ops.Tensor,
combiner: Optional[Text] = None) -> ops.Tensor:
"""Apply the combiner to the embedding look up result on second to last axis.
Args:
embeddings: A Tensor of the embedding lookup result.
weight: A Tensor of weight which has the same shape of the embeddings.
combiner: One of "mean", "sum", "sqrtn". Defaults to "mean".
Raises:
ValueError: If the combiner is not one of 'mean', 'sqrtn' or 'sum'.
Returns:
A Tensor.
"""
if combiner is None:
combiner = "mean"
if combiner == "sum":
embeddings = math_ops.reduce_sum(embeddings, axis=-2)
elif combiner == "mean":
embeddings = math_ops.reduce_sum(embeddings, axis=-2)
weight_sum = math_ops.reduce_sum(weight, axis=-2)
embeddings = math_ops.div_no_nan(embeddings, weight_sum)
elif combiner == "sqrtn":
embeddings = math_ops.reduce_sum(embeddings, axis=-2)
weight_squared = math_ops.pow(weight, 2)
weight_sum = math_ops.reduce_sum(weight_squared, axis=-2)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt)
else:
raise ValueError(
f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
return embeddings
def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor,
sequence_length: int) -> ops.Tensor:
"""Pad or truncate the embedding lookup result based on the sequence length.
Args:
embeddings: A rank 3 Tensor of the embedding lookup result.
sequence_length: number of the max sequence length set in the feature
config.
Returns:
A Tensor with second last axis padded or truncated.
"""
original_sequence_length = embeddings.shape[1]
if original_sequence_length > sequence_length:
embeddings = array_ops.slice(
embeddings, begin=[0, 0, 0], size=[-1, sequence_length, -1])
else:
embeddings = array_ops.pad(
embeddings,
paddings=[[0, 0], [0, sequence_length - original_sequence_length],
[0, 0]])
return embeddings
def embedding_lookup(self,
features: Any,
weights: Optional[Any] = None) -> Any:
"""Apply embedding lookup on TPUs using Tensorcore.
Note that all the sparse and ragged tensors will be converted to dense
tensors on CPU and then passed to the TPU to do embedding look up. Large
embedding lookup is not supported by this API, use the TPUEmbedding mid
level api instead.
Args:
features: a nested structure of Tensors, SparseTensors or RaggedTensors.
weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
None for no weights. If not None, structure must match that of inputs,
but entries are allowed to be None.
Returns:
A nested structure of Tensors with the same structure as inputs.
"""
if not self._built:
self.build()
nest.assert_same_structure(features, self._feature_config)
flat_inputs = nest.flatten(features)
flat_weights = [None] * len(flat_inputs)
if weights is not None:
nest.assert_same_structure(features, weights)
flat_weights = nest.flatten(weights)
flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
outputs = []
for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
flat_features):
table = self.embedding_tables[feature.table]
if weight is not None:
if isinstance(inp, ops.Tensor):
raise ValueError(
"Weight specified for {}, but input is dense.".format(path))
elif type(weight) is not type(inp):
raise ValueError(
"Weight for {} is of type {} but it does not match type of the "
"input which is {}.".format(path, type(weight), type(inp)))
elif feature.max_sequence_length > 0:
raise ValueError("Weight specified for {}, but this is a sequence "
"feature.".format(path))
if isinstance(inp, ops.Tensor):
if feature.max_sequence_length > 0:
raise ValueError(
"Feature {} is a sequence feature but a dense tensor "
"was passed.".format(path))
outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
elif isinstance(inp, sparse_tensor.SparseTensor):
outputs.append(
self._embedding_lookup_for_sparse_tensor(inp, weight, table,
feature))
elif isinstance(inp, ragged_tensor.RaggedTensor):
outputs.append(
self._embedding_lookup_for_ragged_tensor(inp, weight, table,
feature))
else:
raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
"RaggedTensor expected.".format(path, type(inp)))
return nest.pack_sequence_as(self._feature_config, outputs)
def _embedding_lookup_for_sparse_tensor(
self, inp: sparse_tensor.SparseTensor,
weight: Optional[sparse_tensor.SparseTensor],
table: tf_variables.Variable,
feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
"""Embedding lookup for sparse tensor based on its feature config.
Args:
inp: a single SparseTensor input.
weight: None or SparseTensor which has the same shape of the input.
table: a table variable.
feature: a feature config.
Returns:
Embedding lookup result.
"""
# This computation needs to placed outside of tpu as the size of the
# indices and values can change for different batch which can cause
# the program to re-compile.
def sparse_to_dense_computation(inp, weight):
if weight is None:
weight = sparse_tensor.SparseTensor(
inp.indices,
array_ops.ones_like(inp.values, dtype=dtypes.float32),
dense_shape=inp.dense_shape)
# Pad the sparse tensor to be dense tensor.
inp = sparse_ops.sparse_tensor_to_dense(inp)
weight = sparse_ops.sparse_tensor_to_dense(weight)
return inp, weight
inp, weight = tpu_replication.outside_compilation(
sparse_to_dense_computation, inp=inp, weight=weight)
embeddings = embedding_ops.embedding_lookup_v2(table, inp)
weight = array_ops.expand_dims(weight, -1)
embeddings *= weight
if not feature.output_shape and feature.max_sequence_length > 0:
embeddings = self._pad_or_truncate_with_sequence_length(
embeddings, feature.max_sequence_length)
else:
embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
feature.table.combiner)
return embeddings
def _embedding_lookup_for_ragged_tensor(
self, inp: ragged_tensor.RaggedTensor,
weight: Optional[ragged_tensor.RaggedTensor],
table: tf_variables.Variable,
feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
"""Embedding lookup for ragged tensor based on its feature config.
Args:
inp: a single rank 2 RaggedTensor input.
weight: None or RaggedTensor which has the same shape of the input.
table: a table variable.
feature: a feature config.
Returns:
Embedding lookup result.
Raises:
ValueError: if input ragged tensor is not rank 2 or output shape set in
the feature config doesn't match with the first dim size of the input.
"""
if inp.shape.rank != 2:
raise ValueError(
"Only rank 2 ragged tensor is supported, but got rank {}".format(
inp.shape.rank))
batch_size = inp.shape[0]
# This computation needs to placed outside of tpu as the size of the row
# splits and values can change for different batch which can cause
# the program to re-compile.
def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature):
if weight is None:
weight = ragged_tensor.RaggedTensor.from_row_splits(
array_ops.ones_like(inp.values, dtype=dtypes.float32),
inp.row_splits)
if not feature.output_shape and feature.max_sequence_length > 0:
inp = inp.to_tensor(shape=(batch_size, feature.max_sequence_length))
# Ignore weight if it is a sequence feature.
weight = array_ops.ones_like(inp, dtype=dtypes.float32)
elif feature.output_shape:
# Eagerly run the following op as the result as to be a number in
# order to use it as part of the output shape.
with ops.init_scope():
output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
# If the output batch size matches the data batch size, treat it as
# normal ragged input.
if output_batch_size == batch_size:
inp, weight = inp.to_tensor(), weight.to_tensor()
# If the data batch size is a factor of the output batch size, the
# divide result will be the sequence length. Ignore the weights and
# combiner.
elif output_batch_size > batch_size and output_batch_size % batch_size == 0:
# Pad or truncate in the sequence dimension
seq_length = output_batch_size // batch_size
inp = inp.to_tensor(shape=(batch_size, seq_length))
# Ignore weight if it is a sequence feature.
weight = array_ops.ones_like(inp, dtype=dtypes.float32)
else:
raise ValueError(
"Output shape set in the FeatureConfig should be the factor of "
"the input data batch size. But instead got output shape {}, "
"input data batch size {}".format(feature.output_shape,
batch_size))
else:
inp, weight = inp.to_tensor(), weight.to_tensor()
return inp, weight
inp, weight = tpu_replication.outside_compilation(
ragged_to_dense_outside_compilation,
inp=inp,
weight=weight,
batch_size=batch_size,
feature=feature)
embeddings = embedding_ops.embedding_lookup_v2(table, inp)
weight = array_ops.expand_dims(weight, -1)
embeddings *= weight
if feature.output_shape:
with ops.init_scope():
output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
if output_batch_size == batch_size:
embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
feature.table.combiner)
embeddings = array_ops.reshape(
embeddings, shape=feature.output_shape + [feature.table.dim])
else:
if feature.max_sequence_length == 0:
embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
feature.table.combiner)
return embeddings