# Copyright 2019 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import time

from recipe_engine.recipe_api import Property

DEPS = [
    'fuchsia/status_check',
    'fuchsia/swarming_retry',
    'recipe_engine/json',
    'recipe_engine/led',
    'recipe_engine/properties',
    'recipe_engine/step',
    'recipe_engine/swarming',
    'recipe_engine/time',
]

PROPERTIES = {
    'full':
        Property(
            kind=bool,
            default=False,
            help='Whether to run six tasks or just one.',
        ),

    'task_type':
        Property(
            kind=str,
            default='test',
            help='Type of tasks to create. Options: '
            '"test", "internal_failure", "raising", "led", "triggered".',
        ),

    'max_attempts':
        Property(
            kind=int,
            default=2,
            help='Overall max attempts.',
        ),

    'last_task_max_attempts':
        Property(
            kind=int,
            default=None,
            help='Override the overall max attempts by setting on '
            'Task.max_attempts. Only set on last task.',
        ),
    'launch_deadline_time':
        Property(
            kind=float,
            default=None,
            help='Passed through to swarming_retry.Task.__init__().'
        ),
}  # yapf: disable


def create_task_class(api):

  class Task(api.swarming_retry.Task):
    """Required subclass for testing swarming_retry.

    Defined inside a function because base class is inside api object.
    """

    def __init__(self, initial_task_id, *args, **kwargs):
      """Construct a Task object.

      Args:
        initial_task_id (int or str): integer decimal value (since this needs
          to be incremented but is then used as a str later this method
          accepts both int and str types to minimize confusion, so long as
          int(initial_task_id) works)
      """

      kwargs.setdefault('api', api)
      super(Task, self).__init__(*args, **kwargs)
      self._next_task_id = int(initial_task_id)

    def launch(self):
      kwargs = {
          'index': len(self.attempts),
          'task_id': str(self._next_task_id),
      }

      self._next_task_id += 1

      # This looks funny but it's needed to ensure coverage of
      # Attempt.task_ui_link.
      if self._next_task_id % 2 == 0:
        kwargs['host'] = 'testhost'
      else:
        kwargs['task_ui_link'] = ('https://testhost/task?id=%s' %
                                  kwargs['task_id'])

      attempt = api.swarming_retry.Attempt(**kwargs)
      step = api.step('launch %s' % self.name, None)
      step.presentation.step_summary_text = attempt.task_id

      self.attempts.append(attempt)
      return attempt

  return Task


def create_task(api, name, initial_task_id, **kwargs):
  Task = create_task_class(api)  # pylint: disable=invalid-name
  return Task(name=name, initial_task_id=initial_task_id, **kwargs)


def create_internal_failure_task(api, name, initial_task_id, **kwargs):
  Task = create_task_class(api)  # pylint: disable=invalid-name

  class InternalFailureTask(Task):

    def process_result(self):
      self.attempts[-1].failure_reason = 'internal failure'

  return InternalFailureTask(
      name=name, initial_task_id=initial_task_id, **kwargs)


def create_raising_task(api, name, initial_task_id, **kwargs):
  Task = create_task_class(api)  # pylint: disable=invalid-name

  class RaisingTask(Task):

    def process_result(self):
      raise self._api.step.StepFailure('something failed')

  return RaisingTask(name=name, initial_task_id=initial_task_id, **kwargs)


def create_led_task(api, name, initial_task_id, **kwargs):
  del initial_task_id  # Unused.

  return api.swarming_retry.LedTask(
      api=api, name=name, led_data=api.led('get-build', 'builder'), **kwargs)


def create_triggered_task(api, name, initial_task_id, **kwargs):
  del initial_task_id  # Unused.

  dimensions = {
      'pool': 'pool',
      'device_type': 'device_type',
  }

  request = api.swarming.task_request().with_name(name)
  request = request.with_slice(0, request[0].with_dimensions(**dimensions))

  return api.swarming_retry.TriggeredTask(
      api=api, name=name, request=request, **kwargs)


# pylint: disable=invalid-name
def RunSteps(api, full, task_type, max_attempts, last_task_max_attempts,
             launch_deadline_time):

  task_types = {
      'test': create_task,
      'internal_failure': create_internal_failure_task,
      'raising': create_raising_task,
      'led': create_led_task,
      'triggered': create_triggered_task,
  }

  _create_task = task_types[task_type]  # pylint: disable=invalid-name

  if full:
    tasks = [
        _create_task(api=api, name='pass', initial_task_id=100),
        _create_task(api=api, name='flake', initial_task_id=200),
        _create_task(api=api, name='fail', initial_task_id=300),
        _create_task(api=api, name='pass_long', initial_task_id=400),
        _create_task(api=api, name='flake_long', initial_task_id=500),
        _create_task(api=api, name='fail_long', initial_task_id=600),
    ]

  else:
    tasks = [
        _create_task(
            api=api,
            name='task',
            initial_task_id=100,
            launch_deadline_time=launch_deadline_time)
    ]

  if last_task_max_attempts:
    tasks[-1].max_attempts = last_task_max_attempts

  with api.swarming_retry.retry(tasks) as retry:
    retry.run_tasks(max_attempts=max_attempts)
    retry.present_tasks()

    # Needed for coverage of Task.result.
    _ = tasks[0].result

    # Deliberately not calling retry.raise_failures() here. The __exit__
    # portion of the context will call it for us.


