blob: 01002b735191adf2f32ade1d230debb7777f1503 [file] [log] [blame]
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Common starlark rules. """
load("@bazel_skylib//lib:new_sets.bzl", "sets")
# Rule for simple expansion of template files. This performs a simple
# search over the template file for the keys in substitutions,
# and replaces them with the corresponding values.
#
# Borrowed from TensorFlow (https://github.com/tensorflow/tensorflow)
#
# Typical usage:
# load("/tools/build_rules/template_rule", "expand_header_template")
# template_rule(
# name = "ExpandMyTemplate",
# src = "my.template",
# out = "my.txt",
# substitutions = {
# "$VAR1": "foo",
# "$VAR2": "bar",
# }
# )
#
# Args:
# name: The name of the rule.
# template: The template file to expand
# out: The destination of the expanded file
# substitutions: A dictionary mapping strings to their substitutions
def template_rule_impl(ctx):
ctx.actions.expand_template(
template = ctx.file.src,
output = ctx.outputs.out,
substitutions = ctx.attr.substitutions,
)
template_rule = rule(
attrs = {
"src": attr.label(
mandatory = True,
allow_single_file = True,
),
"substitutions": attr.string_dict(mandatory = True),
"out": attr.output(mandatory = True),
},
# output_to_genfiles is required for header files.
output_to_genfiles = True,
implementation = template_rule_impl,
)
# Traverse the dependency graph along the "deps" attribute of the
# target and return a struct with one field called 'tf_collected_deps'.
# tf_collected_deps will be the union of the deps of the current target
# and the tf_collected_deps of the dependencies of this target.
# Borrowed from TensorFlow (https://github.com/tensorflow/tensorflow).
def _collect_deps_aspect_impl(target, ctx):
direct, transitive = [], []
all_deps = []
if hasattr(ctx.rule.attr, "deps"):
all_deps += ctx.rule.attr.deps
if hasattr(ctx.rule.attr, "data"):
all_deps += ctx.rule.attr.data
for dep in all_deps:
direct.append(dep.label)
if hasattr(dep, "tf_collected_deps"):
transitive.append(dep.tf_collected_deps)
return struct(tf_collected_deps = depset(direct = direct, transitive = transitive))
collect_deps_aspect = aspect(
attr_aspects = ["deps", "data"],
implementation = _collect_deps_aspect_impl,
)
def _dep_label(dep):
label = dep.label
return label.package + ":" + label.name
# This rule checks that transitive dependencies don't depend on the targets
# listed in the 'disallowed_deps' attribute, but do depend on the targets listed
# in the 'required_deps' attribute. Dependencies considered are targets in the
# 'deps' attribute or the 'data' attribute.
# Borrowed from TensorFlow (https://github.com/tensorflow/tensorflow).
def _check_deps_impl(ctx):
required_deps = ctx.attr.required_deps
disallowed_deps = ctx.attr.disallowed_deps
for input_dep in ctx.attr.deps:
if not hasattr(input_dep, "tf_collected_deps"):
continue
collected_deps = sets.make(input_dep.tf_collected_deps.to_list())
for disallowed_dep in disallowed_deps:
if sets.contains(collected_deps, disallowed_dep.label):
fail(
_dep_label(input_dep) + " cannot depend on " +
_dep_label(disallowed_dep),
)
for required_dep in required_deps:
if not sets.contains(collected_deps, required_dep.label):
fail(
_dep_label(input_dep) + " must depend on " +
_dep_label(required_dep),
)
check_deps = rule(
_check_deps_impl,
attrs = {
"deps": attr.label_list(
aspects = [collect_deps_aspect],
mandatory = True,
allow_files = True,
),
"disallowed_deps": attr.label_list(
default = [],
allow_files = True,
),
"required_deps": attr.label_list(
default = [],
allow_files = True,
),
},
)