# Copyright 2018 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Recipe module that wraps the testsharder tool, which searches a Fuchsia build for
test specifications and groups them into shards.

The tool assigns a unique name to each produced shard.

Testsharder tool:
https://fuchsia.googlesource.com/fuchsia/+/refs/heads/main/tools/integration/testsharder/cmd/
"""

import re

import attr

from recipe_engine import recipe_api


@attr.s(frozen=True)
class Shard:
    """Represents an immutable shard of several tests with one common environment."""

    # The name of the shard.
    name = attr.ib(type=str)

    # A sequence of dicts representing tests.
    tests = attr.ib(converter=tuple)

    # A dictionary of Swarming dimensions, mapping name to value.
    _dimensions = attr.ib(
        type=dict, converter=lambda d: {k: str(v) for k, v in d.items()}
    )

    # The execution timeout to set for the task that runs this shard.
    timeout_secs = attr.ib(type=int, default=0)

    # A service account to be attached to a swarming task when this shard is run.
    service_account = attr.ib(type=str, default="")

    # Whether to netboot instead of paving before running the tests.
    netboot = attr.ib(type=bool, default=False)

    # The runtime dependencies of the tests in the shard given as relative paths
    # within the build directory. This field only makes sense for Linux and Mac
    # tests.
    deps = attr.ib(factory=tuple, converter=tuple)

    # The path to the shard specific package repository.
    pkg_repo = attr.ib(type=str, default="")

    # A dictionary containing the images to override in this shard.
    image_overrides = attr.ib(factory=dict)

    # A summary containing the test results of this shard. This is only
    # populated if all the contained tests were skipped.
    summary = attr.ib(factory=dict)

    @classmethod
    def from_jsonish(cls, jsond):
        """Creates a new Shard from a JSON-compatible Python dict."""

        return cls(
            name=jsond["name"],
            tests=jsond["tests"],
            dimensions=jsond["environment"]["dimensions"],
            service_account=jsond["environment"].get("service_account", ""),
            deps=jsond.get("deps", ()),
            netboot=jsond["environment"].get("netboot", False),
            timeout_secs=jsond.get("timeout_secs", 0),
            pkg_repo=jsond.get("pkg_repo", ""),
            image_overrides=jsond.get("image_overrides", {}),
            summary=jsond.get("summary", {}),
        )

    @property
    def dimensions(self):
        """The Swarming dimensions (dict[str]str) specified by the shard."""
        return self._dimensions.copy()

    @property
    def os(self):
        """Returns the OS that the shard's tests should run on."""
        return self._dimensions.get("os")

    @property
    def device_type(self):
        """Returns the device type used by the shard."""
        return self._dimensions.get("device_type")

    @property
    def targets_fuchsia(self):
        """Returns whether the shard of tests is meant to run on or against fuchsia."""
        return self.os not in ["Linux", "Mac"]

    @property
    def should_skip(self):
        """Returns true if this shard contains a summary and should be skipped."""
        if not self.summary.get("tests"):
            return False
        return True


@attr.s(frozen=True)
class TestModifier:
    name = attr.ib(type=str)
    os = attr.ib(type=str, converter=lambda x: x or "", default="")
    total_runs = attr.ib(type=int, converter=lambda x: int(x) if x else 0, default=0)
    affected = attr.ib(type=bool, default=False)
    max_attempts = attr.ib(type=int, default=1)

    def render_to_jsonish(self):
        return attr.asdict(self)

    def merge(self, mod):
        """Returns a new TestModifier with merged fields from mod and self.

        Except for `total_runs`, the same field cannot be set to different
        non-default values for both modifiers. For `total_runs`, we pick the
        larger value.
        """
        fields = {}
        for field in attr.fields(TestModifier):
            v1 = getattr(self, field.name)
            v2 = getattr(mod, field.name)
            if field.name in ["total_runs", "max_attempts"]:
                # A total_runs of 0 means to run as many times as will fit in one
                # shard whereas a negative total_runs value means to not multiply it
                # at all. Thus, we should take the value of 0 over a negative value.
                fields[field.name] = max(v1, v2)
                continue
            assert not (
                v1 and v2 and v1 != v2
            ), "multiple modifiers have conflicting values for the same test"
            fields[field.name] = v1 or v2
        return TestModifier(**fields)


