Standardize Async RPC for snippet client v2 (#826)
* Provide a simple abstract base class for callback handlers of all platforms.
* Implement the callback handler v2 for Android.
diff --git a/mobly/controllers/android_device_lib/callback_handler_v2.py b/mobly/controllers/android_device_lib/callback_handler_v2.py
new file mode 100644
index 0000000..5675f7a
--- /dev/null
+++ b/mobly/controllers/android_device_lib/callback_handler_v2.py
@@ -0,0 +1,67 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The callback handler V2 module for Android Mobly Snippet Lib."""
+
+from mobly.snippet import callback_handler_base
+from mobly.snippet import errors
+
+# The timeout error meesage when pulling events from the server
+TIMEOUT_ERROR_MESSAGE = 'EventSnippetException: timeout.'
+
+
+class CallbackHandlerV2(callback_handler_base.CallbackHandlerBase):
+ """The callback handler V2 class for Android Mobly Snippet Lib."""
+
+ def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec):
+ """Waits and returns an existing CallbackEvent for the specified identifier.
+
+ This function calls snippet lib's eventWaitAndGet RPC.
+
+ Args:
+ callback_id: str, the callback identifier.
+ event_name: str, the callback name.
+ timeout_sec: float, the number of seconds to wait for the event.
+
+ Returns:
+ The event dictionary.
+
+ Raises:
+ errors.CallbackHandlerTimeoutError: The expected event does not occur
+ within the time limit.
+ """
+ timeout_ms = int(timeout_sec * 1000)
+ try:
+ return self._event_client.eventWaitAndGet(callback_id, event_name,
+ timeout_ms)
+ except Exception as e:
+ if TIMEOUT_ERROR_MESSAGE in str(e):
+ raise errors.CallbackHandlerTimeoutError(
+ self._device, (f'Timed out after waiting {timeout_sec}s for event '
+ f'"{event_name}" triggered by {self._method_name} '
+ f'({self.callback_id}).')) from e
+ raise
+
+ def callEventGetAllRpc(self, callback_id, event_name):
+ """Gets all existing events for the specified identifier without waiting.
+
+ This function calls snippet lib's eventGetAll RPC.
+
+ Args:
+ callback_id: str, the callback identifier.
+ event_name: str, the callback name.
+
+ Returns:
+ A list of event dictionaries.
+ """
+ return self._event_client.eventGetAll(callback_id, event_name)
diff --git a/mobly/controllers/android_device_lib/snippet_client_v2.py b/mobly/controllers/android_device_lib/snippet_client_v2.py
index 0c9540a..3adfde5 100644
--- a/mobly/controllers/android_device_lib/snippet_client_v2.py
+++ b/mobly/controllers/android_device_lib/snippet_client_v2.py
@@ -20,7 +20,7 @@
from mobly import utils
from mobly.controllers.android_device_lib import adb
-from mobly.controllers.android_device_lib import callback_handler
+from mobly.controllers.android_device_lib import callback_handler_v2
from mobly.controllers.android_device_lib import errors as android_device_lib_errors
from mobly.snippet import client_base
from mobly.snippet import errors
@@ -70,7 +70,10 @@
_SOCKET_CONNECTION_TIMEOUT = 60
# Maximum time to wait for a response message on the socket.
-_SOCKET_READ_TIMEOUT = callback_handler.MAX_TIMEOUT
+_SOCKET_READ_TIMEOUT = 60 * 10
+
+# The default timeout for callback handlers returned by this client
+_CALLBACK_DEFAULT_TIMEOUT_SEC = 60 * 2
class ConnectionHandshakeCommand(enum.Enum):
@@ -505,11 +508,14 @@
"""
if self._event_client is None:
self._create_event_client()
- return callback_handler.CallbackHandler(callback_id=callback_id,
- event_client=self._event_client,
- ret_value=ret_value,
- method_name=rpc_func_name,
- ad=self._device)
+ return callback_handler_v2.CallbackHandlerV2(
+ callback_id=callback_id,
+ event_client=self._event_client,
+ ret_value=ret_value,
+ method_name=rpc_func_name,
+ device=self._device,
+ rpc_max_timeout_sec=_SOCKET_READ_TIMEOUT,
+ default_timeout_sec=_CALLBACK_DEFAULT_TIMEOUT_SEC)
def _create_event_client(self):
"""Creates a separate client to the same session for propagating events.
diff --git a/mobly/snippet/callback_event.py b/mobly/snippet/callback_event.py
new file mode 100644
index 0000000..55471c7
--- /dev/null
+++ b/mobly/snippet/callback_event.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The class that represents callback events for Mobly Snippet Lib."""
+
+
+def from_dict(event_dict):
+ """Creates a CallbackEvent object from a dictionary.
+
+ Args:
+ event_dict: dict, a dictionary representing an event.
+
+ Returns:
+ A CallbackEvent object.
+ """
+ return CallbackEvent(callback_id=event_dict['callbackId'],
+ name=event_dict['name'],
+ creation_time=event_dict['time'],
+ data=event_dict['data'])
+
+
+class CallbackEvent:
+ """The class that represents callback events for Mobly Snippet Library.
+
+ Attributes:
+ callback_id: str, the callback ID associated with the event.
+ name: str, the name of the event.
+ creation_time: int, the epoch time when the event is created on the
+ RPC server side.
+ data: dict, the data held by the event. Can be None.
+ """
+
+ def __init__(self, callback_id, name, creation_time, data):
+ self.callback_id = callback_id
+ self.name = name
+ self.creation_time = creation_time
+ self.data = data
+
+ def __repr__(self):
+ return (
+ f'CallbackEvent(callback_id: {self.callback_id}, name: {self.name}, '
+ f'creation_time: {self.creation_time}, data: {self.data})')
diff --git a/mobly/snippet/callback_handler_base.py b/mobly/snippet/callback_handler_base.py
new file mode 100644
index 0000000..50465d1
--- /dev/null
+++ b/mobly/snippet/callback_handler_base.py
@@ -0,0 +1,240 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for the base class to handle Mobly Snippet Lib's callback events."""
+import abc
+import time
+
+from mobly.snippet import callback_event
+from mobly.snippet import errors
+
+
+class CallbackHandlerBase(abc.ABC):
+ """Base class for handling Mobly Snippet Lib's callback events.
+
+ All the events handled by a callback handler are originally triggered by one
+ async RPC call. All the events are tagged with a callback_id specific to a
+ call to an async RPC method defined on the server side.
+
+ The raw message representing an event looks like:
+
+ .. code-block:: python
+
+ {
+ 'callbackId': <string, callbackId>,
+ 'name': <string, name of the event>,
+ 'time': <long, epoch time of when the event was created on the
+ server side>,
+ 'data': <dict, extra data from the callback on the server side>
+ }
+
+ Each message is then used to create a CallbackEvent object on the client
+ side.
+
+ Attributes:
+ ret_value: any, the direct return value of the async RPC call.
+ """
+
+ def __init__(self,
+ callback_id,
+ event_client,
+ ret_value,
+ method_name,
+ device,
+ rpc_max_timeout_sec,
+ default_timeout_sec=120):
+ """Initializes a callback handler base object.
+
+ Args:
+ callback_id: str, the callback ID which associates with a group of
+ callback events.
+ event_client: SnippetClientV2, the client object used to send RPC to the
+ server and receive response.
+ ret_value: any, the direct return value of the async RPC call.
+ method_name: str, the name of the executed Async snippet function.
+ device: DeviceController, the device object associated with this handler.
+ rpc_max_timeout_sec: float, maximum time for sending a single RPC call.
+ default_timeout_sec: float, the default timeout for this handler. It
+ must be no longer than rpc_max_timeout_sec.
+ """
+ self._id = callback_id
+ self.ret_value = ret_value
+ self._device = device
+ self._event_client = event_client
+ self._method_name = method_name
+
+ if rpc_max_timeout_sec < default_timeout_sec:
+ raise ValueError('The max timeout of a single RPC must be no smaller '
+ 'than the default timeout of the callback handler. '
+ f'Got rpc_max_timeout_sec={rpc_max_timeout_sec}, '
+ f'default_timeout_sec={default_timeout_sec}.')
+ self._rpc_max_timeout_sec = rpc_max_timeout_sec
+ self._default_timeout_sec = default_timeout_sec
+
+ @property
+ def rpc_max_timeout_sec(self):
+ """Maximum time for sending a single RPC call."""
+ return self._rpc_max_timeout_sec
+
+ @property
+ def default_timeout_sec(self):
+ """Default timeout used by this callback handler."""
+ return self._default_timeout_sec
+
+ @property
+ def callback_id(self):
+ """The callback ID which associates a group of callback events."""
+ return self._id
+
+ @abc.abstractmethod
+ def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec):
+ """Calls snippet lib's RPC to wait for a callback event.
+
+ Override this method to use this class with various snippet lib
+ implementations.
+
+ This function waits and gets a CallbackEvent with the specified identifier
+ from the server. It will raise a timeout error if the expected event does
+ not occur within the time limit.
+
+ Args:
+ callback_id: str, the callback identifier.
+ event_name: str, the callback name.
+ timeout_sec: float, the number of seconds to wait for the event. It is
+ already checked that this argument is no longer than the max timeout
+ of a single RPC.
+
+ Returns:
+ The event dictionary.
+
+ Raises:
+ errors.CallbackHandlerTimeoutError: Raised if the expected event does not
+ occur within the time limit.
+ """
+
+ @abc.abstractmethod
+ def callEventGetAllRpc(self, callback_id, event_name):
+ """Calls snippet lib's RPC to get all existing snippet events.
+
+ Override this method to use this class with various snippet lib
+ implementations.
+
+ This function gets all existing events in the server with the specified
+ identifier without waiting.
+
+ Args:
+ callback_id: str, the callback identifier.
+ event_name: str, the callback name.
+
+ Returns:
+ A list of event dictionaries.
+ """
+
+ def waitAndGet(self, event_name, timeout=None):
+ """Waits and gets a CallbackEvent with the specified identifier.
+
+ It will raise a timeout error if the expected event does not occur within
+ the time limit.
+
+ Args:
+ event_name: str, the name of the event to get.
+ timeout: float, the number of seconds to wait before giving up. If None,
+ it will be set to self.default_timeout_sec.
+
+ Returns:
+ CallbackEvent, the oldest entry of the specified event.
+
+ Raises:
+ errors.CallbackHandlerBaseError: If the specified timeout is longer than
+ the max timeout supported.
+ errors.CallbackHandlerTimeoutError: The expected event does not occur
+ within the time limit.
+ """
+ if timeout is None:
+ timeout = self.default_timeout_sec
+
+ if timeout:
+ if timeout > self.rpc_max_timeout_sec:
+ raise errors.CallbackHandlerBaseError(
+ self._device,
+ f'Specified timeout {timeout} is longer than max timeout '
+ f'{self.rpc_max_timeout_sec}.')
+
+ raw_event = self.callEventWaitAndGetRpc(self._id, event_name, timeout)
+ return callback_event.from_dict(raw_event)
+
+ def waitForEvent(self, event_name, predicate, timeout=None):
+ """Waits for an event of the specific name that satisfies the predicate.
+
+ This call will block until the expected event has been received or time
+ out.
+
+ The predicate function defines the condition the event is expected to
+ satisfy. It takes an event and returns True if the condition is
+ satisfied, False otherwise.
+
+ Note all events of the same name that are received but don't satisfy
+ the predicate will be discarded and not be available for further
+ consumption.
+
+ Args:
+ event_name: str, the name of the event to wait for.
+ predicate: function, a function that takes an event (dictionary) and
+ returns a bool.
+ timeout: float, the number of seconds to wait before giving up. If None,
+ it will be set to self.default_timeout_sec.
+
+ Returns:
+ dictionary, the event that satisfies the predicate if received.
+
+ Raises:
+ errors.CallbackHandlerTimeoutError: raised if no event that satisfies the
+ predicate is received after timeout seconds.
+ """
+ if timeout is None:
+ timeout = self.default_timeout_sec
+
+ deadline = time.perf_counter() + timeout
+ while time.perf_counter() <= deadline:
+ single_rpc_timeout = deadline - time.perf_counter()
+ if single_rpc_timeout < 0:
+ break
+
+ single_rpc_timeout = min(single_rpc_timeout, self.rpc_max_timeout_sec)
+ try:
+ event = self.waitAndGet(event_name, single_rpc_timeout)
+ except errors.CallbackHandlerTimeoutError:
+ # Ignoring errors.CallbackHandlerTimeoutError since we need to throw
+ # one with a more specific message.
+ break
+ if predicate(event):
+ return event
+
+ raise errors.CallbackHandlerTimeoutError(
+ self._device,
+ f'Timed out after {timeout}s waiting for an "{event_name}" event that '
+ f'satisfies the predicate "{predicate.__name__}".')
+
+ def getAll(self, event_name):
+ """Gets all existing events in the server with the specified identifier.
+
+ This is a non-blocking call.
+
+ Args:
+ event_name: str, the name of the event to get.
+
+ Returns:
+ A list of CallbackEvent, each representing an event from the Server side.
+ """
+ raw_events = self.callEventGetAllRpc(self._id, event_name)
+ return [callback_event.from_dict(msg) for msg in raw_events]
diff --git a/mobly/snippet/errors.py b/mobly/snippet/errors.py
index 764aea4..4d41adb 100644
--- a/mobly/snippet/errors.py
+++ b/mobly/snippet/errors.py
@@ -60,3 +60,12 @@
class ServerDiedError(Error):
"""Raised if the snippet server died before all tests finish."""
+
+
+# Error types for callback handlers
+class CallbackHandlerBaseError(errors.DeviceError):
+ """Base error type for snippet clients."""
+
+
+class CallbackHandlerTimeoutError(Error):
+ """Raised if the expected event does not occur within the time limit."""
diff --git a/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py b/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py
new file mode 100644
index 0000000..b598cae
--- /dev/null
+++ b/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py
@@ -0,0 +1,152 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for callback_handler_v2.CallbackHandlerV2."""
+
+import unittest
+from unittest import mock
+
+from mobly.controllers.android_device_lib import callback_handler_v2
+from mobly.snippet import callback_event
+from mobly.snippet import errors
+
+MOCK_CALLBACK_ID = '2-1'
+MOCK_RAW_EVENT = {
+ 'callbackId': '2-1',
+ 'name': 'AsyncTaskResult',
+ 'time': 20460228696,
+ 'data': {
+ 'exampleData': "Here's a simple event.",
+ 'successful': True,
+ 'secretNumber': 12
+ }
+}
+
+
+class CallbackHandlerV2Test(unittest.TestCase):
+ """Unit tests for callback_handler_v2.CallbackHandlerV2."""
+
+ def _make_callback_handler(self,
+ callback_id=None,
+ event_client=None,
+ ret_value=None,
+ method_name=None,
+ device=None,
+ rpc_max_timeout_sec=600,
+ default_timeout_sec=120):
+ return callback_handler_v2.CallbackHandlerV2(
+ callback_id=callback_id,
+ event_client=event_client,
+ ret_value=ret_value,
+ method_name=method_name,
+ device=device,
+ rpc_max_timeout_sec=rpc_max_timeout_sec,
+ default_timeout_sec=default_timeout_sec)
+
+ def assert_event_correct(self, actual_event, expected_raw_event_dict):
+ expected_event = callback_event.from_dict(expected_raw_event_dict)
+ self.assertEqual(str(actual_event), str(expected_event))
+
+ def test_wait_and_get(self):
+ mock_event_client = mock.Mock()
+ mock_event_client.eventWaitAndGet = mock.Mock(return_value=MOCK_RAW_EVENT)
+ handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+ event_client=mock_event_client)
+ event = handler.waitAndGet('ha')
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+ mock_event_client.eventWaitAndGet.assert_called_once_with(
+ MOCK_CALLBACK_ID, 'ha', mock.ANY)
+
+ def test_wait_and_get_timeout_arg_transform(self):
+ mock_event_client = mock.Mock()
+ mock_event_client.eventWaitAndGet = mock.Mock(return_value=MOCK_RAW_EVENT)
+ handler = self._make_callback_handler(event_client=mock_event_client)
+
+ wait_and_get_timeout_sec = 10
+ expected_rpc_timeout_ms = 10000
+ _ = handler.waitAndGet('ha', timeout=wait_and_get_timeout_sec)
+ mock_event_client.eventWaitAndGet.assert_called_once_with(
+ mock.ANY, mock.ANY, expected_rpc_timeout_ms)
+
+ def test_wait_for_event(self):
+ mock_event_client = mock.Mock()
+ handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+ event_client=mock_event_client)
+
+ event_should_ignore = {
+ 'callbackId': '2-1',
+ 'name': 'AsyncTaskResult',
+ 'time': 20460228696,
+ 'data': {
+ 'successful': False,
+ }
+ }
+ mock_event_client.eventWaitAndGet.side_effect = [
+ event_should_ignore, MOCK_RAW_EVENT
+ ]
+
+ def some_condition(event):
+ return event.data['successful']
+
+ event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+ mock_event_client.eventWaitAndGet.assert_has_calls([
+ mock.call(MOCK_CALLBACK_ID, 'AsyncTaskResult', mock.ANY),
+ mock.call(MOCK_CALLBACK_ID, 'AsyncTaskResult', mock.ANY),
+ ])
+
+ def test_get_all(self):
+ mock_event_client = mock.Mock()
+ handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+ event_client=mock_event_client)
+
+ mock_event_client.eventGetAll = mock.Mock(
+ return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT])
+
+ all_events = handler.getAll('ha')
+ self.assertEqual(len(all_events), 2)
+ for event in all_events:
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+ mock_event_client.eventGetAll.assert_called_once_with(
+ MOCK_CALLBACK_ID, 'ha')
+
+ def test_wait_and_get_timeout_message_pattern_matches(self):
+ mock_event_client = mock.Mock()
+ android_snippet_timeout_msg = (
+ 'com.google.android.mobly.snippet.event.EventSnippet$'
+ 'EventSnippetException: timeout.')
+ mock_event_client.eventWaitAndGet = mock.Mock(
+ side_effect=errors.ApiError(mock.Mock(), android_snippet_timeout_msg))
+ handler = self._make_callback_handler(event_client=mock_event_client,
+ method_name='test_method')
+
+ expected_msg = ('Timed out after waiting .*s for event "ha" triggered by '
+ 'test_method .*')
+ with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError,
+ expected_msg):
+ handler.waitAndGet('ha')
+
+ def test_wait_and_get_reraise_if_pattern_not_match(self):
+ mock_event_client = mock.Mock()
+ snippet_timeout_msg = 'Snippet executed with error.'
+ mock_event_client.eventWaitAndGet = mock.Mock(
+ side_effect=errors.ApiError(mock.Mock(), snippet_timeout_msg))
+ handler = self._make_callback_handler(event_client=mock_event_client)
+
+ with self.assertRaisesRegex(errors.ApiError, snippet_timeout_msg):
+ handler.waitAndGet('ha')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
index c7e8ef1..1943abb 100644
--- a/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
+++ b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
@@ -209,8 +209,8 @@
@mock.patch('mobly.utils.stop_standing_subprocess')
@mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
'utils.start_standing_subprocess')
- @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
- 'CallbackHandler')
+ @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+ 'CallbackHandlerV2')
def test_the_whole_lifecycle_with_an_async_rpc(self, mock_callback_class,
mock_start_subprocess,
mock_stop_standing_subprocess,
@@ -244,11 +244,14 @@
self.assertListEqual(self.mock_socket_file.write.call_args_list,
expected_socket_writes)
- mock_callback_class.assert_called_with(callback_id='1-0',
- event_client=event_client,
- ret_value=123,
- method_name='some_async_rpc',
- ad=self.device)
+ mock_callback_class.assert_called_with(
+ callback_id='1-0',
+ event_client=event_client,
+ ret_value=123,
+ method_name='some_async_rpc',
+ device=self.device,
+ rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+ default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
self.assertIs(rpc_result, mock_callback_class.return_value)
self.assertIsNone(event_client.host_port, None)
self.assertIsNone(event_client.device_port, None)
@@ -260,8 +263,8 @@
@mock.patch('mobly.utils.stop_standing_subprocess')
@mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
'utils.start_standing_subprocess')
- @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
- 'CallbackHandler')
+ @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+ 'CallbackHandlerV2')
def test_the_whole_lifecycle_with_multiple_rpcs(self, mock_callback_class,
mock_start_subprocess,
mock_stop_standing_subprocess,
@@ -311,12 +314,18 @@
event_client=event_client,
ret_value=456,
method_name='some_async_rpc',
- ad=self.device),
- mock.call(callback_id='2-0',
- event_client=event_client,
- ret_value=321,
- method_name='some_async_rpc',
- ad=self.device),
+ device=self.device,
+ rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+ default_timeout_sec=(
+ snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)),
+ mock.call(
+ callback_id='2-0',
+ event_client=event_client,
+ ret_value=321,
+ method_name='some_async_rpc',
+ device=self.device,
+ rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+ default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
]
self.assertListEqual(rpc_results, rpc_results_expected)
mock_callback_class.assert_has_calls(mock_callback_class_calls_expected)
@@ -821,8 +830,8 @@
@mock.patch('socket.create_connection')
@mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
'utils.start_standing_subprocess')
- @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
- 'CallbackHandler')
+ @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+ 'CallbackHandlerV2')
def test_async_rpc_start_event_client(self, mock_callback_class,
mock_start_subprocess,
mock_socket_create_conn):
@@ -855,7 +864,9 @@
event_client=self.client._event_client,
ret_value=123,
method_name='some_async_rpc',
- ad=self.device)
+ device=self.device,
+ rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+ default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
self.assertIs(rpc_result, mock_callback_class.return_value)
# Ensure the event client is alive
diff --git a/tests/mobly/snippet/callback_event_test.py b/tests/mobly/snippet/callback_event_test.py
new file mode 100755
index 0000000..2593cc3
--- /dev/null
+++ b/tests/mobly/snippet/callback_event_test.py
@@ -0,0 +1,40 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.callback_event.CallbackEvent."""
+
+import unittest
+
+from mobly.snippet import callback_event
+
+MOCK_CALLBACK_ID = 'myCallbackId'
+MOCK_EVENT_NAME = 'onXyzEvent'
+MOCK_CREATION_TIME = '12345678'
+MOCK_DATA = {'foo': 'bar'}
+
+
+class CallbackEventTest(unittest.TestCase):
+ """Unit tests for mobly.snippet.callback_event.CallbackEvent."""
+
+ def test_basic(self):
+ """Verifies that an event object can be created and logged properly."""
+ event = callback_event.CallbackEvent(MOCK_CALLBACK_ID, MOCK_EVENT_NAME,
+ MOCK_CREATION_TIME, MOCK_DATA)
+ self.assertEqual(
+ repr(event),
+ "CallbackEvent(callback_id: myCallbackId, name: onXyzEvent, "
+ "creation_time: 12345678, data: {'foo': 'bar'})")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/mobly/snippet/callback_handler_base_test.py b/tests/mobly/snippet/callback_handler_base_test.py
new file mode 100644
index 0000000..cda26ef
--- /dev/null
+++ b/tests/mobly/snippet/callback_handler_base_test.py
@@ -0,0 +1,183 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
+
+import unittest
+from unittest import mock
+
+from mobly.snippet import callback_event
+from mobly.snippet import callback_handler_base
+from mobly.snippet import errors
+
+MOCK_CALLBACK_ID = '2-1'
+MOCK_RAW_EVENT = {
+ 'callbackId': '2-1',
+ 'name': 'AsyncTaskResult',
+ 'time': 20460228696,
+ 'data': {
+ 'exampleData': "Here's a simple event.",
+ 'successful': True,
+ 'secretNumber': 12
+ }
+}
+
+
+class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase):
+ """Fake client class for unit tests."""
+
+ def __init__(self,
+ callback_id=None,
+ event_client=None,
+ ret_value=None,
+ method_name=None,
+ device=None,
+ rpc_max_timeout_sec=120,
+ default_timeout_sec=120):
+ """Initializes a fake callback handler object used for unit tests."""
+ super().__init__(callback_id, event_client, ret_value, method_name, device,
+ rpc_max_timeout_sec, default_timeout_sec)
+ self.mock_rpc_func = mock.Mock()
+
+ def callEventWaitAndGetRpc(self, *args, **kwargs):
+ """See base class."""
+ return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs)
+
+ def callEventGetAllRpc(self, *args, **kwargs):
+ """See base class."""
+ return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs)
+
+
+class CallbackHandlerBaseTest(unittest.TestCase):
+ """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
+
+ def assert_event_correct(self, actual_event, expected_raw_event_dict):
+ expected_event = callback_event.from_dict(expected_raw_event_dict)
+ self.assertEqual(str(actual_event), str(expected_event))
+
+ def test_default_timeout_too_large(self):
+ err_msg = ('The max timeout of a single RPC must be no smaller than '
+ 'the default timeout of the callback handler. '
+ 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.')
+ with self.assertRaisesRegex(ValueError, err_msg):
+ _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20)
+
+ def test_timeout_property(self):
+ handler = FakeCallbackHandler(rpc_max_timeout_sec=20,
+ default_timeout_sec=10)
+ self.assertEqual(handler.rpc_max_timeout_sec, 20)
+ self.assertEqual(handler.default_timeout_sec, 10)
+ with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+ handler.rpc_max_timeout_sec = 5
+
+ with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+ handler.default_timeout_sec = 5
+
+ def test_callback_id_property(self):
+ handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+ self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID)
+ with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+ handler.callback_id = 'ha'
+
+ def test_event_dict_to_snippet_event(self):
+ handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+
+ event = handler.waitAndGet('ha', timeout=10)
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+ handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
+ MOCK_CALLBACK_ID, 'ha', 10)
+
+ def test_wait_and_get_timeout_default(self):
+ handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5)
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+ _ = handler.waitAndGet('ha')
+ handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
+ mock.ANY, mock.ANY, 5)
+
+ def test_wait_and_get_timeout_ecxeed_threshold(self):
+ rpc_max_timeout_sec = 5
+ big_timeout_sec = 10
+ handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
+ default_timeout_sec=rpc_max_timeout_sec)
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+
+ expected_msg = (
+ f'Specified timeout {big_timeout_sec} is longer than max timeout '
+ f'{rpc_max_timeout_sec}.')
+ with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg):
+ handler.waitAndGet('ha', big_timeout_sec)
+
+ def test_wait_for_event(self):
+ handler = FakeCallbackHandler()
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+
+ def some_condition(event):
+ return event.data['successful']
+
+ event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+ def test_wait_for_event_negative(self):
+ handler = FakeCallbackHandler()
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+
+ expected_msg = (
+ 'Timed out after 0.01s waiting for an "AsyncTaskResult" event that'
+ ' satisfies the predicate "some_condition".')
+
+ def some_condition(_):
+ return False
+
+ with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError,
+ expected_msg):
+ handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+
+ def test_wait_for_event_max_timeout(self):
+ """waitForEvent should not raise the timeout exceed threshold error."""
+ rpc_max_timeout_sec = 5
+ big_timeout_sec = 10
+ handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
+ default_timeout_sec=rpc_max_timeout_sec)
+ handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+ return_value=MOCK_RAW_EVENT)
+
+ def some_condition(event):
+ return event.data['successful']
+
+ # This line should not raise.
+ event = handler.waitForEvent('AsyncTaskResult',
+ some_condition,
+ timeout=big_timeout_sec)
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+ def test_get_all(self):
+ handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+ handler.mock_rpc_func.callEventGetAllRpc = mock.Mock(
+ return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT])
+
+ all_events = handler.getAll('ha')
+ for event in all_events:
+ self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+ handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with(
+ MOCK_CALLBACK_ID, 'ha')
+
+
+if __name__ == '__main__':
+ unittest.main()