| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Defines a rule to check the dependencies of a given target.""" |
| |
| load("@bazel_skylib//lib:new_sets.bzl", "sets") |
| |
| # Traverse the dependency graph along the "deps" attribute of the |
| # target and return a struct with one field called 'tf_collected_deps'. |
| # tf_collected_deps will be the union of the deps of the current target |
| # and the tf_collected_deps of the dependencies of this target. |
| # Borrowed from TensorFlow (https://github.com/tensorflow/tensorflow). |
| def _collect_deps_aspect_impl(target, ctx): |
| direct, transitive = [], [] |
| all_deps = [] |
| if hasattr(ctx.rule.attr, "deps"): |
| all_deps += ctx.rule.attr.deps |
| if hasattr(ctx.rule.attr, "data"): |
| all_deps += ctx.rule.attr.data |
| for dep in all_deps: |
| direct.append(dep.label) |
| if hasattr(dep, "tf_collected_deps"): |
| transitive.append(dep.tf_collected_deps) |
| return struct(tf_collected_deps = depset(direct = direct, transitive = transitive)) |
| |
| collect_deps_aspect = aspect( |
| attr_aspects = ["deps", "data"], |
| implementation = _collect_deps_aspect_impl, |
| ) |
| |
| def _dep_label(dep): |
| label = dep.label |
| return label.package + ":" + label.name |
| |
| # This rule checks that transitive dependencies don't depend on the targets |
| # listed in the 'disallowed_deps' attribute, but do depend on the targets listed |
| # in the 'required_deps' attribute. Dependencies considered are targets in the |
| # 'deps' attribute or the 'data' attribute. |
| # Borrowed from TensorFlow (https://github.com/tensorflow/tensorflow). |
| def _check_deps_impl(ctx): |
| required_deps = ctx.attr.required_deps |
| disallowed_deps = ctx.attr.disallowed_deps |
| for input_dep in ctx.attr.deps: |
| if not hasattr(input_dep, "tf_collected_deps"): |
| continue |
| collected_deps = sets.make(input_dep.tf_collected_deps.to_list()) |
| for disallowed_dep in disallowed_deps: |
| if sets.contains(collected_deps, disallowed_dep.label): |
| fail( |
| _dep_label(input_dep) + " cannot depend on " + |
| _dep_label(disallowed_dep), |
| ) |
| for required_dep in required_deps: |
| if not sets.contains(collected_deps, required_dep.label): |
| fail( |
| _dep_label(input_dep) + " must depend on " + |
| _dep_label(required_dep), |
| ) |
| |
| check_deps = rule( |
| _check_deps_impl, |
| attrs = { |
| "deps": attr.label_list( |
| aspects = [collect_deps_aspect], |
| mandatory = True, |
| allow_files = True, |
| ), |
| "disallowed_deps": attr.label_list( |
| default = [], |
| allow_files = True, |
| ), |
| "required_deps": attr.label_list( |
| default = [], |
| allow_files = True, |
| ), |
| }, |
| ) |