| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================= |
| """Methods for rewriting while_v2 grad functions with IndexedSlices output.""" |
| |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import indexed_slices |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_conversion |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_resource_variable_ops |
| from tensorflow.python.util import nest |
| |
| |
| def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, |
| forward_inputs): |
| """Handles special case of IndexedSlices returned from while gradient. |
| |
| Some gradient functions return IndexedSlices instead of a Tensor (e.g. the |
| gradient of Gather ops). When this happens in the gradient of a while body, |
| the resulting gradient body function will have mismatched inputs and outputs, |
| since the input is a single Tensor, but the IndexedSlices gets unnested into |
| three output Tensors. |
| |
| This function fixes this by rewriting the gradient body to have three inputs |
| to match the three outputs, i.e., it effectively converts the input Tensor |
| into an input IndexedSlices. It also returns new `loop_vars` to reflect the |
| new inputs. |
| |
| Args: |
| grads: the input gradient Tensors to the while gradient computation. |
| body_grad_graph: _WhileBodyGradFuncGraph. |
| loop_vars: list of Tensors. The inputs to body_grad_graph. |
| forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While |
| op. |
| |
| Returns: |
| The new loop_vars to pass to body_grad_graph. |
| """ |
| # Match up body_grad_graph.structured_outputs with the corresponding |
| # forward_inputs. |
| # |
| # Note that we don't expect a gradient computation to have structured output |
| # (e.g. no nested lists), so no need to flatten |
| # body_grad_graph.structured_outputs. However, structured_outputs may still |
| # contain composite tensors such as IndexedSlices, unlike |
| # body_grad_graph.outputs, which contains flattened composite tensors. |
| inputs_with_grads = [ |
| t for g, t in zip(grads, forward_inputs) if g is not None |
| ] |
| # Skip loop counter, maximum_iterations and total number of loop iterations. |
| structured_outputs = body_grad_graph.structured_outputs[3:] |
| |
| for forward_input, output in zip(inputs_with_grads, structured_outputs): |
| if not isinstance(output, indexed_slices.IndexedSlices): |
| continue |
| |
| if forward_input.dtype == dtypes.resource: |
| # TODO(skyewm): In theory we should use this for all captured inputs, not |
| # just resource handles (which can only be captured). We can do this by |
| # checking that forward_input is passed straight through to its output. |
| loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, |
| forward_input, loop_vars) |
| else: |
| _rewrite_output_as_tensor(body_grad_graph, output) |
| |
| return loop_vars |
| |
| |
| def _get_tensor_index_in_iterable(iterable, t): |
| """Returns index of first occurence of `t`, raises ValueError if not found.""" |
| for i, elem in enumerate(iterable): |
| if t is elem: |
| return i |
| raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.") |
| |
| |
| def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): |
| """Rewrites grad_output_slices to be a Tensor output. |
| |
| Args: |
| body_grad_graph: _WhileBodyGradFuncGraph. |
| grad_output_slices: IndexedSlices output of body_grad_graph. |
| """ |
| with body_grad_graph.as_default(): |
| new_output = tensor_conversion.convert_to_tensor_v2(grad_output_slices) |
| |
| idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs, |
| grad_output_slices) |
| body_grad_graph.structured_outputs[idx] = new_output |
| body_grad_graph.outputs = func_graph.flatten( |
| body_grad_graph.structured_outputs) |
| |
| |
| def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, |
| forward_input, loop_vars): |
| """Rewrites grad_output_slices's corresponding input to be an IndexedSlices. |
| |
| This rewrite requires that forward_input was captured in the forward loop, |
| i.e. is not a user-specified loop variable. This is important because the |
| rewrite assumes that forward_input is passed through to its corresponding |
| output unchanged. This assumption is used in _rewrite_input_as_indexed_slices, |
| which depends on the exact gradient structure produced by the input's fanout. |
| |
| This can yield a more efficient computation than using |
| _rewrite_output_as_tensor, since it preserves the IndexedSlices structure |
| instead of converting the IndexedSlices to a dense Tensor. |
| |
| Args: |
| body_grad_graph: _WhileBodyGradFuncGraph. |
| grad_output_slices: IndexedSlices output of body_grad_graph. |
| forward_input: the corresponding Tensor input to the forward loop. |
| loop_vars: list of Tensors. The inputs to body_grad_graph. |
| |
| Returns: |
| The new loop_vars to pass to body_grad_graph. |
| """ |
| # Create initial IndexedSlices that will be the input to the grad While |
| # op. This will start as zeros, and accumulate the IndexedSlices grad output. |
| # Note that because forward_input is captured and not a loop var, its incoming |
| # gradient should always be zero. |
| init_slices = _create_grad_indexed_slices_init(grad_output_slices, |
| forward_input) |
| |
| # Create a new version of grad_output_slices's gradient computation that uses |
| # the new IndexedSlices input instead of the original Tensor input. We'll |
| # return the new computation and leave the old computation as dead code. |
| # TODO(skyewm): considering pruning body_grad_graph to remove the old |
| # computation. |
| with body_grad_graph.as_default(): |
| input_slices = indexed_slices.IndexedSlices( |
| values=body_grad_graph.capture(init_slices.values, allowlisted=True), |
| indices=body_grad_graph.capture(init_slices.indices, allowlisted=True), |
| dense_shape=body_grad_graph.capture( |
| init_slices.dense_shape, allowlisted=True)) |
| |
| # Remove the captured tensors from the function inputs. We'll add them back |
| # at the correct index in _update_indexed_slices_param. |
| for t in _flatten(init_slices): |
| captured_t = body_grad_graph.captures.pop(t) |
| body_grad_graph.inputs.remove(captured_t) |
| |
| new_output_slices = _rewrite_grad_indexed_slices_output( |
| grad_output_slices, input_slices) |
| |
| # Update body_grad_graph's inputs and outputs to reflect the new |
| # IndexedSlices computation. |
| return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices, |
| input_slices, new_output_slices, |
| grad_output_slices) |
| |
| |
| def _create_grad_indexed_slices_init(grad_output_slices, forward_input): |
| """Creates an IndexedSlices to pass as input to the while grad function. |
| |
| Args: |
| grad_output_slices: IndexedSlices. The corresponding while grad function |
| output. |
| forward_input: Tensor. The corresponding input to the forward while op. |
| |
| Returns: |
| Zeros IndexedSlices, created in current Graph. |
| """ |
| assert isinstance(grad_output_slices, indexed_slices.IndexedSlices) |
| assert isinstance(forward_input, ops.Tensor) |
| values_out = grad_output_slices.values |
| indices_out = grad_output_slices.indices |
| |
| # Create the initial values tensor. |
| if values_out.shape.is_fully_defined(): |
| values_shape = tensor_shape.TensorShape([0] + |
| values_out.shape.as_list()[1:]) |
| values = array_ops.zeros( |
| values_shape, dtype=values_out.dtype, name="values_init") |
| else: |
| if forward_input.dtype == dtypes.resource: |
| forward_shape = gen_resource_variable_ops.variable_shape(forward_input) |
| else: |
| forward_shape = array_ops.shape(forward_input) |
| values_shape = array_ops.concat([[0], forward_shape[1:]], 0) |
| values = array_ops.zeros( |
| values_shape, dtype=values_out.dtype, name="values_init") |
| |
| # Create the initial indices tensor. |
| indices = constant_op.constant([], indices_out.dtype, name="indices_init") |
| |
| # Create the initial dense_shape tensor. We assume is the same shape as |
| # forward_input, since captured tensors don't change shape across loop |
| # iterations. |
| if forward_input.dtype == dtypes.resource: |
| shape = gen_resource_variable_ops.variable_shape( |
| forward_input, name="shape_init") |
| else: |
| shape = array_ops.shape(forward_input, name="shape_init") |
| |
| return indexed_slices.IndexedSlices( |
| values=values, indices=indices, dense_shape=shape) |
| |
| |
| def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices): |
| """Creates a new version of old_output_slices with new_input_slices as input. |
| |
| This method assumes that old_output_slices.{values,indices} are produced by |
| concatenating the incoming gradient Tensor input with the IndexedSlices |
| produced by the gradient computation of the while body. See |
| backprop.aggregate_indexed_slices_gradients for where these concats are |
| constructed. We build new concats that use new_input_slices instead of the |
| original Tensor input. |
| |
| Args: |
| old_output_slices: original IndexedSlices output of while gradient. |
| new_input_slices: new IndexedSlices to use as input to while gradient. |
| |
| Returns: |
| A new IndexedSlices to replace old_output_slices. |
| """ |
| |
| def rewrite(old_output, new_input): |
| assert old_output.type == "Identity" |
| concat_op = old_output.inputs[0].op |
| assert concat_op.type == "ConcatV2" |
| # Don't include axis arg |
| old_concat_args = concat_op.inputs[:-1] |
| # We assume that the original gradient input was the first argument to the |
| # concat op. |
| # TODO(skyewm): do this in a more robust way. |
| return array_ops.concat([new_input] + old_concat_args[1:], 0) |
| |
| values = rewrite(old_output_slices.values.op, new_input_slices.values) |
| indices = rewrite(old_output_slices.indices.op, new_input_slices.indices) |
| return indexed_slices.IndexedSlices( |
| values=values, indices=indices, dense_shape=new_input_slices.dense_shape) |
| |
| |
| def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, |
| output_slices, old_output_slices): |
| """Updates graph with new IndexedSlices input/output. |
| |
| Updates graph's metadata to output the gradient computation defined by |
| init_slices, input_slices, and output_slices, instead of outputting |
| old_output_slices. Also returns a new version of loop_vars with init_slices |
| replacing the old input. |
| |
| Args: |
| graph: _WhileBodyGradFuncGraph. |
| loop_vars: the inputs to graph. |
| init_slices: the new IndexedSlices to use as input to graph. |
| input_slices: the new IndexedSlices in graph that should be fed by |
| init_slices. |
| output_slices: the new IndexedSlices in graph that should be the |
| corresponding output to input_slices. |
| old_output_slices: the IndexedSlices in graph that are currently being |
| output. |
| |
| Returns: |
| New loop_vars to pass to graph. |
| """ |
| structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, |
| old_output_slices) |
| # We assume that the component tensors of old_output_slices appear |
| # sequentially in graph.outputs. We use the first of these tensors |
| # as the reference index. |
| flat_idx = _get_tensor_index_in_iterable( |
| graph.outputs, |
| func_graph.flatten(old_output_slices)[0]) |
| |
| graph.structured_outputs[structured_idx] = output_slices |
| graph.outputs = func_graph.flatten(graph.structured_outputs) |
| |
| graph.inputs = ( |
| graph.inputs[:flat_idx] + _flatten(input_slices) + |
| graph.inputs[flat_idx + 1:]) |
| |
| return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:] |
| |
| |
| def _flatten(arg): |
| return nest.flatten(arg, expand_composites=True) |