| #!/usr/bin/env python3 |
| # Copyright 2026 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Tool to fix and validate Go import grouping and aliasing. |
| |
| Conventions: |
| 1. Three sections in the import block, separated by a newline: |
| a. Standard library (unaliased). |
| b. External and internal packages (unaliased). |
| c. Aliased and blank (_) imports. |
| 2. All protobuf message packages must be aliased. |
| 3. Protobuf aliases must end in 'pb' or 'grpc'. |
| """ |
| |
| import sys |
| import re |
| import argparse |
| from pathlib import Path |
| |
| # Known protobuf package paths and their preferred aliases. |
| # All packages in this map MUST be aliased when imported. |
| KNOWN_PROTOS: dict[str, str] = { |
| "google.golang.org/protobuf/types/known/timestamppb": "tspb", |
| "google.golang.org/protobuf/types/known/fieldmaskpb": "fmpb", |
| "google.golang.org/protobuf/types/known/durationpb": "dpb", |
| "google.golang.org/protobuf/types/known/wrapperspb": "wpb", |
| "google.golang.org/protobuf/types/known/anypb": "anypb", |
| "google.golang.org/protobuf/types/known/emptypb": "emptypb", |
| "google.golang.org/genproto/googleapis/devtools/resultstore/v2": "rspb", |
| "google.golang.org/genproto/googleapis/devtools/build/v1": "buildpb", |
| "google.golang.org/genproto/googleapis/bytestream": "bspb", |
| "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2": "repb", |
| "github.com/bazelbuild/remote-apis/build/bazel/semver": "svpb", |
| "github.com/bazelbuild/rsclient/api/jobstatus": "jspb", |
| "github.com/bazelbuild/rsclient/third_party/bazel/build_event_stream": "bespb", |
| "github.com/bazelbuild/rsclient/internal/pkg/besproxy/translation/proto": "tracepb", |
| } |
| |
| # Regex for an import line with an optional alias. |
| # Match: alias "path" // comment |
| # Alias can be an identifier, _, or . |
| IMPORT_LINE_WITH_ALIAS_RE: re.Pattern = re.compile( |
| r'^([a-zA-Z0-9_]+|[_.]|)\s*"([^"]+)"(?:\s+(//.*))?$' |
| ) |
| |
| # Regex for an import line without an alias. |
| # Match: "path" // comment |
| IMPORT_LINE_NO_ALIAS_RE: re.Pattern = re.compile(r'^"([^"]+)"(?:\s+(//.*))?$') |
| |
| |
| def is_stdlib(path: str) -> bool: |
| """Returns True if the given import path is likely a Go standard library package. |
| |
| Args: |
| path: The Go import path string. |
| |
| Returns: |
| bool: True if it's a standard library package, False otherwise. |
| """ |
| if not path: |
| return False |
| # CGO is special |
| if path == "C": |
| return True |
| first_part = path.split("/")[0] |
| return "." not in first_part |
| |
| |
| def parse_import_line(line: str) -> dict[str, str | None] | None: |
| """Parses a single line from an import block. |
| |
| Args: |
| line: The raw line string from the Go source file. |
| |
| Returns: |
| dict: A dictionary with 'alias', 'path', and 'comment' keys if the line |
| matches an import statement, or None if it doesn't. |
| """ |
| stripped = line.strip() |
| if not stripped: |
| return None |
| |
| match = IMPORT_LINE_WITH_ALIAS_RE.match(stripped) |
| if match: |
| alias, path, comment = match.groups() |
| if not alias: |
| alias = None |
| return {"alias": alias, "path": path, "comment": comment} |
| |
| match = IMPORT_LINE_NO_ALIAS_RE.match(stripped) |
| if match: |
| path, comment = match.groups() |
| return {"alias": None, "path": path, "comment": comment} |
| |
| return None |
| |
| |
| def transform_import_block( |
| import_lines: list[str], |
| file_path: str = "<string>", |
| known_protos: dict[str, str] | None = None, |
| ) -> tuple[list[str] | None, list[str]]: |
| """Transforms a list of raw import lines into the standardized grouped format. |
| |
| Args: |
| import_lines: A list of raw strings representing the lines inside an |
| import (...) block. |
| file_path: Optional file path for error reporting. |
| known_protos: Optional mapping of proto paths to preferred aliases. |
| Defaults to KNOWN_PROTOS. |
| |
| Returns: |
| tuple: (new_import_block, errors). new_import_block is a list of strings |
| representing the new formatted block, or None if errors occurred. |
| errors is a list of validation error strings. |
| """ |
| if known_protos is None: |
| known_protos = KNOWN_PROTOS |
| |
| parsed_imports: list[dict[str, any]] = [] |
| current_header_comments: list[str] = [] |
| |
| for line in import_lines: |
| stripped = line.strip() |
| if not stripped: |
| continue |
| if stripped.startswith("//"): |
| current_header_comments.append(line) |
| continue |
| |
| parsed = parse_import_line(line) |
| if parsed: |
| parsed["header_comments"] = current_header_comments |
| parsed_imports.append(parsed) |
| current_header_comments = [] |
| else: |
| raise ValueError( |
| f"{file_path}: Could not parse import line: {line.strip()}" |
| ) |
| |
| # Categorize and Validate |
| stdlib_group: list[dict[str, any]] = [] |
| external_group: list[dict[str, any]] = [] |
| aliased_group: list[dict[str, any]] = [] |
| blank_group: list[dict[str, any]] = [] |
| errors: list[str] = [] |
| |
| for imp in parsed_imports: |
| path: str = imp["path"] |
| alias: str | None = imp["alias"] |
| |
| # Validation: Protobufs must be aliased and match preferred alias |
| preferred_alias = known_protos.get(path) |
| if preferred_alias: |
| if alias is None: |
| errors.append( |
| f"Protobuf import must be aliased: {path} (preferred: {preferred_alias})" |
| ) |
| elif alias != "_" and alias != preferred_alias: |
| # Allow grpc variant if preferred ends in pb |
| is_grpc_variant = preferred_alias.endswith("pb") and alias == ( |
| preferred_alias[:-2] + "grpc" |
| ) |
| if not is_grpc_variant: |
| msg = f"Invalid alias '{alias}' for proto '{path}'; preferred is '{preferred_alias}'" |
| if preferred_alias.endswith("pb"): |
| msg += f" (or '{preferred_alias[:-2] + 'grpc'}' if appropriate)" |
| errors.append(msg) |
| |
| # Categorization |
| if alias == "_": |
| blank_group.append(imp) |
| elif alias: |
| aliased_group.append(imp) |
| elif is_stdlib(path): |
| stdlib_group.append(imp) |
| else: |
| external_group.append(imp) |
| |
| if errors: |
| return None, errors |
| |
| # Build the new import block |
| def format_group(group: list[dict[str, any]]) -> list[str]: |
| res: list[str] = [] |
| for imp in group: |
| res.extend(imp["header_comments"]) |
| line = "\t" |
| if imp["alias"]: |
| line += imp["alias"] + " " |
| line += f'"{imp["path"]}"' |
| if imp["comment"]: |
| line += " " + imp["comment"] |
| res.append(line + "\n") |
| return res |
| |
| new_import_block: list[str] = [] |
| if stdlib_group: |
| new_import_block.extend(format_group(stdlib_group)) |
| |
| if external_group: |
| if new_import_block: |
| new_import_block.append("\n") |
| new_import_block.extend(format_group(external_group)) |
| |
| if aliased_group: |
| if new_import_block: |
| new_import_block.append("\n") |
| new_import_block.extend(format_group(aliased_group)) |
| |
| if blank_group: |
| if new_import_block: |
| new_import_block.append("\n") |
| new_import_block.extend(format_group(blank_group)) |
| |
| # Append any leftover comments that were at the end of the block |
| if current_header_comments: |
| if new_import_block: |
| new_import_block.append("\n") |
| new_import_block.extend(current_header_comments) |
| |
| return new_import_block, [] |
| |
| |
| def process_file(file_path: Path, write: bool = False) -> bool: |
| """Processes a single Go file, validating and optionally fixing its imports. |
| |
| Args: |
| file_path: Path object to the .go file. |
| write: If True, writes the fixed import block back to the file. |
| |
| Returns: |
| bool: True if the file is compliant (or was fixed), False if violations remain. |
| """ |
| if file_path.suffix != ".go": |
| return True |
| |
| # Safety check: skip third_party and internal tool dirs |
| exclude_dirs = {"third_party", "for_context_only", ".cipd", ".git"} |
| if any(ex in file_path.parts for ex in exclude_dirs): |
| return True |
| |
| try: |
| with open(file_path, "r") as f: |
| lines = f.readlines() |
| except Exception as e: |
| print(f"Error reading {file_path}: {e}") |
| return False |
| |
| start_index = -1 |
| end_index = -1 |
| import_lines: list[str] = [] |
| |
| # Find the multi-line import block |
| for i, line in enumerate(lines): |
| if line.startswith("import ("): |
| start_index = i |
| for j in range(i + 1, len(lines)): |
| if lines[j].startswith(")"): |
| end_index = j |
| break |
| import_lines.append(lines[j]) |
| break |
| |
| if start_index == -1 or end_index == -1: |
| # No multi-line import block found. Skip single imports or files without imports. |
| return True |
| |
| try: |
| new_import_block, errors = transform_import_block( |
| import_lines, str(file_path) |
| ) |
| except ValueError as e: |
| print(e, file=sys.stderr) |
| return False |
| |
| if errors: |
| for err in errors: |
| print(f"{file_path}: {err}", file=sys.stderr) |
| return False |
| |
| # Check if anything changed (ignoring whitespace differences within the block that gofmt will fix) |
| if "".join(import_lines) == "".join(new_import_block): |
| return True |
| |
| if write: |
| new_file_content = ( |
| lines[: start_index + 1] + new_import_block + lines[end_index:] |
| ) |
| with open(file_path, "w") as f: |
| f.writelines(new_file_content) |
| return True |
| else: |
| print( |
| f"{file_path}: Import block is not correctly grouped. Run with --write to fix." |
| ) |
| return False |
| |
| |
| def main(argv: list[str] | None = None) -> None: |
| """Main entry point for the tool. |
| |
| Args: |
| argv: Optional list of command-line arguments. If None, sys.argv[1:] is used. |
| """ |
| parser = argparse.ArgumentParser( |
| description="Fix and validate Go import grouping." |
| ) |
| parser.add_argument( |
| "files", |
| nargs="+", |
| type=Path, |
| help="Go files to process.", |
| ) |
| parser.add_argument( |
| "-w", "--write", action="store_true", help="Write changes to files" |
| ) |
| args = parser.parse_args(argv) |
| |
| success = True |
| for file_path in args.files: |
| if not process_file(file_path, write=args.write): |
| success = False |
| |
| if not success: |
| sys.exit(1) |
| |
| |
| if __name__ == "__main__": |
| main(sys.argv[1:]) |