Merge pull request #1 from pmarks-net/master

Use both IPv4+IPv6 sockets to check whether a port is free.
diff --git a/src/portpicker.py b/src/portpicker.py
index a77869e..08dde0a 100644
--- a/src/portpicker.py
+++ b/src/portpicker.py
@@ -55,6 +55,10 @@
     This is primarily a helper function for PickUnusedPort, used to see
     if a particular port number is available.
 
+    For the port to be considered available, the kernel must support at least
+    one of (IPv6, IPv4), and the port must be available on each supported
+    family.
+
     Args:
       port: The port number to bind to, or 0 to have the OS pick a free port.
       socket_type: The type of the socket (ex: socket.SOCK_STREAM).
@@ -63,15 +67,24 @@
     Returns:
       The port number on success or None on failure.
     """
-    sock = socket.socket(socket.AF_INET, socket_type, socket_proto)
-    try:
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        sock.bind(('', port))
-        return sock.getsockname()[1]
-    except socket.error:
-        return None
-    finally:
-        sock.close()
+    got_socket = False
+    for family in (socket.AF_INET6, socket.AF_INET):
+        try:
+            sock = socket.socket(family, socket_type, socket_proto)
+            got_socket = True
+        except socket.error:
+            continue
+        try:
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            sock.bind(('', port))
+            if socket_type == socket.SOCK_STREAM:
+                sock.listen(1)
+            port = sock.getsockname()[1]
+        except socket.error:
+            return None
+        finally:
+            sock.close()
+    return port if got_socket else None
 
 Bind = bind  # legacy API. pylint: disable=invalid-name
 
@@ -84,8 +97,7 @@
     Returns:
       boolean, whether it is free to use for both TCP and UDP
     """
-    return (bind(port, _PROTOS[0][0], _PROTOS[0][1]) and
-            bind(port, _PROTOS[1][0], _PROTOS[1][1]))
+    return bind(port, *_PROTOS[0]) and bind(port, *_PROTOS[1])
 
 IsPortFree = is_port_free  # legacy API. pylint: disable=invalid-name
 
diff --git a/src/portserver.py b/src/portserver.py
index 54a480f..fcade6c 100644
--- a/src/portserver.py
+++ b/src/portserver.py
@@ -38,6 +38,9 @@
 
 log = None  # Initialized to a logging.Logger by _configure_logging().
 
+_PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP),
+           (socket.SOCK_DGRAM, socket.IPPROTO_UDP)]
+
 
 def _get_process_command_line(pid):
     try:
@@ -55,23 +58,51 @@
         return 0
 
 
-def _port_is_available(port):
-    """Return False if the given network port is currently in use."""
-    for socket_type, proto in ((socket.SOCK_STREAM, socket.IPPROTO_TCP),
-                               (socket.SOCK_DGRAM, 0)):
-        sock = None
+# TODO: Consider importing portpicker.bind() instead of duplicating the code.
+def _bind(port, socket_type, socket_proto):
+    """Try to bind to a socket of the specified type, protocol, and port.
+
+    For the port to be considered available, the kernel must support at least
+    one of (IPv6, IPv4), and the port must be available on each supported
+    family.
+
+    Args:
+      port: The port number to bind to, or 0 to have the OS pick a free port.
+      socket_type: The type of the socket (ex: socket.SOCK_STREAM).
+      socket_proto: The protocol of the socket (ex: socket.IPPROTO_TCP).
+
+    Returns:
+      The port number on success or None on failure.
+    """
+    got_socket = False
+    for family in (socket.AF_INET6, socket.AF_INET):
         try:
-            sock = socket.socket(socket.AF_INET, socket_type, proto)
+            sock = socket.socket(family, socket_type, socket_proto)
+            got_socket = True
+        except socket.error:
+            continue
+        try:
             sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             sock.bind(('', port))
             if socket_type == socket.SOCK_STREAM:
                 sock.listen(1)
+            port = sock.getsockname()[1]
         except socket.error:
-            return False
+            return None
         finally:
-            if sock:
-                sock.close()
-    return True
+            sock.close()
+    return port if got_socket else None
+
+
+def _is_port_free(port):
+    """Check if specified port is free.
+
+    Args:
+      port: integer, port to check
+    Returns:
+      boolean, whether it is free to use for both TCP and UDP
+    """
+    return _bind(port, *_PROTOS[0]) and _bind(port, *_PROTOS[1])
 
 
 def _should_allocate_port(pid):
@@ -149,7 +180,7 @@
             check_count += 1
             if (candidate.start_time == 0 or
                 candidate.start_time != _get_process_start_time(candidate.pid)):
-                if _port_is_available(candidate.pid):
+                if _is_port_free(candidate.pid):
                     candidate.pid = pid
                     candidate.start_time = _get_process_start_time(pid)
                     if not candidate.start_time:
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py
index daabb41..9e826a6 100644
--- a/src/tests/portpicker_test.py
+++ b/src/tests/portpicker_test.py
@@ -16,9 +16,11 @@
 #
 """Unittests for the portpicker module."""
 
+from __future__ import print_function
 import os
 import random
 import socket
+import sys
 import unittest
 
 try:
@@ -137,6 +139,52 @@
                 self.assertTrue(self.IsUnusedTCPPort(port))
                 self.assertTrue(self.IsUnusedUDPPort(port))
 
