blob: cb272c296341bbae1acb5b960e55b23f4149a109 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Optional helper for gradient handling."""
import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu.ops import tpu_ops
def get_gradients_through_compute_gradients(optimizer, loss, activations):
"""Compute gradients to send to TPU embedding.
Args:
optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
Used to call compute_gradients().
loss: a Tensor to call optimizer.compute_gradients() on.
activations: an OrderedDict mapping feature_name to Tensors of activations.
Returns:
An OrderedDict mapping from feature name Strings to Tensors of gradients of
the loss wrt the activations of the features.
"""
activation_list = activations.values()
grads_and_vars = optimizer.compute_gradients(loss, activation_list)
grads = [grad for grad, _ in grads_and_vars]
feature_to_gradient_dict = collections.OrderedDict(
zip(activations.keys(), grads))
return feature_to_gradient_dict
def create_dummy_table_variables(tpu_embedding):
"""Create dummy embedding table variables.
The sole purpose of these dummy variables are to trigger gradient
calculation wrt them so that the gradients wrt activation can be captured
and later sent to TPU embedding.
Args:
tpu_embedding: TPUEmbedding, dummy table variables will be created for use
with tpu_embedding.
Returns:
A tuple of dummy variables and their initializer.
Raises:
RuntimeError: if collection to store gradients already exists and is not
empty.
"""
dummy_table_variables = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
dummy_table_variables[table] = (
# Explicitly specifying collections prevents this variable from
# being added to the GLOBAL_VARIABLES collection, so that Saver()
# ignores it.
# But Tensorflow optimizer creates slot variable for these dummy
# variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
# which will be in GLOBAL_VARIABLES collection,
variable_scope.get_variable(
'tpu_embedding_dummy_table_variable_{}'.format(table),
dtype=dtypes.float32,
shape=[1],
use_resource=True,
trainable=True,
collections=['tpu_embedding_dummy_table_variables']))
g = ops.get_default_graph()
table_gradients = g.get_collection_ref(
'tpu_embedding_gradients_table_{}'.format(table_id))
if table_gradients:
raise RuntimeError(
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
num_features = len(tpu_embedding.table_to_features_dict[table])
table_gradients.extend([None for _ in range(num_features)])
return (dummy_table_variables,
variables.variables_initializer(
dummy_table_variables.values(),
name='tpu_embedding_dummy_table_variables_init'))
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
dummy_table_variables):
"""Have activations depend on dummy table variables for gradient intercept.
Args:
tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
tpu_embedding.
activations: An OrderedDict of feature name String to activation tensors.
dummy_table_variables: An OrderedDict of table name String to dummy table
variables.
Returns:
An OrderedDict of feature name String to activation tensors, which can be
used just as the activations input.
"""
new_activations = collections.OrderedDict()
for feature in activations:
table = tpu_embedding.feature_to_config_dict[feature].table_id
new_activations[feature] = tpu_ops.tpu_embedding_activations(
dummy_table_variables[table],
activations[feature],
table_id=list(tpu_embedding.table_to_config_dict).index(table),
lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
return new_activations
def get_gradients_through_dummy_table_variables(tpu_embedding):
"""Get gradients wrt the activations of each feature.
Args:
tpu_embedding: TPUEmbedding, create dummy table variable to be used with
tpu_embedding.
Returns:
An OrderedDict mapping feature name to gradient.
Raises:
ValueError: if some gradients are not defined.
"""
g = ops.get_default_graph()
gradients_found = False
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
table_gradients = g.get_collection(
'tpu_embedding_gradients_table_{}'.format(table_id))
if any(gradient is None for gradient in table_gradients):
# TODO(bfontain): create a white-list for optimizers which are compatible
# with `tf.stop_gradient`.
logging.warn(
'Table {} with id {} has undefined gradients: this is probably '
'because the model asked TPUEmbedding to compute activations that '
'were not used, or tf.stop_gradient() is applied. Gradients of zeros '
'are sent back to TPUEmbedding instead. Gradients of zeros and no '
'gradients are equivalent for SGD, AdaGrad, FTRL, etc, but '
'might differ for other optimizers due to implementation of TPU '
'embedding optimizers.'.format(table, table_id))
gradients_found = gradients_found or any(
gradient is not None for gradient in table_gradients)
if not gradients_found:
logging.warn(
'All tables have undefined gradients: this is probably because the '
'model asked TPUEmbedding to compute activations that were not used. '
'If all TPUEmbedding features have stop_gradients, consider using the '
'INFERENCE mode instead.')
feature_to_gradient_dict = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
table_gradients = g.get_collection(
'tpu_embedding_gradients_table_{}'.format(table_id))
for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
table_gradients):
if gradient is not None:
feature_to_gradient_dict[feature] = gradient
else:
dimension = tpu_embedding.table_to_config_dict[table].dimension
batch_size = tpu_embedding.batch_size_per_core
max_sequence_length = (
tpu_embedding.feature_to_config_dict[feature].max_sequence_length)
if max_sequence_length:
feature_to_gradient_dict[feature] = array_ops.zeros(
[batch_size, max_sequence_length, dimension])
else:
feature_to_gradient_dict[feature] = array_ops.zeros(
[batch_size, dimension])
return feature_to_gradient_dict