Add a new base class for snippet client (#795)

This is the first PR to kick off the standardization of Mobly snippet client, which would enable us to scale the snippet mechanism to a broader set of platforms.

Once the new snippet client code is proven, we will enable it by default and remove the old snippet client under `android_device_lib`.
diff --git a/mobly/snippet/__init__.py b/mobly/snippet/__init__.py
new file mode 100644
index 0000000..ac3f9e6
--- /dev/null
+++ b/mobly/snippet/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/mobly/snippet/client_base.py b/mobly/snippet/client_base.py
new file mode 100644
index 0000000..8fdd6c9
--- /dev/null
+++ b/mobly/snippet/client_base.py
@@ -0,0 +1,466 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The JSON RPC client base for communicating with snippet servers.
+
+The JSON RPC protocol expected by this module is:
+
+.. code-block:: json
+
+  Request:
+  {
+    'id': <Required. Monotonically increasing integer containing the ID of this
+          request.>,
+    'method': <Required. String containing the name of the method to execute.>,
+    'params': <Required. JSON array containing the arguments to the method,
+              `null` if no positional arguments for the RPC method.>,
+    'kwargs': <Optional. JSON dict containing the keyword arguments for the
+              method, `null` if no positional arguments for the RPC method.>,
+  }
+
+  Response:
+  {
+    'error': <Required. String containing the error thrown by executing the
+             method, `null` if no error occurred.>,
+    'id': <Required. Int id of request that this response maps to.>,
+    'result': <Required. Arbitrary JSON object containing the result of
+              executing the method, `null` if the method could not be executed
+              or returned void.>,
+    'callback': <Required. String that represents a callback ID used to
+                identify events associated with a particular CallbackHandler
+                object, `null` if this is not a async RPC.>,
+  }
+"""
+
+import abc
+import contextlib
+import enum
+import json
+import threading
+import time
+
+from mobly.snippet import errors
+
+# Maximum logging length of RPC response in DEBUG level when verbose logging is
+# off.
+_MAX_RPC_RESP_LOGGING_LENGTH = 1024
+
+# The required field names of RPC response.
+RPC_RESPONSE_REQUIRED_FIELDS = ('id', 'error', 'result', 'callback')
+
+
+class StartServerStages(enum.Enum):
+  """The stages for the starting server process."""
+  BEFORE_STARTING_SERVER = 1
+  DO_START_SERVER = 2
+  BUILD_CONNECTION = 3
+  AFTER_STARTING_SERVER = 4
+
+
+class ClientBase(abc.ABC):
+  """Base class for JSON RPC clients that connect to snippet servers.
+
+  Connects to a remote device running a JSON RPC compatible server. Users call
+  the function `start_server` to start the server on the remote device before
+  sending any RPC. After sending all RPCs, users call the function `stop_server`
+  to stop all the running instances.
+
+  Attributes:
+    package: str, the user-visible name of the snippet library being
+      communicated with.
+    host_port: int, the host port of this RPC client.
+    device_port: int, the device port of this RPC client.
+    log: Logger, the logger of the corresponding device controller.
+    verbose_logging: bool, if True, prints more detailed log
+      information. Default is True.
+  """
+
+  def __init__(self, package, device):
+    """Initializes the instance of ClientBase.
+
+    Args:
+      package: str, the user-visible name of the snippet library being
+        communicated with.
+      device: DeviceController, the device object associated with a client.
+    """
+
+    self.package = package
+    self.host_port = None
+    self.device_port = None
+    self.log = device.log
+    self.verbose_logging = True
+    self._device = device
+    self._counter = None
+    self._lock = threading.Lock()
+    self._event_client = None
+
+  def __del__(self):
+    self.close_connection()
+
+  def start_server(self):
+    """Starts the server on the remote device and connects to it.
+
+    This process contains four stages:
+      - before_starting_server: prepares for starting the server.
+      - do_start_server: starts the server on the remote device.
+      - build_connection: builds a connection with the server.
+      - after_starting_server: does the things after the server is available.
+
+    After this, the self.host_port and self.device_port attributes must be
+    set.
+
+    Raises:
+      errors.ProtocolError: something went wrong when exchanging data with the
+        server.
+      errors.ServerStartPreCheckError: when prechecks for starting the server
+        failed.
+      errors.ServerStartError: when failed to start the snippet server.
+    """
+
+    @contextlib.contextmanager
+    def _execute_one_stage(stage):
+      """Context manager for executing one stage.
+
+      Args:
+        stage: StartServerStages, the stage which is running under this
+          context manager.
+
+      Yields:
+        None.
+      """
+      self.log.debug('[START_SERVER] Running the stage %s.', stage.name)
+      yield
+      self.log.debug('[START_SERVER] Finished the stage %s.', stage.name)
+
+    self.log.debug('Starting the server.')
+    start_time = time.perf_counter()
+
+    with _execute_one_stage(StartServerStages.BEFORE_STARTING_SERVER):
+      self.before_starting_server()
+
+    try:
+      with _execute_one_stage(StartServerStages.DO_START_SERVER):
+        self.do_start_server()
+
+      with _execute_one_stage(StartServerStages.BUILD_CONNECTION):
+        self._build_connection()
+
+      with _execute_one_stage(StartServerStages.AFTER_STARTING_SERVER):
+        self.after_starting_server()
+
+    except Exception:
+      self.log.error('[START SERVER] Error occurs when starting the server.')
+      try:
+        self.stop_server()
+      except Exception:  # pylint: disable=broad-except
+        # Only prints this exception and re-raises the original exception
+        self.log.exception('[START_SERVER] Failed to stop server because of '
+                           'new exception.')
+
+      raise
+
+    self.log.debug('Snippet %s started after %.1fs on host port %d.',
+                   self.package,
+                   time.perf_counter() - start_time, self.host_port)
+
+  @abc.abstractmethod
+  def before_starting_server(self):
+    """Prepares for starting the server.
+
+    For example, subclass can check or modify the device settings at this
+    stage.
+
+    Raises:
+      errors.ServerStartPreCheckError: when prechecks for starting the server
+        failed.
+    """
+
+  @abc.abstractmethod
+  def do_start_server(self):
+    """Starts the server on the remote device.
+
+    The client has completed the preparations, so the client calls this
+    function to start the server.
+    """
+
+  def _build_connection(self):
+    """Proxy function of build_connection.
+
+    This function resets the RPC id counter before calling `build_connection`.
+    """
+    self._counter = self._id_counter()
+    self.build_connection()
+
+  @abc.abstractmethod
+  def build_connection(self):
+    """Builds a connection with the server on the remote device.
+
+    The command to start the server has been already sent before calling this
+    function. So the client builds a connection to it and sends a handshake
+    to ensure the server is available for upcoming RPCs.
+
+    This function uses self.host_port for communicating with the server. If
+    self.host_port is 0 or None, this function finds an available host port to
+    build connection and set self.host_port to the found port.
+
+    Raises:
+      errors.ProtocolError: something went wrong when exchanging data with the
+        server.
+    """
+
+  @abc.abstractmethod
+  def after_starting_server(self):
+    """Does the things after the server is available.
+
+    For example, subclass can get device information from the server.
+    """
+
+  def __getattr__(self, name):
+    """Wrapper for python magic to turn method calls into RPCs."""
+
+    def rpc_call(*args, **kwargs):
+      return self._rpc(name, *args, **kwargs)
+
+    return rpc_call
+
+  def _id_counter(self):
+    """Returns an id generator."""
+    i = 0
+    while True:
+      yield i
+      i += 1
+
+  def set_snippet_client_verbose_logging(self, verbose):
+    """Switches verbose logging. True for logging full RPC responses.
+
+    By default it will write full messages returned from RPCs. Turning off the
+    verbose logging will result in writing no more than
+    _MAX_RPC_RESP_LOGGING_LENGTH characters per RPC returned string.
+
+    _MAX_RPC_RESP_LOGGING_LENGTH will be set to 1024 by default. The length
+    contains the full RPC response in JSON format, not just the RPC result
+    field.
+
+    Args:
+      verbose: bool, if True, turns on verbose logging, otherwise turns off.
+    """
+    self.log.info('Sets verbose logging to %s.', verbose)
+    self.verbose_logging = verbose
+
+  @abc.abstractmethod
+  def restore_server_connection(self, port=None):
+    """Reconnects to the server after the device was disconnected.
+
+    Instead of creating a new instance of the client:
+      - Uses the given port (or finds a new available host_port if 0 or None is
+      given).
+      - Tries to connect to the remote server with the selected port.
+
+    Args:
+      port: int, if given, this is the host port from which to connect to the
+        remote device port. Otherwise, finds a new available port as host
+        port.
+
+    Raises:
+      errors.ServerRestoreConnectionError: when failed to restore the connection
+        with the snippet server.
+    """
+
+  def _rpc(self, rpc_func_name, *args, **kwargs):
+    """Sends a RPC to the server.
+
+    Args:
+      rpc_func_name: str, the name of the snippet function to execute on the
+        server.
+      *args: any, the positional arguments of the RPC request.
+      **kwargs: any, the keyword arguments of the RPC request.
+
+    Returns:
+      The result of the RPC.
+
+    Raises:
+      errors.ProtocolError: something went wrong when exchanging data with the
+        server.
+      errors.ApiError: the RPC went through, however executed with errors.
+    """
+    try:
+      self.check_server_proc_running()
+    except Exception:
+      self.log.error(
+          'Server process running check failed, skip sending RPC method(%s).',
+          rpc_func_name)
+      raise
+
+    with self._lock:
+      rpc_id = next(self._counter)
+      request = self._gen_rpc_request(rpc_id, rpc_func_name, *args, **kwargs)
+
+      self.log.debug('Sending RPC request %s.', request)
+      response = self.send_rpc_request(request)
+      self.log.debug('RPC request sent.')
+
+      if self.verbose_logging or _MAX_RPC_RESP_LOGGING_LENGTH >= len(response):
+        self.log.debug('Snippet received: %s', response)
+      else:
+        self.log.debug('Snippet received: %s... %d chars are truncated',
+                       response[:_MAX_RPC_RESP_LOGGING_LENGTH],
+                       len(response) - _MAX_RPC_RESP_LOGGING_LENGTH)
+
+    response_decoded = self._decode_response_string_and_validate_format(
+        rpc_id, response)
+    return self._handle_rpc_response(rpc_func_name, response_decoded)
+
+  @abc.abstractmethod
+  def check_server_proc_running(self):
+    """Checks whether the server is still running.
+
+    If the server is not running, it throws an error. As this function is called
+    each time the client tries to send a RPC, this should be a quick check
+    without affecting performance. Otherwise it is fine to not check anything.
+
+    Raises:
+      errors.ServerDiedError: if the server died.
+    """
+
+  def _gen_rpc_request(self, rpc_id, rpc_func_name, *args, **kwargs):
+    """Generates the JSON RPC request.
+
+    In the generated JSON string, the fields are sorted by keys in ascending
+    order.
+
+    Args:
+      rpc_id: int, the id of this RPC.
+      rpc_func_name: str, the name of the snippet function to execute
+        on the server.
+      *args: any, the positional arguments of the RPC.
+      **kwargs: any, the keyword arguments of the RPC.
+
+    Returns:
+      A string of the JSON RPC request.
+    """
+    data = {'id': rpc_id, 'method': rpc_func_name, 'params': args}
+    if kwargs:
+      data['kwargs'] = kwargs
+    return json.dumps(data, sort_keys=True)
+
+  @abc.abstractmethod
+  def send_rpc_request(self, request):
+    """Sends the JSON RPC request to the server and gets a response.
+
+    Note that the request and response are both in string format. So if the
+    connection with server provides interfaces in bytes format, please
+    transform them to string in the implementation of this function.
+
+    Args:
+      request: str, a string of the RPC request.
+
+    Returns:
+      A string of the RPC response.
+
+    Raises:
+      errors.ProtocolError: something went wrong when exchanging data with the
+        server.
+    """
+
+  def _decode_response_string_and_validate_format(self, rpc_id, response):
+    """Decodes response JSON string to python dict and validates its format.
+
+    Args:
+      rpc_id: int, the actual id of this RPC. It should be the same with the id
+        in the response, otherwise throws an error.
+      response: str, the JSON string of the RPC response.
+
+    Returns:
+      A dict decoded from the response JSON string.
+
+    Raises:
+      errors.ProtocolError: if the response format is invalid.
+    """
+    if not response:
+      raise errors.ProtocolError(self._device,
+                                 errors.ProtocolError.NO_RESPONSE_FROM_SERVER)
+
+    result = json.loads(response)
+    for field_name in RPC_RESPONSE_REQUIRED_FIELDS:
+      if field_name not in result:
+        raise errors.ProtocolError(
+            self._device,
+            errors.ProtocolError.RESPONSE_MISSING_FIELD % field_name)
+
+    if result['id'] != rpc_id:
+      raise errors.ProtocolError(self._device,
+                                 errors.ProtocolError.MISMATCHED_API_ID)
+
+    return result
+
+  def _handle_rpc_response(self, rpc_func_name, response):
+    """Handles the content of RPC response.
+
+    If the RPC response contains error information, it throws an error. If the
+    RPC is asynchronous, it creates and returns a callback handler
+    object. Otherwise, it returns the result field of the response.
+
+    Args:
+      rpc_func_name: str, the name of the snippet function that this RPC
+        triggered on the snippet server.
+      response: dict, the object decoded from the response JSON string.
+
+    Returns:
+      The result of the RPC. If synchronous RPC, it is the result field of the
+      response. If asynchronous RPC, it is the callback handler object.
+
+    Raises:
+      errors.ApiError: if the snippet function executed with errors.
+    """
+
+    if response['error']:
+      raise errors.ApiError(self._device, response['error'])
+    if response['callback'] is not None:
+      return self.handle_callback(response['callback'], response['result'],
+                                  rpc_func_name)
+    return response['result']
+
+  @abc.abstractmethod
+  def handle_callback(self, callback_id, ret_value, rpc_func_name):
+    """Creates a callback handler for the asynchronous RPC.
+
+    Args:
+      callback_id: str, the callback ID for creating a callback handler object.
+      ret_value: any, the result field of the RPC response.
+      rpc_func_name: str, the name of the snippet function executed on the
+        server.
+
+    Returns:
+      The callback handler object.
+    """
+
+  def stop_server(self):
+    """Proxy function of do_stop_server."""
+    self.log.debug('Stopping snippet %s.', self.package)
+    self.do_stop_server()
+    self.log.debug('Snippet %s stopped.', self.package)
+
+  @abc.abstractmethod
+  def do_stop_server(self):
+    """Kills any running instance of the server."""
+
+  @abc.abstractmethod
+  def close_connection(self):
+    """Closes the connection to the snippet server on the device.
+
+    This is a unilateral closing from the client side, without tearing down
+    the snippet server running on the device.
+
+    The connection to the snippet server can be re-established by calling
+    `restore_server_connection`.
+    """
diff --git a/mobly/snippet/errors.py b/mobly/snippet/errors.py
new file mode 100644
index 0000000..d22e9fc
--- /dev/null
+++ b/mobly/snippet/errors.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for errors thrown from snippet client objects."""
+# TODO(mhaoli): Package `mobly.snippet` should not import errors from
+# android_device_lib. However, android_device_lib.DeviceError is the base error
+# for the errors thrown from Android snippet clients and device controllers.
+# We should resolve this legacy problem.
+from mobly.controllers.android_device_lib import errors
+
+
+class Error(errors.DeviceError):
+  """Root error type for snippet clients."""
+
+
+class ServerRestoreConnectionError(Error):
+  """Raised when failed to restore the connection with the snippet server."""
+
+
+class ServerStartError(Error):
+  """Raised when failed to start the snippet server."""
+
+
+class ServerStartPreCheckError(Error):
+  """Raised when prechecks for starting the snippet server failed."""
+
+
+class ApiError(Error):
+  """Raised when remote API reported an error."""
+
+
+class ProtocolError(Error):
+  """Raised when there was an error in exchanging data with server."""
+  NO_RESPONSE_FROM_HANDSHAKE = 'No response from handshake.'
+  NO_RESPONSE_FROM_SERVER = ('No response from server. '
+                             'Check the device logcat for crashes.')
+  MISMATCHED_API_ID = 'RPC request-response ID mismatch.'
+  RESPONSE_MISSING_FIELD = 'Missing required field in the RPC response: %s.'
+
+
+class ServerDiedError(Error):
+  """Raised if the snippet server died before all tests finish."""
diff --git a/tests/mobly/snippet/__init__.py b/tests/mobly/snippet/__init__.py
new file mode 100644
index 0000000..ac3f9e6
--- /dev/null
+++ b/tests/mobly/snippet/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/mobly/snippet/client_base_test.py b/tests/mobly/snippet/client_base_test.py
new file mode 100755
index 0000000..4da74d6
--- /dev/null
+++ b/tests/mobly/snippet/client_base_test.py
@@ -0,0 +1,442 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.client_base."""
+
+import logging
+import random
+import string
+import unittest
+from unittest import mock
+
+from mobly.snippet import client_base
+from mobly.snippet import errors
+
+
+def _generate_fix_length_rpc_response(
+    response_length,
+    template='{"id": 0, "result": "%s", "error": null, "callback": null}'):
+  """Generates a RPC response string with specified length.
+
+  This function generates a random string and formats the template with the
+  generated random string to get the response string. This function formats
+  the template with printf style string formatting.
+
+  Args:
+    response_length: int, the length of the response string to generate.
+    template: str, the template used for generating the response string.
+
+  Returns:
+    The generated response string.
+
+  Raises:
+    ValueError: if the specified length is too small to generate a response.
+  """
+  # We need to -2 here because the string formatting will delete the substring
+  # '%s' in the template, of which the length is 2.
+  result_length = response_length - (len(template) - 2)
+  if result_length < 0:
+    raise ValueError(f'The response_length should be no smaller than '
+                     f'template_length + 2. Got response_length '
+                     f'{response_length}, template_length {len(template)}.')
+  chars = string.ascii_letters + string.digits
+  return template % ''.join(random.choice(chars) for _ in range(result_length))
+
+
+class FakeClient(client_base.ClientBase):
+  """Fake client class for unit tests."""
+
+  def __init__(self):
+    """Initializes the instance by mocking a device controller."""
+    mock_device = mock.Mock()
+    mock_device.log = logging
+    super().__init__(package='FakeClient', device=mock_device)
+
+  # Override abstract methods to enable initialization
+  def before_starting_server(self):
+    pass
+
+  def do_start_server(self):
+    pass
+
+  def build_connection(self):
+    pass
+
+  def after_starting_server(self):
+    pass
+
+  def restore_server_connection(self, port=None):
+    pass
+
+  def check_server_proc_running(self):
+    pass
+
+  def send_rpc_request(self, request):
+    pass
+
+  def handle_callback(self, callback_id, ret_value, rpc_func_name):
+    pass
+
+  def do_stop_server(self):
+    pass
+
+  def close_connection(self):
+    pass
+
+
+class ClientBaseTest(unittest.TestCase):
+  """Unit tests for mobly.snippet.client_base.ClientBase."""
+
+  def setUp(self):
+    super().setUp()
+    self.client = FakeClient()
+    self.client.host_port = 12345
+
+  @mock.patch.object(FakeClient, 'before_starting_server')
+  @mock.patch.object(FakeClient, 'do_start_server')
+  @mock.patch.object(FakeClient, '_build_connection')
+  @mock.patch.object(FakeClient, 'after_starting_server')
+  def test_start_server_stage_order(self, mock_after_func, mock_build_conn_func,
+                                    mock_do_start_func, mock_before_func):
+    """Test that starting server runs its stages in expected order."""
+    order_manager = mock.Mock()
+    order_manager.attach_mock(mock_before_func, 'mock_before_func')
+    order_manager.attach_mock(mock_do_start_func, 'mock_do_start_func')
+    order_manager.attach_mock(mock_build_conn_func, 'mock_build_conn_func')
+    order_manager.attach_mock(mock_after_func, 'mock_after_func')
+
+    self.client.start_server()
+
+    expected_call_order = [
+        mock.call.mock_before_func(),
+        mock.call.mock_do_start_func(),
+        mock.call.mock_build_conn_func(),
+        mock.call.mock_after_func(),
+    ]
+    self.assertListEqual(order_manager.mock_calls, expected_call_order)
+
+  @mock.patch.object(FakeClient, 'stop_server')
+  @mock.patch.object(FakeClient, 'before_starting_server')
+  def test_start_server_before_starting_server_fail(self, mock_before_func,
+                                                    mock_stop_server):
+    """Test starting server's stage before_starting_server fails."""
+    mock_before_func.side_effect = Exception('ha')
+
+    with self.assertRaisesRegex(Exception, 'ha'):
+      self.client.start_server()
+    mock_stop_server.assert_not_called()
+
+  @mock.patch.object(FakeClient, 'stop_server')
+  @mock.patch.object(FakeClient, 'do_start_server')
+  def test_start_server_do_start_server_fail(self, mock_do_start_func,
+                                             mock_stop_server):
+    """Test starting server's stage do_start_server fails."""
+    mock_do_start_func.side_effect = Exception('ha')
+
+    with self.assertRaisesRegex(Exception, 'ha'):
+      self.client.start_server()
+    mock_stop_server.assert_called()
+
+  @mock.patch.object(FakeClient, 'stop_server')
+  @mock.patch.object(FakeClient, '_build_connection')
+  def test_start_server_build_connection_fail(self, mock_build_conn_func,
+                                              mock_stop_server):
+    """Test starting server's stage _build_connection fails."""
+    mock_build_conn_func.side_effect = Exception('ha')
+
+    with self.assertRaisesRegex(Exception, 'ha'):
+      self.client.start_server()
+    mock_stop_server.assert_called()
+
+  @mock.patch.object(FakeClient, 'stop_server')
+  @mock.patch.object(FakeClient, 'after_starting_server')
+  def test_start_server_after_starting_server_fail(self, mock_after_func,
+                                                   mock_stop_server):
+    """Test starting server's stage after_starting_server fails."""
+    mock_after_func.side_effect = Exception('ha')
+
+    with self.assertRaisesRegex(Exception, 'ha'):
+      self.client.start_server()
+    mock_stop_server.assert_called()
+
+  @mock.patch.object(FakeClient, 'check_server_proc_running')
+  @mock.patch.object(FakeClient, '_gen_rpc_request')
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format')
+  @mock.patch.object(FakeClient, '_handle_rpc_response')
+  def test_rpc_stage_dependencies(self, mock_handle_resp, mock_decode_resp_str,
+                                  mock_send_request, mock_gen_request,
+                                  mock_precheck):
+    """Test the internal dependencies when sending a RPC.
+
+    When sending a RPC, it calls multiple functions in specific order, and
+    each function uses the output of the previously called function. This test
+    case checks above dependencies.
+
+    Args:
+      mock_handle_resp: the mock function of FakeClient._handle_rpc_response.
+      mock_decode_resp_str: the mock function of
+        FakeClient._decode_response_string_and_validate_format.
+      mock_send_request: the mock function of FakeClient.send_rpc_request.
+      mock_gen_request: the mock function of FakeClient._gen_rpc_request.
+      mock_precheck: the mock function of FakeClient.check_server_proc_running.
+    """
+    self.client.start_server()
+
+    expected_response_str = ('{"id": 0, "result": 123, "error": null, '
+                             '"callback": null}')
+    expected_response_dict = {
+        'id': 0,
+        'result': 123,
+        'error': None,
+        'callback': None,
+    }
+    expected_request = ('{"id": 10, "method": "some_rpc", "params": [1, 2],'
+                        '"kwargs": {"test_key": 3}')
+    expected_result = 123
+
+    mock_gen_request.return_value = expected_request
+    mock_send_request.return_value = expected_response_str
+    mock_decode_resp_str.return_value = expected_response_dict
+    mock_handle_resp.return_value = expected_result
+    rpc_result = self.client.some_rpc(1, 2, test_key=3)
+
+    mock_precheck.assert_called()
+    mock_gen_request.assert_called_with(0, 'some_rpc', 1, 2, test_key=3)
+    mock_send_request.assert_called_with(expected_request)
+    mock_decode_resp_str.assert_called_with(0, expected_response_str)
+    mock_handle_resp.assert_called_with('some_rpc', expected_response_dict)
+    self.assertEqual(rpc_result, expected_result)
+
+  @mock.patch.object(FakeClient, 'check_server_proc_running')
+  @mock.patch.object(FakeClient, '_gen_rpc_request')
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format')
+  @mock.patch.object(FakeClient, '_handle_rpc_response')
+  def test_rpc_precheck_fail(self, mock_handle_resp, mock_decode_resp_str,
+                             mock_send_request, mock_gen_request,
+                             mock_precheck):
+    """Test when RPC precheck fails it will skip sending RPC."""
+    self.client.start_server()
+    mock_precheck.side_effect = Exception('server_died')
+
+    with self.assertRaisesRegex(Exception, 'server_died'):
+      self.client.some_rpc(1, 2)
+
+    mock_gen_request.assert_not_called()
+    mock_send_request.assert_not_called()
+    mock_handle_resp.assert_not_called()
+    mock_decode_resp_str.assert_not_called()
+
+  def test_gen_request(self):
+    """Test generating a RPC request.
+
+    Test that _gen_rpc_request returns a string represents a JSON dict
+    with all required fields.
+    """
+    request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2, test_key=3)
+    expected_result = ('{"id": 0, "kwargs": {"test_key": 3}, '
+                       '"method": "test_rpc", "params": [1, 2]}')
+    self.assertEqual(request, expected_result)
+
+  def test_gen_request_without_kwargs(self):
+    """Test no keyword arguments.
+
+    Test that _gen_rpc_request ignores the kwargs field when no
+    keyword arguments.
+    """
+    request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2)
+    expected_result = '{"id": 0, "method": "test_rpc", "params": [1, 2]}'
+    self.assertEqual(request, expected_result)
+
+  def test_rpc_no_response(self):
+    """Test parsing an empty RPC response."""
+    with self.assertRaisesRegex(errors.ProtocolError,
+                                errors.ProtocolError.NO_RESPONSE_FROM_SERVER):
+      self.client._decode_response_string_and_validate_format(0, '')
+
+    with self.assertRaisesRegex(errors.ProtocolError,
+                                errors.ProtocolError.NO_RESPONSE_FROM_SERVER):
+      self.client._decode_response_string_and_validate_format(0, None)
+
+  def test_rpc_response_missing_fields(self):
+    """Test parsing a RPC response that misses some required fields."""
+    mock_resp_without_id = '{"result": 123, "error": null, "callback": null}'
+    with self.assertRaisesRegex(
+        errors.ProtocolError,
+        errors.ProtocolError.RESPONSE_MISSING_FIELD % 'id'):
+      self.client._decode_response_string_and_validate_format(
+          10, mock_resp_without_id)
+
+    mock_resp_without_result = '{"id": 10, "error": null, "callback": null}'
+    with self.assertRaisesRegex(
+        errors.ProtocolError,
+        errors.ProtocolError.RESPONSE_MISSING_FIELD % 'result'):
+      self.client._decode_response_string_and_validate_format(
+          10, mock_resp_without_result)
+
+    mock_resp_without_error = '{"id": 10, "result": 123, "callback": null}'
+    with self.assertRaisesRegex(
+        errors.ProtocolError,
+        errors.ProtocolError.RESPONSE_MISSING_FIELD % 'error'):
+      self.client._decode_response_string_and_validate_format(
+          10, mock_resp_without_error)
+
+    mock_resp_without_callback = '{"id": 10, "result": 123, "error": null}'
+    with self.assertRaisesRegex(
+        errors.ProtocolError,
+        errors.ProtocolError.RESPONSE_MISSING_FIELD % 'callback'):
+      self.client._decode_response_string_and_validate_format(
+          10, mock_resp_without_callback)
+
+  def test_rpc_response_error(self):
+    """Test parsing a RPC response with a non-empty error field."""
+    mock_resp_with_error = {
+        'id': 10,
+        'result': 123,
+        'error': 'some_error',
+        'callback': None,
+    }
+    with self.assertRaisesRegex(errors.ApiError, 'some_error'):
+      self.client._handle_rpc_response('some_rpc', mock_resp_with_error)
+
+  def test_rpc_response_callback(self):
+    """Test parsing response function handles the callback field well."""
+    # Call handle_callback function if the "callback" field is not None
+    mock_resp_with_callback = {
+        'id': 10,
+        'result': 123,
+        'error': None,
+        'callback': '1-0'
+    }
+    with mock.patch.object(self.client,
+                           'handle_callback') as mock_handle_callback:
+      expected_callback = mock.Mock()
+      mock_handle_callback.return_value = expected_callback
+
+      rpc_result = self.client._handle_rpc_response('some_rpc',
+                                                    mock_resp_with_callback)
+      mock_handle_callback.assert_called_with('1-0', 123, 'some_rpc')
+      # Ensure the RPC function returns what handle_callback returned
+      self.assertIs(expected_callback, rpc_result)
+
+    # Do not call handle_callback function if the "callback" field is None
+    mock_resp_without_callback = {
+        'id': 10,
+        'result': 123,
+        'error': None,
+        'callback': None
+    }
+    with mock.patch.object(self.client,
+                           'handle_callback') as mock_handle_callback:
+      self.client._handle_rpc_response('some_rpc', mock_resp_without_callback)
+      mock_handle_callback.assert_not_called()
+
+  def test_rpc_response_id_mismatch(self):
+    """Test parsing a RPC response with wrong id."""
+    right_id = 5
+    wrong_id = 20
+    resp = f'{{"id": {right_id}, "result": 1, "error": null, "callback": null}}'
+
+    with self.assertRaisesRegex(errors.ProtocolError,
+                                errors.ProtocolError.MISMATCHED_API_ID):
+      self.client._decode_response_string_and_validate_format(wrong_id, resp)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_rpc_verbose_logging_with_long_string(self, mock_send_request):
+    """Test RPC response isn't truncated when verbose logging is on."""
+    mock_log = mock.Mock()
+    self.client.log = mock_log
+    self.client.set_snippet_client_verbose_logging(True)
+    self.client.start_server()
+
+    resp = _generate_fix_length_rpc_response(
+        client_base._MAX_RPC_RESP_LOGGING_LENGTH * 2)
+    mock_send_request.return_value = resp
+    self.client.some_rpc(1, 2)
+    mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_rpc_truncated_logging_short_response(self, mock_send_request):
+    """Test RPC response isn't truncated with small length."""
+    mock_log = mock.Mock()
+    self.client.log = mock_log
+    self.client.set_snippet_client_verbose_logging(False)
+    self.client.start_server()
+
+    resp = _generate_fix_length_rpc_response(
+        int(client_base._MAX_RPC_RESP_LOGGING_LENGTH // 2))
+    mock_send_request.return_value = resp
+    self.client.some_rpc(1, 2)
+    mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_rpc_truncated_logging_fit_size_response(self, mock_send_request):
+    """Test RPC response isn't truncated with length equal to the threshold."""
+    mock_log = mock.Mock()
+    self.client.log = mock_log
+    self.client.set_snippet_client_verbose_logging(False)
+    self.client.start_server()
+
+    resp = _generate_fix_length_rpc_response(
+        client_base._MAX_RPC_RESP_LOGGING_LENGTH)
+    mock_send_request.return_value = resp
+    self.client.some_rpc(1, 2)
+    mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_rpc_truncated_logging_long_response(self, mock_send_request):
+    """Test RPC response is truncated with length larger than the threshold."""
+    mock_log = mock.Mock()
+    self.client.log = mock_log
+    self.client.set_snippet_client_verbose_logging(False)
+    self.client.start_server()
+
+    max_len = client_base._MAX_RPC_RESP_LOGGING_LENGTH
+    resp = _generate_fix_length_rpc_response(max_len * 40)
+    mock_send_request.return_value = resp
+    self.client.some_rpc(1, 2)
+    mock_log.debug.assert_called_with(
+        'Snippet received: %s... %d chars are truncated',
+        resp[:client_base._MAX_RPC_RESP_LOGGING_LENGTH],
+        len(resp) - max_len)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_rpc_call_increment_counter(self, mock_send_request):
+    """Test that with each RPC call the counter is incremented by 1."""
+    self.client.start_server()
+    resp = '{"id": %d, "result": 123, "error": null, "callback": null}'
+    mock_send_request.side_effect = (resp % (i,) for i in range(10))
+
+    for _ in range(0, 10):
+      self.client.some_rpc()
+
+    self.assertEqual(next(self.client._counter), 10)
+
+  @mock.patch.object(FakeClient, 'send_rpc_request')
+  def test_build_connection_reset_counter(self, mock_send_request):
+    """Test that _build_connection resets the counter to zero."""
+    self.client.start_server()
+    resp = '{"id": %d, "result": 123, "error": null, "callback": null}'
+    mock_send_request.side_effect = (resp % (i,) for i in range(10))
+
+    for _ in range(0, 10):
+      self.client.some_rpc()
+
+    self.assertEqual(next(self.client._counter), 10)
+    self.client._build_connection()
+    self.assertEqual(next(self.client._counter), 0)
+
+
+if __name__ == '__main__':
+  unittest.main()