blob: c317c3c6c5107406e4a9876787258eb09f9f5aae [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for managing state of v1 control flow for computing gradients."""
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import resource_variable_ops
# pylint: disable=protected-access
def _GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
"""Calculate a max_size for use by stack ops inside an XLA while_loop.
Args:
value: The value inside the while_loop forward context. Used for printing
error messages.
while_ctxt: The forward context inside which value resides. This does not
always match the value's immediate context, as `value` may be inside e.g.
a cond context inside the while_loop.
Returns:
A tensor containing the `max_size` to feed to a Stack initializer.
Raises:
ValueError: If `value` is nested inside a `while_loop` that either
lacks a `maximum_iterations` parameter, or the `maximum_iterations`
parameter:
- is inside a `while_loop` that is a parent of the calling context, and
- cannot be evaluated at graph build time to a constant.
"""
value_name = value.name
# curr_ctxt is the context that tf.gradients was called in.
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
max_size = constant_op.constant(1)
# Loop through all containing while contexts between value and the
# current context, multiplying together each context's
# max_iterations to get the maximum stack size.
while while_ctxt not in (None, curr_ctxt):
max_iter = while_ctxt.maximum_iterations
if max_iter is None:
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside "
"XLA while_loop because maximum_iterations was not passed to "
"the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
# pylint: disable=protected-access
max_iter_ctxt = max_iter.op._get_control_flow_context()
# pylint: enable=protected-access
# If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
max_size *= max_iter
else:
# We cannot use max_iter because it's defined in a nested while
# or cond context, so will fail if we try to use it as input to
# any ops in curr_ctxt (e.g. max_size or the final accumulator
# stack). Attempt to get a constant value out to use instead.
const_max_iter = tensor_util.constant_value(max_iter)
if const_max_iter is None:
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside XLA "
"while_loop. maximum_iterations tensor '%s' for while_loop context "
"'%s' must be statically known (e.g. a constant value or known "
"shape dimension), or be defined at or outside the while loop "
"context '%s' (currently defined in '%s')." %
(value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
max_iter_ctxt.name))
max_size *= const_max_iter
# Find the next outer WhileContext (or stop if we reach the
# tf.gradient's context).
while_ctxt = util.GetContainingWhileContext(
while_ctxt.outer_context, stop_ctxt=curr_ctxt)
return max_size
class _GradLoopState:
"""The state used for constructing the gradient graph for a while loop.
We create a _GradLoopState for each while loop in forward and its
corresponding while loop in backprop. This gives us access to both
the forward and the backprop WhileContexts.
During the construction of gradient graph, any time when we detect
a forward value that is needed for backprop, we create a history
accumulator and add it to `history_map`. Any time when we backprop
a loop switch op (in _SwitchGrad), we add the grad merge op in
`switch_map`.
"""
def __init__(self, forward_ctxt, outer_grad_state):
# The grad loop state for the outer while loop.
self._outer_grad_state = None
# The while loop context for forward.
self._forward_context = None
# The loop counter added by AddForwardLoopCounter. It is the value
# of the loop counter for the next iteration.
self._forward_index = None
# A sync op for forward.
self._forward_sync = None
# The while loop context for backprop.
self._grad_context = None
# The loop counter added by AddBackpropLoopCounter. It is the value
# of the loop counter for the current iteration.
self._grad_index = None
# A sync op for backprop.
self._grad_sync = None
# Information needed by backprop.
self._history_map = {}
self._switch_map = {}
self._unused_exits = []
self._deferred_exits = []
self._forward_loop_exits = list(forward_ctxt.loop_exits)
self._pending_exits_count = len(forward_ctxt.loop_exits)
self._outer_grad_state = outer_grad_state
if outer_grad_state:
outer_forward_ctxt = outer_grad_state.forward_context
else:
if not hasattr(forward_ctxt, "outer_context"):
raise ValueError("Failed to call gradients on a while loop without"
"properly serializing graph via MetaGraphDef")
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
if outer_forward_ctxt:
outer_forward_ctxt.Enter()
cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
if outer_forward_ctxt:
outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
# Add the backprop WhileContext, and the backprop loop counter.
if outer_grad_state:
# This is a nested loop. Remember the iteration counts for each
# execution of this inner loop.
outer_forward_ctxt.AddName(cnt.name)
history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
self._grad_context = control_flow_ops.WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
back_prop=forward_ctxt.back_prop,
swap_memory=forward_ctxt.swap_memory,
name=forward_ctxt.name,
grad_state=self)
real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
real_cnt, outer_grad_state)
outer_grad_ctxt.Exit()
else:
if outer_forward_ctxt:
outer_forward_ctxt.Enter()
self._grad_context = control_flow_ops.WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
back_prop=forward_ctxt.back_prop,
swap_memory=forward_ctxt.swap_memory,
name=forward_ctxt.name,
grad_state=self)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
cnt, outer_grad_state)
if outer_forward_ctxt:
outer_forward_ctxt.Exit()
@property
def outer_grad_state(self):
"""The grad loop state for outer loop."""
return self._outer_grad_state
@property
def forward_context(self):
"""The while loop context for forward."""
return self._forward_context
@property
def forward_index(self):
"""The loop index of forward loop."""
return self._forward_index
@property
def forward_sync(self):
"""A control trigger node for synchronization in the forward loop.
One main use is to keep the push ops of a stack executed in the
iteration order.
"""
if self._forward_sync is None:
with ops.control_dependencies(None):
self._forward_sync = control_flow_ops.control_trigger(name="f_sync")
self._forward_sync._set_control_flow_context(self._forward_context)
self._forward_index.op._add_control_input(self._forward_sync)
return self._forward_sync
@property
def grad_context(self):
"""The corresponding WhileContext for gradient."""
return self._grad_context
@property
def grad_index(self):
"""The loop index of backprop loop."""
return self._grad_index
@property
def grad_sync(self):
"""A control trigger node for synchronization in the grad loop.
One main use is to keep the pop ops of a stack executed in the
iteration order.
"""
if self._grad_sync is None:
with ops.control_dependencies(None):
self._grad_sync = control_flow_ops.control_trigger(name="b_sync")
self._grad_sync._set_control_flow_context(self._grad_context)
self._grad_index.op._add_control_input(self._grad_sync)
if self._grad_context.outer_context:
self._grad_context.outer_context.AddInnerOp(self._grad_sync)
return self._grad_sync
@property
def history_map(self):
"""The map that records all the tensors needed for backprop."""
return self._history_map
@property
def switch_map(self):
"""The map that records all the Switch ops for the while loop."""
return self._switch_map
@property
def unused_exits(self):
"""The list of "unused" exits."""
return self._unused_exits
@property
def deferred_exits(self):
"""The list of "deferred" exits."""
return self._deferred_exits
@property
def forward_loop_exits(self):
"""The list of exits of the forward loop."""
return self._forward_loop_exits
@property
def pending_exits_count(self):
"""The number of exits we expect to see but haven't."""
return self._pending_exits_count
@pending_exits_count.setter
def pending_exits_count(self, cnt):
"""Set the pending count to cnt."""
self._pending_exits_count = cnt
def AddForwardAccumulator(self, value, dead_branch=False):
"""Add an accumulator for each forward tensor that is needed in backprop.
This is added to the forward loop at the first time when a tensor
in the forward loop is used by backprop gradient computation loop.
We create an accumulator that accumulates the value of tensor at each
iteration. Called in the control flow context where gradients() is called.
The pseudocode is:
```
acc = stack();
while (_pivot) {
acc = stack_push(acc, value);
}
```
We make sure that the stack push op in one iteration is executed before
next iteration. This is achieved by adding a control edge from
`forward_index.op.inputs[0].op` to the push op, and another control
edge from the push op to either `forward_index.op` or `forward_sync`.
Args:
value: The source tensor in forward that is to be accumulated.
dead_branch: True iff the tensor is on a dead branch of a cond.
Returns:
The stack that contains the accumulated history of the tensor.
Raises:
TypeError: For internal errors involving the value condition context.
ValueError: If `value` is inside a XLA scope and a valid max size
for the stack can't be found.
"""
# curr_ctxt is the context that tf.gradients was called in.
with self._forward_index.graph.as_default():
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
with ops.control_dependencies(None):
if curr_ctxt:
curr_ctxt.Enter()
with ops.colocate_with(value):
# We only need to pass maximum_iterations to the stack if
# we're inside an XLA context.
if not util.IsInXLAContext(value.op):
max_size = constant_op.constant(-1, dtypes.int32)
else:
max_size = _GetMaxSizeFromNestedMaximumIterations(
value, self.forward_context)
acc = gen_data_flow_ops.stack_v2(
max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
if curr_ctxt:
curr_ctxt.Exit()
# Make acc available in the forward context.
enter_acc = self.forward_context.AddValue(acc)
# Add the stack_push op in the context of value.op.
swap_enabled = self.forward_context.swap_memory
value_ctxt = util.GetOutputContext(value.op)
if value_ctxt == self.forward_context:
# value is not nested in the forward context.
self.forward_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
self.forward_context.Exit()
# Protect stack push and order it before forward_index.
self.forward_index.op._add_control_input(push.op)
else:
# value is in a cond context within the forward context.
if not isinstance(value_ctxt, control_flow_ops.CondContext):
raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
if dead_branch:
# The special case for creating a zero tensor for a dead
# branch of a switch. See _ControlFlowState.ZerosLikeV1WhileLoop().
value_ctxt.outer_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
value_ctxt.outer_context.Exit()
push.op._set_control_flow_context(value_ctxt)
else:
value_ctxt.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
value_ctxt.Exit()
# Protect stack push and order it before forward_sync.
self.forward_sync._add_control_input(push.op)
# Order stack push after the successor of forward_index
add_op = self.forward_index.op.inputs[0].op
push.op._add_control_input(add_op)
return acc
def AddBackpropAccumulatedValue(self, history_value, value,
dead_branch=False):
"""Add the getter for an accumulated value in the grad context.
This is added to the backprop loop. Called in the grad context to
get the value of an accumulated value. The stack pop op must be guarded
by the pred of the controlling cond.
Args:
history_value: The history (a stack) of a value.
value: The value that is pushed onto the stack.
dead_branch: True iff the tensor is on a dead branch of a cond.
Returns:
The current value (the top of the stack).
"""
history_ctxt = history_value.op._get_control_flow_context()
# Find the cond context that controls history_value if any.
cond_ctxt = None
value_ctxt = value.op._get_control_flow_context()
while value_ctxt and value_ctxt != history_ctxt:
if isinstance(value_ctxt, control_flow_ops.CondContext):
cond_ctxt = value_ctxt
break
value_ctxt = value_ctxt.outer_context
with ops.control_dependencies(None):
self.grad_context.Enter()
if cond_ctxt:
# Guard stack pop with a switch if it is controlled by a cond.
grad_state = self
pred = None
while pred is None and grad_state:
pred = grad_state.history_map.get(cond_ctxt.pred.name)
grad_state = grad_state.outer_grad_state
if pred is None:
pred = cond_ctxt.pred
branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
history_value = control_flow_ops._SwitchRefOrTensor(
history_value, pred)[branch]
pop = gen_data_flow_ops.stack_pop_v2(history_value,
value.dtype.base_dtype)
pop.set_shape(value.get_shape())
self.grad_context.Exit()
parallel_iterations = self.grad_context.parallel_iterations
if parallel_iterations > 1:
# All pops are ordered after pivot_for_body and before grad_sync.
self.grad_sync._add_control_input(pop.op)
return pop
def GetRealValue(self, value):
"""Get the real value of `value`.
If backprop "uses" a value produced by forward inference, an accumulator
is added in the forward loop to accumulate its values. We use the
accumulated value. This method must be called in the grad loop context.
`value` must be in forward and needed for backprop.
Args:
value: A tensor to be captured.
Returns:
The same tensor obtained from the saved history.
"""
assert value.op.type not in ["Variable", "VariableV2"]
real_value = self._history_map.get(value.name)
if real_value is None:
cur_value = value
cur_grad_state = self
while True:
enter_op = util.GetLoopConstantEnter(cur_value)
if enter_op:
# Special case: cur_value comes from a constant Enter node.
cur_value = enter_op.inputs[0]
cur_grad_state = cur_grad_state.outer_grad_state
if cur_grad_state is None:
# We are now outside all nested loops for this gradient(),
# so `value` is a loop invariant and there is no need to
# save the history of value. Just make cur_value to enter
# the right control flow context.
real_value = self._grad_context.AddValue(cur_value)
break
elif constant_op.is_constant(cur_value):
# If the value to be forwarded is a constant, clone the constant in
# the gradient loop rather than using a stack.
# TODO(phawkins): consider hoisting the constant out of the loop
# instead.
real_value = constant_op.constant(
tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
break
else:
# Record the history of this value in forward_ctxt.
self._grad_context.Exit()
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
break
if real_value is None:
# Add the stack pop op in the grad context.
real_value = cur_grad_state.AddBackpropAccumulatedValue(
history_value, cur_value)
if cur_grad_state != self:
real_value = self._grad_context.AddValue(real_value)
self._history_map[value.name] = real_value
return real_value
class _ControlFlowState:
"""Maintain the mapping from the loops to their grad states."""
def __init__(self):
self._map = {} # maps forward loop context to _GradLoopState
def GetGradState(self, op, before):
"""Return the grad state for this op if it's in a forward loop context."""
if before and util.IsLoopExit(op):
forward_ctxt = op._get_control_flow_context() # pylint: disable=protected-access
forward_ctxt = forward_ctxt.outer_context
if forward_ctxt:
forward_ctxt = forward_ctxt.GetWhileContext()
else:
forward_ctxt = util.GetWhileContext(op)
if forward_ctxt:
return self._map.get(forward_ctxt)
return None
def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
"""Process all the "unused" loop exits.
The "unused" exits of the loops are added to `unused_exits`. An exit is
unused if its pending_count is 0. If there is an exit with real gradient,
all these deferred exits will enter the backprop loop with zero gradient.
Otherwise, they will enter the backprop loop with None. As an example,
people often write:
```python
v1, _ = tf.while_loop(p, b, [x1, x2])
result = gradients(v1, x1)
```
The exit node for x2 is not included by the betweenness analysis. But we
need to backprop x2 if x2 is involved in computing v1.
Args:
pending_count: The number of backprop inputs for every op.
to_ops_set: The set of ops for ys in gradients(ys, xs)
Returns:
The set of unused loop exits that we know at this point we need
to backprop.
"""
loop_exits = []
for grad_state in self._map.values():
for y in grad_state.forward_loop_exits:
if pending_count[y.op] == 0:
grad_state.pending_exits_count -= 1
if y.op not in to_ops_set:
grad_state.unused_exits.append(y)
if grad_state.pending_exits_count == 0:
loop_exits.extend(grad_state.unused_exits)
# Need to include Enters in backprop for higher-order gradients.
for y in grad_state.forward_context.loop_enters:
if pending_count[y.op] == 0:
pending_count[y.op] = 1
return loop_exits
def EnterGradWhileContext(self, op, before):
"""Enter the WhileContext for gradient computation."""
grad_state = self.GetGradState(op, before)
if grad_state:
grad_state.grad_context.Enter()
def ExitGradWhileContext(self, op, before):
"""Exit the WhileContext for gradient computation."""
grad_state = self.GetGradState(op, before)
if grad_state:
grad_state.grad_context.Exit()
def AddWhileContext(self, op, between_op_list, between_ops):
"""Add the grad state for the while loop that op belongs to.
Note that op is an Exit, and this method must be called in
the control flow context where gradients() is called.
Note that this method modifies `between_op_list` and `between_ops`.
"""
forward_ctxt = util.GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
if grad_state is None:
# This is a new while loop so create a grad state for it.
outer_forward_ctxt = forward_ctxt.outer_context
if outer_forward_ctxt:
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
outer_grad_state = None
if outer_forward_ctxt:
outer_grad_state = self._map.get(outer_forward_ctxt)
grad_state = _GradLoopState(forward_ctxt, outer_grad_state)
self._map[forward_ctxt] = grad_state
# We need to include all exits of a loop for backprop.
for loop_exit in grad_state.forward_loop_exits:
if loop_exit.op not in between_ops:
between_ops.add(loop_exit.op)
between_op_list.append(loop_exit.op)
def ZerosLikeForExit(self, val):
"""Create zeros_like gradient for a loop exit.
If the result of a loop variable is not used but is involved in
computing the result of some needed loop variable, we create a
zero-valued tensor that is fed as gradient for the Exit node of that
loop variable. Note that val.op is an Exit, and this method must be
called in the control flow context where gradients() is called.
Args:
val: The output tensor of an Exit op.
Returns:
A zero tensor of the same shape of val.
"""
val_shape = val.get_shape()
forward_ctxt = val.op._get_control_flow_context()
outer_forward_ctxt = forward_ctxt.outer_context
if outer_forward_ctxt:
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
outer_grad_state = None
if outer_forward_ctxt:
outer_grad_state = self._map.get(outer_forward_ctxt)
if outer_grad_state:
# This is a nested loop.
if val_shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor
# with the right shape in the right context.
outer_grad_state.grad_context.Enter()
result = array_ops.zeros(val_shape.dims, val.dtype)
outer_grad_state.grad_context.Exit()
else:
# Only the shape of value is needed for backprop.
forward_ctxt.outer_context.Enter()
shape = array_ops.shape_internal(val, optimize=False)
forward_ctxt.outer_context.Exit()
# Save the shape to a stack.
history_shape = outer_grad_state.AddForwardAccumulator(shape)
# Get the shape back from the stack.
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_shape, shape)
result = array_ops.zeros(real_shape, val.dtype)
outer_grad_ctxt.Exit()
else:
# This is not a nested loop.
if val_shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor
# with the right shape.
result = array_ops.zeros(val_shape.dims, val.dtype)
else:
result = array_ops.zeros_like(val, optimize=False)
return result
def ZerosLikeV1WhileLoop(self, op, index):
"""Create zeros_like for the specified output of an op.
If op is in a while loop that is part of gradients(), this method
must be called in its grad loop context.
Args:
op: A tensorflow operation.
index: the index for a specific output of the op.
Returns:
A zero tensor of the same shape of op.outputs[index].
"""
if util.IsLoopSwitch(op):
return None
if op.graph.building_function:
# The optimization here is tricky to apply to functions
return array_ops.zeros_like(op.outputs[index])
dead_branch = util.IsSwitch(op)
forward_ctxt = util.GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
if grad_state is None:
# op is not in a while loop that is part of gradients().
return ZerosLike(op, index)
op_ctxt = op._get_control_flow_context()
val = ops.convert_to_tensor(op.outputs[index], name="tensor")
shape = val.get_shape()
if shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor with
# the right shape in the grad loop context.
if val.dtype == dtypes.resource:
result = array_ops.zeros(
resource_variable_ops.variable_shape(val),
dtype=default_gradient.get_zeros_dtype(val))
else:
result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
if dead_branch:
# op is a cond switch. Guard the zero tensor with a switch.
pred = grad_state.history_map.get(op_ctxt.pred.name)
branch = op_ctxt.branch
result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch]
else:
# Unknown shape so keep a history of the shape at runtime.
if dead_branch:
# Need to add a special switch to guard the value.
pred = op_ctxt.pred
branch = op_ctxt.branch
op_ctxt.outer_context.Enter()
val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
pred)[1 - branch]
zeros_shape = array_ops.shape_internal(val, optimize=False)
op_ctxt.outer_context.Exit()
val.op._set_control_flow_context(op_ctxt)
zeros_shape.op._set_control_flow_context(op_ctxt)
else:
op_ctxt.Enter()
zeros_shape = array_ops.shape_internal(val, optimize=False)
op_ctxt.Exit()
# Add forward accumulator for shape.
grad_state.grad_context.Exit()
history_zeros_shape = grad_state.AddForwardAccumulator(
zeros_shape, dead_branch=dead_branch)
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
def PostProcessing(self):
"""Perform postprocessing at the end of gradients().
We have created the gradient graph at this point. So this function
can be used to perform any postprocessing on the gradient graph.
We currently perform the following postprocessing:
1. Patch the gradient graph if the output of a loop variable
doesn't depend on its input.
"""
for _, grad_state in self._map.items():
for _, b_merge in grad_state.switch_map.items():
if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
# The value of this loop variable at iteration i+1 doesn't
# depend on its value at iteration i. So use zeros as the
# gradients for all iterations > 0.
dtype = b_merge.op.inputs[0].dtype
shape = b_merge.op.inputs[0].get_shape()
# pylint: disable=protected-access
if shape.is_fully_defined():
grad_state.grad_context.Enter()
# Create a zeros and use it for iterations > 0.
grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
next_grad_val = control_flow_ops._NextIteration(grad_val)
grad_state.grad_context.Exit()
else:
# Create a zeros in the outer grad context.
outer_grad_ctxt = grad_state.grad_context.outer_context
if outer_grad_ctxt:
outer_grad_ctxt.Enter()
enter_grad_op = b_merge.op.inputs[0].op
enter_grad = enter_grad_op.inputs[0]
grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
grad_val = array_ops.zeros(grad_shape)
if outer_grad_ctxt:
outer_grad_ctxt.Exit()
# Use the zeros for iterations > 0.
grad_state.grad_context.Enter()
next_grad_val = control_flow_ops._NextIteration(grad_val)
grad_state.grad_context.Exit()
b_merge.op._update_input(1, next_grad_val)
# pylint: enable=protected-access
def MaybeCreateControlFlowState(between_op_list, between_ops,
colocate_gradients_with_ops):
"""Create the state for all the while loops involved in one gradients().
We create a _ControlFlowState when there are while loops involved in
gradients(). In gradients(), control flow logic is only invoked when
the _ControlFlowState is not None.
Note that this method modifies `between_op_list` and `between_ops`.
"""
loop_state = None
for op in between_op_list:
if util.IsLoopExit(op):
if loop_state is None:
loop_state = _ControlFlowState()
if colocate_gradients_with_ops:
with ops.colocate_with(op):
loop_state.AddWhileContext(op, between_op_list, between_ops)
else:
loop_state.AddWhileContext(op, between_op_list, between_ops)
return loop_state
def _ZerosLikeV1(op, index):
"""Branch of ZerosLike for TF1."""
val = op.outputs[index]
op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access
if op_ctxt:
# We are in a cond context. Use a switch to create zeros only when needed.
pred = op_ctxt.pred
branch = op_ctxt.branch
switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch]
# A op is created along the branch taken as control dependencies are on
# the whole op and not on the tensor output.
pivot = array_ops.identity(switch_val)
if val.dtype == dtypes.resource:
with ops.control_dependencies([pivot]):
return array_ops.zeros(
gen_resource_variable_ops.variable_shape(switch_val),
dtype=default_gradient.get_zeros_dtype(val))
zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
# Ensure ops created within array_ops.zeros are dominated by switch in
# cond context.
with ops.control_dependencies([pivot]):
return array_ops.zeros(zeros_shape, dtype=val.dtype)
else:
return array_ops.zeros_like(val, optimize=False)
def _ZerosLikeV2(op, index):
"""Branch of ZerosLike for TF2."""
val = op.outputs[index]
if val.dtype == dtypes.resource:
return array_ops.zeros(
gen_resource_variable_ops.variable_shape(val),
dtype=default_gradient.get_zeros_dtype(val))
if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph)
and val.dtype != dtypes.variant):
# In while_v2 we do not want to add a `ZerosLike` op because that will
# trigger accumulation of `val`. Normally `ZerosLike` is preferred because
# it helps avoid creating extra nodes(possibly Consts) for the shape.
# For variants, we must use ZerosLike.
if val.shape.is_fully_defined():
return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype)
else:
# Note: Even though we add `Shape` in the default graph, while_v2 is smart
# enough to place it in the forward graph i.e. `val.graph`.
zeros_shape = array_ops.shape_internal(val, optimize=False)
return array_ops.zeros(zeros_shape, val.dtype)
else:
return array_ops.zeros_like(val, optimize=False)
def ZerosLike(op, index):
"""Create zeros_like for the specified output of an op."""
if not util.IsSwitch(op):
return _ZerosLikeV2(op, index)
else:
return _ZerosLikeV1(op, index)