Standardize Async RPC for snippet client v2 (#826)

* Provide a simple abstract base class for callback handlers of all platforms.
* Implement the callback handler v2 for Android. 
diff --git a/mobly/controllers/android_device_lib/callback_handler_v2.py b/mobly/controllers/android_device_lib/callback_handler_v2.py
new file mode 100644
index 0000000..5675f7a
--- /dev/null
+++ b/mobly/controllers/android_device_lib/callback_handler_v2.py
@@ -0,0 +1,67 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The callback handler V2 module for Android Mobly Snippet Lib."""
+
+from mobly.snippet import callback_handler_base
+from mobly.snippet import errors
+
+# The timeout error meesage when pulling events from the server
+TIMEOUT_ERROR_MESSAGE = 'EventSnippetException: timeout.'
+
+
+class CallbackHandlerV2(callback_handler_base.CallbackHandlerBase):
+  """The callback handler V2 class for Android Mobly Snippet Lib."""
+
+  def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec):
+    """Waits and returns an existing CallbackEvent for the specified identifier.
+
+    This function calls snippet lib's eventWaitAndGet RPC.
+
+    Args:
+      callback_id: str, the callback identifier.
+      event_name: str, the callback name.
+      timeout_sec: float, the number of seconds to wait for the event.
+
+    Returns:
+      The event dictionary.
+
+    Raises:
+      errors.CallbackHandlerTimeoutError: The expected event does not occur
+        within the time limit.
+    """
+    timeout_ms = int(timeout_sec * 1000)
+    try:
+      return self._event_client.eventWaitAndGet(callback_id, event_name,
+                                                timeout_ms)
+    except Exception as e:
+      if TIMEOUT_ERROR_MESSAGE in str(e):
+        raise errors.CallbackHandlerTimeoutError(
+            self._device, (f'Timed out after waiting {timeout_sec}s for event '
+                           f'"{event_name}" triggered by {self._method_name} '
+                           f'({self.callback_id}).')) from e
+      raise
+
+  def callEventGetAllRpc(self, callback_id, event_name):
+    """Gets all existing events for the specified identifier without waiting.
+
+    This function calls snippet lib's eventGetAll RPC.
+
+    Args:
+      callback_id: str, the callback identifier.
+      event_name: str, the callback name.
+
+    Returns:
+      A list of event dictionaries.
+    """
+    return self._event_client.eventGetAll(callback_id, event_name)
diff --git a/mobly/controllers/android_device_lib/snippet_client_v2.py b/mobly/controllers/android_device_lib/snippet_client_v2.py
index 0c9540a..3adfde5 100644
--- a/mobly/controllers/android_device_lib/snippet_client_v2.py
+++ b/mobly/controllers/android_device_lib/snippet_client_v2.py
@@ -20,7 +20,7 @@
 
 from mobly import utils
 from mobly.controllers.android_device_lib import adb
-from mobly.controllers.android_device_lib import callback_handler
+from mobly.controllers.android_device_lib import callback_handler_v2
 from mobly.controllers.android_device_lib import errors as android_device_lib_errors
 from mobly.snippet import client_base
 from mobly.snippet import errors
@@ -70,7 +70,10 @@
 _SOCKET_CONNECTION_TIMEOUT = 60
 
 # Maximum time to wait for a response message on the socket.
-_SOCKET_READ_TIMEOUT = callback_handler.MAX_TIMEOUT
+_SOCKET_READ_TIMEOUT = 60 * 10
+
+# The default timeout for callback handlers returned by this client
+_CALLBACK_DEFAULT_TIMEOUT_SEC = 60 * 2
 
 
 class ConnectionHandshakeCommand(enum.Enum):
@@ -505,11 +508,14 @@
     """
     if self._event_client is None:
       self._create_event_client()
-    return callback_handler.CallbackHandler(callback_id=callback_id,
-                                            event_client=self._event_client,
-                                            ret_value=ret_value,
-                                            method_name=rpc_func_name,
-                                            ad=self._device)
+    return callback_handler_v2.CallbackHandlerV2(
+        callback_id=callback_id,
+        event_client=self._event_client,
+        ret_value=ret_value,
+        method_name=rpc_func_name,
+        device=self._device,
+        rpc_max_timeout_sec=_SOCKET_READ_TIMEOUT,
+        default_timeout_sec=_CALLBACK_DEFAULT_TIMEOUT_SEC)
 
   def _create_event_client(self):
     """Creates a separate client to the same session for propagating events.
diff --git a/mobly/snippet/callback_event.py b/mobly/snippet/callback_event.py
new file mode 100644
index 0000000..55471c7
--- /dev/null
+++ b/mobly/snippet/callback_event.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The class that represents callback events for Mobly Snippet Lib."""
+
+
+def from_dict(event_dict):
+  """Creates a CallbackEvent object from a dictionary.
+
+  Args:
+    event_dict: dict, a dictionary representing an event.
+
+  Returns:
+    A CallbackEvent object.
+  """
+  return CallbackEvent(callback_id=event_dict['callbackId'],
+                       name=event_dict['name'],
+                       creation_time=event_dict['time'],
+                       data=event_dict['data'])
+
+
+class CallbackEvent:
+  """The class that represents callback events for Mobly Snippet Library.
+
+  Attributes:
+    callback_id: str, the callback ID associated with the event.
+    name: str, the name of the event.
+    creation_time: int, the epoch time when the event is created on the
+      RPC server side.
+    data: dict, the data held by the event. Can be None.
+  """
+
+  def __init__(self, callback_id, name, creation_time, data):
+    self.callback_id = callback_id
+    self.name = name
+    self.creation_time = creation_time
+    self.data = data
+
+  def __repr__(self):
+    return (
+        f'CallbackEvent(callback_id: {self.callback_id}, name: {self.name}, '
+        f'creation_time: {self.creation_time}, data: {self.data})')
diff --git a/mobly/snippet/callback_handler_base.py b/mobly/snippet/callback_handler_base.py
new file mode 100644
index 0000000..50465d1
--- /dev/null
+++ b/mobly/snippet/callback_handler_base.py
@@ -0,0 +1,240 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for the base class to handle Mobly Snippet Lib's callback events."""
+import abc
+import time
+
+from mobly.snippet import callback_event
+from mobly.snippet import errors
+
+
+class CallbackHandlerBase(abc.ABC):
+  """Base class for handling Mobly Snippet Lib's callback events.
+
+  All the events handled by a callback handler are originally triggered by one
+  async RPC call. All the events are tagged with a callback_id specific to a
+  call to an async RPC method defined on the server side.
+
+  The raw message representing an event looks like:
+
+  .. code-block:: python
+
+    {
+      'callbackId': <string, callbackId>,
+      'name': <string, name of the event>,
+      'time': <long, epoch time of when the event was created on the
+        server side>,
+      'data': <dict, extra data from the callback on the server side>
+    }
+
+  Each message is then used to create a CallbackEvent object on the client
+  side.
+
+  Attributes:
+    ret_value: any, the direct return value of the async RPC call.
+  """
+
+  def __init__(self,
+               callback_id,
+               event_client,
+               ret_value,
+               method_name,
+               device,
+               rpc_max_timeout_sec,
+               default_timeout_sec=120):
+    """Initializes a callback handler base object.
+
+    Args:
+      callback_id: str, the callback ID which associates with a group of
+        callback events.
+      event_client: SnippetClientV2, the client object used to send RPC to the
+        server and receive response.
+      ret_value: any, the direct return value of the async RPC call.
+      method_name: str, the name of the executed Async snippet function.
+      device: DeviceController, the device object associated with this handler.
+      rpc_max_timeout_sec: float, maximum time for sending a single RPC call.
+      default_timeout_sec: float, the default timeout for this handler. It
+        must be no longer than rpc_max_timeout_sec.
+    """
+    self._id = callback_id
+    self.ret_value = ret_value
+    self._device = device
+    self._event_client = event_client
+    self._method_name = method_name
+
+    if rpc_max_timeout_sec < default_timeout_sec:
+      raise ValueError('The max timeout of a single RPC must be no smaller '
+                       'than the default timeout of the callback handler. '
+                       f'Got rpc_max_timeout_sec={rpc_max_timeout_sec}, '
+                       f'default_timeout_sec={default_timeout_sec}.')
+    self._rpc_max_timeout_sec = rpc_max_timeout_sec
+    self._default_timeout_sec = default_timeout_sec
+
+  @property
+  def rpc_max_timeout_sec(self):
+    """Maximum time for sending a single RPC call."""
+    return self._rpc_max_timeout_sec
+
+  @property
+  def default_timeout_sec(self):
+    """Default timeout used by this callback handler."""
+    return self._default_timeout_sec
+
+  @property
+  def callback_id(self):
+    """The callback ID which associates a group of callback events."""
+    return self._id
+
+  @abc.abstractmethod
+  def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec):
+    """Calls snippet lib's RPC to wait for a callback event.
+
+    Override this method to use this class with various snippet lib
+    implementations.
+
+    This function waits and gets a CallbackEvent with the specified identifier
+    from the server. It will raise a timeout error if the expected event does
+    not occur within the time limit.
+
+    Args:
+      callback_id: str, the callback identifier.
+      event_name: str, the callback name.
+      timeout_sec: float, the number of seconds to wait for the event. It is
+        already checked that this argument is no longer than the max timeout
+        of a single RPC.
+
+    Returns:
+      The event dictionary.
+
+    Raises:
+      errors.CallbackHandlerTimeoutError: Raised if the expected event does not
+        occur within the time limit.
+    """
+
+  @abc.abstractmethod
+  def callEventGetAllRpc(self, callback_id, event_name):
+    """Calls snippet lib's RPC to get all existing snippet events.
+
+    Override this method to use this class with various snippet lib
+    implementations.
+
+    This function gets all existing events in the server with the specified
+    identifier without waiting.
+
+    Args:
+      callback_id: str, the callback identifier.
+      event_name: str, the callback name.
+
+    Returns:
+      A list of event dictionaries.
+    """
+
+  def waitAndGet(self, event_name, timeout=None):
+    """Waits and gets a CallbackEvent with the specified identifier.
+
+    It will raise a timeout error if the expected event does not occur within
+    the time limit.
+
+    Args:
+      event_name: str, the name of the event to get.
+      timeout: float, the number of seconds to wait before giving up. If None,
+        it will be set to self.default_timeout_sec.
+
+    Returns:
+      CallbackEvent, the oldest entry of the specified event.
+
+    Raises:
+      errors.CallbackHandlerBaseError: If the specified timeout is longer than
+        the max timeout supported.
+      errors.CallbackHandlerTimeoutError: The expected event does not occur
+        within the time limit.
+    """
+    if timeout is None:
+      timeout = self.default_timeout_sec
+
+    if timeout:
+      if timeout > self.rpc_max_timeout_sec:
+        raise errors.CallbackHandlerBaseError(
+            self._device,
+            f'Specified timeout {timeout} is longer than max timeout '
+            f'{self.rpc_max_timeout_sec}.')
+
+    raw_event = self.callEventWaitAndGetRpc(self._id, event_name, timeout)
+    return callback_event.from_dict(raw_event)
+
+  def waitForEvent(self, event_name, predicate, timeout=None):
+    """Waits for an event of the specific name that satisfies the predicate.
+
+    This call will block until the expected event has been received or time
+    out.
+
+    The predicate function defines the condition the event is expected to
+    satisfy. It takes an event and returns True if the condition is
+    satisfied, False otherwise.
+
+    Note all events of the same name that are received but don't satisfy
+    the predicate will be discarded and not be available for further
+    consumption.
+
+    Args:
+      event_name: str, the name of the event to wait for.
+      predicate: function, a function that takes an event (dictionary) and
+        returns a bool.
+      timeout: float, the number of seconds to wait before giving up. If None,
+        it will be set to self.default_timeout_sec.
+
+    Returns:
+      dictionary, the event that satisfies the predicate if received.
+
+    Raises:
+      errors.CallbackHandlerTimeoutError: raised if no event that satisfies the
+        predicate is received after timeout seconds.
+    """
+    if timeout is None:
+      timeout = self.default_timeout_sec
+
+    deadline = time.perf_counter() + timeout
+    while time.perf_counter() <= deadline:
+      single_rpc_timeout = deadline - time.perf_counter()
+      if single_rpc_timeout < 0:
+        break
+
+      single_rpc_timeout = min(single_rpc_timeout, self.rpc_max_timeout_sec)
+      try:
+        event = self.waitAndGet(event_name, single_rpc_timeout)
+      except errors.CallbackHandlerTimeoutError:
+        # Ignoring errors.CallbackHandlerTimeoutError since we need to throw
+        # one with a more specific message.
+        break
+      if predicate(event):
+        return event
+
+    raise errors.CallbackHandlerTimeoutError(
+        self._device,
+        f'Timed out after {timeout}s waiting for an "{event_name}" event that '
+        f'satisfies the predicate "{predicate.__name__}".')
+
+  def getAll(self, event_name):
+    """Gets all existing events in the server with the specified identifier.
+
+    This is a non-blocking call.
+
+    Args:
+      event_name: str, the name of the event to get.
+
+    Returns:
+      A list of CallbackEvent, each representing an event from the Server side.
+    """
+    raw_events = self.callEventGetAllRpc(self._id, event_name)
+    return [callback_event.from_dict(msg) for msg in raw_events]
diff --git a/mobly/snippet/errors.py b/mobly/snippet/errors.py
index 764aea4..4d41adb 100644
--- a/mobly/snippet/errors.py
+++ b/mobly/snippet/errors.py
@@ -60,3 +60,12 @@
 
 class ServerDiedError(Error):
   """Raised if the snippet server died before all tests finish."""
