| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Gradients for operators defined in math_ops.py.""" |
| import numpy as np |
| |
| from tensorflow.python.compat import compat |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import special_math_ops |
| |
| |
| def _safe_shape_div(x, y): |
| """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" |
| return x // math_ops.maximum(y, 1) |
| |
| |
| @ops.RegisterGradient("ArgMax") |
| def _ArgMaxGrad(op, grad): |
| del op, grad |
| return [None, None] |
| |
| |
| @ops.RegisterGradient("ArgMin") |
| def _ArgMinGrad(op, grad): |
| del op, grad |
| return [None, None] |
| |
| |
| @ops.RegisterGradient("EuclideanNorm") |
| def _EuclideanNormGrad(op, grad): |
| """Gradient for EuclideanNorm.""" |
| |
| output = op.outputs[0] |
| |
| if not op.get_attr("keep_dims"): |
| output_shape_kept_dims = math_ops.reduced_shape( |
| array_ops.shape(op.inputs[0]), op.inputs[1]) |
| output = array_ops.reshape(output, output_shape_kept_dims) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| |
| return math_ops.truediv(op.inputs[0], output / grad), None |
| |
| |
| def SmartBroadcastGradientArgs(x, y, grad): |
| """Optimized version of `broadcast_gradient_args` that caches results. |
| |
| This implementation avoids creating `broadcast_gradient_args` ops in the case |
| that the input shapes are fully defined, and provides hints to the calling |
| code that can be used to avoid creating reduction and reshaping ops. |
| |
| Args: |
| x: The left input tensor to a broadcasting binary op. |
| y: The right input tensor to a broadcasting binary op. |
| grad: The incoming gradient tensor for a broadcasting binary op. |
| |
| Returns: |
| A pair of tuples, containing: |
| * A 3-tuple of broadcast information for x, containing: |
| * The shape of x (as a tuple or Tensor). |
| * The reduction indices for x (as a tuple or Tensor). |
| * A boolean, which if True, indicates that x's shape differs from grad's |
| shape (and so x's gradient must be reduced and/or reshaped). |
| * A 3-tuple of broadcast information for y, containing the respective |
| details for y. |
| """ |
| # NOTE: It may be productive to apply these optimizations in the eager case |
| # as well. |
| if context.executing_eagerly() or not ( |
| isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) |
| and isinstance(grad, ops.Tensor)): |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (sx, rx, True), (sy, ry, True) |
| |
| # pylint: disable=protected-access |
| x_shape_tuple = x._shape_tuple() |
| y_shape_tuple = y._shape_tuple() |
| grad_shape_tuple = grad._shape_tuple() |
| # pylint: enable=protected-access |
| |
| if (x_shape_tuple is None or None in x_shape_tuple or |
| y_shape_tuple is None or None in y_shape_tuple): |
| sx = array_ops.shape_internal(x, optimize=False) |
| sy = array_ops.shape_internal(y, optimize=False) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (sx, rx, True), (sy, ry, True) |
| |
| x_needs_reduction = x_shape_tuple != grad_shape_tuple |
| y_needs_reduction = y_shape_tuple != grad_shape_tuple |
| |
| # Get the default graph rather than relying on `x.graph`, `y.graph`, or |
| # `grad.graph`, because these may be eager tensors. |
| g = ops.get_default_graph() |
| |
| try: |
| rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access |
| return (x_shape_tuple, rx, x_needs_reduction), ( |
| y_shape_tuple, ry, y_needs_reduction) |
| except KeyError: |
| rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple) |
| # TODO(mrry): If this becomes a bottleneck, add a multi-output version of |
| # `TF_TryEvaluateConstant()`. |
| rx_value = tuple(tensor_util.try_evaluate_constant(rx)) |
| assert rx_value is not None |
| ry_value = tuple(tensor_util.try_evaluate_constant(ry)) |
| assert ry_value is not None |
| g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access |
| rx_value, ry_value) |
| |
| return (x_shape_tuple, rx_value, x_needs_reduction), ( |
| y_shape_tuple, ry_value, y_needs_reduction) |
| |
| |
| _empty_tuple = () |
| |
| |
| def _IsScalar(x): |
| return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access |
| |
| |
| @ops.RegisterGradient("Sum") |
| def _SumGrad(op, grad): |
| """Gradient for Sum.""" |
| # Fast path for when reducing to a scalar and ndims is known: adds only |
| # Reshape and Tile ops (and possibly a Shape). |
| input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access |
| if input_0_shape is not None: |
| axes = tensor_util.constant_value(op.inputs[1]) |
| if axes is not None: |
| rank = len(input_0_shape) |
| if np.array_equal(axes, np.arange(rank)): # Reduce all dims. |
| if context.executing_eagerly(): |
| ctx = context.context() |
| new_shape = ctx.ones_rank_cache().get(rank) |
| if new_shape is None: |
| new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) |
| ctx.ones_rank_cache().put(rank, new_shape) |
| else: |
| new_shape = [1] * rank |
| grad = array_ops.reshape(grad, new_shape) |
| # If shape is not fully defined (but rank is), we use Shape. |
| if None not in input_0_shape: |
| input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) |
| else: |
| input_shape = array_ops.shape(op.inputs[0]) |
| return [array_ops.tile(grad, input_shape), None] |
| elif None not in input_0_shape and not context.executing_eagerly(): |
| # The shape and reduction indices are statically known, so we use a |
| # graph-level cache to avoid recomputing `reduced_shape()` for each |
| # invocation. |
| graph = ops.get_default_graph() |
| |
| # Canonicalize `axes` to be a tuple of indices. The incoming |
| # value may be a scalar or a vector, and may include negative indices. |
| axes = tuple(axes.reshape(-1)) |
| |
| try: |
| output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access |
| (input_0_shape, axes)] |
| except KeyError: |
| |
| # Compute and cache `output_shape_kept_dims` and `tile_scaling`. |
| def EvaluateAsTuple(t): |
| if tensor_util.is_tf_type(t): |
| value = tensor_util.try_evaluate_constant(t) |
| assert value is not None |
| else: |
| value = t |
| return tuple(value) |
| |
| output_shape_kept_dims = EvaluateAsTuple( |
| math_ops.reduced_shape(input_0_shape, axes)) |
| tile_scaling = EvaluateAsTuple( |
| _safe_shape_div(input_0_shape, output_shape_kept_dims)) |
| graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access |
| output_shape_kept_dims, tile_scaling) |
| |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| return [array_ops.tile(grad, tile_scaling), None] |
| |
| input_shape = array_ops.shape(op.inputs[0]) |
| |
| if not op.get_attr("keep_dims"): |
| with ops.colocate_with(input_shape): |
| # TODO(apassos) remove this once device placement for eager ops makes |
| # more sense. |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, |
| op.inputs[1]) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| return [array_ops.broadcast_to(grad, input_shape), None] |
| |
| |
| def _MinOrMaxGrad(op, grad): |
| """Gradient for Min or Max. Amazingly it's precisely the same code.""" |
| input_shape = array_ops.shape(op.inputs[0]) |
| y = op.outputs[0] |
| if not op.get_attr("keep_dims"): |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) |
| y = array_ops.reshape(y, output_shape_kept_dims) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| else: |
| output_shape_kept_dims = array_ops.shape(y) |
| |
| # Compute the number of selected (maximum or minimum) elements in each |
| # reduction dimension. If there are multiple minimum or maximum elements |
| # then the gradient will be divided between them. |
| indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) |
| num_selected = array_ops.reshape( |
| math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) |
| |
| return [math_ops.divide(indicators, num_selected) * grad, None] |
| |
| |
| @ops.RegisterGradient("Max") |
| def _MaxGrad(op, grad): |
| """Gradient for Max.""" |
| return _MinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("Min") |
| def _MinGrad(op, grad): |
| return _MinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("Mean") |
| def _MeanGrad(op, grad): |
| """Gradient for Mean.""" |
| sum_grad = _SumGrad(op, grad)[0] |
| input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access |
| output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access |
| if (input_shape is not None and output_shape is not None and |
| None not in input_shape and None not in output_shape): |
| input_size = np.prod(input_shape) |
| output_size = np.prod(output_shape) |
| factor = input_size // max(output_size, 1) |
| factor = constant_op.constant(factor, dtype=sum_grad.dtype) |
| else: |
| input_shape = array_ops.shape(op.inputs[0]) |
| input_rank = array_ops.size(input_shape) |
| axes = (op.inputs[1] + input_rank) % input_rank |
| factor = math_ops.reduce_prod(array_ops.gather(input_shape, axes)) |
| return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None |
| |
| |
| @ops.RegisterGradient("Prod") |
| def _ProdGrad(op, grad): |
| """Gradient for Prod.""" |
| # The gradient can be expressed by dividing the product by each entry of the |
| # input tensor, but this approach can't deal with zeros in the input. |
| # Here, we avoid this problem by composing the output as a product of two |
| # cumprod operations. |
| |
| input_shape = array_ops.shape(op.inputs[0]) |
| # Reshape reduction indices for the case where the parameter is a scalar |
| reduction_indices = array_ops.reshape(op.inputs[1], [-1]) |
| |
| # Expand grad to full input shape |
| if not op.get_attr("keep_dims"): |
| output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) |
| grad = array_ops.reshape(grad, output_shape_kept_dims) |
| |
| grad = array_ops.broadcast_to(grad, input_shape) |
| |
| # Pack all reduced dimensions into a single one, so we can perform the |
| # cumprod ops. If the reduction dims list is empty, it defaults to float32, |
| # so we need to cast here. We put all the shape-related ops on CPU to avoid |
| # copying back and forth, and since listdiff is CPU only. |
| with ops.device("/cpu:0"): |
| rank = array_ops.rank(op.inputs[0]) |
| reduction_indices = (reduction_indices + rank) % rank |
| reduced = math_ops.cast(reduction_indices, dtypes.int32) |
| idx = math_ops.range(0, rank) |
| other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32) |
| perm = array_ops.concat([reduced, other], 0) |
| reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) |
| other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) |
| permuted = array_ops.transpose(op.inputs[0], perm) |
| permuted_shape = array_ops.shape(permuted) |
| reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) |
| |
| # Calculate product, leaving out the current entry |
| left = math_ops.cumprod(reshaped, axis=0, exclusive=True) |
| right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) |
| # For complex inputs, the gradient is in the conjugate direction. |
| y = array_ops.reshape( |
| math_ops.conj(left) * math_ops.conj(right), permuted_shape) |
| |
| # Invert the transpose and reshape operations. |
| # Make sure to set the statically known shape information through a reshape. |
| out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) |
| return array_ops.reshape(out, input_shape), None |
| |
| |
| @ops.RegisterGradient("SegmentSum") |
| def _SegmentSumGrad(op, grad): |
| """Gradient for SegmentSum.""" |
| return array_ops.gather(grad, op.inputs[1]), None |
| |
| |
| @ops.RegisterGradient("SegmentMean") |
| def _SegmentMeanGrad(op, grad): |
| """Gradient for SegmentMean.""" |
| input_rank = array_ops.rank(op.inputs[0]) |
| ones_shape = array_ops.concat([ |
| array_ops.shape(op.inputs[1]), |
| array_ops.ones( |
| array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32) |
| ], 0) |
| ones = array_ops.ones(ones_shape, dtype=grad.dtype) |
| scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1])) |
| return array_ops.gather(scaled_grad, op.inputs[1]), None |
| |
| |
| @ops.RegisterGradient("SparseSegmentSum") |
| def _SparseSegmentSumGrad(op, grad): |
| """Gradient for SparseSegmentSum.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| if compat.forward_compatible(2021, 6, 10): |
| return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None) |
| else: |
| return (math_ops.unsorted_segment_sum( |
| array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSumWithNumSegments") |
| def _SparseSegmentSumWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentSumWithNumSegments.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| if compat.forward_compatible(2021, 6, 10): |
| return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None, None) |
| else: |
| return (math_ops.unsorted_segment_sum( |
| array_ops.gather(grad, op.inputs[2]), op.inputs[1], |
| dim0), None, None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentMean") |
| def _SparseSegmentMeanGrad(op, grad): |
| """Gradient for SparseSegmentMean.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentMeanWithNumSegments") |
| def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentMeanWithNumSegments.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSqrtN") |
| def _SparseSegmentSqrtNGrad(op, grad): |
| """Gradient for SparseSegmentSqrtN.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None) |
| |
| |
| @ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") |
| def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): |
| """Gradient for SparseSegmentSqrtNWithNumSegments.""" |
| dim0 = array_ops.shape(op.inputs[0])[0] |
| return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], |
| dim0), None, None, None) |
| |
| |
| def _SegmentMinOrMaxGrad(op, grad): |
| """ Gradient for SegmentMin and SegmentMax. """ |
| zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) |
| # Get the number of selected (minimum or maximum) elements in each segment. |
| gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) |
| is_selected = math_ops.equal(op.inputs[0], gathered_outputs) |
| num_selected = math_ops.segment_sum( |
| math_ops.cast(is_selected, grad.dtype), op.inputs[1]) |
| # Compute the gradient for each segment. The gradient for the ith segment is |
| # divided evenly among the selected elements in that segment. |
| weighted_grads = math_ops.divide(grad, num_selected) |
| gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) |
| return array_ops.where_v2(is_selected, gathered_grads, zeros), None |
| |
| |
| @ops.RegisterGradient("SegmentMin") |
| def _SegmentMinGrad(op, grad): |
| """Gradient for SegmentMin.""" |
| return _SegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("SegmentMax") |
| def _SegmentMaxGrad(op, grad): |
| """Gradient for SegmentMax.""" |
| return _SegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("SegmentProd") |
| def _SegmentProdGrad(op, grad): |
| """Gradient for SegmentProd. |
| |
| The gradient can be expressed for each segment by dividing the segment's |
| product by each element of the segment input tensor, but this approach can't |
| deal with zeros in the input. |
| Unlike reduce_prod we can't use cumsum here as individual segments may have |
| a different number of elements. Therefore we consider three cases: |
| 1) A segment input contains no zeros and we can safely divide by the input |
| tensor. |
| 2) A segment contains exactly one zero. Then the gradient of each input of |
| the segment is zero except for the 0-input, there the gradient is |
| the product of the remaining segment entries. |
| 3) A segment contains at least two zeros. The gradient is zero for all |
| segment inputs. |
| """ |
| data = op.inputs[0] |
| segment_ids = op.inputs[1] |
| is_zero = math_ops.equal(data, 0) |
| num_zeros = gen_math_ops.segment_sum( |
| math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids) |
| # handle case 3 and set the gradient to 0 for segments with more than one |
| # 0 as input |
| grad = array_ops.where_v2( |
| math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) |
| # replace all zeros with ones and compute the segment_prod |
| non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data) |
| non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids) |
| gathered_prod = array_ops.gather(op.outputs[0], segment_ids) |
| gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids) |
| prod_divided_by_el = gathered_prod / non_zero_data |
| # Now fetch the individual results for segments containing 0 and those that |
| # don't. |
| partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, |
| prod_divided_by_el) |
| gathered_grad = array_ops.gather(grad, segment_ids) |
| return gathered_grad * partial_derivative, None |
| |
| |
| def _GatherDropNegatives(params, |
| ids, |
| zero_clipped_indices=None, |
| is_positive=None): |
| """ Helper function for unsorted segment ops. |
| |
| Gathers params for |
| positive segment ids and gathers 0 for inputs with negative segment id. |
| Also returns the clipped indices and a boolean mask with the same shape |
| as ids where a positive id is masked as true. With this, the latter two |
| can be passed as arguments to this function to reuse them. |
| """ |
| if zero_clipped_indices is None: |
| zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) |
| gathered = array_ops.gather(params, zero_clipped_indices) |
| if is_positive is None: |
| is_positive = math_ops.greater_equal(ids, 0) |
| # tf.where(condition, x, y) requires condition to have the same shape as x |
| # and y. |
| is_positive_shape = array_ops.shape(is_positive) |
| broadcastable_shape = array_ops.concat( |
| [is_positive_shape, |
| array_ops.ones([array_ops.rank(gathered) |
| - array_ops.rank(is_positive)], |
| dtype=is_positive_shape.dtype)], |
| axis=0) |
| is_positive = array_ops.reshape(is_positive, broadcastable_shape) |
| is_positive = ( |
| is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool)) |
| # replace gathered params of negative indices with 0 |
| zero_slice = array_ops.zeros_like(gathered) |
| return (array_ops.where_v2(is_positive, gathered, |
| zero_slice), zero_clipped_indices, is_positive) |
| |
| |
| def _UnsortedSegmentMinOrMaxGrad(op, grad): |
| """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ |
| # Get the number of selected (minimum or maximum) elements in each segment. |
| gathered_outputs, zero_clipped_indices, is_positive = \ |
| _GatherDropNegatives(op.outputs[0], op.inputs[1]) |
| is_selected = math_ops.equal(op.inputs[0], gathered_outputs) |
| is_selected = math_ops.logical_and(is_selected, is_positive) |
| num_selected = math_ops.unsorted_segment_sum( |
| math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) |
| # Compute the gradient for each segment. The gradient for the ith segment is |
| # divided evenly among the selected elements in that segment. |
| weighted_grads = math_ops.divide(grad, num_selected) |
| gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, |
| zero_clipped_indices, is_positive) |
| zeros = array_ops.zeros_like(gathered_grads) |
| return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentSum") |
| def _UnsortedSegmentSumGrad(op, grad): |
| """Gradient for UnsortedSegmentSum.""" |
| return _GatherDropNegatives(grad, op.inputs[1])[0], None, None |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentMax") |
| def _UnsortedSegmentMaxGrad(op, grad): |
| """ Gradient for UnsortedSegmentMax. """ |
| return _UnsortedSegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentMin") |
| def _UnsortedSegmentMinGrad(op, grad): |
| """ Gradient for UnsortedSegmentMin. """ |
| return _UnsortedSegmentMinOrMaxGrad(op, grad) |
| |
| |
| @ops.RegisterGradient("UnsortedSegmentProd") |
| def _UnsortedSegmentProdGrad(op, grad): |
| """ Gradient for UnsortedSegmentProd. |
| |
| The gradient can be expressed for each segment by dividing the segment's |
| product by each element of the segment input tensor, but this approach can't |
| deal with zeros in the input. |
| Unlike reduce_prod we can't use cumsum here as individual segments may have |
| a different number of elements. Therefore we consider three cases: |
| 1) A segment input contains no zeros and we can safely divide by the input |
| tensor. |
| 2) A segment contains exactly one zero. Then the gradient of each input of |
| the segment is zero except for the 0-input, there the gradient is |
| the product of the remaining segment entries. |
| 3) A segment contains at least two zeros. The gradient is zero for all |
| segment inputs. |
| """ |
| # Note that unsorted_segment_sum will filter out the negative indices, |
| # so we don't need to do a logical_and with is_positive here |
| is_zero = math_ops.equal(op.inputs[0], 0) |
| num_zeros = gen_math_ops.unsorted_segment_sum( |
| math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) |
| # handle case 3 and set the gradient to 0 for segments with more than one |
| # 0 as input |
| grad = array_ops.where_v2( |
| math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) |
| # replace all zeros with ones and compute the unsorted_segment_prod |
| non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]), |
| op.inputs[0]) |
| non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data, |
| op.inputs[1], op.inputs[2]) |
| # clip the indices for gather to be positive |
| zero_clipped_indices = math_ops.maximum(op.inputs[1], |
| array_ops.zeros_like(op.inputs[1])) |
| gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) |
| gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices) |
| prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. |
| # Now fetch the individual results for segments containing 0 and those that |
| # don't. is_zero will also fetch results for entries with negative index |
| # but the following gather_drop_negatives sets the corresponding entry in |
| # grad to 0 for these |
| partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, |
| prod_divided_by_el) |
| gathered_grad = _GatherDropNegatives(grad, op.inputs[1], |
| zero_clipped_indices)[0] |
| return gathered_grad * partial_derivative, None, None |
| |
| |
| @ops.RegisterGradient("Abs") |
| def _AbsGrad(op, grad): |
| x = op.inputs[0] |
| return grad * math_ops.sign(x) |
| |
| |
| @ops.RegisterGradient("Neg") |
| def _NegGrad(_, grad): |
| """Returns -grad.""" |
| return -grad |
| |
| |
| @ops.RegisterGradient("Inv") |
| def _InvGrad(op, grad): |
| """Returns -grad * (1 / x^2).""" |
| y = op.outputs[0] # y = 1 / x |
| return gen_math_ops.reciprocal_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("Reciprocal") |
| def _ReciprocalGrad(op, grad): |
| """Returns -grad * (1 / x^2).""" |
| y = op.outputs[0] # y = 1 / x |
| return gen_math_ops.reciprocal_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("InvGrad") |
| def _InvGradGrad(op, grad): |
| b = op.inputs[1] |
| # op.output[0]: y = -b * conj(a)^2 |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(op.inputs[0]) |
| cg = math_ops.conj(grad) |
| return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) |
| |
| |
| @ops.RegisterGradient("ReciprocalGrad") |
| def _ReciprocalGradGrad(op, grad): |
| b = op.inputs[1] |
| # op.output[0]: y = -b * conj(a)^2 |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(op.inputs[0]) |
| cg = math_ops.conj(grad) |
| return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) |
| |
| |
| @ops.RegisterGradient("Square") |
| def _SquareGrad(op, grad): |
| x = op.inputs[0] |
| # Added control dependencies to prevent 2*x from being computed too early. |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| y = constant_op.constant(2.0, dtype=x.dtype) |
| return math_ops.multiply(grad, math_ops.multiply(x, y)) |
| |
| |
| @ops.RegisterGradient("Sqrt") |
| def _SqrtGrad(op, grad): |
| y = op.outputs[0] # y = x^(1/2) |
| return gen_math_ops.sqrt_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("SqrtGrad") |
| def _SqrtGradGrad(op, grad): |
| a = op.inputs[0] |
| y = op.outputs[0] # y = 0.5 * b / conj(a) |
| with ops.control_dependencies([grad]): |
| ga = grad / a |
| return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type |
| |
| |
| @ops.RegisterGradient("Rsqrt") |
| def _RsqrtGrad(op, grad): |
| """Returns -0.5 * grad * conj(y)^3.""" |
| y = op.outputs[0] # y = x^(-1/2) |
| return gen_math_ops.rsqrt_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("RsqrtGrad") |
| def _RsqrtGradGrad(op, grad): |
| """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3.""" |
| a = op.inputs[0] # a = x^{-1/2} |
| b = op.inputs[1] # backprop gradient for a |
| with ops.control_dependencies([grad]): |
| ca = math_ops.conj(a) |
| cg = math_ops.conj(grad) |
| grad_a = -1.5 * cg * b * math_ops.square(ca) |
| grad_b = gen_math_ops.rsqrt_grad(ca, grad) |
| return grad_a, grad_b |
| |
| |
| @ops.RegisterGradient("Exp") |
| def _ExpGrad(op, grad): |
| """Returns grad * exp(x).""" |
| y = op.outputs[0] # y = e^x |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return grad * y |
| |
| |
| @ops.RegisterGradient("Expm1") |
| def _Expm1Grad(op, grad): |
| """Returns grad * exp(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| y = math_ops.exp(x) |
| return grad * y |
| |
| |
| @ops.RegisterGradient("Log") |
| def _LogGrad(op, grad): |
| """Returns grad * (1/x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.reciprocal(x) |
| |
| |
| @ops.RegisterGradient("Log1p") |
| def _Log1pGrad(op, grad): |
| """Returns grad * (1/(1 + x)).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.reciprocal(1 + x) |
| |
| |
| @ops.RegisterGradient("Xlogy") |
| def _XLogyGrad(op, grad): |
| """Returns gradient of xlogy(x, y) with respect to x and y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| not_zero_x = math_ops.cast( |
| math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) |
| partial_x = gen_math_ops.xlogy(not_zero_x, y) |
| partial_y = gen_math_ops.xdivy(x, y) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Xlog1py") |
| def _XLog1pyGrad(op, grad): |
| """Returns gradient of xlog1py(x, y) with respect to x and y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| not_zero_x = math_ops.cast( |
| math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) |
| partial_x = gen_math_ops.xlog1py(not_zero_x, y) |
| partial_y = gen_math_ops.xdivy(x, y + 1.) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Xdivy") |
| def _XDivyGrad(op, grad): |
| """Returns gradient of xdivy(x, y) with respect to x and y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| with ops.control_dependencies([grad]): |
| not_zero_x = math_ops.cast( |
| math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) |
| partial_x = gen_math_ops.xdivy(not_zero_x, y) |
| partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) |
| |
| |
| @ops.RegisterGradient("Sinh") |
| def _SinhGrad(op, grad): |
| """Returns grad * cosh(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.cosh(x) |
| |
| |
| @ops.RegisterGradient("Cosh") |
| def _CoshGrad(op, grad): |
| """Returns grad * sinh(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.sinh(x) |
| |
| |
| @ops.RegisterGradient("Tanh") |
| def _TanhGrad(op, grad): |
| """Returns grad * (1 - tanh(x) * tanh(x)).""" |
| y = op.outputs[0] # y = tanh(x) |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return gen_math_ops.tanh_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("Asinh") |
| def _AsinhGrad(op, grad): |
| """Returns grad * 1/cosh(y).""" |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return grad / math_ops.cosh(y) |
| |
| |
| @ops.RegisterGradient("Acosh") |
| def _AcoshGrad(op, grad): |
| """Returns grad * 1/sinh(y).""" |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return grad / math_ops.sinh(y) |
| |
| |
| @ops.RegisterGradient("Atanh") |
| def _AtanhGrad(op, grad): |
| """Returns grad * 1/ (1 - x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| inv = math_ops.reciprocal(math_ops.subtract(one, x2)) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("TanhGrad") |
| def _TanhGradGrad(op, grad): |
| with ops.control_dependencies([grad]): |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad) |
| |
| |
| @ops.RegisterGradient("Erf") |
| def _ErfGrad(op, grad): |
| """Returns grad * 2/sqrt(pi) * exp(-x**2).""" |
| x = op.inputs[0] |
| two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("Erfc") |
| def _ErfcGrad(op, grad): |
| """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" |
| x = op.inputs[0] |
| minus_two_over_root_pi = constant_op.constant( |
| -2 / np.sqrt(np.pi), dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("Erfinv") |
| def _ErfinvGrad(op, grad): |
| """Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2).""" |
| root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| return grad * root_pi_over_two * math_ops.exp( |
| math_ops.square(op.outputs[0])) |
| |
| |
| @ops.RegisterGradient("Ndtri") |
| def _NdtriGrad(op, grad): |
| """Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2).""" |
| root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype) |
| with ops.control_dependencies([grad]): |
| return grad * root_two_pi * math_ops.exp( |
| math_ops.square(op.outputs[0]) / 2.) |
| |
| |
| @ops.RegisterGradient("Lgamma") |
| def _LgammaGrad(op, grad): |
| """Returns grad * digamma(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.digamma(x) |
| |
| |
| @ops.RegisterGradient("Digamma") |
| def _DigammaGrad(op, grad): |
| """Compute gradient of the digamma function with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("Dawsn") |
| def _DawsnGrad(op, grad): |
| """Compute gradient of dawsn(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| return grad * (1. - 2 * x * y) |
| |
| |
| @ops.RegisterGradient("Expint") |
| def _ExpintGrad(op, grad): |
| """Compute gradient of expint(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| return grad * math_ops.exp(x) / x |
| |
| |
| @ops.RegisterGradient("FresnelCos") |
| def _FresnelCosGrad(op, grad): |
| """Compute gradient of fresnel_cos(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("FresnelSin") |
| def _FresnelSinGrad(op, grad): |
| """Compute gradient of fresnel_sin(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x)) |
| |
| |
| @ops.RegisterGradient("Spence") |
| def _SpenceGrad(op, grad): |
| """Compute gradient of spence(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = math_ops.log(x) / (1 - x) |
| partial_x = array_ops.where( |
| math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselI0") |
| def _BesselI0Grad(op, grad): |
| """Compute gradient of bessel_i0(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = special_math_ops.bessel_i1(x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselI0e") |
| def _BesselI0eGrad(op, grad): |
| """Compute gradient of bessel_i0e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselI1") |
| def _BesselI1Grad(op, grad): |
| """Compute gradient of bessel_i1(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # For x = 0, the correct gradient is 1.0. |
| # However, the main branch gives NaN because of the division by x, so |
| # we impute the gradient manually. |
| # An alternative solution is to express the gradient via bessel_i0 and |
| # bessel_i2, but the latter is not yet implemented in Eigen. |
| dy_dx = array_ops.where_v2( |
| math_ops.equal(x, 0.), math_ops.cast(1., x.dtype), |
| special_math_ops.bessel_i0(x) - math_ops.div(y, x)) |
| return grad * dy_dx |
| |
| |
| @ops.RegisterGradient("BesselI1e") |
| def _BesselI1eGrad(op, grad): |
| """Compute gradient of bessel_i1e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # For x = 0, the correct gradient is 0.5. |
| # However, the main branch gives NaN because of the division by x, so |
| # we impute the gradient manually. |
| # An alternative solution is to express the gradient via bessel_i0e and |
| # bessel_i2e, but the latter is not yet implemented in Eigen. |
| dy_dx = array_ops.where_v2( |
| math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), |
| special_math_ops.bessel_i0e(x) - y * |
| (math_ops.sign(x) + math_ops.reciprocal(x))) |
| return grad * dy_dx |
| |
| |
| @ops.RegisterGradient("BesselK0") |
| def _BesselK0Grad(op, grad): |
| """Compute gradient of bessel_k0(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = -special_math_ops.bessel_k1(x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselK0e") |
| def _BesselK0eGrad(op, grad): |
| """Compute gradient of bessel_k0e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = (y - special_math_ops.bessel_k1e(x)) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselK1") |
| def _BesselK1Grad(op, grad): |
| """Compute gradient of bessel_k1(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # At 0., this is NaN which is fine since the derivative is undefined |
| # at 0. |
| partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselK1e") |
| def _BesselK1eGrad(op, grad): |
| """Compute gradient of bessel_k1e(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # At 0., this is NaN which is fine since the derivative is undefined |
| # at 0. |
| partial_x = ( |
| y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x)) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselJ0") |
| def _BesselJ0Grad(op, grad): |
| """Compute gradient of bessel_j0(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = -special_math_ops.bessel_j1(x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselJ1") |
| def _BesselJ1Grad(op, grad): |
| """Compute gradient of bessel_j1(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # For x = 0, the correct gradient is 0.5. |
| # However, the main branch gives NaN because of the division by x, so |
| # we impute the gradient manually. |
| # An alternative solution is to express the gradient via bessel_i0e and |
| # bessel_i2e, but the latter is not yet implemented in Eigen. |
| dy_dx = array_ops.where_v2( |
| math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), |
| special_math_ops.bessel_j0(x) - math_ops.div(y, x)) |
| return grad * dy_dx |
| |
| |
| @ops.RegisterGradient("BesselY0") |
| def _BesselY0Grad(op, grad): |
| """Compute gradient of bessel_y0(x) with respect to its argument.""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| partial_x = -special_math_ops.bessel_y1(x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("BesselY1") |
| def _BesselY1Grad(op, grad): |
| """Compute gradient of bessel_y1(x) with respect to its argument.""" |
| x = op.inputs[0] |
| y = op.outputs[0] |
| with ops.control_dependencies([grad]): |
| # At 0., this is NaN which is fine since the derivative is undefined |
| # at 0. |
| partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x) |
| return grad * partial_x |
| |
| |
| @ops.RegisterGradient("Igamma") |
| def _IgammaGrad(op, grad): |
| """Returns gradient of igamma(a, x) with respect to a and x.""" |
| a = op.inputs[0] |
| x = op.inputs[1] |
| sa = array_ops.shape(a) |
| sx = array_ops.shape(x) |
| ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) |
| |
| with ops.control_dependencies([grad]): |
| partial_a = gen_math_ops.igamma_grad_a(a, x) |
| # Perform operations in log space before summing, because Gamma(a) |
| # and Gamma'(a) can grow large. |
| partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - |
| math_ops.lgamma(a)) |
| return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Igammac") |
| def _IgammacGrad(op, grad): |
| """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" |
| igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) |
| return (-igamma_grad_a, -igamma_grad_x) |
| |
| |
| @ops.RegisterGradient("Betainc") |
| def _BetaincGrad(op, grad): |
| """Returns gradient of betainc(a, b, x) with respect to x.""" |
| # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b |
| a, b, x = op.inputs |
| |
| # two cases: x is a scalar and a/b are same-shaped tensors, or vice |
| # versa; so its sufficient to check against shape(a). |
| sa = array_ops.shape(a) |
| sx = array_ops.shape(x) |
| _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) |
| |
| # Perform operations in log space before summing, because terms |
| # can grow large. |
| log_beta = ( |
| gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - |
| gen_math_ops.lgamma(a + b)) |
| # We use xlog1py and xlogy since the derivatives should tend to |
| # zero one of the tails when a is 1. or b is 1. |
| partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) + |
| math_ops.xlogy(a - 1, x) - log_beta) |
| |
| return ( |
| None, # da |
| None, # db |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Zeta") |
| def _ZetaGrad(op, grad): |
| """Returns gradient of zeta(x, q) with respect to x and q.""" |
| # TODO(tillahoffmann): Add derivative with respect to x |
| x = op.inputs[0] |
| q = op.inputs[1] |
| # Broadcast gradients |
| sx = array_ops.shape(x) |
| sq = array_ops.shape(q) |
| unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) |
| # Evaluate gradient |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| q = math_ops.conj(q) |
| partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type |
| return (None, |
| array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) |
| |
| |
| @ops.RegisterGradient("Polygamma") |
| def _PolygammaGrad(op, grad): |
| """Returns gradient of psi(n, x) with respect to n and x.""" |
| # TODO(tillahoffmann): Add derivative with respect to n |
| n = op.inputs[0] |
| x = op.inputs[1] |
| # Broadcast gradients |
| sn = array_ops.shape(n) |
| sx = array_ops.shape(x) |
| unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) |
| # Evaluate gradient |
| with ops.control_dependencies([grad]): |
| n = math_ops.conj(n) |
| x = math_ops.conj(x) |
| partial_x = math_ops.polygamma(n + 1, x) |
| return (None, |
| array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) |
| |
| |
| @ops.RegisterGradient("Sigmoid") |
| def _SigmoidGrad(op, grad): |
| """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" |
| y = op.outputs[0] # y = sigmoid(x) |
| with ops.control_dependencies([grad]): |
| y = math_ops.conj(y) |
| return gen_math_ops.sigmoid_grad(y, grad) |
| |
| |
| @ops.RegisterGradient("SigmoidGrad") |
| def _SigmoidGradGrad(op, grad): |
| with ops.control_dependencies([grad]): |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| gb = grad * b |
| return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad) |
| |
| |
| @ops.RegisterGradient("Sign") |
| def _SignGrad(op, _): |
| """Returns 0.""" |
| x = op.inputs[0] |
| return array_ops.zeros_like(x) |
| |
| |
| @ops.RegisterGradient("Sin") |
| def _SinGrad(op, grad): |
| """Returns grad * cos(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return grad * math_ops.cos(x) |
| |
| |
| @ops.RegisterGradient("Cos") |
| def _CosGrad(op, grad): |
| """Returns grad * -sin(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| return -grad * math_ops.sin(x) |
| |
| |
| @ops.RegisterGradient("Tan") |
| def _TanGrad(op, grad): |
| """Returns grad * 1/sec^2(x).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| secx = math_ops.reciprocal(math_ops.cos(x)) |
| secx2 = math_ops.square(secx) |
| return secx2 * grad |
| |
| |
| @ops.RegisterGradient("Asin") |
| def _AsinGrad(op, grad): |
| """Returns grad * 1/sqrt(1-x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| den = math_ops.sqrt(math_ops.subtract(one, x2)) |
| inv = math_ops.reciprocal(den) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("Acos") |
| def _AcosGrad(op, grad): |
| """Returns grad * -1/sqrt(1-x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| den = math_ops.sqrt(math_ops.subtract(one, x2)) |
| inv = math_ops.reciprocal(den) |
| return -grad * inv |
| |
| |
| @ops.RegisterGradient("Atan") |
| def _AtanGrad(op, grad): |
| """Returns grad * 1/ (1 + x^2).""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| x = math_ops.conj(x) |
| x2 = math_ops.square(x) |
| one = constant_op.constant(1, dtype=grad.dtype) |
| inv = math_ops.reciprocal(math_ops.add(one, x2)) |
| return grad * inv |
| |
| |
| @ops.RegisterGradient("Atan2") |
| def _Atan2Grad(op, grad): |
| """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" |
| y = op.inputs[0] |
| x = op.inputs[1] |
| with ops.control_dependencies([grad]): |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad) |
| ) |
| |
| grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) |
| |
| gx = -y * grad_inv |
| if must_reduce_x: |
| gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) |
| |
| gy = x * grad_inv |
| if must_reduce_y: |
| gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) |
| return gy, gx |
| |
| |
| @ops.RegisterGradient("AddN") |
| def _AddNGrad(op, grad): |
| """Copies the gradient to all inputs.""" |
| # Not broadcasting. |
| return [grad] * len(op.inputs) |
| |
| |
| def _ShapesFullySpecifiedAndEqual(x, y, grad): |
| # pylint: disable=protected-access |
| x_shape = x._shape_tuple() |
| y_shape = y._shape_tuple() |
| grad_shape = grad._shape_tuple() |
| # pylint: enable=protected-access |
| return (x_shape == y_shape and x_shape == grad_shape and |
| x_shape is not None and None not in x_shape) |
| |
| |
| @ops.RegisterGradient("Add") |
| @ops.RegisterGradient("AddV2") |
| def _AddGrad(op, grad): |
| """Gradient for Add.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| return grad, None |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return grad, grad |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad)) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| elif not must_reduce_x: |
| gx = grad |
| else: |
| gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| elif not must_reduce_y: |
| gy = grad |
| else: |
| gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy) |
| return (gx, gy) |
| |
| |
| @ops.RegisterGradient("Sub") |
| def _SubGrad(op, grad): |
| """Gradient for Sub.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| return grad, None |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return grad, -grad |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad)) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| elif not must_reduce_x: |
| gx = grad |
| else: |
| gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| elif not must_reduce_y: |
| gy = -grad |
| else: |
| gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy) |
| return (gx, gy) |
| |
| |
| @ops.RegisterGradient("Mul") |
| def _MulGrad(op, grad): |
| """The gradient of scalar multiplication.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| return gen_math_ops.mul(grad, math_ops.conj(y)), None |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad) and |
| grad.dtype in (dtypes.int32, dtypes.float32)): |
| return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) |
| assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) |
| |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad)) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| elif not must_reduce_x: |
| gx = gen_math_ops.mul(grad, y) |
| else: |
| gx = array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx) |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| elif not must_reduce_y: |
| gy = gen_math_ops.mul(x, grad) |
| else: |
| gy = array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy) |
| return (gx, gy) |
| |
| |
| @ops.RegisterGradient("MulNoNan") |
| def _MulNoNanGrad(op, grad): |
| """The gradient of scalar multiplication with NaN-suppression.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) |
| assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy)) |
| |
| |
| @ops.RegisterGradient("Div") |
| def _DivGrad(op, grad): |
| """The gradient for the Div operator.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| # pylint: disable=invalid-unary-operand-type |
| return ( |
| array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y), |
| ry), sy)) |
| |
| |
| @ops.RegisterGradient("FloorDiv") |
| def _FloorDivGrad(_, unused_grad): |
| """The gradient for the FloorDiv operator.""" |
| return None, None |
| |
| |
| @ops.RegisterGradient("FloorMod") |
| def _FloorModGrad(op, grad): |
| """Returns grad * (1, -floor(x/y)).""" |
| x = math_ops.conj(op.inputs[0]) |
| y = math_ops.conj(op.inputs[1]) |
| |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| floor_xy = math_ops.floor_div(x, y) |
| gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) |
| gy = array_ops.reshape( |
| math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) |
| return gx, gy |
| |
| |
| @ops.RegisterGradient("TruncateDiv") |
| def _TruncateDivGrad(_, unused_grad): |
| return None, None |
| |
| |
| @ops.RegisterGradient("RealDiv") |
| def _RealDivGrad(op, grad): |
| """RealDiv op gradient.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) # pylint: disable=invalid-unary-operand-type |
| |
| |
| @ops.RegisterGradient("DivNoNan") |
| def _DivNoNanGrad(op, grad): |
| """DivNoNan op gradient.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| return ( |
| array_ops.reshape( |
| math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), |
| array_ops.reshape( |
| math_ops.reduce_sum( |
| grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), # pylint: disable=invalid-unary-operand-type |
| ry), |
| sy)) |
| |
| |
| @ops.RegisterGradient("Pow") |
| def _PowGrad(op, grad): |
| """Returns grad * (y*x^(y-1), z*log(x)).""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the |
| # constant `1` into a single constant op. |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| return grad * y * math_ops.pow(x, y - 1), None |
| |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad)) |
| x = math_ops.conj(x) |
| y = math_ops.conj(y) |
| |
| if skip_input_indices is None or 0 not in skip_input_indices: |
| gx = grad * y * math_ops.pow(x, y - 1) |
| if must_reduce_x: |
| gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) |
| else: |
| gx = None |
| |
| if skip_input_indices is None or 1 not in skip_input_indices: |
| z = math_ops.conj(op.outputs[0]) |
| |
| # Avoid false singularity at x = 0 |
| if x.dtype.is_complex: |
| # real(x) < 0 is fine for the complex case |
| mask = math_ops.not_equal(x, 0) |
| else: |
| # There's no sensible real value to return if x < 0, so return 0 |
| mask = x > 0 |
| safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) |
| log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) |
| gy = grad * z * log_x |
| if must_reduce_y: |
| gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) |
| else: |
| gy = None |
| |
| return gx, gy |
| |
| |
| def _MaximumMinimumGradInputOnly(op, grad, selector_op): |
| x = op.inputs[0] |
| y = op.inputs[1] |
| zeros = array_ops.zeros_like(grad) |
| xmask = selector_op(x, y) |
| xgrad = array_ops.where_v2(xmask, grad, zeros) |
| ygrad = None # Return None for ygrad since the config allows that. |
| return (xgrad, ygrad) |
| |
| |
| def _MaximumMinimumGrad(op, grad, selector_op): |
| """Factor out the code for the gradient of Maximum or Minimum.""" |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( |
| y): |
| # When we want to get gradients for the first input only, and the second |
| # input tensor is a scalar, we can do a much simpler calculation |
| return _MaximumMinimumGradInputOnly(op, grad, selector_op) |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| x = op.inputs[0] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| zeros = array_ops.zeros_like(grad) |
| xmask = selector_op(x, y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| else: |
| xgrad = array_ops.where_v2(xmask, grad, zeros) |
| gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) |
| |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| else: |
| ygrad = array_ops.where_v2(xmask, zeros, grad) |
| gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) |
| |
| return (gx, gy) |
| |
| |
| @ops.RegisterGradient("Maximum") |
| def _MaximumGrad(op, grad): |
| """Returns grad*(x >= y, x < y) with type of grad.""" |
| return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) |
| |
| |
| @ops.RegisterGradient("Minimum") |
| def _MinimumGrad(op, grad): |
| """Returns grad*(x <= y, x > y) with type of grad.""" |
| return _MaximumMinimumGrad(op, grad, math_ops.less_equal) |
| |
| |
| @ops.RegisterGradient("SquaredDifference") |
| def _SquaredDifferenceGrad(op, grad): |
| """Returns the gradient for (x-y)^2.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| skip_input_indices = None |
| try: |
| skip_input_indices = op.skip_input_indices |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| |
| with ops.control_dependencies([grad]): |
| # The parens ensure that if grad is IndexedSlices, it'll get multiplied by |
| # Tensor (not a number like 2.0) which causes it to convert to Tensor. |
| x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) |
| |
| if (isinstance(grad, ops.Tensor) and |
| _ShapesFullySpecifiedAndEqual(x, y, grad)): |
| return x_grad, -x_grad |
| |
| (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( |
| SmartBroadcastGradientArgs(x, y, grad)) |
| |
| if skip_input_indices is not None and 0 in skip_input_indices: |
| gx = None |
| elif must_reduce_x: |
| gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) |
| else: |
| gx = x_grad |
| |
| if skip_input_indices is not None and 1 in skip_input_indices: |
| gy = None |
| elif must_reduce_y: |
| gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) |
| else: |
| gy = -x_grad |
| return (gx, gy) |
| |
| |
| # Logical operations have no gradients. |
| ops.NotDifferentiable("Less") |
| ops.NotDifferentiable("LessEqual") |
| ops.NotDifferentiable("Greater") |
| ops.NotDifferentiable("GreaterEqual") |
| ops.NotDifferentiable("Equal") |
| ops.NotDifferentiable("ApproximateEqual") |
| ops.NotDifferentiable("NotEqual") |
| ops.NotDifferentiable("LogicalAnd") |
| ops.NotDifferentiable("LogicalOr") |
| ops.NotDifferentiable("LogicalNot") |
| |
| |
| @ops.RegisterGradient("Select") |
| def _SelectGrad(op, grad): |
| c = op.inputs[0] |
| x = op.inputs[1] |
| zeros = array_ops.zeros_like(x) |
| return (None, array_ops.where(c, grad, zeros), array_ops.where( |
| c, zeros, grad)) |
| |
| |
| @ops.RegisterGradient("SelectV2") |
| def _SelectGradV2(op, grad): |
| c = op.inputs[0] |
| x = op.inputs[1] |
| y = op.inputs[2] |
| zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype) |
| gx = array_ops.where_v2(c, grad, zeros) |
| x_shape = array_ops.shape(x) |
| output_shape = array_ops.shape(op.outputs[0]) |
| # Reduce away broadcasted leading dims. |
| reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape) |
| gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x) |
| gx = array_ops.reshape(gx, x_shape) |
| |
| gy = array_ops.where_v2(c, zeros, grad) |
| y_shape = array_ops.shape(y) |
| # Reduce away broadcasted leading dims. |
| reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape) |
| gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y) |
| gy = array_ops.reshape(gy, y_shape) |
| |
| return (None, gx, gy) |
| |
| |
| def _MatMulGradAgainstFirstOnly(op, grad): |
| """Gradient for MatMul, only for the first input.""" |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| b = math_ops.conj(op.inputs[1]) |
| if not t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) |
| elif not t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b) |
| elif t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) |
| elif t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) |
| return grad_a, None |
| |
| |
| def _MatMulGradAgainstSecondOnly(op, grad): |
| """Gradient for MatMul, only for the second input.""" |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| a = math_ops.conj(op.inputs[0]) |
| if not t_a and not t_b: |
| grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) |
| elif not t_a and t_b: |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) |
| elif t_a and not t_b: |
| grad_b = gen_math_ops.mat_mul(a, grad) |
| elif t_a and t_b: |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) |
| return None, grad_b |
| |
| |
| @ops.RegisterGradient("MatMul") |
| def _MatMulGrad(op, grad): |
| """Gradient for MatMul.""" |
| try: |
| skip_input_indices = op.skip_input_indices |
| if skip_input_indices is not None: |
| if 1 in skip_input_indices: |
| return _MatMulGradAgainstFirstOnly(op, grad) |
| elif 0 in skip_input_indices: |
| return _MatMulGradAgainstSecondOnly(op, grad) |
| except AttributeError: |
| # No gradient skipping, so do the full gradient computation |
| pass |
| |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| a = math_ops.conj(op.inputs[0]) |
| b = math_ops.conj(op.inputs[1]) |
| if not t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) |
| elif not t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(grad, b) |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) |
| elif t_a and not t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(a, grad) |
| elif t_a and t_b: |
| grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) |
| grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) |
| return grad_a, grad_b |
| |
| |
| @ops.RegisterGradient("SparseMatMul") |
| def _SparseMatMulGrad(op, grad): |
| """Gradient for SparseMatMul.""" |
| |
| t_a = op.get_attr("transpose_a") |
| t_b = op.get_attr("transpose_b") |
| is_sparse = {} |
| is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse") |
| is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse") |
| # Use heuristic to figure out if grad might be sparse |
| is_sparse[grad.ref()] = not context.executing_eagerly() and ( |
| grad.op.type == "ReluGrad") |
| |
| def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): |
| """Helper function to create SparseMatMul op.""" |
| |
| assert t1.ref() in is_sparse and t2.ref() in is_sparse |
| t1_sparse = is_sparse[t1.ref()] |
| t2_sparse = is_sparse[t2.ref()] |
| if transpose_b: |
| t2 = array_ops.transpose(t2) |
| transpose_b = False |
| prod = math_ops.matmul( |
| t1, |
| t2, |
| transpose_a=transpose_a, |
| transpose_b=transpose_b, |
| a_is_sparse=t1_sparse, |
| b_is_sparse=t2_sparse) |
| if prod.dtype != out_dtype: |
| prod = math_ops.cast(prod, out_dtype) |
| return prod |
| |
| dtype_a = op.inputs[0].dtype |
| dtype_b = op.inputs[1].dtype |
| if not t_a and not t_b: |
| return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), |
| _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) |
| elif not t_a and t_b: |
| return (_SparseMatMul(grad, op.inputs[1], dtype_a), |
| _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) |
| elif t_a and not t_b: |
| return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), |
| _SparseMatMul(op.inputs[0], grad, dtype_b)) |
| elif t_a and t_b: |
| return (_SparseMatMul( |
| op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), |
| _SparseMatMul( |
| grad, op.inputs[0], dtype_b, transpose_a=True, |
| transpose_b=True)) |
| |
| |
| @ops.RegisterGradient("Floor") |
| def _FloorGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Ceil") |
| def _CeilGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Round") |
| def _RoundGrad(_, unused_grad): |
| return [None] |
| |
| |
| @ops.RegisterGradient("Rint") |
| def _RintGrad(_, unused_grad): |
| # the gradient of Rint is zero |
| return [None] |
| |
| |
| @ops.RegisterGradient("BatchMatMul") |
| def _BatchMatMul(op, grad): |
| """Returns the gradient of x and y given the gradient of x * y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| adj_x = op.get_attr("adj_x") |
| adj_y = op.get_attr("adj_y") |
| |
| if not adj_x: |
| if not adj_y: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) |
| else: |
| if not adj_y: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) |
| |
| return grad_x, grad_y |
| |
| |
| @ops.RegisterGradient("BatchMatMulV2") |
| @ops.RegisterGradient("BatchMatMulV3") |
| def _BatchMatMulV2(op, grad): |
| """Returns the gradient of x and y given the gradient of x * y.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| adj_x = op.get_attr("adj_x") |
| adj_y = op.get_attr("adj_y") |
| |
| if not adj_x: |
| if not adj_y: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) |
| else: |
| if not adj_y: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) |
| grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) |
| else: |
| grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) |
| grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) |
| |
| # Possibly reduce along the broadcasted batch dimensions, if broadcasting |
| # is required. |
| shape_x_static = x.get_shape() |
| shape_y_static = y.get_shape() |
| output_may_have_non_empty_batch_shape = ( |
| (shape_x_static.rank is None or shape_x_static.rank > 2) or |
| (shape_y_static.rank is None or shape_y_static.rank > 2)) |
| batch_shapes_match = ( |
| shape_x_static[:-2].is_fully_defined() and |
| shape_y_static[:-2].is_fully_defined() and |
| shape_x_static[:-2] == shape_y_static[:-2]) |
| if (not output_may_have_non_empty_batch_shape) or batch_shapes_match: |
| return grad_x, grad_y |
| |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2]) |
| grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx) |
| grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy) |
| return grad_x, grad_y |
| |
| |
| ops.NotDifferentiable("Range") |
| ops.NotDifferentiable("LinSpace") |
| |
| |
| @ops.RegisterGradient("Complex") |
| def _ComplexGrad(op, grad): |
| """Returns the real and imaginary components of 'grad', respectively.""" |
| x = op.inputs[0] |
| y = op.inputs[1] |
| sx = array_ops.shape(x) |
| sy = array_ops.shape(y) |
| rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) |
| return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), |
| array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) |
| |
| |
| @ops.RegisterGradient("Real") |
| def _RealGrad(_, grad): |
| """Returns 'grad' as the real part and set the imaginary part 0.""" |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| return math_ops.complex(grad, zero) |
| |
| |
| @ops.RegisterGradient("Imag") |
| def _ImagGrad(_, grad): |
| """Returns 'grad' as the imaginary part and set the real part 0.""" |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| return math_ops.complex(zero, grad) |
| |
| |
| @ops.RegisterGradient("Angle") |
| def _AngleGrad(op, grad): |
| """Returns -grad / (Im(x) + iRe(x))""" |
| x = op.inputs[0] |
| with ops.control_dependencies([grad]): |
| re = math_ops.real(x) |
| im = math_ops.imag(x) |
| z = math_ops.reciprocal(math_ops.complex(im, re)) |
| zero = constant_op.constant(0, dtype=grad.dtype) |
| complex_grad = math_ops.complex(grad, zero) |
| return -complex_grad * z |
| |
| |
| @ops.RegisterGradient("Conj") |
| def _ConjGrad(_, grad): |
| """Returns the complex conjugate of grad.""" |
| return math_ops.conj(grad) |
| |
| |
| @ops.RegisterGradient("ComplexAbs") |
| def _ComplexAbsGrad(op, grad): |
| """Returns the gradient of ComplexAbs.""" |
| return math_ops.div_no_nan( |
| math_ops.complex( |
| grad, array_ops.zeros_like(grad)) * op.inputs[0], |
| math_ops.complex( |
| op.outputs[0], array_ops.zeros_like(op.outputs[0]))) |
| |
| |
| @ops.RegisterGradient("Cast") |
| def _CastGrad(op, grad): |
| t = [ |
| dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, |
| dtypes.complex64, dtypes.complex128 |
| ] |
| src_type = op.inputs[0].dtype.base_dtype |
| dst_type = grad.dtype.base_dtype |
| if src_type in t and dst_type in t: |
| return math_ops.cast(grad, src_type) |
| else: |
| return None |
| |
| |
| @ops.RegisterGradient("Cross") |
| def _CrossGrad(op, grad): |
| u = op.inputs[0] |
| v = op.inputs[1] |
| return (math_ops.cross(v, grad), math_ops.cross(grad, u)) |
| |
| |
| @ops.RegisterGradient("Cumsum") |
| def _CumsumGrad(op, grad): |
| axis = op.inputs[1] |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| return [ |
| math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), |
| None |
| ] |
| |
| |
| @ops.RegisterGradient("Cumprod") |
| def _CumprodGrad(op, grad): |
| x = op.inputs[0] |
| axis = op.inputs[1] |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| |
| prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) |
| out = math_ops.cumsum( |
| prod * grad, axis, exclusive=exclusive, reverse=not reverse) |
| return [math_ops.div_no_nan(out, x), None] |
| |
| |
| @ops.RegisterGradient("CumulativeLogsumexp") |
| def _CumulativeLogsumexpGrad(op, grad): |
| x = op.inputs[0] |
| axis = op.inputs[1] |
| cumulative_logsumexp = op.outputs[0] |
| |
| exclusive = op.get_attr("exclusive") |
| reverse = op.get_attr("reverse") |
| |
| # Split the incoming gradient into positive and negative part |
| # in order to take logs. This is required for stable results. |
| log_grad_positive = array_ops.where_v2( |
| math_ops.greater(grad, 0), |
| math_ops.log(grad), |
| grad.dtype.min) |
| |
| log_grad_negative = array_ops.where_v2( |
| math_ops.less(grad, 0), |
| math_ops.log(-grad), |
| grad.dtype.min) |
| |
| output_pos = math_ops.exp( |
| math_ops.cumulative_logsumexp( |
| log_grad_positive - cumulative_logsumexp, |
| axis=axis, reverse=not reverse, exclusive=exclusive) + x) |
| |
| output_neg = math_ops.exp( |
| math_ops.cumulative_logsumexp( |
| log_grad_negative - cumulative_logsumexp, |
| axis=axis, reverse=not reverse, exclusive=exclusive) + x) |
| |
| return [output_pos - output_neg, None] |
| |
| |
| @ops.RegisterGradient("NextAfter") |
| def _NextAfterGrad(op, grad): |
| """Returns gradient of nextafter(x1, x2) with respect to x1 and x2.""" |
| x1 = op.inputs[0] |
| x2 = op.inputs[1] |
| s_x1 = array_ops.shape(x1) |
| s_x2 = array_ops.shape(x2) |
| r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2) |
| with ops.control_dependencies([grad]): |
| partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype) |
| partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype) |
| return (array_ops.reshape( |
| math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1), |
| array_ops.reshape( |
| math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2)) |