Rename the `setup_generated_tests` stage to `pre_run`. (#844)

Users have identified more legit use cases where operations before `setup_class` are needed. The `setup_generated_tests` stage is essentially the stage before `setup_class`. Hence we are renaming it more properly.
diff --git a/mobly/base_test.py b/mobly/base_test.py
index c559b44..00ba3d9 100644
--- a/mobly/base_test.py
+++ b/mobly/base_test.py
@@ -36,6 +36,8 @@
 TEST_STAGE_END_LOG_TEMPLATE = '[{parent_token}]#{child_token} <<< END <<<'
 
 # Names of execution stages, in the order they happen during test runs.
+STAGE_NAME_PRE_RUN = 'pre_run'
+# Deprecated, use `STAGE_NAME_PRE_RUN` instead.
 STAGE_NAME_SETUP_GENERATED_TESTS = 'setup_generated_tests'
 STAGE_NAME_SETUP_CLASS = 'setup_class'
 STAGE_NAME_SETUP_TEST = 'setup_test'
@@ -339,22 +341,26 @@
       self.summary_writer.dump(record.to_dict(),
                                records.TestSummaryEntryType.CONTROLLER_INFO)
 
-  def _setup_generated_tests(self):
-    """Proxy function to guarantee the base implementation of
-    setup_generated_tests is called.
+  def _pre_run(self):
+    """Proxy function to guarantee the base implementation of `pre_run` is
+    called.
 
     Returns:
       True if setup is successful, False otherwise.
     """
-    stage_name = STAGE_NAME_SETUP_GENERATED_TESTS
+    stage_name = STAGE_NAME_PRE_RUN
     record = records.TestResultRecord(stage_name, self.TAG)
     record.test_begin()
     self.current_test_info = runtime_test_info.RuntimeTestInfo(
         stage_name, self.log_path, record)
     try:
       with self._log_test_stage(stage_name):
+        self.pre_run()
+      # TODO(angli): Remove this context block after the full deprecation of
+      # `setup_generated_tests`.
+      with self._log_test_stage(stage_name):
         self.setup_generated_tests()
-        return True
+      return True
     except Exception as e:
       logging.exception('%s failed for %s.', stage_name, self.TAG)
       record.test_error(e)
@@ -363,7 +369,7 @@
                                records.TestSummaryEntryType.RECORD)
       return False
 
-  def setup_generated_tests(self):
+  def pre_run(self):
     """Preprocesses that need to be done before setup_class.
 
     This phase is used to do pre-test processes like generating tests.
@@ -374,6 +380,19 @@
     requested is unknown at this point.
     """
 
+  def setup_generated_tests(self):
+    """[DEPRECATED] Use `pre_run` instead.
+
+    Preprocesses that need to be done before setup_class.
+
+    This phase is used to do pre-test processes like generating tests.
+    This is the only place `self.generate_tests` should be called.
+
+    If this function throws an error, the test class will be marked failure
+    and the "Requested" field will be 0 because the number of tests
+    requested is unknown at this point.
+    """
+
   def _setup_class(self):
     """Proxy function to guarantee the base implementation of setup_class
     is called.
@@ -854,7 +873,7 @@
         arguments as the test logic function and returns a string that
         is the corresponding UID.
     """
-    self._assert_function_name_in_stack(STAGE_NAME_SETUP_GENERATED_TESTS)
+    self._assert_function_name_in_stack(STAGE_NAME_PRE_RUN)
     root_msg = 'During test generation of "%s":' % test_logic.__name__
     for args in arg_sets:
       test_name = name_func(*args)
@@ -866,8 +885,8 @@
       # decorators, copy the attributes added by the decorators to the
       # generated test methods as well, so the generated test methods
       # also have the retry/repeat behavior.
-      for attr_name in (
-        ATTR_MAX_RETRY_CNT, ATTR_MAX_CONSEC_ERROR, ATTR_REPEAT_CNT):
+      for attr_name in (ATTR_MAX_RETRY_CNT, ATTR_MAX_CONSEC_ERROR,
+                        ATTR_REPEAT_CNT):
         attr = getattr(test_logic, attr_name, None)
         if attr is not None:
           setattr(test_func, attr_name, attr)
@@ -987,7 +1006,7 @@
     """
     logging.log_path = self.log_path
     # Executes pre-setup procedures, like generating test methods.
-    if not self._setup_generated_tests():
+    if not self._pre_run():
       return self.results
     logging.info('==========> %s <==========', self.TAG)
     # Devise the actual test methods to run in the test class.
diff --git a/tests/mobly/base_test_test.py b/tests/mobly/base_test_test.py
index ebac0e0..5cc6b20 100755
--- a/tests/mobly/base_test_test.py
+++ b/tests/mobly/base_test_test.py
@@ -1960,6 +1960,37 @@
     bc.unpack_userparams(arg1="haha")
     self.assertEqual(bc.arg1, "haha")
 
+  def test_pre_run_failure(self):
+    """Test code path for `pre_run` failure.
+
+    When `pre_run` fails, pre-execution calculation is incomplete and the
+    number of tests requested is unknown. This is a
+    fatal issue that blocks any test execution in a class.
+
+    A class level error record is generated.
+    Unlike `setup_class` failure, no test is considered "skipped" in this
+    case as execution stage never started.
+    """
+
+    class MockBaseTest(base_test.BaseTestClass):
+
+      def pre_run(self):
+        raise Exception(MSG_EXPECTED_EXCEPTION)
+
+      def logic(self, a, b):
+        pass
+
+      def test_foo(self):
+        pass
+
+    bt_cls = MockBaseTest(self.mock_test_cls_configs)
+    bt_cls.run()
+    self.assertEqual(len(bt_cls.results.requested), 0)
+    class_record = bt_cls.results.error[0]
+    self.assertEqual(class_record.test_name, 'pre_run')
+    self.assertEqual(bt_cls.results.skipped, [])
+
+  # TODO(angli): remove after the full deprecation of `setup_generated_tests`.
   def test_setup_generated_tests_failure(self):
     """Test code path for setup_generated_tests failure.
 
@@ -1987,14 +2018,14 @@
     bt_cls.run()
     self.assertEqual(len(bt_cls.results.requested), 0)
     class_record = bt_cls.results.error[0]
-    self.assertEqual(class_record.test_name, 'setup_generated_tests')
+    self.assertEqual(class_record.test_name, 'pre_run')
     self.assertEqual(bt_cls.results.skipped, [])
 
   def test_generate_tests_run(self):
 
     class MockBaseTest(base_test.BaseTestClass):
 
-      def setup_generated_tests(self):
+      def pre_run(self):
         self.generate_tests(test_logic=self.logic,
                             name_func=self.name_gen,
                             arg_sets=[(1, 2), (3, 4)])
@@ -2018,7 +2049,7 @@
 
     class MockBaseTest(base_test.BaseTestClass):
 
-      def setup_generated_tests(self):
+      def pre_run(self):
         self.generate_tests(test_logic=self.logic,
                             name_func=self.name_gen,
                             uid_func=self.uid_logic,
@@ -2042,7 +2073,7 @@
 
     class MockBaseTest(base_test.BaseTestClass):
 
-      def setup_generated_tests(self):
+      def pre_run(self):
         self.generate_tests(test_logic=self.logic,
                             name_func=self.name_gen,
                             uid_func=self.uid_logic,
@@ -2068,7 +2099,7 @@
 
     class MockBaseTest(base_test.BaseTestClass):
 
-      def setup_generated_tests(self):
+      def pre_run(self):
         self.generate_tests(test_logic=self.logic,
                             name_func=self.name_gen,
                             arg_sets=[(1, 2), (3, 4)])
@@ -2085,7 +2116,7 @@
     self.assertEqual(len(bt_cls.results.passed), 1)
     self.assertEqual(bt_cls.results.passed[0].test_name, 'test_3_4')
 
-  def test_generate_tests_call_outside_of_setup_generated_tests(self):
+  def test_generate_tests_call_outside_of_pre_run(self):
 
     class MockBaseTest(base_test.BaseTestClass):
 
@@ -2105,9 +2136,8 @@
     actual_record = bt_cls.results.error[0]
     utils.validate_test_result(bt_cls.results)
     self.assertEqual(actual_record.test_name, "test_ha")
-    self.assertEqual(
-        actual_record.details,
-        '"generate_tests" cannot be called outside of setup_generated_tests')
+    self.assertEqual(actual_record.details,
+                     '"generate_tests" cannot be called outside of pre_run')
     expected_summary = ("Error 1, Executed 1, Failed 0, Passed 0, "
                         "Requested 1, Skipped 0")
     self.assertEqual(bt_cls.results.summary_str(), expected_summary)
@@ -2116,7 +2146,7 @@
 
     class MockBaseTest(base_test.BaseTestClass):
 
-      def setup_generated_tests(self):
+      def pre_run(self):
         self.generate_tests(test_logic=self.logic,
                             name_func=self.name_gen,
                             arg_sets=[(1, 2), (3, 4)])
@@ -2130,7 +2160,7 @@
     bt_cls = MockBaseTest(self.mock_test_cls_configs)
     bt_cls.run()
     actual_record = bt_cls.results.error[0]
-    self.assertEqual(actual_record.test_name, "setup_generated_tests")
+    self.assertEqual(actual_record.test_name, "pre_run")
     self.assertEqual(
         actual_record.details,
         'During test generation of "logic": Test name "ha" already exists'
@@ -2300,11 +2330,10 @@
       def _run_test_logic(self, arg):
         pass
 
-      def setup_generated_tests(self):
-        self.generate_tests(
-          self._run_test_logic,
-          name_func=lambda arg: f'test_generated_{arg}',
-          arg_sets=[(1,)])
+      def pre_run(self):
+        self.generate_tests(self._run_test_logic,
+                            name_func=lambda arg: f'test_generated_{arg}',
+                            arg_sets=[(1,)])
 
     bt_cls = MockBaseTest(self.mock_test_cls_configs)
     bt_cls.run()
@@ -2480,7 +2509,8 @@
   def test_retry_generated_test_last_pass(self):
     max_count = 3
     mock_action = mock.MagicMock(
-      side_effect = [Exception('Fail 1'), Exception('Fail 2'), None])
+        side_effect=[Exception('Fail 1'),
+                     Exception('Fail 2'), None])
 
     class MockBaseTest(base_test.BaseTestClass):
 
@@ -2488,11 +2518,10 @@
       def _run_test_logic(self, arg):
         mock_action()
 
-      def setup_generated_tests(self):
-        self.generate_tests(
-          self._run_test_logic,
-          name_func=lambda arg: f'test_generated_{arg}',
-          arg_sets=[(1,)])
+      def pre_run(self):
+        self.generate_tests(self._run_test_logic,
+                            name_func=lambda arg: f'test_generated_{arg}',
+                            arg_sets=[(1,)])
 
     bt_cls = MockBaseTest(self.mock_test_cls_configs)
     bt_cls.run()