Add find_subclass_in_module utility. (#632)
diff --git a/mobly/test_runner.py b/mobly/test_runner.py
index d32d304..4dabfb2 100644
--- a/mobly/test_runner.py
+++ b/mobly/test_runner.py
@@ -145,14 +145,17 @@
Returns:
The test class in the test module.
+
+ Raises:
+ SystemExit: Raised if the number of test classes is not exactly one.
"""
- test_classes = utils.find_subclasses_in_module(base_test.BaseTestClass,
- sys.modules['__main__'])
- if len(test_classes) != 1:
- logging.error('Expected 1 test class per file, found %s.',
- [t.__name__ for t in test_classes])
+ try:
+ return utils.find_subclass_in_module(base_test.BaseTestClass,
+ sys.modules['__main__'])
+ except ValueError:
+ logging.exception('Exactly one subclass of `base_test.BaseTestClass`'
+ ' should be in the main file.')
sys.exit(1)
- return test_classes[0]
def _print_test_names(test_class):
diff --git a/mobly/utils.py b/mobly/utils.py
index a32c779..2a60bae 100644
--- a/mobly/utils.py
+++ b/mobly/utils.py
@@ -567,3 +567,24 @@
if issubclass(module_member, base_class):
subclasses.append(module_member)
return subclasses
+
+
+def find_subclass_in_module(base_class, module):
+ """Finds the single subclass of the given base class in the given module.
+
+ Args:
+ base_class: class, the base class to look for a subclass of in the module.
+ module: module, the module to look for the single subclass in.
+
+ Returns:
+ The single subclass of the given base class.
+
+ Raises:
+ ValueError: If the number of subclasses found was not exactly one.
+ """
+ subclasses = find_subclasses_in_module([base_class], module)
+ if len(subclasses) != 1:
+ raise ValueError('Expected 1 subclass of %s per module, found %s.' %
+ (base_class.__name__,
+ [subclass.__name__ for subclass in subclasses]))
+ return subclasses[0]
diff --git a/tests/mobly/test_runner_test.py b/tests/mobly/test_runner_test.py
index 90e9268..4074891 100755
--- a/tests/mobly/test_runner_test.py
+++ b/tests/mobly/test_runner_test.py
@@ -32,6 +32,7 @@
from tests.lib import integration_test
from tests.lib import integration2_test
from tests.lib import integration3_test
+from tests.lib import multiple_subclasses_module
class TestRunnerTest(unittest.TestCase):
@@ -342,6 +343,22 @@
test_runner.main(['-c', tmp_file_path])
mock_exit.assert_called_once_with(1)
+ def test__find_test_class_when_one_test_class(self):
+ with mock.patch.dict('sys.modules', __main__=integration_test):
+ test_class = test_runner._find_test_class()
+ self.assertEqual(test_class, integration_test.IntegrationTest)
+
+ def test__find_test_class_when_no_test_class(self):
+ with self.assertRaises(SystemExit):
+ with mock.patch.dict('sys.modules', __main__=mock_controller):
+ test_class = test_runner._find_test_class()
+
+ def test__find_test_class_when_multiple_test_classes(self):
+ with self.assertRaises(SystemExit):
+ with mock.patch.dict('sys.modules',
+ __main__=multiple_subclasses_module):
+ test_class = test_runner._find_test_class()
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/mobly/utils_test.py b/tests/mobly/utils_test.py
index 1c49377..714c238 100755
--- a/tests/mobly/utils_test.py
+++ b/tests/mobly/utils_test.py
@@ -291,6 +291,34 @@
self.assertIn(multiple_subclasses_module.Subclass1Runner, subclasses)
self.assertIn(multiple_subclasses_module.Subclass2Runner, subclasses)
+ def test_find_subclass_in_module_when_one_subclass(self):
+ subclass = utils.find_subclass_in_module(base_test.BaseTestClass,
+ integration_test)
+ self.assertEqual(subclass, integration_test.IntegrationTest)
+
+ def test_find_subclass_in_module_when_indirect_subclass(self):
+ subclass = utils.find_subclass_in_module(base_test.BaseTestClass,
+ mock_instrumentation_test)
+ self.assertEqual(subclass,
+ mock_instrumentation_test.MockInstrumentationTest)
+
+ def test_find_subclass_in_module_when_no_subclasses(self):
+ with self.assertRaisesRegex(
+ ValueError,
+ '.*Expected 1 subclass of BaseTestClass per module, found'
+ r' \[\].*'):
+ _ = utils.find_subclass_in_module(base_test.BaseTestClass,
+ mock_controller)
+
+ def test_find_subclass_in_module_when_multiple_subclasses(self):
+ with self.assertRaisesRegex(
+ ValueError,
+ '.*Expected 1 subclass of BaseTestClass per module, found'
+ r' \[(\'Subclass1Test\', \'Subclass2Test\''
+ r'|\'Subclass2Test\', \'Subclass1Test\')\].*'):
+ _ = utils.find_subclass_in_module(base_test.BaseTestClass,
+ multiple_subclasses_module)
+
if __name__ == '__main__':
unittest.main()