| #!/usr/bin/python3 |
| # |
| # Copyright 2015 Google Inc. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """Tests for the example portserver.""" |
| |
| from __future__ import print_function |
| import asyncio |
| import os |
| import socket |
| import sys |
| import unittest |
| from unittest import mock |
| |
| import portpicker |
| import portserver |
| |
| |
| def setUpModule(): |
| portserver._configure_logging(verbose=True) |
| |
| |
| class PortserverFunctionsTest(unittest.TestCase): |
| |
| @classmethod |
| def setUp(cls): |
| cls.port = portpicker.PickUnusedPort() |
| |
| def test_get_process_command_line(self): |
| portserver._get_process_command_line(os.getpid()) |
| |
| def test_get_process_start_time(self): |
| self.assertGreater(portserver._get_process_start_time(os.getpid()), 0) |
| |
| def test_is_port_free(self): |
| """This might be flaky unless this test is run with a portserver.""" |
| # The port should be free initially. |
| self.assertTrue(portserver._is_port_free(self.port)) |
| |
| cases = [ |
| (socket.AF_INET, socket.SOCK_STREAM, None), |
| (socket.AF_INET6, socket.SOCK_STREAM, 0), |
| (socket.AF_INET6, socket.SOCK_STREAM, 1), |
| (socket.AF_INET, socket.SOCK_DGRAM, None), |
| (socket.AF_INET6, socket.SOCK_DGRAM, 0), |
| (socket.AF_INET6, socket.SOCK_DGRAM, 1), |
| ] |
| for (sock_family, sock_type, v6only) in cases: |
| # Occupy the port on a subset of possible protocols. |
| try: |
| sock = socket.socket(sock_family, sock_type, 0) |
| except socket.error: |
| print('Kernel does not support sock_family=%d' % sock_family, |
| file=sys.stderr) |
| # Skip this case, since we cannot occupy a port. |
| continue |
| if v6only is not None: |
| try: |
| sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, |
| v6only) |
| except socket.error: |
| print('Kernel does not support IPV6_V6ONLY=%d' % v6only, |
| file=sys.stderr) |
| # Don't care; just proceed with the default. |
| sock.bind(('', self.port)) |
| |
| # The port should be busy. |
| self.assertFalse(portserver._is_port_free(self.port)) |
| sock.close() |
| |
| # Now it's free again. |
| self.assertTrue(portserver._is_port_free(self.port)) |
| |
| def test_is_port_free_exception(self): |
| with mock.patch.object(socket, 'socket') as mock_sock: |
| mock_sock.side_effect = socket.error('fake socket error', 0) |
| self.assertFalse(portserver._is_port_free(self.port)) |
| |
| def test_should_allocate_port(self): |
| self.assertFalse(portserver._should_allocate_port(0)) |
| self.assertFalse(portserver._should_allocate_port(1)) |
| self.assertTrue(portserver._should_allocate_port, os.getpid()) |
| child_pid = os.fork() |
| if child_pid == 0: |
| os._exit(0) |
| else: |
| os.waitpid(child_pid, 0) |
| # This test assumes that after waitpid returns the kernel has finished |
| # cleaning the process. We also assume that the kernel will not reuse |
| # the former child's pid before our next call checks for its existence. |
| # Likely assumptions, but not guaranteed. |
| self.assertFalse(portserver._should_allocate_port(child_pid)) |
| |
| def test_parse_command_line(self): |
| with mock.patch.object( |
| sys, 'argv', ['program_name', '--verbose', |
| '--portserver_static_pool=1-1,3-8', |
| '--portserver_unix_socket_address=@hello-test']): |
| portserver._parse_command_line() |
| |
| def test_parse_port_ranges(self): |
| self.assertFalse(portserver._parse_port_ranges('')) |
| self.assertCountEqual(portserver._parse_port_ranges('1-1'), {1}) |
| self.assertCountEqual(portserver._parse_port_ranges('1-1,3-8,375-378'), |
| {1, 3, 4, 5, 6, 7, 8, 375, 376, 377, 378}) |
| # Unparsable parts are logged but ignored. |
| self.assertEqual({1, 2}, |
| portserver._parse_port_ranges('1-2,not,numbers')) |
| self.assertEqual(set(), portserver._parse_port_ranges('8080-8081x')) |
| # Port ranges that go out of bounds are logged but ignored. |
| self.assertEqual(set(), portserver._parse_port_ranges('0-1138')) |
| self.assertEqual(set(range(19, 84 + 1)), |
| portserver._parse_port_ranges('1138-65536,19-84')) |
| |
| def test_configure_logging(self): |
| """Just code coverage really.""" |
| portserver._configure_logging(False) |
| portserver._configure_logging(True) |
| |
| @mock.patch.object( |
| sys, 'argv', ['PortserverFunctionsTest.test_main', |
| '--portserver_unix_socket_address=@TST-%d' % os.getpid()] |
| ) |
| @mock.patch.object(portserver, '_parse_port_ranges') |
| @mock.patch('asyncio.get_event_loop') |
| @mock.patch('asyncio.start_unix_server') |
| def test_main(self, *unused_mocks): |
| portserver._parse_port_ranges.return_value = set() |
| with self.assertRaises(SystemExit): |
| portserver.main() |
| |
| # Give it at least one port and try again. |
| portserver._parse_port_ranges.return_value = {self.port} |
| |
| mock_event_loop = mock.Mock(spec=asyncio.base_events.BaseEventLoop) |
| asyncio.get_event_loop.return_value = mock_event_loop |
| asyncio.start_unix_server.return_value = mock.Mock() |
| mock_event_loop.run_forever.side_effect = KeyboardInterrupt |
| |
| portserver.main() |
| |
| mock_event_loop.run_until_complete.assert_any_call( |
| asyncio.start_unix_server.return_value) |
| mock_event_loop.close.assert_called_once_with() |
| # NOTE: This could be improved. Tests of main() are often gross. |
| |
| |
| class PortPoolTest(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.port = portpicker.PickUnusedPort() |
| |
| def setUp(self): |
| self.pool = portserver._PortPool() |
| |
| def test_initialization(self): |
| self.assertEqual(0, self.pool.num_ports()) |
| self.pool.add_port_to_free_pool(self.port) |
| self.assertEqual(1, self.pool.num_ports()) |
| self.pool.add_port_to_free_pool(1138) |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0) |
| self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| def test_get_port_for_process_ok(self, mock_is_port_free): |
| self.pool.add_port_to_free_pool(self.port) |
| mock_is_port_free.return_value = True |
| self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(1, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| def test_get_port_for_process_none_left(self, mock_is_port_free): |
| self.pool.add_port_to_free_pool(self.port) |
| self.pool.add_port_to_free_pool(22) |
| mock_is_port_free.return_value = False |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(0, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| @mock.patch.object(os, 'getpid') |
| def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free): |
| self.pool.add_port_to_free_pool(12345) |
| self.pool.add_port_to_free_pool(12344) |
| mock_is_port_free.side_effect = lambda port: port == os.getpid() |
| mock_getpid.return_value = 12345 |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(12345, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| @mock.patch.object(os, 'getpid') |
| def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free): |
| self.pool.add_port_to_free_pool(12344) |
| self.pool.add_port_to_free_pool(12345) |
| mock_is_port_free.side_effect = lambda port: port != os.getpid() |
| mock_getpid.return_value = 12345 |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(12344, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| |
| @mock.patch.object(portserver, '_get_process_command_line') |
| @mock.patch.object(portserver, '_should_allocate_port') |
| @mock.patch.object(portserver._PortPool, 'get_port_for_process') |
| class PortServerRequestHandlerTest(unittest.TestCase): |
| def setUp(self): |
| portserver._configure_logging(verbose=True) |
| self.rh = portserver._PortServerRequestHandler([23, 42, 54]) |
| |
| def test_stats_reporting(self, *unused_mocks): |
| with mock.patch.object(portserver, 'log') as mock_logger: |
| self.rh.dump_stats() |
| mock_logger.info.assert_called_with('total-allocations 0') |
| |
| def test_handle_port_request_bad_data(self, *unused_mocks): |
| self._test_bad_data_from_client(b'') |
| self._test_bad_data_from_client(b'\n') |
| self._test_bad_data_from_client(b'99Z\n') |
| self._test_bad_data_from_client(b'99 8\n') |
| self.assertEqual([], portserver._get_process_command_line.mock_calls) |
| |
| def _test_bad_data_from_client(self, data): |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(data, mock_writer) |
| self.assertFalse(portserver._should_allocate_port.mock_calls) |
| |
| def test_handle_port_request_denied_allocation(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = False |
| self.assertEqual(0, self.rh._denied_allocations) |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(b'5\n', mock_writer) |
| self.assertEqual(1, self.rh._denied_allocations) |
| |
| def test_handle_port_request_bad_port_returned(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = True |
| self.rh._port_pool.get_port_for_process.return_value = 0 |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(b'6\n', mock_writer) |
| self.rh._port_pool.get_port_for_process.assert_called_once_with(6) |
| self.assertEqual(1, self.rh._denied_allocations) |
| |
| def test_handle_port_request_success(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = True |
| self.rh._port_pool.get_port_for_process.return_value = 999 |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.assertEqual(0, self.rh._total_allocations) |
| self.rh._handle_port_request(b'8', mock_writer) |
| portserver._should_allocate_port.assert_called_once_with(8) |
| self.rh._port_pool.get_port_for_process.assert_called_once_with(8) |
| self.assertEqual(1, self.rh._total_allocations) |
| self.assertEqual(0, self.rh._denied_allocations) |
| mock_writer.write.assert_called_once_with(b'999\n') |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |