[engine] Support specifying an allow-list of checks
... user-controllable via an `--only <check1>,<check2>` flag.
This has a couple use cases:
1. Allowing users to only run one specific check to save time while
iterating on a new check implementation.
2. Rerunning checks to resolve conflicts when applying fixes, but
skipping checks for which fixes were already successfully applied.
This will come in a future change.
Change-Id: I503d5d07d20602262e364c45b16707a042eafcc9
Reviewed-on: https://fuchsia-review.googlesource.com/c/shac-project/shac/+/927712
Fuchsia-Auto-Submit: Oliver Newman <olivernewman@google.com>
Commit-Queue: Auto-Submit <auto-submit@fuchsia-infra.iam.gserviceaccount.com>
Reviewed-by: Anthony Fandrianto <atyfto@google.com>
diff --git a/README.md b/README.md
index c009e65..7518f13 100644
--- a/README.md
+++ b/README.md
@@ -267,6 +267,7 @@
- [x] Include unstaged files in analysis, including respecting unstaged
`shac.star` files
- [x] Automatic fix application with handling for conflicting suggestions
+- [ ] Rerun formatting checks after a conflict is encountered
- [ ] Provide a `.shac` cache directory that checks can write to
- [ ] Mount checkout directory read-only
- [x] By default
diff --git a/internal/cli/base.go b/internal/cli/base.go
index 15a52d5..58e7955 100644
--- a/internal/cli/base.go
+++ b/internal/cli/base.go
@@ -26,6 +26,7 @@
allFiles bool
entryPoint string
noRecurse bool
+ allowList []string
vars stringMapFlag
}
@@ -34,6 +35,7 @@
f.BoolVar(&c.allFiles, "all", false, "checks all the files instead of guess the upstream to diff against")
f.BoolVar(&c.noRecurse, "no-recurse", false, "do not look for shac.star files recursively")
f.StringVar(&c.entryPoint, "entrypoint", engine.DefaultEntryPoint, "basename of Starlark files to run")
+ f.StringSliceVar(&c.allowList, "only", nil, "comma-separated allowlist of checks to run; by default all checks are run")
c.vars = stringMapFlag{}
f.Var(&c.vars, "var", "runtime variables to set, of the form key=value")
}
@@ -49,5 +51,8 @@
Recurse: !c.noRecurse,
Vars: c.vars,
EntryPoint: c.entryPoint,
+ Filter: engine.CheckFilter{
+ AllowList: c.allowList,
+ },
}, nil
}
diff --git a/internal/cli/fix.go b/internal/cli/fix.go
index 7daa6f6..2a39aff 100644
--- a/internal/cli/fix.go
+++ b/internal/cli/fix.go
@@ -46,6 +46,6 @@
if err != nil {
return err
}
- o.Filter = engine.OnlyNonFormatters
+ o.Filter.FormatterFiltering = engine.OnlyNonFormatters
return engine.Fix(ctx, &o, c.quiet)
}
diff --git a/internal/cli/fmt.go b/internal/cli/fmt.go
index f21d29c..995fb95 100644
--- a/internal/cli/fmt.go
+++ b/internal/cli/fmt.go
@@ -46,6 +46,6 @@
if err != nil {
return err
}
- o.Filter = engine.OnlyFormatters
+ o.Filter.FormatterFiltering = engine.OnlyFormatters
return engine.Fix(ctx, &o, c.quiet)
}
diff --git a/internal/cli/main_test.go b/internal/cli/main_test.go
index 9c2afb8..9b72aa0 100644
--- a/internal/cli/main_test.go
+++ b/internal/cli/main_test.go
@@ -16,9 +16,14 @@
import (
"bytes"
+ "fmt"
+ "os"
+ "path/filepath"
"strconv"
"strings"
"testing"
+
+ "github.com/google/go-cmp/cmp"
)
func TestMainHelp(t *testing.T) {
@@ -33,6 +38,7 @@
{[]string{"shac", "fix", "--help"}, "Usage of shac fix:\n"},
{[]string{"shac", "fmt", "--help"}, "Usage of shac fmt:\n"},
{[]string{"shac", "doc", "--help"}, "Usage of shac doc:\n"},
+ {[]string{"shac", "version", "--help"}, "Usage of shac version:\n"},
}
for i, line := range data {
t.Run(strconv.Itoa(i), func(t *testing.T) {
@@ -49,7 +55,7 @@
type panicWrite struct{}
-func (panicWrite) Write(b []byte) (int, error) {
+func (panicWrite) Write([]byte) (int, error) {
panic("unexpected write!")
}
@@ -66,3 +72,77 @@
func init() {
helpOut = panicWrite{}
}
+
+func TestMainErr(t *testing.T) {
+ t.Parallel()
+ data := map[string]func(t *testing.T) (args []string, wantErr string){
+ "no shac.star files": func(t *testing.T) ([]string, string) {
+ root := t.TempDir()
+ return []string{"check", "-C", root, "--only", "foocheck"},
+ fmt.Sprintf("no shac.star files found in %s", root)
+ },
+ "--all with positional arguments": func(t *testing.T) ([]string, string) {
+ return []string{"check", "--all", "foo.txt", "bar.txt"},
+ "--all cannot be set together with positional file arguments"
+ },
+ "--only flag without value": func(t *testing.T) ([]string, string) {
+ root := t.TempDir()
+ return []string{"check", "-C", root, "--only"},
+ "flag needs an argument: --only"
+ },
+ "allowlist with invalid check": func(t *testing.T) ([]string, string) {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", "def cb(ctx): pass\nshac.register_check(cb)")
+ return []string{"check", "-C", root, "--only", "does-not-exist"},
+ "check does not exist: does-not-exist"
+ },
+ // Simple check that `shac fmt` filters out non-formatter checks.
+ "fmt with no checks to run": func(t *testing.T) ([]string, string) {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def non_formatter(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(shac.check(non_formatter))\n")
+ return []string{"fmt", "-C", root, "--only", "non_formatter"},
+ "no checks to run"
+ },
+ // Simple check that `shac fix` filters out formatters.
+ "fix with no checks to run": func(t *testing.T) ([]string, string) {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def formatter(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(shac.check(formatter, formatter = True))\n")
+ return []string{"fix", "-C", root, "--only", "formatter"},
+ "no checks to run"
+ },
+ }
+ for name, f := range data {
+ f := f
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ args, wantErr := f(t)
+ cmd := append([]string{"shac"}, args...)
+ err := Main(cmd)
+ if err == nil {
+ t.Fatalf("Expected error from running %s", cmd)
+ }
+ if diff := cmp.Diff(wantErr, err.Error()); diff != "" {
+ t.Fatalf("Wrong error (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func writeFile(t testing.TB, root, path, content string) {
+ t.Helper()
+ writeFileBytes(t, root, path, []byte(content), 0o600)
+}
+
+func writeFileBytes(t testing.TB, root, path string, content []byte, perm os.FileMode) {
+ t.Helper()
+ abs := filepath.Join(root, path)
+ if err := os.WriteFile(abs, content, perm); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/internal/engine/run.go b/internal/engine/run.go
index 4901e3e..512cd9c 100644
--- a/internal/engine/run.go
+++ b/internal/engine/run.go
@@ -79,19 +79,88 @@
_ struct{}
}
-// CheckFilter controls which checks get run by `Run`. It returns true for
-// checks that should be run, false for checks that should be skipped.
-type CheckFilter func(registeredCheck) bool
+// FormatterFiltering specifies whether formatting or non-formatting checks will
+// be filtered out.
+type FormatterFiltering int
-// OnlyFormatters causes only checks marked with `formatter = True` to be run.
-func OnlyFormatters(c registeredCheck) bool {
- return c.formatter
+const (
+ // AllChecks does not perform any filtering based on whether a check is a
+ // formatter or not.
+ AllChecks FormatterFiltering = iota
+ // OnlyFormatters causes only checks marked with `formatter = True` to be
+ // run.
+ OnlyFormatters
+ // OnlyNonFormatters causes only checks *not* marked with `formatter = True` to
+ // be run.
+ OnlyNonFormatters
+)
+
+// CheckFilter controls which checks are run.
+type CheckFilter struct {
+ FormatterFiltering FormatterFiltering
+ // AllowList specifies checks to run. If non-empty, all other checks will be
+ // skipped.
+ AllowList []string
}
-// OnlyNonFormatters causes only checks *not* marked with `formatter = True` to
-// be run.
-func OnlyNonFormatters(c registeredCheck) bool {
- return !c.formatter
+func (f *CheckFilter) filter(checks []*registeredCheck) ([]*registeredCheck, error) {
+ if len(checks) == 0 {
+ return checks, nil
+ }
+
+ // Keep track of the allowlist elements that correspond to valid checks so
+ // we can report any invalid allowlist elements at the end.
+ nonValidatedAllowList := make(map[string]struct{})
+ for _, name := range f.AllowList {
+ nonValidatedAllowList[name] = struct{}{}
+ }
+
+ var filtered []*registeredCheck
+ for _, check := range checks {
+ if len(f.AllowList) > 0 {
+ if _, ok := nonValidatedAllowList[check.name]; !ok {
+ // Check is not allow-listed.
+ continue
+ }
+ delete(nonValidatedAllowList, check.name)
+ }
+ switch f.FormatterFiltering {
+ case AllChecks:
+ case OnlyFormatters:
+ if !check.formatter {
+ continue
+ }
+ case OnlyNonFormatters:
+ if check.formatter {
+ continue
+ }
+ default:
+ return nil, fmt.Errorf("invalid FormatterFiltering value: %d", f.FormatterFiltering)
+ }
+ filtered = append(filtered, check)
+ }
+
+ if len(nonValidatedAllowList) > 0 {
+ var invalidChecks []string
+ for name := range nonValidatedAllowList {
+ invalidChecks = append(invalidChecks, name)
+ }
+ var msg string
+ if len(invalidChecks) == 1 {
+ msg = "check does not exist"
+ } else {
+ msg = "checks do not exist"
+ }
+ slices.Sort(invalidChecks)
+ return nil, fmt.Errorf("%s: %s", msg, strings.Join(invalidChecks, ", "))
+ }
+
+ if len(filtered) == 0 {
+ // Fail noisily if all checks are filtered out, it's probably user
+ // error.
+ return nil, errors.New("no checks to run")
+ }
+ return filtered, nil
}
// Level is one of "notice", "warning" or "error".
@@ -570,7 +639,7 @@
//
// Checks are executed sequentially after all Starlark code is loaded and not
// mutated. They run checks and emit results (results and comments).
- checks []registeredCheck
+ checks []*registeredCheck
// filter controls which checks run. If nil, all checks will run.
filter CheckFilter
passthroughEnv []*PassthroughEnv
@@ -641,18 +710,19 @@
}
args := starlark.Tuple{shacCtx}
args.Freeze()
- for i := range s.checks {
- if s.filter != nil && !s.filter(s.checks[i]) {
- continue
- }
- i := i
+ checks, err := s.filter.filter(s.checks)
+ if err != nil {
+ return err
+ }
+ for _, check := range checks {
+ check := check
ch <- func() error {
start := time.Now()
pi := func(th *starlark.Thread, msg string) {
pos := th.CallFrame(1).Pos
- s.r.Print(ctx, s.checks[i].name, pos.Filename(), int(pos.Line), msg)
+ s.r.Print(ctx, check.name, pos.Filename(), int(pos.Line), msg)
}
- err := s.checks[i].call(ctx, s.env, args, pi)
+ err := check.call(ctx, s.env, args, pi)
if err != nil && ctx.Err() != nil {
// Don't report the check completion if the context was
// canceled. The error was probably caused by the context being
@@ -661,7 +731,7 @@
// check failures.
return ctx.Err()
}
- s.r.CheckCompleted(ctx, s.checks[i].name, start, time.Since(start), s.checks[i].highestLevel, err)
+ s.r.CheckCompleted(ctx, check.name, start, time.Since(start), check.highestLevel, err)
return err
}
}
diff --git a/internal/engine/run_test.go b/internal/engine/run_test.go
index da01076..9a9a722 100644
--- a/internal/engine/run_test.go
+++ b/internal/engine/run_test.go
@@ -180,6 +180,74 @@
}(),
"var must be declared in a shac.textproto file: unknown_var",
},
+ {
+ "invalid allowlist item",
+ func() Options {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def cb(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(cb)")
+ return Options{
+ Dir: root,
+ Filter: CheckFilter{
+ AllowList: []string{"does-not-exist"},
+ },
+ }
+ }(),
+ "check does not exist: does-not-exist",
+ },
+ {
+ "multiple invalid allowlist items",
+ func() Options {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def cb(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(cb)")
+ return Options{
+ Dir: root,
+ Filter: CheckFilter{
+ AllowList: []string{"does-not-exist", "cb", "also-does-not-exist"},
+ },
+ }
+ }(),
+ "checks do not exist: also-does-not-exist, does-not-exist",
+ },
+ {
+ "invalid FormatterFiltering",
+ func() Options {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def cb(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(cb)")
+ return Options{
+ Dir: root,
+ Filter: CheckFilter{
+ FormatterFiltering: FormatterFiltering(3),
+ },
+ }
+ }(),
+ "invalid FormatterFiltering value: 3",
+ },
+ {
+ "all checks filtered out",
+ func() Options {
+ root := t.TempDir()
+ writeFile(t, root, "shac.star", ""+
+ "def cb(ctx):\n"+
+ " pass\n"+
+ "shac.register_check(cb)")
+ return Options{
+ Dir: root,
+ Filter: CheckFilter{
+ FormatterFiltering: OnlyFormatters,
+ },
+ }
+ }(),
+ "no checks to run",
+ },
}
for i := range data {
i := i
@@ -436,6 +504,77 @@
}
}
+func TestRun_Filtering(t *testing.T) {
+ t.Parallel()
+
+ root := resolvedTempDir(t)
+
+ writeFile(t, root, "shac.star", ""+
+ "def non_formatter(ctx):\n"+
+ " print(\"non-formatter running\")\n"+
+ "def formatter(ctx):\n"+
+ " print(\"formatter running\")\n"+
+ "shac.register_check(shac.check(formatter, formatter = True))\n"+
+ "shac.register_check(shac.check(non_formatter))\n")
+
+ data := []struct {
+ name string
+ filter CheckFilter
+ want string
+ }{
+ {
+ name: "all checks",
+ want: "[//shac.star:2] non-formatter running\n" +
+ "[//shac.star:4] formatter running\n",
+ },
+ {
+ name: "only formatters",
+ filter: CheckFilter{
+ FormatterFiltering: OnlyFormatters,
+ },
+ want: "[//shac.star:4] formatter running\n",
+ },
+ {
+ name: "only non-formatters",
+ filter: CheckFilter{
+ FormatterFiltering: OnlyNonFormatters,
+ },
+ want: "[//shac.star:2] non-formatter running\n",
+ },
+ {
+ name: "only specified checks (non-formatter)",
+ filter: CheckFilter{
+ AllowList: []string{"non_formatter"},
+ },
+ want: "[//shac.star:2] non-formatter running\n",
+ },
+ {
+ name: "only specified checks (formatter)",
+ filter: CheckFilter{
+ AllowList: []string{"formatter"},
+ },
+ want: "[//shac.star:4] formatter running\n",
+ },
+ }
+ for i := range data {
+ i := i
+ t.Run(data[i].name, func(t *testing.T) {
+ r := reportPrint{reportNoPrint: reportNoPrint{t: t}}
+ o := Options{Report: &r, Dir: root, Filter: data[i].filter}
+ if err := Run(context.Background(), &o); err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+ got := sortLines(r.b.String())
+ want := sortLines(data[i].want)
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Helper()
+ t.Fatalf("mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
func TestRun_Ignore(t *testing.T) {
t.Parallel()
root := t.TempDir()
@@ -2037,7 +2176,7 @@
err := Run(context.Background(), &o)
if data[i].err != "" {
if err == nil {
- t.Fatal("expected error")
+ t.Fatalf("expected error")
}
got := err.Error()
if data[i].err != got {
diff --git a/internal/engine/runtime_shac.go b/internal/engine/runtime_shac.go
index 2dd6ba7..f9ae715 100644
--- a/internal/engine/runtime_shac.go
+++ b/internal/engine/runtime_shac.go
@@ -70,7 +70,7 @@
return errors.New("can't register checks after done loading")
}
// Register the new callback.
- s.checks = append(s.checks, registeredCheck{check: c})
+ s.checks = append(s.checks, ®isteredCheck{check: c})
return nil
}
diff --git a/internal/engine/version.go b/internal/engine/version.go
index 9a5ab01..554d94b 100644
--- a/internal/engine/version.go
+++ b/internal/engine/version.go
@@ -26,7 +26,7 @@
// Version is the current tool version.
//
// TODO(maruel): Add proper version, preferably from git tag.
- Version = shacVersion{0, 1, 11}
+ Version = shacVersion{0, 1, 12}
)
func (v shacVersion) String() string {