| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| """Utility functions for control flow. |
| |
| This file is necessary to avoid cyclic dependencies between ops.py and |
| control_flow_ops.py. |
| """ |
| |
| import os |
| import traceback |
| |
| from tensorflow.python import tf2 |
| from tensorflow.python.platform import tf_logging as logging |
| |
| ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and |
| os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or |
| os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or |
| os.getenv("TF_ENABLE_COND_V2", "0") != "0" or |
| os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or |
| os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") |
| |
| |
| # TODO(b/137793122): Remove this. |
| def enable_control_flow_v2(): # pylint: disable=invalid-name |
| """Use control flow v2. |
| |
| Do not use this symbol. This will be removed. |
| """ |
| global ENABLE_CONTROL_FLOW_V2 |
| ENABLE_CONTROL_FLOW_V2 = True |
| |
| |
| def EnableControlFlowV2(graph): |
| """Returns whether control flow v2 should be used in `graph`.""" |
| # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs). |
| # TODO(skyewm): do something better than hasattr without messing up imports. |
| return ENABLE_CONTROL_FLOW_V2 or ( |
| graph.building_function and not hasattr(graph, "_captured")) |
| |
| |
| def IsInXLAContext(op): |
| try: |
| xla_compile = op.get_attr("_XlaCompile") |
| if xla_compile: return True |
| except ValueError: |
| pass |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| return GetContainingXLAContext(ctxt) is not None |
| |
| |
| def InXlaContext(graph): |
| ctxt = graph._get_control_flow_context() # pylint: disable=protected-access |
| return GetContainingXLAContext(ctxt) is not None |
| |
| |
| def GraphOrParentsInXlaContext(graph): |
| while True: |
| if InXlaContext(graph): return True |
| try: |
| graph = graph.outer_graph |
| except AttributeError: |
| return False |
| |
| |
| def IsInWhileLoop(op): |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| return GetContainingWhileContext(ctxt) is not None |
| |
| |
| def IsInCond(op): |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| return GetContainingCondContext(ctxt) is not None |
| |
| |
| def IsSwitch(op): |
| """Return true if `op` is a Switch.""" |
| return op.type == "Switch" or op.type == "RefSwitch" |
| |
| |
| def IsMerge(op): |
| """Return true if `op` is a Merge.""" |
| return op.type == "Merge" or op.type == "RefMerge" |
| |
| |
| def IsLoopEnter(op): |
| """Returns true if `op` is an Enter.""" |
| return op.type == "Enter" or op.type == "RefEnter" |
| |
| |
| def IsLoopExit(op): |
| """Return true if `op` is an Exit.""" |
| return op.type == "Exit" or op.type == "RefExit" |
| |
| |
| def IsCondSwitch(op): |
| """Return true if `op` is the Switch for a conditional.""" |
| if not IsSwitch(op): |
| return False |
| if not op.outputs: |
| return False |
| # Switch nodes are not part of the cond control flow context that they |
| # represent, so consider the consumers of its outputs to determine if it is |
| # cond switch or not. A switch is a cond switch iff all its consumers are in |
| # cond contexts. |
| is_cond_switch = True |
| for o in op.outputs: |
| for c in o.consumers(): |
| ctxt = c._get_control_flow_context() # pylint: disable=protected-access |
| if IsLoopEnter(c): |
| ctxt = ctxt.outer_context |
| is_cond_switch = is_cond_switch and (ctxt is not None and |
| ctxt.IsCondContext()) |
| return is_cond_switch |
| |
| |
| def IsCondMerge(op): |
| """Return true if `op` is the Merge for a conditional.""" |
| if not IsMerge(op): |
| return False |
| if not op.inputs: |
| return False |
| # Merge nodes are not part of the cond control flow context that they |
| # represent, so consider the inputs to the merge of to determine if it is |
| # cond merge or not: A merge is a cond merge iff all its inputs are in |
| # cond contexts. |
| is_cond_merge = True |
| for i in op.inputs: |
| ctxt = GetOutputContext(i.op) |
| is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext() |
| return is_cond_merge |
| |
| |
| def IsLoopSwitch(op): |
| """Return true if `op` is the Switch for a while loop.""" |
| if IsSwitch(op): |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op) |
| return False |
| |
| |
| def IsLoopMerge(op): |
| """Return true if `op` is the Merge for a while loop.""" |
| if IsMerge(op): |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op) |
| return False |
| |
| |
| def IsLoopConstantEnter(op): |
| """Return true iff op is a loop invariant.""" |
| return IsLoopEnter(op) and op.get_attr("is_constant") |
| |
| |
| def GetLoopConstantEnter(value): |
| """Return the enter op if we can infer `value` to be a loop invariant.""" |
| id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} |
| op = value.op |
| while op.type in id_ops: |
| op = op.inputs[0].op |
| return op if IsLoopConstantEnter(op) else None |
| |
| |
| def GetOutputContext(op): |
| """Return the control flow context for the output of an op.""" |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| # Exit nodes usually have a control flow context, except in the case where the |
| # exit node was imported via import_graph_def (in which case no nodes have |
| # control flow contexts). |
| if ctxt is not None and IsLoopExit(op): |
| ctxt = ctxt.outer_context |
| return ctxt |
| |
| |
| def GetContainingWhileContext(ctxt, stop_ctxt=None): |
| """Returns the first ancestor WhileContext of `ctxt`. |
| |
| Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a |
| while loop. |
| |
| Args: |
| ctxt: ControlFlowContext |
| stop_ctxt: ControlFlowContext, optional. If provided, the search will end |
| if it sees stop_ctxt. |
| |
| Returns: |
| `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing |
| `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not |
| `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. |
| """ |
| while ctxt: |
| if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt |
| ctxt = ctxt.outer_context |
| return None |
| |
| |
| def GetContainingXLAContext(ctxt): |
| """Returns the first ancestor XLAContext of `ctxt`. |
| |
| Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a |
| while loop. |
| |
| Args: |
| ctxt: ControlFlowContext |
| |
| Returns: |
| `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing |
| `ctxt`, or None if `ctxt` is not in a while loop. |
| """ |
| while ctxt: |
| if ctxt.IsXLAContext(): return ctxt |
| ctxt = ctxt.outer_context |
| return None |
| |
| |
| def GetContainingCondContext(ctxt): |
| """Returns the first ancestor CondContext of `ctxt`. |
| |
| Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. |
| |
| Args: |
| ctxt: ControlFlowContext |
| |
| Returns: |
| `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing |
| `ctxt`, or None if `ctxt` is not in a cond. |
| """ |
| while ctxt: |
| if ctxt.IsCondContext(): return ctxt |
| ctxt = ctxt.outer_context |
| return None |
| |
| |
| def IsContainingContext(ctxt, maybe_containing_ctxt): |
| """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" |
| while ctxt is not maybe_containing_ctxt: |
| if ctxt is None: return False |
| ctxt = ctxt.outer_context |
| return True |
| |
| |
| def OpInContext(op, ctxt): |
| return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access |
| |
| |
| def TensorInContext(tensor, ctxt): |
| return OpInContext(tensor.op, ctxt) |
| |
| |
| def CheckInputFromValidContext(op, input_op): |
| """Returns whether `input_op` can be used from `op`s context. |
| |
| Conceptually, only inputs from op's while context or any ancestor while |
| context (including outside of any context) are valid. In practice, there are |
| many other edge cases as well. |
| |
| Args: |
| op: Operation |
| input_op: Operation |
| |
| Raises: |
| ValueError: if input_op is from an invalid context. |
| """ |
| op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| input_ctxt = GetOutputContext(input_op) |
| valid = False |
| |
| if not input_ctxt: |
| # input_op isn't in a control flow context. |
| valid = True |
| elif op_ctxt is input_ctxt: |
| # input_op is in the same context as op. |
| valid = True |
| else: |
| while_ctxt = GetContainingWhileContext(op_ctxt) |
| input_while_ctxt = GetContainingWhileContext(input_ctxt) |
| |
| if while_ctxt is None: |
| if input_while_ctxt is None: |
| # Neither op nor input_op is in a while loop, but one or both are in |
| # conds. We allow this, although execution will fail if the branch |
| # corresponding to input_op's cond context isn't taken. |
| valid = True |
| # Invalid if op isn't in a while loop and input_op is. Unless... |
| if IsLoopEnter(op): |
| # WhileContext._BuildLoop clears context for Enter nodes. |
| valid = True |
| if IsSwitch(op): |
| # CondContext.AddValue clears context for Switch nodes. |
| valid = True |
| elif IsContainingContext(while_ctxt, input_while_ctxt): |
| # input_op is in a while loop which contains op's while loop (or not in a |
| # while loop at all). |
| valid = True |
| elif (while_ctxt.grad_state and |
| IsContainingContext(while_ctxt.grad_state.forward_context, |
| input_while_ctxt)): |
| # op is in a gradient context and input_op is in the associated forward |
| # pass context or an ancestor thereof. This case is need to build while |
| # loop gradients. |
| # NOTE(skyewm): we theoretically also need this case for custom gradient |
| # functions that close over tensors from ancestor contexts, but I haven't |
| # verified this. |
| valid = True |
| elif (while_ctxt.grad_state and |
| while_ctxt.grad_state.forward_context is |
| input_while_ctxt._outer_context): # pylint: disable=protected-access |
| # op is in a gradient context and input_op is in a child of the associated |
| # forward pass context. This case is needed for the gradients of while |
| # loops with conds. |
| valid = True |
| elif (input_while_ctxt.grad_state and |
| input_while_ctxt.grad_state.forward_context is while_ctxt): |
| # input_op is in the gradient context of op's context. This case is needed |
| # when the gradient of a while loop gradient is requested (this will |
| # eventually fail unless there is a stop_gradient() or similar). |
| valid = True |
| elif (input_while_ctxt.grad_state and |
| input_ctxt.grad_state.forward_context.grad_state and |
| input_ctxt.grad_state.forward_context.grad_state.forward_context is |
| while_ctxt): |
| # input_op is in the grad grad context of op's context. This case is |
| # needed when the gradient of a while loop gradient is requested (this |
| # will eventually fail unless there is a stop_gradient() or similar). |
| valid = True |
| |
| if not valid: |
| if while_ctxt: |
| error_msg = ( |
| f"Cannot use '{input_op.name}' as input to '{op.name}' because they " |
| "are in different while loops.") |
| else: |
| error_msg = ( |
| f"Cannot use '{input_op.name}' as input to '{op.name}' because " |
| f"'{input_op.name}' is in a while loop.") |
| |
| # Log the error message plus the relevant stack traces. The stacks may be |
| # useful for debugging this error, but we don't want to raise an |
| # unreadable exception. |
| log_msg = error_msg |
| log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) |
| log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) |
| log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( |
| op.name, "".join(traceback.format_list(op.traceback)), |
| input_op.name, "".join(traceback.format_list(input_op.traceback))) |
| logging.info(log_msg) |
| raise ValueError(error_msg + " See info log for more details.") |
| |
| |
| def GetWhileContext(op): |
| """Get the WhileContext to which this op belongs.""" |
| ctxt = op._get_control_flow_context() # pylint: disable=protected-access |
| if ctxt: |
| ctxt = ctxt.GetWhileContext() |
| return ctxt |