Merge pull request #661 from winterfroststrom/concurrent_exec_raises
Improve concurrent_exec util
diff --git a/mobly/utils.py b/mobly/utils.py
index 68ce39f..fa644a2 100644
--- a/mobly/utils.py
+++ b/mobly/utils.py
@@ -29,6 +29,7 @@
import signal
import string
import subprocess
+import sys
import threading
import time
import traceback
@@ -256,7 +257,8 @@
# Thead/Process related functions.
-def concurrent_exec(func, param_list):
+def concurrent_exec(func, param_list, max_workers=30,
+ raise_on_exception=False):
"""Executes a function with different parameters pseudo-concurrently.
This is basically a map function. Each element (should be an iterable) in
@@ -267,16 +269,26 @@
func: The function that parforms a task.
param_list: A list of iterables, each being a set of params to be
passed into the function.
+ max_workers: int, the number of workers to use for parallelizing the
+ tasks. By default, this is 30 workers.
+ raise_on_exception: bool, raises all of the task failures if any of the
+ tasks failed if `True`. By default, this is `False`.
Returns:
A list of return values from each function execution. If an execution
caused an exception, the exception object will be the corresponding
result.
+
+ Raises:
+ RuntimeError: If executing any of the tasks failed and
+ `raise_on_exception` is True.
"""
- with concurrent.futures.ThreadPoolExecutor(max_workers=30) as executor:
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max_workers) as executor:
# Start the load operations and mark each future with its params
future_to_params = {executor.submit(func, *p): p for p in param_list}
return_vals = []
+ exceptions = []
for future in concurrent.futures.as_completed(future_to_params):
params = future_to_params[future]
try:
@@ -285,6 +297,22 @@
logging.exception("{} generated an exception: {}".format(
params, traceback.format_exc()))
return_vals.append(exc)
+ exceptions.append(exc)
+ if raise_on_exception and exceptions:
+ error_messages = []
+ if sys.version_info < (3, 0):
+ for exception in exceptions:
+ error_messages.append(
+ unicode(exception.message,
+ encoding='utf-8',
+ errors='replace'))
+ else:
+ for exception in exceptions:
+ error_messages.append(''.join(
+ traceback.format_exception(exception.__class__,
+ exception,
+ exception.__traceback__)))
+ raise RuntimeError('\n\n'.join(error_messages))
return return_vals
diff --git a/tests/mobly/utils_test.py b/tests/mobly/utils_test.py
index 207273c..e4dca9c 100755
--- a/tests/mobly/utils_test.py
+++ b/tests/mobly/utils_test.py
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from concurrent import futures
import io
import mock
import os
import platform
import shutil
import socket
+import sys
import subprocess
import tempfile
import time
@@ -43,7 +45,6 @@
"""This test class has unit tests for the implementation of everything
under mobly.utils.
"""
-
def setUp(self):
system = platform.system()
self.tmp_dir = tempfile.mkdtemp()
@@ -166,6 +167,235 @@
utils.stop_standing_subprocess(p)
self.assertFalse(p1.is_running())
+ @unittest.skipIf(sys.version_info >= (3, 4) and sys.version_info < (3, 5),
+ 'Python 3.4 does not support `None` max_workers.')
+ def test_concurrent_exec_when_none_workers(self):
+ def adder(a, b):
+ return a + b
+
+ with mock.patch.object(
+ futures, 'ThreadPoolExecutor',
+ wraps=futures.ThreadPoolExecutor) as thread_pool_spy:
+ results = utils.concurrent_exec(adder, [(1, 1), (2, 2)],
+ max_workers=None)
+
+ thread_pool_spy.assert_called_once_with(max_workers=None)
+
+ self.assertEqual(len(results), 2)
+ self.assertIn(2, results)
+ self.assertIn(4, results)
+
+ def test_concurrent_exec_when_default_max_workers(self):
+ def adder(a, b):
+ return a + b
+
+ with mock.patch.object(
+ futures, 'ThreadPoolExecutor',
+ wraps=futures.ThreadPoolExecutor) as thread_pool_spy:
+ results = utils.concurrent_exec(adder, [(1, 1), (2, 2)])
+
+ thread_pool_spy.assert_called_once_with(max_workers=30)
+
+ self.assertEqual(len(results), 2)
+ self.assertIn(2, results)
+ self.assertIn(4, results)
+
+ def test_concurrent_exec_when_custom_max_workers(self):
+ def adder(a, b):
+ return a + b
+
+ with mock.patch.object(
+ futures, 'ThreadPoolExecutor',
+ wraps=futures.ThreadPoolExecutor) as thread_pool_spy:
+ results = utils.concurrent_exec(adder, [(1, 1), (2, 2)],
+ max_workers=1)
+
+ thread_pool_spy.assert_called_once_with(max_workers=1)
+ self.assertEqual(len(results), 2)
+ self.assertIn(2, results)
+ self.assertIn(4, results)
+
+ def test_concurrent_exec_makes_all_calls(self):
+ mock_function = mock.MagicMock()
+ _ = utils.concurrent_exec(mock_function, [
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ ])
+ self.assertEqual(mock_function.call_count, 3)
+ mock_function.assert_has_calls(
+ [mock.call(1, 1),
+ mock.call(2, 2),
+ mock.call(3, 3)],
+ any_order=True)
+
+ def test_concurrent_exec_generates_results(self):
+ def adder(a, b):
+ return a + b
+
+ results = utils.concurrent_exec(adder, [(1, 1), (2, 2)])
+ self.assertEqual(len(results), 2)
+ self.assertIn(2, results)
+ self.assertIn(4, results)
+
+ def test_concurrent_exec_when_exception_makes_all_calls(self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ results = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('123', ),
+ ('not_int', ),
+ (5435, ),
+ ])
+
+ self.assertEqual(mock_call_recorder.call_count, 4)
+ mock_call_recorder.assert_has_calls([
+ mock.call(1),
+ mock.call('123'),
+ mock.call('not_int'),
+ mock.call(5435),
+ ],
+ any_order=True)
+
+ def test_concurrent_exec_when_exception_generates_results(self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ results = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('123', ),
+ ('not_int', ),
+ (5435, ),
+ ])
+
+ self.assertEqual(len(results), 4)
+ self.assertIn(1, results)
+ self.assertIn(123, results)
+ self.assertIn(5435, results)
+ exceptions = [
+ result for result in results if isinstance(result, Exception)
+ ]
+ self.assertEqual(len(exceptions), 1)
+ self.assertIsInstance(exceptions[0], ValueError)
+
+ def test_concurrent_exec_when_multiple_exceptions_makes_all_calls(self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ results = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('not_int1', ),
+ ('not_int2', ),
+ (5435, ),
+ ])
+
+ self.assertEqual(mock_call_recorder.call_count, 4)
+ mock_call_recorder.assert_has_calls([
+ mock.call(1),
+ mock.call('not_int1'),
+ mock.call('not_int2'),
+ mock.call(5435),
+ ],
+ any_order=True)
+
+ def test_concurrent_exec_when_multiple_exceptions_generates_results(self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ results = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('not_int1', ),
+ ('not_int2', ),
+ (5435, ),
+ ])
+
+ self.assertEqual(len(results), 4)
+ self.assertIn(1, results)
+ self.assertIn(5435, results)
+ exceptions = [
+ result for result in results if isinstance(result, Exception)
+ ]
+ self.assertEqual(len(exceptions), 2)
+ self.assertIsInstance(exceptions[0], ValueError)
+ self.assertIsInstance(exceptions[1], ValueError)
+ self.assertNotEqual(exceptions[0], exceptions[1])
+
+ def test_concurrent_exec_when_raising_exception_generates_results(self):
+ def adder(a, b):
+ return a + b
+
+ results = utils.concurrent_exec(adder, [(1, 1), (2, 2)],
+ raise_on_exception=True)
+ self.assertEqual(len(results), 2)
+ self.assertIn(2, results)
+ self.assertIn(4, results)
+
+ def test_concurrent_exec_when_raising_exception_makes_all_calls(self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ with self.assertRaisesRegex(RuntimeError, '.*not_int.*'):
+ _ = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('123', ),
+ ('not_int', ),
+ (5435, ),
+ ],
+ raise_on_exception=True)
+
+ self.assertEqual(mock_call_recorder.call_count, 4)
+ mock_call_recorder.assert_has_calls([
+ mock.call(1),
+ mock.call('123'),
+ mock.call('not_int'),
+ mock.call(5435),
+ ],
+ any_order=True)
+
+ def test_concurrent_exec_when_raising_multiple_exceptions_makes_all_calls(
+ self):
+ mock_call_recorder = mock.MagicMock()
+
+ def fake_int(a, ):
+ mock_call_recorder(a)
+ return int(a)
+
+ with self.assertRaisesRegex(
+ RuntimeError,
+ r'(?m).*(not_int1(.|\s)+not_int2|not_int2(.|\s)+not_int1).*'):
+ _ = utils.concurrent_exec(fake_int, [
+ (1, ),
+ ('not_int1', ),
+ ('not_int2', ),
+ (5435, ),
+ ],
+ raise_on_exception=True)
+
+ self.assertEqual(mock_call_recorder.call_count, 4)
+ mock_call_recorder.assert_has_calls([
+ mock.call(1),
+ mock.call('not_int1'),
+ mock.call('not_int2'),
+ mock.call(5435),
+ ],
+ any_order=True)
+
def test_create_dir(self):
new_path = os.path.join(self.tmp_dir, 'haha')
self.assertFalse(os.path.exists(new_path))