blob: 7b8e13c8351673f10b642aa6825b3d705a4055e5 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in control_flow_ops.py."""
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable,redefined-builtin
from tensorflow.python.ops.control_flow_ops import *
# pylint: enable=wildcard-import
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO(yuanbyu): Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
enforce_shape_invariant=False)
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
# Unfortunately, we may still get None here for not trainable data types.
if zero_grad is None:
# For resource variables we get None always on the other branch, so bypass
# this.
if op.inputs[0].dtype == dtypes.resource:
return merge(
[grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
return None, None
return merge(grad, name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
ops.RegisterGradient("Switch")(_SwitchGrad)
ops.RegisterGradient("RefSwitch")(_SwitchGrad)
@ops.RegisterGradient("Merge")
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = control_flow_util.GetOutputContext(input_op)
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in range(num_inputs)]
# pylint: disable=protected-access
return [
control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in range(num_inputs)
]
# pylint: enable=protected-access
@ops.RegisterGradient("RefMerge")
def _RefMergeGrad(op, grad, _):
return _MergeGrad(op, grad, _)
@ops.RegisterGradient("Exit")
def _ExitGrad(op, grad):
"""Gradients for an exit op are calculated using an Enter op."""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if not grad_ctxt.back_prop:
# The flag `back_prop` is set by users to suppress gradient
# computation for this loop. If the attribute `back_prop` is false,
# no gradient computation.
return None
if op_ctxt.grad_state:
raise TypeError("Second-order gradient for while loops not supported.")
if isinstance(grad, ops.Tensor):
grad_ctxt.AddName(grad.name)
else:
if not isinstance(
grad, (indexed_slices.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError(f"Type {type(grad)} not supported, must be either"
"`indexed_slices.IndexedSlices` or `SparseTensor`.")
grad_ctxt.AddName(grad.values.name)
grad_ctxt.AddName(grad.indices.name)
dense_shape = grad.dense_shape
if dense_shape is not None:
grad_ctxt.AddName(dense_shape.name)
grad_ctxt.Enter()
# pylint: disable=protected-access
result = control_flow_ops._Enter(
grad, grad_ctxt.name, is_constant=False,
parallel_iterations=grad_ctxt.parallel_iterations,
name="b_exit")
# pylint: enable=protected-access
grad_ctxt.loop_enters.append(result)
grad_ctxt.Exit()
return result
ops.RegisterGradient("RefExit")(_ExitGrad)
@ops.RegisterGradient("NextIteration")
def _NextIterationGrad(_, grad):
"""A forward next_iteration is translated into a backprop identity.
Note that the backprop next_iteration is added in switch grad.
"""
return grad
@ops.RegisterGradient("RefNextIteration")
def _RefNextIterationGrad(_, grad):
return _NextIterationGrad(_, grad)
@ops.RegisterGradient("Enter")
def _EnterGrad(op, grad):
"""Gradients for an Enter are calculated using an Exit op.
For loop variables, grad is the gradient so just add an exit.
For loop invariants, we need to add an accumulator loop.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if grad_ctxt is None:
return grad
if not grad_ctxt.back_prop:
# Skip gradient computation, if the attribute `back_prop` is false.
return grad
if grad_ctxt.grad_state is None:
# Pass the gradient through if we are not in a gradient while context.
return grad
if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
if isinstance(grad, ops.Tensor):
result = grad_ctxt.AddBackpropAccumulator(op, grad)
elif isinstance(grad, indexed_slices.IndexedSlices):
result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
else:
# TODO(yuanbyu, lukasr): Add support for SparseTensor.
raise TypeError(f"Type {type(grad)} not supported,"
"must be Tensor or Indexed Slices")
else:
result = exit(grad)
grad_ctxt.loop_exits.append(result)
grad_ctxt.ExitResult([result])
return result
@ops.RegisterGradient("RefEnter")
def _RefEnterGrad(op, grad):
return _EnterGrad(op, grad)
@ops.RegisterGradient("LoopCond")
def _LoopCondGrad(_):
"""Stop backprop for the predicate of a while loop."""
return None