Add Windows support for the port server (#25)

Add Windows support for the port server and Windows named pipe support to the portpicker client.

Contributed by Patrice Vignola
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index d50e439..88a8115 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -3,16 +3,22 @@
 
 name: Python Portpicker & Portserver
 
-on: [push]
+on:
+  push:
+    branches:
+    - 'main'
+  pull_request:
+    branches:
+    - 'main'
 
 jobs:
-  build:
+  build-ubuntu:
 
     runs-on: ubuntu-latest
     strategy:
       fail-fast: false
       matrix:
-        python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.1']
+        python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.3']
 
     steps:
       - uses: actions/checkout@v2
@@ -29,3 +35,27 @@
         run: |
           # Run tox using the version of Python in `PATH`
           tox -e py
+
+  build-windows:
+
+    runs-on: windows-latest
+    strategy:
+      fail-fast: false
+      matrix:
+        python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.3']
+
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install pytest tox
+          if (Test-Path "requirements.txt") { pip install -r requirements.txt }
+      - name: Test with tox
+        run: |
+          # Run tox using the version of Python in `PATH`
+          tox -e py
diff --git a/ChangeLog.md b/ChangeLog.md
index b385a38..28ae395 100644
--- a/ChangeLog.md
+++ b/ChangeLog.md
@@ -1,3 +1,10 @@
+## 1.5.0
+
+*   Add portserver support to Windows using named pipes. To create or connect
+    to a server, prefix the name of the server with `@` (e.g. 
+    `@unittest-portserver`).
+
+
 ## 1.4.0
 
 *   Use `async def` instead of `@asyncio.coroutine` in order to support 3.10.
diff --git a/setup.cfg b/setup.cfg
index 742a6d8..269b517 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,7 +1,7 @@
 # https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
 [metadata]
 name = portpicker
-version = 1.4.1b1
+version = 1.5.0b1
 maintainer = Google LLC
 maintainer_email = greg@krypto.org
 license = Apache 2.0
@@ -29,10 +29,11 @@
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: Implementation :: CPython
     Programming Language :: Python :: Implementation :: PyPy
-platforms = POSIX
+platforms = POSIX, Windows
 requires =
 
 [options]
+install_requires = psutil
 python_requires = >= 3.6
 package_dir=
     =src
diff --git a/src/portpicker.py b/src/portpicker.py
index e54dcbc..4717bbc 100644
--- a/src/portpicker.py
+++ b/src/portpicker.py
@@ -43,6 +43,11 @@
 import socket
 import sys
 
+if sys.platform == 'win32':
+    import _winapi
+else:
+    _winapi = None
+
 # The legacy Bind, IsPortFree, etc. names are not exported.
 __all__ = ('bind', 'is_port_free', 'pick_unused_port', 'return_port',
            'add_reserved_port', 'get_port_from_port_server')
@@ -63,7 +68,6 @@
 
 class NoFreePortFoundError(Exception):
     """Exception indicating that no free port could be found."""
-    pass
 
 
 def add_reserved_port(port):
@@ -217,6 +221,61 @@
     raise NoFreePortFoundError()
 
 
+def _get_linux_port_from_port_server(portserver_address, pid):
+    # An AF_UNIX address may start with a zero byte, in which case it is in the
+    # "abstract namespace", and doesn't have any filesystem representation.
+    # See 'man 7 unix' for details.
+    # The convention is to write '@' in the address to represent this zero byte.
+    if portserver_address[0] == '@':
+        portserver_address = '\0' + portserver_address[1:]
+
+    try:
+        # Create socket.
+        if hasattr(socket, 'AF_UNIX'):
+            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # pylint: disable=no-member
+        else:
+            # fallback to AF_INET if this is not unix
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        try:
+            # Connect to portserver.
+            sock.connect(portserver_address)
+
+            # Write request.
+            sock.sendall(('%d\n' % pid).encode('ascii'))
+
+            # Read response.
+            # 1K should be ample buffer space.
+            return sock.recv(1024)
+        finally:
+            sock.close()
+    except socket.error as error:
+        print('Socket error when connecting to portserver:', error,
+              file=sys.stderr)
+        return None
+
+
+def _get_windows_port_from_port_server(portserver_address, pid):
+    if portserver_address[0] == '@':
+        portserver_address = '\\\\.\\pipe\\' + portserver_address[1:]
+
+    try:
+        handle = _winapi.CreateFile(
+            portserver_address,
+            _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
+            0,
+            0,
+            _winapi.OPEN_EXISTING,
+            0,
+            0)
+
+        _winapi.WriteFile(handle, ('%d\n' % pid).encode('ascii'))
+        data, _ = _winapi.ReadFile(handle, 6, 0)
+        return data
+    except FileNotFoundError as error:
+        print('File error when connecting to portserver:', error,
+              file=sys.stderr)
+        return None
+
 def get_port_from_port_server(portserver_address, pid=None):
     """Request a free a port from a system-wide portserver.
 
@@ -240,38 +299,16 @@
     """
     if not portserver_address:
         return None
-    # An AF_UNIX address may start with a zero byte, in which case it is in the
-    # "abstract namespace", and doesn't have any filesystem representation.
-    # See 'man 7 unix' for details.
-    # The convention is to write '@' in the address to represent this zero byte.
-    if portserver_address[0] == '@':
-        portserver_address = '\0' + portserver_address[1:]
 
     if pid is None:
         pid = os.getpid()
 
-    try:
-        # Create socket.
-        if hasattr(socket, 'AF_UNIX'):
-            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        else:
-            # fallback to AF_INET if this is not unix
-            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        try:
-            # Connect to portserver.
-            sock.connect(portserver_address)
+    if _winapi:
+        buf = _get_windows_port_from_port_server(portserver_address, pid)
+    else:
+        buf = _get_linux_port_from_port_server(portserver_address, pid)
 
-            # Write request.
-            sock.sendall(('%d\n' % pid).encode('ascii'))
-
-            # Read response.
-            # 1K should be ample buffer space.
-            buf = sock.recv(1024)
-        finally:
-            sock.close()
-    except socket.error as e:
-        print('Socket error when connecting to portserver:', e,
-              file=sys.stderr)
+    if buf is None:
         return None
 
     try:
diff --git a/src/portserver.py b/src/portserver.py
index 58b7ecd..f986f3f 100644
--- a/src/portserver.py
+++ b/src/portserver.py
@@ -31,10 +31,12 @@
 import asyncio
 import collections
 import logging
-import os
 import signal
 import socket
 import sys
+import psutil
+import subprocess
+from datetime import datetime, timezone, timedelta
 
 log = None  # Initialized to a logging.Logger by _configure_logging().
 
@@ -44,18 +46,16 @@
 
 def _get_process_command_line(pid):
     try:
-        with open('/proc/{}/cmdline'.format(pid), 'rt') as cmdline_f:
-            return cmdline_f.read()
-    except IOError:
+        return psutil.Process(pid).cmdline()
+    except psutil.NoSuchProcess:
         return ''
 
 
 def _get_process_start_time(pid):
     try:
-        with open('/proc/{}/stat'.format(pid), 'rt') as pid_stat_f:
-            return int(pid_stat_f.readline().split()[21])
-    except IOError:
-        return 0
+        return psutil.Process(pid).create_time()
+    except psutil.NoSuchProcess:
+        return 0.0
 
 
 # TODO: Consider importing portpicker.bind() instead of duplicating the code.
@@ -115,14 +115,27 @@
         # had been reparented to init.
         log.info('Not allocating a port to init.')
         return False
-    try:
-        os.kill(pid, 0)
-    except (ProcessLookupError, OverflowError):
+
+    if not psutil.pid_exists(pid):
         log.info('Not allocating a port to a non-existent process')
         return False
     return True
 
 
+async def _start_windows_server(client_connected_cb, path):
+    """Start the server on Windows using named pipes."""
+    def protocol_factory():
+        stream_reader = asyncio.StreamReader()
+        stream_reader_protocol = asyncio.StreamReaderProtocol(
+            stream_reader, client_connected_cb)
+        return stream_reader_protocol
+
+    loop = asyncio.get_event_loop()
+    server, *_ = await loop.start_serving_pipe(protocol_factory, address=path)
+
+    return server
+
+
 class _PortInfo(object):
     """Container class for information about a given port assignment.
 
@@ -137,7 +150,7 @@
     def __init__(self, port):
         self.port = port
         self.pid = 0
-        self.start_time = 0
+        self.start_time = 0.0
 
 
 class _PortPool(object):
@@ -178,7 +191,7 @@
             candidate = self._port_queue.pop()
             self._port_queue.appendleft(candidate)
             check_count += 1
-            if (candidate.start_time == 0 or
+            if (candidate.start_time == 0.0 or
                 candidate.start_time != _get_process_start_time(candidate.pid)):
                 if _is_port_free(candidate.port):
                     candidate.pid = pid
@@ -287,10 +300,13 @@
         default='15000-24999',
         help='Comma separated N-P Range(s) of ports to manage (inclusive).')
     parser.add_argument(
-        '--portserver_unix_socket_address',
+        '--portserver_address',
+        '--portserver_unix_socket_address', # Alias to be backward compatible
         type=str,
         default='@unittest-portserver',
-        help='Address of AF_UNIX socket on which to listen (first @ is a NUL).')
+        help='Address of AF_UNIX socket on which to listen on Unix (first @ is '
+             'a NUL) or the name of the pipe on Windows (first @ is the '
+             r'\\.\pipe\ prefix).')
     parser.add_argument('--verbose',
                         action='store_true',
                         default=False,
@@ -348,14 +364,33 @@
 
     request_handler = _PortServerRequestHandler(ports_to_serve)
 
+    if sys.platform == 'win32':
+        asyncio.set_event_loop(asyncio.ProactorEventLoop())
+
     event_loop = asyncio.get_event_loop()
-    event_loop.add_signal_handler(signal.SIGUSR1, request_handler.dump_stats)
-    old_py_loop = {'loop': event_loop} if sys.version_info < (3, 10) else {}
-    coro = asyncio.start_unix_server(
-        request_handler.handle_port_request,
-        path=config.portserver_unix_socket_address.replace('@', '\0', 1),
-        **old_py_loop)
-    server_address = config.portserver_unix_socket_address
+
+    if sys.platform == 'win32':
+        # On Windows, we need to periodically pause the loop to allow the user
+        # to send a break signal (e.g. ctrl+c)
+        def listen_for_signal():
+            event_loop.call_later(0.5, listen_for_signal)
+
+        event_loop.call_later(0.5, listen_for_signal)
+
+        coro = _start_windows_server(
+            request_handler.handle_port_request,
+            path=config.portserver_address.replace('@', '\\\\.\\pipe\\', 1))
+    else:
+        event_loop.add_signal_handler(
+            signal.SIGUSR1, request_handler.dump_stats) # pylint: disable=no-member
+
+        old_py_loop = {'loop': event_loop} if sys.version_info < (3, 10) else {}
+        coro = asyncio.start_unix_server(
+            request_handler.handle_port_request,
+            path=config.portserver_address.replace('@', '\0', 1),
+            **old_py_loop)
+
+    server_address = config.portserver_address
 
     server = event_loop.run_until_complete(coro)
     log.info('Serving on %s', server_address)
@@ -365,8 +400,12 @@
         log.info('Stopping due to ^C.')
 
     server.close()
-    event_loop.run_until_complete(server.wait_closed())
-    event_loop.remove_signal_handler(signal.SIGUSR1)
+
+    if sys.platform != 'win32':
+        # PipeServer doesn't have a wait_closed() function
+        event_loop.run_until_complete(server.wait_closed())
+        event_loop.remove_signal_handler(signal.SIGUSR1) # pylint: disable=no-member
+
     event_loop.close()
     request_handler.dump_stats()
     log.info('Goodbye.')
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py
index b82d7bf..e479a46 100644
--- a/src/tests/portpicker_test.py
+++ b/src/tests/portpicker_test.py
@@ -23,6 +23,12 @@
 import socket
 import sys
 import unittest
+from contextlib import ExitStack
+
+if sys.platform == 'win32':
+    import _winapi
+else:
+    _winapi = None
 
 try:
     # pylint: disable=no-name-in-module
@@ -100,27 +106,82 @@
             self.assertTrue(self.IsUnusedUDPPort(port))
 
     def testSendsPidToPortServer(self):
-        server = mock.Mock()
-        server.recv.return_value = b'42768\n'
-        with mock.patch.object(socket, 'socket', return_value=server):
-            port = portpicker.get_port_from_port_server('portserver', pid=1234)
-            server.sendall.assert_called_once_with(b'1234\n')
+        with ExitStack() as stack:
+            if _winapi:
+                create_file_mock = mock.Mock()
+                create_file_mock.return_value = 0
+                read_file_mock = mock.Mock()
+                write_file_mock = mock.Mock()
+                read_file_mock.return_value = (b'42768\n', 0)
+                stack.enter_context(
+                    mock.patch('_winapi.CreateFile', new=create_file_mock))
+                stack.enter_context(
+                    mock.patch('_winapi.WriteFile', new=write_file_mock))
+                stack.enter_context(
+                    mock.patch('_winapi.ReadFile', new=read_file_mock))
+                port = portpicker.get_port_from_port_server(
+                    'portserver', pid=1234)
+                write_file_mock.assert_called_once_with(0, b'1234\n')
+            else:
+                server = mock.Mock()
+                server.recv.return_value = b'42768\n'
+                stack.enter_context(
+                    mock.patch.object(socket, 'socket', return_value=server))
+                port = portpicker.get_port_from_port_server(
+                    'portserver', pid=1234)
+                server.sendall.assert_called_once_with(b'1234\n')
+
         self.assertEqual(port, 42768)
 
     def testPidDefaultsToOwnPid(self):
-        server = mock.Mock()
-        server.recv.return_value = b'52768\n'
-        with mock.patch.object(socket, 'socket', return_value=server):
-            with mock.patch.object(os, 'getpid', return_value=9876):
+        with ExitStack() as stack:
+            stack.enter_context(
+                mock.patch.object(os, 'getpid', return_value=9876))
+
+            if _winapi:
+                create_file_mock = mock.Mock()
+                create_file_mock.return_value = 0
+                read_file_mock = mock.Mock()
+                write_file_mock = mock.Mock()
+                read_file_mock.return_value = (b'52768\n', 0)
+                stack.enter_context(
+                    mock.patch('_winapi.CreateFile', new=create_file_mock))
+                stack.enter_context(
+                    mock.patch('_winapi.WriteFile', new=write_file_mock))
+                stack.enter_context(
+                    mock.patch('_winapi.ReadFile', new=read_file_mock))
+                port = portpicker.get_port_from_port_server('portserver')
+                write_file_mock.assert_called_once_with(0, b'9876\n')
+            else:
+                server = mock.Mock()
+                server.recv.return_value = b'52768\n'
+                stack.enter_context(
+                    mock.patch.object(socket, 'socket', return_value=server))
                 port = portpicker.get_port_from_port_server('portserver')
                 server.sendall.assert_called_once_with(b'9876\n')
+
         self.assertEqual(port, 52768)
 
     @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'})
     def testReusesPortServerPorts(self):
-        server = mock.Mock()
-        server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n']
-        with mock.patch.object(socket, 'socket', return_value=server):
+        with ExitStack() as stack:
+            if _winapi:
+                read_file_mock = mock.Mock()
+                read_file_mock.side_effect = [
+                    (b'12345\n', 0),
+                    (b'23456\n', 0),
+                    (b'34567\n', 0),
+                ]
+                stack.enter_context(mock.patch('_winapi.CreateFile'))
+                stack.enter_context(mock.patch('_winapi.WriteFile'))
+                stack.enter_context(
+                    mock.patch('_winapi.ReadFile', new=read_file_mock))
+            else:
+                server = mock.Mock()
+                server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n']
+                stack.enter_context(
+                    mock.patch.object(socket, 'socket', return_value=server))
+
             self.assertEqual(portpicker.pick_unused_port(), 12345)
             self.assertEqual(portpicker.pick_unused_port(), 23456)
             portpicker.return_port(12345)
@@ -248,12 +309,18 @@
 
         cases = [
             (socket.AF_INET,  socket.SOCK_STREAM, None),
-            (socket.AF_INET6, socket.SOCK_STREAM, 0),
             (socket.AF_INET6, socket.SOCK_STREAM, 1),
             (socket.AF_INET,  socket.SOCK_DGRAM,  None),
-            (socket.AF_INET6, socket.SOCK_DGRAM,  0),
             (socket.AF_INET6, socket.SOCK_DGRAM,  1),
         ]
+
+        # Using v6only=0 on Windows doesn't result in collisions
+        if not _winapi:
+            cases.extend([
+                (socket.AF_INET6, socket.SOCK_STREAM, 0),
+                (socket.AF_INET6, socket.SOCK_DGRAM,  0),
+            ])
+
         for (sock_family, sock_type, v6only) in cases:
             # Occupy the port on a subset of possible protocols.
             try:
diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py
index 394b1b5..b7de094 100644
--- a/src/tests/portserver_test.py
+++ b/src/tests/portserver_test.py
@@ -25,14 +25,23 @@
 import time
 import unittest
 from unittest import mock
+from multiprocessing import Process
 
 import portpicker
+
+# On Windows, portserver.py is located in the "Scripts" folder, which isn't
+# added to the import path by default
+if sys.platform == 'win32':
+    sys.path.append(os.path.join(os.path.split(sys.executable)[0]))
+
 import portserver
 
 
 def setUpModule():
     portserver._configure_logging(verbose=True)
 
+def exit_immediately():
+    os._exit(0)
 
 class PortserverFunctionsTest(unittest.TestCase):
 
@@ -53,12 +62,18 @@
 
         cases = [
             (socket.AF_INET,  socket.SOCK_STREAM, None),
-            (socket.AF_INET6, socket.SOCK_STREAM, 0),
             (socket.AF_INET6, socket.SOCK_STREAM, 1),
             (socket.AF_INET,  socket.SOCK_DGRAM,  None),
-            (socket.AF_INET6, socket.SOCK_DGRAM,  0),
             (socket.AF_INET6, socket.SOCK_DGRAM,  1),
         ]
+
+        # Using v6only=0 on Windows doesn't result in collisions
+        if sys.platform != 'win32':
+            cases.extend([
+                (socket.AF_INET6, socket.SOCK_STREAM, 0),
+                (socket.AF_INET6, socket.SOCK_DGRAM,  0),
+            ])
+
         for (sock_family, sock_type, v6only) in cases:
             # Occupy the port on a subset of possible protocols.
             try:
@@ -68,6 +83,10 @@
                       file=sys.stderr)
                 # Skip this case, since we cannot occupy a port.
                 continue
+
+            if not hasattr(socket, 'IPPROTO_IPV6'):
+                v6only = None
+
             if v6only is not None:
                 try:
                     sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
@@ -94,11 +113,12 @@
         self.assertFalse(portserver._should_allocate_port(0))
         self.assertFalse(portserver._should_allocate_port(1))
         self.assertTrue(portserver._should_allocate_port, os.getpid())
-        child_pid = os.fork()
-        if child_pid == 0:
-            os._exit(0)
-        else:
-            os.waitpid(child_pid, 0)
+
+        p = Process(target=exit_immediately)
+        p.start()
+        child_pid = p.pid
+        p.join()
+
         # This test assumes that after waitpid returns the kernel has finished
         # cleaning the process.  We also assume that the kernel will not reuse
         # the former child's pid before our next call checks for its existence.