class TestsharderApi(recipe_api.RecipeApi):
    """Module for interacting with the Testsharder tool.

    The testsharder tool accepts a set of test specifications and
    produces a file containing shards of execution.
    """

    Shard = Shard
    TestModifier = TestModifier

    def execute(
        self,
        step_name,
        testsharder_path,
        build_dir,
        max_shard_size=None,
        target_duration_secs=0,
        per_test_timeout_secs=0,
        max_shards_per_env=0,
        modifiers=None,
        output_file=None,
        tags=(),
        use_affected_tests=False,
        affected_tests=None,
        affected_tests_multiply_threshold=0,
        affected_tests_max_attempts=0,
        affected_only=False,
        ffx_deps=False,
        image_deps=False,
        hermetic_deps=False,
        pave=False,
        disabled_device_types=(),
        skip_unaffected_tests=False,
        per_shard_package_repos=False,
        cache_test_packages=False,
    ):
        """Executes the testsharder tool.

        Args:
          step_name (str): name of the step.
          testsharder_path (Path): path to the testsharder tool.
          build_dir (Path): path to a Fuchsia build directory root, in
            which GN has been run (ninja need not have been executed).
          max_shard_size (long or None):  Additional shards will be created if needed
            to keep the number of tests per shard <= this number.
          target_duration_secs (int): If >0, testsharder will try to produce
            shards of approximately this duration.
          per_test_timeout_secs (int): Any test that executes for longer than
            this will be considered failed.
          max_shards_per_env (int): Testsharder will limit each environment to
            this many shards, regardless of the values of max_shard_size and
            target_duration_secs. If 0, testsharder will use its hardcoded
            default max; if <0, no max will be set.
          modifiers (list(TestModifier)): the test modifiers specifying
            tests to run multiple times or to label as affected tests; if
            supplied, new shards will be created to run them.
          output_file (Path): optional file path to leak output to.
          tags (list(str)): tags on which to filter test specs.
          use_affected_tests (bool): whether or not to pass affected tests flags to testsharder.
          affected_tests (list(str)): names of affected tests.
          affected_tests_multiply_threshold (int): if there are <= this many affected
            tests, they'll be considered for multiplication.
          affected_tests_max_attempts (int): max attempts for affected tests which are
            not multiplied.
          affected_only (bool): whether to only create shards for affected tests.
          ffx_deps (bool): whether to include ffx deps in the shard deps.
          image_deps (bool): whether to include image deps in the shard deps.
          hermetic_deps (bool): whether to include all deps in the shard deps.
          pave (bool): whether the shards should pave or netboot fuchsia.
          disabled_device_types (list(str)): the device types to not return shards for.
          skip_unaffected_tests (bool): whether to skip running unaffected tests.
          per_shard_package_repos (bool): whether to generate per-shard package repos.
          cache_test_packages (bool): whether to cache test packages in CAS.

        Returns:
          A list of Shards, each representing one test shard.
        """
        logs = {}
        cmd = [
            testsharder_path,
            "-build-dir",
            build_dir,
            "-output-file",
            self.m.json.output(leak_to=output_file),
        ]
        if max_shard_size:
            cmd += ["-max-shard-size", max_shard_size]
        if target_duration_secs:
            cmd += ["-target-duration-secs", target_duration_secs]
        if per_test_timeout_secs:
            cmd += ["-per-test-timeout-secs", per_test_timeout_secs]
        if max_shards_per_env:
            # TODO(10404): this conditional is only necessary for backwards
            # compatibility; by only setting this flag when we want to use a value
            # that's not the testsharder default, we can avoid breaking builds that
            # run against old versions of fuchsia.git from before this flag was
            # implemented. Once backwards compatibility is no longer needed and/or
            # recipe versioning is complete, we can remove the default flag value
            # from testsharder into the specs, and unconditionally set this flag.
            cmd += ["-max-shards-per-env", max_shards_per_env]
        if modifiers:
            input_json = [
                m.render_to_jsonish() for m in self._combine_modifiers(modifiers)
            ]
            cmd += ["-modifiers", self.m.json.input(input_json)]
            logs["test_modifiers.json"] = self.m.json.dumps(input_json, indent=2)
        for tag in tags:
            cmd += ["-tag", tag]

        if use_affected_tests:
            affected_tests_text = "\n".join(affected_tests) + "\n"
            cmd += [
                "-affected-tests",
                self.m.raw_io.input(affected_tests_text),
                "-affected-tests-max-attempts",
                affected_tests_max_attempts,
                "-affected-tests-multiply-threshold",
                affected_tests_multiply_threshold,
            ]
            logs["affected_tests.txt"] = affected_tests_text
        if affected_only:
            cmd.append("-affected-only")
        if skip_unaffected_tests:
            cmd.append("-skip-unaffected")

        if ffx_deps:
            cmd.append("-ffx-deps")
        if image_deps:
            cmd.append("-image-deps")
        if hermetic_deps:
            cmd.append("-hermetic-deps")
        if per_shard_package_repos:
            cmd.append("-per-shard-package-repos")
        if cache_test_packages:
            cmd.append("-cache-test-packages")

        if pave:
            cmd.append("-pave")

        step = self.m.step(
            step_name,
            cmd,
            stderr=self.m.raw_io.output_text(add_output_log="on_failure"),
            ok_ret="any",
        )
        step.presentation.logs.update(logs)
        if step.retcode > 0:
            step.presentation.status = self.m.step.FAILURE
            if not step.stderr.strip():  # pragma: no cover
                # In the unlikely event that there are no logs to explain the
                # failure, emit a regular step failure error message.
                self.m.step.raise_on_failure(step)
            raise self.m.step.StepFailure(
                "testsharder failed:\n\n%s"
                % (
                    self.m.buildbucket_util.summary_message(
                        step.stderr,
                        truncation_message="(truncated, see %s stderr for details)"
                        % step_name,
                    )
                )
            )

        shards = []
        for s in step.json.output:
            shard = Shard.from_jsonish(s)
            if not shard.tests:
                raise self.m.step.StepFailure(f"shard {shard.name} has no tests to run")
            if disabled_device_types and shard.device_type in disabled_device_types:
                continue
            shards.append(shard)
        return shards

    def _combine_modifiers(self, modifiers):
        """Combine modifiers so there is one per test name."""
        mod_map = {}
        for m in modifiers:
            if m.name in mod_map:
                mod_map[m.name] = mod_map[m.name].merge(m)
            else:
                mod_map[m.name] = m
        return sorted(mod_map.values(), key=lambda m: m.name)

    def affected_test_modifiers(self, tests, max_attempts):
        """Returns TestModifiers for the given tests with the affected property set.

        The total_runs should be set to a negative number so that the test
        doesn't get counted as a multiplier.
        """
        return [
            TestModifier(
                name=test, total_runs=-1, affected=True, max_attempts=max_attempts
            )
            for test in tests
        ]

    def should_run_all_tests(self, commit_message):
        """Extracts the value of the "Run-All-Tests" footer.

        Args:
          commit_message (str): The commit message in which to search for the footer.

        Returns:
          True if we Run-All-Tests is set, False if not.
        """
        run_all_tests_re = r"(?i)run-all-tests:\s*(true|false)"
        match = re.search(run_all_tests_re, commit_message, re.MULTILINE)
        if match and match.group(1).lower() == "true":
            return True
        return False

    def extract_multipliers(self, commit_message):
        """Extracts the value of the "Multiply/MULTIPLY" footer.

        We support two multiply syntaxes:
        1. Raw JSON:
          Multiply: `[
            {
              "name": "test_name",
              "os": "linux",
              "total_runs": 123
            }
          ]`
        2. "Friendly":
          Multiply: testsharder_tests (linux): 123, foo_tests (fuchsia): 5

        Args:
          commit_message (str): The commit message in which to search for the footer.

        Returns:
          A list of `TestModifier`, which will be empty if the commit message
          doesn't contain the footer.
        """
        multipliers = []

        prefix = r"^\s*(MULTIPLY|Multiply)\s*[:=]\s*"
        json_multiplier_regex = prefix + r"[\s\n]*`(?P<json>(.|\n)*?)`"
        json_match = re.search(json_multiplier_regex, commit_message, re.MULTILINE)
        if json_match:
            multiplier_json = self.m.jsonutil.permissive_loads(json_match.group("json"))
            for raw_multiplier in multiplier_json:
                try:
                    multiplier = TestModifier(**raw_multiplier)
                except (ValueError, TypeError):
                    raise self.m.step.StepFailure(
                        f"invalid test multiplier: {self.m.json.dumps(raw_multiplier)}"
                    )
                multipliers.append(multiplier)
            return multipliers

        friendly_multiplier_regex = re.compile(
            r"""
        # Required name. It cannot contain commas, spaces, or parentheses. The
        # question mark makes it non-greedy to prevent the colon preceding the
        # optional run count from being considered part of the name.
        (?P<name>[^,\(\s]+?)
        \s*
        # Optional "os" in parentheses
        (
          \(
            \s*
            (?P<os>\w*)
            \s*
          \)
        )?
        \s*
        # Optional "total_runs" integer, preceded by a colon and any number of
        # spaces.
        (
          :\s*
          (?P<total_runs>\d+)
        )?
        """,
            re.VERBOSE,
        )

        for line in commit_message.splitlines():
            line = line.strip()
            footer_match = re.match(prefix, line)
            if not footer_match:
                continue

            raw_multipliers = line[footer_match.end() :].strip(" ,").split(",")
            for raw_multiplier in raw_multipliers:
                raw_multiplier = raw_multiplier.strip()
                match = friendly_multiplier_regex.fullmatch(raw_multiplier)
                if not match:
                    raise self.m.step.StepFailure(
                        f"invalid multiplier {raw_multiplier!r}"
                    )
                multipliers.append(TestModifier(**match.groupdict()))

        return multipliers
