blob: 5b9baf3545c71c32747857588a3223274b27c1aa [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to generate inputs/outputs exclusion lists for GradientTape."""
import sys
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
_GENERATED_FILE_HEADER = """/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Inputs/Outputs exclusion lists for GradientTape.
//
// This file is MACHINE GENERATED! Do not edit.
// Generated by: tensorflow/python/eager/gen_gradient_input_output_exclusions.py
"""
_INCLUDES = """
#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
using tensorflow::string;
namespace {
// Keep static data in a format that's easy to init statically.
struct OpIndexInfo {
const char *op_name;
int num_indices;
std::array<int, 4> unused_indices;
};
// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo.
template <typename T>
auto OpGradientInfoInit(const T &a) {
auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>;
for (const auto &item : a) {
m->emplace(string(item.op_name),
tensorflow::gtl::FlatSet<int>(
item.unused_indices.begin(),
item.unused_indices.begin() + item.num_indices));
}
return m;
}
} // namespace
"""
_EXCLUDED_OPS = [
# Composite ops with custom gradient functions.
"If",
"StatelessIf",
"While",
"StatelessWhile",
"Case",
# TF Lite. These ops only appear in OSS.
# TODO(srbs): Find a better way to filter these out.
"AudioMicrofrontend",
# DTensor Ops with custom gradient functions.
# Note that these ops only appear in OSS, and fails the test in OSS.
"CopyToMesh",
"CopyToMeshGrad",
"Relayout",
"RelayoutLike",
]
class _SubscriptUseTracker(transformer.Base):
"""Track uses of composite names, excluding certain names when subscripted."""
def __init__(self, ctx, exclude_when_subscripted):
super(_SubscriptUseTracker, self).__init__(ctx)
self.exclude = exclude_when_subscripted
self.reads = set()
self.complex_reads = set()
def visit_Attribute(self, node):
"""Visits attribute nodes in the AST."""
if anno.hasanno(node, anno.Basic.QN):
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
self.reads.add(qn)
node = self.generic_visit(node)
return node
def visit_Subscript(self, node):
"""Visits nodes with subscript in the AST."""
s = node.slice
if anno.hasanno(node, anno.Basic.QN):
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
self.reads.add(qn)
elif isinstance(s, (gast.Tuple, gast.Slice)):
if anno.hasanno(node.value, anno.Basic.QN):
self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN))
value_qn = anno.getanno(node.value, anno.Basic.QN, None)
if value_qn in self.exclude:
node.value = self.generic_visit(node.value)
else:
node.value = self.visit(node.value)
node.slice = self.visit(s)
return node
class _FunctionCallsTracker(transformer.Base):
"""Tracks any function calls made with a given first argument name."""
def __init__(self, ctx, first_argument_name):
super(_FunctionCallsTracker, self).__init__(ctx)
self.first_argument_name = first_argument_name
self.calls = set()
def visit_Name(self, node):
node = self.generic_visit(node)
if isinstance(node.ctx, gast.Load) and node.id in self.ctx.info.namespace:
anno.setanno(node, "static_value", self.ctx.info.namespace[node.id])
return node
def visit_Attribute(self, node):
node = self.generic_visit(node)
parent_val = anno.getanno(node.value, "static_value", default=None)
if parent_val is not None:
if hasattr(parent_val, node.attr):
anno.setanno(node, "static_value", getattr(parent_val, node.attr))
return node
def visit_Call(self, node):
node = self.generic_visit(node)
if (node.args and anno.getanno(node.args[0], anno.Basic.QN,
None) == self.first_argument_name):
fn_object = anno.getanno(node.func, "static_value", None)
if fn_object is not None:
self.calls.add(fn_object)
return node
_ALL = object()
def _live_tensors(f, attr_name="inputs"):
"""Returns the indices of the used inputs.
Note: This currently only handles direct index accesses e.g. op.inputs[1].
If the function has slicing or list comprehension on attr_name then returns
_ALL. This ensure that this is correct even if inefficient.
Args:
f: A grad function, taking the op as first argument.
attr_name: op attr to track. "inputs" or "outputs".
Returns:
Either one of:
* set of integers representing individual indices of inputs used
* the value _ALL, if indices are used but cannot be determined which
* empty set, if no inputs are used
"""
node, _ = parser.parse_entity(f, ())
entity_info = transformer.EntityInfo(
name=f.__name__,
source_code=None,
source_file=None,
future_features=(),
namespace=sys.modules[f.__module__].__dict__)
ctx = transformer.Context(entity_info, None, None)
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name)
special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name,))
node = special_tracker.visit(node)
live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)
inputs_outputs_used_qns = set()
for v in special_tracker.complex_reads:
# Complicated patterns like op.inputs[:3]. Could be smarter about them
# if they matter much.
if v == op_inputs_outputs_name:
return _ALL
for v in live_vars_in:
if v in special_tracker.reads:
if (v.has_subscript() and v.parent == op_inputs_outputs_name):
inputs_outputs_used_qns.add(v)
elif v == op_inputs_outputs_name:
# When op.{attr_name} is used directly, assume all tensors are
# used for now. In that case, no point digging further.
# TODO(mdan): We can descend into tuple expansions.
return _ALL
function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name)
node = function_calls_tracker.visit(node)
input_output_indices = set()
for called_f in function_calls_tracker.calls:
child_indices = _live_tensors(called_f, attr_name=attr_name)
if child_indices is _ALL:
return _ALL
input_output_indices |= child_indices
for v in inputs_outputs_used_qns:
assert v.has_subscript()
_, subscript = v.qn
if not subscript.is_simple():
# Not a number, assuming it can be anything.
return _ALL
subscript_val, = subscript.qn
if (not isinstance(subscript_val, qual_names.Literal) and
not isinstance(subscript_val.value, int)):
# Not a number, assuming it can be anything.
return _ALL
input_output_indices.add(subscript_val.value)
return input_output_indices
def _get_num_inputs_outputs(op_type):
"""Returns (num_inputs, num_outputs).
Args:
op_type: String. The type of the Operation. Used to lookup the op in the
registry.
Returns:
(num_inputs, num_outputs), for either num_inputs or num_outputs if the value
can't be statically inferred from the OpDef alone or of the OpDef lookup
fails, -1 is returned.
"""
def _is_list_arg(arg):
return arg.number_attr or arg.type_list_attr
def _count_args(arg_defs):
for arg in arg_defs:
if _is_list_arg(arg):
# Op has list type args which could be variable.
return -1
return len(arg_defs)
op_def = op_def_registry.get(op_type)
if not op_def:
return -1, -1
return _count_args(op_def.input_arg), _count_args(op_def.output_arg)
def get_entries(attr_name):
"""Returns the dict of entries.
Each entry is of the form {op_name, {true|false, indices}}
true: All values are unused.
false: `indices` are the only unused indices.
Note: ops for which all values are used are not printed.
Args:
attr_name: inputs or outputs.
Returns:
A dict from op_type to formatted entry in the dict.
"""
assert attr_name in ["inputs", "outputs"]
entries = {}
for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access
if op_type in _EXCLUDED_OPS:
continue
num_values = _get_num_inputs_outputs(op_type)[0 if attr_name ==
"inputs" else 1]
gradient_fn = ops._gradient_registry.lookup(op_type) # pylint: disable=protected-access
if gradient_fn is None:
# NotDifferentiable
if num_values != -1:
entries[op_type] = "{\"%s\"}," % op_type
continue
used_tensors = _live_tensors(gradient_fn, attr_name=attr_name)
if used_tensors is _ALL:
continue
elif not used_tensors:
entries[op_type] = "{\"%s\"}," % op_type
else:
all_tensors = set(range(num_values))
unused_tensors = all_tensors - used_tensors
if unused_tensors:
unused_tensor_list = sorted(list(unused_tensors))
entries[op_type] = "{\"%s\", %d, {%s}}," % (
op_type, len(unused_tensor_list), ", ".join(
str(i) for i in unused_tensor_list))
return entries
def get_function(name, entries):
"""Generates lookup function with given name and lookup table entries."""
contents = """
absl::optional<tensorflow::gtl::FlatSet<int>> {name}(
const tensorflow::string &op_name) {{
static std::array<OpIndexInfo, {count}> a = {{{{
""".format(
name=name, count=len(entries) + 1)
contents += " "
contents += "\n ".join(entries[op_type] for op_type in sorted(entries))
contents += "\n {\"VarHandleOp\"},"
contents += """
}};
static const auto &m = *OpGradientInfoInit(a);
auto it = m.find(op_name);
if (it != m.end()) {
return it->second;
}
return absl::nullopt;
}
"""
return contents
def get_contents():
"""Returns contents for the generated file."""
contents = ""
contents += _GENERATED_FILE_HEADER + _INCLUDES
contents += get_function("OpGradientUnusedInputIndices",
get_entries("inputs"))
contents += get_function("OpGradientUnusedOutputIndices",
get_entries("outputs"))
return contents