# Copyright 2019 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

from recipe_engine.recipe_api import Property

from RECIPE_MODULES.fuchsia.swarming_retry import api as swarming_retry_api

DEPS = [
    'fuchsia/status_check',
    'fuchsia/swarming_retry',
    'recipe_engine/json',
    'recipe_engine/led',
    'recipe_engine/properties',
    'recipe_engine/step',
    'recipe_engine/swarming',
    'recipe_engine/time',
]

PROPERTIES = {
    'full':
        Property(
            kind=bool,
            default=False,
            help='Whether to run six tasks or just one.',
        ),

    'task_type':
        Property(
            kind=str,
            default='test',
            help='Type of tasks to create. Options: '
            '"test", "internal_failure", "raising", "led", "triggered".',
        ),

    'max_attempts':
        Property(
            kind=int,
            default=2,
            help='Overall max attempts.',
        ),

    'last_task_max_attempts':
        Property(
            kind=int,
            default=None,
            help='Override the overall max attempts by setting on '
            'Task.max_attempts. Only set on last task.',
        ),
    'launch_deadline_time':
        Property(
            kind=float,
            default=None,
            help='Passed through to swarming_retry.Task.__init__().'
        ),
}  # yapf: disable


class Task(swarming_retry_api.Task):
  """Required subclass for testing swarming_retry.

  Defined inside a function because base class is inside api object.
  """

  def __init__(self, initial_task_id, *args, **kwargs):
    """Construct a Task object.

    Args:
      initial_task_id (int or str): integer decimal value (since this needs
        to be incremented but is then used as a str later this method
        accepts both int and str types to minimize confusion, so long as
        int(initial_task_id) works)
    """

    super(Task, self).__init__(*args, **kwargs)
    self._next_task_id = int(initial_task_id)

  def launch(self):
    kwargs = {
        'index': len(self.attempts),
        'task_id': str(self._next_task_id),
    }

    self._next_task_id += 1

    # This looks funny but it's needed to ensure coverage of
    # Attempt.task_ui_link.
    if self._next_task_id % 2 == 0:
      kwargs['host'] = 'testhost'
    else:
      kwargs['task_ui_link'] = ('https://testhost/task?id=%s' %
                                kwargs['task_id'])

    attempt = self._api.swarming_retry.Attempt(**kwargs)
    step = self._api.step('launch %s' % self.name, None)
    step.presentation.step_summary_text = attempt.task_id

    self.attempts.append(attempt)
    return attempt


class InternalFailureTask(Task):

  def process_result(self):
    self.attempts[-1].failure_reason = 'internal failure'


class RaisingTask(Task):

  def process_result(self):
    raise self._api.step.StepFailure('something failed')


class LedTask(swarming_retry_api.LedTask):

  def __init__(self, initial_task_id, api, **kwargs):
    del initial_task_id  # Unused.
    super(LedTask, self).__init__(
        api.led('get-build', 'builder'), api=api, **kwargs)


class TriggeredTask(swarming_retry_api.TriggeredTask):

  def __init__(self, api, name, initial_task_id, **kwargs):
    del initial_task_id  # Unused.

    dimensions = {
        'pool': 'pool',
        'device_type': 'device_type',
    }

    request = api.swarming.task_request().with_name(name)
    request = request.with_slice(0, request[0].with_dimensions(**dimensions))

    super(TriggeredTask, self).__init__(request, api=api, name=name, **kwargs)


# pylint: disable=invalid-name
def RunSteps(api, full, task_type, max_attempts, last_task_max_attempts,
             launch_deadline_time):

  task_types = {
      'test': Task,
      'internal_failure': InternalFailureTask,
      'raising': RaisingTask,
      'led': LedTask,
      'triggered': TriggeredTask,
  }

  _create_task = task_types[task_type]  # pylint: disable=invalid-name

  if full:
    tasks = [
        _create_task(api=api, name='pass', initial_task_id=100),
        _create_task(api=api, name='flake', initial_task_id=200),
        _create_task(api=api, name='fail', initial_task_id=300),
        _create_task(api=api, name='pass_long', initial_task_id=400),
        _create_task(api=api, name='flake_long', initial_task_id=500),
        _create_task(api=api, name='fail_long', initial_task_id=600),
    ]

  else:
    tasks = [
        _create_task(
            api=api,
            name='task',
            initial_task_id=100,
            launch_deadline_time=launch_deadline_time)
    ]

  if last_task_max_attempts:
    tasks[-1].max_attempts = last_task_max_attempts

  with api.swarming_retry.retry(tasks) as retry:
    retry.run_tasks(max_attempts=max_attempts)
    retry.present_tasks()

    # Needed for coverage of Task.result.
    _ = tasks[0].result

    # Deliberately not calling retry.raise_failures() here. The __exit__
    # portion of the context will call it for us.


