blob: 1bce49523930d37cbf0261793b074cd84579bd6f [file] [edit]
#!/usr/bin/env python3
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool to fix and validate Go import grouping and aliasing.
Conventions:
1. Three sections in the import block, separated by a newline:
a. Standard library (unaliased).
b. External and internal packages (unaliased).
c. Aliased and blank (_) imports.
2. All protobuf message packages must be aliased.
3. Protobuf aliases must end in 'pb' or 'grpc'.
"""
import sys
import re
import argparse
from pathlib import Path
# Known protobuf package paths and their preferred aliases.
# All packages in this map MUST be aliased when imported.
KNOWN_PROTOS: dict[str, str] = {
"google.golang.org/protobuf/types/known/timestamppb": "tspb",
"google.golang.org/protobuf/types/known/fieldmaskpb": "fmpb",
"google.golang.org/protobuf/types/known/durationpb": "dpb",
"google.golang.org/protobuf/types/known/wrapperspb": "wpb",
"google.golang.org/protobuf/types/known/anypb": "anypb",
"google.golang.org/protobuf/types/known/emptypb": "emptypb",
"google.golang.org/genproto/googleapis/devtools/resultstore/v2": "rspb",
"google.golang.org/genproto/googleapis/devtools/build/v1": "buildpb",
"google.golang.org/genproto/googleapis/bytestream": "bspb",
"github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2": "repb",
"github.com/bazelbuild/remote-apis/build/bazel/semver": "svpb",
"github.com/bazelbuild/rsclient/api/jobstatus": "jspb",
"github.com/bazelbuild/rsclient/third_party/bazel/build_event_stream": "bespb",
"github.com/bazelbuild/rsclient/internal/pkg/besproxy/translation/proto": "tracepb",
}
# Regex for an import line with an optional alias.
# Match: alias "path" // comment
# Alias can be an identifier, _, or .
IMPORT_LINE_WITH_ALIAS_RE: re.Pattern = re.compile(
r'^([a-zA-Z0-9_]+|[_.]|)\s*"([^"]+)"(?:\s+(//.*))?$'
)
# Regex for an import line without an alias.
# Match: "path" // comment
IMPORT_LINE_NO_ALIAS_RE: re.Pattern = re.compile(r'^"([^"]+)"(?:\s+(//.*))?$')
def is_stdlib(path: str) -> bool:
"""Returns True if the given import path is likely a Go standard library package.
Args:
path: The Go import path string.
Returns:
bool: True if it's a standard library package, False otherwise.
"""
if not path:
return False
# CGO is special
if path == "C":
return True
first_part = path.split("/")[0]
return "." not in first_part
def parse_import_line(line: str) -> dict[str, str | None] | None:
"""Parses a single line from an import block.
Args:
line: The raw line string from the Go source file.
Returns:
dict: A dictionary with 'alias', 'path', and 'comment' keys if the line
matches an import statement, or None if it doesn't.
"""
stripped = line.strip()
if not stripped:
return None
match = IMPORT_LINE_WITH_ALIAS_RE.match(stripped)
if match:
alias, path, comment = match.groups()
if not alias:
alias = None
return {"alias": alias, "path": path, "comment": comment}
match = IMPORT_LINE_NO_ALIAS_RE.match(stripped)
if match:
path, comment = match.groups()
return {"alias": None, "path": path, "comment": comment}
return None
def transform_import_block(
import_lines: list[str],
file_path: str = "<string>",
known_protos: dict[str, str] | None = None,
) -> tuple[list[str] | None, list[str]]:
"""Transforms a list of raw import lines into the standardized grouped format.
Args:
import_lines: A list of raw strings representing the lines inside an
import (...) block.
file_path: Optional file path for error reporting.
known_protos: Optional mapping of proto paths to preferred aliases.
Defaults to KNOWN_PROTOS.
Returns:
tuple: (new_import_block, errors). new_import_block is a list of strings
representing the new formatted block, or None if errors occurred.
errors is a list of validation error strings.
"""
if known_protos is None:
known_protos = KNOWN_PROTOS
parsed_imports: list[dict[str, any]] = []
current_header_comments: list[str] = []
for line in import_lines:
stripped = line.strip()
if not stripped:
continue
if stripped.startswith("//"):
current_header_comments.append(line)
continue
parsed = parse_import_line(line)
if parsed:
parsed["header_comments"] = current_header_comments
parsed_imports.append(parsed)
current_header_comments = []
else:
raise ValueError(
f"{file_path}: Could not parse import line: {line.strip()}"
)
# Categorize and Validate
stdlib_group: list[dict[str, any]] = []
external_group: list[dict[str, any]] = []
aliased_group: list[dict[str, any]] = []
blank_group: list[dict[str, any]] = []
errors: list[str] = []
for imp in parsed_imports:
path: str = imp["path"]
alias: str | None = imp["alias"]
# Validation: Protobufs must be aliased and match preferred alias
preferred_alias = known_protos.get(path)
if preferred_alias:
if alias is None:
errors.append(
f"Protobuf import must be aliased: {path} (preferred: {preferred_alias})"
)
elif alias != "_" and alias != preferred_alias:
# Allow grpc variant if preferred ends in pb
is_grpc_variant = preferred_alias.endswith("pb") and alias == (
preferred_alias[:-2] + "grpc"
)
if not is_grpc_variant:
msg = f"Invalid alias '{alias}' for proto '{path}'; preferred is '{preferred_alias}'"
if preferred_alias.endswith("pb"):
msg += f" (or '{preferred_alias[:-2] + 'grpc'}' if appropriate)"
errors.append(msg)
# Categorization
if alias == "_":
blank_group.append(imp)
elif alias:
aliased_group.append(imp)
elif is_stdlib(path):
stdlib_group.append(imp)
else:
external_group.append(imp)
if errors:
return None, errors
# Build the new import block
def format_group(group: list[dict[str, any]]) -> list[str]:
res: list[str] = []
for imp in group:
res.extend(imp["header_comments"])
line = "\t"
if imp["alias"]:
line += imp["alias"] + " "
line += f'"{imp["path"]}"'
if imp["comment"]:
line += " " + imp["comment"]
res.append(line + "\n")
return res
new_import_block: list[str] = []
if stdlib_group:
new_import_block.extend(format_group(stdlib_group))
if external_group:
if new_import_block:
new_import_block.append("\n")
new_import_block.extend(format_group(external_group))
if aliased_group:
if new_import_block:
new_import_block.append("\n")
new_import_block.extend(format_group(aliased_group))
if blank_group:
if new_import_block:
new_import_block.append("\n")
new_import_block.extend(format_group(blank_group))
# Append any leftover comments that were at the end of the block
if current_header_comments:
if new_import_block:
new_import_block.append("\n")
new_import_block.extend(current_header_comments)
return new_import_block, []
def process_file(file_path: Path, write: bool = False) -> bool:
"""Processes a single Go file, validating and optionally fixing its imports.
Args:
file_path: Path object to the .go file.
write: If True, writes the fixed import block back to the file.
Returns:
bool: True if the file is compliant (or was fixed), False if violations remain.
"""
if file_path.suffix != ".go":
return True
# Safety check: skip third_party and internal tool dirs
exclude_dirs = {"third_party", "for_context_only", ".cipd", ".git"}
if any(ex in file_path.parts for ex in exclude_dirs):
return True
try:
with open(file_path, "r") as f:
lines = f.readlines()
except Exception as e:
print(f"Error reading {file_path}: {e}")
return False
start_index = -1
end_index = -1
import_lines: list[str] = []
# Find the multi-line import block
for i, line in enumerate(lines):
if line.startswith("import ("):
start_index = i
for j in range(i + 1, len(lines)):
if lines[j].startswith(")"):
end_index = j
break
import_lines.append(lines[j])
break
if start_index == -1 or end_index == -1:
# No multi-line import block found. Skip single imports or files without imports.
return True
try:
new_import_block, errors = transform_import_block(
import_lines, str(file_path)
)
except ValueError as e:
print(e, file=sys.stderr)
return False
if errors:
for err in errors:
print(f"{file_path}: {err}", file=sys.stderr)
return False
# Check if anything changed (ignoring whitespace differences within the block that gofmt will fix)
if "".join(import_lines) == "".join(new_import_block):
return True
if write:
new_file_content = (
lines[: start_index + 1] + new_import_block + lines[end_index:]
)
with open(file_path, "w") as f:
f.writelines(new_file_content)
return True
else:
print(
f"{file_path}: Import block is not correctly grouped. Run with --write to fix."
)
return False
def main(argv: list[str] | None = None) -> None:
"""Main entry point for the tool.
Args:
argv: Optional list of command-line arguments. If None, sys.argv[1:] is used.
"""
parser = argparse.ArgumentParser(
description="Fix and validate Go import grouping."
)
parser.add_argument(
"files",
nargs="+",
type=Path,
help="Go files to process.",
)
parser.add_argument(
"-w", "--write", action="store_true", help="Write changes to files"
)
args = parser.parse_args(argv)
success = True
for file_path in args.files:
if not process_file(file_path, write=args.write):
success = False
if not success:
sys.exit(1)
if __name__ == "__main__":
main(sys.argv[1:])