Add find_subclass_in_module utility. (#632)

diff --git a/mobly/test_runner.py b/mobly/test_runner.py
index d32d304..4dabfb2 100644
--- a/mobly/test_runner.py
+++ b/mobly/test_runner.py
@@ -145,14 +145,17 @@
 
     Returns:
         The test class in the test module.
+
+    Raises:
+      SystemExit: Raised if the number of test classes is not exactly one.
     """
-    test_classes = utils.find_subclasses_in_module(base_test.BaseTestClass,
-                                                   sys.modules['__main__'])
-    if len(test_classes) != 1:
-        logging.error('Expected 1 test class per file, found %s.',
-                      [t.__name__ for t in test_classes])
+    try:
+        return utils.find_subclass_in_module(base_test.BaseTestClass,
+                                             sys.modules['__main__'])
+    except ValueError:
+        logging.exception('Exactly one subclass of `base_test.BaseTestClass`'
+                          ' should be in the main file.')
         sys.exit(1)
-    return test_classes[0]
 
 
 def _print_test_names(test_class):
diff --git a/mobly/utils.py b/mobly/utils.py
index a32c779..2a60bae 100644
--- a/mobly/utils.py
+++ b/mobly/utils.py
@@ -567,3 +567,24 @@
                 if issubclass(module_member, base_class):
                     subclasses.append(module_member)
     return subclasses
+
+
+def find_subclass_in_module(base_class, module):
+    """Finds the single subclass of the given base class in the given module.
+
+    Args:
+      base_class: class, the base class to look for a subclass of in the module.
+      module: module, the module to look for the single subclass in.
+
+    Returns:
+      The single subclass of the given base class.
+
+    Raises:
+      ValueError: If the number of subclasses found was not exactly one.
+    """
+    subclasses = find_subclasses_in_module([base_class], module)
+    if len(subclasses) != 1:
+        raise ValueError('Expected 1 subclass of %s per module, found %s.' %
+                         (base_class.__name__,
+                          [subclass.__name__ for subclass in subclasses]))
+    return subclasses[0]
diff --git a/tests/mobly/test_runner_test.py b/tests/mobly/test_runner_test.py
index 90e9268..4074891 100755
--- a/tests/mobly/test_runner_test.py
+++ b/tests/mobly/test_runner_test.py
@@ -32,6 +32,7 @@
 from tests.lib import integration_test
 from tests.lib import integration2_test
 from tests.lib import integration3_test
+from tests.lib import multiple_subclasses_module
 
 
 class TestRunnerTest(unittest.TestCase):
@@ -342,6 +343,22 @@
         test_runner.main(['-c', tmp_file_path])
         mock_exit.assert_called_once_with(1)
 
+    def test__find_test_class_when_one_test_class(self):
+        with mock.patch.dict('sys.modules', __main__=integration_test):
+            test_class = test_runner._find_test_class()
+            self.assertEqual(test_class, integration_test.IntegrationTest)
+
+    def test__find_test_class_when_no_test_class(self):
+        with self.assertRaises(SystemExit):
+            with mock.patch.dict('sys.modules', __main__=mock_controller):
+                test_class = test_runner._find_test_class()
+
+    def test__find_test_class_when_multiple_test_classes(self):
+        with self.assertRaises(SystemExit):
+            with mock.patch.dict('sys.modules',
+                                 __main__=multiple_subclasses_module):
+                test_class = test_runner._find_test_class()
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/tests/mobly/utils_test.py b/tests/mobly/utils_test.py
index 1c49377..714c238 100755
--- a/tests/mobly/utils_test.py
+++ b/tests/mobly/utils_test.py
@@ -291,6 +291,34 @@
         self.assertIn(multiple_subclasses_module.Subclass1Runner, subclasses)
         self.assertIn(multiple_subclasses_module.Subclass2Runner, subclasses)
 
+    def test_find_subclass_in_module_when_one_subclass(self):
+        subclass = utils.find_subclass_in_module(base_test.BaseTestClass,
+                                                 integration_test)
+        self.assertEqual(subclass, integration_test.IntegrationTest)
+
+    def test_find_subclass_in_module_when_indirect_subclass(self):
+        subclass = utils.find_subclass_in_module(base_test.BaseTestClass,
+                                                 mock_instrumentation_test)
+        self.assertEqual(subclass,
+                         mock_instrumentation_test.MockInstrumentationTest)
+
+    def test_find_subclass_in_module_when_no_subclasses(self):
+        with self.assertRaisesRegex(
+                ValueError,
+                '.*Expected 1 subclass of BaseTestClass per module, found'
+                r' \[\].*'):
+            _ = utils.find_subclass_in_module(base_test.BaseTestClass,
+                                              mock_controller)
+
+    def test_find_subclass_in_module_when_multiple_subclasses(self):
+        with self.assertRaisesRegex(
+                ValueError,
+                '.*Expected 1 subclass of BaseTestClass per module, found'
+                r' \[(\'Subclass1Test\', \'Subclass2Test\''
+                r'|\'Subclass2Test\', \'Subclass1Test\')\].*'):
+            _ = utils.find_subclass_in_module(base_test.BaseTestClass,
+                                              multiple_subclasses_module)
+
 
 if __name__ == '__main__':
     unittest.main()