| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Decorator to overrides the gradient for a function.""" |
| |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import record |
| from tensorflow.python.framework import composite_tensor_gradient |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import handle_data_util |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import op_selector |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_decorator |
| from tensorflow.python.util import tf_inspect |
| from tensorflow.python.util import variable_utils |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| VAR_OP_TYPES = [ |
| "VariableV2", |
| "VarHandleOp", |
| ] |
| |
| |
| @tf_export("custom_gradient") |
| def custom_gradient(f=None): |
| """Decorator to define a function with a custom gradient. |
| |
| This decorator allows fine grained control over the gradients of a sequence |
| for operations. This may be useful for multiple reasons, including providing |
| a more efficient or numerically stable gradient for a sequence of operations. |
| |
| For example, consider the following function that commonly occurs in the |
| computation of cross entropy and log likelihoods: |
| |
| ```python |
| def log1pexp(x): |
| return tf.math.log(1 + tf.exp(x)) |
| ``` |
| |
| Due to numerical instability, the gradient of this function evaluated at x=100 |
| is NaN. For example: |
| |
| ```python |
| with tf.GradientTape() as tape: |
| tape.watch(x) |
| y=log1pexp(x) |
| dy_dx = tape.gradient(y, x) # Will be NaN when evaluated. |
| ``` |
| |
| The gradient expression can be analytically simplified to provide numerical |
| stability: |
| |
| ```python |
| @tf.custom_gradient |
| def log1pexp(x): |
| e = tf.exp(x) |
| def grad(upstream): |
| return upstream * (1 - 1 / (1 + e)) |
| return tf.math.log(1 + e), grad |
| ``` |
| |
| With this definition, the gradient `dy_dx` at `x = 100` will be correctly |
| evaluated as 1.0. |
| |
| The variable `upstream` is defined as the upstream gradient. i.e. the gradient |
| from all the layers or functions originating from this layer. The above |
| example has no upstream functions, therefore `upstream = dy/dy = 1.0`. |
| |
| Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`, |
| `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By |
| chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * |
| dx_i/dx_i-1 * ... * dx_1/dx_0`. |
| |
| In this case the gradient of our current function defined as |
| `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be |
| `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient |
| multiplied by the current gradient is then passed downstream. |
| |
| In case the function takes multiple variables as input, the `grad` |
| function must also return the same number of variables. |
| We take the function `z = x * y` as an example. |
| |
| >>> @tf.custom_gradient |
| ... def bar(x, y): |
| ... def grad(upstream): |
| ... dz_dx = y |
| ... dz_dy = x |
| ... return upstream * dz_dx, upstream * dz_dy |
| ... z = x * y |
| ... return z, grad |
| >>> x = tf.constant(2.0, dtype=tf.float32) |
| >>> y = tf.constant(3.0, dtype=tf.float32) |
| >>> with tf.GradientTape(persistent=True) as tape: |
| ... tape.watch(x) |
| ... tape.watch(y) |
| ... z = bar(x, y) |
| >>> z |
| <tf.Tensor: shape=(), dtype=float32, numpy=6.0> |
| >>> tape.gradient(z, x) |
| <tf.Tensor: shape=(), dtype=float32, numpy=3.0> |
| >>> tape.gradient(z, y) |
| <tf.Tensor: shape=(), dtype=float32, numpy=2.0> |
| |
| Nesting custom gradients can lead to unintuitive results. The default |
| behavior does not correspond to n-th order derivatives. For example |
| |
| ```python |
| @tf.custom_gradient |
| def op(x): |
| y = op1(x) |
| @tf.custom_gradient |
| def grad_fn(dy): |
| gdy = op2(x, y, dy) |
| def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. |
| return op3(x, y, dy, ddy) |
| return gdy, grad_grad_fn |
| return y, grad_fn |
| ``` |
| |
| The function `grad_grad_fn` will be calculating the first order gradient |
| of `grad_fn` with respect to `dy`, which is used to generate forward-mode |
| gradient graphs from backward-mode gradient graphs, but is not the same as |
| the second order gradient of `op` with respect to `x`. |
| |
| Instead, wrap nested `@tf.custom_gradients` in another function: |
| |
| ```python |
| @tf.custom_gradient |
| def op_with_fused_backprop(x): |
| y, x_grad = fused_op(x) |
| def first_order_gradient(dy): |
| @tf.custom_gradient |
| def first_order_custom(unused_x): |
| def second_order_and_transpose(ddy): |
| return second_order_for_x(...), gradient_wrt_dy(...) |
| return x_grad, second_order_and_transpose |
| return dy * first_order_custom(x) |
| return y, first_order_gradient |
| ``` |
| |
| Additional arguments to the inner `@tf.custom_gradient`-decorated function |
| control the expected return values of the innermost function. |
| |
| The examples above illustrate how to specify custom gradients for functions |
| which do not read from variables. The following example uses variables, which |
| require special handling because they are effectively inputs of the forward |
| function. |
| |
| >>> weights = tf.Variable(tf.ones([2])) # Trainable variable weights |
| >>> @tf.custom_gradient |
| ... def linear_poly(x): |
| ... # Creating polynomial |
| ... poly = weights[1] * x + weights[0] |
| ... |
| ... def grad_fn(dpoly, variables): |
| ... # dy/dx = weights[1] and we need to left multiply dpoly |
| ... grad_xs = dpoly * weights[1] # Scalar gradient |
| ... |
| ... grad_vars = [] # To store gradients of passed variables |
| ... assert variables is not None |
| ... assert len(variables) == 1 |
| ... assert variables[0] is weights |
| ... # Manually computing dy/dweights |
| ... dy_dw = dpoly * tf.stack([x ** 1, x ** 0]) |
| ... grad_vars.append( |
| ... tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1) |
| ... ) |
| ... return grad_xs, grad_vars |
| ... return poly, grad_fn |
| >>> x = tf.constant([1., 2., 3.]) |
| >>> with tf.GradientTape(persistent=True) as tape: |
| ... tape.watch(x) |
| ... poly = linear_poly(x) |
| >>> poly # poly = x + 1 |
| <tf.Tensor: shape=(3,), |
| dtype=float32, |
| numpy=array([2., 3., 4.], dtype=float32)> |
| >>> tape.gradient(poly, x) # conventional scalar gradient dy/dx |
| <tf.Tensor: shape=(3,), |
| dtype=float32, |
| numpy=array([1., 1., 1.], dtype=float32)> |
| >>> tape.gradient(poly, weights) |
| <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)> |
| |
| Above example illustrates usage of trainable variable `weights`. |
| In the example, the inner `grad_fn` accepts an extra `variables` input |
| parameter and also returns an extra `grad_vars` output. That extra argument |
| is passed if the forward function reads any variables. You need to |
| compute the gradient w.r.t. each of those `variables` and output it as a list |
| of `grad_vars`. Note here that default value of `variables` is set to `None` |
| when no variables are used in the forward function. |
| |
| It should be noted `tf.GradientTape` is still watching the forward pass of a |
| `tf.custom_gradient`, and will use the ops it watches. As a consequence, |
| calling `tf.function` while the tape is still watching leads |
| to a gradient graph being built. If an op is used in `tf.function` without |
| registered gradient, a `LookupError` will be raised. |
| |
| Users can insert `tf.stop_gradient` to customize this behavior. This |
| is demonstrated in the example below. `tf.random.shuffle` does not have a |
| registered gradient. As a result `tf.stop_gradient` is used to avoid the |
| `LookupError`. |
| |
| ```python |
| x = tf.constant([0.3, 0.5], dtype=tf.float32) |
| |
| @tf.custom_gradient |
| def test_func_with_stop_grad(x): |
| @tf.function |
| def _inner_func(): |
| # Avoid exception during the forward pass |
| return tf.stop_gradient(tf.random.shuffle(x)) |
| # return tf.random.shuffle(x) # This will raise |
| |
| res = _inner_func() |
| def grad(upstream): |
| return upstream # Arbitrarily defined custom gradient |
| return res, grad |
| |
| with tf.GradientTape() as g: |
| g.watch(x) |
| res = test_func_with_stop_grad(x) |
| |
| g.gradient(res, x) |
| ``` |
| |
| See also `tf.RegisterGradient` which registers a gradient function for a |
| primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows |
| for fine grained control over the gradient computation of a sequence of |
| operations. |
| |
| Note that if the decorated function uses `Variable`s, the enclosing variable |
| scope must be using |
| [ResourceVariables](https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables). |
| |
| Args: |
| f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: |
| - `x` is a sequence of (nested structures of) `Tensor` inputs to the |
| function. |
| - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow |
| operations in `f` to `x`. |
| - `grad_fn` is a function with the signature `g(*grad_ys)` which returns |
| a list of `Tensor`s the same size as (flattened) `x` - the derivatives |
| of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is |
| a sequence of `Tensor`s the same size as (flattened) `y` holding the |
| initial value gradients for each `Tensor` in `y`. |
| |
| In a pure mathematical sense, a vector-argument vector-valued function |
| `f`'s derivatives should be its Jacobian matrix `J`. Here we are |
| expressing the Jacobian `J` as a function `grad_fn` which defines how |
| `J` will transform a vector `grad_ys` when left-multiplied with it |
| (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional |
| representation of a matrix is convenient to use for chain-rule |
| calculation (in e.g. the back-propagation algorithm). |
| |
| If `f` uses `Variable`s (that are not part of the |
| inputs), i.e. through `get_variable`, then `grad_fn` should have |
| signature `g(*grad_ys, variables=None)`, where `variables` is a list of |
| the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where |
| `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` |
| with the derivatives of `Tensor`s in `y` with respect to the variables |
| (that is, grad_vars has one Tensor per variable in variables). |
| |
| Returns: |
| A function `h(x)` which returns the same value as `f(x)[0]` and whose |
| gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. |
| """ |
| |
| if f is None: |
| return lambda f: custom_gradient(f=f) |
| |
| @Bind.decorator |
| def decorated(wrapped, args, kwargs): |
| """Decorated function with custom gradient.""" |
| if context.executing_eagerly(): |
| return _eager_mode_decorator(wrapped, args, kwargs) |
| else: |
| return _graph_mode_decorator(wrapped, args, kwargs) |
| |
| return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter |
| |
| |
| class Bind: |
| """When called evaluates `d(f, args, kwargs)` but supports binding `f`. |
| |
| >>> @Bind.decorator |
| ... def my_decorator(f, args, kwargs): |
| ... print("my_decorator called with", args, kwargs) |
| ... return f(*args, **kwargs) |
| |
| >>> class Foo: |
| ... @my_decorator |
| ... def bar(self, a, b, c): |
| ... return a * b * c |
| |
| >>> Foo.bar(None, 1, 2, c=3) |
| my_decorator called with (None, 1, 2) {'c': 3} |
| 6 |
| |
| >>> foo = Foo() |
| >>> foo.bar(1, 2, c=3) |
| my_decorator called with (1, 2) {'c': 3} |
| 6 |
| """ |
| |
| @classmethod |
| def decorator(cls, d): |
| return lambda f: Bind(f, d) |
| |
| def __init__(self, f, d): |
| self._f = f |
| self._d = d |
| |
| def __get__(self, instance, owner): |
| if instance is not None: |
| f = self._f.__get__(instance, owner) |
| return tf_decorator.make_decorator(f, Bind(f, self._d)) |
| else: |
| return self |
| |
| def __call__(self, *a, **k): |
| return self._d(self._f, a, k) |
| |
| |
| def get_variable_by_name(var_name): |
| """Given a variable name, retrieves a handle on the tensorflow Variable.""" |
| global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |
| |
| def _filter_fn(item): |
| try: |
| return var_name == item.op.name |
| except AttributeError: |
| # Collection items without operation are ignored. |
| return False |
| |
| candidate_vars = list(filter(_filter_fn, global_vars)) |
| |
| if len(candidate_vars) >= 1: |
| # Filter out non-trainable variables. |
| candidate_vars = [v for v in candidate_vars if v.trainable] |
| else: |
| raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) |
| |
| if len(candidate_vars) == 1: |
| return candidate_vars[0] |
| elif len(candidate_vars) > 1: |
| raise ValueError( |
| "Unsuccessful at finding trainable variable {}. " |
| "Number of candidates: {}. " |
| "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) |
| else: |
| # The variable is not trainable. |
| return None |
| |
| |
| def _get_dependent_variables(input_ops, output_ops): |
| """Finds variables involved in the subgraph between input_ops and output_ops. |
| |
| Args: |
| input_ops: Flattened list of input ops |
| output_ops: Flattened list of output ops |
| |
| Returns: |
| A list of variables |
| """ |
| |
| # avoids the edge-case when input_ops == output_ops. |
| output_ops = nest.map_structure(gen_array_ops.identity, output_ops) |
| inbetween_ops = op_selector.get_backward_walk_ops( |
| seed_ops=output_ops, |
| stop_at_ts=input_ops, |
| inclusive=False, |
| only_differentiable=True) |
| var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) |
| var_names = (op.name for op in var_ops) |
| tf_vars = (get_variable_by_name(var_name) for var_name in var_names) |
| tf_vars = [v for v in tf_vars if v is not None] |
| return tf_vars |
| |
| |
| def generate_name(): |
| return "CustomGradient-%s" % ops.uid() |
| |
| |
| def _graph_mode_decorator(f, args, kwargs): |
| """Implement custom gradient decorator for graph mode.""" |
| # TODO(rsepassi): Add support for kwargs |
| if kwargs: |
| raise ValueError( |
| "The custom_gradient decorator currently supports keywords " |
| "arguments only when eager execution is enabled.") |
| name = generate_name() |
| args = variable_utils.convert_variables_to_tensors(args) |
| args = nest.map_structure(ops.convert_to_tensor, args, expand_composites=True) |
| |
| # Checking global and local variables attempts to ensure that no non-resource |
| # Variables are added to the graph. |
| current_var_scope = variable_scope.get_variable_scope() |
| before_vars = set([ |
| v.ref() for v in current_var_scope.global_variables() + |
| current_var_scope.local_variables() |
| ]) |
| with record.VariableWatcher() as variable_watcher: |
| result, grad_fn = f(*args) |
| |
| flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(args)) |
| flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(result)) |
| flat_result_len = len(flat_result) |
| |
| after_vars = set([ |
| v.ref() for v in current_var_scope.global_variables() + |
| current_var_scope.local_variables() |
| ]) |
| new_vars = after_vars - before_vars |
| new_vars_list = [v.deref() for v in new_vars] |
| for v in new_vars_list: |
| if not resource_variable_ops.is_resource_variable(v): |
| raise TypeError( |
| "All variables used by a function wrapped with @custom_gradient must " |
| "be `ResourceVariable`s. Ensure that no `variable_scope` is created " |
| "with `use_resource=False`.") |
| |
| # The variables that grad_fn needs to return gradients for are the set of |
| # variables used that are *not* part of the inputs. |
| variables_in_tape = frozenset([ |
| v.ref() for v in variable_watcher.watched_variables() |
| ]) |
| |
| graphs = {getattr(o, "graph", None) for o in flat_result} |
| # Not all results may be tensors. However, we want to ensure all tensor |
| # outputs are from the same graph and get a list of captured inputs for |
| # variable search |
| graphs.discard(None) # Discard non-graph outputs |
| if graphs: |
| if len(graphs) > 1: |
| raise ValueError( |
| "All custom_gradient outputs should be from the same graph") |
| output_graph = graphs.pop() |
| filtered_input_tensors = [] |
| for i in flat_args: |
| if i.graph == output_graph: |
| filtered_input_tensors.append(i) |
| else: |
| filtered_input_tensors = flat_args |
| |
| variables_in_subgraph = frozenset([ |
| v.ref() for v in _get_dependent_variables( |
| input_ops=filtered_input_tensors, output_ops=flat_result) |
| ]) |
| variables = sorted( |
| [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], |
| key=lambda v: v.name) |
| |
| grad_argspec = tf_inspect.getfullargspec(grad_fn) |
| variables_in_signature = ("variables" in grad_argspec.args or |
| "variables" in grad_argspec.kwonlyargs or |
| grad_argspec.varkw) |
| if variables and not variables_in_signature: |
| raise TypeError( |
| "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " |
| "since function uses variables: {}".format(variables)) |
| if variables_in_signature and not variables: |
| # User seems to intend to use variables but none were captured. |
| logging.vlog( |
| 1, "@custom_gradient grad_fn has 'variables' in signature, " |
| "but no ResourceVariables were used on the forward pass.") |
| |
| all_tensors = flat_result + flat_args + variables |
| |
| def tape_grad_fn(*result_grad_components): |
| """Custom grad fn wrapper.""" |
| result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients( |
| nest.flatten(result), result_grad_components[:flat_result_len]) |
| if not isinstance(result_grads, (list, tuple)): |
| result_grads = [result_grads] |
| |
| if variables: |
| input_grads, variable_grads = grad_fn(*result_grads, variables=variables) |
| if len(variable_grads) != len(variables): |
| raise ValueError("Must return gradient for each variable from " |
| "@custom_gradient grad_fn.") |
| else: |
| input_grads = grad_fn(*result_grads) |
| variable_grads = [] |
| |
| # Need to return one value per input to the IdentityN, so pad the |
| # gradients of the inputs of the custom_gradient function with the |
| # gradients of the outputs as well. |
| input_grads = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(input_grads)) |
| return ([None] * flat_result_len) + input_grads + variable_grads |
| |
| @ops.RegisterGradient(name) |
| def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable |
| """Custom grad fn wrapper.""" |
| return tape_grad_fn(*result_grads) |
| |
| original_tensors = all_tensors |
| with ops.get_default_graph().gradient_override_map({"IdentityN": name}): |
| all_tensors = array_ops.identity_n(all_tensors) |
| |
| original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] |
| |
| # Propagate handle data for happier shape inference for resource variables. |
| for i, t in enumerate(original_tensors): |
| if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): |
| all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access |
| record.record_operation( |
| f.__name__, all_tensors, original_tensors, tape_grad_fn) |
| for ot, t in zip(original_tensors, all_tensors): |
| handle_data_util.copy_handle_data(ot, t) |
| flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients( |
| nest.flatten(result), all_tensors[:flat_result_len]) |
| return nest.pack_sequence_as(result, flat_result) |
| |
| |
| def _eager_mode_decorator(f, args, kwargs): |
| """Implement custom gradient decorator for eager mode.""" |
| with record.VariableWatcher() as variable_watcher: |
| result, grad_fn = f(*args, **kwargs) |
| flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(args)) |
| flat_kwargs = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(kwargs)) |
| all_inputs = flat_args + flat_kwargs |
| # The variables that grad_fn needs to return gradients for are the set of |
| # variables used that are *not* part of the inputs. |
| variables = [ |
| v.deref() # pylint: disable=g-complex-comprehension |
| for v in set(v.ref() for v in variable_watcher.watched_variables()) |
| if all(v.deref() is not i for i in all_inputs) |
| ] |
| grad_argspec = tf_inspect.getfullargspec(grad_fn) |
| if (variables and ("variables" not in grad_argspec.args) and |
| ("variables" not in grad_argspec.kwonlyargs) and |
| not grad_argspec.varkw): |
| raise TypeError( |
| "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " |
| "since function uses variables: {}".format(variables)) |
| flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(result)) |
| # TODO(apassos) consider removing the identity below. |
| flat_result = [gen_array_ops.identity(x) for x in flat_result] |
| |
| input_tensors = [ |
| ops.convert_to_tensor(x) for x in flat_args + list(variables)] |
| |
| recorded_inputs = input_tensors |
| arg_count = len(flat_args) |
| |
| def actual_grad_fn(*result_grad_components): |
| """Custom grad fn wrapper.""" |
| result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients( |
| nest.flatten(result), result_grad_components) |
| if not isinstance(result_grads, (list, tuple)): |
| result_grads = [result_grads] |
| |
| if variables: |
| input_grads, variable_grads = grad_fn(*result_grads, variables=variables) |
| if len(variable_grads) != len(variables): |
| raise ValueError("Must return gradient for each variable from " |
| "@custom_gradient grad_fn.") |
| else: |
| input_grads = grad_fn(*result_grads) |
| variable_grads = [] |
| flat_grads = composite_tensor_gradient.get_flat_tensors_for_gradients( |
| nest.flatten(input_grads)) |
| if len(flat_grads) != arg_count: |
| raise ValueError( |
| f"custom_gradient function expected to return {arg_count} " |
| f"gradients, but returned {len(flat_grads)} instead.") |
| return flat_grads + variable_grads |
| |
| record.record_operation(f.__name__, flat_result, recorded_inputs, |
| actual_grad_fn) |
| flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients( |
| nest.flatten(result), flat_result) |
| return nest.pack_sequence_as(result, flat_result) |
| |
| |
| @tf_export("recompute_grad") |
| def recompute_grad(f): |
| """Defines a function as a recompute-checkpoint for the tape auto-diff. |
| |
| Tape checkpointing is a technique to reduce the memory consumption of the |
| auto-diff tape: |
| |
| - Without tape checkpointing operations and intermediate values are |
| recorded to the tape for use in the backward pass. |
| |
| - With tape checkpointing, only the function call and its inputs are |
| recorded. During back-propagation the `recompute_grad` custom gradient |
| (`tf.custom_gradient`) recomputes the function under a localized Tape object. |
| This recomputation of the function during backpropagation performs redundant |
| calculation, but reduces the overall memory usage of the Tape. |
| |
| >>> y = tf.Variable(1.0) |
| |
| >>> def my_function(x): |
| ... tf.print('running') |
| ... z = x*y |
| ... return z |
| |
| >>> my_function_recompute = tf.recompute_grad(my_function) |
| |
| >>> with tf.GradientTape() as tape: |
| ... r = tf.constant(1.0) |
| ... for i in range(4): |
| ... r = my_function_recompute(r) |
| running |
| running |
| running |
| running |
| |
| >>> grad = tape.gradient(r, [y]) |
| running |
| running |
| running |
| running |
| |
| Without `recompute_grad`, the tape contains all intermitate steps, and no |
| recomputation is performed. |
| |
| >>> with tf.GradientTape() as tape: |
| ... r = tf.constant(1.0) |
| ... for i in range(4): |
| ... r = my_function(r) |
| running |
| running |
| running |
| running |
| |
| >>> grad = tape.gradient(r, [y]) |
| |
| |
| If `f` was a `tf.keras` `Model` or `Layer` object, methods and attributes |
| such as `f.variables` are not available on the returned function `g`. |
| Either keep a reference of `f` , or use `g.__wrapped__` for accessing |
| these variables and methods. |
| |
| |
| >>> def print_running_and_return(x): |
| ... tf.print("running") |
| ... return x |
| |
| >>> model = tf.keras.Sequential([ |
| ... tf.keras.layers.Lambda(print_running_and_return), |
| ... tf.keras.layers.Dense(2) |
| ... ]) |
| |
| >>> model_recompute = tf.recompute_grad(model) |
| |
| >>> with tf.GradientTape(persistent=True) as tape: |
| ... r = tf.constant([[1,2]]) |
| ... for i in range(4): |
| ... r = model_recompute(r) |
| running |
| running |
| running |
| running |
| |
| >>> grad = tape.gradient(r, model.variables) |
| running |
| running |
| running |
| running |
| |
| Alternatively, use the `__wrapped__` attribute to access the original |
| model object. |
| |
| >>> grad = tape.gradient(r, model_recompute.__wrapped__.variables) |
| running |
| running |
| running |
| running |
| |
| |
| Args: |
| f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. |
| |
| Returns: |
| A function `g` wrapping `f` that defines a custom gradient, which recomputes |
| `f` on the backwards pass of a gradient call. |
| """ |
| # TODO(cdfreeman) Add is_recomputing functionality from graph mode version |
| |
| @custom_gradient |
| def inner(*args, **kwargs): |
| """Inner function closure for calculating gradients.""" |
| current_var_scope = variable_scope.get_variable_scope() |
| with record.stop_recording(): |
| result = f(*args, **kwargs) |
| |
| def grad_wrapper(*wrapper_args, variables=None): |
| """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient.""" |
| |
| @custom_gradient |
| def inner_recompute_grad(*dresult): |
| """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" |
| # Gradient calculation for reverse mode autodiff. |
| with backprop.GradientTape() as t: |
| id_args = nest.map_structure(gen_array_ops.identity, args) |
| # Tuple `dresult` should contain at least one tensor. |
| assert len(dresult) >= 1 |
| |
| if not context.executing_eagerly(): |
| # XLA doesn't respect `tf.control_dependencies`. The code block |
| # below manually adds a data dependency to `dresult` to ensure |
| # recomputation of `f(*args, **kwargs)` happens after `dresult`. |
| |
| # This works even if `dresult[0]` is a size 0 tensor as reduce_max |
| # of a size 0 tensor returns -inf. Use reshape here to avoid reading |
| # the entire `dresult[0]`. |
| elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1]) |
| # Cast elem to bool in case elem is NaN. |
| elem_bool = math_ops.cast(elem, dtypes.bool) |
| dresult_dep = array_ops.where_v2( |
| elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself |
| id_args = nest.map_structure( |
| lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args) |
| |
| t.watch(id_args) |
| if variables is not None: |
| t.watch(variables) |
| with variable_scope.variable_scope(current_var_scope): |
| recomputed_result = f(*id_args, **kwargs) |
| kw_vars = [] |
| if variables is not None: |
| kw_vars = list(variables) |
| grads = t.gradient( |
| recomputed_result, |
| list(id_args) + kw_vars, |
| output_gradients=dresult, |
| unconnected_gradients=UnconnectedGradients.ZERO) |
| |
| def transpose(*t_args, **t_kwargs): |
| """Gradient function calculation for forward mode autodiff.""" |
| # Just throw an error since gradients / activations are not stored on |
| # tape for recompute. |
| raise NotImplementedError( |
| "recompute_grad tried to transpose grad of {}. " |
| "Consider not using recompute_grad in forward mode" |
| "autodiff".format(f.__name__)) |
| |
| return (grads[:len(id_args)], grads[len(id_args):]), transpose |
| |
| return inner_recompute_grad(*wrapper_args) |
| |
| return result, grad_wrapper |
| |
| return tf_decorator.make_decorator(f, inner) |
| |
| |
| @tf_export("grad_pass_through") |
| def grad_pass_through(f): |
| """Creates a grad-pass-through op with the forward behavior provided in f. |
| |
| Use this function to wrap any op, maintaining its behavior in the forward |
| pass, but replacing the original op in the backward graph with an identity. |
| For example: |
| |
| ```python |
| x = tf.Variable(1.0, name="x") |
| z = tf.Variable(3.0, name="z") |
| |
| with tf.GradientTape() as tape: |
| # y will evaluate to 9.0 |
| y = tf.grad_pass_through(x.assign)(z**2) |
| # grads will evaluate to 6.0 |
| grads = tape.gradient(y, z) |
| ``` |
| |
| Another example is a 'differentiable' moving average approximation, where |
| gradients are allowed to flow into the last value fed to the moving average, |
| but the moving average is still used for the forward pass: |
| |
| ```python |
| x = ... # Some scalar value |
| # A moving average object, we don't need to know how this is implemented |
| moving_average = MovingAverage() |
| with backprop.GradientTape() as tape: |
| # mavg_x will evaluate to the current running average value |
| mavg_x = tf.grad_pass_through(moving_average)(x) |
| grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 |
| ``` |
| |
| Args: |
| f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` |
| outputs. |
| |
| Returns: |
| A function `h(x)` which returns the same values as `f(x)` and whose |
| gradients are the same as those of an identity function. |
| """ |
| @custom_gradient |
| def _grad_pass_through_op(*args, **kwargs): |
| def grad(*args, **kwargs): |
| variables = kwargs.get("variables") |
| if variables is not None: |
| # Variables involved in the wrapped op will not receive gradients. |
| return args, [None] * len(variables) |
| return args |
| return f(*args, **kwargs), grad |
| return tf_decorator.make_decorator(f, _grad_pass_through_op) |