Create base suite class and add run_suite_class to suite runner. (#817)

diff --git a/mobly/base_suite.py b/mobly/base_suite.py
new file mode 100644
index 0000000..6febcd8
--- /dev/null
+++ b/mobly/base_suite.py
@@ -0,0 +1,73 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+class BaseSuite(abc.ABC):
+  """Class used to define a Mobly suite.
+
+  To create a suite, inherit from this class and implement setup_suite.
+
+  Use `BaseSuite.add_test_class` to specify which classes to run with which
+  configs and test selectors.
+
+  After defining the sub class, the suite can be executed using
+  suite_runner.run_suite_class.
+
+  Users can use this class if they need to define their own setup and teardown
+  steps on the suite level. Otherwise, just use suite_runner.run_suite on the
+  list of test classes.
+  """
+
+  def __init__(self, runner, config):
+    self._runner = runner
+    self._config = config.copy()
+
+  @property
+  def user_params(self):
+    return self._config.user_params
+
+  def add_test_class(self, clazz, config=None, tests=None, name_suffix=None):
+    """Adds a test class to the suite.
+
+    Args:
+      clazz: class, a Mobly test class.
+      config: config_parser.TestRunConfig, the config to run the class with. If
+        not specified, the default config passed from google3 infra is used.
+      tests: list of strings, names of the tests to run in this test class, in
+        the execution order. If not specified, all tests in the class are
+        executed.
+      name_suffix: string, suffix to append to the class name for reporting.
+        This is used for differentiating the same class executed with different
+        parameters in a suite.
+    """
+    if not config:
+      config = self._config
+    self._runner.add_test_class(config, clazz, tests, name_suffix)
+
+  @abc.abstractmethod
+  def setup_suite(self, config):
+    """Function used to add test classes, has to be implemented by child class.
+
+    Args:
+      config: config_parser.TestRunConfig, the config provided by google3 infra.
+
+    Raises:
+      Error: when setup_suite is not implemented by child class.
+    """
+    pass
+
+  def teardown_suite(self):
+    """Function used to add post tests cleanup tasks (optional)."""
+    pass
diff --git a/mobly/suite_runner.py b/mobly/suite_runner.py
index d921199..037476c 100644
--- a/mobly/suite_runner.py
+++ b/mobly/suite_runner.py
@@ -13,8 +13,14 @@
 # limitations under the License.
 """Runner for Mobly test suites.
 
-To create a test suite, call suite_runner.run_suite() with one or more
-individual test classes. For example:
+These is just example code to help users run a collection of Mobly test
+classes. Users can use it as is or customize it based on their requirements.
+
+There are two ways to use this runner.
+
+1. Call suite_runner.run_suite() with one or more individual test classes. This
+is for users who just need to execute a collection of test classes without any
+additional steps.
 
 .. code-block:: python
 
@@ -25,14 +31,47 @@
   ...
   if __name__ == '__main__':
     suite_runner.run_suite(foo_test.FooTest, bar_test.BarTest)
-"""
 
+2. Create a subclass of base_suite.BaseSuite and add the individual test
+classes. Using the BaseSuite class allows users to define their own setup
+and teardown steps on the suite level as well as custom config for each test
+class.
+
+.. code-block:: python
+
+  from mobly import base_suite
+  from mobly import suite_runner
+
+  from my.path import MyFooTest
+  from my.path import MyBarTest
+
+
+  class MySuite(base_suite.BaseSuite):
+
+    def setup_suite(self, config):
+      # Add a class with default config.
+      self.add_test_class(MyFooTest)
+      # Add a class with test selection.
+      self.add_test_class(MyBarTest,
+                          tests=['test_a', 'test_b'])
+      # Add the same class again with a custom config and suffix.
+      my_config = some_config_logic(config)
+      self.add_test_class(MyBarTest,
+                          config=my_config,
+                          name_suffix='WithCustomConfig')
+
+
+  if __name__ == '__main__':
+    suite_runner.run_suite_class()
+"""
 import argparse
 import collections
+import inspect
 import logging
 import sys
 
 from mobly import base_test
+from mobly import base_suite
 from mobly import config_parser
 from mobly import signals
 from mobly import test_runner
@@ -42,18 +81,16 @@
   pass
 
 
-def run_suite(test_classes, argv=None):
-  """Executes multiple test classes as a suite.
-
-  This is the default entry point for running a test suite script file
-  directly.
+def _parse_cli_args(argv):
+  """Parses cli args that are consumed by Mobly.
 
   Args:
-    test_classes: List of python classes containing Mobly tests.
     argv: A list that is then parsed as cli args. If None, defaults to cli
       input.
+
+  Returns:
+    Namespace containing the parsed args.
   """
-  # Parse cli args.
   parser = argparse.ArgumentParser(description='Mobly Suite Executable.')
   parser.add_argument('-c',
                       '--config',
@@ -70,7 +107,77 @@
       help='A list of test classes and optional tests to execute.')
   if not argv:
     argv = sys.argv[1:]
-  args = parser.parse_args(argv)
+  return parser.parse_args(argv)
+
+
+def _find_suite_class():
+  """Finds the test suite class in the current module.
+
+  Walk through module members and find the subclass of BaseSuite. Only
+  one subclass is allowed in a module.
+
+  Returns:
+      The test suite class in the test module.
+  """
+  test_suites = []
+  main_module_members = sys.modules['__main__']
+  for _, module_member in main_module_members.__dict__.items():
+    if inspect.isclass(module_member):
+      if issubclass(module_member, base_suite.BaseSuite):
+        test_suites.append(module_member)
+  if len(test_suites) != 1:
+    logging.error('Expected 1 test class per file, found %s.',
+                  [t.__name__ for t in test_suites])
+    sys.exit(1)
+  return test_suites[0]
+
+
+def run_suite_class(argv=None):
+  """Executes tests in the test suite.
+
+  Args:
+    argv: A list that is then parsed as CLI args. If None, defaults to sys.argv.
+  """
+  if argv is None:
+    argv = sys.argv
+  cli_args = _parse_cli_args(argv)
+  test_configs = config_parser.load_test_config_file(cli_args.config)
+  config_count = len(test_configs)
+  if config_count != 1:
+    logging.error('Expect exactly one test config, found %d', config_count)
+  config = test_configs[0]
+  runner = test_runner.TestRunner(
+      log_dir=config.log_path, testbed_name=config.testbed_name)
+  suite_class = _find_suite_class()
+  suite = suite_class(runner, config)
+  ok = False
+  with runner.mobly_logger():
+    try:
+      suite.setup_suite(config.copy())
+      try:
+        runner.run()
+        ok = runner.results.is_all_pass
+        print(ok)
+      except signals.TestAbortAll:
+        pass
+    finally:
+      suite.teardown_suite()
+  if not ok:
+    sys.exit(1)
+
+
+def run_suite(test_classes, argv=None):
+  """Executes multiple test classes as a suite.
+
+  This is the default entry point for running a test suite script file
+  directly.
+
+  Args:
+    test_classes: List of python classes containing Mobly tests.
+    argv: A list that is then parsed as cli args. If None, defaults to cli
+      input.
+  """
+  args = _parse_cli_args(argv)
   # Load test config file.
   test_configs = config_parser.load_test_config_file(args.config)
 
diff --git a/tests/mobly/suite_runner_test.py b/tests/mobly/suite_runner_test.py
index 7297236..653bbfe 100755
--- a/tests/mobly/suite_runner_test.py
+++ b/tests/mobly/suite_runner_test.py
@@ -12,18 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 import io
 import os
 import shutil
+import sys
 import tempfile
 import unittest
 from unittest import mock
 
+from mobly import base_suite
+from mobly import base_test
+from mobly import config_parser
+from mobly import test_runner
 from mobly import suite_runner
 from tests.lib import integration2_test
 from tests.lib import integration_test
 
 
+class FakeTest1(base_test.BaseTestClass):
+  pass
+
+
 class SuiteRunnerTest(unittest.TestCase):
 
   def setUp(self):
@@ -108,6 +118,44 @@
                            argv=['-c', tmp_file_path])
     mock_exit.assert_called_once_with(1)
 