+
+
+# Error types for callback handlers
+class CallbackHandlerBaseError(errors.DeviceError):
+  """Base error type for snippet clients."""
+
+
+class CallbackHandlerTimeoutError(Error):
+  """Raised if the expected event does not occur within the time limit."""
diff --git a/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py b/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py
new file mode 100644
index 0000000..b598cae
--- /dev/null
+++ b/tests/mobly/controllers/android_device_lib/callback_handler_v2_test.py
@@ -0,0 +1,152 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for callback_handler_v2.CallbackHandlerV2."""
+
+import unittest
+from unittest import mock
+
+from mobly.controllers.android_device_lib import callback_handler_v2
+from mobly.snippet import callback_event
+from mobly.snippet import errors
+
+MOCK_CALLBACK_ID = '2-1'
+MOCK_RAW_EVENT = {
+    'callbackId': '2-1',
+    'name': 'AsyncTaskResult',
+    'time': 20460228696,
+    'data': {
+        'exampleData': "Here's a simple event.",
+        'successful': True,
+        'secretNumber': 12
+    }
+}
+
+
+class CallbackHandlerV2Test(unittest.TestCase):
+  """Unit tests for callback_handler_v2.CallbackHandlerV2."""
+
+  def _make_callback_handler(self,
+                             callback_id=None,
+                             event_client=None,
+                             ret_value=None,
+                             method_name=None,
+                             device=None,
+                             rpc_max_timeout_sec=600,
+                             default_timeout_sec=120):
+    return callback_handler_v2.CallbackHandlerV2(
+        callback_id=callback_id,
+        event_client=event_client,
+        ret_value=ret_value,
+        method_name=method_name,
+        device=device,
+        rpc_max_timeout_sec=rpc_max_timeout_sec,
+        default_timeout_sec=default_timeout_sec)
+
+  def assert_event_correct(self, actual_event, expected_raw_event_dict):
+    expected_event = callback_event.from_dict(expected_raw_event_dict)
+    self.assertEqual(str(actual_event), str(expected_event))
+
+  def test_wait_and_get(self):
+    mock_event_client = mock.Mock()
+    mock_event_client.eventWaitAndGet = mock.Mock(return_value=MOCK_RAW_EVENT)
+    handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+                                          event_client=mock_event_client)
+    event = handler.waitAndGet('ha')
+    self.assert_event_correct(event, MOCK_RAW_EVENT)
+    mock_event_client.eventWaitAndGet.assert_called_once_with(
+        MOCK_CALLBACK_ID, 'ha', mock.ANY)
+
+  def test_wait_and_get_timeout_arg_transform(self):
+    mock_event_client = mock.Mock()
+    mock_event_client.eventWaitAndGet = mock.Mock(return_value=MOCK_RAW_EVENT)
+    handler = self._make_callback_handler(event_client=mock_event_client)
+
+    wait_and_get_timeout_sec = 10
+    expected_rpc_timeout_ms = 10000
+    _ = handler.waitAndGet('ha', timeout=wait_and_get_timeout_sec)
+    mock_event_client.eventWaitAndGet.assert_called_once_with(
+        mock.ANY, mock.ANY, expected_rpc_timeout_ms)
+
+  def test_wait_for_event(self):
+    mock_event_client = mock.Mock()
+    handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+                                          event_client=mock_event_client)
+
+    event_should_ignore = {
+        'callbackId': '2-1',
+        'name': 'AsyncTaskResult',
+        'time': 20460228696,
+        'data': {
+            'successful': False,
+        }
+    }
+    mock_event_client.eventWaitAndGet.side_effect = [
+        event_should_ignore, MOCK_RAW_EVENT
+    ]
+
+    def some_condition(event):
+      return event.data['successful']
+
+    event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+    self.assert_event_correct(event, MOCK_RAW_EVENT)
+    mock_event_client.eventWaitAndGet.assert_has_calls([
+        mock.call(MOCK_CALLBACK_ID, 'AsyncTaskResult', mock.ANY),
+        mock.call(MOCK_CALLBACK_ID, 'AsyncTaskResult', mock.ANY),
+    ])
+
+  def test_get_all(self):
+    mock_event_client = mock.Mock()
+    handler = self._make_callback_handler(callback_id=MOCK_CALLBACK_ID,
+                                          event_client=mock_event_client)
+
+    mock_event_client.eventGetAll = mock.Mock(
+        return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT])
+
+    all_events = handler.getAll('ha')
+    self.assertEqual(len(all_events), 2)
+    for event in all_events:
+      self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+    mock_event_client.eventGetAll.assert_called_once_with(
+        MOCK_CALLBACK_ID, 'ha')
+
+  def test_wait_and_get_timeout_message_pattern_matches(self):
+    mock_event_client = mock.Mock()
+    android_snippet_timeout_msg = (
+        'com.google.android.mobly.snippet.event.EventSnippet$'
+        'EventSnippetException: timeout.')
+    mock_event_client.eventWaitAndGet = mock.Mock(
+        side_effect=errors.ApiError(mock.Mock(), android_snippet_timeout_msg))
+    handler = self._make_callback_handler(event_client=mock_event_client,
+                                          method_name='test_method')
+
+    expected_msg = ('Timed out after waiting .*s for event "ha" triggered by '
+                    'test_method .*')
+    with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError,
+                                expected_msg):
+      handler.waitAndGet('ha')
+
+  def test_wait_and_get_reraise_if_pattern_not_match(self):
+    mock_event_client = mock.Mock()
+    snippet_timeout_msg = 'Snippet executed with error.'
+    mock_event_client.eventWaitAndGet = mock.Mock(
+        side_effect=errors.ApiError(mock.Mock(), snippet_timeout_msg))
+    handler = self._make_callback_handler(event_client=mock_event_client)
+
+    with self.assertRaisesRegex(errors.ApiError, snippet_timeout_msg):
+      handler.waitAndGet('ha')
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
index c7e8ef1..1943abb 100644
--- a/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
+++ b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
@@ -209,8 +209,8 @@
   @mock.patch('mobly.utils.stop_standing_subprocess')
   @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
               'utils.start_standing_subprocess')
-  @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
-              'CallbackHandler')
+  @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+              'CallbackHandlerV2')
   def test_the_whole_lifecycle_with_an_async_rpc(self, mock_callback_class,
                                                  mock_start_subprocess,
                                                  mock_stop_standing_subprocess,
@@ -244,11 +244,14 @@
 
     self.assertListEqual(self.mock_socket_file.write.call_args_list,
                          expected_socket_writes)
-    mock_callback_class.assert_called_with(callback_id='1-0',
-                                           event_client=event_client,
-                                           ret_value=123,
-                                           method_name='some_async_rpc',
-                                           ad=self.device)
+    mock_callback_class.assert_called_with(
+        callback_id='1-0',
+        event_client=event_client,
+        ret_value=123,
+        method_name='some_async_rpc',
+        device=self.device,
+        rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+        default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
     self.assertIs(rpc_result, mock_callback_class.return_value)
     self.assertIsNone(event_client.host_port, None)
     self.assertIsNone(event_client.device_port, None)
@@ -260,8 +263,8 @@
   @mock.patch('mobly.utils.stop_standing_subprocess')
   @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
               'utils.start_standing_subprocess')
-  @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
-              'CallbackHandler')
+  @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+              'CallbackHandlerV2')
   def test_the_whole_lifecycle_with_multiple_rpcs(self, mock_callback_class,
                                                   mock_start_subprocess,
                                                   mock_stop_standing_subprocess,
@@ -311,12 +314,18 @@
                   event_client=event_client,
                   ret_value=456,
                   method_name='some_async_rpc',
-                  ad=self.device),
-        mock.call(callback_id='2-0',
-                  event_client=event_client,
-                  ret_value=321,
-                  method_name='some_async_rpc',
-                  ad=self.device),
+                  device=self.device,
+                  rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+                  default_timeout_sec=(
+                      snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)),
+        mock.call(
+            callback_id='2-0',
+            event_client=event_client,
+            ret_value=321,
+            method_name='some_async_rpc',
+            device=self.device,
+            rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+            default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
     ]
     self.assertListEqual(rpc_results, rpc_results_expected)
     mock_callback_class.assert_has_calls(mock_callback_class_calls_expected)
@@ -821,8 +830,8 @@
   @mock.patch('socket.create_connection')
   @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
               'utils.start_standing_subprocess')
-  @mock.patch('mobly.controllers.android_device_lib.callback_handler.'
-              'CallbackHandler')
+  @mock.patch('mobly.controllers.android_device_lib.callback_handler_v2.'
+              'CallbackHandlerV2')
   def test_async_rpc_start_event_client(self, mock_callback_class,
                                         mock_start_subprocess,
                                         mock_socket_create_conn):
@@ -855,7 +864,9 @@
         event_client=self.client._event_client,
         ret_value=123,
         method_name='some_async_rpc',
-        ad=self.device)
+        device=self.device,
+        rpc_max_timeout_sec=snippet_client_v2._SOCKET_READ_TIMEOUT,
+        default_timeout_sec=snippet_client_v2._CALLBACK_DEFAULT_TIMEOUT_SEC)
     self.assertIs(rpc_result, mock_callback_class.return_value)
 
     # Ensure the event client is alive
diff --git a/tests/mobly/snippet/callback_event_test.py b/tests/mobly/snippet/callback_event_test.py
new file mode 100755
index 0000000..2593cc3
--- /dev/null
+++ b/tests/mobly/snippet/callback_event_test.py
@@ -0,0 +1,40 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.callback_event.CallbackEvent."""
+
+import unittest
+
+from mobly.snippet import callback_event
+
+MOCK_CALLBACK_ID = 'myCallbackId'
+MOCK_EVENT_NAME = 'onXyzEvent'
+MOCK_CREATION_TIME = '12345678'
+MOCK_DATA = {'foo': 'bar'}
+
+
+class CallbackEventTest(unittest.TestCase):
+  """Unit tests for mobly.snippet.callback_event.CallbackEvent."""
+
+  def test_basic(self):
+    """Verifies that an event object can be created and logged properly."""
+    event = callback_event.CallbackEvent(MOCK_CALLBACK_ID, MOCK_EVENT_NAME,
+                                         MOCK_CREATION_TIME, MOCK_DATA)
+    self.assertEqual(
+        repr(event),
+        "CallbackEvent(callback_id: myCallbackId, name: onXyzEvent, "
+        "creation_time: 12345678, data: {'foo': 'bar'})")
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/tests/mobly/snippet/callback_handler_base_test.py b/tests/mobly/snippet/callback_handler_base_test.py
new file mode 100644
index 0000000..cda26ef
--- /dev/null
+++ b/tests/mobly/snippet/callback_handler_base_test.py
@@ -0,0 +1,183 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
+
+import unittest
+from unittest import mock
+
+from mobly.snippet import callback_event
+from mobly.snippet import callback_handler_base
+from mobly.snippet import errors
+
+MOCK_CALLBACK_ID = '2-1'
+MOCK_RAW_EVENT = {
+    'callbackId': '2-1',
+    'name': 'AsyncTaskResult',
+    'time': 20460228696,
+    'data': {
+        'exampleData': "Here's a simple event.",
+        'successful': True,
+        'secretNumber': 12
+    }
+}
+
+
+class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase):
+  """Fake client class for unit tests."""
+
+  def __init__(self,
+               callback_id=None,
+               event_client=None,
+               ret_value=None,
+               method_name=None,
+               device=None,
+               rpc_max_timeout_sec=120,
+               default_timeout_sec=120):
+    """Initializes a fake callback handler object used for unit tests."""
+    super().__init__(callback_id, event_client, ret_value, method_name, device,
+                     rpc_max_timeout_sec, default_timeout_sec)
+    self.mock_rpc_func = mock.Mock()
+
+  def callEventWaitAndGetRpc(self, *args, **kwargs):
+    """See base class."""
+    return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs)
+
+  def callEventGetAllRpc(self, *args, **kwargs):
+    """See base class."""
+    return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs)
+
+
+class CallbackHandlerBaseTest(unittest.TestCase):
+  """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
+
+  def assert_event_correct(self, actual_event, expected_raw_event_dict):
+    expected_event = callback_event.from_dict(expected_raw_event_dict)
+    self.assertEqual(str(actual_event), str(expected_event))
+
+  def test_default_timeout_too_large(self):
+    err_msg = ('The max timeout of a single RPC must be no smaller than '
+               'the default timeout of the callback handler. '
+               'Got rpc_max_timeout_sec=10, default_timeout_sec=20.')
+    with self.assertRaisesRegex(ValueError, err_msg):
+      _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20)
+
+  def test_timeout_property(self):
+    handler = FakeCallbackHandler(rpc_max_timeout_sec=20,
+                                  default_timeout_sec=10)
+    self.assertEqual(handler.rpc_max_timeout_sec, 20)
+    self.assertEqual(handler.default_timeout_sec, 10)
+    with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+      handler.rpc_max_timeout_sec = 5
+
+    with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+      handler.default_timeout_sec = 5
+
+  def test_callback_id_property(self):
+    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+    self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID)
+    with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+      handler.callback_id = 'ha'
+
+  def test_event_dict_to_snippet_event(self):
+    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+
+    event = handler.waitAndGet('ha', timeout=10)
+    self.assert_event_correct(event, MOCK_RAW_EVENT)
+    handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
+        MOCK_CALLBACK_ID, 'ha', 10)
+
+  def test_wait_and_get_timeout_default(self):
+    handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5)
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+    _ = handler.waitAndGet('ha')
+    handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
+        mock.ANY, mock.ANY, 5)
+
+  def test_wait_and_get_timeout_ecxeed_threshold(self):
+    rpc_max_timeout_sec = 5
+    big_timeout_sec = 10
+    handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
+                                  default_timeout_sec=rpc_max_timeout_sec)
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+
+    expected_msg = (
+        f'Specified timeout {big_timeout_sec} is longer than max timeout '
+        f'{rpc_max_timeout_sec}.')
+    with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg):
+      handler.waitAndGet('ha', big_timeout_sec)
+
+  def test_wait_for_event(self):
+    handler = FakeCallbackHandler()
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+
+    def some_condition(event):
+      return event.data['successful']
+
+    event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+    self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+  def test_wait_for_event_negative(self):
+    handler = FakeCallbackHandler()
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+
+    expected_msg = (
+        'Timed out after 0.01s waiting for an "AsyncTaskResult" event that'
+        ' satisfies the predicate "some_condition".')
+
+    def some_condition(_):
+      return False
+
+    with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError,
+                                expected_msg):
+      handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
+
+  def test_wait_for_event_max_timeout(self):
+    """waitForEvent should not raise the timeout exceed threshold error."""
+    rpc_max_timeout_sec = 5
+    big_timeout_sec = 10
+    handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
+                                  default_timeout_sec=rpc_max_timeout_sec)
+    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
+        return_value=MOCK_RAW_EVENT)
+
+    def some_condition(event):
+      return event.data['successful']
+
+    # This line should not raise.
+    event = handler.waitForEvent('AsyncTaskResult',
+                                 some_condition,
+                                 timeout=big_timeout_sec)
+    self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+  def test_get_all(self):
+    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
+    handler.mock_rpc_func.callEventGetAllRpc = mock.Mock(
+        return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT])
+
+    all_events = handler.getAll('ha')
+    for event in all_events:
+      self.assert_event_correct(event, MOCK_RAW_EVENT)
+
+    handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with(
+        MOCK_CALLBACK_ID, 'ha')
+
+
+if __name__ == '__main__':
+  unittest.main()