| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================= |
| """Functional operations.""" |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_functional_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops import variable_scope as vs |
| from tensorflow.python.ops import while_loop |
| # pylint: disable=unused-import |
| from tensorflow.python.ops.gen_functional_ops import remote_call |
| # pylint: enable=unused-import |
| from tensorflow.python.ops.gen_functional_ops import symbolic_gradient |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import dispatch |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| # TODO(yuanbyu, mrry): Handle stride to support sliding windows. |
| @tf_export(v1=["foldl"]) |
| @dispatch.add_dispatch_support |
| def foldl(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| name=None): |
| """foldl on the list of tensors unpacked from `elems` on dimension 0. |
| |
| This foldl operator repeatedly applies the callable `fn` to a sequence |
| of elements from first to last. The elements are made of the tensors |
| unpacked from `elems` on dimension 0. The callable fn takes two tensors as |
| arguments. The first argument is the accumulated value computed from the |
| preceding invocation of fn, and the second is the value at the current |
| position of `elems`. If `initializer` is None, `elems` must contain at least |
| one element, and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is fn(initializer, values[0]).shape`. |
| |
| This method also allows multi-arity `elems` and output of `fn`. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The signature of `fn` may |
| match the structure of `elems`. That is, if `elems` is |
| `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: |
| `fn = lambda (t1, [t2, t3, [t4, t5]]):`. |
| |
| Args: |
| fn: The callable to be performed. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| as the initial value for the accumulator. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) True enables support for back propagation. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors, resulting from applying |
| `fn` consecutively to the list of tensors unpacked from `elems`, from first |
| to last. |
| |
| Raises: |
| TypeError: if `fn` is not callable. |
| |
| Example: |
| ```python |
| elems = tf.constant([1, 2, 3, 4, 5, 6]) |
| sum = foldl(lambda a, x: a + x, elems) |
| # sum == 21 |
| ``` |
| """ |
| if not callable(fn): |
| raise TypeError( |
| f"{fn.__name__} is not callable. Please provide a callable function.") |
| |
| def create_ta(elem): |
| return tensor_array_ops.TensorArray( |
| dtype=elem.dtype, size=n, dynamic_size=False, |
| infer_shape=True).unstack(elem) |
| |
| in_graph_mode = not context.executing_eagerly() |
| with ops.name_scope(name, "foldl", [elems]): |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode: |
| # Any get_variable calls in fn will cache the first call locally |
| # and not issue repeated network I/O requests for each iteration. |
| varscope = vs.get_variable_scope() |
| varscope_caching_device_was_none = False |
| if varscope.caching_device is None: |
| # TODO(ebrevdo): Change to using colocate_with here and in other |
| # methods. |
| varscope.set_caching_device(lambda op: op.device) |
| varscope_caching_device_was_none = True |
| |
| # Convert elems to tensor array. n may be known statically. |
| elems_flat = [ |
| ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) |
| ] |
| n = ( |
| tensor_shape.dimension_value(elems_flat[0].shape[0]) or |
| array_ops.shape(elems_flat[0])[0]) |
| |
| elems_ta = nest.map_structure(create_ta, elems) |
| |
| if initializer is None: |
| a = nest.map_structure(lambda elem: elem.read(0), elems_ta) |
| i = constant_op.constant(1) |
| else: |
| a = initializer |
| i = constant_op.constant(0) |
| |
| def compute(i, a): |
| elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) |
| a = fn(a, elem_i) |
| return [i + 1, a] |
| |
| _, r_a = while_loop.while_loop( |
| lambda i, a: i < n, |
| compute, [i, a], |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| maximum_iterations=n) |
| |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode and varscope_caching_device_was_none: |
| varscope.set_caching_device(None) |
| |
| return r_a |
| |
| |
| @tf_export("foldl", v1=[]) |
| @dispatch.add_dispatch_support |
| @deprecation.deprecated_arg_values( |
| None, |
| """back_prop=False is deprecated. Consider using tf.stop_gradient instead. |
| Instead of: |
| results = tf.foldl(fn, elems, back_prop=False) |
| Use: |
| results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""", |
| warn_once=True, |
| back_prop=False) |
| def foldl_v2(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| name=None): |
| """foldl on the list of tensors unpacked from `elems` on dimension 0. |
| |
| This foldl operator repeatedly applies the callable `fn` to a sequence |
| of elements from first to last. The elements are made of the tensors |
| unpacked from `elems` on dimension 0. The callable fn takes two tensors as |
| arguments. The first argument is the accumulated value computed from the |
| preceding invocation of fn, and the second is the value at the current |
| position of `elems`. If `initializer` is None, `elems` must contain at least |
| one element, and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is fn(initializer, values[0]).shape`. |
| |
| This method also allows multi-arity `elems` and output of `fn`. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The signature of `fn` may |
| match the structure of `elems`. That is, if `elems` is |
| `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: |
| `fn = lambda (t1, [t2, t3, [t4, t5]]):`. |
| |
| Args: |
| fn: The callable to be performed. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| as the initial value for the accumulator. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) Deprecated. False disables support for back |
| propagation. Prefer using `tf.stop_gradient` instead. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors, resulting from applying |
| `fn` consecutively to the list of tensors unpacked from `elems`, from first |
| to last. |
| |
| Raises: |
| TypeError: if `fn` is not callable. |
| |
| Example: |
| ```python |
| elems = tf.constant([1, 2, 3, 4, 5, 6]) |
| sum = tf.foldl(lambda a, x: a + x, elems) |
| # sum == 21 |
| ``` |
| """ |
| return foldl( |
| fn=fn, |
| elems=elems, |
| initializer=initializer, |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| name=name) |
| |
| |
| @tf_export(v1=["foldr"]) |
| @dispatch.add_dispatch_support |
| def foldr(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| name=None): |
| """foldr on the list of tensors unpacked from `elems` on dimension 0. |
| |
| This foldr operator repeatedly applies the callable `fn` to a sequence |
| of elements from last to first. The elements are made of the tensors |
| unpacked from `elems`. The callable fn takes two tensors as arguments. |
| The first argument is the accumulated value computed from the preceding |
| invocation of fn, and the second is the value at the current position of |
| `elems`. If `initializer` is None, `elems` must contain at least one element, |
| and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is `fn(initializer, values[0]).shape`. |
| |
| This method also allows multi-arity `elems` and output of `fn`. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The signature of `fn` may |
| match the structure of `elems`. That is, if `elems` is |
| `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: |
| `fn = lambda (t1, [t2, t3, [t4, t5]]):`. |
| |
| Args: |
| fn: The callable to be performed. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| as the initial value for the accumulator. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) True enables support for back propagation. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors, resulting from applying |
| `fn` consecutively to the list of tensors unpacked from `elems`, from last |
| to first. |
| |
| Raises: |
| TypeError: if `fn` is not callable. |
| |
| Example: |
| ```python |
| elems = [1, 2, 3, 4, 5, 6] |
| sum = foldr(lambda a, x: a + x, elems) |
| # sum == 21 |
| ``` |
| """ |
| if not callable(fn): |
| raise TypeError( |
| f"{fn.__name__} is not callable. Please provide a callable function.") |
| |
| def create_ta(elem): |
| return tensor_array_ops.TensorArray( |
| dtype=elem.dtype, size=n, dynamic_size=False, |
| infer_shape=True).unstack(elem) |
| |
| in_graph_mode = not context.executing_eagerly() |
| with ops.name_scope(name, "foldr", [elems]): |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode: |
| # Any get_variable calls in fn will cache the first call locally and not |
| # issue repeated network I/O requests for each iteration. |
| varscope = vs.get_variable_scope() |
| varscope_caching_device_was_none = False |
| if varscope.caching_device is None: |
| # TODO(ebrevdo): Change to using colocate_with here and in other |
| # methods. |
| varscope.set_caching_device(lambda op: op.device) |
| varscope_caching_device_was_none = True |
| |
| # Convert elems to tensor array. n may be known statically. |
| elems_flat = [ |
| ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) |
| ] |
| n = ( |
| tensor_shape.dimension_value(elems_flat[0].shape[0]) or |
| array_ops.shape(elems_flat[0])[0]) |
| |
| elems_ta = nest.map_structure(create_ta, elems) |
| |
| if initializer is None: |
| i = n - 1 |
| a = nest.map_structure(lambda elem: elem.read(i), elems_ta) |
| else: |
| i = n |
| a = initializer |
| |
| def compute(i, a): |
| i -= 1 |
| elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) |
| a_out = fn(a, elem) |
| return [i, a_out] |
| |
| _, r_a = while_loop.while_loop( |
| lambda i, a: i > 0, |
| compute, [i, a], |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| maximum_iterations=n) |
| |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode and varscope_caching_device_was_none: |
| varscope.set_caching_device(None) |
| |
| return r_a |
| |
| |
| @tf_export("foldr", v1=[]) |
| @dispatch.add_dispatch_support |
| @deprecation.deprecated_arg_values( |
| None, |
| """back_prop=False is deprecated. Consider using tf.stop_gradient instead. |
| Instead of: |
| results = tf.foldr(fn, elems, back_prop=False) |
| Use: |
| results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""", |
| warn_once=True, |
| back_prop=False) |
| def foldr_v2(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| name=None): |
| """foldr on the list of tensors unpacked from `elems` on dimension 0. |
| |
| This foldr operator repeatedly applies the callable `fn` to a sequence |
| of elements from last to first. The elements are made of the tensors |
| unpacked from `elems`. The callable fn takes two tensors as arguments. |
| The first argument is the accumulated value computed from the preceding |
| invocation of fn, and the second is the value at the current position of |
| `elems`. If `initializer` is None, `elems` must contain at least one element, |
| and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is `fn(initializer, values[0]).shape`. |
| |
| This method also allows multi-arity `elems` and output of `fn`. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The signature of `fn` may |
| match the structure of `elems`. That is, if `elems` is |
| `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: |
| `fn = lambda (t1, [t2, t3, [t4, t5]]):`. |
| |
| Args: |
| fn: The callable to be performed. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| as the initial value for the accumulator. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) Deprecated. False disables support for back |
| propagation. Prefer using `tf.stop_gradient` instead. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors, resulting from applying |
| `fn` consecutively to the list of tensors unpacked from `elems`, from last |
| to first. |
| |
| Raises: |
| TypeError: if `fn` is not callable. |
| |
| Example: |
| ```python |
| elems = [1, 2, 3, 4, 5, 6] |
| sum = tf.foldr(lambda a, x: a + x, elems) |
| # sum == 21 |
| ``` |
| """ |
| return foldr( |
| fn=fn, |
| elems=elems, |
| initializer=initializer, |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| name=name) |
| |
| |
| @tf_export(v1=["scan"]) |
| @dispatch.add_dispatch_support |
| def scan(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| infer_shape=True, |
| reverse=False, |
| name=None): |
| """scan on the list of tensors unpacked from `elems` on dimension 0. |
| |
| See also `tf.map_fn`. |
| |
| The simplest version of `scan` repeatedly applies the callable `fn` to a |
| sequence of elements from first to last. The elements are made of the tensors |
| unpacked from `elems` on dimension 0. The callable fn takes two tensors as |
| arguments. The first argument is the accumulated value computed from the |
| preceding invocation of fn, and the second is the value at the current |
| position of `elems`. If `initializer` is None, `elems` must contain at least |
| one element, and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. |
| If reverse=True, it's fn(initializer, values[-1]).shape. |
| |
| This method also allows multi-arity `elems` and accumulator. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The second argument of |
| `fn` must match the structure of `elems`. |
| |
| If no `initializer` is provided, the output structure and dtypes of `fn` |
| are assumed to be the same as its input; and in this case, the first |
| argument of `fn` must match the structure of `elems`. |
| |
| If an `initializer` is provided, then the output of `fn` must have the same |
| structure as `initializer`; and the first argument of `fn` must match |
| this structure. |
| |
| For example, if `elems` is `(t1, [t2, t3])` and `initializer` is |
| `[i1, i2]` then an appropriate signature for `fn` in `python2` is: |
| `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, |
| `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the |
| one that works in `python3`, is: |
| `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. |
| |
| Args: |
| fn: The callable to be performed. It accepts two arguments. The first will |
| have the same structure as `initializer` if one is provided, otherwise it |
| will have the same structure as `elems`. The second will have the same |
| (possibly nested) structure as `elems`. Its output must have the same |
| structure as `initializer` if one is provided, otherwise it must have the |
| same structure as `elems`. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| initial value for the accumulator, and the expected output type of `fn`. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) True enables support for back propagation. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| infer_shape: (optional) False disables tests for consistent output shapes. |
| reverse: (optional) True scans the tensor last to first (instead of first to |
| last). |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors. Each tensor packs the |
| results of applying `fn` to tensors unpacked from `elems` along the first |
| dimension, and the previous accumulator value(s), from first to last (or |
| last to first, if `reverse=True`). |
| |
| Raises: |
| TypeError: if `fn` is not callable or the structure of the output of |
| `fn` and `initializer` do not match. |
| ValueError: if the lengths of the output of `fn` and `initializer` |
| do not match. |
| |
| Examples: |
| ```python |
| elems = np.array([1, 2, 3, 4, 5, 6]) |
| sum = scan(lambda a, x: a + x, elems) |
| # sum == [1, 3, 6, 10, 15, 21] |
| sum = scan(lambda a, x: a + x, elems, reverse=True) |
| # sum == [21, 20, 18, 15, 11, 6] |
| ``` |
| |
| ```python |
| elems = np.array([1, 2, 3, 4, 5, 6]) |
| initializer = np.array(0) |
| sum_one = scan( |
| lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) |
| # sum_one == [1, 2, 3, 4, 5, 6] |
| ``` |
| |
| ```python |
| elems = np.array([1, 0, 0, 0, 0, 0]) |
| initializer = (np.array(0), np.array(1)) |
| fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) |
| # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) |
| ``` |
| """ |
| if not callable(fn): |
| raise TypeError( |
| f"{fn.__name__} is not callable. Please provide a callable function.") |
| |
| input_is_sequence = nest.is_nested(elems) |
| input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] |
| |
| def input_pack(x): |
| return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] |
| |
| if initializer is None: |
| output_is_sequence = input_is_sequence |
| output_flatten = input_flatten |
| output_pack = input_pack |
| else: |
| output_is_sequence = nest.is_nested(initializer) |
| output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] |
| |
| def output_pack(x): |
| return (nest.pack_sequence_as(initializer, x) |
| if output_is_sequence else x[0]) |
| |
| elems_flat = input_flatten(elems) |
| |
| in_graph_mode = not context.executing_eagerly() |
| with ops.name_scope(name, "scan", elems_flat): |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode: |
| # Any get_variable calls in fn will cache the first call locally |
| # and not issue repeated network I/O requests for each iteration. |
| varscope = vs.get_variable_scope() |
| varscope_caching_device_was_none = False |
| if varscope.caching_device is None: |
| # TODO(ebrevdo): Change to using colocate_with here and in other |
| # methods. |
| varscope.set_caching_device(lambda op: op.device) |
| varscope_caching_device_was_none = True |
| |
| # Convert elems to tensor array. |
| elems_flat = [ |
| ops.convert_to_tensor(elem, name="elem") for elem in elems_flat |
| ] |
| |
| # Convert elems to tensor array. n may be known statically. |
| n = tensor_shape.dimension_value(elems_flat[0].shape[0]) |
| if n is None: |
| n = array_ops.shape(elems_flat[0])[0] |
| |
| # TensorArrays are always flat |
| elems_ta = [ |
| tensor_array_ops.TensorArray( |
| dtype=elem.dtype, |
| size=n, |
| dynamic_size=False, |
| element_shape=elem.shape[1:], |
| infer_shape=True) for elem in elems_flat |
| ] |
| # Unpack elements |
| elems_ta = [ |
| elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat) |
| ] |
| |
| if initializer is None: |
| a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] |
| i = 1 |
| else: |
| initializer_flat = output_flatten(initializer) |
| a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] |
| i = 0 |
| |
| # Create a tensor array to store the intermediate values. |
| accs_ta = [ |
| tensor_array_ops.TensorArray( |
| dtype=init.dtype, |
| size=n, |
| element_shape=init.shape if infer_shape else None, |
| dynamic_size=False, |
| infer_shape=infer_shape) for init in a_flat |
| ] |
| |
| if initializer is None: |
| accs_ta = [ |
| acc_ta.write(n - 1 if reverse else 0, a) |
| for (acc_ta, a) in zip(accs_ta, a_flat) |
| ] |
| |
| def compute(i, a_flat, tas): |
| """The loop body of scan. |
| |
| Args: |
| i: the loop counter. |
| a_flat: the accumulator value(s), flattened. |
| tas: the output accumulator TensorArray(s), flattened. |
| |
| Returns: |
| [i + 1, a_flat, tas]: the updated counter + new accumulator values + |
| updated TensorArrays |
| |
| Raises: |
| TypeError: if initializer and fn() output structure do not match |
| ValueType: if initializer and fn() output lengths do not match |
| """ |
| packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) |
| packed_a = output_pack(a_flat) |
| a_out = fn(packed_a, packed_elems) |
| nest.assert_same_structure(elems if initializer is None else initializer, |
| a_out) |
| flat_a_out = output_flatten(a_out) |
| tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] |
| if reverse: |
| next_i = i - 1 |
| else: |
| next_i = i + 1 |
| return (next_i, flat_a_out, tas) |
| |
| if reverse: |
| initial_i = n - 1 - i |
| condition = lambda i, _1, _2: i >= 0 |
| else: |
| initial_i = i |
| condition = lambda i, _1, _2: i < n |
| _, _, r_a = while_loop.while_loop( |
| condition, |
| compute, (initial_i, a_flat, accs_ta), |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| maximum_iterations=n) |
| |
| results_flat = [r.stack() for r in r_a] |
| |
| n_static = tensor_shape.Dimension( |
| tensor_shape.dimension_value( |
| elems_flat[0].get_shape().with_rank_at_least(1)[0])) |
| for elem in elems_flat[1:]: |
| n_static.assert_is_compatible_with( |
| tensor_shape.Dimension( |
| tensor_shape.dimension_value( |
| elem.get_shape().with_rank_at_least(1)[0]))) |
| for r in results_flat: |
| r.set_shape( |
| tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:])) |
| |
| # TODO(akshayka): Remove the in_graph_mode check once caching devices are |
| # supported in Eager |
| if in_graph_mode and varscope_caching_device_was_none: |
| varscope.set_caching_device(None) |
| |
| return output_pack(results_flat) |
| |
| |
| @tf_export("scan", v1=[]) |
| @dispatch.add_dispatch_support |
| @deprecation.deprecated_arg_values( |
| None, |
| """back_prop=False is deprecated. Consider using tf.stop_gradient instead. |
| Instead of: |
| results = tf.scan(fn, elems, back_prop=False) |
| Use: |
| results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""", |
| warn_once=True, |
| back_prop=False) |
| def scan_v2(fn, |
| elems, |
| initializer=None, |
| parallel_iterations=10, |
| back_prop=True, |
| swap_memory=False, |
| infer_shape=True, |
| reverse=False, |
| name=None): |
| """scan on the list of tensors unpacked from `elems` on dimension 0. |
| |
| The simplest version of `scan` repeatedly applies the callable `fn` to a |
| sequence of elements from first to last. The elements are made of the tensors |
| unpacked from `elems` on dimension 0. The callable fn takes two tensors as |
| arguments. The first argument is the accumulated value computed from the |
| preceding invocation of fn, and the second is the value at the current |
| position of `elems`. If `initializer` is None, `elems` must contain at least |
| one element, and its first element is used as the initializer. |
| |
| Suppose that `elems` is unpacked into `values`, a list of tensors. The shape |
| of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. |
| If reverse=True, it's fn(initializer, values[-1]).shape. |
| |
| This method also allows multi-arity `elems` and accumulator. If `elems` |
| is a (possibly nested) list or tuple of tensors, then each of these tensors |
| must have a matching first (unpack) dimension. The second argument of |
| `fn` must match the structure of `elems`. |
| |
| If no `initializer` is provided, the output structure and dtypes of `fn` |
| are assumed to be the same as its input; and in this case, the first |
| argument of `fn` must match the structure of `elems`. |
| |
| If an `initializer` is provided, then the output of `fn` must have the same |
| structure as `initializer`; and the first argument of `fn` must match |
| this structure. |
| |
| For example, if `elems` is `(t1, [t2, t3])` and `initializer` is |
| `[i1, i2]` then an appropriate signature for `fn` in `python2` is: |
| `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, |
| `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the |
| one that works in `python3`, is: |
| `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. |
| |
| Args: |
| fn: The callable to be performed. It accepts two arguments. The first will |
| have the same structure as `initializer` if one is provided, otherwise it |
| will have the same structure as `elems`. The second will have the same |
| (possibly nested) structure as `elems`. Its output must have the same |
| structure as `initializer` if one is provided, otherwise it must have the |
| same structure as `elems`. |
| elems: A tensor or (possibly nested) sequence of tensors, each of which will |
| be unpacked along their first dimension. The nested sequence of the |
| resulting slices will be the first argument to `fn`. |
| initializer: (optional) A tensor or (possibly nested) sequence of tensors, |
| initial value for the accumulator, and the expected output type of `fn`. |
| parallel_iterations: (optional) The number of iterations allowed to run in |
| parallel. |
| back_prop: (optional) Deprecated. False disables support for back |
| propagation. Prefer using `tf.stop_gradient` instead. |
| swap_memory: (optional) True enables GPU-CPU memory swapping. |
| infer_shape: (optional) False disables tests for consistent output shapes. |
| reverse: (optional) True scans the tensor last to first (instead of first to |
| last). |
| name: (optional) Name prefix for the returned tensors. |
| |
| Returns: |
| A tensor or (possibly nested) sequence of tensors. Each tensor packs the |
| results of applying `fn` to tensors unpacked from `elems` along the first |
| dimension, and the previous accumulator value(s), from first to last (or |
| last to first, if `reverse=True`). |
| |
| Raises: |
| TypeError: if `fn` is not callable or the structure of the output of |
| `fn` and `initializer` do not match. |
| ValueError: if the lengths of the output of `fn` and `initializer` |
| do not match. |
| |
| Examples: |
| ```python |
| elems = np.array([1, 2, 3, 4, 5, 6]) |
| sum = scan(lambda a, x: a + x, elems) |
| # sum == [1, 3, 6, 10, 15, 21] |
| sum = scan(lambda a, x: a + x, elems, reverse=True) |
| # sum == [21, 20, 18, 15, 11, 6] |
| ``` |
| |
| ```python |
| elems = np.array([1, 2, 3, 4, 5, 6]) |
| initializer = np.array(0) |
| sum_one = scan( |
| lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) |
| # sum_one == [1, 2, 3, 4, 5, 6] |
| ``` |
| |
| ```python |
| elems = np.array([1, 0, 0, 0, 0, 0]) |
| initializer = (np.array(0), np.array(1)) |
| fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) |
| # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) |
| ``` |
| """ |
| return scan( |
| fn=fn, |
| elems=elems, |
| initializer=initializer, |
| parallel_iterations=parallel_iterations, |
| back_prop=back_prop, |
| swap_memory=swap_memory, |
| infer_shape=infer_shape, |
| reverse=reverse, |
| name=name) |
| |
| |
| # pylint: disable=invalid-name |
| def If(cond, inputs, then_branch, else_branch, name=None): |
| r"""output = Cond(inputs) ? |
| |
| then_branch(inputs) : else_branch(inputs). |
| |
| Args: |
| cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is |
| converted to a boolean according to the following rule: if the scalar is a |
| numerical value, non-zero means True and zero means False; if the scalar |
| is a string, non-empty means True and empty means False. |
| inputs: A list of input tensors. |
| then_branch: A function takes 'inputs' and returns a list of tensors, whose |
| types are the same as what else_branch returns. |
| else_branch: A function takes 'inputs' and returns a list of tensors. whose |
| types are the same as what then_branch returns. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A list of tensors returned by either then_branch(inputs) |
| or else_branch(inputs). |
| """ |
| # pylint: disable=protected-access |
| # Handle the Defun case until users have transitioned to tf.function. Note |
| # that composites may need to be re-packed by the caller. |
| if isinstance(then_branch, function._DefinedFunction): |
| tlist = [_.type for _ in then_branch.definition.signature.output_arg] |
| return gen_functional_ops._if( |
| cond, inputs, tlist, then_branch, else_branch, name=name) |
| |
| # We assume that `then_branch` is a ConcreteFunction here. |
| then_out = then_branch.structured_outputs |
| else_out = else_branch.structured_outputs |
| |
| # Ensure then/else are the same type of composites to avoid an invalid call |
| # to pack_sequence_as later on. |
| nest.assert_same_structure(then_out, else_out, expand_composites=True) |
| |
| tlist = nest.flatten(then_branch.output_dtypes) |
| ret = gen_functional_ops._if( |
| cond, inputs, tlist, then_branch, else_branch, name=name) |
| |
| # Re-pack the outputs to restore any CompositeTensors |
| return nest.pack_sequence_as(then_out, ret, expand_composites=True) |
| |
| |
| def Gradient(inputs, f, name=None): |
| r"""Computes the gradient function for function f via backpropagation. |
| |
| Args: |
| inputs: A list of tensors of size N + M. |
| f: The function we want to compute the gradient for. The function 'f' must |
| be a numerical function which takes N inputs and produces M outputs. Its |
| gradient function 'g', which is a function taking N + M inputs and |
| produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., |
| xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, |
| dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., |
| xN) (e.g., the loss function). dL/dxi is the partial derivative of L with |
| respect to xi. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A list of tensors of size N. |
| """ |
| # TODO(zhifengc): Pretty-print the above spec in latex. |
| # TODO(zhfiengc): Needs some math expert to say the comment above better. |
| tlist = [_.type for _ in f.definition.signature.input_arg] |
| return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) |
| |
| |
| def _GetInputDtypes(func): |
| """Returns the input dtypes of func, excluding dtypes for captured inputs.""" |
| if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access |
| return func.declared_input_types |
| |
| # We assume that `func` is a ConcreteFunction here, but we are not able to |
| # verify since importing eager function library will cause cyclic dependence. |
| # |
| # ConcreteFunction.inputs includes captured inputs. |
| num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs) |
| inputs_without_captured = func.inputs[:num_non_captured_inputs] |
| return [t.dtype for t in inputs_without_captured] |
| |
| |
| def _LoopBodyCaptureWrapper(func): |
| """Returns a wrapper for `func` that handles loop-carried captured inputs.""" |
| |
| @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name) |
| def Wrapper(*args): |
| """A wrapper that handles loop-carried captured inputs.""" |
| result = func(*args) |
| extra_args = tuple(function.get_extra_args()) |
| # Nullary functions return an Operation. Normal functions can't do this |
| # because their return values are converted to Tensors. |
| if isinstance(result, ops.Operation): |
| return extra_args |
| # Unary functions return a single Tensor value. |
| elif not isinstance(result, (list, tuple)): |
| return (result,) + extra_args |
| # N-ary functions return a tuple of Tensors. |
| else: |
| return result + type(result)(extra_args) |
| |
| return Wrapper |
| |
| |
| # pylint: disable=invalid-name,protected-access |
| def While(input_, cond, body, name=None, hostmem=None): |
| r"""output = input; While (Cond(output)) { output = Body(output) }. |
| |
| Args: |
| input_: A list of `Tensor` objects. A list of input tensors whose types are |
| T. |
| cond: . A function takes 'input' and returns a tensor. If the tensor is a |
| scalar of non-boolean, the scalar is converted to a boolean |
| according to the following rule: if the scalar is a numerical value, |
| non-zero means True and zero means False; if the scalar is a string, |
| non-empty means True and empty means False. If the tensor is not a |
| scalar, non-emptiness means True and False otherwise. |
| body: . A function takes a list of tensors and returns another list tensors. |
| Both lists have the same types as specified by T. |
| name: A name for the operation (optional). |
| hostmem: A list of integer. If i is in the list, input[i] is a host memory |
| tensor. |
| |
| Raises: |
| ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` |
| have different signatures. |
| |
| Returns: |
| A list of `Tensor` objects. Has the same type as `input`. |
| A list of output tensors whose types are T. |
| """ |
| if cond.captured_inputs: |
| raise ValueError( |
| "The 'cond' argument can not have implicitly captured inputs. Received " |
| f"captured_inputs: {cond.captured_inputs}") |
| |
| cond_input_types = _GetInputDtypes(cond) |
| body_input_types = _GetInputDtypes(body) |
| |
| if cond_input_types != body_input_types: |
| raise ValueError( |
| "The 'cond' and 'body' signatures do not match. Received: " |
| f"cond_input_types={cond_input_types}, body_input_types=" |
| f"{body_input_types}") |
| |
| if body.captured_inputs: |
| cond_dtypes = list(body_input_types) + [ |
| t.dtype for t in body.captured_inputs |
| ] |
| |
| @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) |
| def CondWrapper(*args): |
| """A wrapper that handles loop-carried captured inputs.""" |
| return cond(*args[:len(body_input_types)]) |
| |
| ret = gen_functional_ops._while( |
| input_ + body.captured_inputs, |
| CondWrapper, |
| _LoopBodyCaptureWrapper(body), |
| name=name) |
| # Slice off the loop-carried captured inputs. |
| ret = ret[:-len(body.captured_inputs)] |
| else: |
| ret = gen_functional_ops._while(input_, cond, body, name=name) |
| if hostmem: |
| input_attr = attr_value_pb2.AttrValue() |
| input_attr.list.i.extend(hostmem) |
| ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access |
| |
| output_attr = attr_value_pb2.AttrValue() |
| output_attr.list.i.extend(hostmem) |
| ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access |
| return ret |
| |
| |
| # b/36459430 |
| # |
| # Ideally, we do not need this rewrite For loop into a While loop. |
| # However, today, if a While runs on GPU and the condition returns a |
| # boolean, the While kernel crashes. Even if we fix the crash, the |
| # bool needs to be copied between GPU and CPU. So, a for loop is much |
| # preferred when running on GPU. |
| # |
| # On the other hand, For op has no directly XLA kernel. So, when we run |
| # a for loop, we need to rewrite it using a While op. |
| # |
| # It should be possible and probably better to write a XLA C++ kernel |
| # implementing the logic in _ForUsingWhile. |
| def _ForUsingWhile(start, |
| limit, |
| delta, |
| inputs, |
| forbody, |
| name=None, |
| hostmem=None): |
| """Helper to implement a For loop using a While.""" |
| # To support negative delta (e.g., range(100, 0, -3)), we iterate |
| # over the range(n) and use iter * delta + start as the real |
| # iteration index. (e.g., for i in range(34): iter = i * (-3) + |
| # 100). |
| d = math_ops.abs(delta) |
| # XLA on TPUs doesn't support integer division |
| n = math_ops.cast( |
| math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / |
| math_ops.cast(d, dtypes.float32), dtypes.int32) |
| |
| # Carried loop variables ("extra_args") are implicitly added to the input list |
| # of the WhileBody function. WhileCond does not call forbody, and so does not |
| # depend on any of forbody's extra_args. Since WhileCond and WhileBody |
| # must have identical inputs, we have to augment the cond signature to take |
| # the same types as the carried loop variables. |
| body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] |
| |
| cond_name = "%s_Cond" % forbody.name |
| |
| @function.Defun(*body_sig, func_name=cond_name) |
| def WhileCond(i, n, *args): |
| del args |
| return i < n |
| |
| body_name = "%s_Body" % forbody.name |
| |
| @function.Defun(*body_sig, func_name=body_name) |
| def WhileBody(i, n, start, delta, *args): |
| """A While wrapper for forbody that handles loop-carried captured inputs.""" |
| for_result = forbody(start + i * delta, *args) |
| # Nullary functions return an Operation. Normal functions can't do this |
| # because their return values are converted to Tensors. |
| if isinstance(for_result, ops.Operation): |
| for_result = () |
| # Unary functions return a single Tensor value. |
| elif isinstance(for_result, ops.Tensor): |
| for_result = (for_result,) |
| return (i + 1, n, start, delta) + tuple(for_result) |
| |
| if hostmem is not None: |
| hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] |
| else: |
| hostmem = [0, 1, 2, 3] |
| |
| results = While( |
| input_=[0, n, start, delta] + inputs, |
| cond=WhileCond, |
| body=WhileBody, |
| name=name, |
| hostmem=hostmem) |
| # Slice off the loop-carried captured inputs. |
| return list(results[4:len(results)]) |
| |
| |
| def For(start, |
| limit, |
| delta, |
| inputs, |
| body, |
| name=None, |
| hostmem=None, |
| rewrite_with_while=None): |
| r"""out = input; for i in range(start, limit, delta) out = body(i, out). |
| |
| Args: |
| start: A `Tensor` of type `int32`. |
| limit: A `Tensor` of type `int32`. |
| delta: A `Tensor` of type `int32`. |
| inputs: A list of `Tensor` objects. A list of input tensors whose types are |
| T. |
| body: A function takes a list of tensors and returns another list of |
| tensors. Both lists have the same types as (int32, T...). |
| name: A name for the operation (optional). |
| hostmem: A list of integer. If i is in the list, inputs[i] is a host memory |
| tensor. In other words, (i+1)-th argument of the body function is |
| expecting a host memory. |
| rewrite_with_while: If True, using While op to implement the For. |
| |
| Returns: |
| A list of `Tensor` objects. Has the same type as `input`. |
| A list of output tensors whose types are T. |
| """ |
| if rewrite_with_while: |
| return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) |
| if body.captured_inputs: |
| ret = gen_functional_ops._for( |
| start, |
| limit, |
| delta, |
| inputs + body.captured_inputs, |
| _LoopBodyCaptureWrapper(body), |
| name=name) |
| # Slice off the loop-carried captured inputs. |
| ret = ret[:-len(body.captured_inputs)] |
| else: |
| ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) |
| if hostmem: |
| num_for_params = 3 # start/limit/delta |
| |
| input_attr = attr_value_pb2.AttrValue() |
| input_attr.list.i.extend([num_for_params + i for i in hostmem]) |
| ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access |
| |
| output_attr = attr_value_pb2.AttrValue() |
| output_attr.list.i.extend(hostmem) |
| ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access |
| return ret |