blob: f09a280a9438a3aaee57ecde0d936ec39e20f891 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""cond_v2 and gradient.
This is a version of cond that emits a single If op, as well as the gradient
function for If ops produced by cond_v2. This will eventually replace the
current tf.cond implementation once it reaches feature and performance parity.
"""
import collections
from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_optional_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import handle_data_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
# that they aren't part of the official public API. These protected members
# often need to be used by implementation code however. Rather than litter the
# code with pylint comments, we ignore protected access violations for
# readability.
# pylint: disable=protected-access
_COND = 1
_CASE = 2
def cond_v2(pred, true_fn, false_fn, name="cond"):
"""Like tf.cond, except emits a single If op."""
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool", pred)
if not name:
name = "cond"
with ops.name_scope(name) as scope:
true_name = util.unique_fn_name(scope, "true")
false_name = util.unique_fn_name(scope, "false")
# Automatic control dependencies are added in defuns, but not in v1
# graphs. Propagate that behavior here.
add_control_dependencies = ops.get_default_graph()._add_control_dependencies
pred = ops.convert_to_tensor(pred)
if (tensor_util.is_tf_type(pred) and
(pred.shape.dims is None or pred.shape.dims)):
pred = array_ops.squeeze_v2(pred)
true_graph = func_graph_module.func_graph_from_py_func(
true_name,
true_fn, [], {},
func_graph=util.CondBranchFuncGraph(
true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies,
op_return_value=pred)
false_graph = func_graph_module.func_graph_from_py_func(
false_name,
false_fn, [], {},
func_graph=util.CondBranchFuncGraph(
false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies,
op_return_value=pred)
verify_captures(_COND, [true_graph, false_graph])
return _build_cond(
pred,
true_graph,
false_graph,
true_graph.external_captures,
false_graph.external_captures,
building_gradient=False,
name=scope)
@ops.RegisterGradient("StatelessIf")
@ops.RegisterGradient("If")
def _IfGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of an If op produced by cond_v2."""
# Get the if operator (this logic handles the case where op is a MockOp)
if_op = op.outputs[0].op
true_graph, false_graph = get_func_graphs(if_op)
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
assert true_graph.outer_graph == if_op.graph
assert false_graph.outer_graph == if_op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
# functions.
true_grad_graph = _create_grad_func(
true_graph, grads, util.unique_grad_fn_name(true_graph.name))
false_grad_graph = _create_grad_func(
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
# Replaces output None grads with zeros if at least one branch has non-None
# grad at that index.
_create_zeros_for_none_grads([true_graph, false_graph],
[true_grad_graph, false_grad_graph])
if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
# Modify 'op' to output the intermediates needed by the grad functions. Note
# that all needed intermediates are wrapped in optionals. Each optional
# intermediate output will have a value iff its corresponding branch is
# taken.
# NOTE(skyewm): if there are any active sessions, this modification to `op`
# may make them unrunnable!
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(skyewm,jpienaar): can XLA support optionals?
true_intermediates = true_grad_graph.xla_intermediates
false_intermediates = false_grad_graph.xla_intermediates
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
[true_graph, false_graph], [true_intermediates, false_intermediates])
else:
true_intermediates = true_grad_graph.wrapped_intermediates
false_intermediates = false_grad_graph.wrapped_intermediates
# Make outputs match by adding none optionals.
extra_true_outputs, extra_false_outputs = _make_intermediates_match(
[true_graph, false_graph], [true_intermediates, false_intermediates])
true_graph.outputs.extend(extra_true_outputs)
false_graph.outputs.extend(extra_false_outputs)
# TODO(skyewm): indicate it's an internal bug if this fails.
_check_same_outputs(_COND, [true_graph, false_graph])
true_graph.name += "_rewritten"
false_graph.name += "_rewritten"
if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
if_op._set_func_attr("else_branch",
util.create_new_tf_function(false_graph))
if_op._set_type_list_attr("Tout", true_graph.output_types)
if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
if_op._add_outputs(
[t.dtype for t in extra_true_outputs],
[t.shape for t in extra_true_outputs])
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# This modifies true_grad_graph and false_grad_graph.
_make_output_composite_tensors_match(_COND,
[true_grad_graph, false_grad_graph])
outputs = _build_cond(
if_op.inputs[0],
true_grad_graph,
false_grad_graph,
true_grad_inputs,
false_grad_inputs,
building_gradient=True,
)
# The predicate has no gradient.
return [None] + outputs
def _build_cond(pred,
true_graph,
false_graph,
true_inputs,
false_inputs,
building_gradient,
name=None):
"""Creates an If op from the specified predicate, branch functions and inputs.
Note that this modifies true_graph and false_graph to make the inputs match,
and to output all intermediates values so they're available for the gradient
computation.
true_graph and false_graph need not have the same input types, but they must
have the same output types.
Args:
pred: boolean Tensor
true_graph: FuncGraph
false_graph: FuncGraph
true_inputs: a list of Tensors to be passed to true_graph as input.
false_inputs: a list of Tensors to be passed to false_graph as input.
building_gradient: Whether this is a gradient If op.
name: the name for the If op.
Returns:
A list of Tensors which are the outputs of the If op. Does not include added
intermediate outputs.
"""
_make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
_check_same_outputs(_COND, [true_graph, false_graph])
# Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match([true_graph, false_graph],
[true_inputs, false_inputs])
# We do not output intermediates of the gradient If op since this is just
# for backwards compatibility with existing code.
if not building_gradient and util.output_all_intermediates():
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation. Since the outputs of the two functions must
# match, we wrap all the intermediates in optionals. Each intermediate
# output will have a value iff its corresponding branch is taken.
true_intermediates = _get_intermediates(true_graph)
false_intermediates = _get_intermediates(false_graph)
# Wrap intermediates in optionals.
wrapped_true_intermediates = _wrap_intermediates(true_graph,
true_intermediates)
wrapped_false_intermediates = _wrap_intermediates(false_graph,
false_intermediates)
# Make outputs match by adding none optionals.
extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking
[true_graph, false_graph],
[wrapped_true_intermediates, wrapped_false_intermediates])
true_graph.outputs.extend(extra_true_outputs)
false_graph.outputs.extend(extra_false_outputs)
_check_same_outputs(_COND, [true_graph, false_graph])
# Create the If op.
with ops.control_dependencies(
list(true_graph.function_captures.control) + list(
false_graph.function_captures.control)):
true_stateful_ops = [
op for op in true_graph.get_operations() if op._is_stateful
]
false_stateful_ops = [
op for op in false_graph.get_operations() if op._is_stateful
]
if (true_stateful_ops or false_stateful_ops):
op_fn = gen_functional_ops._if
else:
op_fn = gen_functional_ops.stateless_if
def _make_op(inputs):
if_op, tensors = util.get_op_and_outputs(op_fn(
pred,
inputs, [t.dtype for t in true_graph.outputs],
util.create_new_tf_function(true_graph),
util.create_new_tf_function(false_graph),
output_shapes=_get_output_shapes(true_graph.outputs,
false_graph.outputs),
name=name))
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
# The true and false graphs have already been created, and we need that
# to happen before we know which tensors will be captured and so whether
# to wrap the cond in a tf.function. Post-hoc mutation of the branch
# `outer_graph` properties seems like the only option if we want to
# conditionally wrap in a function.
true_graph.outer_graph = ops.get_default_graph()
false_graph.outer_graph = ops.get_default_graph()
if_op._true_graph = true_graph
if_op._false_graph = false_graph
util.maybe_set_lowering_attr(if_op)
util.maybe_propagate_compile_time_consts_in_xla(if_op)
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
return tensors
tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
# Return identities for each output of the If op, rather than the output of
# the If op directly. This makes pruning work if the output of cond() is
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
# which if fetched will cause all ops in the taken branch to be run (since
# it takes all merge ops as input). After lowering, each output identity op
# will end up with only the appropriate merge op as input.
# TODO(b/79984175): this doesn't have to be a tuple once we covert to the
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
structured_output_specs = _get_compatible_structured_output_specs(true_graph,
false_graph)
return _pack_sequence_as(structured_output_specs, tensors)
def get_func_graphs(op):
"""Returns `FuncGraph`s for the input op branches.
Args:
op: The If or Case Operation.
Returns:
A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
for Case).
"""
def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
"""Generates and returns a FuncGraph for the given branch."""
func_graph = None
if cached_attr_name is not None:
func_graph = getattr(op, cached_attr_name, None)
inputs = op.inputs[1:] # First input is pred.
if func_graph is None:
input_shapes = [t.shape for t in inputs]
func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
for external_t, internal_t in zip(inputs, func_graph.inputs):
handle_data_util.copy_handle_data(external_t, internal_t)
func_graph.function_captures.reset_captures(inputs, func_graph.inputs)
# Link the op so that the gradient code can use it.
func_graph._forward_cond = op
return func_graph
if op.type in ["If", "StatelessIf"]:
return (_get_func_graph_for_branch(
op.get_attr("then_branch"), "_true_graph"),
_get_func_graph_for_branch(
op.get_attr("else_branch"), "_false_graph"))
elif op.type in ["Case", "StatelessCase"]:
return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i))
for i, branch_fn in enumerate(op.get_attr("branches"))]
else:
raise ValueError("Unsupported op type: {}".format(op.type))
def _get_compatible_structured_output_specs(true_graph, false_graph):
"""Returns the most specific compatible specs of graph structured outputs."""
return nest.map_structure(_get_compatible_spec,
true_graph.structured_outputs,
false_graph.structured_outputs)
def _get_compatible_spec(value_or_spec1, value_or_spec2):
"""Returns the most specific compatible spec.
Args:
value_or_spec1: A TypeSpecs or a value that has a defined TypeSpec.
value_or_spec2: A TypeSpecs or a value that has a defined TypeSpec.
Returns:
The most specific compatible TypeSpecs of the input.
Raises:
ValueError: If value_or_spec1 is not compatible with value_or_spec2.
"""
spec1 = _get_spec_for(value_or_spec1)
spec2 = _get_spec_for(value_or_spec2)
# pylint: disable=protected-access
common = spec1._without_tensor_names().most_specific_common_supertype(
[spec2._without_tensor_names()])
if common is None:
raise TypeError(f"No common supertype of {spec1} and {spec2}.")
return common
def _get_spec_for(value_or_spec):
"""Returns TypeSpec of a value or itself if it is a TypeSpec already."""
if isinstance(value_or_spec, type_spec.TypeSpec):
return value_or_spec
return type_spec.type_spec_from_value(value_or_spec)
def _grad_fn(func_graph, grads):
"""The gradient function for each conditional branch.
This function builds the gradient graph of the corresponding forward-pass
conditional branch in `func_graph`. This is done by differentiating
func_graph's outputs w.r.t. its inputs.
Args:
func_graph: FuncGraph. The corresponding forward-pass function.
grads: The list of input gradient Tensors.
Returns:
The output gradient Tensors.
"""
# Filter out untrainable function outputs.
# NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
# cause _GradientsHelper to raise an exception (e.g. the implementation
# doesn't expect 'ys' to contain boolean tensors).
assert len(func_graph.outputs) == len(grads)
ys = []
grad_ys = []
for y, grad_y in zip(func_graph.outputs, grads):
if not backprop_util.IsTrainable(y):
continue
ys.append(y)
grad_ys.append(grad_y)
# Build the gradient graph. Note that this builds the gradient computation of
# func_graph in the current graph, which requires capturing tensors from
# func_graph. The captured func_graph tensors are resolved to external tensors
# in _resolve_grad_inputs.
result = gradients_util._GradientsHelper(
ys, func_graph.inputs, grad_ys=grad_ys,
src_graph=func_graph)
return result
def _create_grad_func(func_graph, grads, name):
"""Returns the FuncGraph representation of _grad_fn."""
return func_graph_module.func_graph_from_py_func(
name,
lambda: _grad_fn(func_graph, grads), [], {},
func_graph=_CondGradFuncGraph(name, func_graph))
def _resolve_grad_inputs(cond_graph, grad_graph):
"""Returns the tensors to pass as inputs to `grad_graph`.
The `grad_graph` may have external references to
1. Its outer graph containing the input gradients. These references are kept
as is.
2. Tensors in the forward pass graph. These tensors may not be "live"
when the gradient is being computed. We replace such references by their
corresponding tensor in `cond_graph.outer_graph`. In the case of nested
control flow or functions, the gradient logic handling
`grad_graph.outer_graph` will make sure the tensor from
`cond_graph.outer_graph` is also correctly captured.
Args:
cond_graph: FuncGraph. The forward-pass function.
grad_graph: FuncGraph. The gradients function.
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
new_inputs = []
for t in grad_graph.external_captures:
# `t` must either be in `grad_graph.outer_graph` or in the forward
# `cond_graph`.
if t.graph != grad_graph.outer_graph:
assert t.graph == cond_graph
# `internal_captures` are not treated as intermediates and hence not added
# to If op outputs. So we get the outer tensor corresponding to those
# from the list of `external_captures`.
for i, output in enumerate(t.graph.outputs):
if output is t:
t = t.graph._forward_cond.outputs[i]
break
else:
for i, output in enumerate(t.graph.internal_captures):
if output is t:
t = t.graph.external_captures[i]
break
else:
raise ValueError("Could not find external tensor capture {tensor} in "
"captures or outputs".format(tensor=t))
# Note: We rely on the capturing logic of the gradient If op graph to
# correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
# and while_v2 handle this while building their gradient functions.
assert t.graph == cond_graph.outer_graph
new_inputs.append(t)
return new_inputs
def _get_intermediates(func_graph):
"""Returns intermediate tensors of `func_graph` for gradient computation."""
intermediates = []
for op in func_graph.get_operations():
for t in op.outputs:
if t in func_graph.inputs: continue
if t in func_graph.outputs: continue
if t.dtype is dtypes.resource:
continue
# Accumulating mutexes can cause deadlock.
if op.type == "MutexLock":
continue
intermediates.append(t)
return intermediates
def _make_intermediates_match(branch_graphs, branch_optionals):
"""Returns new optionals lists that have matching signatures.
This is done by mirroring each list in the other using none optionals.
There is no merging of like optionals.
Args:
branch_graphs: `list` of `FuncGraph`.
branch_optionals: `list` of `list`s of optional `Tensor`s from other
branch_graphs
Returns:
A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
same number of `Tensor`s, all of which will be optionals of the same
shape/type.
"""
new_branch_optionals = []
# Since the intermediates are optionals with dtype variant, we only need
# enough room for the longest list of intermediates.
intermediates_size = max(len(o) for o in branch_optionals)
for i, branch_graph in enumerate(branch_graphs):
other_optionals = _create_none_optionals(
branch_graph, intermediates_size - len(branch_optionals[i]))
new_branch_optionals.append(branch_optionals[i] + other_optionals)
return new_branch_optionals
def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
"""Like _make_intermediates_match but for the XLA case."""
new_branch_intermediates = []
for i, branch_graph in enumerate(branch_graphs):
other_fakeparams = _create_fakeparams(
branch_graph,
sum((bi for bi in branch_intermediates
if bi is not branch_intermediates[i]), []))
num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
new_branch_intermediates.append(other_fakeparams[:num_preceding] +
branch_intermediates[i] +
other_fakeparams[num_preceding:])
return new_branch_intermediates
def _make_inputs_match(branch_graphs, branch_inputs):
"""Modifies branch_graphs so they have the same input signature.
This method reorders and/or adds parameters to each graph in branch_graphs so
they have the same input signature, and updates the 'inputs' and 'captured'
fields of each graph accordingly. It uses the input tensors from the outer
graph to avoid duplicating shared arguments.
Args:
branch_graphs: a `list` of `FuncGraph`
branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
inputs for the corresponding graph in `branch_graphs`.
Returns:
A new list of Tensors from the outer graph that are the new inputs for each
branch_graph. This is a deduped version of `sum(branch_inputs)`.
"""
assert len(branch_graphs) == len(branch_inputs)
added_inputs = set()
new_inputs = []
for branch_in in branch_inputs:
for tensor in branch_in:
tensor_id = ops.tensor_id(tensor)
if tensor_id not in added_inputs:
added_inputs.add(tensor_id)
new_inputs.append(tensor)
for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
input_ids = [ops.tensor_id(t) for t in branch_in]
branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
input_list = []
for in_t in new_inputs:
param = branch_input_to_param.get(ops.tensor_id(in_t))
if param is None:
param = _create_dummy_input(branch_graph, in_t)
input_list.append(param)
branch_graph.inputs = input_list
# Rewrite the FuncGraphs' state to reflect the new inputs.
branch_graph.function_captures.reset_captures(
new_inputs, branch_graph.inputs)
return new_inputs
def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
"""Creates zeros for None out grads if at least one branch has non-None grad.
Args:
forward_graphs: List of forward FuncGraphs.
grad_graphs: List of grad FuncGraphs.
"""
assert len(forward_graphs) == len(grad_graphs)
branch_outputs = [g.structured_outputs for g in grad_graphs]
num_outputs_per_branch = [len(outs) for outs in branch_outputs]
assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
if (any(t is None for t in branch_outs) and
any(t is not None for t in branch_outs)):
for branch_index, t in enumerate(branch_outs):
if t is None:
with grad_graphs[branch_index].as_default():
zeros = default_gradient.zeros_like(
forward_graphs[branch_index].inputs[output_idx])
grad_graphs[branch_index].structured_outputs[output_idx] = zeros
for grad_graph in grad_graphs:
grad_graph.outputs = [
t for t in func_graph_module.flatten(grad_graph.structured_outputs)
if t is not None
]
def _make_output_composite_tensors_match(op_type, branch_graphs):
"""Modifies each branch_graph's outputs to have the same output signature.
Currently the only transformation implemented is turning a Tensor into an
equivalent IndexedSlices if the other branch returns an IndexedSlices.
Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
branch_graphs.
Args:
op_type: _COND or _CASE
branch_graphs: `list` of `FuncGraph`
Raises:
TypeError: if a set of outputs cannot be rewritten.
"""
# Note: since this is only used for gradient graphs, we do not expect the
# outputs to be structured (e.g. nested lists), and thus do not need to use
# nest.flatten, etc.
assert branch_graphs
branch_outputs = [g.structured_outputs for g in branch_graphs]
outputs_per_branch = list(len(outs) for outs in branch_outputs)
assert len(set(outputs_per_branch)) == 1, outputs_per_branch
for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
if len(set(type(out) for out in branch_outs)) == 1:
continue
if not any(
isinstance(out, indexed_slices.IndexedSlices) for out in branch_outs):
continue
for branch_idx, branch_out in enumerate(branch_outs):
if isinstance(branch_out, indexed_slices.IndexedSlices):
continue
elif isinstance(branch_out, ops.Tensor):
with branch_graphs[branch_idx].as_default():
branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
branch_out)
else:
raise TypeError(
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
" outputs from all branches: {outputs}".format(
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
output_idx=output_idx,
outputs=branch_outs))
for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
branch_graph.structured_outputs = branch_outs
branch_graph.outputs = [
t for t in func_graph_module.flatten(branch_outs) if t is not None
]
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
"""Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
assert branch_graphs
# Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
indexed_slice_indices = []
current_index = 0
# Note that this still contains Nones. We leave those in so that error
# messages contain the correct indices. We handle the Nones later when
# updating `current_index`.
branch_outputs_flat_with_composites = [
nest.flatten(branch_graph.structured_outputs, expand_composites=False)
for branch_graph in branch_graphs
]
outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
assert len(set(outs_per_branch)) == 1, outs_per_branch
# Store indices of IndexedSlices.indices in `indexed_slice_indices`.
for output_idx, branch_outs in enumerate(
zip(*branch_outputs_flat_with_composites)):
if len(
set(
isinstance(out, indexed_slices.IndexedSlices)
for out in branch_outs)) != 1:
raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
" branches returned: {outputs}".format(
op_name="cond" if op_type == _COND else "switch_case",
output_idx=output_idx,
outputs=branch_outs))
if isinstance(branch_outs[0], indexed_slices.IndexedSlices):
# indices is the second component of the composite tensor.
indexed_slice_indices.append(current_index + 1)
if nest.is_nested_or_composite(branch_outs[0]):
current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
elif branch_outs[0] is not None:
# `FuncGraph.outputs` does not contain Nones so no need to update the
# counter in that case.
current_index += 1
if not indexed_slice_indices:
return
# `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
# the Nones.
if current_index != len(branch_graphs[0].outputs):
raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
"Expected: %i\n"
"Actual: %i" %
(current_index, len(branch_graphs[0].outputs)))
# Cast indices with mismatching types to int64.
for index in indexed_slice_indices:
if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
for bg in branch_graphs):
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
"Found: %s" %
str([bg.outputs[index].dtype for bg in branch_graphs]))
if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
for branch_graph in branch_graphs:
if branch_graph.outputs[index].dtype == dtypes.int32:
with branch_graph.as_default():
branch_graph.outputs[index] = math_ops.cast(
branch_graph.outputs[index], dtypes.int64)
for branch_graph in branch_graphs:
branch_graph.structured_outputs = _pack_sequence_as(
branch_graph.structured_outputs, branch_graph.outputs)
def _pack_sequence_as(structured_outputs, op_outputs):
"""Packs the outputs of the gradient If/Case op.
The branch functions may contain None's in the list of `structured_outputs`.
`op_outputs` has those outputs missing. So we need to add those Nones to the
list of `op_outputs` and then pack it in the same structure as
`structured_outputs`.
Args:
structured_outputs: structured_outputs from one of the branch functions.
op_outputs: List of output tensors of the op.
Returns:
`op_outputs` packed like `structured_outputs`.
"""
outputs_with_nones = []
counter = 0
for output in nest.flatten(structured_outputs, expand_composites=True):
if output is None:
outputs_with_nones.append(None)
else:
outputs_with_nones.append(op_outputs[counter])
counter += 1
return func_graph_module.pack_sequence_as(structured_outputs,
outputs_with_nones)
def _wrap_intermediates(func_graph, intermediates):
with func_graph.as_default():
return [gen_optional_ops.optional_from_value([t]) for t in intermediates]
def _create_dummy_input(func_graph, template_tensor):
"""Creates tensors in func_graph to represent template_tensors.
Args:
func_graph: FuncGraph.
template_tensor: a tensor in the outer graph.
Returns:
A tensor in func_graph.
"""
with func_graph.as_default():
return array_ops.placeholder(
template_tensor.dtype, shape=template_tensor.shape)
def _create_none_optionals(func_graph, n):
"""Creates `n` `None` optionals in func_graph.
Args:
func_graph: FuncGraph.
n: `int` the number of `None` optionals to make.
Returns:
A list of tensors in func_graph.
"""
with func_graph.as_default():
return [gen_optional_ops.optional_none() for _ in range(n)]
# TODO(b/265317139): remove this function and move this dynamic dimension
# handling logic to XLA once XLA shape is ready for dynamic dimensions.
def _convert_dynamic_dimension_to_zero(shape):
"""Converts dynamic dimensions in `shape` to zero.
The fake params created to match the intermediates captured in other branches
could have dynamic dimensions. But the XLA shape is not able to handle
dynamic dimensions in TF TensorShape. Setting the dynamic dimensions to
size zero will help avoid failing safety checks in bridge. When XLA
DynamicConditional op reconciles branch differences, XLA will replace the
dimension size 0 with a bounded dimension determined from the shape of
real argument in the other branch.
Note: Rank unknown shapes are returned as they are.
Args:
shape: The TensorShape of fake param.
Returns:
The new TensorShape with dynamic dimensions set to zero.
"""
if shape.rank is None:
return shape
return tensor_shape.TensorShape(
[0 if d is None else d for d in shape.as_list()]
)
def _create_fakeparams(func_graph, template_tensors):
"""Creates FakeParams for the XLA case."""
with func_graph.as_default():
return [
gen_functional_ops.fake_param(
dtype=t.dtype, shape=_convert_dynamic_dimension_to_zero(t.shape))
for t in template_tensors]
def _check_same_outputs(op_type, graphs):
"""Raises an error if `graphs` have different outputs."""
def error(branch_idx, error_detail):
raise TypeError(
"{b0_name} and {bn_name} arguments to {op_name} must have the same "
"number, type, and overall structure of return values.\n"
"\n"
"{b0_name} output: {b0_out}\n"
"{bn_name} output: {bn_out}\n"
"\n"
"Error details:\n"
"{detail}".format(
b0_name="true_fn" if op_type == _COND else "branches[0]",
bn_name=("false_fn" if op_type == _COND else
"branches[{}]".format(branch_idx)),
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
b0_out=graphs[0].structured_outputs,
bn_out=graphs[branch_idx].structured_outputs,
detail=error_detail))
for b in range(1, len(graphs)):
try:
nest.assert_same_structure(
graphs[0].structured_outputs,
graphs[b].structured_outputs,
expand_composites=True)
except (ValueError, TypeError) as e:
error(b, str(e))
op_type_str = "cond" if op_type == _COND else "case"
if len(graphs[0].outputs) != len(graphs[b].outputs):
raise ValueError("Lengths of branch outputs of {op_type} must match.\n"
"len(graphs[0].outputs): {len_0}\n"
"len(graphs[{b}].outputs): {len_b}\n".format(
op_type=op_type_str,
len_0=len(graphs[0].outputs),
b=b,
len_b=len(graphs[b].outputs)))
for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
if b0_out.dtype != bn_out.dtype:
error(b, "%s and %s have different types" % (b0_out, bn_out))
def _get_output_shapes(*branch_graph_outputs):
output_shapes = []
for out_by_branch in zip(*branch_graph_outputs):
shape = out_by_branch[0].shape
for other_out in out_by_branch[1:]:
shape = shape.most_specific_compatible_shape(other_out.shape)
output_shapes.append(shape)
return output_shapes
def _copy_handle_data(external_tensors, *branch_graph_outputs):
"""Combines shapes in handle data and sets metadata on `external_tensors`."""
for tensors in zip(external_tensors, *branch_graph_outputs):
external = tensors[0]
internal = tensors[1:]
internal_handle_data = []
for tensor in internal:
handle_data = handle_data_util.get_resource_handle_data(tensor)
# NOTE: Assumes handle data has only one ShapeAndType entry. It's
# unclear how to combine different lengths across branches.
if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
break
internal_handle_data.append(handle_data)
else: # There is handle data, so we need to combine it.
combined_shape = tensor_shape.TensorShape(None)
combined_dtype = None
for handle_data in internal_handle_data:
handle_shape = tensor_shape.TensorShape(
handle_data.shape_and_type[0].shape)
combined_shape = combined_shape.most_specific_compatible_shape(
handle_shape)
if combined_dtype is None:
combined_dtype = handle_data.shape_and_type[0].dtype
elif handle_data.shape_and_type[0].dtype != combined_dtype:
# Variants from different branches have different dtypes. The
# combined variant has no static dtype.
combined_dtype = types_pb2.DT_INVALID
combined_handle_data = internal_handle_data[0]
combined_handle_data.shape_and_type[0].shape.CopyFrom(
combined_shape.as_proto())
combined_handle_data.shape_and_type[0].dtype = combined_dtype
handle_data_util.set_handle_data(external, combined_handle_data)
def verify_captures(op_type, branch_graphs):
"""Verify that a branch's tensor is not accessed in another branch fn."""
# Note: It is technically not possible for lower-branch_index branches to
# capture tensors from higher-branch_index branches, because of the order of
# branch graph construction, but we check all for completeness and to
# guard against potential future changes.
other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
for i, branch_graph in enumerate(branch_graphs):
for t in branch_graph.external_captures:
if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
"branch {}".format(bi) for bi in range(len(branch_graphs))]
raise ValueError(
"Tensor {tname} in {b0name} is accessed from {b1name}.".format(
tname=t.name,
b0name=branch_names[other_branch_graphs[t.graph]],
b1name=branch_names[i]))
class _CondGradFuncGraph(util.CondBranchFuncGraph):
"""FuncGraph for the gradient function of the branch of an If op.
Handles wrapping and unwrapping intermediate values that are captured by the
gradient computation in optionals.
Attributes:
op_needs_rewrite: True if any intermediates were captured, meaning the
forward If op needs to be written to output the wrapped intermediates.
"""
def __init__(self, name, forward_graph):
super(_CondGradFuncGraph, self).__init__(
name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access
self.op_needs_rewrite = False
self._forward_graph = forward_graph
# Maps from forward intermediate tensor -> the unwrapped captured
# intermediate.
self._indirect_captures = {}
# Maps unwrapped intermediate -> optional-wrapped intermediate in the
# forward graph.
self._wrapped_intermediates = collections.OrderedDict()
# Raw intermediates captured from the forward graph. Populated iff we're in
# an XLA context.
self._xla_intermediates = []
# Maps forward intermediate constant valued tensor's id to the constant
# created in this graph for that tensor.
self._captured_constants = {}
@property
def wrapped_intermediates(self):
"""The optional-wrapped intermediates captured from the forward graph."""
return list(self._wrapped_intermediates.values())
@property
def xla_intermediates(self):
"""Raw intermediates captured from the forward graph if XLA is enabled."""
return self._xla_intermediates
def _capture_helper(self, tensor, name):
if (tensor.graph is not self._forward_graph or
any(tensor is t for t in self._forward_graph.inputs) or
any(tensor is t for t in self._forward_graph.outputs)):
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
tensor_id = ops.tensor_id(tensor)
# If `tensor` is a graph-building time constant, we create a constant with
# the same value in the backward graph instead of capturing it.
if tensor_id in self._captured_constants:
return self._captured_constants[tensor_id]
elif constant_op.is_constant(tensor):
self._captured_constants[tensor_id] = constant_op.constant(
tensor_util.constant_value(tensor), dtype=tensor.dtype)
return self._captured_constants[tensor_id]
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so capture intermediates directly.
# TODO(skyewm,jpienaar): can XLA support optionals?
if all(tensor is not capture for capture in self.external_captures):
self.xla_intermediates.append(tensor)
self.op_needs_rewrite = True
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
captured_tensor = self._indirect_captures.get(tensor_id)
if captured_tensor is not None:
return captured_tensor
# 'tensor' is an uncaptured intermediate in the forward graph.
# If it is not a resource, we wrap it in an optional in the forward graph
# and capture the optional normally. We then unwrap the captured optional
# value in the gradient graph to get the raw intermediate value.
# If it is a resource, we trace the resource up to the input in the forward
# graph and capture that.
if tensor.dtype == dtypes.resource:
# Index of the forward graph input corresponding to the resource tensor.
index = util.resource_input_index(
tensor.name, [t.name for t in self._forward_graph.inputs],
{op.name: op.node_def for op in self._forward_graph.get_operations()},
self._forward_graph._functions)
# This gets mapped to the corresponding If op input in
# `_resolve_grad_inputs`.
captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
self._forward_graph.inputs[index], name)
else:
if tensor_id not in self._wrapped_intermediates:
# If the gradient has already been computed for this If op, 'tensor' may
# already be wrapped.
for consumer in tensor.consumers():
if (consumer.type == "OptionalFromValue" and
any(consumer.outputs[0] is output
for output in self._forward_graph.outputs)):
optional = consumer.outputs[0]
break
else:
# 'tensor' hasn't been wrapped, do it now.
with self._forward_graph.as_default():
optional = gen_optional_ops.optional_from_value([tensor])
self.op_needs_rewrite = True
self._wrapped_intermediates[tensor_id] = optional
optional = self._wrapped_intermediates[tensor_id]
captured_optional = super(_CondGradFuncGraph,
self)._capture_helper(optional, name)
captured_tensor = gen_optional_ops.optional_get_value(
captured_optional, [tensor.dtype], [tensor.shape]
)[0]
self._indirect_captures[tensor_id] = captured_tensor
return captured_tensor
def indexed_case(branch_index,
branch_fns,
name="indexed_case",
lower_using_switch_merge=None):
"""Like conv_v2, except emits a Case op instead of an If."""
if isinstance(branch_index, int):
raise TypeError("branch_index must not be a Python int", branch_index)
with ops.name_scope(name) as scope:
branch_names = [
util.unique_fn_name(scope, "branch{}".format(b))
for b in range(len(branch_fns))
]
# Automatic control dependencies are added in defuns, but not in v1
# graphs. Propagate that behavior here.
add_control_dependencies = ops.get_default_graph()._add_control_dependencies
branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
branch_graphs = []
for branch_name, branch_fn in zip(branch_names, branch_fns):
branch_graphs.append(
func_graph_module.func_graph_from_py_func(
branch_name,
branch_fn,
[],
{},
func_graph=util.CondBranchFuncGraph(
branch_name,
collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies,
op_return_value=branch_index))
verify_captures(_CASE, branch_graphs)
return _build_case(
branch_index,
branch_graphs, [g.external_captures for g in branch_graphs],
name=scope,
lower_using_switch_merge=lower_using_switch_merge)
@ops.RegisterGradient("Case")
@ops.RegisterGradient("StatelessCase")
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of a Case op produced by tf.switch_case."""
# Get the Case operator (this logic handles the case where op is a MockOp)
case_op = op.outputs[0].op
branch_graphs = get_func_graphs(case_op)
assert branch_graphs
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
for branch_graph in branch_graphs:
assert branch_graph.outer_graph == case_op.graph
# Create grad functions that compute the gradient of the branch forward
# graphs. These functions will capture tensors from the forward pass
# functions.
branch_grad_graphs = []
for branch_graph in branch_graphs:
branch_grad_graphs.append(
_create_grad_func(branch_graph, grads,
util.unique_grad_fn_name(branch_graph.name)))
# Replaces output None grads with zeros if at least one branch has non-None
# grad at that index.
_create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
if any(g.op_needs_rewrite for g in branch_grad_graphs):
# Modify 'op' to output the intermediates needed by the grad functions. Note
# that all needed intermediates are wrapped in optionals. Each optional
# intermediate output will have a value iff its corresponding branch is
# taken.
# NOTE(bjp): if there are any active sessions, this modification to `op`
# may make them unrunnable!
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(bjp,jpienaar): can XLA support optionals?
branches_intermediates = [
branch_grad_graph.xla_intermediates
for branch_grad_graph in branch_grad_graphs
]
extra_branch_outputs = _make_intermediates_match_xla(
branch_graphs, branches_intermediates)
else:
branch_intermediates = [
g.wrapped_intermediates for g in branch_grad_graphs
]
# Make outputs match by adding none optionals.
extra_branch_outputs = _make_intermediates_match(branch_graphs,
branch_intermediates)
for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
branch_graph.outputs.extend(extra_outputs)
# TODO(bjp): indicate it's an internal bug if this fails.
_check_same_outputs(_CASE, branch_graphs)
for branch_graph in branch_graphs:
branch_graph.name += "_rewritten"
case_op._set_func_list_attr("branches", [
util.create_new_tf_function(branch_graph)
for branch_graph in branch_graphs
])
case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
case_op._set_shape_list_attr("output_shapes",
branch_graphs[0].output_shapes)
case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
[t.shape for t in extra_branch_outputs[0]])
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
branches_grad_inputs = [
_resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
]
# This modifies the graphs in branch_grad_graphs.
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
try:
lowering = case_op._get_attr_bool("_lower_using_switch_merge")
except errors_impl.NotFoundError:
lowering = None
outputs = _build_case(
case_op.inputs[0],
branch_grad_graphs,
branches_grad_inputs,
name="gradient",
lower_using_switch_merge=lowering)
# The predicate has no gradient.
return [None] + outputs
def _build_case(branch_index,
branch_graphs,
branch_inputs,
name=None,
lower_using_switch_merge=None):
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
Note that this modifies `branch_graphs` to make the inputs match, and to
output all intermediates values so they're available for the gradient
computation.
`branch_graphs` need not have the same input types, but they must
have the same output types.
Args:
branch_index: integer Tensor
branch_graphs: List of FuncGraph
branch_inputs: List of lists of Tensors to be passed to corresponding
branch_graph as input.
name: the name for the Case op.
lower_using_switch_merge: Lower this op using switch merge ops (optional).
Returns:
A list of Tensors which are the outputs of the Case op. Does not include
added intermediate outputs.
"""
_make_indexed_slices_indices_types_match(_CASE, branch_graphs)
_check_same_outputs(_CASE, branch_graphs)
# Add inputs to branch_graphs to make them match. Note that this modifies the
# graphs in `branch_graphs`.
case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
stateful_ops = []
for bg in branch_graphs:
stateful_ops.extend([
op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
])
if stateful_ops:
op_fn = gen_functional_ops.case
else:
op_fn = gen_functional_ops.stateless_case
# Create the Case op.
with ops.control_dependencies(
sum((list(bg.function_captures.control) for bg in branch_graphs), [])):
def _make_op(inputs):
case_op, tensors = util.get_op_and_outputs(op_fn(
branch_index,
inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name))
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
if case_op is not None:
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
# Store the branch graphs so they can be reused during the gradient
# pass.
for i, bg in enumerate(branch_graphs):
bg.outer_graph = ops.get_default_graph()
setattr(case_op, "_branch_graph_{}".format(i), bg)
return tensors
tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of switch_case()
# is fetched: the lowering pass converts the Case outputs into IdentityN
# outputs, which if fetched will cause all ops in the taken branch to be run
# (since it takes all merge ops as input). After lowering, each output
# identity op will end up with only the appropriate merge op as input.
# TODO(b/79984175): this doesn't have to be a tuple once we covert to the
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
def _set_read_only_resource_inputs_attr(op, branch_graphs):
"""Sets the list of resource inputs which are read-only.
This is used by AutomaticControlDependencies.
Args:
op: If or Case Operation.
branch_graphs: List of branch FuncGraphs.
"""
# The first entry in `op.inputs` is the predicate which is not passed to
# branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1.
read_only_indices = set(range(len(op.inputs) - 1))
for branch_graph in branch_graphs:
assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen"
if not read_only_indices:
break
branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
branch_graph)
read_only_indices = read_only_indices.intersection(branch_read_only_indices)
# Convert indices in `branch_graphs[i].inputs` to `op.inputs`.
read_only_indices = [i + 1 for i in read_only_indices]
ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
sorted(read_only_indices))