def GenTests(api):  # pylint: disable=invalid-name

  test_api = api.swarming_retry

  yield (
      api.status_check.test('full_test', status='failure') +
      api.properties(full=True) +

      test_api.collect_data([
          test_api.passed_task('pass', 100),
          test_api.failed_task('flake', 200),
          test_api.failed_task('fail', 300),
      ], iteration=0) +

      test_api.collect_data([
          test_api.passed_task('flake', 201),
          test_api.failed_task('fail', 301),
      ], iteration=1) +

      test_api.collect_data([
          test_api.incomplete_task('pass_long', 400),
          test_api.incomplete_task('flake_long', 500),
          test_api.incomplete_task('fail_long', 600),
      ], iteration=2) +

      test_api.collect_data([], iteration=3) +

      test_api.collect_data([
          test_api.passed_task('pass_long', 400),
      ], iteration=4) +

      test_api.collect_data([
          test_api.failed_task('flake_long', 500),
      ], iteration=5) +

      test_api.collect_data([
          test_api.failed_task('fail_long', 600),
      ], iteration=6) +

      test_api.collect_data([], iteration=7) +
      test_api.collect_data([], iteration=8) +

      test_api.collect_data([
          test_api.passed_task('flake_long', 501),
      ], iteration=9) +

      test_api.collect_data([
          test_api.failed_task('fail_long', 601),
      ], iteration=10)
  )  # yapf: disable

  yield (
      api.status_check.test('timeout_then_pass') +
      api.properties(full=False) +
      test_api.collect_data([test_api.timed_out_task('task', 100)]) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('internal_failure', status='failure') +
      api.properties(full=False, task_type='internal_failure') +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('raising_process_results', status='failure') +
      api.properties(full=False, task_type='raising') +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('led_task') +
      api.properties(full=False, task_type='led') +
      api.swarming_retry.led_data('task', 1, iteration=0) +
      test_api.collect_data([test_api.failed_task('task', 1)], iteration=0) +
      api.swarming_retry.led_data('task', 2, iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 2)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('led_task_hardcoded_attempt') +
      api.properties(full=False, task_type='led') +
      api.swarming_retry.led_data('task', 1, attempt=0)
  )  # yapf: disable

  yield (
      api.status_check.test('triggered_task') +
      api.properties(full=False, task_type='triggered') +
      api.swarming_retry.trigger_data('task', 1) +
      test_api.collect_data([test_api.passed_task('task', 1)], iteration=0)
  )  # yapf: disable

  yield (
      api.status_check.test('max_attempts_three', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 102)], iteration=2)
  )  # yapf: disable

  yield (
      api.status_check.test('last_task_max_attempts_low', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3,
                     last_task_max_attempts=1) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0)
  )  # yapf: disable

  yield (
      api.status_check.test('last_task_max_attempts_high', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3,
                     last_task_max_attempts=5) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 102)], iteration=2) +
      test_api.collect_data([test_api.passed_task('task', 103)], iteration=3) +
      test_api.collect_data([test_api.passed_task('task', 104)], iteration=4)
  )  # yapf: disable

  yield (
      api.status_check.test(
          'last_task_max_attempts_high_mixed', status='failure') +
      api.properties(full=True, task_type='raising', max_attempts=1,
                     last_task_max_attempts=5) +

      # Names are misleading for this test. Sorry.
      test_api.collect_data([
          test_api.failed_task('pass', 100),
          test_api.failed_task('flake', 200),
          test_api.failed_task('fail', 300),
          test_api.failed_task('pass_long', 400),
          test_api.failed_task('flake_long', 500),
          test_api.failed_task('fail_long', 600),
      ], iteration=0) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 601)], iteration=1) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 602)], iteration=2) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 603)], iteration=3) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 604)], iteration=4)
  )  # yapf: disable

  yield (api.status_check.test('launch_deadline_time') +
         # Settings such that launch_deadline_time is in the range
         # (time.seed() + time.step(), time.seed() + 2 * time.step()]
         # should result in the task being launched exactly once.
         api.properties(
             full=False, task_type='triggered', launch_deadline_time=1.0) +
         api.time.seed(0) + api.time.step(0.6) +
         api.swarming_retry.trigger_data('task', 1) +
         test_api.collect_data([test_api.passed_task('task', 1)], iteration=0)
  ) # yapf: disable
