# Copyright 2019 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Launch and retry swarming jobs until they pass or we hit max attempts."""

import itertools
import traceback

import attr
from recipe_engine import recipe_api

DEFAULT_MAX_ATTEMPTS = 2
DEFAULT_COLLECT_TIMEOUT = '5m'


@attr.s
class Attempt(object):
  """References a specific attempt of a task."""

  index = attr.ib(type=int)  # Number of prior attempts.
  task_id = attr.ib(type=str)
  host = attr.ib(type=str, default=None)
  task_ui_link = attr.ib(type=str, default=None)
  # api.swarming.TaskResult from api.swarming.collect() call.
  result = attr.ib(default=None)
  # This attribute should be set by overrides of Task.process_result(). It
  # indicates that even though at the swarming level the task may have
  # passed something failed inside that larger task.
  failure_reason = attr.ib(type=str, default='')
  task_outputs_link = attr.ib(type=str, default=None)
  logs = attr.ib(type=dict, default=attr.Factory(dict))

  def __attrs_post_init__(self):
    # The led module gives the host and the id, but the swarming module
    # gives the link and the id. Require the id (since it has no default
    # above) and require either the host or task_ui_link attributes.
    assert self.host or self.task_ui_link
    if not self.task_ui_link:
      self.task_ui_link = 'https://%s/task?id=%s' % (self.host, self.task_id)

  @property
  def name(self):
    return 'attempt %d' % self.index

  @property
  def in_progress(self):
    return self.result is None

  # TODO(mohrr) add hook for pass/fail beyond swarming task level.
  # In some cases may need to examine isolated output to determine pass/fail.
  @property
  def success(self):
    if self.failure_reason:
      return False

    if not self.result:
      return False

    try:
      self.result.analyze()
      return True
    except recipe_api.StepFailure:
      return False


class Task(object):
  """Metadata about tasks, meant to be subclassed.

  Subclasses must define a launch() method. It must launch a task (using
  swarming, led, or something else), create an Attempt object, and append
  it to self.attempts. The attempt object requires task_id and host from
  the swarming or led result and an index equal to the prior length of
  self.results.

  In most cases Task.max_attempts should be left alone. If the caller wants
  to ensure a task has a larger or smaller number of max attempts than the
  default for other tasks, set max_attempts to that number.
  """

  def __init__(self, api, name, launch_deadline_time=None):
    """Initializer.

    Args:
      api: recipe_api.RecipeApiPlain object.
      name: str, human readable name of this task
      launch_deadline_time: float or int. If set, will keep launching until
        this time, ignoring max_attempts.
    """
    self._api = api
    self.name = name
    self._launch_deadline_time = launch_deadline_time
    self.attempts = []
    self.max_attempts = None

  def process_result(self):
    """Examine the result in the last attempt for failures.

    Subclasses can set self.attempts[-1].failure_reason if they find a
    failure inside self.attempts[-1].result. failure_reason should be a
    short summary of the failure (< 50 chars).

    This is invoked shortly after api.swarming.collect() returns that a
    task completed. It cannot assume the swarming task completed
    successfully.

    This is a no-op here but can be overridden by subclasses.

    Returns:
      None
    """
    pass

  def present_status(self, parent_step_presentation, attempt, **kwargs):
    """Present an Attempt while showing progress in launch/collect step.

    Args:
      parent_step_presentation (StepPresentation): will always be for
        'passed tasks' or 'failed tasks'
      attempt (Attempt): the Attempt to present
      **kwargs (Dict): pass-through arguments for subclasses

    This method will be invoked to show details of an Attempt. This base
    class method just creates a link to the swarming results from the task,
    but subclasses are free to create a step with much more elaborate
    details of results.

    This is only invoked for completed tasks. Identical code is used for
    incomplete tasks, except it's not in Task so it can't be overridden, so
    subclasses don't need to handle incomplete tasks.

    Returns:
      None
    """
    del kwargs  # Unused.
    name = '%s (%s)' % (self.name, attempt.name)
    parent_step_presentation.links[name] = attempt.task_ui_link

  def present_attempt(self, task_step_presentation, attempt, **kwargs):
    """Present an Attempt when summarizing results at the end of the run.

    Args:
      task_step_presentation (StepPresentation): assuming present() was not
        overridden, this will always be for a step titled after the current
        task
      attempt (Attempt): the Attempt to present
      **kwargs (Dict): pass-through arguments for subclasses

    This method will be invoked to show details of an Attempt. This base
    class method just creates a link to the swarming results from the task,
    but subclasses are free to create a step with much more elaborate
    details of results.

    Returns:
      None
    """
    del kwargs  # Unused.
    name = '%s (%s)' % (attempt.name, 'pass' if attempt.success else 'fail')
    task_step_presentation.links[name] = attempt.task_ui_link

  def present(self, **kwargs):
    """Present this Task when summarizing results at the end of the run.

    Args:
      **kwargs (Dict): pass-through arguments for subclasses

    This method will be invoked to show details of this Task. This base
    class method nests with the name of the task and loops over
    Attempts, but subclasses are free to do something more elaborate.

    Returns:
      None
    """
    with self._api.step.nest(self.name) as task_step_presentation:
      for attempt in self.attempts:
        self.present_attempt(task_step_presentation, attempt, **kwargs)

      if self.success:
        task_step_presentation.status = self._api.step.SUCCESS
      else:
        task_step_presentation.status = self._api.step.FAILURE

  def should_launch(self):
    if self.in_progress:
      return False
    if self._launch_deadline_time:
      return self._api.time.time() < self._launch_deadline_time
    return (not self.success) and (len(self.attempts) < self.max_attempts)

  @property
  def success(self):
    return self.attempts and self.attempts[-1].success

  @property
  def in_progress(self):
    return self.attempts and self.attempts[-1].in_progress

  def launch(self):
    assert False, 'Subclasses must define launch() method.'  # pragma: no cover

  @property
  def result(self):
    assert self.attempts
    return self.attempts[-1].result


