| # Copyright 2022 Google Inc. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" |
| |
| import unittest |
| from unittest import mock |
| |
| from mobly.snippet import callback_event |
| from mobly.snippet import callback_handler_base |
| from mobly.snippet import errors |
| |
| MOCK_CALLBACK_ID = '2-1' |
| MOCK_RAW_EVENT = { |
| 'callbackId': '2-1', |
| 'name': 'AsyncTaskResult', |
| 'time': 20460228696, |
| 'data': { |
| 'exampleData': "Here's a simple event.", |
| 'successful': True, |
| 'secretNumber': 12 |
| } |
| } |
| |
| |
| class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase): |
| """Fake client class for unit tests.""" |
| |
| def __init__(self, |
| callback_id=None, |
| event_client=None, |
| ret_value=None, |
| method_name=None, |
| device=None, |
| rpc_max_timeout_sec=120, |
| default_timeout_sec=120): |
| """Initializes a fake callback handler object used for unit tests.""" |
| super().__init__(callback_id, event_client, ret_value, method_name, device, |
| rpc_max_timeout_sec, default_timeout_sec) |
| self.mock_rpc_func = mock.Mock() |
| |
| def callEventWaitAndGetRpc(self, *args, **kwargs): |
| """See base class.""" |
| return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs) |
| |
| def callEventGetAllRpc(self, *args, **kwargs): |
| """See base class.""" |
| return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs) |
| |
| |
| class CallbackHandlerBaseTest(unittest.TestCase): |
| """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" |
| |
| def assert_event_correct(self, actual_event, expected_raw_event_dict): |
| expected_event = callback_event.from_dict(expected_raw_event_dict) |
| self.assertEqual(str(actual_event), str(expected_event)) |
| |
| def test_default_timeout_too_large(self): |
| err_msg = ('The max timeout of a single RPC must be no smaller than ' |
| 'the default timeout of the callback handler. ' |
| 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.') |
| with self.assertRaisesRegex(ValueError, err_msg): |
| _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20) |
| |
| def test_timeout_property(self): |
| handler = FakeCallbackHandler(rpc_max_timeout_sec=20, |
| default_timeout_sec=10) |
| self.assertEqual(handler.rpc_max_timeout_sec, 20) |
| self.assertEqual(handler.default_timeout_sec, 10) |
| with self.assertRaisesRegex(AttributeError, "can't set attribute"): |
| handler.rpc_max_timeout_sec = 5 |
| |
| with self.assertRaisesRegex(AttributeError, "can't set attribute"): |
| handler.default_timeout_sec = 5 |
| |
| def test_callback_id_property(self): |
| handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) |
| self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID) |
| with self.assertRaisesRegex(AttributeError, "can't set attribute"): |
| handler.callback_id = 'ha' |
| |
| def test_event_dict_to_snippet_event(self): |
| handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| |
| event = handler.waitAndGet('ha', timeout=10) |
| self.assert_event_correct(event, MOCK_RAW_EVENT) |
| handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( |
| MOCK_CALLBACK_ID, 'ha', 10) |
| |
| def test_wait_and_get_timeout_default(self): |
| handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5) |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| _ = handler.waitAndGet('ha') |
| handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( |
| mock.ANY, mock.ANY, 5) |
| |
| def test_wait_and_get_timeout_ecxeed_threshold(self): |
| rpc_max_timeout_sec = 5 |
| big_timeout_sec = 10 |
| handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec, |
| default_timeout_sec=rpc_max_timeout_sec) |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| |
| expected_msg = ( |
| f'Specified timeout {big_timeout_sec} is longer than max timeout ' |
| f'{rpc_max_timeout_sec}.') |
| with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg): |
| handler.waitAndGet('ha', big_timeout_sec) |
| |
| def test_wait_for_event(self): |
| handler = FakeCallbackHandler() |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| |
| def some_condition(event): |
| return event.data['successful'] |
| |
| event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) |
| self.assert_event_correct(event, MOCK_RAW_EVENT) |
| |
| def test_wait_for_event_negative(self): |
| handler = FakeCallbackHandler() |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| |
| expected_msg = ( |
| 'Timed out after 0.01s waiting for an "AsyncTaskResult" event that' |
| ' satisfies the predicate "some_condition".') |
| |
| def some_condition(_): |
| return False |
| |
| with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError, |
| expected_msg): |
| handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) |
| |
| def test_wait_for_event_max_timeout(self): |
| """waitForEvent should not raise the timeout exceed threshold error.""" |
| rpc_max_timeout_sec = 5 |
| big_timeout_sec = 10 |
| handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec, |
| default_timeout_sec=rpc_max_timeout_sec) |
| handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( |
| return_value=MOCK_RAW_EVENT) |
| |
| def some_condition(event): |
| return event.data['successful'] |
| |
| # This line should not raise. |
| event = handler.waitForEvent('AsyncTaskResult', |
| some_condition, |
| timeout=big_timeout_sec) |
| self.assert_event_correct(event, MOCK_RAW_EVENT) |
| |
| def test_get_all(self): |
| handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) |
| handler.mock_rpc_func.callEventGetAllRpc = mock.Mock( |
| return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT]) |
| |
| all_events = handler.getAll('ha') |
| for event in all_events: |
| self.assert_event_correct(event, MOCK_RAW_EVENT) |
| |
| handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with( |
| MOCK_CALLBACK_ID, 'ha') |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |