Adds an optional timeout that a port will remain bound to after the CLI exits (#31)
* Add CLI unittests.
* Allow tests to launch their own portserver.
* Add -h --help support to the CLI.
diff --git a/ChangeLog.md b/ChangeLog.md
index d2e71b7..b059d22 100644
--- a/ChangeLog.md
+++ b/ChangeLog.md
@@ -1,5 +1,14 @@
-## 1.6.x
+## 1.6.0b1
+* Add -h and --help text to the command line tool.
+* The command line interface now defaults to associating the returned port
+ with its parent process PID (usually the calling script) when no argument
+ was given as that makes more sense.
+* When portpicker is used as a command line tool from a script, if a port is
+ chosen without a portserver it can now be kept bound to a socket by a
+ child process for a user specified timeout. When successful, this helps
+ minimize race conditions as subsequent portpicker CLI invocations within
+ the timeout window cannot choose the same port.
* Some pylint based refactorings to portpicker and portpicker\_test.
* Drop 3.6 from our CI test matrix and metadata. It probably still works
there, but expect our unittests to include 3.7-ism's in the future. We'll
diff --git a/src/portpicker.py b/src/portpicker.py
index 1d678cb..10b373d 100644
--- a/src/portpicker.py
+++ b/src/portpicker.py
@@ -45,6 +45,7 @@
import random
import socket
import sys
+import time
_winapi = None # pylint: disable=invalid-name
if sys.platform == 'win32':
@@ -112,8 +113,33 @@
Returns:
The port number on success or None on failure.
"""
+ return _bind(port, socket_type, socket_proto)
+
+
+def _bind(port, socket_type, socket_proto, return_socket=None,
+ return_family=socket.AF_INET6):
+ """Internal implementation of bind.
+
+ Args:
+ port, socket_type, socket_proto: see bind().
+ return_socket: If supplied, a list that we will append an open bound
+ reuseaddr socket on the port in question to.
+ return_family: The socket family to return in return_socket.
+
+ Returns:
+ The port number on success or None on failure.
+ """
+ # Our return family must come last when returning a bound socket
+ # as we cannot keep it bound while testing a bind on the other
+ # family with many network stack configurations.
+ if return_socket is None or return_family == socket.AF_INET:
+ socket_families = (socket.AF_INET6, socket.AF_INET)
+ elif return_family == socket.AF_INET6:
+ socket_families = (socket.AF_INET, socket.AF_INET6)
+ else:
+ raise ValueError('unknown return_family %s' % return_family)
got_socket = False
- for family in (socket.AF_INET6, socket.AF_INET):
+ for family in socket_families:
try:
sock = socket.socket(family, socket_type, socket_proto)
got_socket = True
@@ -128,27 +154,43 @@
except socket.error:
return None
finally:
- sock.close()
+ if return_socket is None or family != return_family:
+ sock.close()
+ if return_socket is not None and family == return_family:
+ return_socket.append(sock)
+ break # Final iteration due to pre-loop logic; don't close.
return port if got_socket else None
-Bind = bind # legacy API. pylint: disable=invalid-name
-
def is_port_free(port):
"""Check if specified port is free.
Args:
port: integer, port to check
- Returns:
- boolean, whether it is free to use for both TCP and UDP
- """
- return bind(port, *_PROTOS[0]) and bind(port, *_PROTOS[1])
-IsPortFree = is_port_free # legacy API. pylint: disable=invalid-name
+ Returns:
+ bool, whether port is free to use for both TCP and UDP.
+ """
+ return _is_port_free(port)
+
+
+def _is_port_free(port, return_sockets=None):
+ """Internal implementation of is_port_free.
+
+ Args:
+ port: integer, port to check
+ return_sockets: If supplied, a list that we will append open bound
+ sockets on the port in question to rather than closing them.
+
+ Returns:
+ bool, whether port is free to use for both TCP and UDP.
+ """
+ return (_bind(port, *_PROTOS[0], return_socket=return_sockets) and
+ _bind(port, *_PROTOS[1], return_socket=return_sockets))
def pick_unused_port(pid=None, portserver_address=None):
- """A pure python implementation of PickUnusedPort.
+ """Picks an unused port and reserves it for use by a given process id.
Args:
pid: PID to tell the portserver to associate the reservation with. If
@@ -161,12 +203,30 @@
address, the environment will be checked for a PORTSERVER_ADDRESS
variable. If that is not set, no port server will be used.
+ If no portserver is used, no pid based reservation is managed by any
+ central authority. Race conditions and duplicate assignments may occur.
+
Returns:
A port number that is unused on both TCP and UDP.
Raises:
NoFreePortFoundError: No free port could be found.
"""
+ return _pick_unused_port(pid, portserver_address)
+
+
+def _pick_unused_port(pid=None, portserver_address=None,
+ noserver_bind_timeout=0):
+ """Internal implementation of pick_unused_port.
+
+ Args:
+ pid, portserver_address: See pick_unused_port().
+ noserver_bind_timeout: If no portserver was used, this is the number of
+ seconds we will attempt to keep a child process around with the ports
+ returned open and bound SO_REUSEADDR style to help avoid race condition
+ port reuse. A non-zero value attempts os.fork(). Do not use it in a
+ multithreaded process.
+ """
try: # Instead of `if _free_ports:` to handle the race condition.
port = _free_ports.pop()
except KeyError:
@@ -184,12 +244,46 @@
pid=pid)
if port:
return port
- return _pick_unused_port_without_server()
-
-PickUnusedPort = pick_unused_port # legacy API. pylint: disable=invalid-name
+ return _pick_unused_port_without_server(bind_timeout=noserver_bind_timeout)
-def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-name
+def _spawn_bound_port_holding_daemon(port, bound_sockets, timeout):
+ """If possible, fork()s a daemon process to hold bound_sockets open.
+
+ Emits a warning to stderr if it cannot.
+
+ Args:
+ port: The port number the sockets are bound to (informational).
+ bound_sockets: The list of bound sockets our child process will hold
+ open. If the list is empty, no action is taken.
+ timeout: A positive number of seconds the child should sleep for before
+ closing the sockets and exiting.
+ """
+ if bound_sockets and timeout > 0:
+ try:
+ fork_pid = os.fork() # This concept only works on POSIX.
+ except Exception as err: # pylint: disable=broad-except
+ print('WARNING: Cannot timeout unbinding close of port', port,
+ ' closing on exit. -', err, file=sys.stderr)
+ else:
+ if fork_pid == 0:
+ # This child process inherits and holds bound_sockets open
+ # for bind_timeout seconds.
+ try:
+ # Close the stdio fds as may be connected to
+ # a pipe that will cause a grandparent process
+ # to wait on before returning. (cl/427587550)
+ os.close(sys.stdin.fileno())
+ os.close(sys.stdout.fileno())
+ os.close(sys.stderr.fileno())
+ time.sleep(timeout)
+ for held_socket in bound_sockets:
+ held_socket.close()
+ finally:
+ sys.exit(0)
+
+
+def _pick_unused_port_without_server(bind_timeout=0):
"""Pick an available network port without the help of a port server.
This code ensures that the port is available on both TCP and UDP.
@@ -197,6 +291,11 @@
This function is an implementation detail of PickUnusedPort(), and
should not be called by code outside of this module.
+ Args:
+ bind_timeout: number of seconds to attempt to keep a child process
+ process around bound SO_REUSEADDR style to the port. If we cannot
+ do that we emit a warning to stderr.
+
Returns:
A port number that is unused on both TCP and UDP.
@@ -206,28 +305,42 @@
# Next, try a few times to get an OS-assigned port.
# Ambrose discovered that on the 2.6 kernel, calling Bind() on UDP socket
# returns the same port over and over. So always try TCP first.
+ port = None
+ bound_sockets = [] if bind_timeout > 0 else None
for _ in range(10):
# Ask the OS for an unused port.
- port = bind(0, _PROTOS[0][0], _PROTOS[0][1])
+ port = _bind(0, socket.SOCK_STREAM, socket.IPPROTO_TCP, bound_sockets)
# Check if this port is unused on the other protocol.
if (port and port not in _random_ports and
- bind(port, _PROTOS[1][0], _PROTOS[1][1])):
+ _bind(port, socket.SOCK_DGRAM, socket.IPPROTO_UDP, bound_sockets)):
_random_ports.add(port)
+ _spawn_bound_port_holding_daemon(port, bound_sockets, bind_timeout)
return port
+ if bound_sockets:
+ for held_socket in bound_sockets:
+ held_socket.close()
+ del bound_sockets[:]
# Try random ports as a last resort.
rng = random.Random()
for _ in range(10):
port = int(rng.randrange(15000, 25000))
- if port not in _random_ports and is_port_free(port):
- _random_ports.add(port)
- return port
+ if port not in _random_ports:
+ if _is_port_free(port, bound_sockets):
+ _random_ports.add(port)
+ _spawn_bound_port_holding_daemon(
+ port, bound_sockets, bind_timeout)
+ return port
+ if bound_sockets:
+ for held_socket in bound_sockets:
+ held_socket.close()
+ del bound_sockets[:]
# Give up.
raise NoFreePortFoundError()
-def _get_linux_port_from_port_server(portserver_address, pid):
+def _posix_get_port_from_port_server(portserver_address, pid):
# An AF_UNIX address may start with a zero byte, in which case it is in the
# "abstract namespace", and doesn't have any filesystem representation.
# See 'man 7 unix' for details.
@@ -238,7 +351,7 @@
try:
# Create socket.
if hasattr(socket, 'AF_UNIX'):
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # pylint: disable=no-member
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
# fallback to AF_INET if this is not unix
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -260,7 +373,7 @@
return None
-def _get_windows_port_from_port_server(portserver_address, pid):
+def _windows_get_port_from_port_server(portserver_address, pid):
if portserver_address[0] == '@':
portserver_address = '\\\\.\\pipe\\' + portserver_address[1:]
@@ -282,6 +395,7 @@
file=sys.stderr)
return None
+
def get_port_from_port_server(portserver_address, pid=None):
"""Request a free a port from a system-wide portserver.
@@ -310,9 +424,9 @@
pid = os.getpid()
if _winapi:
- buf = _get_windows_port_from_port_server(portserver_address, pid)
+ buf = _windows_get_port_from_port_server(portserver_address, pid)
else:
- buf = _get_linux_port_from_port_server(portserver_address, pid)
+ buf = _posix_get_port_from_port_server(portserver_address, pid)
if buf is None:
return None
@@ -326,12 +440,48 @@
return port
-GetPortFromPortServer = get_port_from_port_server # legacy API. pylint: disable=invalid-name
+# Legacy APIs.
+# pylint: disable=invalid-name
+Bind = bind
+GetPortFromPortServer = get_port_from_port_server
+IsPortFree = is_port_free
+PickUnusedPort = pick_unused_port
+# pylint: enable=invalid-name
def main(argv):
- """If passed an arg, treat it as a PID, otherwise portpicker uses getpid."""
- port = pick_unused_port(pid=int(argv[1]) if len(argv) > 1 else None)
+ """If passed an arg, treat it as a PID, otherwise we use getppid().
+
+ A second optional argument can be a bind timeout in seconds that will be
+ used ONLY if no portserver is found. We attempt to leave a process around
+ holding the port open and bound with SO_REUSEADDR set for timeout seconds.
+ If the timeout bind was not possible, a warning is emitted to stderr.
+
+ #!/bin/bash
+ port="$(python -m portpicker $$ 1.23)"
+ test_my_server "$port"
+
+ This will pick a port for your script's PID and assign it to $port, if no
+ portserver was used, it attempts to keep a socket bound to $port for 1.23
+ seconds after the portpicker process has exited. This is a convenient hack
+ to attempt to prevent port reallocation during scripts outside of
+ portserver managed environments.
+
+ Older versions of the portpicker CLI ignore everything beyond the first arg.
+ Older versions also used getpid() instead of getppid(), so script users are
+ strongly encouraged to be explicit and pass $$ or your languages equivalent
+ to associate the port with the PID of the controlling process.
+ """
+ # Our command line is trivial so I avoid an argparse import. If we ever
+ # grow more than 1-2 args, switch to a using argparse.
+ if '-h' in argv or '--help' in argv:
+ print(argv[0], 'usage:\n')
+ import inspect
+ print(inspect.getdoc(main))
+ sys.exit(1)
+ pid=int(argv[1]) if len(argv) > 1 else os.getppid()
+ bind_timeout=float(argv[2]) if len(argv) > 2 else 0
+ port = _pick_unused_port(pid=pid, noserver_bind_timeout=bind_timeout)
if not port:
sys.exit(1)
print(port)
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py
index 033e50d..9967648 100644
--- a/src/tests/portpicker_test.py
+++ b/src/tests/portpicker_test.py
@@ -22,15 +22,19 @@
import errno
import os
import socket
+import subprocess
import sys
+import time
import unittest
from unittest import mock
import portpicker
_winapi = portpicker._winapi
+# pylint: disable=invalid-name,protected-access,missing-class-docstring,missing-function-docstring
-class PickUnusedPortTest(unittest.TestCase):
+
+class CommonTestMixin:
def IsUnusedTCPPort(self, port):
return self._bind(port, socket.SOCK_STREAM, socket.IPPROTO_TCP)
@@ -45,15 +49,62 @@
portpicker._free_ports.clear()
portpicker._random_ports.clear()
- def testPickUnusedPortActuallyWorks(self):
- """This test can be flaky."""
- for _ in range(10):
- port = portpicker.pick_unused_port()
- self.assertTrue(self.IsUnusedTCPPort(port))
- self.assertTrue(self.IsUnusedUDPPort(port))
- @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
- 'no port server to test against')
+@unittest.skipIf(
+ ('PORTSERVER_ADDRESS' not in os.environ) and
+ not hasattr(socket, 'AF_UNIX'),
+ 'no existing port server; test launching code requires AF_UNIX.')
+class PickUnusedPortTestWithAPortServer(CommonTestMixin, unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.portserver_process = None
+ if 'PORTSERVER_ADDRESS' not in os.environ:
+ # Launch a portserver child process for our tests to use if we are
+ # able to. Obviously not host-exclusive, but good for integration
+ # testing purposes on CI without a portserver of its own.
+ cls.portserver_address = '@pid%d-test-ports' % os.getpid()
+ try:
+ cls.portserver_process = subprocess.Popen(
+ ['portserver.py', # Installed in PATH within the venv.
+ '--portserver_address=%s' % cls.portserver_address])
+ except EnvironmentError as err:
+ raise unittest.SkipTest(
+ 'Unable to launch portserver.py: %s' % err)
+ linux_addr = '\0' + cls.portserver_address[1:] # The @ means 0.
+ # loop for a few seconds waiting for that socket to work.
+ err = '???'
+ for _ in range(123):
+ time.sleep(0.05)
+ try:
+ ps_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ ps_sock.connect(linux_addr)
+ except socket.error as err: # pylint: disable=unused-variable
+ continue
+ ps_sock.close()
+ break
+ else:
+ # The socket failed or never accepted connections, assume our
+ # portserver setup attempt failed and bail out.
+ if cls.portserver_process.poll() is not None:
+ cls.portserver_process.kill()
+ cls.portserver_process.wait()
+ cls.portserver_process = None
+ raise unittest.SkipTest(
+ 'Unable to connect to our own portserver.py: %s' % err)
+ # Point child processes at our shiny portserver process.
+ os.environ['PORTSERVER_ADDRESS'] = cls.portserver_address
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.portserver_process:
+ if os.environ.get('PORTSERVER_ADDRESS') == cls.portserver_address:
+ del os.environ['PORTSERVER_ADDRESS']
+ if cls.portserver_process.poll() is None:
+ cls.portserver_process.kill()
+ cls.portserver_process.wait()
+ cls.portserver_process = None
+
def testPickUnusedCanSuccessfullyUsePortServer(self):
with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
@@ -67,8 +118,6 @@
self.assertTrue(self.IsUnusedTCPPort(port))
self.assertTrue(self.IsUnusedUDPPort(port))
- @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
- 'no port server to test against')
def testPickUnusedCanSuccessfullyUsePortServerAddressKwarg(self):
with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
@@ -87,8 +136,6 @@
finally:
os.environ['PORTSERVER_ADDRESS'] = addr
- @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
- 'no port server to test against')
def testGetPortFromPortServer(self):
"""Exercise the get_port_from_port_server() helper function."""
for _ in range(10):
@@ -97,6 +144,16 @@
self.assertTrue(self.IsUnusedTCPPort(port))
self.assertTrue(self.IsUnusedUDPPort(port))
+
+class PickUnusedPortTest(CommonTestMixin, unittest.TestCase):
+
+ def testPickUnusedPortActuallyWorks(self):
+ """This test can be flaky."""
+ for _ in range(10):
+ port = portpicker.pick_unused_port()
+ self.assertTrue(self.IsUnusedTCPPort(port))
+ self.assertTrue(self.IsUnusedUDPPort(port))
+
def testSendsPidToPortServer(self):
with ExitStack() as stack:
if _winapi:
@@ -377,5 +434,92 @@
portpicker.GetPortFromPortServer)
+def get_open_listen_tcp_ports():
+ netstat = subprocess.run(['netstat', '-lnt'], capture_output=True,
+ encoding='utf-8')
+ if netstat.returncode != 0:
+ raise unittest.SkipTest('Unable to run netstat -lnt to list binds.')
+ rows = (line.split() for line in netstat.stdout.splitlines())
+ listen_addrs = (row[3] for row in rows if row[0].startswith('tcp'))
+ listen_ports = [int(addr.split(':')[-1]) for addr in listen_addrs]
+ return listen_ports
+
+
+@unittest.skipUnless((sys.executable and os.access(sys.executable, os.X_OK))
+ or (os.environ.get('TEST_PORTPICKER_CLI') and
+ os.access(os.environ['TEST_PORTPICKER_CLI'], os.X_OK)),
+ 'sys.executable portpicker.__file__ not launchable and '
+ ' no TEST_PORTPICKER_CLI supplied.')
+class PortpickerCommandLineTests(unittest.TestCase):
+ def setUp(self):
+ self.main_py = portpicker.__file__
+
+ def _run_portpicker(self, pp_args, env_override=None):
+ env = dict(os.environ)
+ if env_override:
+ env.update(env_override)
+ if os.environ.get('TEST_PORTPICKER_CLI'):
+ pp_command = [os.environ['TEST_PORTPICKER_CLI']]
+ else:
+ pp_command = [sys.executable, '-m', 'portpicker']
+ return subprocess.run(pp_command + pp_args,
+ capture_output=True,
+ env=env,
+ encoding='utf-8',
+ check=False)
+
+ def test_command_line_help(self):
+ cmd = self._run_portpicker(['-h'])
+ self.assertNotEqual(0, cmd.returncode)
+ self.assertIn('usage', cmd.stdout)
+ self.assertIn('passed an arg', cmd.stdout)
+ cmd = self._run_portpicker(['--help'])
+ self.assertNotEqual(0, cmd.returncode)
+ self.assertIn('usage', cmd.stdout)
+ self.assertIn('passed an arg', cmd.stdout)
+
+ def test_command_line_help_text_dedented(self):
+ cmd = self._run_portpicker(['-h'])
+ self.assertNotEqual(0, cmd.returncode)
+ self.assertIn('\nIf passed an arg', cmd.stdout)
+ self.assertIn('\n #!/bin/bash', cmd.stdout)
+ self.assertIn('\nOlder versions ', cmd.stdout)
+
+ def test_command_line_interface(self):
+ cmd = self._run_portpicker([str(os.getpid())])
+ cmd.check_returncode()
+ port = int(cmd.stdout)
+ self.assertNotEqual(0, port, msg=cmd)
+ listen_ports = sorted(get_open_listen_tcp_ports())
+ self.assertNotIn(port, listen_ports, msg='expected nothing to be bound to port.')
+
+ def test_command_line_interface_no_portserver(self):
+ cmd = self._run_portpicker([str(os.getpid())],
+ env_override={'PORTSERVER_ADDRESS': ''})
+ cmd.check_returncode()
+ port = int(cmd.stdout)
+ self.assertNotEqual(0, port, msg=cmd)
+ listen_ports = sorted(get_open_listen_tcp_ports())
+ self.assertNotIn(port, listen_ports, msg='expected nothing to be bound to port.')
+
+ def test_command_line_interface_no_portserver_bind_timeout(self):
+ # This test is timing sensitive and leaves that bind process hanging
+ # around consuming resources until it dies on its own unless the test
+ # runner kills the process group upon exit.
+ timeout = 9.5
+ before = time.monotonic()
+ cmd = self._run_portpicker([str(os.getpid()), str(timeout)],
+ env_override={'PORTSERVER_ADDRESS': ''})
+ self.assertEqual(0, cmd.returncode, msg=(cmd.stdout, cmd.stderr))
+ port = int(cmd.stdout)
+ self.assertNotEqual(0, port, msg=cmd)
+ if 'WARNING' in cmd.stderr:
+ raise unittest.SkipTest('bind timeout not supported on this platform.')
+ listen_ports = sorted(get_open_listen_tcp_ports())
+ self.assertIn(port, listen_ports, msg='expected port to be bound. '
+ '%f seconds elapsed of %f bind timeout.' %
+ (time.monotonic() - before, timeout))
+
+
if __name__ == '__main__':
unittest.main()