| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """While loop for Control Flow Operations.""" |
| |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import type_spec |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import control_flow_util as util |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops import while_v2 |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import variable_utils |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| # @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature". |
| # pylint: disable=redefined-outer-name |
| @tf_export("while_loop", v1=[]) |
| @deprecation.deprecated_arg_values( |
| None, |
| """back_prop=False is deprecated. Consider using tf.stop_gradient instead. |
| Instead of: |
| results = tf.while_loop(c, b, vars, back_prop=False) |
| Use: |
| results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""", |
| warn_once=True, |
| back_prop=False) |
| def while_loop_v2(cond, |
| body, |
| loop_vars, |
| shape_invariants=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| maximum_iterations=None, |
| name=None): |
| """Repeat `body` while the condition `cond` is true. |
| |
| Note: This op is automatically used in a `tf.function` to convert Python for- |
| and while- loops when the loop variable is a `tf.Tensor`, unless |
| `autograph=False` is explicitly specified in `tf.function` args. For example, |
| the following are equivalent: |
| |
| >>> @tf.function |
| ... def sumSquare(n): |
| ... i, result = tf.constant(0), tf.constant(0) |
| ... while i < n: # AutoGraph converts while-loop to tf.while_loop(). |
| ... result += i * i |
| ... i += 1 |
| ... return result |
| >>> sumSquare(10).numpy() |
| 285 |
| |
| >>> @tf.function |
| ... def sumSquare2(n): |
| ... i, result = tf.constant(0), tf.constant(0) |
| ... c = lambda i, _: tf.less(i, n) |
| ... b = lambda i, result: (i + 1, result + i * i) |
| ... return tf.while_loop(c, b, [i, result])[1] |
| >>> sumSquare2(10).numpy() |
| 285 |
| |
| For more information, see [tf.function and AutoGraph guide |
| ](https://www.tensorflow.org/guide/function#autograph_transformations). |
| |
| `cond` is a callable returning a boolean scalar tensor. `body` is a callable |
| returning a (possibly nested) tuple, namedtuple or list of tensors of the same |
| arity (length and structure) and types as `loop_vars`. `loop_vars` is a |
| (possibly nested) tuple, namedtuple or list of tensors that is passed to both |
| `cond` and `body`. `cond` and `body` both take as many arguments as there are |
| `loop_vars`. |
| |
| In addition to regular Tensors or IndexedSlices, the body may accept and |
| return TensorArray objects. The flows of the TensorArray objects will |
| be appropriately forwarded between loops and during gradient calculations. |
| |
| Note that `while_loop` calls `cond` and `body` *exactly once* (inside the |
| call to `while_loop`, and not at all during `Session.run()`). `while_loop` |
| stitches together the graph fragments created during the `cond` and `body` |
| calls with some additional graph nodes to create the graph flow that |
| repeats `body` until `cond` returns false. |
| |
| For correctness, `tf.while_loop()` strictly enforces shape invariants for |
| the loop variables. A shape invariant is a (possibly partial) shape that |
| is unchanged across the iterations of the loop. An error will be raised |
| if the shape of a loop variable after an iteration is determined to be more |
| general than or incompatible with its shape invariant. For example, a shape |
| of `[11, None]` is more general than a shape of `[11, 17]`, and `[11, 21]` is |
| not compatible with `[11, 17]`. By default (if the argument `shape_invariants` |
| is not specified), it is assumed that the initial shape of each tensor in |
| `loop_vars` is the same in every iteration. The `shape_invariants` argument |
| allows the caller to specify a less specific shape invariant for each loop |
| variable, which is needed if the shape varies between iterations. The |
| `tf.Tensor.set_shape` |
| function may also be used in the `body` function to indicate that |
| the output loop variable has a particular shape. The shape invariant for |
| SparseTensor and IndexedSlices are treated specially as follows: |
| |
| a) If a loop variable is a SparseTensor, the shape invariant must be |
| `TensorShape([r])` where `r` is the rank of the dense tensor represented |
| by the sparse tensor. It means the shapes of the three tensors of the |
| SparseTensor are `([None], [None, r], [r])`. NOTE: The shape invariant here |
| is the shape of the SparseTensor.dense_shape property. It must be the shape of |
| a vector. |
| |
| b) If a loop variable is an IndexedSlices, the shape invariant must be |
| a shape invariant of the values tensor of the IndexedSlices. It means |
| the shapes of the three tensors of the IndexedSlices are `(shape, [shape[0]], |
| [shape.ndims])`. |
| |
| `while_loop` implements non-strict semantics, enabling multiple iterations |
| to run in parallel. The maximum number of parallel iterations can be |
| controlled by `parallel_iterations`, which gives users some control over |
| memory consumption and execution order. For correct programs, `while_loop` |
| should return the same result for any `parallel_iterations > 0`. |
| |
| For training, TensorFlow stores the tensors that are produced in the |
| forward inference and are needed in back propagation. These tensors are a |
| main source of memory consumption and often cause OOM errors when training |
| on GPUs. When the flag swap_memory is true, we swap out these tensors from |
| GPU to CPU. This for example allows us to train RNN models with very long |
| sequences and large batches. |
| |
| Args: |
| cond: A callable that represents the termination condition of the loop. |
| body: A callable that represents the loop body. |
| loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, |
| `Tensor`, and `TensorArray` objects. |
| shape_invariants: The shape invariants for the loop variables. |
| parallel_iterations: The number of iterations allowed to run in parallel. It |
| must be a positive integer. |
| back_prop: (optional) Deprecated. False disables support for back |
| propagation. Prefer using `tf.stop_gradient` instead. |
| swap_memory: Whether GPU-CPU memory swap is enabled for this loop. |
| maximum_iterations: Optional maximum number of iterations of the while loop |
| to run. If provided, the `cond` output is AND-ed with an additional |
| condition ensuring the number of iterations executed is no greater than |
| `maximum_iterations`. |
| name: Optional name prefix for the returned tensors. |
| |
| Returns: |
| The output tensors for the loop variables after the loop. The return value |
| has the same structure as `loop_vars`. |
| |
| Raises: |
| TypeError: if `cond` or `body` is not callable. |
| ValueError: if `loop_vars` is empty. |
| |
| Example: |
| |
| >>> i = tf.constant(0) |
| >>> c = lambda i: tf.less(i, 10) |
| >>> b = lambda i: (tf.add(i, 1), ) |
| >>> r = tf.while_loop(c, b, [i])[0] |
| >>> r.numpy() |
| 10 |
| |
| Example with nesting and a namedtuple: |
| |
| >>> import collections |
| >>> Pair = collections.namedtuple('Pair', 'j, k') |
| >>> ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) |
| >>> c = lambda i, p: i < 10 |
| >>> b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) |
| >>> ijk_final = tf.while_loop(c, b, ijk_0)[1] |
| >>> ijk_final[0].numpy(), ijk_final[1].numpy() |
| (32, 64) |
| |
| Example using shape_invariants: |
| |
| >>> i0 = tf.constant(0) |
| >>> m0 = tf.ones([2, 2]) |
| >>> c = lambda i, m: i < 10 |
| >>> b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] |
| >>> tf.while_loop( |
| ... c, b, loop_vars=[i0, m0], |
| ... shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])[1] |
| <tf.Tensor: shape=(2048, 2), dtype=float32, numpy=...> |
| |
| Example which demonstrates non-strict semantics: In the following |
| example, the final value of `counter` does not depend on `x`. So |
| the `while_loop` can increment the counter parallel to updates of `x`. |
| However, because the loop counter at one loop iteration depends |
| on the value at the previous iteration, the loop counter itself cannot |
| be incremented in parallel. Hence if we just want the final value of the |
| counter (which we print on the line `print(sess.run(i))`), then |
| `x` will never be incremented, but the counter will be updated on a |
| single thread. Conversely, if we want the value of the output (which we |
| print on the line `print(sess.run(out).shape)`), then the counter may be |
| incremented on its own thread, while `x` can be incremented in |
| parallel on a separate thread. In the extreme case, it is conceivable |
| that the thread incrementing the counter runs until completion before |
| `x` is incremented even a single time. The only thing that can never |
| happen is that the thread updating `x` can never get ahead of the |
| counter thread because the thread incrementing `x` depends on the value |
| of the counter. |
| |
| >>> with tf.compat.v1.Session() as sess: |
| ... n = 10 |
| ... c = lambda i, x: i < n |
| ... b = lambda i, x: ( |
| ... tf.compat.v1.Print(i + 1, [i], "Updating i based on i == "), |
| ... # Let x depend on i |
| ... tf.compat.v1.Print(x + i, [i], "Updating x based on i == ")) |
| ... |
| ... # Make x to be a big matrix so its updating thread would run slowly |
| ... x = tf.zeros([1000, 100], dtype=tf.int32) |
| ... counter = tf.constant(0) |
| ... counter_out, x_out = tf.while_loop(c, b, (counter, x)) |
| ... |
| ... # The following line may increment the counter and x in parallel. |
| ... # The counter thread may get ahead of the x thread, but not the |
| ... # other way around. For example, the log may contain these messages: |
| ... # ``` |
| ... # Updating i based on i == [9] |
| ... # Updating x based on i == [3] |
| ... # ``` |
| ... # meaning that the counter(i) thread is on iteration 9, |
| ... # while the x thread is on iteration 3. |
| ... print(sess.run(x_out).shape) |
| (1000, 100) |
| |
| """ |
| return while_loop( |
| cond=cond, |
| body=body, |
| loop_vars=loop_vars, |
| shape_invariants=shape_invariants, |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| name=name, |
| maximum_iterations=maximum_iterations, |
| return_same_structure=True) |
| |
| |
| # pylint: disable=redefined-outer-name |
| @tf_export(v1=["while_loop"]) |
| def while_loop(cond, |
| body, |
| loop_vars, |
| shape_invariants=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| name=None, |
| maximum_iterations=None, |
| return_same_structure=False): |
| """Repeat `body` while the condition `cond` is true. |
| |
| `cond` is a callable returning a boolean scalar tensor. `body` is a callable |
| returning a (possibly nested) tuple, namedtuple or list of tensors of the same |
| arity (length and structure) and types as `loop_vars`. `loop_vars` is a |
| (possibly nested) tuple, namedtuple or list of tensors that is passed to both |
| `cond` and `body`. `cond` and `body` both take as many arguments as there are |
| `loop_vars`. |
| |
| In addition to regular Tensors or IndexedSlices, the body may accept and |
| return TensorArray objects. The flows of the TensorArray objects will |
| be appropriately forwarded between loops and during gradient calculations. |
| |
| Note that `while_loop` calls `cond` and `body` *exactly once* (inside the |
| call to `while_loop`, and not at all during `Session.run()`). `while_loop` |
| stitches together the graph fragments created during the `cond` and `body` |
| calls with some additional graph nodes to create the graph flow that |
| repeats `body` until `cond` returns false. |
| |
| For correctness, `tf.while_loop()` strictly enforces shape invariants for |
| the loop variables. A shape invariant is a (possibly partial) shape that |
| is unchanged across the iterations of the loop. An error will be raised |
| if the shape of a loop variable after an iteration is determined to be more |
| general than or incompatible with its shape invariant. For example, a shape |
| of [11, None] is more general than a shape of [11, 17], and [11, 21] is not |
| compatible with [11, 17]. By default (if the argument `shape_invariants` is |
| not specified), it is assumed that the initial shape of each tensor in |
| `loop_vars` is the same in every iteration. The `shape_invariants` argument |
| allows the caller to specify a less specific shape invariant for each loop |
| variable, which is needed if the shape varies between iterations. The |
| `tf.Tensor.set_shape` |
| function may also be used in the `body` function to indicate that |
| the output loop variable has a particular shape. The shape invariant for |
| SparseTensor and IndexedSlices are treated specially as follows: |
| |
| a) If a loop variable is a SparseTensor, the shape invariant must be |
| TensorShape([r]) where r is the rank of the dense tensor represented |
| by the sparse tensor. It means the shapes of the three tensors of the |
| SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here |
| is the shape of the SparseTensor.dense_shape property. It must be the shape of |
| a vector. |
| |
| b) If a loop variable is an IndexedSlices, the shape invariant must be |
| a shape invariant of the values tensor of the IndexedSlices. It means |
| the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], |
| [shape.ndims]). |
| |
| `while_loop` implements non-strict semantics, enabling multiple iterations |
| to run in parallel. The maximum number of parallel iterations can be |
| controlled by `parallel_iterations`, which gives users some control over |
| memory consumption and execution order. For correct programs, `while_loop` |
| should return the same result for any parallel_iterations > 0. |
| |
| For training, TensorFlow stores the tensors that are produced in the |
| forward inference and are needed in back propagation. These tensors are a |
| main source of memory consumption and often cause OOM errors when training |
| on GPUs. When the flag swap_memory is true, we swap out these tensors from |
| GPU to CPU. This for example allows us to train RNN models with very long |
| sequences and large batches. |
| |
| Args: |
| cond: A callable that represents the termination condition of the loop. |
| body: A callable that represents the loop body. |
| loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, |
| `Tensor`, and `TensorArray` objects. |
| shape_invariants: The shape invariants for the loop variables. |
| parallel_iterations: The number of iterations allowed to run in parallel. It |
| must be a positive integer. |
| back_prop: Whether backprop is enabled for this while loop. |
| swap_memory: Whether GPU-CPU memory swap is enabled for this loop. |
| name: Optional name prefix for the returned tensors. |
| maximum_iterations: Optional maximum number of iterations of the while loop |
| to run. If provided, the `cond` output is AND-ed with an additional |
| condition ensuring the number of iterations executed is no greater than |
| `maximum_iterations`. |
| return_same_structure: If True, output has same structure as `loop_vars`. If |
| eager execution is enabled, this is ignored (and always treated as True). |
| |
| Returns: |
| The output tensors for the loop variables after the loop. |
| If `return_same_structure` is True, the return value has the same |
| structure as `loop_vars`. |
| If `return_same_structure` is False, the return value is a Tensor, |
| TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list |
| otherwise. |
| |
| Raises: |
| TypeError: if `cond` or `body` is not callable. |
| ValueError: if `loop_vars` is empty. |
| |
| Example: |
| |
| ```python |
| i = tf.constant(0) |
| c = lambda i: tf.less(i, 10) |
| b = lambda i: tf.add(i, 1) |
| r = tf.while_loop(c, b, [i]) |
| ``` |
| |
| Example with nesting and a namedtuple: |
| |
| ```python |
| import collections |
| Pair = collections.namedtuple('Pair', 'j, k') |
| ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) |
| c = lambda i, p: i < 10 |
| b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) |
| ijk_final = tf.while_loop(c, b, ijk_0) |
| ``` |
| |
| Example using shape_invariants: |
| |
| ```python |
| i0 = tf.constant(0) |
| m0 = tf.ones([2, 2]) |
| c = lambda i, m: i < 10 |
| b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] |
| tf.while_loop( |
| c, b, loop_vars=[i0, m0], |
| shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) |
| ``` |
| |
| Example which demonstrates non-strict semantics: In the following |
| example, the final value of the counter `i` does not depend on `x`. So |
| the `while_loop` can increment the counter parallel to updates of `x`. |
| However, because the loop counter at one loop iteration depends |
| on the value at the previous iteration, the loop counter itself cannot |
| be incremented in parallel. Hence if we just want the final value of the |
| counter (which we print on the line `print(sess.run(i))`), then |
| `x` will never be incremented, but the counter will be updated on a |
| single thread. Conversely, if we want the value of the output (which we |
| print on the line `print(sess.run(out).shape)`), then the counter may be |
| incremented on its own thread, while `x` can be incremented in |
| parallel on a separate thread. In the extreme case, it is conceivable |
| that the thread incrementing the counter runs until completion before |
| `x` is incremented even a single time. The only thing that can never |
| happen is that the thread updating `x` can never get ahead of the |
| counter thread because the thread incrementing `x` depends on the value |
| of the counter. |
| |
| ```python |
| import tensorflow as tf |
| |
| n = 10000 |
| x = tf.constant(list(range(n))) |
| c = lambda i, x: i < n |
| b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, |
| [i], "x:")) |
| i, out = tf.while_loop(c, b, (0, x)) |
| with tf.compat.v1.Session() as sess: |
| print(sess.run(i)) # prints [0] ... [9999] |
| |
| # The following line may increment the counter and x in parallel. |
| # The counter thread may get ahead of the other thread, but not the |
| # other way around. So you may see things like |
| # [9996] x:[9987] |
| # meaning that the counter thread is on iteration 9996, |
| # while the other thread is on iteration 9987 |
| print(sess.run(out).shape) |
| ``` |
| """ |
| if not callable(cond): |
| raise TypeError("'cond' must be callable.") |
| if not callable(body): |
| raise TypeError("'body' must be callable.") |
| if parallel_iterations < 1: |
| raise TypeError("'parallel_iterations' must be a positive integer.") |
| |
| loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) |
| |
| # Always enable control flow v2 if building a function, regardless of toggle. |
| executing_eagerly = context.executing_eagerly() |
| if (util.EnableControlFlowV2(ops.get_default_graph()) and |
| not executing_eagerly): |
| return while_v2.while_loop( |
| cond, |
| body, |
| loop_vars, |
| shape_invariants=shape_invariants, |
| parallel_iterations=parallel_iterations, |
| maximum_iterations=maximum_iterations, |
| name=name, |
| return_same_structure=return_same_structure, |
| back_prop=back_prop) |
| |
| with ops.name_scope(name, "while", loop_vars): |
| if not loop_vars: |
| raise ValueError("'loop_vars' must be provided.") |
| try_to_pack = (len(loop_vars) == 1 and not return_same_structure) |
| if maximum_iterations is not None: |
| maximum_iterations = ops.convert_to_tensor( |
| maximum_iterations, name="maximum_iterations") |
| if maximum_iterations.shape.ndims != 0: |
| raise ValueError("'maximum_iterations' must be a scalar. " |
| f"Received shape: {maximum_iterations.shape}") |
| |
| if executing_eagerly: |
| counter = 0 |
| maximum_iterations = int(maximum_iterations.numpy()) |
| else: |
| counter = constant_op.constant( |
| 0, dtype=maximum_iterations.dtype, name="iteration_counter") |
| orig_cond = cond |
| orig_body = body |
| if try_to_pack: |
| loop_vars = (counter, loop_vars[0]) |
| cond = lambda i, lv: ( # pylint: disable=g-long-lambda |
| math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) |
| body = lambda i, lv: (i + 1, orig_body(lv)) |
| else: |
| loop_vars = (counter, loop_vars) |
| cond = lambda i, lv: ( # pylint: disable=g-long-lambda |
| math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) |
| body = lambda i, lv: (i + 1, orig_body(*lv)) |
| try_to_pack = False |
| |
| if executing_eagerly: |
| packed = False # whether the body result was packed into a 1-item tuple |
| |
| loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, |
| list(loop_vars)) |
| while cond(*loop_vars): |
| loop_vars = body(*loop_vars) |
| if try_to_pack and not isinstance(loop_vars, (list, tuple)): |
| packed = True |
| loop_vars = (loop_vars,) |
| nest.assert_same_structure(loop_var_structure, list(loop_vars)) |
| |
| def convert(x): |
| if isinstance(x, tensor_array_ops.TensorArray): |
| return x |
| return ops.convert_to_tensor(x) |
| |
| loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) |
| if maximum_iterations is not None: |
| return loop_vars[1] |
| else: |
| return loop_vars[0] if packed else loop_vars |
| |
| if shape_invariants is not None: |
| if maximum_iterations is not None: |
| shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) |
| |
| loop_context = control_flow_ops.WhileContext( |
| maximum_iterations=maximum_iterations, |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory) |
| # Only add non-nested loops to the collection. Any nested control flow will |
| # be encapsulated in the root context. |
| if loop_context.outer_context is None: |
| ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) |
| result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, |
| return_same_structure) |
| if maximum_iterations is not None: |
| return result[1] |
| else: |
| return result |