Make test_runner logging a context manager. (#606)

diff --git a/mobly/test_runner.py b/mobly/test_runner.py
index 7c0b0cc..77db5cf 100644
--- a/mobly/test_runner.py
+++ b/mobly/test_runner.py
@@ -17,6 +17,7 @@
 standard_library.install_aliases()
 
 import argparse
+import contextlib
 import inspect
 import logging
 import os
@@ -73,16 +74,17 @@
     for config in test_configs:
         runner = TestRunner(log_dir=config.log_path,
                             test_bed_name=config.test_bed_name)
-        runner.add_test_class(config, test_class, tests)
-        try:
-            runner.run()
-            ok = runner.results.is_all_pass and ok
-        except signals.TestAbortAll:
-            pass
-        except:
-            logging.exception('Exception when executing %s.',
-                              config.test_bed_name)
-            ok = False
+        with runner.mobly_loggger():
+            runner.add_test_class(config, test_class, tests)
+            try:
+                runner.run()
+                ok = runner.results.is_all_pass and ok
+            except signals.TestAbortAll:
+                pass
+            except:
+                logging.exception('Exception when executing %s.',
+                                  config.test_bed_name)
+                ok = False
     if not ok:
         sys.exit(1)
 
@@ -187,12 +189,10 @@
         self.results: The test result object used to record the results of
             this test run.
     """
-
     class _TestRunInfo(object):
         """Identifies one test class to run, which tests to run, and config to
         run it with.
         """
-
         def __init__(self,
                      config,
                      test_class,
@@ -218,45 +218,22 @@
 
         self._log_path = None
 
-    def setup_logger(self):
-        """Sets up logging for the next test run.
+    @contextlib.contextmanager
+    def mobly_logger(self):
+        """Starts and stops a logging context for a Mobly test run.
 
-        This is called automatically in 'run', so normally, this method doesn't
-        need to be called. Only use this method if you want to use Mobly's
-        logger before the test run starts.
-
-        .. code-block:: python
-
-            tr = TestRunner(...)
-            tr.setup_logger()
-            logging.info(...)
-            tr.run()
-
+        Yields:
+            The host file path where the logs for the test run are stored.
         """
-        if self._log_path is not None:
-            return
-
         self._start_time = logger.get_log_file_timestamp()
         self._log_path = os.path.join(self._log_dir, self._test_bed_name,
                                       self._start_time)
         logger.setup_test_logger(self._log_path, self._test_bed_name)
-
-    def _teardown_logger(self):
-        """Tears down logging at the end of the test run.
-
-        This is called automatically in 'run', so normally, this method doesn't
-        need to be called. Only use this to change the logger teardown
-        behaviour.
-
-        Raises:
-            Error: if this is called before the logger is setup.
-        """
-        if self._log_path is None:
-            raise Error('TestRunner._teardown_logger() called before '
-                        'TestRunner.setup_logger()!')
-
-        logger.kill_test_logger(logging.getLogger())
-        self._log_path = None
+        try:
+            yield self._log_path
+        finally:
+            logger.kill_test_logger(logging.getLogger())
+            self._log_path = None
 
     def add_test_class(self, config, test_class, tests=None, name_suffix=None):
         """Adds tests to the execution plan of this TestRunner.
@@ -327,7 +304,6 @@
         if not self._test_run_infos:
             raise Error('No tests to execute.')
 
-        self.setup_logger()
         summary_writer = records.TestSummaryWriter(
             os.path.join(self._log_path, records.OUTPUT_FILE_SUMMARY))
         try:
@@ -353,4 +329,3 @@
                 self._test_bed_name, self._start_time,
                 self.results.summary_str())
             logging.info(msg.strip())
-            self._teardown_logger()
diff --git a/tests/mobly/output_test.py b/tests/mobly/output_test.py
index 6caf9b8..743c342 100755
--- a/tests/mobly/output_test.py
+++ b/tests/mobly/output_test.py
@@ -40,7 +40,6 @@
     """This test class has unit tests for the implementation of Mobly's output
     files.
     """
-
     def setUp(self):
         self.tmp_dir = tempfile.mkdtemp()
         self.base_mock_test_config = config_parser.TestRunConfig()
@@ -94,6 +93,11 @@
             for item in blacklist:
                 self.assertNotIn(item, content)
 
+    def test_yields_logging_path(self):
+        tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
+        with tr.mobly_logger() as log_path:
+            self.assertEqual(log_path, logging.log_path)
+
     @unittest.skipIf(platform.system() == 'Windows',
                      'Symlinks are usually specific to Unix operating systems')
     def test_symlink(self):
@@ -101,7 +105,8 @@
         mock_test_config = self.create_mock_test_config(
             self.base_mock_test_config)
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.setup_logger()
+        with tr.mobly_logger():
+            pass
         symlink = os.path.join(self.log_dir, self.test_bed_name, 'latest')
         self.assertEqual(os.readlink(symlink), logging.log_path)
 
@@ -117,8 +122,8 @@
         mock_test_config = self.create_mock_test_config(
             self.base_mock_test_config)
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.setup_logger()
-        tr._teardown_logger()
+        with tr.mobly_logger():
+            pass
         shortcut = shell.CreateShortCut(shortcut_path)
         # Normalize paths for case and truncation
         normalized_shortcut_path = os.path.normcase(
@@ -127,7 +132,7 @@
             win32file.GetLongPathName(logging.log_path))
         self.assertEqual(normalized_shortcut_path, normalized_logger_path)
 
-    def test_setup_logger_before_run(self):
+    def test_logging_before_run(self):
         """Verifies the expected output files from a test run.
 
         * Files are correctly created.
@@ -138,21 +143,23 @@
         info_uuid = 'e098d4ff-4e90-4e08-b369-aa84a7ef90ec'
         debug_uuid = 'c6f1474e-960a-4df8-8305-1c5b8b905eca'
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.setup_logger()
-        logging.info(info_uuid)
-        logging.debug(debug_uuid)
-        tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.run()
+        with tr.mobly_logger():
+            logging.info(info_uuid)
+            logging.debug(debug_uuid)
+            tr.add_test_class(mock_test_config,
+                              integration_test.IntegrationTest)
+            tr.run()
         output_dir = logging.log_path
         (summary_file_path, debug_log_path,
          info_log_path) = self.assert_output_logs_exist(output_dir)
-        self.assert_log_contents(
-            debug_log_path, whitelist=[debug_uuid, info_uuid])
-        self.assert_log_contents(
-            info_log_path, whitelist=[info_uuid], blacklist=[debug_uuid])
+        self.assert_log_contents(debug_log_path,
+                                 whitelist=[debug_uuid, info_uuid])
+        self.assert_log_contents(info_log_path,
+                                 whitelist=[info_uuid],
+                                 blacklist=[debug_uuid])
 
-    @mock.patch(
-        'mobly.logger.get_log_file_timestamp', side_effect=str(time.time()))
+    @mock.patch('mobly.logger.get_log_file_timestamp',
+                side_effect=str(time.time()))
     def test_run_twice_for_two_sets_of_logs(self, mock_timestamp):
         """Verifies the expected output files from a test run.
 
@@ -163,17 +170,18 @@
             self.base_mock_test_config)
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
         tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.setup_logger()
-        tr.run()
+        with tr.mobly_logger():
+            tr.run()
         output_dir1 = logging.log_path
-        tr.run()
+        with tr.mobly_logger():
+            tr.run()
         output_dir2 = logging.log_path
         self.assertNotEqual(output_dir1, output_dir2)
         self.assert_output_logs_exist(output_dir1)
         self.assert_output_logs_exist(output_dir2)
 
-    @mock.patch(
-        'mobly.logger.get_log_file_timestamp', side_effect=str(time.time()))
+    @mock.patch('mobly.logger.get_log_file_timestamp',
+                side_effect=str(time.time()))
     def test_teardown_erases_logs(self, mock_timestamp):
         """Verifies the expected output files from a test run.
 
@@ -188,16 +196,14 @@
         debug_uuid2 = 'd564da87-c42f-49c3-b0bf-18fa97cf0218'
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
 
-        tr.setup_logger()
-        logging.info(info_uuid1)
-        logging.debug(debug_uuid1)
-        tr._teardown_logger()
+        with tr.mobly_logger():
+            logging.info(info_uuid1)
+            logging.debug(debug_uuid1)
         output_dir1 = logging.log_path
 
-        tr.setup_logger()
-        logging.info(info_uuid2)
-        logging.debug(debug_uuid2)
-        tr._teardown_logger()
+        with tr.mobly_logger():
+            logging.info(info_uuid2)
+            logging.debug(debug_uuid2)
         output_dir2 = logging.log_path
 
         self.assertNotEqual(output_dir1, output_dir2)
@@ -207,14 +213,12 @@
         (summary_file_path2, debug_log_path2,
          info_log_path2) = self.get_output_logs(output_dir2)
 
-        self.assert_log_contents(
-            debug_log_path1,
-            whitelist=[debug_uuid1, info_uuid1],
-            blacklist=[info_uuid2, debug_uuid2])
-        self.assert_log_contents(
-            debug_log_path2,
-            whitelist=[debug_uuid2, info_uuid2],
-            blacklist=[info_uuid1, debug_uuid1])
+        self.assert_log_contents(debug_log_path1,
+                                 whitelist=[debug_uuid1, info_uuid1],
+                                 blacklist=[info_uuid2, debug_uuid2])
+        self.assert_log_contents(debug_log_path2,
+                                 whitelist=[debug_uuid2, info_uuid2],
+                                 blacklist=[info_uuid1, debug_uuid1])
 
     def test_basic_output(self):
         """Verifies the expected output files from a test run.
@@ -225,8 +229,10 @@
         mock_test_config = self.create_mock_test_config(
             self.base_mock_test_config)
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(mock_test_config,
+                              integration_test.IntegrationTest)
+            tr.run()
         output_dir = logging.log_path
         (summary_file_path, debug_log_path,
          info_log_path) = self.assert_output_logs_exist(output_dir)
@@ -236,8 +242,9 @@
                 self.assertTrue(entry['Type'])
                 summary_entries.append(entry)
         self.assert_log_contents(debug_log_path, whitelist=['DEBUG', 'INFO'])
-        self.assert_log_contents(
-            info_log_path, whitelist=['INFO'], blacklist=['DEBUG'])
+        self.assert_log_contents(info_log_path,
+                                 whitelist=['INFO'],
+                                 blacklist=['DEBUG'])
 
     def test_teardown_class_output(self):
         """Verifies the summary file includes the failure record for
@@ -245,9 +252,11 @@
         """
         mock_test_config = self.base_mock_test_config.copy()
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config,
-                          teardown_class_failure_test.TearDownClassFailureTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(
+                mock_test_config,
+                teardown_class_failure_test.TearDownClassFailureTest)
+            tr.run()
         output_dir = logging.log_path
         summary_file_path = os.path.join(output_dir,
                                          records.OUTPUT_FILE_SUMMARY)
diff --git a/tests/mobly/test_runner_test.py b/tests/mobly/test_runner_test.py
index 6806f94..0987454 100755
--- a/tests/mobly/test_runner_test.py
+++ b/tests/mobly/test_runner_test.py
@@ -38,7 +38,6 @@
     """This test class has unit tests for the implementation of everything
     under mobly.test_runner.
     """
-
     def setUp(self):
         self.tmp_dir = tempfile.mkdtemp()
         self.base_mock_test_config = config_parser.TestRunConfig()
@@ -79,11 +78,14 @@
         }]
         mock_test_config.controller_configs[mock_ctrlr_config_name] = my_config
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(mock_test_config,
+                              integration_test.IntegrationTest)
+            tr.run()
         self.assertTrue(
             mock_test_config.controller_configs[mock_ctrlr_config_name][0])
-        tr.run()
+        with tr.mobly_logger():
+            tr.run()
         results = tr.results.summary_dict()
         self.assertEqual(results['Requested'], 2)
         self.assertEqual(results['Executed'], 2)
@@ -127,8 +129,10 @@
         }]
         mock_test_config.controller_configs[mock_ctrlr_config_name] = my_config
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(mock_test_config,
+                              integration_test.IntegrationTest)
+            tr.run()
         summary_path = os.path.join(logging.log_path,
                                     records.OUTPUT_FILE_SUMMARY)
         with io.open(summary_path, 'r', encoding='utf-8') as f:
@@ -144,18 +148,14 @@
         self.assertEqual(summary_entries[3]['Type'],
                          records.TestSummaryEntryType.SUMMARY.value)
 
-    @mock.patch(
-        'mobly.controllers.android_device_lib.adb.AdbProxy',
-        return_value=mock_android_device.MockAdbProxy(1))
-    @mock.patch(
-        'mobly.controllers.android_device_lib.fastboot.FastbootProxy',
-        return_value=mock_android_device.MockFastbootProxy(1))
-    @mock.patch(
-        'mobly.controllers.android_device.list_adb_devices',
-        return_value=['1'])
-    @mock.patch(
-        'mobly.controllers.android_device.get_all_instances',
-        return_value=mock_android_device.get_mock_ads(1))
+    @mock.patch('mobly.controllers.android_device_lib.adb.AdbProxy',
+                return_value=mock_android_device.MockAdbProxy(1))
+    @mock.patch('mobly.controllers.android_device_lib.fastboot.FastbootProxy',
+                return_value=mock_android_device.MockFastbootProxy(1))
+    @mock.patch('mobly.controllers.android_device.list_adb_devices',
+                return_value=['1'])
+    @mock.patch('mobly.controllers.android_device.get_all_instances',
+                return_value=mock_android_device.get_mock_ads(1))
     def test_run_two_test_classes(self, mock_get_all, mock_list_adb,
                                   mock_fastboot, mock_adb):
         """Verifies that running more than one test class in one test run works
@@ -178,9 +178,12 @@
             'serial': '1'
         }]
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config, integration2_test.Integration2Test)
-        tr.add_test_class(mock_test_config, integration_test.IntegrationTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(mock_test_config,
+                              integration2_test.Integration2Test)
+            tr.add_test_class(mock_test_config,
+                              integration_test.IntegrationTest)
+            tr.run()
         results = tr.results.summary_dict()
         self.assertEqual(results['Requested'], 2)
         self.assertEqual(results['Executed'], 2)
@@ -203,15 +206,14 @@
         config2 = config1.copy()
         config2.user_params['icecream'] = 10
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(
-            config1,
-            integration_test.IntegrationTest,
-            name_suffix='FirstConfig')
-        tr.add_test_class(
-            config2,
-            integration_test.IntegrationTest,
-            name_suffix='SecondConfig')
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(config1,
+                              integration_test.IntegrationTest,
+                              name_suffix='FirstConfig')
+            tr.add_test_class(config2,
+                              integration_test.IntegrationTest,
+                              name_suffix='SecondConfig')
+            tr.run()
         results = tr.results.summary_dict()
         self.assertEqual(results['Requested'], 2)
         self.assertEqual(results['Executed'], 2)
@@ -226,9 +228,11 @@
     def test_run_with_abort_all(self):
         mock_test_config = self.base_mock_test_config.copy()
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        tr.add_test_class(mock_test_config, integration3_test.Integration3Test)
-        with self.assertRaises(signals.TestAbortAll):
-            tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(mock_test_config,
+                              integration3_test.Integration3Test)
+            with self.assertRaises(signals.TestAbortAll):
+                tr.run()
         results = tr.results.summary_dict()
         self.assertEqual(results['Requested'], 1)
         self.assertEqual(results['Executed'], 0)
@@ -255,25 +259,15 @@
             tr.add_test_class(self.base_mock_test_config,
                               integration_test.IntegrationTest)
 
-    def test_teardown_logger_before_setup_logger(self):
-        tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
-        with self.assertRaisesRegex(
-                test_runner.Error,
-                r'TestRunner\._teardown_logger\(\) called before '
-                r'TestRunner\.setup_logger\(\)!'):
-            tr._teardown_logger()
-
     def test_run_no_tests(self):
         tr = test_runner.TestRunner(self.log_dir, self.test_bed_name)
         with self.assertRaisesRegex(test_runner.Error, 'No tests to execute.'):
             tr.run()
 
-    @mock.patch(
-        'mobly.test_runner._find_test_class',
-        return_value=type('SampleTest', (), {}))
-    @mock.patch(
-        'mobly.test_runner.config_parser.load_test_config_file',
-        return_value=[config_parser.TestRunConfig()])
+    @mock.patch('mobly.test_runner._find_test_class',
+                return_value=type('SampleTest', (), {}))
+    @mock.patch('mobly.test_runner.config_parser.load_test_config_file',
+                return_value=[config_parser.TestRunConfig()])
     @mock.patch('mobly.test_runner.TestRunner', return_value=mock.MagicMock())
     def test_main_parse_args(self, mock_test_runner, mock_config,
                              mock_find_test):
diff --git a/tests/mobly/test_suite_test.py b/tests/mobly/test_suite_test.py
index 9b1efc1..183042d 100755
--- a/tests/mobly/test_suite_test.py
+++ b/tests/mobly/test_suite_test.py
@@ -33,7 +33,6 @@
 
     Tests here target a combination of test_runner and base_test code.
     """
-
     def setUp(self):
         self.tmp_dir = tempfile.mkdtemp()
         self.mock_test_cls_configs = config_parser.TestRunConfig()
@@ -69,9 +68,10 @@
 
         tr = test_runner.TestRunner(self.tmp_dir,
                                     test_run_config.test_bed_name)
-        tr.add_test_class(test_run_config, FooTest)
-        tr.add_test_class(test_run_config, BarTest)
-        tr.run()
+        with tr.mobly_logger():
+            tr.add_test_class(test_run_config, FooTest)
+            tr.add_test_class(test_run_config, BarTest)
+            tr.run()
         self.assertIsNot(self.controller1, self.controller2)