+  @mock.patch('sys.exit')
+  @mock.patch.object(suite_runner, '_find_suite_class', autospec=True)
+  def test_run_suite_class(self, mock_find_suite_class, mock_exit):
+    mock_called = mock.MagicMock()
+
+    class FakeTestSuite(base_suite.BaseSuite):
+
+      def setup_suite(self, config):
+        mock_called.setup_suite()
+        super().setup_suite(config)
+        self.add_test_class(FakeTest1)
+
+      def teardown_suite(self):
+        mock_called.teardown_suite()
+        super().teardown_suite()
+
+    mock_find_suite_class.return_value = FakeTestSuite
+
+    tmp_file_path = os.path.join(self.tmp_dir, 'config.yml')
+    with io.open(tmp_file_path, 'w', encoding='utf-8') as f:
+      f.write(u"""
+        TestBeds:
+          # A test bed where adb will find Android devices.
+          - Name: SampleTestBed
+            Controllers:
+              MagicDevice: '*'
+      """)
+
+    mock_cli_args = [f'--config={tmp_file_path}']
+
+    with mock.patch.object(sys, 'argv', new=mock_cli_args):
+      suite_runner.run_suite_class()
+
+    mock_find_suite_class.assert_called_once()
+    mock_called.setup_suite.assert_called_once_with()
+    mock_called.teardown_suite.assert_called_once_with()
+    mock_exit.assert_not_called()
+
 
 if __name__ == "__main__":
   unittest.main()