| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Code to generate inputs/outputs exclusion lists for GradientTape.""" |
| |
| import sys |
| |
| import gast |
| |
| from tensorflow.python.autograph.pyct import anno |
| from tensorflow.python.autograph.pyct import cfg |
| from tensorflow.python.autograph.pyct import parser |
| from tensorflow.python.autograph.pyct import qual_names |
| from tensorflow.python.autograph.pyct import transformer |
| from tensorflow.python.autograph.pyct.static_analysis import activity |
| from tensorflow.python.autograph.pyct.static_analysis import liveness |
| from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs |
| from tensorflow.python.framework import op_def_registry |
| from tensorflow.python.framework import ops |
| |
| _GENERATED_FILE_HEADER = """/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Inputs/Outputs exclusion lists for GradientTape. |
| // |
| // This file is MACHINE GENERATED! Do not edit. |
| // Generated by: tensorflow/python/eager/gen_gradient_input_output_exclusions.py |
| """ |
| |
| _INCLUDES = """ |
| #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" |
| |
| #include "absl/types/optional.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/lib/gtl/flatset.h" |
| |
| using tensorflow::string; |
| |
| namespace { |
| // Keep static data in a format that's easy to init statically. |
| struct OpIndexInfo { |
| const char *op_name; |
| int num_indices; |
| std::array<int, 4> unused_indices; |
| }; |
| |
| // Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo. |
| template <typename T> |
| auto OpGradientInfoInit(const T &a) { |
| auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>; |
| for (const auto &item : a) { |
| m->emplace(string(item.op_name), |
| tensorflow::gtl::FlatSet<int>( |
| item.unused_indices.begin(), |
| item.unused_indices.begin() + item.num_indices)); |
| } |
| return m; |
| } |
| } // namespace |
| """ |
| |
| _EXCLUDED_OPS = [ |
| # Composite ops with custom gradient functions. |
| "If", |
| "StatelessIf", |
| "While", |
| "StatelessWhile", |
| "Case", |
| # TF Lite. These ops only appear in OSS. |
| # TODO(srbs): Find a better way to filter these out. |
| "AudioMicrofrontend", |
| # DTensor Ops with custom gradient functions. |
| # Note that these ops only appear in OSS, and fails the test in OSS. |
| "CopyToMesh", |
| "CopyToMeshGrad", |
| "Relayout", |
| "RelayoutLike", |
| ] |
| |
| |
| class _SubscriptUseTracker(transformer.Base): |
| """Track uses of composite names, excluding certain names when subscripted.""" |
| |
| def __init__(self, ctx, exclude_when_subscripted): |
| super(_SubscriptUseTracker, self).__init__(ctx) |
| self.exclude = exclude_when_subscripted |
| self.reads = set() |
| self.complex_reads = set() |
| |
| def visit_Attribute(self, node): |
| """Visits attribute nodes in the AST.""" |
| if anno.hasanno(node, anno.Basic.QN): |
| qn = anno.getanno(node, anno.Basic.QN) |
| if isinstance(node.ctx, gast.Load): |
| self.reads.add(qn) |
| node = self.generic_visit(node) |
| return node |
| |
| def visit_Subscript(self, node): |
| """Visits nodes with subscript in the AST.""" |
| s = node.slice |
| if anno.hasanno(node, anno.Basic.QN): |
| qn = anno.getanno(node, anno.Basic.QN) |
| if isinstance(node.ctx, gast.Load): |
| self.reads.add(qn) |
| elif isinstance(s, (gast.Tuple, gast.Slice)): |
| if anno.hasanno(node.value, anno.Basic.QN): |
| self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN)) |
| value_qn = anno.getanno(node.value, anno.Basic.QN, None) |
| if value_qn in self.exclude: |
| node.value = self.generic_visit(node.value) |
| else: |
| node.value = self.visit(node.value) |
| node.slice = self.visit(s) |
| return node |
| |
| |
| class _FunctionCallsTracker(transformer.Base): |
| """Tracks any function calls made with a given first argument name.""" |
| |
| def __init__(self, ctx, first_argument_name): |
| super(_FunctionCallsTracker, self).__init__(ctx) |
| self.first_argument_name = first_argument_name |
| self.calls = set() |
| |
| def visit_Name(self, node): |
| node = self.generic_visit(node) |
| if isinstance(node.ctx, gast.Load) and node.id in self.ctx.info.namespace: |
| anno.setanno(node, "static_value", self.ctx.info.namespace[node.id]) |
| return node |
| |
| def visit_Attribute(self, node): |
| node = self.generic_visit(node) |
| parent_val = anno.getanno(node.value, "static_value", default=None) |
| if parent_val is not None: |
| if hasattr(parent_val, node.attr): |
| anno.setanno(node, "static_value", getattr(parent_val, node.attr)) |
| return node |
| |
| def visit_Call(self, node): |
| node = self.generic_visit(node) |
| if (node.args and anno.getanno(node.args[0], anno.Basic.QN, |
| None) == self.first_argument_name): |
| fn_object = anno.getanno(node.func, "static_value", None) |
| if fn_object is not None: |
| self.calls.add(fn_object) |
| return node |
| |
| |
| _ALL = object() |
| |
| |
| def _live_tensors(f, attr_name="inputs"): |
| """Returns the indices of the used inputs. |
| |
| Note: This currently only handles direct index accesses e.g. op.inputs[1]. |
| If the function has slicing or list comprehension on attr_name then returns |
| _ALL. This ensure that this is correct even if inefficient. |
| |
| Args: |
| f: A grad function, taking the op as first argument. |
| attr_name: op attr to track. "inputs" or "outputs". |
| |
| Returns: |
| Either one of: |
| * set of integers representing individual indices of inputs used |
| * the value _ALL, if indices are used but cannot be determined which |
| * empty set, if no inputs are used |
| """ |
| node, _ = parser.parse_entity(f, ()) |
| entity_info = transformer.EntityInfo( |
| name=f.__name__, |
| source_code=None, |
| source_file=None, |
| future_features=(), |
| namespace=sys.modules[f.__module__].__dict__) |
| ctx = transformer.Context(entity_info, None, None) |
| |
| graphs = cfg.build(node) |
| node = qual_names.resolve(node) |
| node = activity.resolve(node, ctx, None) |
| node = reaching_fndefs.resolve(node, ctx, graphs) |
| node = liveness.resolve(node, ctx, graphs) |
| |
| op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN) |
| op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name) |
| |
| special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name,)) |
| node = special_tracker.visit(node) |
| |
| live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN) |
| inputs_outputs_used_qns = set() |
| for v in special_tracker.complex_reads: |
| # Complicated patterns like op.inputs[:3]. Could be smarter about them |
| # if they matter much. |
| if v == op_inputs_outputs_name: |
| return _ALL |
| for v in live_vars_in: |
| if v in special_tracker.reads: |
| if (v.has_subscript() and v.parent == op_inputs_outputs_name): |
| inputs_outputs_used_qns.add(v) |
| elif v == op_inputs_outputs_name: |
| # When op.{attr_name} is used directly, assume all tensors are |
| # used for now. In that case, no point digging further. |
| # TODO(mdan): We can descend into tuple expansions. |
| return _ALL |
| |
| function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name) |
| node = function_calls_tracker.visit(node) |
| |
| input_output_indices = set() |
| |
| for called_f in function_calls_tracker.calls: |
| child_indices = _live_tensors(called_f, attr_name=attr_name) |
| if child_indices is _ALL: |
| return _ALL |
| input_output_indices |= child_indices |
| |
| for v in inputs_outputs_used_qns: |
| assert v.has_subscript() |
| _, subscript = v.qn |
| if not subscript.is_simple(): |
| # Not a number, assuming it can be anything. |
| return _ALL |
| subscript_val, = subscript.qn |
| if (not isinstance(subscript_val, qual_names.Literal) and |
| not isinstance(subscript_val.value, int)): |
| # Not a number, assuming it can be anything. |
| return _ALL |
| input_output_indices.add(subscript_val.value) |
| return input_output_indices |
| |
| |
| def _get_num_inputs_outputs(op_type): |
| """Returns (num_inputs, num_outputs). |
| |
| Args: |
| op_type: String. The type of the Operation. Used to lookup the op in the |
| registry. |
| |
| Returns: |
| (num_inputs, num_outputs), for either num_inputs or num_outputs if the value |
| can't be statically inferred from the OpDef alone or of the OpDef lookup |
| fails, -1 is returned. |
| """ |
| |
| def _is_list_arg(arg): |
| return arg.number_attr or arg.type_list_attr |
| |
| def _count_args(arg_defs): |
| for arg in arg_defs: |
| if _is_list_arg(arg): |
| # Op has list type args which could be variable. |
| return -1 |
| return len(arg_defs) |
| |
| op_def = op_def_registry.get(op_type) |
| if not op_def: |
| return -1, -1 |
| return _count_args(op_def.input_arg), _count_args(op_def.output_arg) |
| |
| |
| def get_entries(attr_name): |
| """Returns the dict of entries. |
| |
| Each entry is of the form {op_name, {true|false, indices}} |
| |
| true: All values are unused. |
| false: `indices` are the only unused indices. |
| |
| Note: ops for which all values are used are not printed. |
| |
| Args: |
| attr_name: inputs or outputs. |
| |
| Returns: |
| A dict from op_type to formatted entry in the dict. |
| """ |
| assert attr_name in ["inputs", "outputs"] |
| entries = {} |
| for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access |
| if op_type in _EXCLUDED_OPS: |
| continue |
| num_values = _get_num_inputs_outputs(op_type)[0 if attr_name == |
| "inputs" else 1] |
| gradient_fn = ops._gradient_registry.lookup(op_type) # pylint: disable=protected-access |
| if gradient_fn is None: |
| # NotDifferentiable |
| if num_values != -1: |
| entries[op_type] = "{\"%s\"}," % op_type |
| continue |
| used_tensors = _live_tensors(gradient_fn, attr_name=attr_name) |
| if used_tensors is _ALL: |
| continue |
| elif not used_tensors: |
| entries[op_type] = "{\"%s\"}," % op_type |
| else: |
| all_tensors = set(range(num_values)) |
| unused_tensors = all_tensors - used_tensors |
| if unused_tensors: |
| unused_tensor_list = sorted(list(unused_tensors)) |
| entries[op_type] = "{\"%s\", %d, {%s}}," % ( |
| op_type, len(unused_tensor_list), ", ".join( |
| str(i) for i in unused_tensor_list)) |
| return entries |
| |
| |
| def get_function(name, entries): |
| """Generates lookup function with given name and lookup table entries.""" |
| contents = """ |
| absl::optional<tensorflow::gtl::FlatSet<int>> {name}( |
| const tensorflow::string &op_name) {{ |
| static std::array<OpIndexInfo, {count}> a = {{{{ |
| """.format( |
| name=name, count=len(entries) + 1) |
| contents += " " |
| contents += "\n ".join(entries[op_type] for op_type in sorted(entries)) |
| contents += "\n {\"VarHandleOp\"}," |
| contents += """ |
| }}; |
| static const auto &m = *OpGradientInfoInit(a); |
| |
| auto it = m.find(op_name); |
| if (it != m.end()) { |
| return it->second; |
| } |
| return absl::nullopt; |
| } |
| """ |
| return contents |
| |
| |
| def get_contents(): |
| """Returns contents for the generated file.""" |
| contents = "" |
| contents += _GENERATED_FILE_HEADER + _INCLUDES |
| contents += get_function("OpGradientUnusedInputIndices", |
| get_entries("inputs")) |
| contents += get_function("OpGradientUnusedOutputIndices", |
| get_entries("outputs")) |
| return contents |