blob: 7ada2203bfeede6016f0a245bd99a5fb0fe226c2 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in math_ops.py."""
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
def _safe_shape_div(x, y):
"""Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
return x // math_ops.maximum(y, 1)
@ops.RegisterGradient("ArgMax")
def _ArgMaxGrad(op, grad):
del op, grad
return [None, None]
@ops.RegisterGradient("ArgMin")
def _ArgMinGrad(op, grad):
del op, grad
return [None, None]
@ops.RegisterGradient("EuclideanNorm")
def _EuclideanNormGrad(op, grad):
"""Gradient for EuclideanNorm."""
output = op.outputs[0]
if not op.get_attr("keep_dims"):
output_shape_kept_dims = math_ops.reduced_shape(
array_ops.shape(op.inputs[0]), op.inputs[1])
output = array_ops.reshape(output, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
return math_ops.truediv(op.inputs[0], output / grad), None
def SmartBroadcastGradientArgs(x, y, grad):
"""Optimized version of `broadcast_gradient_args` that caches results.
This implementation avoids creating `broadcast_gradient_args` ops in the case
that the input shapes are fully defined, and provides hints to the calling
code that can be used to avoid creating reduction and reshaping ops.
Args:
x: The left input tensor to a broadcasting binary op.
y: The right input tensor to a broadcasting binary op.
grad: The incoming gradient tensor for a broadcasting binary op.
Returns:
A pair of tuples, containing:
* A 3-tuple of broadcast information for x, containing:
* The shape of x (as a tuple or Tensor).
* The reduction indices for x (as a tuple or Tensor).
* A boolean, which if True, indicates that x's shape differs from grad's
shape (and so x's gradient must be reduced and/or reshaped).
* A 3-tuple of broadcast information for y, containing the respective
details for y.
"""
# NOTE: It may be productive to apply these optimizations in the eager case
# as well.
if context.executing_eagerly() or not (
isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
and isinstance(grad, ops.Tensor)):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (sx, rx, True), (sy, ry, True)
# pylint: disable=protected-access
x_shape_tuple = x._shape_tuple()
y_shape_tuple = y._shape_tuple()
grad_shape_tuple = grad._shape_tuple()
# pylint: enable=protected-access
if (x_shape_tuple is None or None in x_shape_tuple or
y_shape_tuple is None or None in y_shape_tuple):
sx = array_ops.shape_internal(x, optimize=False)
sy = array_ops.shape_internal(y, optimize=False)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (sx, rx, True), (sy, ry, True)
x_needs_reduction = x_shape_tuple != grad_shape_tuple
y_needs_reduction = y_shape_tuple != grad_shape_tuple
# Get the default graph rather than relying on `x.graph`, `y.graph`, or
# `grad.graph`, because these may be eager tensors.
g = ops.get_default_graph()
try:
rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access
return (x_shape_tuple, rx, x_needs_reduction), (
y_shape_tuple, ry, y_needs_reduction)
except KeyError:
rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
# TODO(mrry): If this becomes a bottleneck, add a multi-output version of
# `TF_TryEvaluateConstant()`.
rx_value = tuple(tensor_util.try_evaluate_constant(rx))
assert rx_value is not None
ry_value = tuple(tensor_util.try_evaluate_constant(ry))
assert ry_value is not None
g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access
rx_value, ry_value)
return (x_shape_tuple, rx_value, x_needs_reduction), (
y_shape_tuple, ry_value, y_needs_reduction)
_empty_tuple = ()
def _IsScalar(x):
return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access
@ops.RegisterGradient("Sum")
def _SumGrad(op, grad):
"""Gradient for Sum."""
# Fast path for when reducing to a scalar and ndims is known: adds only
# Reshape and Tile ops (and possibly a Shape).
input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
if input_0_shape is not None:
axes = tensor_util.constant_value(op.inputs[1])
if axes is not None:
rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
if context.executing_eagerly():
ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape)
else:
new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape:
input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
else:
input_shape = array_ops.shape(op.inputs[0])
return [array_ops.tile(grad, input_shape), None]
elif None not in input_0_shape and not context.executing_eagerly():
# The shape and reduction indices are statically known, so we use a
# graph-level cache to avoid recomputing `reduced_shape()` for each
# invocation.
graph = ops.get_default_graph()
# Canonicalize `axes` to be a tuple of indices. The incoming
# value may be a scalar or a vector, and may include negative indices.
axes = tuple(axes.reshape(-1))
try:
output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access
(input_0_shape, axes)]
except KeyError:
# Compute and cache `output_shape_kept_dims` and `tile_scaling`.
def EvaluateAsTuple(t):
if tensor_util.is_tf_type(t):
value = tensor_util.try_evaluate_constant(t)
assert value is not None
else:
value = t
return tuple(value)
output_shape_kept_dims = EvaluateAsTuple(
math_ops.reduced_shape(input_0_shape, axes))
tile_scaling = EvaluateAsTuple(
_safe_shape_div(input_0_shape, output_shape_kept_dims))
graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access
output_shape_kept_dims, tile_scaling)
grad = array_ops.reshape(grad, output_shape_kept_dims)
return [array_ops.tile(grad, tile_scaling), None]
input_shape = array_ops.shape(op.inputs[0])
if not op.get_attr("keep_dims"):
with ops.colocate_with(input_shape):
# TODO(apassos) remove this once device placement for eager ops makes
# more sense.
output_shape_kept_dims = math_ops.reduced_shape(input_shape,
op.inputs[1])
grad = array_ops.reshape(grad, output_shape_kept_dims)
return [array_ops.broadcast_to(grad, input_shape), None]
def _MinOrMaxGrad(op, grad):
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
input_shape = array_ops.shape(op.inputs[0])
y = op.outputs[0]
if not op.get_attr("keep_dims"):
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
y = array_ops.reshape(y, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
else:
output_shape_kept_dims = array_ops.shape(y)
# Compute the number of selected (maximum or minimum) elements in each
# reduction dimension. If there are multiple minimum or maximum elements
# then the gradient will be divided between them.
indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
num_selected = array_ops.reshape(
math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
return [math_ops.divide(indicators, num_selected) * grad, None]
@ops.RegisterGradient("Max")
def _MaxGrad(op, grad):
"""Gradient for Max."""
return _MinOrMaxGrad(op, grad)
@ops.RegisterGradient("Min")
def _MinGrad(op, grad):
return _MinOrMaxGrad(op, grad)
@ops.RegisterGradient("Mean")
def _MeanGrad(op, grad):
"""Gradient for Mean."""
sum_grad = _SumGrad(op, grad)[0]
input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access
if (input_shape is not None and output_shape is not None and
None not in input_shape and None not in output_shape):
input_size = np.prod(input_shape)
output_size = np.prod(output_shape)
factor = input_size // max(output_size, 1)
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
else:
input_shape = array_ops.shape(op.inputs[0])
input_rank = array_ops.size(input_shape)
axes = (op.inputs[1] + input_rank) % input_rank
factor = math_ops.reduce_prod(array_ops.gather(input_shape, axes))
return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
@ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
"""Gradient for Prod."""
# The gradient can be expressed by dividing the product by each entry of the
# input tensor, but this approach can't deal with zeros in the input.
# Here, we avoid this problem by composing the output as a product of two
# cumprod operations.
input_shape = array_ops.shape(op.inputs[0])
# Reshape reduction indices for the case where the parameter is a scalar
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
# Expand grad to full input shape
if not op.get_attr("keep_dims"):
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
grad = array_ops.reshape(grad, output_shape_kept_dims)
grad = array_ops.broadcast_to(grad, input_shape)
# Pack all reduced dimensions into a single one, so we can perform the
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
# so we need to cast here. We put all the shape-related ops on CPU to avoid
# copying back and forth, and since listdiff is CPU only.
with ops.device("/cpu:0"):
rank = array_ops.rank(op.inputs[0])
reduction_indices = (reduction_indices + rank) % rank
reduced = math_ops.cast(reduction_indices, dtypes.int32)
idx = math_ops.range(0, rank)
other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32)
perm = array_ops.concat([reduced, other], 0)
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
permuted = array_ops.transpose(op.inputs[0], perm)
permuted_shape = array_ops.shape(permuted)
reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
# Calculate product, leaving out the current entry
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
# For complex inputs, the gradient is in the conjugate direction.
y = array_ops.reshape(
math_ops.conj(left) * math_ops.conj(right), permuted_shape)
# Invert the transpose and reshape operations.
# Make sure to set the statically known shape information through a reshape.
out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
return array_ops.reshape(out, input_shape), None
@ops.RegisterGradient("SegmentSum")
def _SegmentSumGrad(op, grad):
"""Gradient for SegmentSum."""
return array_ops.gather(grad, op.inputs[1]), None
@ops.RegisterGradient("SegmentMean")
def _SegmentMeanGrad(op, grad):
"""Gradient for SegmentMean."""
input_rank = array_ops.rank(op.inputs[0])
ones_shape = array_ops.concat([
array_ops.shape(op.inputs[1]),
array_ops.ones(
array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32)
], 0)
ones = array_ops.ones(ones_shape, dtype=grad.dtype)
scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
return array_ops.gather(scaled_grad, op.inputs[1]), None
@ops.RegisterGradient("SparseSegmentSum")
def _SparseSegmentSumGrad(op, grad):
"""Gradient for SparseSegmentSum."""
dim0 = array_ops.shape(op.inputs[0])[0]
if compat.forward_compatible(2021, 6, 10):
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
else:
return (math_ops.unsorted_segment_sum(
array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentSumWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
if compat.forward_compatible(2021, 6, 10):
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
else:
return (math_ops.unsorted_segment_sum(
array_ops.gather(grad, op.inputs[2]), op.inputs[1],
dim0), None, None, None)
@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op, grad):
"""Gradient for SparseSegmentMean."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentMeanWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op, grad):
"""Gradient for SparseSegmentSqrtN."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
"""Gradient for SparseSegmentSqrtNWithNumSegments."""
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
def _SegmentMinOrMaxGrad(op, grad):
""" Gradient for SegmentMin and SegmentMax. """
zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
num_selected = math_ops.segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
return array_ops.where_v2(is_selected, gathered_grads, zeros), None
@ops.RegisterGradient("SegmentMin")
def _SegmentMinGrad(op, grad):
"""Gradient for SegmentMin."""
return _SegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("SegmentMax")
def _SegmentMaxGrad(op, grad):
"""Gradient for SegmentMax."""
return _SegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("SegmentProd")
def _SegmentProdGrad(op, grad):
"""Gradient for SegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
Unlike reduce_prod we can't use cumsum here as individual segments may have
a different number of elements. Therefore we consider three cases:
1) A segment input contains no zeros and we can safely divide by the input
tensor.
2) A segment contains exactly one zero. Then the gradient of each input of
the segment is zero except for the 0-input, there the gradient is
the product of the remaining segment entries.
3) A segment contains at least two zeros. The gradient is zero for all
segment inputs.
"""
data = op.inputs[0]
segment_ids = op.inputs[1]
is_zero = math_ops.equal(data, 0)
num_zeros = gen_math_ops.segment_sum(
math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids)
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where_v2(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the segment_prod
non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data)
non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids)
gathered_prod = array_ops.gather(op.outputs[0], segment_ids)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids)
prod_divided_by_el = gathered_prod / non_zero_data
# Now fetch the individual results for segments containing 0 and those that
# don't.
partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
prod_divided_by_el)
gathered_grad = array_ops.gather(grad, segment_ids)
return gathered_grad * partial_derivative, None
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
""" Helper function for unsorted segment ops.
Gathers params for
positive segment ids and gathers 0 for inputs with negative segment id.
Also returns the clipped indices and a boolean mask with the same shape
as ids where a positive id is masked as true. With this, the latter two
can be passed as arguments to this function to reuse them.
"""
if zero_clipped_indices is None:
zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
gathered = array_ops.gather(params, zero_clipped_indices)
if is_positive is None:
is_positive = math_ops.greater_equal(ids, 0)
# tf.where(condition, x, y) requires condition to have the same shape as x
# and y.
is_positive_shape = array_ops.shape(is_positive)
broadcastable_shape = array_ops.concat(
[is_positive_shape,
array_ops.ones([array_ops.rank(gathered)
- array_ops.rank(is_positive)],
dtype=is_positive_shape.dtype)],
axis=0)
is_positive = array_ops.reshape(is_positive, broadcastable_shape)
is_positive = (
is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
# replace gathered params of negative indices with 0
zero_slice = array_ops.zeros_like(gathered)
return (array_ops.where_v2(is_positive, gathered,
zero_slice), zero_clipped_indices, is_positive)
def _UnsortedSegmentMinOrMaxGrad(op, grad):
""" Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs, zero_clipped_indices, is_positive = \
_GatherDropNegatives(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
is_selected = math_ops.logical_and(is_selected, is_positive)
num_selected = math_ops.unsorted_segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = array_ops.zeros_like(gathered_grads)
return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None
@ops.RegisterGradient("UnsortedSegmentSum")
def _UnsortedSegmentSumGrad(op, grad):
"""Gradient for UnsortedSegmentSum."""
return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
@ops.RegisterGradient("UnsortedSegmentMax")
def _UnsortedSegmentMaxGrad(op, grad):
""" Gradient for UnsortedSegmentMax. """
return _UnsortedSegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("UnsortedSegmentMin")
def _UnsortedSegmentMinGrad(op, grad):
""" Gradient for UnsortedSegmentMin. """
return _UnsortedSegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("UnsortedSegmentProd")
def _UnsortedSegmentProdGrad(op, grad):
""" Gradient for UnsortedSegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
Unlike reduce_prod we can't use cumsum here as individual segments may have
a different number of elements. Therefore we consider three cases:
1) A segment input contains no zeros and we can safely divide by the input
tensor.
2) A segment contains exactly one zero. Then the gradient of each input of
the segment is zero except for the 0-input, there the gradient is
the product of the remaining segment entries.
3) A segment contains at least two zeros. The gradient is zero for all
segment inputs.
"""
# Note that unsorted_segment_sum will filter out the negative indices,
# so we don't need to do a logical_and with is_positive here
is_zero = math_ops.equal(op.inputs[0], 0)
num_zeros = gen_math_ops.unsorted_segment_sum(
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where_v2(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the unsorted_segment_prod
non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]),
op.inputs[0])
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
op.inputs[1], op.inputs[2])
# clip the indices for gather to be positive
zero_clipped_indices = math_ops.maximum(op.inputs[1],
array_ops.zeros_like(op.inputs[1]))
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
# Now fetch the individual results for segments containing 0 and those that
# don't. is_zero will also fetch results for entries with negative index
# but the following gather_drop_negatives sets the corresponding entry in
# grad to 0 for these
partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
prod_divided_by_el)
gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
zero_clipped_indices)[0]
return gathered_grad * partial_derivative, None, None
@ops.RegisterGradient("Abs")
def _AbsGrad(op, grad):
x = op.inputs[0]
return grad * math_ops.sign(x)
@ops.RegisterGradient("Neg")
def _NegGrad(_, grad):
"""Returns -grad."""
return -grad
@ops.RegisterGradient("Inv")
def _InvGrad(op, grad):
"""Returns -grad * (1 / x^2)."""
y = op.outputs[0] # y = 1 / x
return gen_math_ops.reciprocal_grad(y, grad)
@ops.RegisterGradient("Reciprocal")
def _ReciprocalGrad(op, grad):
"""Returns -grad * (1 / x^2)."""
y = op.outputs[0] # y = 1 / x
return gen_math_ops.reciprocal_grad(y, grad)
@ops.RegisterGradient("InvGrad")
def _InvGradGrad(op, grad):
b = op.inputs[1]
# op.output[0]: y = -b * conj(a)^2
with ops.control_dependencies([grad]):
ca = math_ops.conj(op.inputs[0])
cg = math_ops.conj(grad)
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
@ops.RegisterGradient("ReciprocalGrad")
def _ReciprocalGradGrad(op, grad):
b = op.inputs[1]
# op.output[0]: y = -b * conj(a)^2
with ops.control_dependencies([grad]):
ca = math_ops.conj(op.inputs[0])
cg = math_ops.conj(grad)
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
@ops.RegisterGradient("Square")
def _SquareGrad(op, grad):
x = op.inputs[0]
# Added control dependencies to prevent 2*x from being computed too early.
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
y = constant_op.constant(2.0, dtype=x.dtype)
return math_ops.multiply(grad, math_ops.multiply(x, y))
@ops.RegisterGradient("Sqrt")
def _SqrtGrad(op, grad):
y = op.outputs[0] # y = x^(1/2)
return gen_math_ops.sqrt_grad(y, grad)
@ops.RegisterGradient("SqrtGrad")
def _SqrtGradGrad(op, grad):
a = op.inputs[0]
y = op.outputs[0] # y = 0.5 * b / conj(a)
with ops.control_dependencies([grad]):
ga = grad / a
return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type
@ops.RegisterGradient("Rsqrt")
def _RsqrtGrad(op, grad):
"""Returns -0.5 * grad * conj(y)^3."""
y = op.outputs[0] # y = x^(-1/2)
return gen_math_ops.rsqrt_grad(y, grad)
@ops.RegisterGradient("RsqrtGrad")
def _RsqrtGradGrad(op, grad):
"""Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
a = op.inputs[0] # a = x^{-1/2}
b = op.inputs[1] # backprop gradient for a
with ops.control_dependencies([grad]):
ca = math_ops.conj(a)
cg = math_ops.conj(grad)
grad_a = -1.5 * cg * b * math_ops.square(ca)
grad_b = gen_math_ops.rsqrt_grad(ca, grad)
return grad_a, grad_b
@ops.RegisterGradient("Exp")
def _ExpGrad(op, grad):
"""Returns grad * exp(x)."""
y = op.outputs[0] # y = e^x
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return grad * y
@ops.RegisterGradient("Expm1")
def _Expm1Grad(op, grad):
"""Returns grad * exp(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
y = math_ops.exp(x)
return grad * y
@ops.RegisterGradient("Log")
def _LogGrad(op, grad):
"""Returns grad * (1/x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.reciprocal(x)
@ops.RegisterGradient("Log1p")
def _Log1pGrad(op, grad):
"""Returns grad * (1/(1 + x))."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.reciprocal(1 + x)
@ops.RegisterGradient("Xlogy")
def _XLogyGrad(op, grad):
"""Returns gradient of xlogy(x, y) with respect to x and y."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
not_zero_x = math_ops.cast(
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
partial_x = gen_math_ops.xlogy(not_zero_x, y)
partial_y = gen_math_ops.xdivy(x, y)
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
@ops.RegisterGradient("Xlog1py")
def _XLog1pyGrad(op, grad):
"""Returns gradient of xlog1py(x, y) with respect to x and y."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
not_zero_x = math_ops.cast(
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
partial_x = gen_math_ops.xlog1py(not_zero_x, y)
partial_y = gen_math_ops.xdivy(x, y + 1.)
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
@ops.RegisterGradient("Xdivy")
def _XDivyGrad(op, grad):
"""Returns gradient of xdivy(x, y) with respect to x and y."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
with ops.control_dependencies([grad]):
not_zero_x = math_ops.cast(
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
partial_x = gen_math_ops.xdivy(not_zero_x, y)
partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.cosh(x)
@ops.RegisterGradient("Cosh")
def _CoshGrad(op, grad):
"""Returns grad * sinh(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.sinh(x)
@ops.RegisterGradient("Tanh")
def _TanhGrad(op, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
y = op.outputs[0] # y = tanh(x)
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return gen_math_ops.tanh_grad(y, grad)
@ops.RegisterGradient("Asinh")
def _AsinhGrad(op, grad):
"""Returns grad * 1/cosh(y)."""
y = op.outputs[0]
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return grad / math_ops.cosh(y)
@ops.RegisterGradient("Acosh")
def _AcoshGrad(op, grad):
"""Returns grad * 1/sinh(y)."""
y = op.outputs[0]
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return grad / math_ops.sinh(y)
@ops.RegisterGradient("Atanh")
def _AtanhGrad(op, grad):
"""Returns grad * 1/ (1 - x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
inv = math_ops.reciprocal(math_ops.subtract(one, x2))
return grad * inv
@ops.RegisterGradient("TanhGrad")
def _TanhGradGrad(op, grad):
with ops.control_dependencies([grad]):
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad)
@ops.RegisterGradient("Erf")
def _ErfGrad(op, grad):
"""Returns grad * 2/sqrt(pi) * exp(-x**2)."""
x = op.inputs[0]
two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
@ops.RegisterGradient("Erfc")
def _ErfcGrad(op, grad):
"""Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
x = op.inputs[0]
minus_two_over_root_pi = constant_op.constant(
-2 / np.sqrt(np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
@ops.RegisterGradient("Erfinv")
def _ErfinvGrad(op, grad):
"""Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2)."""
root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype)
with ops.control_dependencies([grad]):
return grad * root_pi_over_two * math_ops.exp(
math_ops.square(op.outputs[0]))
@ops.RegisterGradient("Ndtri")
def _NdtriGrad(op, grad):
"""Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2)."""
root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad]):
return grad * root_two_pi * math_ops.exp(
math_ops.square(op.outputs[0]) / 2.)
@ops.RegisterGradient("Lgamma")
def _LgammaGrad(op, grad):
"""Returns grad * digamma(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.digamma(x)
@ops.RegisterGradient("Digamma")
def _DigammaGrad(op, grad):
"""Compute gradient of the digamma function with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
return grad * partial_x
@ops.RegisterGradient("Dawsn")
def _DawsnGrad(op, grad):
"""Compute gradient of dawsn(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
return grad * (1. - 2 * x * y)
@ops.RegisterGradient("Expint")
def _ExpintGrad(op, grad):
"""Compute gradient of expint(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
return grad * math_ops.exp(x) / x
@ops.RegisterGradient("FresnelCos")
def _FresnelCosGrad(op, grad):
"""Compute gradient of fresnel_cos(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x))
@ops.RegisterGradient("FresnelSin")
def _FresnelSinGrad(op, grad):
"""Compute gradient of fresnel_sin(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x))
@ops.RegisterGradient("Spence")
def _SpenceGrad(op, grad):
"""Compute gradient of spence(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = math_ops.log(x) / (1 - x)
partial_x = array_ops.where(
math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type
return grad * partial_x
@ops.RegisterGradient("BesselI0")
def _BesselI0Grad(op, grad):
"""Compute gradient of bessel_i0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = special_math_ops.bessel_i1(x)
return grad * partial_x
@ops.RegisterGradient("BesselI0e")
def _BesselI0eGrad(op, grad):
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
return grad * partial_x
@ops.RegisterGradient("BesselI1")
def _BesselI1Grad(op, grad):
"""Compute gradient of bessel_i1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 1.0.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0 and
# bessel_i2, but the latter is not yet implemented in Eigen.
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(1., x.dtype),
special_math_ops.bessel_i0(x) - math_ops.div(y, x))
return grad * dy_dx
@ops.RegisterGradient("BesselI1e")
def _BesselI1eGrad(op, grad):
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 0.5.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0e and
# bessel_i2e, but the latter is not yet implemented in Eigen.
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
special_math_ops.bessel_i0e(x) - y *
(math_ops.sign(x) + math_ops.reciprocal(x)))
return grad * dy_dx
@ops.RegisterGradient("BesselK0")
def _BesselK0Grad(op, grad):
"""Compute gradient of bessel_k0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_k1(x)
return grad * partial_x
@ops.RegisterGradient("BesselK0e")
def _BesselK0eGrad(op, grad):
"""Compute gradient of bessel_k0e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
partial_x = (y - special_math_ops.bessel_k1e(x))
return grad * partial_x
@ops.RegisterGradient("BesselK1")
def _BesselK1Grad(op, grad):
"""Compute gradient of bessel_k1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x)
return grad * partial_x
@ops.RegisterGradient("BesselK1e")
def _BesselK1eGrad(op, grad):
"""Compute gradient of bessel_k1e(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = (
y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x))
return grad * partial_x
@ops.RegisterGradient("BesselJ0")
def _BesselJ0Grad(op, grad):
"""Compute gradient of bessel_j0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_j1(x)
return grad * partial_x
@ops.RegisterGradient("BesselJ1")
def _BesselJ1Grad(op, grad):
"""Compute gradient of bessel_j1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# For x = 0, the correct gradient is 0.5.
# However, the main branch gives NaN because of the division by x, so
# we impute the gradient manually.
# An alternative solution is to express the gradient via bessel_i0e and
# bessel_i2e, but the latter is not yet implemented in Eigen.
dy_dx = array_ops.where_v2(
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
special_math_ops.bessel_j0(x) - math_ops.div(y, x))
return grad * dy_dx
@ops.RegisterGradient("BesselY0")
def _BesselY0Grad(op, grad):
"""Compute gradient of bessel_y0(x) with respect to its argument."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
partial_x = -special_math_ops.bessel_y1(x)
return grad * partial_x
@ops.RegisterGradient("BesselY1")
def _BesselY1Grad(op, grad):
"""Compute gradient of bessel_y1(x) with respect to its argument."""
x = op.inputs[0]
y = op.outputs[0]
with ops.control_dependencies([grad]):
# At 0., this is NaN which is fine since the derivative is undefined
# at 0.
partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x)
return grad * partial_x
@ops.RegisterGradient("Igamma")
def _IgammaGrad(op, grad):
"""Returns gradient of igamma(a, x) with respect to a and x."""
a = op.inputs[0]
x = op.inputs[1]
sa = array_ops.shape(a)
sx = array_ops.shape(x)
ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
with ops.control_dependencies([grad]):
partial_a = gen_math_ops.igamma_grad_a(a, x)
# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
math_ops.lgamma(a))
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Igammac")
def _IgammacGrad(op, grad):
"""Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
return (-igamma_grad_a, -igamma_grad_x)
@ops.RegisterGradient("Betainc")
def _BetaincGrad(op, grad):
"""Returns gradient of betainc(a, b, x) with respect to x."""
# TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
a, b, x = op.inputs
# two cases: x is a scalar and a/b are same-shaped tensors, or vice
# versa; so its sufficient to check against shape(a).
sa = array_ops.shape(a)
sx = array_ops.shape(x)
_, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
# Perform operations in log space before summing, because terms
# can grow large.
log_beta = (
gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
gen_math_ops.lgamma(a + b))
# We use xlog1py and xlogy since the derivatives should tend to
# zero one of the tails when a is 1. or b is 1.
partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) +
math_ops.xlogy(a - 1, x) - log_beta)
return (
None, # da
None, # db
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Zeta")
def _ZetaGrad(op, grad):
"""Returns gradient of zeta(x, q) with respect to x and q."""
# TODO(tillahoffmann): Add derivative with respect to x
x = op.inputs[0]
q = op.inputs[1]
# Broadcast gradients
sx = array_ops.shape(x)
sq = array_ops.shape(q)
unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq)
# Evaluate gradient
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
q = math_ops.conj(q)
partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type
return (None,
array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
@ops.RegisterGradient("Polygamma")
def _PolygammaGrad(op, grad):
"""Returns gradient of psi(n, x) with respect to n and x."""
# TODO(tillahoffmann): Add derivative with respect to n
n = op.inputs[0]
x = op.inputs[1]
# Broadcast gradients
sn = array_ops.shape(n)
sx = array_ops.shape(x)
unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx)
# Evaluate gradient
with ops.control_dependencies([grad]):
n = math_ops.conj(n)
x = math_ops.conj(x)
partial_x = math_ops.polygamma(n + 1, x)
return (None,
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Sigmoid")
def _SigmoidGrad(op, grad):
"""Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
y = op.outputs[0] # y = sigmoid(x)
with ops.control_dependencies([grad]):
y = math_ops.conj(y)
return gen_math_ops.sigmoid_grad(y, grad)
@ops.RegisterGradient("SigmoidGrad")
def _SigmoidGradGrad(op, grad):
with ops.control_dependencies([grad]):
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
gb = grad * b
return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad)
@ops.RegisterGradient("Sign")
def _SignGrad(op, _):
"""Returns 0."""
x = op.inputs[0]
return array_ops.zeros_like(x)
@ops.RegisterGradient("Sin")
def _SinGrad(op, grad):
"""Returns grad * cos(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return grad * math_ops.cos(x)
@ops.RegisterGradient("Cos")
def _CosGrad(op, grad):
"""Returns grad * -sin(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
return -grad * math_ops.sin(x)
@ops.RegisterGradient("Tan")
def _TanGrad(op, grad):
"""Returns grad * 1/sec^2(x)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
secx = math_ops.reciprocal(math_ops.cos(x))
secx2 = math_ops.square(secx)
return secx2 * grad
@ops.RegisterGradient("Asin")
def _AsinGrad(op, grad):
"""Returns grad * 1/sqrt(1-x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
inv = math_ops.reciprocal(den)
return grad * inv
@ops.RegisterGradient("Acos")
def _AcosGrad(op, grad):
"""Returns grad * -1/sqrt(1-x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.subtract(one, x2))
inv = math_ops.reciprocal(den)
return -grad * inv
@ops.RegisterGradient("Atan")
def _AtanGrad(op, grad):
"""Returns grad * 1/ (1 + x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
inv = math_ops.reciprocal(math_ops.add(one, x2))
return grad * inv
@ops.RegisterGradient("Atan2")
def _Atan2Grad(op, grad):
"""Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
y = op.inputs[0]
x = op.inputs[1]
with ops.control_dependencies([grad]):
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad)
)
grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
gx = -y * grad_inv
if must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
gy = x * grad_inv
if must_reduce_y:
gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
return gy, gx
@ops.RegisterGradient("AddN")
def _AddNGrad(op, grad):
"""Copies the gradient to all inputs."""
# Not broadcasting.
return [grad] * len(op.inputs)
def _ShapesFullySpecifiedAndEqual(x, y, grad):
# pylint: disable=protected-access
x_shape = x._shape_tuple()
y_shape = y._shape_tuple()
grad_shape = grad._shape_tuple()
# pylint: enable=protected-access
return (x_shape == y_shape and x_shape == grad_shape and
x_shape is not None and None not in x_shape)
@ops.RegisterGradient("Add")
@ops.RegisterGradient("AddV2")
def _AddGrad(op, grad):
"""Gradient for Add."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return grad, None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, grad
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = grad
else:
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = grad
else:
gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Sub")
def _SubGrad(op, grad):
"""Gradient for Sub."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return grad, None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return grad, -grad
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = grad
else:
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = -grad
else:
gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Mul")
def _MulGrad(op, grad):
"""The gradient of scalar multiplication."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
return gen_math_ops.mul(grad, math_ops.conj(y)), None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad) and
grad.dtype in (dtypes.int32, dtypes.float32)):
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
x = math_ops.conj(x)
y = math_ops.conj(y)
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif not must_reduce_x:
gx = gen_math_ops.mul(grad, y)
else:
gx = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif not must_reduce_y:
gy = gen_math_ops.mul(x, grad)
else:
gy = array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
return (gx, gy)
@ops.RegisterGradient("MulNoNan")
def _MulNoNanGrad(op, grad):
"""The gradient of scalar multiplication with NaN-suppression."""
x = op.inputs[0]
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
"""The gradient for the Div operator."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
# pylint: disable=invalid-unary-operand-type
return (
array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y),
ry), sy))
@ops.RegisterGradient("FloorDiv")
def _FloorDivGrad(_, unused_grad):
"""The gradient for the FloorDiv operator."""
return None, None
@ops.RegisterGradient("FloorMod")
def _FloorModGrad(op, grad):
"""Returns grad * (1, -floor(x/y))."""
x = math_ops.conj(op.inputs[0])
y = math_ops.conj(op.inputs[1])
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
floor_xy = math_ops.floor_div(x, y)
gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
gy = array_ops.reshape(
math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
return gx, gy
@ops.RegisterGradient("TruncateDiv")
def _TruncateDivGrad(_, unused_grad):
return None, None
@ops.RegisterGradient("RealDiv")
def _RealDivGrad(op, grad):
"""RealDiv op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) # pylint: disable=invalid-unary-operand-type
@ops.RegisterGradient("DivNoNan")
def _DivNoNanGrad(op, grad):
"""DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (
array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), # pylint: disable=invalid-unary-operand-type
ry),
sy))
@ops.RegisterGradient("Pow")
def _PowGrad(op, grad):
"""Returns grad * (y*x^(y-1), z*log(x))."""
x = op.inputs[0]
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
# TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
# constant `1` into a single constant op.
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
x = math_ops.conj(x)
y = math_ops.conj(y)
return grad * y * math_ops.pow(x, y - 1), None
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
x = math_ops.conj(x)
y = math_ops.conj(y)
if skip_input_indices is None or 0 not in skip_input_indices:
gx = grad * y * math_ops.pow(x, y - 1)
if must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
else:
gx = None
if skip_input_indices is None or 1 not in skip_input_indices:
z = math_ops.conj(op.outputs[0])
# Avoid false singularity at x = 0
if x.dtype.is_complex:
# real(x) < 0 is fine for the complex case
mask = math_ops.not_equal(x, 0)
else:
# There's no sensible real value to return if x < 0, so return 0
mask = x > 0
safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
gy = grad * z * log_x
if must_reduce_y:
gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
else:
gy = None
return gx, gy
def _MaximumMinimumGradInputOnly(op, grad, selector_op):
x = op.inputs[0]
y = op.inputs[1]
zeros = array_ops.zeros_like(grad)
xmask = selector_op(x, y)
xgrad = array_ops.where_v2(xmask, grad, zeros)
ygrad = None # Return None for ygrad since the config allows that.
return (xgrad, ygrad)
def _MaximumMinimumGrad(op, grad, selector_op):
"""Factor out the code for the gradient of Maximum or Minimum."""
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
y):
# When we want to get gradients for the first input only, and the second
# input tensor is a scalar, we can do a much simpler calculation
return _MaximumMinimumGradInputOnly(op, grad, selector_op)
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
x = op.inputs[0]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
zeros = array_ops.zeros_like(grad)
xmask = selector_op(x, y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
else:
xgrad = array_ops.where_v2(xmask, grad, zeros)
gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
else:
ygrad = array_ops.where_v2(xmask, zeros, grad)
gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
return (gx, gy)
@ops.RegisterGradient("Maximum")
def _MaximumGrad(op, grad):
"""Returns grad*(x >= y, x < y) with type of grad."""
return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
@ops.RegisterGradient("Minimum")
def _MinimumGrad(op, grad):
"""Returns grad*(x <= y, x > y) with type of grad."""
return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
@ops.RegisterGradient("SquaredDifference")
def _SquaredDifferenceGrad(op, grad):
"""Returns the gradient for (x-y)^2."""
x = op.inputs[0]
y = op.inputs[1]
skip_input_indices = None
try:
skip_input_indices = op.skip_input_indices
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
with ops.control_dependencies([grad]):
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by
# Tensor (not a number like 2.0) which causes it to convert to Tensor.
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return x_grad, -x_grad
(sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
SmartBroadcastGradientArgs(x, y, grad))
if skip_input_indices is not None and 0 in skip_input_indices:
gx = None
elif must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
else:
gx = x_grad
if skip_input_indices is not None and 1 in skip_input_indices:
gy = None
elif must_reduce_y:
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
else:
gy = -x_grad
return (gx, gy)
# Logical operations have no gradients.
ops.NotDifferentiable("Less")
ops.NotDifferentiable("LessEqual")
ops.NotDifferentiable("Greater")
ops.NotDifferentiable("GreaterEqual")
ops.NotDifferentiable("Equal")
ops.NotDifferentiable("ApproximateEqual")
ops.NotDifferentiable("NotEqual")
ops.NotDifferentiable("LogicalAnd")
ops.NotDifferentiable("LogicalOr")
ops.NotDifferentiable("LogicalNot")
@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
return (None, array_ops.where(c, grad, zeros), array_ops.where(
c, zeros, grad))
@ops.RegisterGradient("SelectV2")
def _SelectGradV2(op, grad):
c = op.inputs[0]
x = op.inputs[1]
y = op.inputs[2]
zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
gx = array_ops.where_v2(c, grad, zeros)
x_shape = array_ops.shape(x)
output_shape = array_ops.shape(op.outputs[0])
# Reduce away broadcasted leading dims.
reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape)
gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x)
gx = array_ops.reshape(gx, x_shape)
gy = array_ops.where_v2(c, zeros, grad)
y_shape = array_ops.shape(y)
# Reduce away broadcasted leading dims.
reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape)
gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y)
gy = array_ops.reshape(gy, y_shape)
return (None, gx, gy)
def _MatMulGradAgainstFirstOnly(op, grad):
"""Gradient for MatMul, only for the first input."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
elif not t_a and t_b:
grad_a = gen_math_ops.mat_mul(grad, b)
elif t_a and not t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
elif t_a and t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
return grad_a, None
def _MatMulGradAgainstSecondOnly(op, grad):
"""Gradient for MatMul, only for the second input."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
if not t_a and not t_b:
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_b = gen_math_ops.mat_mul(a, grad)
elif t_a and t_b:
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
return None, grad_b
@ops.RegisterGradient("MatMul")
def _MatMulGrad(op, grad):
"""Gradient for MatMul."""
try:
skip_input_indices = op.skip_input_indices
if skip_input_indices is not None:
if 1 in skip_input_indices:
return _MatMulGradAgainstFirstOnly(op, grad)
elif 0 in skip_input_indices:
return _MatMulGradAgainstSecondOnly(op, grad)
except AttributeError:
# No gradient skipping, so do the full gradient computation
pass
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_a = gen_math_ops.mat_mul(grad, b)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
grad_b = gen_math_ops.mat_mul(a, grad)
elif t_a and t_b:
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
return grad_a, grad_b
@ops.RegisterGradient("SparseMatMul")
def _SparseMatMulGrad(op, grad):
"""Gradient for SparseMatMul."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
is_sparse = {}
is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
# Use heuristic to figure out if grad might be sparse
is_sparse[grad.ref()] = not context.executing_eagerly() and (
grad.op.type == "ReluGrad")
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
"""Helper function to create SparseMatMul op."""
assert t1.ref() in is_sparse and t2.ref() in is_sparse
t1_sparse = is_sparse[t1.ref()]
t2_sparse = is_sparse[t2.ref()]
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False
prod = math_ops.matmul(
t1,
t2,
transpose_a=transpose_a,
transpose_b=transpose_b,
a_is_sparse=t1_sparse,
b_is_sparse=t2_sparse)
if prod.dtype != out_dtype:
prod = math_ops.cast(prod, out_dtype)
return prod
dtype_a = op.inputs[0].dtype
dtype_b = op.inputs[1].dtype
if not t_a and not t_b:
return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
_SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
elif not t_a and t_b:
return (_SparseMatMul(grad, op.inputs[1], dtype_a),
_SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
elif t_a and not t_b:
return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
_SparseMatMul(op.inputs[0], grad, dtype_b))
elif t_a and t_b:
return (_SparseMatMul(
op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
_SparseMatMul(
grad, op.inputs[0], dtype_b, transpose_a=True,
transpose_b=True))
@ops.RegisterGradient("Floor")
def _FloorGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Ceil")
def _CeilGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Round")
def _RoundGrad(_, unused_grad):
return [None]
@ops.RegisterGradient("Rint")
def _RintGrad(_, unused_grad):
# the gradient of Rint is zero
return [None]
@ops.RegisterGradient("BatchMatMul")
def _BatchMatMul(op, grad):
"""Returns the gradient of x and y given the gradient of x * y."""
x = op.inputs[0]
y = op.inputs[1]
adj_x = op.get_attr("adj_x")
adj_y = op.get_attr("adj_y")
if not adj_x:
if not adj_y:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
else:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
else:
if not adj_y:
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
else:
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
return grad_x, grad_y
@ops.RegisterGradient("BatchMatMulV2")
@ops.RegisterGradient("BatchMatMulV3")
def _BatchMatMulV2(op, grad):
"""Returns the gradient of x and y given the gradient of x * y."""
x = op.inputs[0]
y = op.inputs[1]
adj_x = op.get_attr("adj_x")
adj_y = op.get_attr("adj_y")
if not adj_x:
if not adj_y:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
else:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
else:
if not adj_y:
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
else:
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
# Possibly reduce along the broadcasted batch dimensions, if broadcasting
# is required.
shape_x_static = x.get_shape()
shape_y_static = y.get_shape()
output_may_have_non_empty_batch_shape = (
(shape_x_static.rank is None or shape_x_static.rank > 2) or
(shape_y_static.rank is None or shape_y_static.rank > 2))
batch_shapes_match = (
shape_x_static[:-2].is_fully_defined() and
shape_y_static[:-2].is_fully_defined() and
shape_x_static[:-2] == shape_y_static[:-2])
if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
return grad_x, grad_y
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
return grad_x, grad_y
ops.NotDifferentiable("Range")
ops.NotDifferentiable("LinSpace")
@ops.RegisterGradient("Complex")
def _ComplexGrad(op, grad):
"""Returns the real and imaginary components of 'grad', respectively."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
@ops.RegisterGradient("Real")
def _RealGrad(_, grad):
"""Returns 'grad' as the real part and set the imaginary part 0."""
zero = constant_op.constant(0, dtype=grad.dtype)
return math_ops.complex(grad, zero)
@ops.RegisterGradient("Imag")
def _ImagGrad(_, grad):
"""Returns 'grad' as the imaginary part and set the real part 0."""
zero = constant_op.constant(0, dtype=grad.dtype)
return math_ops.complex(zero, grad)
@ops.RegisterGradient("Angle")
def _AngleGrad(op, grad):
"""Returns -grad / (Im(x) + iRe(x))"""
x = op.inputs[0]
with ops.control_dependencies([grad]):
re = math_ops.real(x)
im = math_ops.imag(x)
z = math_ops.reciprocal(math_ops.complex(im, re))
zero = constant_op.constant(0, dtype=grad.dtype)
complex_grad = math_ops.complex(grad, zero)
return -complex_grad * z
@ops.RegisterGradient("Conj")
def _ConjGrad(_, grad):
"""Returns the complex conjugate of grad."""
return math_ops.conj(grad)
@ops.RegisterGradient("ComplexAbs")
def _ComplexAbsGrad(op, grad):
"""Returns the gradient of ComplexAbs."""
return math_ops.div_no_nan(
math_ops.complex(
grad, array_ops.zeros_like(grad)) * op.inputs[0],
math_ops.complex(
op.outputs[0], array_ops.zeros_like(op.outputs[0])))
@ops.RegisterGradient("Cast")
def _CastGrad(op, grad):
t = [
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
dtypes.complex64, dtypes.complex128
]
src_type = op.inputs[0].dtype.base_dtype
dst_type = grad.dtype.base_dtype
if src_type in t and dst_type in t:
return math_ops.cast(grad, src_type)
else:
return None
@ops.RegisterGradient("Cross")
def _CrossGrad(op, grad):
u = op.inputs[0]
v = op.inputs[1]
return (math_ops.cross(v, grad), math_ops.cross(grad, u))
@ops.RegisterGradient("Cumsum")
def _CumsumGrad(op, grad):
axis = op.inputs[1]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
return [
math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
None
]
@ops.RegisterGradient("Cumprod")
def _CumprodGrad(op, grad):
x = op.inputs[0]
axis = op.inputs[1]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
out = math_ops.cumsum(
prod * grad, axis, exclusive=exclusive, reverse=not reverse)
return [math_ops.div_no_nan(out, x), None]
@ops.RegisterGradient("CumulativeLogsumexp")
def _CumulativeLogsumexpGrad(op, grad):
x = op.inputs[0]
axis = op.inputs[1]
cumulative_logsumexp = op.outputs[0]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
# Split the incoming gradient into positive and negative part
# in order to take logs. This is required for stable results.
log_grad_positive = array_ops.where_v2(
math_ops.greater(grad, 0),
math_ops.log(grad),
grad.dtype.min)
log_grad_negative = array_ops.where_v2(
math_ops.less(grad, 0),
math_ops.log(-grad),
grad.dtype.min)
output_pos = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_positive - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
output_neg = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_negative - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
return [output_pos - output_neg, None]
@ops.RegisterGradient("NextAfter")
def _NextAfterGrad(op, grad):
"""Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
x1 = op.inputs[0]
x2 = op.inputs[1]
s_x1 = array_ops.shape(x1)
s_x2 = array_ops.shape(x2)
r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2)
with ops.control_dependencies([grad]):
partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype)
partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype)
return (array_ops.reshape(
math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1),
array_ops.reshape(
math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))