Add logic for ANDROID_SERIAL to snippet shell. (#574)


diff --git a/mobly/controllers/android_device_lib/jsonrpc_shell_base.py b/mobly/controllers/android_device_lib/jsonrpc_shell_base.py
index 162813a..82fe0ac 100755
--- a/mobly/controllers/android_device_lib/jsonrpc_shell_base.py
+++ b/mobly/controllers/android_device_lib/jsonrpc_shell_base.py
@@ -15,6 +15,7 @@
 from __future__ import print_function
 
 import code
+import os
 import pprint
 import sys
 
@@ -28,14 +29,14 @@
 class JsonRpcShellBase(object):
     def _start_services(self, console_env):
         """Starts the services needed by this client and adds them to console_env.
-  
+
         Must be implemented by subclasses.
         """
         raise NotImplemented()
 
     def _get_banner(self, serial):
         """Returns the user-friendly banner message to print before the console.
-  
+
         Must be implemented by subclasses.
         """
         raise NotImplemented()
@@ -43,19 +44,24 @@
     def load_device(self, serial=None):
         """Creates an AndroidDevice for the given serial number.
 
-        If no serial is given, it will be read from 'adb devices' if there is
-        only one.
+        If no serial is given, it will read from the ANDROID_SERIAL
+        environmental variable. If the environmental variable is not set, then
+        it will read from 'adb devices' if there is only one.
         """
         serials = android_device.list_adb_devices()
         if not serials:
             raise Error('No adb device found!')
         # No serial provided, try to pick up the device automatically.
         if not serial:
-            if len(serials) != 1:
+            env_serial = os.environ.get('ANDROID_SERIAL', None)
+            if env_serial is not None:
+                serial = env_serial
+            elif len(serials) == 1:
+                serial = serials[0]
+            else:
                 raise Error(
-                    'Expected one phone, but %d found. Use the -s flag.' %
-                    len(serials))
-            serial = serials[0]
+                    'Expected one phone, but %d found. Use the -s flag or '
+                    'specify ANDROID_SERIAL.' % len(serials))
         if serial not in serials:
             raise Error('Device "%s" is not found by adb.' % serial)
         ads = android_device.get_instances([serial])
diff --git a/tests/mobly/controllers/android_device_lib/jsonrpc_shell_base_test.py b/tests/mobly/controllers/android_device_lib/jsonrpc_shell_base_test.py
new file mode 100755
index 0000000..677cb89
--- /dev/null
+++ b/tests/mobly/controllers/android_device_lib/jsonrpc_shell_base_test.py
@@ -0,0 +1,91 @@
+# Copyright 2019 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import mock
+from future.tests.base import unittest
+
+from mobly.controllers import android_device
+from mobly.controllers.android_device_lib import jsonrpc_shell_base
+
+
+class JsonRpcClientBaseTest(unittest.TestCase):
+    """Unit tests for mobly.controllers.android_device_lib.jsonrpc_shell_base.
+    """
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    @mock.patch.object(android_device, 'get_instances')
+    @mock.patch.object(os, 'environ', new={})
+    def test_load_device(self, mock_get_instances, mock_list_adb_devices):
+        mock_list_adb_devices.return_value = ['1234', '4312']
+        mock_device = mock.MagicMock(spec=android_device.AndroidDevice)
+        mock_get_instances.return_value = [mock_device]
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        json_shell.load_device(serial='1234')
+        self.assertEqual(json_shell._ad, mock_device)
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    @mock.patch.object(android_device, 'get_instances')
+    @mock.patch.object(os, 'environ', new={})
+    def test_load_device_when_one_device(self, mock_get_instances,
+                                         mock_list_adb_devices):
+        mock_list_adb_devices.return_value = ['1234']
+        mock_device = mock.MagicMock(spec=android_device.AndroidDevice)
+        mock_get_instances.return_value = [mock_device]
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        json_shell.load_device()
+        self.assertEqual(json_shell._ad, mock_device)
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    @mock.patch.object(android_device, 'get_instances')
+    @mock.patch.object(os, 'environ', new={'ANDROID_SERIAL': '1234'})
+    def test_load_device_when_android_serial(self, mock_get_instances,
+                                             mock_list_adb_devices):
+        mock_list_adb_devices.return_value = ['1234', '4321']
+        mock_device = mock.MagicMock(spec=android_device.AndroidDevice)
+        mock_get_instances.return_value = [mock_device]
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        json_shell.load_device()
+        self.assertEqual(json_shell._ad, mock_device)
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    def test_load_device_when_no_devices(self, mock_list_adb_devices):
+        mock_list_adb_devices.return_value = []
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        with self.assertRaisesRegex(jsonrpc_shell_base.Error,
+                                    'No adb device found!'):
+            json_shell.load_device()
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    @mock.patch.object(os, 'environ', new={})
+    def test_load_device_when_unspecified_device(self, mock_list_adb_devices):
+        mock_list_adb_devices.return_value = ['1234', '4321']
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        with self.assertRaisesRegex(jsonrpc_shell_base.Error,
+                                    'Expected one phone.*'):
+            json_shell.load_device()
+
+    @mock.patch.object(android_device, 'list_adb_devices')
+    @mock.patch.object(os, 'environ', new={})
+    def test_load_device_when_device_not_found(self, mock_list_adb_devices):
+        mock_list_adb_devices.return_value = ['4321']
+        json_shell = jsonrpc_shell_base.JsonRpcShellBase()
+        with self.assertRaisesRegex(jsonrpc_shell_base.Error,
+                                    'Device "1234" is not found by adb.'):
+            json_shell.load_device(serial='1234')
+
+
+if __name__ == '__main__':
+    unittest.main()