class TriggeredTask(Task):

  def __init__(self, request, *args, **kwargs):
    super(TriggeredTask, self).__init__(*args, **kwargs)
    self._request = request

  def launch(self):
    assert self._request
    metadata = self._api.swarming.trigger(
        'trigger', [self._request], cancel_extra_tasks=True)
    assert len(metadata) == 1
    metadata = metadata[0]

    attempt = self._api.swarming_retry.Attempt(
        index=len(self.attempts),
        task_ui_link=metadata.task_ui_link,
        task_id=metadata.id,
    )
    self.attempts.append(attempt)
    return attempt


class LedTask(Task):

  def __init__(self, led_data, *args, **kwargs):
    super(LedTask, self).__init__(*args, **kwargs)
    self._led_data = led_data

  def launch(self):
    assert self._led_data
    res = self._led_data.then('launch')
    attempt = self._api.swarming_retry.Attempt(
        index=len(self.attempts),
        host=res.result['swarming']['host_name'],
        task_id=res.result['swarming']['task_id'],
    )
    self.attempts.append(attempt)
    return attempt


class RetrySwarmingApi(recipe_api.RecipeApi):
  """Launch and retry swarming jobs until they pass or we hit max attempts."""

  DEFAULT_MAX_ATTEMPTS = DEFAULT_MAX_ATTEMPTS

  Task = Task  # pylint: disable=invalid-name
  LedTask = LedTask  # pylint: disable=invalid-name
  TriggeredTask = TriggeredTask  # pylint: disable=invalid-name

  Attempt = Attempt  # pylint: disable=invalid-name

  def __init__(self, *args, **kwargs):
    super(RetrySwarmingApi, self).__init__(*args, **kwargs)
    self._tasks_by_id = {}

  def retry(self, tasks=None):
    """Create an object that will enforce order of calls.

    Create an object that will enforce run_tasks is called before
    present_tasks and raise_failures and that raise_failures is always
    called if run_tasks was called. The two examples below both invoke
    raise_failures, but only one does so explicitly.

    with api.swarming_retry.retry(tasks) as retry:
      retry.run_tasks()
      retry.present_tasks()
      retry.raise_failures()

    with api.swarming_retry.retry(tasks) as retry:
      retry.run_tasks()
      retry.present_tasks()
      # Not needed because this is handled by the context object.
      # retry.raise_failures()

    Args:
      tasks (seq[Task]): tasks to run, present, and check status of

    If tasks is None, retry.tasks must be set before calling
    retry.run_tasks().
    """

    class Context(object):

      def __init__(self, api, tasks=None):
        self._api = api
        self.tasks = tasks
        self._has_run = self._has_raised = False

      def __enter__(self):
        return self

      def __exit__(self, exc_type, exc_value, exc_traceback):
        del exc_type, exc_value, exc_traceback  # Unused.
        if self._has_run and not self._has_raised:
          self.raise_failures()

      def run_tasks(self, **kwargs):
        assert not self._has_run
        assert self.tasks
        self._api.swarming_retry.run_tasks(self.tasks, **kwargs)
        self._has_run = True

      def present_tasks(self, **kwargs):
        assert self._has_run
        self._api.swarming_retry.present_tasks(self.tasks, **kwargs)

      def raise_failures(self, **kwargs):
        assert self._has_run
        self._has_raised = True
        self._api.swarming_retry.raise_failures(self.tasks, **kwargs)

    return Context(api=self.m, tasks=tasks)

  def _is_complete(self, result):
    # At the moment results have a bunch of fields set to None if incomplete.
    # On the assumption this will be changed at some point I'm also checking
    # the state explicitly.
    if result.name is None:
      return False

    return result.state not in {
        self.m.swarming.TaskState.RUNNING,
        self.m.swarming.TaskState.PENDING,
    }

  def _launch(self, tasks):
    for task in tasks:
      task_name = '%s (attempt %d)' % (task.name, len(task.attempts))
      with self.m.step.nest(task_name) as presentation:
        attempt = task.launch()
        assert attempt.task_id not in self._tasks_by_id
        self._tasks_by_id[attempt.task_id] = task
        presentation.links['Swarming task'] = attempt.task_ui_link

  def _launch_and_collect(self, tasks, collect_timeout, collect_output_dir,
                          summary_presentation):
    """Launch necessary tasks and process those that complete.

    Launch any tasks that are not currently running, have not passed,
    and have not exceeded max_attempts.

    After launching tasks, wait for all tasks to complete (those just
    launched as well as those that have been running for awhile), but
    timeout after collect_timeout.  Summarize the jobs that have just
    passed or failed as well as those still running (with swarming
    links).

    This function is mostly stateless. The caller must pass in the
    same arguments for each invocation, and state is kept inside the
    tasks themselves.

    Args:
      tasks (list[Task]): tasks to execute
      collect_timeout (str): duration to wait for tasks to complete
        (format: https://golang.org/pkg/time/#ParseDuration)
      collect_output_dir (Path or None): output directory to pass to
        api.swarming.collect()
      summary_presentation (StepPresentation): where to attach the
        summary for this round of launch/collect

    Returns:
      Number of jobs still running or to be relaunched. As long as this
      is positive the caller should continue calling this method.
    """

    summary = []

    to_launch = [task for task in tasks if task.should_launch()]
    if to_launch:
      with self.m.step.nest('launch'):
        self._launch(to_launch)
      summary.append('%d launched' % len(to_launch))

    results = []
    task_ids = [x.attempts[-1].task_id for x in tasks if x.in_progress]
    if task_ids:
      results = self.m.swarming.collect(
          'collect',
          task_ids,
          timeout=collect_timeout,
          output_dir=collect_output_dir,
      )

    # Sometimes collect doesn't respond with all the requested task ids.
    # We ignore those tasks for this step, but make sure to run this method
    # again, hoping collect will give us data on them the next time around.
    # (This also makes testing this module much easier.)
    num_missed_by_collect = len(task_ids) - len(results)

    incomplete_tasks = []
    complete_tasks = []
    for result in results:
      task = self._tasks_by_id[result.id]
      if self._is_complete(result):
        task.attempts[-1].result = result
        complete_tasks.append(task)
      else:
        incomplete_tasks.append(task)

    passed_tasks = []
    failed_tasks = []
    if complete_tasks:
      with self.m.step.nest('process results', status='last'):
        for task in complete_tasks:
          try:
            task.process_result()
          except recipe_api.StepFailure as e:
            error_step = self.m.step('exception', None)
            error_step.presentation.step_summary_text = str(e)
            error_step.presentation.logs['exception'] = (
                traceback.format_exc().splitlines())
            task.attempts[-1].failure_reason = (
                'exception during result processing')

          if task.success:
            passed_tasks.append(task)
          else:
            failed_tasks.append(task)

        # Add passing step at end so parent step always passes (since
        # parent step has status='last'). Any errors will be shown when
        # presenting results.
        self.m.step('always pass', None)

    for list_name, task_list in [
        ('passed', passed_tasks),
        ('failed', failed_tasks),
    ]:
      if task_list:
        with self.m.step.nest('%s tasks' % list_name) as list_step:
          for task in task_list:
            task.present_status(list_step, task.attempts[-1])
        summary.append('%d %s' % (len(task_list), list_name))

    if incomplete_tasks:
      with self.m.step.nest('incomplete tasks') as list_step_presentation:
        for task in incomplete_tasks:
          # Always do minimal presentation of in-progress Attempts.
          name = '%s (%s)' % (task.name, task.attempts[-1].name)
          list_step_presentation.links[name] = task.attempts[-1].task_ui_link
      summary.append('%d incomplete' % len(incomplete_tasks))

    to_be_relaunched = [x for x in tasks if x.should_launch()]

    failed_after_max_attempts = [
        x for x in tasks
        if (not x.success and
            not x.in_progress and
            len(x.attempts) >= x.max_attempts)
    ]  # yapf: disable
    if failed_after_max_attempts:
      summary.append('%d failed after max attempts' %
                     len(failed_after_max_attempts))

    summary_presentation.step_summary_text = ', '.join(summary)

    return len(to_be_relaunched) + len(incomplete_tasks) + num_missed_by_collect

  def run_tasks(self,
                tasks,
                max_attempts=0,
                collect_timeout=None,
                collect_output_dir=None):
    """Launch all tasks, retry until max_attempts reached, return results.

    Args:
      tasks (seq[Task]): tasks to execute
      max_attempts (int): maximum number of attempts per task (0 means
        DEFAULT_MAX_ATTEMPTS)
      collect_timeout (str or None): duration to wait for tasks to complete
        (format: https://golang.org/pkg/time/#ParseDuration)
      collect_output_dir (Path or None): output directory to pass to
        api.swarming.collect()

    Returns:
      Number of tasks that did not pass.
    """

    max_attempts = max_attempts or DEFAULT_MAX_ATTEMPTS
    collect_timeout = collect_timeout or DEFAULT_COLLECT_TIMEOUT

    # If there's only one task there's no need to have a timeout. The
    # timeout allows us to check on the status of one task while another
    # is still running, and that's not relevant in this case.
    if len(tasks) == 1:
      collect_timeout = None

    for task in tasks:
      if not task.max_attempts:
        task.max_attempts = max_attempts

    with self.m.step.nest('launch/collect'), self.m.context(infra_steps=True):
      for i in itertools.count(0):
        with self.m.step.nest(str(i)) as presentation:
          if not self._launch_and_collect(
              tasks=tasks,
              collect_timeout=collect_timeout,
              collect_output_dir=collect_output_dir,
              summary_presentation=presentation):
            break

    failed = [x for x in tasks if not x.success]
    return len(failed)

  def present_tasks(self, tasks, **kwargs):
    """Present results as steps.

    Examine tasks for pass/fail status and create step data for displaying
    that status. Group all passes under one step and all failures under
    another step. Passes that failed at least once are also listed as
    flakes.

    Args:
      tasks (seq[Task]): tasks to examine
      **kwargs: extra args handled by subclasses of Task
    """

    # TODO(mohrr) add hooks to include task-specific data beyond pass/fail.
    passed_tasks = [x for x in tasks if x.success]
    failed_tasks = [x for x in tasks if not x.success]

    flaked_tasks = []
    with self.m.step.nest('passes') as step_presentation:
      for task in passed_tasks:
        task.present(category='passes', **kwargs)
        if len(task.attempts) > 1:
          flaked_tasks.append(task)
      step_presentation.step_summary_text = '%d passed' % len(passed_tasks)

    with self.m.step.nest('flakes') as step_presentation:
      for task in flaked_tasks:
        task.present(category='flakes', **kwargs)
      step_presentation.step_summary_text = '%d flaked' % len(flaked_tasks)

    with self.m.step.nest('failures') as step_presentation:
      for task in failed_tasks:
        task.present(category='failures', **kwargs)
      step_presentation.step_summary_text = '%d failed' % len(failed_tasks)

    if not failed_tasks:
      self.m.step('all tasks passed', None)  # pragma: no cover

  def raise_failures(self, tasks):
    """Raise an exception if any tasks failed.

    Examine tasks for pass/fail status. If any failed, raise a StepFailure.

    Args:
      tasks (seq[Task]): tasks to examine
    """
    failed = [x for x in tasks if not x.success]
    if failed:
      raise self.m.step.StepFailure('task(s) failed: %s' %
                                    ', '.join(x.name for x in failed))