+    def testIsPortFree(self):
+        """This might be flaky unless this test is run with a portserver."""
+        # The port should be free initially.
+        port = portpicker.pick_unused_port()
+        self.assertTrue(portpicker.is_port_free(port))
+
+        cases = [
+            (socket.AF_INET,  socket.SOCK_STREAM, None),
+            (socket.AF_INET6, socket.SOCK_STREAM, 0),
+            (socket.AF_INET6, socket.SOCK_STREAM, 1),
+            (socket.AF_INET,  socket.SOCK_DGRAM,  None),
+            (socket.AF_INET6, socket.SOCK_DGRAM,  0),
+            (socket.AF_INET6, socket.SOCK_DGRAM,  1),
+        ]
+        for (sock_family, sock_type, v6only) in cases:
+            # Occupy the port on a subset of possible protocols.
+            try:
+                sock = socket.socket(sock_family, sock_type, 0)
+            except socket.error:
+                print('Kernel does not support sock_family=%d' % sock_family,
+                      file=sys.stderr)
+                # Skip this case, since we cannot occupy a port.
+                continue
+            if v6only is not None:
+                try:
+                    sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
+                                    v6only)
+                except socket.error:
+                    print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
+                          file=sys.stderr)
+                    # Don't care; just proceed with the default.
+            sock.bind(('', port))
+
+            # The port should be busy.
+            self.assertFalse(portpicker.is_port_free(port))
+            sock.close()
+
+            # Now it's free again.
+            self.assertTrue(portpicker.is_port_free(port))
+
+    def testIsPortFreeException(self):
+        port = portpicker.pick_unused_port()
+        with mock.patch.object(socket, 'socket') as mock_sock:
+            mock_sock.side_effect = socket.error('fake socket error', 0)
+            self.assertFalse(portpicker.is_port_free(port))
+
     def testThatLegacyCapWordsAPIsExist(self):
         """The original APIs were CapWords style, 1.1 added PEP8 names."""
         self.assertEqual(portpicker.bind, portpicker.Bind)
diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py
index 2e49595..f0475c3 100644
--- a/src/tests/portserver_test.py
+++ b/src/tests/portserver_test.py
@@ -16,6 +16,7 @@
 #
 """Tests for the example portserver."""
 
+from __future__ import print_function
 import asyncio
 import os
 import socket
@@ -43,15 +44,49 @@
     def test_get_process_start_time(self):
         self.assertGreater(portserver._get_process_start_time(os.getpid()), 0)
 
-    def test_port_is_available_true(self):
+    def test_is_port_free(self):
         """This might be flaky unless this test is run with a portserver."""
-        # Insert Inception "we must go deeper" meme here.
-        self.assertTrue(portserver._port_is_available(self.port))
+        # The port should be free initially.
+        self.assertTrue(portserver._is_port_free(self.port))
 
-    def test_port_is_available_false(self):
+        cases = [
+            (socket.AF_INET,  socket.SOCK_STREAM, None),
+            (socket.AF_INET6, socket.SOCK_STREAM, 0),
+            (socket.AF_INET6, socket.SOCK_STREAM, 1),
+            (socket.AF_INET,  socket.SOCK_DGRAM,  None),
+            (socket.AF_INET6, socket.SOCK_DGRAM,  0),
+            (socket.AF_INET6, socket.SOCK_DGRAM,  1),
+        ]
+        for (sock_family, sock_type, v6only) in cases:
+            # Occupy the port on a subset of possible protocols.
+            try:
+                sock = socket.socket(sock_family, sock_type, 0)
+            except socket.error:
+                print('Kernel does not support sock_family=%d' % sock_family,
+                      file=sys.stderr)
+                # Skip this case, since we cannot occupy a port.
+                continue
+            if v6only is not None:
+                try:
+                    sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
+                                    v6only)
+                except socket.error:
+                    print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
+                          file=sys.stderr)
+                    # Don't care; just proceed with the default.
+            sock.bind(('', self.port))
+
+            # The port should be busy.
+            self.assertFalse(portserver._is_port_free(self.port))
+            sock.close()
+
+            # Now it's free again.
+            self.assertTrue(portserver._is_port_free(self.port))
+
+    def test_is_port_free_exception(self):
         with mock.patch.object(socket, 'socket') as mock_sock:
             mock_sock.side_effect = socket.error('fake socket error', 0)
-            self.assertFalse(portserver._port_is_available(self.port))
+            self.assertFalse(portserver._is_port_free(self.port))
 
     def test_should_allocate_port(self):
         self.assertFalse(portserver._should_allocate_port(0))
@@ -140,18 +175,18 @@
         self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0)
         self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536)
 
-    @mock.patch.object(portserver, '_port_is_available')
-    def test_get_port_for_process_ok(self, mock_port_is_available):
+    @mock.patch.object(portserver, '_is_port_free')
+    def test_get_port_for_process_ok(self, mock_is_port_free):
         self.pool.add_port_to_free_pool(self.port)
-        mock_port_is_available.return_value = True
+        mock_is_port_free.return_value = True
         self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid()))
         self.assertEqual(1, self.pool.ports_checked_for_last_request)
 
-    @mock.patch.object(portserver, '_port_is_available')
-    def test_get_port_for_process_none_left(self, mock_port_is_available):
+    @mock.patch.object(portserver, '_is_port_free')
+    def test_get_port_for_process_none_left(self, mock_is_port_free):
         self.pool.add_port_to_free_pool(self.port)
         self.pool.add_port_to_free_pool(22)
-        mock_port_is_available.return_value = False
+        mock_is_port_free.return_value = False
         self.assertEqual(2, self.pool.num_ports())
         self.assertEqual(0, self.pool.get_port_for_process(os.getpid()))
         self.assertEqual(2, self.pool.num_ports())