blob: 0a399b224783e9bcfccfb27ffb03d5975f27b793 [file] [log] [blame]
"""Cross platform abstractions for inter-process communication
On Unix, this uses AF_UNIX sockets.
On Windows, this uses NamedPipes.
"""
import base64
import os
import shutil
import sys
import tempfile
from typing import Optional, Callable
MYPY = False
if MYPY:
from typing import Type
from types import TracebackType
if sys.platform == 'win32':
# This may be private, but it is needed for IPC on Windows, and is basically stable
import _winapi
import ctypes
_IPCHandle = int
kernel32 = ctypes.windll.kernel32
DisconnectNamedPipe = kernel32.DisconnectNamedPipe # type: Callable[[_IPCHandle], int]
FlushFileBuffers = kernel32.FlushFileBuffers # type: Callable[[_IPCHandle], int]
else:
import socket
_IPCHandle = socket.socket
class IPCException(Exception):
"""Exception for IPC issues."""
pass
class IPCBase:
"""Base class for communication between the dmypy client and server.
This contains logic shared between the client and server, such as reading
and writing.
"""
connection = None # type: _IPCHandle
def __init__(self, name: str, timeout: Optional[float]) -> None:
self.name = name
self.timeout = timeout
def read(self, size: int = 100000) -> bytes:
"""Read bytes from an IPC connection until its empty."""
bdata = bytearray()
if sys.platform == 'win32':
while True:
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
# TODO: remove once typeshed supports Literal types
assert isinstance(ov, _winapi.Overlapped)
assert isinstance(err, int)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException("Bad result from I/O wait: {}".format(res))
except BaseException:
ov.cancel()
raise
_, err = ov.GetOverlappedResult(True)
more = ov.getbuffer()
if more:
bdata.extend(more)
if err == 0:
# we are done!
break
elif err == _winapi.ERROR_MORE_DATA:
# read again
continue
elif err == _winapi.ERROR_OPERATION_ABORTED:
raise IPCException("ReadFile operation aborted.")
else:
while True:
more = self.connection.recv(size)
if not more:
break
bdata.extend(more)
return bytes(bdata)
def write(self, data: bytes) -> None:
"""Write bytes to an IPC connection."""
if sys.platform == 'win32':
try:
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
# TODO: remove once typeshed supports Literal types
assert isinstance(ov, _winapi.Overlapped)
assert isinstance(err, int)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException("Bad result from I/O wait: {}".format(res))
elif err != 0:
raise IPCException("Failed writing to pipe with error: {}".format(err))
except BaseException:
ov.cancel()
raise
bytes_written, err = ov.GetOverlappedResult(True)
assert err == 0, err
assert bytes_written == len(data)
except WindowsError as e:
raise IPCException("Failed to write with error: {}".format(e.winerror))
else:
self.connection.sendall(data)
self.connection.shutdown(socket.SHUT_WR)
def close(self) -> None:
if sys.platform == 'win32':
if self.connection != _winapi.NULL:
_winapi.CloseHandle(self.connection)
else:
self.connection.close()
class IPCClient(IPCBase):
"""The client side of an IPC connection."""
def __init__(self, name: str, timeout: Optional[float]) -> None:
super().__init__(name, timeout)
if sys.platform == 'win32':
timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
try:
_winapi.WaitNamedPipe(self.name, timeout)
except FileNotFoundError:
raise IPCException("The NamedPipe at {} was not found.".format(self.name))
except WindowsError as e:
if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
raise IPCException("Timed out waiting for connection.")
else:
raise
try:
self.connection = _winapi.CreateFile(
self.name,
_winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
0,
_winapi.NULL,
_winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED,
_winapi.NULL,
)
except WindowsError as e:
if e.winerror == _winapi.ERROR_PIPE_BUSY:
raise IPCException("The connection is busy.")
else:
raise
_winapi.SetNamedPipeHandleState(self.connection,
_winapi.PIPE_READMODE_MESSAGE,
None,
None)
else:
self.connection = socket.socket(socket.AF_UNIX)
self.connection.settimeout(timeout)
self.connection.connect(name)
def __enter__(self) -> 'IPCClient':
return self
def __exit__(self,
exc_ty: 'Optional[Type[BaseException]]' = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> bool:
self.close()
return False
class IPCServer(IPCBase):
BUFFER_SIZE = 2**16
def __init__(self, name: str, timeout: Optional[float] = None) -> None:
if sys.platform == 'win32':
name = r'\\.\pipe\{}-{}.pipe'.format(
name, base64.urlsafe_b64encode(os.urandom(6)).decode())
else:
name = '{}.sock'.format(name)
super().__init__(name, timeout)
if sys.platform == 'win32':
self.connection = _winapi.CreateNamedPipe(self.name,
_winapi.PIPE_ACCESS_DUPLEX
| _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
| _winapi.FILE_FLAG_OVERLAPPED,
_winapi.PIPE_READMODE_MESSAGE
| _winapi.PIPE_TYPE_MESSAGE
| _winapi.PIPE_WAIT
| 0x8, # PIPE_REJECT_REMOTE_CLIENTS
1, # one instance
self.BUFFER_SIZE,
self.BUFFER_SIZE,
_winapi.NMPWAIT_WAIT_FOREVER,
0, # Use default security descriptor
)
if self.connection == -1: # INVALID_HANDLE_VALUE
err = _winapi.GetLastError()
raise IPCException('Invalid handle to pipe: {err}'.format(err))
else:
self.sock_directory = tempfile.mkdtemp()
sockfile = os.path.join(self.sock_directory, self.name)
self.sock = socket.socket(socket.AF_UNIX)
self.sock.bind(sockfile)
self.sock.listen(1)
if timeout is not None:
self.sock.settimeout(timeout)
def __enter__(self) -> 'IPCServer':
if sys.platform == 'win32':
# NOTE: It is theoretically possible that this will hang forever if the
# client never connects, though this can be "solved" by killing the server
try:
ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
# TODO: remove once typeshed supports Literal types
assert isinstance(ov, _winapi.Overlapped)
except WindowsError as e:
# Don't raise if the client already exists, or the client already connected
if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
raise
else:
try:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
assert res == _winapi.WAIT_OBJECT_0
except BaseException:
ov.cancel()
_winapi.CloseHandle(self.connection)
raise
_, err = ov.GetOverlappedResult(True)
assert err == 0
else:
try:
self.connection, _ = self.sock.accept()
except socket.timeout:
raise IPCException('The socket timed out')
return self
def __exit__(self,
exc_ty: 'Optional[Type[BaseException]]' = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> bool:
if sys.platform == 'win32':
try:
# Wait for the client to finish reading the last write before disconnecting
if not FlushFileBuffers(self.connection):
raise IPCException("Failed to flush NamedPipe buffer,"
"maybe the client hung up?")
finally:
DisconnectNamedPipe(self.connection)
else:
self.close()
return False
def cleanup(self) -> None:
if sys.platform == 'win32':
self.close()
else:
shutil.rmtree(self.sock_directory)
@property
def connection_name(self) -> str:
if sys.platform == 'win32':
return self.name
else:
return self.sock.getsockname()