[engine] Support specifying an allow-list of checks

... user-controllable via an `--only <check1>,<check2>` flag.

This has a couple use cases:
1. Allowing users to only run one specific check to save time while
   iterating on a new check implementation.
2. Rerunning checks to resolve conflicts when applying fixes, but
   skipping checks for which fixes were already successfully applied.
   This will come in a future change.

Change-Id: I503d5d07d20602262e364c45b16707a042eafcc9
Reviewed-on: https://fuchsia-review.googlesource.com/c/shac-project/shac/+/927712
Fuchsia-Auto-Submit: Oliver Newman <olivernewman@google.com>
Commit-Queue: Auto-Submit <auto-submit@fuchsia-infra.iam.gserviceaccount.com>
Reviewed-by: Anthony Fandrianto <atyfto@google.com>
diff --git a/README.md b/README.md
index c009e65..7518f13 100644
--- a/README.md
+++ b/README.md
@@ -267,6 +267,7 @@
 - [x] Include unstaged files in analysis, including respecting unstaged
       `shac.star` files
 - [x] Automatic fix application with handling for conflicting suggestions
+- [ ] Rerun formatting checks after a conflict is encountered
 - [ ] Provide a `.shac` cache directory that checks can write to
 - [ ] Mount checkout directory read-only
   - [x] By default
diff --git a/internal/cli/base.go b/internal/cli/base.go
index 15a52d5..58e7955 100644
--- a/internal/cli/base.go
+++ b/internal/cli/base.go
@@ -26,6 +26,7 @@
 	allFiles   bool
 	entryPoint string
 	noRecurse  bool
+	allowList  []string
 	vars       stringMapFlag
 }
 
@@ -34,6 +35,7 @@
 	f.BoolVar(&c.allFiles, "all", false, "checks all the files instead of guess the upstream to diff against")
 	f.BoolVar(&c.noRecurse, "no-recurse", false, "do not look for shac.star files recursively")
 	f.StringVar(&c.entryPoint, "entrypoint", engine.DefaultEntryPoint, "basename of Starlark files to run")
+	f.StringSliceVar(&c.allowList, "only", nil, "comma-separated allowlist of checks to run; by default all checks are run")
 	c.vars = stringMapFlag{}
 	f.Var(&c.vars, "var", "runtime variables to set, of the form key=value")
 }
@@ -49,5 +51,8 @@
 		Recurse:    !c.noRecurse,
 		Vars:       c.vars,
 		EntryPoint: c.entryPoint,
+		Filter: engine.CheckFilter{
+			AllowList: c.allowList,
+		},
 	}, nil
 }
diff --git a/internal/cli/fix.go b/internal/cli/fix.go
index 7daa6f6..2a39aff 100644
--- a/internal/cli/fix.go
+++ b/internal/cli/fix.go
@@ -46,6 +46,6 @@
 	if err != nil {
 		return err
 	}
-	o.Filter = engine.OnlyNonFormatters
+	o.Filter.FormatterFiltering = engine.OnlyNonFormatters
 	return engine.Fix(ctx, &o, c.quiet)
 }
diff --git a/internal/cli/fmt.go b/internal/cli/fmt.go
index f21d29c..995fb95 100644
--- a/internal/cli/fmt.go
+++ b/internal/cli/fmt.go
@@ -46,6 +46,6 @@
 	if err != nil {
 		return err
 	}
-	o.Filter = engine.OnlyFormatters
+	o.Filter.FormatterFiltering = engine.OnlyFormatters
 	return engine.Fix(ctx, &o, c.quiet)
 }
