blob: 64f7ed8d40db6158db2ee70106ece465c7ab48fb [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.util import nest
def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
forward_inputs):
"""Handles special case of IndexedSlices returned from while gradient.
Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
gradient of Gather ops). When this happens in the gradient of a while body,
the resulting gradient body function will have mismatched inputs and outputs,
since the input is a single Tensor, but the IndexedSlices gets unnested into
three output Tensors.
This function fixes this by rewriting the gradient body to have three inputs
to match the three outputs, i.e., it effectively converts the input Tensor
into an input IndexedSlices. It also returns new `loop_vars` to reflect the
new inputs.
Args:
grads: the input gradient Tensors to the while gradient computation.
body_grad_graph: _WhileBodyGradFuncGraph.
loop_vars: list of Tensors. The inputs to body_grad_graph.
forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While
op.
Returns:
The new loop_vars to pass to body_grad_graph.
"""
# Match up body_grad_graph.structured_outputs with the corresponding
# forward_inputs.
#
# Note that we don't expect a gradient computation to have structured output
# (e.g. no nested lists), so no need to flatten
# body_grad_graph.structured_outputs. However, structured_outputs may still
# contain composite tensors such as IndexedSlices, unlike
# body_grad_graph.outputs, which contains flattened composite tensors.
inputs_with_grads = [
t for g, t in zip(grads, forward_inputs) if g is not None
]
# Skip loop counter, maximum_iterations and total number of loop iterations.
structured_outputs = body_grad_graph.structured_outputs[3:]
for forward_input, output in zip(inputs_with_grads, structured_outputs):
if not isinstance(output, indexed_slices.IndexedSlices):
continue
if forward_input.dtype == dtypes.resource:
# TODO(skyewm): In theory we should use this for all captured inputs, not
# just resource handles (which can only be captured). We can do this by
# checking that forward_input is passed straight through to its output.
loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
forward_input, loop_vars)
else:
_rewrite_output_as_tensor(body_grad_graph, output)
return loop_vars
def _get_tensor_index_in_iterable(iterable, t):
"""Returns index of first occurence of `t`, raises ValueError if not found."""
for i, elem in enumerate(iterable):
if t is elem:
return i
raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.")
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
"""Rewrites grad_output_slices to be a Tensor output.
Args:
body_grad_graph: _WhileBodyGradFuncGraph.
grad_output_slices: IndexedSlices output of body_grad_graph.
"""
with body_grad_graph.as_default():
new_output = tensor_conversion.convert_to_tensor_v2(grad_output_slices)
idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
grad_output_slices)
body_grad_graph.structured_outputs[idx] = new_output
body_grad_graph.outputs = func_graph.flatten(
body_grad_graph.structured_outputs)
def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
forward_input, loop_vars):
"""Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
This rewrite requires that forward_input was captured in the forward loop,
i.e. is not a user-specified loop variable. This is important because the
rewrite assumes that forward_input is passed through to its corresponding
output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
which depends on the exact gradient structure produced by the input's fanout.
This can yield a more efficient computation than using
_rewrite_output_as_tensor, since it preserves the IndexedSlices structure
instead of converting the IndexedSlices to a dense Tensor.
Args:
body_grad_graph: _WhileBodyGradFuncGraph.
grad_output_slices: IndexedSlices output of body_grad_graph.
forward_input: the corresponding Tensor input to the forward loop.
loop_vars: list of Tensors. The inputs to body_grad_graph.
Returns:
The new loop_vars to pass to body_grad_graph.
"""
# Create initial IndexedSlices that will be the input to the grad While
# op. This will start as zeros, and accumulate the IndexedSlices grad output.
# Note that because forward_input is captured and not a loop var, its incoming
# gradient should always be zero.
init_slices = _create_grad_indexed_slices_init(grad_output_slices,
forward_input)
# Create a new version of grad_output_slices's gradient computation that uses
# the new IndexedSlices input instead of the original Tensor input. We'll
# return the new computation and leave the old computation as dead code.
# TODO(skyewm): considering pruning body_grad_graph to remove the old
# computation.
with body_grad_graph.as_default():
input_slices = indexed_slices.IndexedSlices(
values=body_grad_graph.capture(init_slices.values, allowlisted=True),
indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
dense_shape=body_grad_graph.capture(
init_slices.dense_shape, allowlisted=True))
# Remove the captured tensors from the function inputs. We'll add them back
# at the correct index in _update_indexed_slices_param.
for t in _flatten(init_slices):
captured_t = body_grad_graph.captures.pop(t)
body_grad_graph.inputs.remove(captured_t)
new_output_slices = _rewrite_grad_indexed_slices_output(
grad_output_slices, input_slices)
# Update body_grad_graph's inputs and outputs to reflect the new
# IndexedSlices computation.
return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices,
input_slices, new_output_slices,
grad_output_slices)
def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
"""Creates an IndexedSlices to pass as input to the while grad function.
Args:
grad_output_slices: IndexedSlices. The corresponding while grad function
output.
forward_input: Tensor. The corresponding input to the forward while op.
Returns:
Zeros IndexedSlices, created in current Graph.
"""
assert isinstance(grad_output_slices, indexed_slices.IndexedSlices)
assert isinstance(forward_input, ops.Tensor)
values_out = grad_output_slices.values
indices_out = grad_output_slices.indices
# Create the initial values tensor.
if values_out.shape.is_fully_defined():
values_shape = tensor_shape.TensorShape([0] +
values_out.shape.as_list()[1:])
values = array_ops.zeros(
values_shape, dtype=values_out.dtype, name="values_init")
else:
if forward_input.dtype == dtypes.resource:
forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
else:
forward_shape = array_ops.shape(forward_input)
values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
values = array_ops.zeros(
values_shape, dtype=values_out.dtype, name="values_init")
# Create the initial indices tensor.
indices = constant_op.constant([], indices_out.dtype, name="indices_init")
# Create the initial dense_shape tensor. We assume is the same shape as
# forward_input, since captured tensors don't change shape across loop
# iterations.
if forward_input.dtype == dtypes.resource:
shape = gen_resource_variable_ops.variable_shape(
forward_input, name="shape_init")
else:
shape = array_ops.shape(forward_input, name="shape_init")
return indexed_slices.IndexedSlices(
values=values, indices=indices, dense_shape=shape)
def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
"""Creates a new version of old_output_slices with new_input_slices as input.
This method assumes that old_output_slices.{values,indices} are produced by
concatenating the incoming gradient Tensor input with the IndexedSlices
produced by the gradient computation of the while body. See
backprop.aggregate_indexed_slices_gradients for where these concats are
constructed. We build new concats that use new_input_slices instead of the
original Tensor input.
Args:
old_output_slices: original IndexedSlices output of while gradient.
new_input_slices: new IndexedSlices to use as input to while gradient.
Returns:
A new IndexedSlices to replace old_output_slices.
"""
def rewrite(old_output, new_input):
assert old_output.type == "Identity"
concat_op = old_output.inputs[0].op
assert concat_op.type == "ConcatV2"
# Don't include axis arg
old_concat_args = concat_op.inputs[:-1]
# We assume that the original gradient input was the first argument to the
# concat op.
# TODO(skyewm): do this in a more robust way.
return array_ops.concat([new_input] + old_concat_args[1:], 0)
values = rewrite(old_output_slices.values.op, new_input_slices.values)
indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
return indexed_slices.IndexedSlices(
values=values, indices=indices, dense_shape=new_input_slices.dense_shape)
def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
output_slices, old_output_slices):
"""Updates graph with new IndexedSlices input/output.
Updates graph's metadata to output the gradient computation defined by
init_slices, input_slices, and output_slices, instead of outputting
old_output_slices. Also returns a new version of loop_vars with init_slices
replacing the old input.
Args:
graph: _WhileBodyGradFuncGraph.
loop_vars: the inputs to graph.
init_slices: the new IndexedSlices to use as input to graph.
input_slices: the new IndexedSlices in graph that should be fed by
init_slices.
output_slices: the new IndexedSlices in graph that should be the
corresponding output to input_slices.
old_output_slices: the IndexedSlices in graph that are currently being
output.
Returns:
New loop_vars to pass to graph.
"""
structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
old_output_slices)
# We assume that the component tensors of old_output_slices appear
# sequentially in graph.outputs. We use the first of these tensors
# as the reference index.
flat_idx = _get_tensor_index_in_iterable(
graph.outputs,
func_graph.flatten(old_output_slices)[0])
graph.structured_outputs[structured_idx] = output_slices
graph.outputs = func_graph.flatten(graph.structured_outputs)
graph.inputs = (
graph.inputs[:flat_idx] + _flatten(input_slices) +
graph.inputs[flat_idx + 1:])
return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
def _flatten(arg):
return nest.flatten(arg, expand_composites=True)