blob: 80fcc7e51910caee46008f0c143c797c7249dc38 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Control Flow Operations.
See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
"""
# pylint: disable=g-bad-name
import abc
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
# pylint: disable=protected-access
def _Identity(tensor, name=None):
"""Return a tensor with the same shape and contents as the input tensor.
Args:
tensor: A Tensor.
name: A name for this operation (optional).
Returns:
A Tensor with the same type and value as the input Tensor.
"""
tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
# TODO(b/246438937): Remove this when we expand ResourceVariables into
# dt_resource tensors.
tensor = variable_utils.convert_variables_to_tensors(tensor)
if isinstance(tensor, ops.Tensor):
if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_array_ops.ref_identity(tensor, name=name)
else:
return array_ops.identity(tensor, name=name)
elif isinstance(tensor, composite_tensor.CompositeTensor):
return nest.map_structure(_Identity, tensor, expand_composites=True)
else:
raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
f"Received: {type(tensor)}.")
def _NextIteration(tensor, name=None):
tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
if isinstance(tensor, ops.Tensor):
if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(tensor, name=name)
else:
return next_iteration(tensor, name=name)
elif isinstance(tensor, composite_tensor.CompositeTensor):
return nest.map_structure(_NextIteration, tensor, expand_composites=True)
else:
raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
f"Received: {type(tensor)}.")
def _Enter(tensor,
frame_name,
is_constant=False,
parallel_iterations=10,
use_ref=True,
use_input_shape=True,
name=None):
"""Creates or finds a child frame, and makes `tensor` available to it.
The unique `frame_name` is used by the `Executor` to identify frames. If
`is_constant` is true, `tensor` is a constant in the child frame; otherwise
it may be changed in the child frame. At most `parallel_iterations`
iterations are run in parallel in the child frame.
Args:
tensor: The tensor to be made available to the child frame.
frame_name: The name of the child frame.
is_constant: If true, the output is constant within the child frame.
parallel_iterations: The number of iterations allowed to run in parallel.
use_ref: If true, use ref_enter if tensor is of ref type.
use_input_shape: If true, set the result's shape based on tensor's shape.
name: A name for this operation (optional).
Returns:
The same tensor as `tensor`.
Raises:
ValueError: If any tensor in `tensor` has a less specific shape
than its corresponding shape in `shape_invariant`.
"""
tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
if isinstance(tensor, ops.Tensor):
if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
result = gen_control_flow_ops.ref_enter(
tensor, frame_name, is_constant, parallel_iterations, name=name)
else:
result = gen_control_flow_ops.enter(
tensor, frame_name, is_constant, parallel_iterations, name=name)
if use_input_shape:
result.set_shape(tensor.get_shape())
return result
elif isinstance(tensor, composite_tensor.CompositeTensor):
def enter_component(t):
return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
use_input_shape)
return nest.map_structure(enter_component, tensor, expand_composites=True)
else:
raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
f"Received: {type(tensor)}.")
def exit(tensor, name=None): # pylint: disable=redefined-builtin
"""Exits the current frame to its parent frame.
Exit makes its input `tensor` available to the parent frame.
Args:
tensor: The tensor to be made available to the parent frame.
name: A name for this operation (optional).
Returns:
The same tensor as `tensor`.
"""
tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
if isinstance(tensor, ops.Tensor):
if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_control_flow_ops.ref_exit(tensor, name)
else:
return gen_control_flow_ops._exit(tensor, name)
elif isinstance(tensor, composite_tensor.CompositeTensor):
return nest.map_structure(exit, tensor, expand_composites=True)
else:
raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
f"Received: {type(tensor)}.")
def switch(data, pred, dtype=None, name=None):
"""Forwards `data` to an output determined by `pred`.
If `pred` is false, the `data` input is forwarded to the first output.
Otherwise, the data goes to the second output.
This op handles `Tensor`s and `IndexedSlices`.
Args:
data: The tensor to be forwarded to the appropriate output.
pred: A scalar that specifies which output port will receive data.
dtype: Optional element type for the returned tensor. If missing, the type
is inferred from the type of `value`.
name: A name for this operation (optional).
Returns:
`(output_false, output_true)`: If `pred` is true, data will be forwarded
to `output_true`, otherwise it goes to `output_false`.
"""
with ops.name_scope(name, "Switch", [data, pred]) as name:
data = ops.internal_convert_to_tensor_or_composite(
data, dtype=dtype, name="data", as_ref=True)
pred = ops.convert_to_tensor(pred, name="pred")
if isinstance(data, ops.Tensor):
return gen_control_flow_ops.switch(data, pred, name=name)
else:
if not isinstance(data, composite_tensor.CompositeTensor):
raise TypeError(
"'data' must be a Tensor or CompositeTensor. "
f"Received: {type(data)}.")
tensors = nest.flatten(data, expand_composites=True)
mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
mapped_f, mapped_t = zip(*mapped)
return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
nest.pack_sequence_as(data, mapped_t, expand_composites=True))
def _SwitchRefOrTensor(data, pred, name="Switch"):
"""Forwards `data` to an output determined by `pred`.
If `pred` is false, the `data` input is forwarded to the first output.
Otherwise, the data goes to the second output.
This op handles `Tensor`s and `IndexedSlices`.
Args:
data: The tensor to be forwarded to the appropriate output.
pred: A scalar that specifies which output port will receive data.
name: A name for this operation (optional).
Returns:
`(output_false, output_true)`: If `pred` is true, data will be forwarded to
`output_true`, otherwise it goes to `output_false`.
Raises:
TypeError: if data is not a Tensor or IndexedSlices
"""
data = ops.convert_to_tensor_or_composite(data, name="data")
# NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
# addresses the following scenario.
#
# Assume you execute Optimizer.apply_gradients() in a branch of a cond().
#
# 1. The update op is created inside a `with ops.colocate(var):` block
#
# 2. Some tensor `data` is captured and a switch is created in a
# `with ops.colocate_with(data):` block.
#
# with ops.colocate_with(var):
# with ops.colocate_with(data):
# op = ...
#
# var and data may be pinned to different devices, so we want to ops
# created within ops.colocate_with(data) to ignore the existing stack.
with ops.colocate_with(data, ignore_existing=True):
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_switch(data, pred, name=name)
return switch(data, pred, name=name)
def merge(inputs, name=None):
"""Returns the value of an available element of `inputs`.
This op tests each of the tensors in `inputs` in turn to determine if any of
them is available. If it finds an available tensor, it returns it and its
index in `inputs`.
It is an error if more than one tensor in `inputs` is available. If no tensor
in `inputs` is available, the returned tensor and index are not set.
This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
`Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
before merging.
Args:
inputs: The input tensors, at most one of which is available.
name: A name for this operation (optional).
Returns:
A tuple containing the chosen input tensor and its index in `inputs`.
Raises:
ValueError: If any of the inputs is None, or inputs are IndexedSlices and
some but not all have a dense_shape property.
"""
if any(inp is None for inp in inputs):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
inputs = [
ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
for inp in inputs
]
if all(isinstance(v, ops.Tensor) for v in inputs):
if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access
return gen_control_flow_ops.ref_merge(inputs, name)
else:
return gen_control_flow_ops.merge(inputs, name)
else:
# If there is a mix of tensors and indexed slices, then convert the
# tensors to indexed slices.
if all(
isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor))
for v in inputs):
inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
for v in inputs:
if not isinstance(v, composite_tensor.CompositeTensor):
raise TypeError("Type %s not supported" % type(v))
for v in inputs[1:]:
nest.assert_same_structure(inputs[0], v, expand_composites=True)
flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
merged_results = [
gen_control_flow_ops.merge(component)
for component in zip(*flat_inputs)
]
flat_merged = [tensor for (tensor, _) in merged_results]
chosen_index = merged_results[0][1]
merged_inputs = nest.pack_sequence_as(
inputs[0], flat_merged, expand_composites=True)
return (merged_inputs, chosen_index)
def _convert_tensorarray_to_flow(tensor_or_tensor_array):
if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
return tensor_or_tensor_array.flow
else:
return tensor_or_tensor_array
def _convert_flow_to_tensorarray(tensor_or_tensor_array, tensor_or_flow):
if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
return tensor_array_ops.build_ta_with_new_flow(tensor_or_tensor_array,
tensor_or_flow)
else:
return tensor_or_flow
def _convert_to_tensor_or_composite_or_tensorarray(var):
if isinstance(var, tensor_array_ops.TensorArray):
return var
return ops.convert_to_tensor_or_composite(var)
# TODO(xjun): replace this with is_subtype_of after it is landed.
def _ShapeLessThanOrEqual(shape1, shape2):
if shape2.dims is None:
return True
if shape1.ndims != shape2.ndims:
return False
for dim1, dim2 in zip(shape1.dims, shape2.dims):
if dim2.value is not None and dim1.value != dim2.value:
return False
return True
def _shape_invariant_to_type_spec(var, shape=None):
"""Converts a shape invariant to a TypeSpec.
If `var` is a TensorArray, it will first be converted to its flow.
Args:
var: The tensor, tensor array or composite tensor whose shape is described
by the shape invariant.
shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`,
then it is simply returned as-is.
Returns:
A `TypeSpec` for `var`, consistent with the given shape.
Raises:
TypeError: If `shape` is a TypeSpec and not compatible with `var`.
TypeError: If `shape` is not None, a TypeSpec, or a TensorShape.
TypeError: If `shape` is a TensorShape, `var` is a CompositeTensor, and
`var` doesn't implement the `_shape_invariant_to_type_spec` method.
"""
var = _convert_tensorarray_to_flow(var)
if shape is None:
return type_spec.type_spec_from_value(var)
elif isinstance(shape, type_spec.TypeSpec):
if not shape.is_compatible_with(var):
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
return shape
elif not isinstance(shape, tensor_shape.TensorShape):
raise TypeError(
"'shape' must be one of TypeSpec, TensorShape or None. "
f"Received: {type(shape)}")
if isinstance(var, ops.Tensor):
return tensor_spec.TensorSpec(shape, var.dtype)
else:
try:
return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access
except NotImplementedError as e:
raise TypeError(
f"To describe or constrain a {type(var).__name__}, use a "
f"{type(var._type_spec).__name__} instead of a TensorShape.") from e # pylint: disable=protected-access
def _EnforceShapeInvariant(merge_var, next_var):
"""Check if the shapes of the loops variables are invariants.
Args:
merge_var: The tensor representing the initial values of the loop
variables.
next_var: The tensor representing the values of the loop variables
after one loop iteration.
Raises:
ValueError: If any tensor in `merge_var` has a more specific shape than
its corresponding tensor in `next_var`.
"""
if isinstance(merge_var, ops.Tensor):
m_shape = merge_var.get_shape()
n_shape = next_var.get_shape()
if not _ShapeLessThanOrEqual(n_shape, m_shape):
enter = merge_var.op.inputs[0].op
assert util.IsLoopEnter(enter)
input_t = enter.inputs[0]
raise ValueError(
"Input tensor '%s' enters the loop with shape %s, but has shape %s "
"after one iteration. To allow the shape to vary across iterations, "
"use the `shape_invariants` argument of tf.while_loop to specify a "
"less-specific shape." % (input_t.name, input_t.shape, n_shape))
else:
raise TypeError("'merge_var' must be a Tensor. "
f"Received: {type(merge_var)}.")
def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
"""Add NextIteration and back edge from v to m."""
if isinstance(m, ops.Tensor):
v = ops.convert_to_tensor(v)
v = _NextIteration(v)
if enforce_shape_invariant:
# Make sure the shapes of loop outputs are correct. We do this before
# calling _update_input, which will raise a less-helpful error message if
# the types don't match.
# TODO(skyewm): call this for other cases below (needs testing)
_EnforceShapeInvariant(m, v)
m.op._update_input(1, v) # pylint: disable=protected-access
elif isinstance(m, composite_tensor.CompositeTensor):
# pylint: disable=protected-access
def update_component(m_component, v_component):
m_component.op._update_input(1, v_component)
if isinstance(m, indexed_slices.IndexedSlices):
v = math_ops._as_indexed_slices(v, optimize=False)
# pylint: enable=protected-access
v = _NextIteration(v)
return nest.map_structure(update_component, m, v, expand_composites=True)
else:
raise TypeError("'m' must be a Tensor or CompositeTensor. "
f"Received: {type(m)}.")
return v
class ControlFlowContext(metaclass=abc.ABCMeta):
"""The base class for control flow context.
The usage pattern is a sequence of (Enter, Exit) followed by a final
ExitResult.
We maintain the following state for control flow contexts during graph
construction:
1. graph has _control_flow_context: the current context used to
construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
2. op has _control_flow_context: the context to which the op belongs.
Set at the time the op is created. Immutable.
3. A ControlFlowContext has _outer_context: the context in which this
context is created. Set at the time a context is created. Immutable.
4. A ControlFlowContext has _context_stack.
Pushed and popped by ctxt.Enter() and ctxt.Exit()
"""
def __init__(self, values_def=None, import_scope=None):
self._nested_contexts = []
self._outer_context = ops.get_default_graph()._get_control_flow_context()
if self._outer_context:
self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access
self._context_stack = []
if values_def:
self._init_values_from_proto(values_def, import_scope=import_scope)
else:
# The names of tensors that have been already seen in this context.
self._values = set()
# The keys are the names of tensors referenced by but external to this
# context. Each value is the Tensor that should be used by this context to
# access the key value (e.g. a switch output guarding a cond input value).
self._external_values = {}
def _init_values_from_proto(self, values_def, import_scope=None):
"""Initializes values and external_values from `ValuesDef` protocol buffer.
Args:
values_def: `ValuesDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(values_def, control_flow_pb2.ValuesDef)
self._values = set(
ops.prepend_name_scope(value, import_scope)
for value in values_def.values)
g = ops.get_default_graph()
self._external_values = {}
for k, v in values_def.external_values.items():
k = ops.prepend_name_scope(k, import_scope)
self._external_values[k] = g.as_graph_element(
ops.prepend_name_scope(v, import_scope))
op_names = set([
op.split(":")[0]
for op in self._values - set(self._external_values.keys())
])
for op in op_names:
# pylint: disable=protected-access
g.as_graph_element(op)._set_control_flow_context(self)
# pylint: enable=protected-access
@property
def name(self):
return self._name
@property
def outer_context(self):
"""Return the context containing this context."""
return self._outer_context
@property
def grad_state(self):
raise NotImplementedError("Abstract method")
@property
def back_prop(self):
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def to_control_flow_context_def(self, context_def, export_scope=None):
"""Serializes this into `context_def`.
Args:
context_def: a `ControlFlowContextDef` protocol buffer.
export_scope: Optional `string`. Name scope to remove.
"""
raise NotImplementedError("Abstract method")
def _to_values_def(self, export_scope=None):
"""Converts the values to a `ValuesDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `ValuesDef` protocol buffer.
"""
values_def = control_flow_pb2.ValuesDef()
values_def.values.extend(
[ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
for k, v in self._external_values.items():
k = ops.strip_name_scope(k, export_scope)
values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
return values_def
def AddName(self, name):
self._values.add(name)
# pylint: disable=protected-access
def Enter(self):
"""Enter this control flow context."""
graph = ops.get_default_graph()
self._context_stack.append(graph._get_control_flow_context())
graph._set_control_flow_context(self)
def Exit(self):
"""Exit this control flow context."""
graph = ops.get_default_graph()
last_context = self._context_stack.pop()
graph._set_control_flow_context(last_context)
def EnterGradientColocation(self, op, gradient_uid):
"""Start building a gradient colocated with an op."""
if self._outer_context:
self._outer_context.EnterGradientColocation(op, gradient_uid)
def ExitGradientColocation(self, op, gradient_uid):
"""Start building a gradient colocated with an op."""
if self._outer_context:
self._outer_context.ExitGradientColocation(op, gradient_uid)
def ExitResult(self, result):
"""Make a list of tensors available in the outer context."""
if self._outer_context:
def fn(x):
self._outer_context.AddName(x.name)
return x
nest.map_structure(fn, result, expand_composites=True)
def GetWhileContext(self):
"""Return the while context containing this context."""
if self._outer_context:
return self._outer_context.GetWhileContext()
return None
def _RemoveExternalControlEdges(self, op):
"""Remove any external control dependency on this op."""
while_ctxt = self.GetWhileContext()
# A control input of `op` is internal if it is in the same while
# loop context as the enclosing while loop context of self.
if while_ctxt is None:
internal_control_inputs, external_control_inputs = op.control_inputs, []
else:
internal_control_inputs, external_control_inputs = [], []
for x in op.control_inputs:
ctxt = util.GetOutputContext(x)
if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
internal_control_inputs.append(x)
else:
external_control_inputs.append(x)
if len(internal_control_inputs) != len(op.control_inputs):
# TODO(mdan): perhaps there should be a replace_control_inputs()
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
return internal_control_inputs, external_control_inputs
# pylint: enable=protected-access
def AddInnerOp(self, op):
"""Notifies a scope about an operator added to an inner scope."""
if self._outer_context:
self._outer_context.AddInnerOp(op)
def GetControlPivot(self):
"""Returns the pivot node for this context, or None."""
return None
def IsWhileContext(self):
return False
def IsCondContext(self):
return False
def IsXLAContext(self):
return False
def __str__(self):
return self.name
class CondContext(ControlFlowContext):
"""The context for the conditional construct."""
def __init__(self,
pred=None,
pivot=None,
branch=None,
name="cond_text",
context_def=None,
import_scope=None):
"""Creates a `CondContext`.
Args:
pred: The `boolean` tensor for the conditional predicate.
pivot: The predicate tensor in this branch.
branch: 0 or 1 representing this branch.
name: Name of the `CondContext` python object.
context_def: Optional `ContextDef` protocol buffer to initialize the
`CondContext` object from.
import_scope: Optional `string`. Name scope to add. Only used when
initialing from protocol buffer.
"""
self._name = ops.get_default_graph().unique_name(name)
if context_def:
self._init_from_proto(context_def, import_scope=import_scope)
else:
# Initializes the default fields.
ControlFlowContext.__init__(self)
self._pred = pred # The boolean tensor for the cond predicate
self._pivot = pivot # The predicate tensor in this branch
self._branch = branch # 0 or 1 representing this branch
# Values considered to have been already seen in this context. pred is not
# included in this context.
self._values.add(pred.name)
self._external_values[pred.name] = pred
self._values.add(pivot.name)
pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `CondContext` from protocol buffer.
Args:
context_def: `CondContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(context_def, control_flow_pb2.CondContextDef)
# Create from context_def.
g = ops.get_default_graph()
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
self._pred = g.as_graph_element(
ops.prepend_name_scope(context_def.pred_name, import_scope))
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
self._branch = context_def.branch
super(CondContext, self).__init__(
values_def=context_def.values_def, import_scope=import_scope)
@property
def pred(self):
return self._pred
@property
def pivot(self):
return self._pivot
@property
def branch(self):
return self._branch
@property
def grad_state(self):
if self.GetWhileContext():
return self.GetWhileContext().grad_state
return None
@property
def back_prop(self):
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
def GetControlPivot(self):
return self._pivot
def to_proto(self, export_scope=None):
"""Converts a `CondContext` to a `CondContextDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `CondContextDef` protocol buffer.
"""
if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.CondContextDef()
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.pred_name = ops.strip_name_scope(self._pred.name,
export_scope)
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
export_scope)
context_def.branch = self._branch
context_def.values_def.MergeFrom(
super(CondContext, self)._to_values_def(export_scope))
for nested in self._nested_contexts:
nested_def = context_def.nested_contexts.add()
nested.to_control_flow_context_def(nested_def)
return context_def
else:
return None
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `CondContext` object created from `context_def`."""
ret = CondContext(context_def=context_def, import_scope=import_scope)
ret.Enter()
for nested_def in context_def.nested_contexts:
from_control_flow_context_def(nested_def, import_scope=import_scope)
ret.Exit()
return ret
def to_control_flow_context_def(self, context_def, export_scope=None):
context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context. This is needed in
# particular for nested conds.
result = self._external_values.get(val.name)
result = val if result is None else result
else:
result = val
self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
self._values.add(result.name)
self._external_values[result.name] = result
with ops.control_dependencies(None):
result = _SwitchRefOrTensor(result, self._pred)[self._branch]
if self._outer_context:
self._outer_context.AddInnerOp(result.op)
result.op.graph.prevent_fetching(result.op)
# pylint: disable=protected-access
result.op._set_control_flow_context(self)
# pylint: enable=protected-access
# Mark Switch output as seen by this context and any outer contexts,
# just like what we do for normal op outputs in _AddOpInternal() below.
ctxt = self
while ctxt is not None:
# pylint: disable=protected-access
ctxt._values.add(result.name)
ctxt = ctxt._outer_context
# pylint: enable=protected-access
self._external_values[val.name] = result
return result
def AddOp(self, op):
self._AddOpInternal(op)
def _AddOpInternal(self, op):
"""Add `op` to the current context."""
if not op.inputs:
# If we're in a while loop, remove any control inputs from outside the
# loop.
self._RemoveExternalControlEdges(op)
if not any(
util.OpInContext(input_op, self) for input_op in op.control_inputs):
# pylint: disable=protected-access
op._add_control_input(self._pivot.op)
# pylint: enable=protected-access
else:
# Make each input to 'op' available in this CondContext. If an input is
# already part of this context there's nothing to do, but if it's
# external, AddValue() will handle adding the appropriate Switch node and
# other bookkeeping.
for index in range(len(op.inputs)):
x = op.inputs[index]
if op.type == "Merge" and x.op.type == "NextIteration":
# Edge case: if we're importing a while loop inside this CondContext,
# AddValue() will not correctly handle the NextIteration inputs to
# Merge node. The problem is that the NextIteration should also be
# part of this context, but if we're importing it won't have been
# processed and added to the context yet, so AddValue() will try to
# add a Switch which results in an invalid graph. Instead, we use the
# NextIteration input as-is here, and it will eventually be added to
# the context via AddOp().
real_x = x
else:
real_x = self.AddValue(x)
if real_x != x:
# pylint: disable=protected-access
op._update_input(index, real_x)
# pylint: enable=protected-access
# Remove any external control dependency on this op.
self._RemoveExternalControlEdges(op)
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
op._add_control_input(self._pivot.op)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
output_names = [x.name for x in op.outputs]
ctxt = self
while ctxt is not None:
# pylint: disable=protected-access
ctxt._values.update(output_names)
ctxt = ctxt._outer_context
# pylint: enable=protected-access
if self._outer_context or not util.IsLoopExit(op):
op.graph.prevent_fetching(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
def _ProcessOutputTensor(self, val):
"""Process an output tensor of a conditional branch."""
real_val = val
if val.name not in self._values:
# Handle the special case of lambda: x
self._values.add(val.name)
if self._outer_context:
real_val = self._outer_context.AddValue(val)
self._values.add(real_val.name)
self._external_values[real_val.name] = real_val
real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
self._external_values[val.name] = real_val
else:
external_val = self._external_values.get(val.name)
if external_val is not None:
real_val = external_val
return real_val
def _BuildCondTensor(self, v):
if isinstance(v, ops.Operation):
# Use pivot as the proxy for this op.
return with_dependencies([v], self._pivot)
else:
v = nest.map_structure(
_convert_tensorarray_to_flow, v, expand_composites=True)
return self._ProcessOutputTensor(ops.convert_to_tensor(v))
def BuildCondBranch(self, fn):
"""Add the subgraph defined by fn() to the graph."""
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
original_result = fn()
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
elif not isinstance(original_result, ops.Operation):
original_result = variable_utils.convert_variables_to_tensors(
original_result)
original_result = nest.map_structure(
array_ops.identity, original_result, expand_composites=True)
if original_result is None:
return None, None
original_result = variable_utils.convert_variables_to_tensors(
original_result)
result = nest.map_structure(
self._BuildCondTensor, original_result, expand_composites=True)
if not isinstance(result, (list, _basetuple)):
result = [result]
return original_result, result
def IsCondContext(self):
return True
# pylint: enable=g-doc-args
# pylint: enable=redefined-outer-name
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
while t.op.inputs:
t = t.op.inputs[0]
return tensor_shape.TensorShape(t.op.get_attr("shape"))
return array_ops.shape_internal(t, optimize=False)
# TODO(yuanbyu): Consider having a unified notion of context for
# not only conditionals and loops but also control dependency and
# subgraphs.
class WhileContext(ControlFlowContext):
"""The context for the loop construct."""
def __init__(self,
maximum_iterations=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name="while_context",
grad_state=None,
context_def=None,
import_scope=None):
""""Creates a `WhileContext`.
Args:
maximum_iterations: Optional upper bound on number of loop iterations.
parallel_iterations: The number of iterations allowed to run in parallel.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
grad_state: The gradient loop state.
context_def: Optional `WhileContextDef` protocol buffer to initialize the
`Whilecontext` python object from.
import_scope: Optional `string`. Name scope to add. Only used when
initialing from protocol buffer.
"""
if context_def:
self._init_from_proto(context_def, import_scope=import_scope)
else:
ControlFlowContext.__init__(self)
self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
swap_memory, name)
# The gradient loop state.
self._grad_state = grad_state
def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
swap_memory, name):
"""Creates a new `WhileContext` from arguments.
Args:
maximum_iterations: Optional upper bound on number of loop iterations.
parallel_iterations: The number of iterations allowed to run in parallel.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
Raises:
ValueError: If `parallel_iterations` has invalid value.
"""
if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
raise ValueError("'parallel_iterations' must be a positive integer: "
"%s" % parallel_iterations)
self._name = ops.get_default_graph().unique_name(name)
self._maximum_iterations = maximum_iterations
self._parallel_iterations = parallel_iterations
self._back_prop = back_prop
self._swap_memory = swap_memory
# We use this node to control constants created by the pred lambda.
self._pivot_for_pred = None
# We use this node to control constants created by the body lambda.
self._pivot_for_body = None
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation
self._pivot = None
# The list of exit tensors for loop variables.
self._loop_exits = []
# The list of enter tensors for loop variables.
self._loop_enters = []
self._graph = ops.get_default_graph()
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `WhileContext` from protocol buffer.
Args:
context_def: `WhileContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(context_def, control_flow_pb2.WhileContextDef)
# Create from context_def.
g = ops.get_default_graph()
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
if context_def.maximum_iterations_name:
self._maximum_iterations = g.as_graph_element(
ops.prepend_name_scope(context_def.maximum_iterations_name,
import_scope))
else:
self._maximum_iterations = None
self._parallel_iterations = context_def.parallel_iterations
self._back_prop = context_def.back_prop
self._swap_memory = context_def.swap_memory
self._pivot_for_pred = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
# We use this node to control constants created by the body lambda.
self._pivot_for_body = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation.
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
# The list of exit tensors for loop variables.
self._loop_exits = [
g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
for exit_name in context_def.loop_exit_names
]
# The list of enter tensors for loop variables.
self._loop_enters = [
g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
for enter_name in context_def.loop_enter_names
]
super(WhileContext, self).__init__(
values_def=context_def.values_def, import_scope=import_scope)
# import_scope causes self.name to be different from the original serialized
# context's name. Rewrite "frame_name" attrs with the new name.
if import_scope:
for tensor_name in self._values:
op = g.as_graph_element(tensor_name).op
if util.IsLoopEnter(op):
# pylint: disable=protected-access
op._set_attr("frame_name",
attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
# pylint: enable=protected-access
self._graph = ops.get_default_graph()
@property
def maximum_iterations(self):
"""The maximum number of iterations that will be executed."""
return self._maximum_iterations
@property
def parallel_iterations(self):
"""The number of iterations allowed to run in parallel."""
return self._parallel_iterations
@property
def back_prop(self):
"""True iff backprop is enabled for this while loop."""
return self._back_prop
@property
def swap_memory(self):
"""True iff GPU-CPU memory swap is enabled for this while loop."""
return self._swap_memory
@property
def pivot(self):
"""The boolean tensor representing the loop termination condition."""
return self._pivot
@property
def loop_enters(self):
"""The list of enter tensors for loop variables."""
return self._loop_enters
@property
def loop_exits(self):
"""The list of exit tensors for loop variables."""
return self._loop_exits
@property
def grad_state(self):
"""The gradient loop state."""
return self._grad_state
def to_proto(self, export_scope=None):
"""Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `WhileContextDef` protocol buffer.
"""
if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.WhileContextDef()
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.parallel_iterations = self._parallel_iterations
if self._maximum_iterations is not None:
context_def.maximum_iterations_name = ops.strip_name_scope(
self._maximum_iterations.name, export_scope)
context_def.back_prop = self._back_prop
context_def.swap_memory = self._swap_memory
context_def.pivot_for_pred_name = ops.strip_name_scope(
self._pivot_for_pred.name, export_scope)
context_def.pivot_for_body_name = ops.strip_name_scope(
self._pivot_for_body.name, export_scope)
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
export_scope)
context_def.loop_exit_names.extend([
ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
])
context_def.loop_enter_names.extend([
ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
])
context_def.values_def.MergeFrom(
super(WhileContext, self)._to_values_def(export_scope=export_scope))
for nested in self._nested_contexts:
nested_def = context_def.nested_contexts.add()
nested.to_control_flow_context_def(nested_def)
return context_def
else:
return None
def to_control_flow_context_def(self, context_def, export_scope=None):
context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `WhileContext` object created from `context_def`.
Args:
context_def: A `WhileContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
Returns:
A `WhileContext` Python object.
"""
ret = WhileContext(context_def=context_def, import_scope=import_scope)
ret.Enter()
for nested_def in context_def.nested_contexts:
from_control_flow_context_def(nested_def, import_scope=import_scope)
ret.Exit()
return ret
def GetWhileContext(self):
return self
def GetControlPivot(self):
if self._pivot_for_body is not None:
return self._pivot_for_body
return self._pivot_for_pred
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
result = val
new_value = val.name not in self._values
# Don't treat ops in this context as new values. Usually all known values
# are in self._values, except when we're importing a while loop inside this
# WhileContext. Since there's a cycle in this case, `val` may be part of the
# imported while loop but not yet processed by this context and added to
# self._values in _AddOpInternal. We only want to process external input
# tensors to the while loop here.
new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access
if new_value:
self._values.add(val.name)
# If we are in a grad context and val is from its forward context,
# use GetRealValue(), which adds the logic to save the history of
# val in forward.
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
if grad_ctxt:
grad_ctxt = grad_ctxt.GetWhileContext()
if grad_ctxt.grad_state:
forward_ctxt = util.GetWhileContext(val.op)
if util.IsLoopExit(val.op):
forward_ctxt = forward_ctxt.outer_context
if forward_ctxt:
forward_ctxt = forward_ctxt.GetWhileContext()
if forward_ctxt == grad_ctxt.grad_state.forward_context:
real_val = grad_ctxt.grad_state.GetRealValue(val)
self._external_values[val.name] = real_val
return real_val
if self._outer_context is not None:
result = self._outer_context.AddValue(val)
# Create an Enter to make `result` known to this loop context.
with ops.control_dependencies(None):
enter = _Enter(
result,
self._name,
is_constant=True,
parallel_iterations=self._parallel_iterations)
enter.graph.prevent_feeding(enter)
if self._outer_context:
self._outer_context.AddInnerOp(enter.op)
# Fix the control inputs and control flow context of these enter ops.
self._FixControlInputsAndContext([enter])
# Add `enter` in this context.
self._values.add(enter.name)
self._external_values[val.name] = enter
result = enter
else:
actual_val = self._external_values.get(val.name)
if actual_val is not None:
result = actual_val
return result
def AddOp(self, op):
"""Add `op` to the current context."""
# For a reduction op, if op is in a grad context and its input is from
# its forward context, moving op to the forward context means we would
# store the tensor after the reduction as opposed to the tensor before
# reduction, and therefore could significantly reduce memory consumption.
# For now, we do this only for a few ops.
#
# If in XLA context, do not move constant ops to forward pass as pushing to
# and popping from a stack removes the constant property of an op and breaks
# XLA compilation, which requires certain inputs to be constant for certain
# ops.
if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
if grad_ctxt:
grad_ctxt = grad_ctxt.GetWhileContext()
if grad_ctxt.grad_state:
op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
op_input_ctxt = op.inputs[0].op._get_control_flow_context()
op._set_control_flow_context(op_input_ctxt)
op_input_ctxt._AddOpInternal(op)
return
self._AddOpInternal(op)
def _AddOpInternal(self, op):
"""Add `op` to the current context.
We move any external control dependencies of the op to the loop pivot, to
ensure they get executed.
"""
# This is needed to prevent frame mismatch errors where there are Const
# nodes inside tf.function in v1 while_loop and inlining is turned on.
if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access
if not op.inputs:
# Remove any external control dependency on this op
control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
# Add a control edge from the control pivot to this op.
if not control_inputs:
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
# pylint: enable=protected-access
for x in op.outputs:
self._values.add(x.name)
else:
for index in range(len(op.inputs)):
x = op.inputs[index]
real_x = self.AddValue(x)
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
# Remove any external control dependency on this op.
_, external_inputs = self._RemoveExternalControlEdges(op)
# Add a control dependency to prevent loop invariants from
# enabling ops that should not be executed.
self._MaybeAddControlDependency(op)
for x in op.outputs:
self._values.add(x.name)
if external_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(apassos): fix that
with ops.control_dependencies(None):
self.Enter()
external_inputs = [
array_ops.identity(x.outputs[0]).op
for x in external_inputs
if x.outputs
]
self.Exit()
op._add_control_inputs(external_inputs) # pylint: disable=protected-access
if self._outer_context or not util.IsLoopExit(op):
op.graph.prevent_fetching(op)
for x in op.outputs:
op.graph.prevent_feeding(x)
if self._outer_context:
self._outer_context.AddInnerOp(op)
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""
def _IsOpFree(op):
"""Determines if `op` needs a control dependency."""
if op.control_inputs:
return False
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
return True
# pylint: enable=protected-access
for x in op.inputs:
if not util.IsLoopConstantEnter(x.op):
return False
return True
if _IsOpFree(op):
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
# pylint: enable=protected-access
def AddForwardLoopCounter(self, outer_grad_state):
"""Adds a loop that counts the number of iterations.
This is added to the forward loop at the time when we start to
create the loop for backprop gradient computation. Called in
the outer context of this forward context.
The pseudocode is:
`n = 0; while (_pivot) { n++; }`
Note that a control dependency is added to `n` to ensure the correct
execution order of stack push ops.
Args:
outer_grad_state: The outer grad state. None if not nested.
Returns:
The number of iterations taken by the forward loop and the loop index.
"""
n = constant_op.constant(0, name="f_count")
if outer_grad_state is not None:
# Force the stack pushes of i-th execution of an inner loop to be ordered
# before the pushes of (i+1)-th execution of the same inner loop.
outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
n.op._add_control_input(outer_add_op) # pylint: disable=protected-access
self.Enter()
self.AddName(n.name)
enter_n = _Enter(
n,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="f_count")
self.loop_enters.append(enter_n)
merge_n = merge([enter_n, enter_n])[0]
switch_n = switch(merge_n, self._pivot)
index = math_ops.add(switch_n[1], 1)
next_n = _NextIteration(index)
merge_n.op._update_input(1, next_n)
total_iterations = exit(switch_n[0], name="f_count")
self.loop_exits.append(total_iterations)
self.ExitResult([total_iterations])
self.Exit()
return total_iterations, next_n
def AddBackpropLoopCounter(self, count, outer_grad_state):
"""Add the backprop loop that controls the iterations.
This is added to the backprop loop. It is used to control the loop
termination of the backprop loop. Called in the outer context of
this grad context.
The pseudocode is:
`n = count; while (n >= 1) { n--; }`
Note that a control dependency is added to `final_zero` to ensure the
correct execution order of stack pop ops.
Args:
count: The number of iterations for backprop.
outer_grad_state: The outer grad state. None if not nested.
Returns:
The loop index.
"""
in_separate_functions = count.graph is not ops.get_default_graph()
if in_separate_functions:
# Brings the count into this graph
count = array_ops.identity(count)
else:
# TODO(apassos) XLA expects this constant to be created outside the loop,
# so doing that for now.
one = constant_op.constant(1, name="b_count")
self.Enter()
self.AddName(count.name)
enter_count = _Enter(
count,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="b_count")
self.loop_enters.append(enter_count)
merge_count = merge([enter_count, enter_count])[0]
self._pivot_for_pred = merge_count
if in_separate_functions:
one = constant_op.constant(1, name="b_count")
pred = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(pred, name="b_count")
switch_count = switch(merge_count, self._pivot)
index = math_ops.subtract(switch_count[1], one)
self._pivot_for_body = index
next_count = _NextIteration(index)
merge_count.op._update_input(1, next_count)
final_zero = exit(switch_count[0], name="b_count")
self.loop_exits.append(final_zero)
if outer_grad_state is not None:
# Force the stack pops of i-th execution of an inner loop to be ordered
# before the pops of (i+1)-th execution of the same inner loop.
# pylint: disable=protected-access
outer_grad_state.grad_sync._add_control_input(final_zero.op)
# pylint: enable=protected-access
self.ExitResult([final_zero])
self.Exit()
return next_count
def AddBackpropAccumulator(self, op, grad):
"""Add an accumulation loop for every loop invariant.
This is added to the backprop loop. It is used to accumulate partial
gradients within each loop iteration. Called when in the gradient while
context.
The pseudocode is:
```
acc = 0.0;
while (_pivot) {
acc += grad;
}
```
Args:
op: The Enter op for a loop invariant.
grad: The partial gradient of an iteration for a loop invariant.
Returns:
The gradient for a loop invariant.
"""
self.Exit()
# Create a zeros tensor with the right shape for acc. If we don't
# know the full shape statically, we will have to get the shape
# dynamically from the forward inference. Getting the shape right
# for the zeros is only needed for the base case when the loop exits
# without running any iterations.
shape = grad.get_shape()
if shape.is_fully_defined():
if self.outer_context:
self.outer_context.Enter()
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
if self.outer_context:
self.outer_context.Exit()
else:
value = op.inputs[0]
if (isinstance(self.outer_context, WhileContext) and
self.outer_context.grad_state is not None):
# We are in a nested while loop.
forward_ctxt = self.grad_state.forward_context
forward_ctxt.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
forward_ctxt.outer_context.Exit()
outer_grad_state = self.grad_state.outer_grad_state
history_zeros_shape = outer_grad_state.AddForwardAccumulator(
zeros_shape)
self.outer_context.Enter()
real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_zeros_shape, zeros_shape)
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
if self.outer_context:
self.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
acc = array_ops.zeros(zeros_shape, grad.dtype)
if self.outer_context:
self.outer_context.Exit()
self.Enter()
self.AddName(acc.name)
enter_acc = _Enter(
acc,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="b_acc")
self.loop_enters.append(enter_acc)
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
add_acc = math_ops.add(switch_acc_true, grad)
next_acc = _NextIteration(add_acc)
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
result_acc = exit(switch_acc_false, name="b_acc")
self.loop_exits.append(result_acc)
self.ExitResult([result_acc])
return result_acc
def AddBackpropIndexedSlicesAccumulator(self, op, grad):
"""This is used for accumulating gradients that are IndexedSlices.
This is essentially the equivalent of AddBackpropAccumulator but optimized
for things like updating embeddings from within a while loop.
Args:
op: The Enter op for a loop invariant.
grad: The partial gradients represented as an IndexedSlices.
Returns:
The accumulated IndexedSlices gradient of the loop invariant.
"""
values = grad.values
indices = grad.indices
dense_shape = grad.dense_shape
self.Exit()
if self.outer_context:
self.outer_context.Enter()
if values.get_shape().is_fully_defined():
values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
values.get_shape().dims[1:])
if self.outer_context:
self.outer_context.Enter()
values_acc = constant_op.constant(
0, values.dtype, shape=values_shape, name="b_acc")
if self.outer_context:
self.outer_context.Exit()
else:
values_shape = _resource_safe_shape(op.inputs[0])[1:]
values_shape = array_ops.concat([[1], values_shape], 0)
values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
indices_acc = constant_op.constant([0], indices.dtype)
shape_acc = None
if dense_shape is not None:
if dense_shape.get_shape().is_fully_defined():
if self.outer_context:
self.outer_context.Enter()
shape_acc = constant_op.constant(
0, dense_shape.dtype, shape=dense_shape.get_shape())
if self.outer_context:
self.outer_context.Exit()
else:
shape_acc = array_ops.zeros_like(
array_ops.shape_internal(
op.inputs[0], optimize=False, out_type=dense_shape.dtype),
optimize=False)
if self.outer_context:
self.outer_context.Exit()
self.Enter()
self.AddName(values_acc.name)
self.AddName(indices_acc.name)
init_acc = [indices_acc, values_acc]
if shape_acc is not None:
self.AddName(shape_acc.name)
init_acc.append(shape_acc)
# Set use_input_shape=False since the accumulator tensors will grow in
# size. If use_input_shape=True, the _update_input call below will result in
# incompatible shapes.
enter_acc = [
_Enter(
x,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
use_input_shape=False,
name="b_acc") for x in init_acc
]
# Manually set appropriate partial shapes.
enter_acc[0].set_shape([None])
if values_acc.shape.dims is not None:
enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
self.loop_enters.extend(enter_acc)
merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
switch_acc = [switch(x, self._pivot) for x in merge_acc]
# The actual accumulation.
acc_indexed_slices = [
array_ops.concat([xa[1], xv], 0)
for xa, xv in zip(switch_acc[:2], [indices, values])
]
if shape_acc is not None:
# For the shape we just keep the maximum
acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
next_acc = [_NextIteration(x) for x in acc_indexed_slices]
for xm, xn in zip(merge_acc, next_acc):
xm.op._update_input(1, xn) # pylint: disable=protected-access
exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
self.loop_exits.extend(exit_acc)
self.ExitResult(exit_acc)
return indexed_slices.IndexedSlices(
indices=exit_acc[0],
values=exit_acc[1],
dense_shape=exit_acc[2] if shape_acc is not None else None)
def _InitializeValues(self, values):
"""Makes the values known to this context."""
self._values = set()
for x in values:
if isinstance(x, ops.Tensor):
self._values.add(x.name)
else:
raise TypeError("'values' must be a list of Tensors. "
f"Received: {type(x)}.")
def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars,
loop_vars_signature):
"""Core: Add the loop termination condition and body to the graph."""
flat_shape_invariants = nest.map_structure(
lambda spec: spec.shape,
nest.flatten(loop_vars_signature, expand_composites=True))
# Let the context know the loop variables so the loop variables
# would be added in the outer contexts properly.
self._InitializeValues(flat_loop_vars)
if self._outer_context:
real_vars = [self._outer_context.AddValue(x) for x in flat_loop_vars]
else:
real_vars = flat_loop_vars
enter_vars = []
with ops.control_dependencies(None):
for real_var, shape_invariant in zip(real_vars, flat_shape_invariants):
enter_var = _Enter(
real_var,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
use_input_shape=False)
if _ShapeLessThanOrEqual(real_var.get_shape(), shape_invariant):
enter_var.set_shape(shape_invariant)
else:
raise ValueError(
f"The shape invariant specified for {real_var.name} is not "
"compatible with the initial shape of the loop variable. It "
f"enters the loop with shape {real_var.get_shape()}, but the "
f"specified shape invariant is {shape_invariant}.")
enter_var.graph.prevent_feeding(enter_var)
if self._outer_context:
self._outer_context.AddInnerOp(enter_var.op)
enter_vars.append(enter_var)
# Finds the closest enclosing non-None control pivot.
outer_context = self._outer_context
control_pivot = None
while outer_context is not None and control_pivot is None:
control_pivot = outer_context.GetControlPivot()
# pylint: disable=protected-access
outer_context = outer_context._outer_context
# pylint: enable=protected-access
if control_pivot is not None:
for var in enter_vars:
if util.IsLoopConstantEnter(var.op.inputs[0].op):
# pylint: disable=protected-access
var.op._add_control_input(control_pivot.op)
# pylint: enable=protected-access
# Fix the control inputs and control flow context of these enter ops.
self._FixControlInputsAndContext(enter_vars)
self._InitializeValues(enter_vars)
self._loop_enters = enter_vars
merge_vars = [merge([x, x])[0] for x in enter_vars]
self._pivot_for_pred = merge_vars[0]
merge_vars_with_tensorarrays = nest.map_structure(
_convert_flow_to_tensorarray, flat_orig_loop_vars, merge_vars)
# Build the graph for pred.
packed_vars = nest.pack_sequence_as(
structure=loop_vars_signature,
flat_sequence=merge_vars_with_tensorarrays,
expand_composites=True)
c = ops.convert_to_tensor(pred(*packed_vars))
self._pivot = loop_cond(c, name="LoopCond")
switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
# Build the graph for body.
vars_for_body = [_Identity(x[1]) for x in switch_vars]
self._pivot_for_body = vars_for_body[0]
# Convert TensorArray flow variables inside the context back into
# their associated TensorArrays for calling the body.
vars_for_body_with_tensorarrays = nest.map_structure(
_convert_flow_to_tensorarray, flat_orig_loop_vars, vars_for_body)
packed_vars_for_body = nest.pack_sequence_as(
structure=loop_vars_signature,
flat_sequence=vars_for_body_with_tensorarrays,
expand_composites=True)
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
body_result = body(*packed_vars_for_body)
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
if not nest.is_nested(body_result):
body_result = [body_result]
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
def map_fn(x):
# TODO(apassos) figure out how to trigger with tensor arrays as well
if isinstance(x, tensor_array_ops.TensorArray):
return x
return array_ops.identity(x)
body_result = nest.map_structure(
map_fn, body_result, expand_composites=True)
body_result = variable_utils.convert_variables_to_tensors(body_result)
# Compare the structure types of input and output of body.
# For backwards compatibility, the first layer is forced to a list
# during this comparison, because inputs are typically lists and
# outputs of the body are typically tuples.
nest.assert_same_structure(
list(packed_vars_for_body), list(body_result), expand_composites=True)
# Store body_result to keep track of TensorArrays returned by body
original_body_result = body_result
# Convert TensorArrays returned by body into their flow variables
result = nest.map_structure(
_convert_tensorarray_to_flow,
nest.flatten(body_result, expand_composites=True),
expand_composites=True)
result = ops.convert_n_to_tensor_or_composite(result)
# Add NextIteration and the back edges to complete the loop.
if len(merge_vars) != len(result):
raise ValueError("Number of inputs and outputs of 'body' must match "
f"'loop_vars'. Got {len(merge_vars)} for the number of "
f"inputs/outputs, and {len(result)} for 'loop_vars'.")
next_vars = []
for m, v in zip(merge_vars, result):
next_vars.append(_AddNextAndBackEdge(m, v))
# Add the exit ops.
exit_vars = [exit(x[0]) for x in switch_vars]
self._loop_exits = exit_vars
# Exit the loop.
self.ExitResult(exit_vars)
return original_body_result, exit_vars
def BuildLoop(self, pred, body, loop_vars, shape_invariants,
return_same_structure):
"""Add the loop termination condition and body to the graph."""
# Keep flat_orig_loop_vars to identify which are TensorArrays
flat_orig_loop_vars = nest.flatten(loop_vars, expand_composites=True)
loop_vars = nest.map_structure(
_convert_to_tensor_or_composite_or_tensorarray, loop_vars)
# Convert TensorArrays to their flow variables
flat_loop_vars = nest.map_structure(
_convert_tensorarray_to_flow,
nest.flatten(loop_vars, expand_composites=True))
if shape_invariants is not None:
loop_vars_signature = nest.map_structure(
_shape_invariant_to_type_spec, loop_vars, shape_invariants)
else:
loop_vars_signature = nest.map_structure(
_shape_invariant_to_type_spec, loop_vars)
try:
self.Enter()
# _BuildLoop calls _update_input in several places. _mutation_lock()
# ensures a Session.run call cannot occur between creating and mutating
# new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
original_body_result, exit_vars = self._BuildLoop(
pred, body, flat_orig_loop_vars, flat_loop_vars,
loop_vars_signature)
finally:
self.Exit()
flat_result = nest.flatten(original_body_result, expand_composites=True)
# Convert TensorArray flow variables outside the context back into
# their associated TensorArrays for returning to caller.
exit_vars_with_tensorarrays = nest.map_structure(
_convert_flow_to_tensorarray, flat_result, exit_vars)
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensorarrays,
expand_composites=True)
if return_same_structure:
return packed_exit_vars
else:
return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
# pylint: disable=protected-access
for e in enters:
if isinstance(e, ops.Tensor):
xs = [e]
else:
raise TypeError("'enters' must be a list of Tensors. "
f"Received: {type(e)}.")
for x in xs:
inp_op = x.op.inputs[0].op
control_inputs = graph._control_dependencies_for_inputs([inp_op])
outer_control_inputs = []
for op in control_inputs:
# We need to keep control inputs that are in any ancestor
# ControlFlowContext, and within outer WhileContext.
keep_as_control_input = True
op_ctxt = util.GetOutputContext(op)
outer_ctxt = self.outer_context
outer_while_context = (None if outer_ctxt is None else
outer_ctxt.GetWhileContext())
while outer_ctxt != op_ctxt:
if outer_ctxt is None or outer_ctxt == outer_while_context:
keep_as_control_input = False
break
outer_ctxt = outer_ctxt.outer_context
if keep_as_control_input:
outer_control_inputs.append(op)
x.op._set_control_flow_context(self)
x.op._add_control_inputs(outer_control_inputs)
graph._record_op_seen_by_control_dependencies(x.op)
# pylint: enable=protected-access
def IsWhileContext(self):
return True
# pylint: enable=redefined-outer-name
def _AsTensorList(x, p):
"""Return x as a list of Tensors or IndexedSlices.
For entries of `x` that are Operations, this returns an Identity of `p`
with a dependency on the operation.
Args:
x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
p: A Tensor to return for entries in `x` that are Operations.
Returns:
A list of Tensors or IndexedSlices.
"""
if not isinstance(x, (list, _basetuple)):
x = [x]
l = []
for v in x:
if isinstance(v, ops.Operation):
v = with_dependencies([v], p)
v = ops.convert_to_tensor_or_composite(v)
if isinstance(v, ops.Tensor):
l.append(array_ops.identity(v))
else:
l.append(
indexed_slices.IndexedSlices(
array_ops.identity(v.values), array_ops.identity(v.indices)))
return l
def _CheckResults(a, b):
assert len(a) == len(b), (
"Values returned by a() and b() must have the same length.")
for x, y in zip(a, b):
assert x.dtype == y.dtype, (
"Values returned by a() [%s] and b() [%s] must have "
"the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
def with_dependencies(dependencies, output_tensor, name=None):
"""Produces the content of `output_tensor` only after `dependencies`.
In some cases, a user may want the output of an operation to be
consumed externally only after some other dependencies have run
first. This function ensures returns `output_tensor`, but only after all
operations in `dependencies` have run. Note that this means that there is
no guarantee that `output_tensor` will be evaluated after any `dependencies`
have run.
See also `tf.tuple` and `tf.group`.
Args:
dependencies: Iterable of operations to run before this op finishes.
output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
name: (Optional) A name for this operation.
Returns:
Same as `output_tensor`.
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
"""
if context.executing_eagerly():
return output_tensor
with ops.name_scope(name, "control_dependency",
list(dependencies) + [output_tensor]) as name:
with ops.colocate_with(output_tensor):
with ops.control_dependencies(dependencies):
output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
if isinstance(output_tensor, indexed_slices.IndexedSlices):
return indexed_slices.IndexedSlices(
_Identity(output_tensor.values, name=name), output_tensor.indices,
output_tensor.dense_shape)
else:
return _Identity(output_tensor, name=name)
def _GroupControlDeps(dev, deps, name=None):
with ops.control_dependencies(deps):
if dev is None:
return no_op(name=name)
else:
with ops.device(dev):
return no_op(name=name)
# TODO(touts): Accept "inputs" as a list.
@tf_export("group")
def group(*inputs, **kwargs):
"""Create an op that groups multiple operations.
When this op finishes, all ops in `inputs` have finished. This op has no
output.
Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
this method, as ops execute in the expected order thanks to automatic control
dependencies.* Only use `tf.group` when working with v1
`tf.Graph` code.
When operating in a v1-style graph context, ops are not executed in the same
order as specified in the code; TensorFlow will attempt to execute ops in
parallel or in an order convenient to the result it is computing. `tf.group`
allows you to request that one or more results finish before execution
continues.
`tf.group` creates a single op (of type `NoOp`), and then adds appropriate
control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph
as this:
with tf.control_dependencies([a, b]):
c = tf.no_op()
See also `tf.tuple` and
`tf.control_dependencies`.
Args:
*inputs: Zero or more tensors to group.
name: A name for this operation (optional).
Returns:
An Operation that executes all its inputs.
Raises:
ValueError: If an unknown keyword argument is provided.
"""
if context.executing_eagerly():
return None
name = kwargs.pop("name", None)
if kwargs:
raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
with ops.name_scope(name, "group_deps", inputs) as name:
# Grouping no inputs means do nothing
if not inputs:
return no_op(name=name)
# Sorts *inputs according to their devices.
ops_on_device = {} # device -> operations specified on the device.
for inp in nest.flatten(inputs, expand_composites=True):
if not hasattr(inp, "device"):
raise TypeError("'inputs' should be zero or more (nested) Tensors. "
f"Received '{inp}' with type '{type(inp)}'.")
dev = inp.device
if dev in ops_on_device:
ops_on_device[dev].append(inp)
else:
ops_on_device[dev] = [inp]
if len(ops_on_device) == 1:
# 1-level tree. The root node is the returned NoOp node.
(dev, deps), = ops_on_device.items()
return _GroupControlDeps(dev, deps, name=name)
# 2-level tree. The root node is the returned NoOp node.
# deps contains 1 NoOp node for each device.
deps = []
def device_key(dev):
"""A sort key that allows None to be compared to strings."""
return "" if dev is None else dev
for dev in sorted(ops_on_device, key=device_key):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
with ops.control_dependencies(deps):
return no_op(name=name)
@tf_export("tuple", v1=[])
@dispatch.add_dispatch_support
def tuple_v2(tensors, control_inputs=None, name=None):
"""Groups tensors together.
The returned tensors have the same value as the input tensors, but they
are computed only after all the input tensors have been computed.
Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
this method, as ops execute in the expected order thanks to automatic control
dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
See also `tf.group` and `tf.control_dependencies`.
Example:
>>> with tf.Graph().as_default():
... with tf.compat.v1.Session() as sess:
... v = tf.Variable(0.0)
... a = tf.constant(1.0)
... sess.run(tf.compat.v1.global_variables_initializer())
... for i in range(5):
... update_op = v.assign_add(1.0)
... b = a + v
... res_b = sess.run(b)
... res_v = sess.run(v)
... print(res_v)
0.0
0.0
0.0
0.0
0.0
>>> with tf.Graph().as_default():
... with tf.compat.v1.Session() as sess:
... v = tf.Variable(0.0)
... a = tf.constant(1.0)
... sess.run(tf.compat.v1.global_variables_initializer())
... for i in range(5):
... update_op = v.assign_add(1.0)
... calc = [a + v]
... # `tf.tuple` ensures `update_op` is run before `b`
... b = tf.tuple(calc, [tf.group(update_op)])
... res_b = sess.run(b)
... res_v = sess.run(v)
... print(res_v)
1.0
2.0
3.0
4.0
5.0
Args:
tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
control_inputs: List of additional ops to finish before returning.
name: (optional) A name to use as a `name_scope` for the operation.
Returns:
Same as `tensors`.
Raises:
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
objects.
"""
return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin
@tf_export(v1=["tuple"])
@dispatch.add_dispatch_support
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
"""Group tensors together.
This creates a tuple of tensors with the same values as the `tensors`
argument, except that the value of each tensor is only returned after the
values of all tensors have been computed.
`control_inputs` contains additional ops that have to finish before this op
finishes, but whose outputs are not returned.
This can be used as a "join" mechanism for parallel computations: all the
argument tensors can be computed in parallel, but the values of any tensor
returned by `tuple` are only available after all the parallel computations
are done.
See also `tf.group` and
`tf.control_dependencies`.
Args:
tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
name: (optional) A name to use as a `name_scope` for the operation.
control_inputs: List of additional ops to finish before returning.
Returns:
Same as `tensors`.
Raises:
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
objects.
"""
if context.executing_eagerly():
return tensors
with ops.name_scope(name, "tuple", tensors) as name:
tensors = [
t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or
t is None) else ops.convert_to_tensor(t) for t in tensors
]
gating_ops = [
t if isinstance(t, ops.Operation) else t.op
for t in tensors
if t is not None
]
if control_inputs:
for c in control_inputs:
if isinstance(c, ops.Tensor):
c = c.op
elif not isinstance(c, ops.Operation):
raise TypeError(
"'control_inputs' must only contain Operation or Tensor. "
f"Received: {type(c)}")
gating_ops.append(c)
# Note that in order to ensure ordering in the pbtxt, we must take care to
# ensure the order here.
gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
if not gating_ops:
raise ValueError("'tensors' must have at least one Tensor. "
f"Received: {tensors}.")
gate = group(*gating_ops)
tpl = []
for t in tensors:
if tensor_util.is_tf_type(t):
tpl.append(with_dependencies([gate], t))
elif isinstance(t, ops.Operation):
with ops.control_dependencies([gate]):
tpl.append(group(t))
else:
tpl.append(None)
return tpl
class XLAControlFlowContext(ControlFlowContext):
"""Base class for XLA and TPU control flow contexts."""
def __init__(self):
super(XLAControlFlowContext, self).__init__()
self._name = "XLAControlFlowContext"
def to_control_flow_context_def(self, context_def, export_scope=None):
# pylint: disable=useless-super-delegation
# NOTE(slebedev): the method is required by `ControlFlowContext`.
super(XLAControlFlowContext,
self).to_control_flow_context_def(context_def, export_scope)
def IsXLAContext(self):
return True
def AddOp(self, _):
pass
def AddValue(self, x):
return x
def RequiresUniqueFunctionRetracing(self):
"""Returns whether the tf.function should be retraced if the context changes.
"""
return False
@tf_export("__internal__.get_enclosing_xla_context", v1=[])
def get_enclosing_xla_context():
"""Recursively find and return the XLAControlFlowContext."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, XLAControlFlowContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
def from_control_flow_context_def(context_def, import_scope=None):
"""Deserializes `context_def` into the appropriate ControlFlowContext.
Args:
context_def: ControlFlowContextDef proto
import_scope: Optional `string`. Name scope to add.
Returns:
A ControlFlowContext subclass
"""
if context_def.HasField("cond_ctxt"):
return CondContext.from_proto(
context_def.cond_ctxt, import_scope=import_scope)
if context_def.HasField("while_ctxt"):
return WhileContext.from_proto(
context_def.while_ctxt, import_scope=import_scope)
raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
context_def.WhichOneof("ctxt"))
ops.register_proto_function(
ops.GraphKeys.COND_CONTEXT,
proto_type=control_flow_pb2.CondContextDef,
to_proto=CondContext.to_proto,
from_proto=CondContext.from_proto)
ops.register_proto_function(
ops.GraphKeys.WHILE_CONTEXT,
proto_type=control_flow_pb2.WhileContextDef,
to_proto=WhileContext.to_proto,
from_proto=WhileContext.from_proto)