Add a flag for whether to use Snippet Client V2 (#815)

diff --git a/mobly/controllers/android_device_lib/services/snippet_management_service.py b/mobly/controllers/android_device_lib/services/snippet_management_service.py
index 1cd0587..7d1dd94 100644
--- a/mobly/controllers/android_device_lib/services/snippet_management_service.py
+++ b/mobly/controllers/android_device_lib/services/snippet_management_service.py
@@ -14,6 +14,7 @@
 """Module for the snippet management service."""
 from mobly.controllers.android_device_lib import errors
 from mobly.controllers.android_device_lib import snippet_client
+from mobly.controllers.android_device_lib import snippet_client_v2
 from mobly.controllers.android_device_lib.services import base_service
 
 MISSING_SNIPPET_CLIENT_MSG = 'No snippet client is registered with name "%s".'
@@ -37,6 +38,23 @@
     self._is_alive = False
     self._snippet_clients = {}
     super().__init__(device)
+    self._use_client_v2_switch = None
+
+  def _is_using_client_v2(self):
+    """Is this service using snippet client V2.
+
+    Do not call this function in the constructor, as this function depends on
+    the device configuration and the device object will load its configuration
+    right after constructing the services.
+
+    NOTE: This is a transient function when we are migrating the snippet client
+    from v1 to v2. It will be removed after the migration is completed.
+    """
+    if self._use_client_v2_switch is None:
+      device_dimensions = getattr(self._device, 'dimensions', {})
+      self._use_client_v2_switch = (device_dimensions.get(
+          'use_mobly_snippet_client_v2', 'false').lower() == 'true')
+    return self._use_client_v2_switch
 
   @property
   def is_alive(self):
@@ -78,8 +96,14 @@
         raise Error(
             self, 'Snippet package "%s" has already been loaded under name'
             ' "%s".' % (package, snippet_name))
-    client = snippet_client.SnippetClient(package=package, ad=self._device)
-    client.start_app_and_connect()
+
+    if self._is_using_client_v2():
+      client = snippet_client_v2.SnippetClientV2(package=package,
+                                                 ad=self._device)
+      client.initialize()
+    else:
+      client = snippet_client.SnippetClient(package=package, ad=self._device)
+      client.start_app_and_connect()
     self._snippet_clients[name] = client
 
   def remove_snippet_client(self, name):
@@ -94,14 +118,20 @@
     if name not in self._snippet_clients:
       raise Error(self._device, MISSING_SNIPPET_CLIENT_MSG % name)
     client = self._snippet_clients.pop(name)
-    client.stop_app()
+    if self._is_using_client_v2():
+      client.stop()
+    else:
+      client.stop_app()
 
   def start(self):
     """Starts all the snippet clients under management."""
     for client in self._snippet_clients.values():
       if not client.is_alive:
         self._device.log.debug('Starting SnippetClient<%s>.', client.package)
-        client.start_app_and_connect()
+        if self._is_using_client_v2():
+          client.initialize()
+        else:
+          client.start_app_and_connect()
       else:
         self._device.log.debug(
             'Not startng SnippetClient<%s> because it is already alive.',
@@ -112,7 +142,10 @@
     for client in self._snippet_clients.values():
       if client.is_alive:
         self._device.log.debug('Stopping SnippetClient<%s>.', client.package)
-        client.stop_app()
+        if self._is_using_client_v2():
+          client.stop()
+        else:
+          client.stop_app()
       else:
         self._device.log.debug(
             'Not stopping SnippetClient<%s> because it is not alive.',
@@ -126,14 +159,20 @@
     """
     for client in self._snippet_clients.values():
       self._device.log.debug('Pausing SnippetClient<%s>.', client.package)
-      client.disconnect()
+      if self._is_using_client_v2():
+        client.close_connection()
+      else:
+        client.disconnect()
 
   def resume(self):
     """Resumes all paused snippet clients."""
     for client in self._snippet_clients.values():
       if not client.is_alive:
         self._device.log.debug('Resuming SnippetClient<%s>.', client.package)
-        client.restore_app_connection()
+        if self._is_using_client_v2():
+          client.restore_server_connection()
+        else:
+          client.restore_app_connection()
       else:
         self._device.log.debug('Not resuming SnippetClient<%s>.',
                                client.package)
diff --git a/tests/mobly/controllers/android_device_lib/services/snippet_management_service_test.py b/tests/mobly/controllers/android_device_lib/services/snippet_management_service_test.py
index 469ce4e..54bf60d 100755
--- a/tests/mobly/controllers/android_device_lib/services/snippet_management_service_test.py
+++ b/tests/mobly/controllers/android_device_lib/services/snippet_management_service_test.py
@@ -19,6 +19,7 @@
 
 MOCK_PACKAGE = 'com.mock.package'
 SNIPPET_CLIENT_CLASS_PATH = 'mobly.controllers.android_device_lib.snippet_client.SnippetClient'
+SNIPPET_CLIENT_V2_CLASS_PATH = 'mobly.controllers.android_device_lib.snippet_client_v2.SnippetClientV2'
 
 
 class SnippetManagementServiceTest(unittest.TestCase):
@@ -161,6 +162,105 @@
     manager.foo.ha('param')
     mock_client.ha.assert_called_once_with('param')
 
+  def test_client_v2_flag_default_value(self):
+    mock_device = mock.MagicMock()
+    mock_device.dimensions = {}
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    self.assertFalse(manager._is_using_client_v2())
+
+  def test_client_v2_flag_false(self):
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'false'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    self.assertFalse(manager._is_using_client_v2())
+
+  def test_client_v2_flag_true(self):
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    self.assertTrue(manager._is_using_client_v2())
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_add_snippet_client(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+    self.assertIs(manager.get_snippet_client('foo'), mock_client)
+    mock_client.initialize.assert_called_once_with()
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_remove_snippet_client(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+    manager.remove_snippet_client('foo')
+    mock_client.stop.assert_called_once_with()
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_start(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+
+    mock_client.initialize.reset_mock()
+    mock_client.is_alive = False
+    manager.start()
+
+    mock_client.initialize.assert_called_once_with()
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_stop(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+
+    mock_client.stop.reset_mock()
+    mock_client.is_alive = True
+    manager.stop()
+
+    mock_client.stop.assert_called_once_with()
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_pause(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+
+    mock_client.close_connection.reset_mock()
+    manager.pause()
+
+    mock_client.close_connection.assert_called_once_with()
+
+  @mock.patch(SNIPPET_CLIENT_V2_CLASS_PATH)
+  def test_client_v2_resume(self, mock_class):
+    mock_client = mock.MagicMock()
+    mock_class.return_value = mock_client
+    mock_device = mock.MagicMock(
+        dimensions={'use_mobly_snippet_client_v2': 'true'})
+    manager = snippet_management_service.SnippetManagementService(mock_device)
+    manager.add_snippet_client('foo', MOCK_PACKAGE)
+
+    mock_client.restore_server_connection.reset_mock()
+    mock_client.is_alive = False
+    manager.resume()
+
+    mock_client.restore_server_connection.assert_called_once_with()
+
 
 if __name__ == '__main__':
   unittest.main()