blob: 4f5cd9d55886977ad5dc5b8f099fb19a6b81646f [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in nn_ops.py."""
import functools
import itertools
import operator
from tensorflow.python.eager import backprop
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ops.RegisterGradient("Conv2DBackpropInput")
def _Conv2DBackpropInputGrad(op, grad):
"""The derivatives for deconvolution.
Args:
op: the Deconvolution op.
grad: the tensor representing the gradient w.r.t. the output
Returns:
the gradients w.r.t. the input and the filter
"""
# We call the gen_nn_ops backprop functions instead of nn_ops backprop
# functions for performance reasons in Eager mode. See _Conv2DGrad.
return [
None,
gen_nn_ops.conv2d_backprop_filter(
grad,
array_ops.shape(op.inputs[1]),
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
data_format=op.get_attr("data_format").decode()),
gen_nn_ops.conv2d(
grad,
op.inputs[1],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
data_format=op.get_attr("data_format").decode())
]
@ops.RegisterGradient("Conv2DBackpropFilter")
def _Conv2DBackpropFilterGrad(op, grad):
# We call the gen_nn_ops backprop functions instead of nn_ops backprop
# functions for performance reasons in Eager mode. See _Conv2DGrad.
return [
gen_nn_ops.conv2d_backprop_input(
array_ops.shape(op.inputs[0]),
grad,
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
data_format=op.get_attr("data_format").decode()), None,
gen_nn_ops.conv2d(
op.inputs[0],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
data_format=op.get_attr("data_format").decode())
]
@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput")
def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
"""The derivatives for deconvolution.
Args:
op: the Deconvolution op.
grad: the tensor representing the gradient w.r.t. the output
Returns:
the gradients w.r.t. the input and the filter
"""
return [
None,
gen_nn_ops.depthwise_conv2d_native_backprop_filter(
grad,
array_ops.shape(op.inputs[1]),
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")),
gen_nn_ops.depthwise_conv2d_native(
grad,
op.inputs[1],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]
@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
return [
gen_nn_ops.depthwise_conv2d_native_backprop_input(
array_ops.shape(op.inputs[0]),
grad,
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")), None,
gen_nn_ops.depthwise_conv2d_native(
op.inputs[0],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]
@ops.RegisterGradient("Conv3D")
def _Conv3DGrad(op, grad):
data_format = op.get_attr("data_format").decode()
shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
return [
nn_ops.conv3d_backprop_input_v2(
shape_0,
op.inputs[1],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format),
nn_ops.conv3d_backprop_filter_v2(
op.inputs[0],
shape_1,
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format)
]
@ops.RegisterGradient("Conv3DBackpropInputV2")
def _Conv3DBackpropInputGrad(op, grad):
data_format = op.get_attr("data_format").decode()
return [
None,
nn_ops.conv3d_backprop_filter_v2(
grad,
array_ops.shape(op.inputs[1]),
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format),
nn_ops.conv3d(
grad,
op.inputs[1],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format)
]
@ops.RegisterGradient("Conv3DBackpropFilterV2")
def _Conv3DBackpropFilterGrad(op, grad):
data_format = op.get_attr("data_format").decode()
return [
nn_ops.conv3d_backprop_input_v2(
array_ops.shape(op.inputs[0]),
grad,
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format), None,
nn_ops.conv3d(
op.inputs[0],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=data_format)
]
@ops.RegisterGradient("AvgPool3D")
def _AvgPool3DGrad(op, grad):
return gen_nn_ops.avg_pool3d_grad(
array_ops.shape(op.inputs[0]),
grad,
ksize=op.get_attr("ksize"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format").decode())
@ops.RegisterGradient("AvgPool3DGrad")
def _AvgPool3DGradGrad(op, grad):
return (array_ops.stop_gradient(op.inputs[0]),
gen_nn_ops.avg_pool3d(
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
op.get_attr("padding"),
data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("MaxPool3D")
def _MaxPool3DGrad(op, grad):
return gen_nn_ops.max_pool3d_grad(
op.inputs[0],
op.outputs[0],
grad,
ksize=op.get_attr("ksize"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format").decode())
@ops.RegisterGradient("MaxPool3DGrad")
def _MaxPool3DGradGrad(op, grad):
return (array_ops.zeros_like(op.inputs[0]),
array_ops.zeros_like(op.inputs[1]),
gen_nn_ops.max_pool3d_grad_grad(
op.inputs[0],
op.inputs[1],
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("MaxPool3DGradGrad")
def _MaxPool3DGradGradGrad(op, grad):
return (array_ops.zeros_like(op.inputs[0]),
array_ops.zeros_like(op.inputs[1]),
gen_nn_ops.max_pool3d_grad(
op.inputs[0],
op.inputs[1],
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("Softmax")
def _SoftmaxGrad(op, grad_softmax):
"""The derivative of the softmax nonlinearity.
We assume that probs is of shape [batch_size * dim]
The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
This matrix is diagonal minus a rank one matrix, so it is easy to implement
as follows:
grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
Args:
op: the Softmax op.
grad_softmax: the tensor representing the gradient w.r.t. the softmax
output.
Returns:
gradient w.r.t the input to the softmax
"""
softmax = op.outputs[0]
sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True)
return (grad_softmax - sum_channels) * softmax
@ops.RegisterGradient("LogSoftmax")
def _LogSoftmaxGrad(op, grad):
"""The gradient for log_softmax.
log_softmax = input - log(sum(exp(input))
dlog_softmax/dinput = diag - softmax(input)
Args:
op: The log softmax op.
grad: The tensor representing the gradient w.r.t. the output.
Returns:
The gradients w.r.t. the input.
"""
softmax = math_ops.exp(op.outputs[0])
return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax
@ops.RegisterGradient("BiasAdd")
def _BiasAddGrad(op, received_grad):
"""Return the gradients for the 2 inputs of bias_op.
The first input of unused_bias_op is the tensor t, and its gradient is
just the gradient the unused_bias_op received.
The second input of unused_bias_op is the bias vector which has one fewer
dimension than "received_grad" (the batch dimension.) Its gradient is the
received gradient Summed on the batch dimension, which is the first dimension.
Args:
op: The BiasOp for which we need to generate gradients.
received_grad: Tensor. The gradients passed to the BiasOp.
Returns:
Two tensors, the first one for the "tensor" input of the BiasOp,
the second one for the "bias" input of the BiasOp.
"""
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
return (received_grad,
gen_nn_ops.bias_add_grad(
out_backprop=received_grad, data_format=data_format))
@ops.RegisterGradient("BiasAddGrad")
def _BiasAddGradGrad(op, received_grad):
"""Gradient for the BiasAddGrad op.
Args:
op: BiasAddGrad op for which we are calculating gradients.
received_grad: The gradients passed to the BiasAddGrad op.
Returns:
A single gradient Tensor for the input to BiasAddGrad (which
is the gradient of the bias term in BiasAdd)
"""
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
shape = array_ops.shape(op.inputs[0])
bias_shape = array_ops.shape(received_grad)
if data_format == b"NCHW":
expanded_shape = array_ops.concat([
array_ops.ones_like(shape[:1]), bias_shape,
array_ops.ones_like(shape[2:])
], 0)
tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
else:
expanded_shape = array_ops.concat(
[array_ops.ones_like(shape[:-1]), bias_shape], 0)
tile_mults = array_ops.concat([shape[:-1], [1]], 0)
expanded_grad = array_ops.reshape(received_grad, expanded_shape)
return array_ops.tile(expanded_grad, tile_mults)
@ops.RegisterGradient("BiasAddV1")
def _BiasAddGradV1(unused_bias_op, received_grad):
"""Return the gradients for the 2 inputs of bias_op.
The first input of unused_bias_op is the tensor t, and its gradient is
just the gradient the unused_bias_op received.
The second input of unused_bias_op is the bias vector which has one fewer
dimension than "received_grad" (the batch dimension.) Its gradient is the
received gradient Summed on the batch dimension, which is the first dimension.
Args:
unused_bias_op: The BiasOp for which we need to generate gradients.
received_grad: Tensor. The gradients passed to the BiasOp.
Returns:
Two tensors, the first one for the "tensor" input of the BiasOp,
the second one for the "bias" input of the BiasOp.
"""
reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
return (received_grad, math_ops.reduce_sum(received_grad,
reduction_dim_tensor))
@ops.RegisterGradient("Relu")
def _ReluGrad(op, grad):
return gen_nn_ops.relu_grad(grad, op.outputs[0])
@ops.RegisterGradient("EluGrad")
def _EluGradGrad(op, grad):
elu_x = op.inputs[1]
return (gen_nn_ops.elu_grad(grad, elu_x),
array_ops.where(
elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x)))
@ops.RegisterGradient("SeluGrad")
def _SeluGradGrad(op, grad):
selu_x = op.inputs[1]
return (gen_nn_ops.selu_grad(grad, selu_x),
array_ops.where(
selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x)))
@ops.RegisterGradient("Relu6")
def _Relu6Grad(op, grad):
return gen_nn_ops.relu6_grad(grad, op.outputs[0])
@ops.RegisterGradient("Relu6Grad")
def _Relu6GradGrad(op, grad):
x = op.inputs[1]
return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x))
@ops.RegisterGradient("LeakyRelu")
def _LeakyReluGrad(op, grad):
x = op.inputs[0]
alpha = op.get_attr("alpha")
return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
@ops.RegisterGradient("LeakyReluGrad")
def _LeakyReluGradGrad(op, grad):
x = op.inputs[1]
alpha = op.get_attr("alpha")
return (gen_nn_ops.leaky_relu_grad(grad, x,
alpha=alpha), array_ops.zeros_like(x))
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])
@ops.RegisterGradient("Selu")
def _SeluGrad(op, grad):
return gen_nn_ops.selu_grad(grad, op.outputs[0])
@ops.RegisterGradient("Softplus")
def _SoftplusGrad(op, grad):
return grad * math_ops.sigmoid(op.inputs[0])
@ops.RegisterGradient("SoftplusGrad")
def _SoftplusGradGrad(op, grad):
# Let:
# y = tf.nn.softplus(x)
# dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x))
# This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
dy, x = op.inputs
with ops.control_dependencies([grad]):
ddy = gen_nn_ops.softplus_grad(grad, x)
d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
return (ddy, d2x)
@ops.RegisterGradient("Softsign")
def _SoftsignGrad(op, grad):
return gen_nn_ops.softsign_grad(grad, op.inputs[0])
@ops.RegisterGradient("ReluGrad")
def _ReluGradGrad(op, grad):
x = op.inputs[1]
return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x))
def _BroadcastMul(vec, mat):
"""Multiply after broadcasting vec to match dimensions of mat.
Args:
vec: A 1-D tensor of dimension [D0]
mat: A 2-D tensor of dimension [D0, D1]
Returns:
A tensor of dimension [D0, D1], the result of vec * mat
"""
# Reshape vec to [D0, 1]
vec = array_ops.expand_dims(vec, -1)
return vec * mat
@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
#
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)
logits = op.inputs[0]
if (grad_grad is not None and
not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(
array_ops.expand_dims(grad_grad, 1),
array_ops.expand_dims(softmax, 2)),
axis=1)) * softmax)
return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) # pylint: disable=invalid-unary-operand-type
@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
# There is no gradient for the labels
#
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)
logits = op.inputs[0]
if (grad_grad is not None and
not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(
array_ops.expand_dims(grad_grad, 1),
array_ops.expand_dims(softmax, 2)),
axis=1)) * softmax)
return grad, None
@ops.RegisterGradient("Conv2D")
def _Conv2DGrad(op, grad):
"""Gradient function for Conv2D."""
dilations = op.get_attr("dilations")
strides = op.get_attr("strides")
padding = op.get_attr("padding")
explicit_paddings = op.get_attr("explicit_paddings")
use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
data_format = op.get_attr("data_format")
shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
# We call the gen_nn_ops backprop functions instead of nn_ops backprop
# functions for performance reasons in Eager mode. gen_nn_ops functions take a
# `explicit_paddings` parameter, but nn_ops functions do not. So if we were
# to use the nn_ops functions, we would have to convert `padding` and
# `explicit_paddings` into a single `padding` parameter, increasing overhead
# in Eager mode.
return [
gen_nn_ops.conv2d_backprop_input(
shape_0,
op.inputs[1],
grad,
dilations=dilations,
strides=strides,
padding=padding,
explicit_paddings=explicit_paddings,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format),
gen_nn_ops.conv2d_backprop_filter(
op.inputs[0],
shape_1,
grad,
dilations=dilations,
strides=strides,
padding=padding,
explicit_paddings=explicit_paddings,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format)
]
@ops.RegisterGradient("DepthwiseConv2dNative")
def _DepthwiseConv2dNativeGrad(op, grad):
return [
gen_nn_ops.depthwise_conv2d_native_backprop_input(
array_ops.shape(op.inputs[0]),
op.inputs[1],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")),
gen_nn_ops.depthwise_conv2d_native_backprop_filter(
op.inputs[0],
array_ops.shape(op.inputs[1]),
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]
@ops.RegisterGradient("Dilation2D")
def _Dilation2DGrad(op, grad):
return [
nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
op.get_attr("strides"),
op.get_attr("rates"),
op.get_attr("padding")),
nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
op.get_attr("strides"),
op.get_attr("rates"),
op.get_attr("padding"))
]
@ops.RegisterGradient("LRN")
def _LRNGrad(op, grad):
depth_radius = op.get_attr("depth_radius")
bias = op.get_attr("bias")
alpha = op.get_attr("alpha")
beta = op.get_attr("beta")
return [
gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias,
alpha, beta)
]
@ops.RegisterGradient("AvgPool")
def _AvgPoolGrad(op, grad):
return gen_nn_ops.avg_pool_grad(
array_ops.shape(op.inputs[0]),
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
op.get_attr("padding"),
data_format=op.get_attr("data_format"))
@ops.RegisterGradient("AvgPoolGrad")
def _AvgPoolGradGrad(op, grad):
return (array_ops.stop_gradient(op.inputs[0]),
gen_nn_ops.avg_pool(
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
op.get_attr("padding"),
data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPool")
def _MaxPoolGrad(op, grad):
return gen_nn_ops.max_pool_grad(
op.inputs[0],
op.outputs[0],
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
@ops.RegisterGradient("MaxPoolV2")
def _MaxPoolGradV2(op, grad):
ksize = op.inputs[1]
strides = op.inputs[2]
return gen_nn_ops.max_pool_grad_v2(
op.inputs[0],
op.outputs[0],
grad,
ksize,
strides,
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")), None, None
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
del unused_argmax_grad
return gen_nn_ops.max_pool_grad_with_argmax(
op.inputs[0],
grad,
op.outputs[1],
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
include_batch_in_index=op.get_attr("include_batch_in_index"))
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros_like(op.inputs[0]),
array_ops.zeros_like(op.inputs[1]),
gen_nn_ops.max_pool_grad_grad(
op.inputs[0],
op.inputs[1],
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPoolGradV2")
def _MaxPoolGradGradV2(op, grad):
ksize = op.inputs[3]
strides = op.inputs[4]
return (array_ops.zeros_like(op.inputs[0]),
array_ops.zeros_like(op.inputs[1]),
gen_nn_ops.max_pool_grad_grad_v2(
op.inputs[0],
op.inputs[1],
grad,
ksize,
strides,
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")), None, None)
@ops.RegisterGradient("MaxPoolGradGrad")
def _MaxPoolGradGradGrad(op, grad):
return (array_ops.zeros_like(op.inputs[0]),
array_ops.zeros_like(op.inputs[1]),
gen_nn_ops.max_pool_grad(
op.inputs[0],
op.inputs[1],
grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")))
@ops.RegisterGradient("FractionalMaxPool")
def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
"""Returns gradient for FractionalMaxPool.
Since FractionalMaxPool has three outputs, there are three gradients passed in
for each of the outputs. Only the first one is useful, the other two gradients
are empty.
Args:
op: The FractionalMaxPoolOp.
grad_0: Gradient with respect to op.outputs[0]
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
Returns:
Input backprop for FractionalMaxPool op.
"""
return gen_nn_ops.fractional_max_pool_grad(
op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
op.get_attr("overlapping"))
@ops.RegisterGradient("FractionalAvgPool")
def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
"""Returns gradient for FractionalAvgPool.
Since FractionalAvgPool has three outputs, there are three gradients passed in
for each of the outputs. Only the first one is useful, the other two gradients
are empty.
Args:
op: The FractionalAvgPoolOp.
grad_0: Gradient with respect to op.outputs[0]
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
Returns:
Input backprop for FractionalAvgPool op.
"""
return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
op.outputs[1], op.outputs[2],
op.get_attr("overlapping"))
@ops.RegisterGradient("BatchNormWithGlobalNormalization")
def _BatchNormWithGlobalNormalizationGrad(op, grad):
"""Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
We do not backprop anything for the mean and var intentionally as they are
not being trained with backprop in the operation.
Args:
op: The BatchNormOp for which we need to generate gradients.
grad: Tensor. The gradients passed to the BatchNormOp.
Returns:
dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
dm: Backprop for mean, which is
sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
dv: Backprop for variance, which is
sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
db: Backprop for beta, which is grad reduced in all except the
last dimension.
dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
"""
dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad(
op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
return dx, dm, dv, db, dg
def _BaseFusedBatchNormGrad(op, version, *grad):
"""Return the gradients for the 3 inputs of BatchNorm.
Args:
op: The BatchNormOp for which we need to compute gradients.
version: Integer indicating which version to use of the fused batch
norm gradient.
*grad: An argument list for tensors of gradients wrt the outputs
with grad[0] as grad_y.
Returns:
grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) *
[grad_y - mean(grad_y) - (x - mean(x)) *
mean(grad_y * (x - mean(x))) / (variance + epsilon)]
in training mode; grad_y * scale * rsqrt(pop_variance + epsilon)
in freeze mode.
grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) *
rsqrt(variance + epsilon)) in training mode;
sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon))
in freeze mode.
grad_offset: gradient for offset, which is sum(grad_y) in training mode;
sum(grad_y) in freeze mode.
"""
x = op.inputs[0]
grad_y = grad[0]
scale = op.inputs[1]
epsilon = op.get_attr("epsilon")
data_format = op.get_attr("data_format")
is_training = op.get_attr("is_training")
if version == 2:
grad_fun = gen_nn_ops.fused_batch_norm_grad_v3
elif version == 1:
grad_fun = gen_nn_ops.fused_batch_norm_grad_v2
else:
grad_fun = gen_nn_ops.fused_batch_norm_grad
if is_training:
args = {
"y_backprop": grad_y,
"x": x,
"scale": scale,
"reserve_space_1": op.outputs[3],
"reserve_space_2": op.outputs[4],
"epsilon": epsilon,
"data_format": data_format,
"is_training": is_training
}
if version == 2:
args["reserve_space_3"] = op.outputs[5]
dx, dscale, doffset, _, _ = grad_fun(**args)
else:
pop_mean = op.inputs[3]
pop_var = op.inputs[4]
if data_format == b"NCHW":
x = array_ops.transpose(x, [0, 2, 3, 1])
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
elif data_format == b"NCDHW":
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
target_data_format = ("NHWC" if data_format in (b"NCHW",
b"NHWC") else "NDHWC")
args = {
"y_backprop": grad_y,
"x": x,
"scale": scale,
"reserve_space_1": pop_mean,
"reserve_space_2": pop_var,
"epsilon": epsilon,
"data_format": target_data_format,
"is_training": is_training
}
if version == 2:
args["reserve_space_3"] = op.outputs[5]
dx, dscale, doffset, _, _ = grad_fun(**args)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
elif data_format == b"NCDHW":
dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
return dx, dscale, doffset, None, None
@ops.RegisterGradient("FusedBatchNorm")
def _FusedBatchNormGrad(op, *grad):
return _BaseFusedBatchNormGrad(op, 0, *grad)
@ops.RegisterGradient("FusedBatchNormV2")
def _FusedBatchNormV2Grad(op, *grad):
return _BaseFusedBatchNormGrad(op, 1, *grad)
@ops.RegisterGradient("FusedBatchNormV3")
def _FusedBatchNormV3Grad(op, *grad):
return _BaseFusedBatchNormGrad(op, 2, *grad)
def _BatchNormGrad(grad_y,
x,
scale,
pop_mean,
pop_var,
epsilon,
data_format,
is_training=True):
"""Returns the gradients for the 3 inputs of BatchNorm.
Args:
grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
x: A `Tensor` of 4 or 5 dimensions for x.
scale: A `Tensor` of 1 dimension for scaling.
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
is_training=False.
pop_var: A `Tensor` of 1 dimension for the population variance. Only used
when is_training=False.
epsilon: A small float number added to the variance of x.
data_format: The data format for input. Either b"NHWC" or b"NCHW".
is_training: A bool value to indicate the operation is for training
(default) or inference.
Returns:
A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
for x, grad_scale the gradient for scale, and grad_offset the gradient
for offset.
"""
x_dtype = x.dtype.base_dtype
if x_dtype == dtypes.float16 or x_dtype == dtypes.bfloat16:
# float16 math is too imprecise, so we do the batch norm gradient
# computations in float32.
x = math_ops.cast(x, dtypes.float32)
grad_y = math_ops.cast(grad_y, dtypes.float32)
if is_training:
if data_format == b"NHWC":
keepdims = False
reduce_axis = [0, 1, 2]
elif data_format == b"NDHWC":
keepdims = False
reduce_axis = [0, 1, 2, 3]
elif data_format == b"NCHW":
keepdims = True
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(scale), 1, 1]
scale = array_ops.reshape(scale, shape)
else:
keepdims = True
reduce_axis = [0, 2, 3, 4]
shape = [1, array_ops.size(scale), 1, 1, 1]
scale = array_ops.reshape(scale, shape)
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
var_x = math_ops.reduce_mean(
math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
reduce_axis,
keepdims=keepdims)
grad_y_offset = grad_y - mean_grad_y
x_offset = x - mean_x
mean = math_ops.reduce_mean(
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
if data_format == b"NCHW" or data_format == b"NCDHW":
grad_scale = array_ops.squeeze(grad_scale)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
else:
if data_format == b"NHWC":
reduce_axis = [0, 1, 2]
elif data_format == b"NDHWC":
reduce_axis = [0, 1, 2, 3]
elif data_format == b"NCHW":
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(pop_mean), 1, 1]
pop_mean = array_ops.reshape(pop_mean, shape)
pop_var = array_ops.reshape(pop_var, shape)
scale = array_ops.reshape(scale, shape)
else:
reduce_axis = [0, 2, 3, 4]
shape = [1, array_ops.size(pop_mean), 1, 1, 1]
pop_mean = array_ops.reshape(pop_mean, shape)
pop_var = array_ops.reshape(pop_var, shape)
scale = array_ops.reshape(scale, shape)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
grad_scale = math_ops.reduce_sum(
grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
grad_x = grad_y * scale * var_rsqrt
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
@ops.RegisterGradient("FusedBatchNormGrad")
def _FusedBatchNormGradGrad(op, *grad):
"""Returns the gradients for the 3 inputs of FusedBatchNormGrad.
Args:
op: The FusedBatchNormGradOp for which we need to compute gradients.
*grad: An argument list for tensors of gradients wrt the outputs with
grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as
grad_grad_offset.
Returns:
A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y
is the gradient for grad_y, grad_x the gradient for x, grad_scale the
gradient for scale.
"""
data_format = op.get_attr("data_format")
epsilon = op.get_attr("epsilon")
is_training = op.get_attr("is_training")
grad_y = op.inputs[0]
x = op.inputs[1]
scale = op.inputs[2]
pop_mean = op.inputs[3]
pop_var = op.inputs[4]
grad_grad_x = grad[0]
grad_grad_scale = grad[1]
grad_grad_offset = grad[2]
with backprop.GradientTape() as tape:
tape.watch(grad_y)
tape.watch(x)
tape.watch(scale)
grad_x, grad_scale, grad_offset = _BatchNormGrad(
grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
grad_grad_y, grad_x, grad_scale = tape.gradient(
[grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
return grad_grad_y, grad_x, grad_scale, None, None
@ops.RegisterGradient("FusedBatchNormGradV2")
def _FusedBatchNormGradGradV2(op, *grad):
return _FusedBatchNormGradGrad(op, *grad)
@ops.RegisterGradient("FusedBatchNormGradV3")
def _FusedBatchNormGradGradV3(op, *grad):
grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad)
return grad_grad_y, grad_x, grad_scale, None, None, None
@ops.RegisterGradient("L2Loss")
def _L2LossGrad(op, grad):
"""Return the gradients for L2Loss.
Args:
op: The L2LossOp for which we need to generate gradients.
grad: Tensor containing a single number.
Returns:
The gradient, which is (x * grad).
"""
return op.inputs[0] * grad
@ops.RegisterGradient("TopK")
@ops.RegisterGradient("TopKV2")
def _TopKGrad(op, grad, _):
"""Return the gradients for TopK.
Args:
op: The TopKOp for which we need to generate gradients.
grad: Tensor. The gradients passed to the TopKOp.
Returns:
A list of two tensors, the first being the gradient w.r.t to the input and
TopK, and the second being the gradient w.r.t. to the indices (all zero).
"""
in_shape = array_ops.shape(op.inputs[0])
ind_shape = array_ops.shape(op.outputs[1])
# int32 is not supported on GPU hence up-casting
ind_lastdim = array_ops.gather(
math_ops.cast(ind_shape, dtypes.int64),
array_ops.size(ind_shape) - 1)
# Flatten indices to 2D.
ind_2d = array_ops.reshape(
op.outputs[1], array_ops_stack.stack([-1, ind_lastdim]))
in_lastdim = array_ops.gather(
math_ops.cast(in_shape, dtypes.int64),
array_ops.size(in_shape) - 1)
outerdim = array_ops.shape(ind_2d)[0]
# Compute linear indices (flattened to 1D).
ind = array_ops.reshape(
ind_2d + math_ops.cast(
array_ops.expand_dims(
math_ops.range(0,
math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
in_lastdim), -1), dtypes.int32), [-1])
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
return [
array_ops.reshape(
array_ops.scatter_nd(
array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
[math_ops.reduce_prod(in_shape)]), in_shape),
array_ops.zeros([], dtype=dtypes.int32)
]
@ops.RegisterGradient("ApproxTopK")
def _ApproxTopKGradient(op, grad, _):
"""Return the gradients for ApproxTopK.
Args:
op: The ApproxTopK for which we need to generate gradients.
grad: The gradients for backprop.
Returns:
Scattered gradient based on the top-k indices.
"""
# The code below is to generate the correct index and value mapping for
# scatter_nd to work properly.
#
# We use static evaluations as much as possible to reduce the runtime cost.
# That's said, use operation.shape instead of array_ops.shape;
# and use functools.reduce(operator.mul, ...) instead of math_ops.reduce_prod
idx_shape = op.outputs[1].shape
lifted_idx_shape = idx_shape + [1]
flat_shape_len = functools.reduce(operator.mul, idx_shape)
rank = idx_shape.rank
reduction_dim = op.get_attr("reduction_dimension")
if reduction_dim < 0:
reduction_dim = rank + reduction_dim
def GetLiftedIdx(d):
if d == reduction_dim:
return array_ops.reshape(op.outputs[1], lifted_idx_shape)
iota_len = idx_shape[d]
iota_shape = list(itertools.repeat(1, rank + 1))
iota_shape[d] = iota_len
iota = array_ops.reshape(math_ops.range(iota_len), iota_shape)
return array_ops.broadcast_to(iota, lifted_idx_shape)
lifted_idx = array_ops.concat(
list(GetLiftedIdx(d) for d in range(rank)), axis=rank)
flat_idx = array_ops.reshape(lifted_idx, [flat_shape_len, rank])
flat_grad = array_ops.reshape(grad, [flat_shape_len])
return array_ops.scatter_nd(flat_idx, flat_grad, op.inputs[0].shape)
@ops.RegisterGradient("NthElement")
def _NthElementGrad(op, grad):
"""Return the gradients for NthElement.
Args:
op: The NthElementOp for which we need to generate gradients.
grad: Tensor. The gradients passed to the NthElementOp
Returns:
A list of two tensors, the first being the gradient w.r.t. the input,
the second being the gradient w.r.t. the N (None).
"""
input = op.inputs[0] # pylint: disable=redefined-builtin
output = op.outputs[0]
# Compute the number of elements which equal to output in each reduction
# dimension. If there are multiple elements then the gradient will be
# divided between them.
indicators = math_ops.cast(
math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
grad = array_ops.expand_dims(grad, -1)
num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
return [math_ops.divide(indicators, num_selected) * grad, None]
def _MeanAggregator(inputs, segments):
"""Replaces each segment with its mean along the last axis.
Specifically, each value in the `inputs` tensor gets replaced by the mean
value computed from the values that belong to the same segment.
Args:
inputs: A 2-tensor. Aggregation is done over dimension 1.
segments: A 2-tensor, same shape as `input`.
Returns:
The result, same shape and type as `inputs`.
"""
result = []
for inputs_i, segments_i in zip(
array_ops.split(inputs, inputs.shape[0]),
array_ops.split(segments, segments.shape[0])):
# Note that we do not use tf.math.segment_mean, as it has no TPU support.
means_i = math_ops.unsorted_segment_mean(
inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1)
result.append(
array_ops.reshape(array_ops.gather(means_i, segments_i), [-1]))
return array_ops_stack.stack(result, axis=0)
# We have to register the gradients for these ops so that tensorflow will know
# how to differentiate them.
@ops.RegisterGradient("IsotonicRegression")
def _IsotonicRegressionGrad(op, grad_output, grad_segments):
"""Gradient for the isotonic regression function.
Args:
op: The IsotonicRegression tensorflow op.
grad_output: Tensor of incoming gradients with respect to the output.
grad_segments: Tensor of incoming gradients with respect to the segments.
Returns:
A tensor, same size as `grad_output` with the gradient with respect to
the input.
"""
del grad_segments # Discrete, non-differentiable.
segments = op.outputs[1]
return _MeanAggregator(grad_output, segments)