diff --git a/internal/cli/main_test.go b/internal/cli/main_test.go
index 9c2afb8..9b72aa0 100644
--- a/internal/cli/main_test.go
+++ b/internal/cli/main_test.go
@@ -16,9 +16,14 @@
 
 import (
 	"bytes"
+	"fmt"
+	"os"
+	"path/filepath"
 	"strconv"
 	"strings"
 	"testing"
+
+	"github.com/google/go-cmp/cmp"
 )
 
 func TestMainHelp(t *testing.T) {
@@ -33,6 +38,7 @@
 		{[]string{"shac", "fix", "--help"}, "Usage of shac fix:\n"},
 		{[]string{"shac", "fmt", "--help"}, "Usage of shac fmt:\n"},
 		{[]string{"shac", "doc", "--help"}, "Usage of shac doc:\n"},
+		{[]string{"shac", "version", "--help"}, "Usage of shac version:\n"},
 	}
 	for i, line := range data {
 		t.Run(strconv.Itoa(i), func(t *testing.T) {
@@ -49,7 +55,7 @@
 
 type panicWrite struct{}
 
-func (panicWrite) Write(b []byte) (int, error) {
+func (panicWrite) Write([]byte) (int, error) {
 	panic("unexpected write!")
 }
 
@@ -66,3 +72,77 @@
 func init() {
 	helpOut = panicWrite{}
 }
+
+func TestMainErr(t *testing.T) {
+	t.Parallel()
+	data := map[string]func(t *testing.T) (args []string, wantErr string){
+		"no shac.star files": func(t *testing.T) ([]string, string) {
+			root := t.TempDir()
+			return []string{"check", "-C", root, "--only", "foocheck"},
+				fmt.Sprintf("no shac.star files found in %s", root)
+		},
+		"--all with positional arguments": func(t *testing.T) ([]string, string) {
+			return []string{"check", "--all", "foo.txt", "bar.txt"},
+				"--all cannot be set together with positional file arguments"
+		},
+		"--only flag without value": func(t *testing.T) ([]string, string) {
+			root := t.TempDir()
+			return []string{"check", "-C", root, "--only"},
+				"flag needs an argument: --only"
+		},
+		"allowlist with invalid check": func(t *testing.T) ([]string, string) {
+			root := t.TempDir()
+			writeFile(t, root, "shac.star", "def cb(ctx): pass\nshac.register_check(cb)")
+			return []string{"check", "-C", root, "--only", "does-not-exist"},
+				"check does not exist: does-not-exist"
+		},
+		// Simple check that `shac fmt` filters out non-formatter checks.
+		"fmt with no checks to run": func(t *testing.T) ([]string, string) {
+			root := t.TempDir()
+			writeFile(t, root, "shac.star", ""+
+				"def non_formatter(ctx):\n"+
+				"    pass\n"+
+				"shac.register_check(shac.check(non_formatter))\n")
+			return []string{"fmt", "-C", root, "--only", "non_formatter"},
+				"no checks to run"
+		},
+		// Simple check that `shac fix` filters out formatters.
+		"fix with no checks to run": func(t *testing.T) ([]string, string) {
+			root := t.TempDir()
+			writeFile(t, root, "shac.star", ""+
+				"def formatter(ctx):\n"+
+				"    pass\n"+
+				"shac.register_check(shac.check(formatter, formatter = True))\n")
+			return []string{"fix", "-C", root, "--only", "formatter"},
+				"no checks to run"
+		},
+	}
+	for name, f := range data {
+		f := f
+		t.Run(name, func(t *testing.T) {
+			t.Parallel()
+			args, wantErr := f(t)
+			cmd := append([]string{"shac"}, args...)
+			err := Main(cmd)
+			if err == nil {
+				t.Fatalf("Expected error from running %s", cmd)
+			}
+			if diff := cmp.Diff(wantErr, err.Error()); diff != "" {
+				t.Fatalf("Wrong error (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
+
+func writeFile(t testing.TB, root, path, content string) {
+	t.Helper()
+	writeFileBytes(t, root, path, []byte(content), 0o600)
+}
+
+func writeFileBytes(t testing.TB, root, path string, content []byte, perm os.FileMode) {
+	t.Helper()
+	abs := filepath.Join(root, path)
+	if err := os.WriteFile(abs, content, perm); err != nil {
+		t.Fatal(err)
+	}
+}
diff --git a/internal/engine/run.go b/internal/engine/run.go
index 4901e3e..512cd9c 100644
--- a/internal/engine/run.go
+++ b/internal/engine/run.go
@@ -79,19 +79,88 @@
 	_ struct{}
 }
 
-// CheckFilter controls which checks get run by `Run`. It returns true for
-// checks that should be run, false for checks that should be skipped.
-type CheckFilter func(registeredCheck) bool
+// FormatterFiltering specifies whether formatting or non-formatting checks will
+// be filtered out.
+type FormatterFiltering int
 
-// OnlyFormatters causes only checks marked with `formatter = True` to be run.
-func OnlyFormatters(c registeredCheck) bool {
-	return c.formatter
+const (
+	// AllChecks does not perform any filtering based on whether a check is a
+	// formatter or not.
+	AllChecks FormatterFiltering = iota
+	// OnlyFormatters causes only checks marked with `formatter = True` to be
+	// run.
+	OnlyFormatters
+	// OnlyNonFormatters causes only checks *not* marked with `formatter = True` to
+	// be run.
+	OnlyNonFormatters
+)
+
+// CheckFilter controls which checks are run.
+type CheckFilter struct {
+	FormatterFiltering FormatterFiltering
+	// AllowList specifies checks to run. If non-empty, all other checks will be
+	// skipped.
+	AllowList []string
 }
 
-// OnlyNonFormatters causes only checks *not* marked with `formatter = True` to
-// be run.
-func OnlyNonFormatters(c registeredCheck) bool {
-	return !c.formatter
+func (f *CheckFilter) filter(checks []*registeredCheck) ([]*registeredCheck, error) {
+	if len(checks) == 0 {
+		return checks, nil
+	}
+
+	// Keep track of the allowlist elements that correspond to valid checks so
+	// we can report any invalid allowlist elements at the end.
+	nonValidatedAllowList := make(map[string]struct{})
+	for _, name := range f.AllowList {
+		nonValidatedAllowList[name] = struct{}{}
+	}
+
+	var filtered []*registeredCheck
+	for _, check := range checks {
+		if len(f.AllowList) > 0 {
+			if _, ok := nonValidatedAllowList[check.name]; !ok {
+				// Check is not allow-listed.
+				continue
+			}
+			delete(nonValidatedAllowList, check.name)
+		}
+		switch f.FormatterFiltering {
+		case AllChecks:
+		case OnlyFormatters:
+			if !check.formatter {
+				continue
+			}
+		case OnlyNonFormatters:
+			if check.formatter {
+				continue
+			}
+		default:
+			return nil, fmt.Errorf("invalid FormatterFiltering value: %d", f.FormatterFiltering)
+		}
+		filtered = append(filtered, check)
+	}
+
+	if len(nonValidatedAllowList) > 0 {
+		var invalidChecks []string
+		for name := range nonValidatedAllowList {
+			invalidChecks = append(invalidChecks, name)
+		}
+		var msg string
+		if len(invalidChecks) == 1 {
+			msg = "check does not exist"
+		} else {
+			msg = "checks do not exist"
+		}
+		slices.Sort(invalidChecks)
+		return nil, fmt.Errorf("%s: %s", msg, strings.Join(invalidChecks, ", "))
+	}
+
+	if len(filtered) == 0 {
+		// Fail noisily if all checks are filtered out, it's probably user
+		// error.
+		return nil, errors.New("no checks to run")
+	}
+	return filtered, nil
 }
 
 // Level is one of "notice", "warning" or "error".
@@ -570,7 +639,7 @@
 	//
 	// Checks are executed sequentially after all Starlark code is loaded and not
 	// mutated. They run checks and emit results (results and comments).
-	checks []registeredCheck
+	checks []*registeredCheck
 	// filter controls which checks run. If nil, all checks will run.
 	filter         CheckFilter
 	passthroughEnv []*PassthroughEnv
@@ -641,18 +710,19 @@
 	}
 	args := starlark.Tuple{shacCtx}
 	args.Freeze()
-	for i := range s.checks {
-		if s.filter != nil && !s.filter(s.checks[i]) {
-			continue
-		}
-		i := i
+	checks, err := s.filter.filter(s.checks)
+	if err != nil {
+		return err
+	}
+	for _, check := range checks {
+		check := check
 		ch <- func() error {
 			start := time.Now()
 			pi := func(th *starlark.Thread, msg string) {
 				pos := th.CallFrame(1).Pos
-				s.r.Print(ctx, s.checks[i].name, pos.Filename(), int(pos.Line), msg)
+				s.r.Print(ctx, check.name, pos.Filename(), int(pos.Line), msg)
 			}
-			err := s.checks[i].call(ctx, s.env, args, pi)
+			err := check.call(ctx, s.env, args, pi)
 			if err != nil && ctx.Err() != nil {
 				// Don't report the check completion if the context was
 				// canceled. The error was probably caused by the context being
@@ -661,7 +731,7 @@
 				// check failures.
 				return ctx.Err()
 			}
-			s.r.CheckCompleted(ctx, s.checks[i].name, start, time.Since(start), s.checks[i].highestLevel, err)
+			s.r.CheckCompleted(ctx, check.name, start, time.Since(start), check.highestLevel, err)
 			return err
 		}
 	}
diff --git a/internal/engine/run_test.go b/internal/engine/run_test.go
index da01076..9a9a722 100644
--- a/internal/engine/run_test.go
+++ b/internal/engine/run_test.go
@@ -180,6 +180,74 @@
 			}(),
 			"var must be declared in a shac.textproto file: unknown_var",
 		},
+		{
+			"invalid allowlist item",
+			func() Options {
+				root := t.TempDir()
+				writeFile(t, root, "shac.star", ""+
+					"def cb(ctx):\n"+
+					"    pass\n"+
+					"shac.register_check(cb)")
+				return Options{
+					Dir: root,
+					Filter: CheckFilter{
+						AllowList: []string{"does-not-exist"},
+					},
+				}
+			}(),
+			"check does not exist: does-not-exist",
+		},
+		{
+			"multiple invalid allowlist items",
+			func() Options {
+				root := t.TempDir()
+				writeFile(t, root, "shac.star", ""+
+					"def cb(ctx):\n"+
+					"    pass\n"+
+					"shac.register_check(cb)")
+				return Options{
+					Dir: root,
+					Filter: CheckFilter{
+						AllowList: []string{"does-not-exist", "cb", "also-does-not-exist"},
+					},
+				}
+			}(),
+			"checks do not exist: also-does-not-exist, does-not-exist",
+		},
+		{
+			"invalid FormatterFiltering",
+			func() Options {
+				root := t.TempDir()
+				writeFile(t, root, "shac.star", ""+
+					"def cb(ctx):\n"+
+					"    pass\n"+
+					"shac.register_check(cb)")
+				return Options{
+					Dir: root,
+					Filter: CheckFilter{
+						FormatterFiltering: FormatterFiltering(3),
+					},
+				}
+			}(),
+			"invalid FormatterFiltering value: 3",
+		},
+		{
+			"all checks filtered out",
+			func() Options {
+				root := t.TempDir()
+				writeFile(t, root, "shac.star", ""+
+					"def cb(ctx):\n"+
+					"    pass\n"+
+					"shac.register_check(cb)")
+				return Options{
+					Dir: root,
+					Filter: CheckFilter{
+						FormatterFiltering: OnlyFormatters,
+					},
+				}
+			}(),
+			"no checks to run",
+		},
 	}
 	for i := range data {
 		i := i
@@ -436,6 +504,77 @@
 	}
 }
 
+func TestRun_Filtering(t *testing.T) {
+	t.Parallel()
+
+	root := resolvedTempDir(t)
+
+	writeFile(t, root, "shac.star", ""+
+		"def non_formatter(ctx):\n"+
+		"    print(\"non-formatter running\")\n"+
+		"def formatter(ctx):\n"+
+		"    print(\"formatter running\")\n"+
+		"shac.register_check(shac.check(formatter, formatter = True))\n"+
+		"shac.register_check(shac.check(non_formatter))\n")
+
+	data := []struct {
+		name   string
+		filter CheckFilter
+		want   string
+	}{
+		{
+			name: "all checks",
+			want: "[//shac.star:2] non-formatter running\n" +
+				"[//shac.star:4] formatter running\n",
+		},
+		{
+			name: "only formatters",
+			filter: CheckFilter{
+				FormatterFiltering: OnlyFormatters,
+			},
+			want: "[//shac.star:4] formatter running\n",
+		},
+		{
+			name: "only non-formatters",
+			filter: CheckFilter{
+				FormatterFiltering: OnlyNonFormatters,
+			},
+			want: "[//shac.star:2] non-formatter running\n",
+		},
+		{
+			name: "only specified checks (non-formatter)",
+			filter: CheckFilter{
+				AllowList: []string{"non_formatter"},
+			},
+			want: "[//shac.star:2] non-formatter running\n",
+		},
+		{
+			name: "only specified checks (formatter)",
+			filter: CheckFilter{
+				AllowList: []string{"formatter"},
+			},
+			want: "[//shac.star:4] formatter running\n",
+		},
+	}
+	for i := range data {
+		i := i
+		t.Run(data[i].name, func(t *testing.T) {
+			r := reportPrint{reportNoPrint: reportNoPrint{t: t}}
+			o := Options{Report: &r, Dir: root, Filter: data[i].filter}
+			if err := Run(context.Background(), &o); err != nil {
+				t.Helper()
+				t.Fatal(err)
+			}
+			got := sortLines(r.b.String())
+			want := sortLines(data[i].want)
+			if diff := cmp.Diff(want, got); diff != "" {
+				t.Helper()
+				t.Fatalf("mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
+
 func TestRun_Ignore(t *testing.T) {
 	t.Parallel()
 	root := t.TempDir()
@@ -2037,7 +2176,7 @@
 			err := Run(context.Background(), &o)
 			if data[i].err != "" {
 				if err == nil {
-					t.Fatal("expected error")
+					t.Fatalf("expected error")
 				}
 				got := err.Error()
 				if data[i].err != got {
diff --git a/internal/engine/runtime_shac.go b/internal/engine/runtime_shac.go
index 2dd6ba7..f9ae715 100644
--- a/internal/engine/runtime_shac.go
+++ b/internal/engine/runtime_shac.go
@@ -70,7 +70,7 @@
 		return errors.New("can't register checks after done loading")
 	}
 	// Register the new callback.
-	s.checks = append(s.checks, registeredCheck{check: c})
+	s.checks = append(s.checks, &registeredCheck{check: c})
 	return nil
 }
 
diff --git a/internal/engine/version.go b/internal/engine/version.go
index 9a5ab01..554d94b 100644
--- a/internal/engine/version.go
+++ b/internal/engine/version.go
@@ -26,7 +26,7 @@
 	// Version is the current tool version.
 	//
 	// TODO(maruel): Add proper version, preferably from git tag.
-	Version = shacVersion{0, 1, 11}
+	Version = shacVersion{0, 1, 12}
 )
 
 func (v shacVersion) String() string {