Fix behaviors of `abort_all` in procedural functions. (#461)
* `abort_all` should skip remaining tests in a class just like `abort_class`.
* A record for `setup_class` should exist if `setup_class` failed and
`abort_all` is called in `on_fail`
* Handle the case of calling `abort_all` in `teardown_class`.
diff --git a/mobly/base_test.py b/mobly/base_test.py
index 9a2c263..0a2dd13 100644
--- a/mobly/base_test.py
+++ b/mobly/base_test.py
@@ -190,6 +190,9 @@
record.test_begin()
try:
self.teardown_class()
+ except signals.TestAbortAll as e:
+ setattr(e, 'results', self.results)
+ raise
except Exception as e:
logging.exception('Error encountered in teardown_class.')
record.test_error(e)
@@ -636,10 +639,10 @@
# Fail the class and skip all tests.
logging.exception('Error in setup_class %s.', self.TAG)
class_record.test_error(e)
- self._exec_procedure_func(self._on_fail, class_record)
self.results.add_class_error(class_record)
self.summary_writer.dump(class_record.to_dict(),
records.TestSummaryEntryType.RECORD)
+ self._exec_procedure_func(self._on_fail, class_record)
self._skip_remaining_tests(e)
return self.results
# Run tests in order.
@@ -651,6 +654,8 @@
self._skip_remaining_tests(e)
return self.results
except signals.TestAbortAll as e:
+ e.details = 'All remaining tests aborted due to: %s' % e.details
+ self._skip_remaining_tests(e)
# Piggy-back test results on this exception object so we don't lose
# results from this test class.
setattr(e, 'results', self.results)
diff --git a/tests/mobly/base_test_test.py b/tests/mobly/base_test_test.py
index a38b532..26fe3da 100755
--- a/tests/mobly/base_test_test.py
+++ b/tests/mobly/base_test_test.py
@@ -874,7 +874,7 @@
("Error 0, Executed 2, Failed 1, Passed 1, "
"Requested 3, Skipped 1"))
- def test_abort_all_setup_class(self):
+ def test_abort_all_in_setup_class(self):
class MockBaseTest(base_test.BaseTestClass):
def setup_class(self):
asserts.abort_all(MSG_EXPECTED_EXCEPTION)
@@ -895,7 +895,27 @@
self.assertTrue(hasattr(context.exception, 'results'))
self.assertEqual(bt_cls.results.summary_str(),
("Error 0, Executed 0, Failed 0, Passed 0, "
- "Requested 3, Skipped 0"))
+ "Requested 3, Skipped 3"))
+
+ def test_abort_all_in_teardown_class(self):
+ class MockBaseTest(base_test.BaseTestClass):
+ def test_1(self):
+ pass
+
+ def test_2(self):
+ pass
+
+ def teardown_class(self):
+ asserts.abort_all(MSG_EXPECTED_EXCEPTION)
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ with self.assertRaisesRegex(signals.TestAbortAll,
+ MSG_EXPECTED_EXCEPTION) as context:
+ bt_cls.run(test_names=["test_1", "test_2"])
+ self.assertTrue(hasattr(context.exception, 'results'))
+ self.assertEqual(bt_cls.results.summary_str(),
+ ("Error 0, Executed 2, Failed 0, Passed 2, "
+ "Requested 2, Skipped 0"))
def test_abort_all_in_setup_test(self):
class MockBaseTest(base_test.BaseTestClass):
@@ -918,7 +938,7 @@
self.assertTrue(hasattr(context.exception, 'results'))
self.assertEqual(bt_cls.results.summary_str(),
("Error 0, Executed 1, Failed 1, Passed 0, "
- "Requested 3, Skipped 0"))
+ "Requested 3, Skipped 2"))
def test_abort_all_in_on_fail(self):
class MockBaseTest(base_test.BaseTestClass):
@@ -941,7 +961,7 @@
self.assertTrue(hasattr(context.exception, 'results'))
self.assertEqual(bt_cls.results.summary_str(),
("Error 0, Executed 1, Failed 1, Passed 0, "
- "Requested 3, Skipped 0"))
+ "Requested 3, Skipped 2"))
def test_abort_all_in_on_fail_from_setup_class(self):
class MockBaseTest(base_test.BaseTestClass):
@@ -964,10 +984,12 @@
with self.assertRaisesRegex(signals.TestAbortAll,
MSG_EXPECTED_EXCEPTION) as context:
bt_cls.run(test_names=["test_1", "test_2", "test_3"])
+ setup_class_record = bt_cls.results.error[0]
+ self.assertEqual(setup_class_record.test_name, 'setup_class')
self.assertTrue(hasattr(context.exception, 'results'))
self.assertEqual(bt_cls.results.summary_str(),
- ("Error 0, Executed 0, Failed 0, Passed 0, "
- "Requested 3, Skipped 0"))
+ ("Error 1, Executed 0, Failed 0, Passed 0, "
+ "Requested 3, Skipped 3"))
def test_abort_all_in_test(self):
class MockBaseTest(base_test.BaseTestClass):
@@ -991,7 +1013,7 @@
MSG_EXPECTED_EXCEPTION)
self.assertEqual(bt_cls.results.summary_str(),
("Error 0, Executed 2, Failed 1, Passed 1, "
- "Requested 3, Skipped 0"))
+ "Requested 3, Skipped 1"))
def test_uncaught_exception(self):
class MockBaseTest(base_test.BaseTestClass):