def GenTests(api):  # pylint: disable=invalid-name

  test_api = api.swarming_retry

  yield (
      api.status_check.test('full_test', status='failure') +
      api.properties(full=True) +

      test_api.collect_data([
          test_api.passed_task('pass', 100),
          test_api.failed_task('flake', 200),
          test_api.failed_task('fail', 300),
      ], iteration=0) +

      test_api.collect_data([
          test_api.passed_task('flake', 201),
          test_api.failed_task('fail', 301),
      ], iteration=1) +

      test_api.collect_data([
          test_api.incomplete_task('pass_long', 400),
          test_api.incomplete_task('flake_long', 500),
          test_api.incomplete_task('fail_long', 600),
      ], iteration=2) +

      test_api.collect_data([], iteration=3) +

      test_api.collect_data([
          test_api.passed_task('pass_long', 400),
      ], iteration=4) +

      test_api.collect_data([
          test_api.failed_task('flake_long', 500),
      ], iteration=5) +

      test_api.collect_data([
          test_api.failed_task('fail_long', 600),
      ], iteration=6) +

      test_api.collect_data([], iteration=7) +
      test_api.collect_data([], iteration=8) +

      test_api.collect_data([
          test_api.passed_task('flake_long', 501),
      ], iteration=9) +

      test_api.collect_data([
          test_api.failed_task('fail_long', 601),
      ], iteration=10)
  )  # yapf: disable

  yield (
      api.status_check.test('timeout_then_pass') +
      api.properties(full=False) +
      test_api.collect_data([test_api.timed_out_task('task', 100)]) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('internal_failure', status='failure') +
      api.properties(full=False, task_type='internal_failure') +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('raising_process_results', status='failure') +
      api.properties(full=False, task_type='raising') +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('led_task') +
      api.properties(full=False, task_type='led') +
      api.swarming_retry.led_data('task', 1, iteration=0) +
      test_api.collect_data([test_api.failed_task('task', 1)], iteration=0) +
      api.swarming_retry.led_data('task', 2, iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 2)], iteration=1)
  )  # yapf: disable

  yield (
      api.status_check.test('led_task_hardcoded_attempt') +
      api.properties(full=False, task_type='led') +
      api.swarming_retry.led_data('task', 1, attempt=0)
  )  # yapf: disable

  yield (
      api.status_check.test('triggered_task') +
      api.properties(full=False, task_type='triggered') +
      api.swarming_retry.trigger_data('task', 1) +
      test_api.collect_data([test_api.passed_task('task', 1)], iteration=0)
  )  # yapf: disable

  yield (
      api.status_check.test('max_attempts_three', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 102)], iteration=2)
  )  # yapf: disable

  yield (
      api.status_check.test('last_task_max_attempts_low', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3,
                     last_task_max_attempts=1) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0)
  )  # yapf: disable

  yield (
      api.status_check.test('last_task_max_attempts_high', status='failure') +
      api.properties(full=False, task_type='raising', max_attempts=3,
                     last_task_max_attempts=5) +
      test_api.collect_data([test_api.passed_task('task', 100)], iteration=0) +
      test_api.collect_data([test_api.passed_task('task', 101)], iteration=1) +
      test_api.collect_data([test_api.passed_task('task', 102)], iteration=2) +
      test_api.collect_data([test_api.passed_task('task', 103)], iteration=3) +
      test_api.collect_data([test_api.passed_task('task', 104)], iteration=4)
  )  # yapf: disable

  yield (
      api.status_check.test(
          'last_task_max_attempts_high_mixed', status='failure') +
      api.properties(full=True, task_type='raising', max_attempts=1,
                     last_task_max_attempts=5) +

      # Names are misleading for this test. Sorry.
      test_api.collect_data([
          test_api.failed_task('pass', 100),
          test_api.failed_task('flake', 200),
          test_api.failed_task('fail', 300),
          test_api.failed_task('pass_long', 400),
          test_api.failed_task('flake_long', 500),
          test_api.failed_task('fail_long', 600),
      ], iteration=0) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 601)], iteration=1) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 602)], iteration=2) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 603)], iteration=3) +
      test_api.collect_data(
          [test_api.failed_task('fail_long', 604)], iteration=4)
  )  # yapf: disable

  yield (api.status_check.test('launch_deadline_time') +
         # Settings such that launch_deadline_time is in the range
         # (time.seed() + time.step(), time.seed() + 2 * time.step()]
         # should result in the task being launched exactly once.
         api.properties(
             full=False, task_type='triggered', launch_deadline_time=1.0) +
         api.time.seed(0) + api.time.step(0.6) +
         api.swarming_retry.trigger_data('task', 1) +
         test_api.collect_data([test_api.passed_task('task', 1)], iteration=0)
  ) # yapf: disable
