blob: e6090f0f69027f26a0496e260840c89deb8b1bd6 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Local Process Pool - based middleware implementation.
This is a simple implementation of a worker pool, running on the local machine.
Each worker object is hosted by a separate process. Each worker object may
handle a number of concurrent requests. The client is given a stub object that
exposes the same methods as the worker, just that they return Futures.
There is a bidirectional pipe between a stub and its corresponding
process/worker. One direction is used to place tasks (method calls), the other
to place results. Tasks and results are correlated by a monotonically
incrementing counter maintained by the stub.
The worker process dequeues tasks promptly and either re-enqueues them to a
local thread pool, or, if the task is 'urgent', it executes it promptly.
"""
import concurrent.futures
import dataclasses
import functools
import multiprocessing
import psutil
import threading
from absl import logging
# pylint: disable=unused-import
from compiler_opt.distributed.worker import Worker, FixedWorkerPool
from contextlib import AbstractContextManager
from multiprocessing import connection
from typing import Any, Callable, Dict, List, Optional
@dataclasses.dataclass(frozen=True)
class Task:
msgid: int
func_name: str
args: tuple
kwargs: dict
is_urgent: bool
@dataclasses.dataclass(frozen=True)
class TaskResult:
msgid: int
success: bool
value: Any
def _run_impl(pipe: connection.Connection, worker_class: 'type[Worker]', *args,
**kwargs):
"""Worker process entrypoint."""
# A setting of 1 does not inhibit the while loop below from running since
# that runs on the main thread of the process. Urgent tasks will still
# process near-immediately. `threads` only controls how many threads are
# spawned at a time which execute given tasks. In the typical clang-spawning
# jobs, this effectively limits the number of clang instances spawned.
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
obj = worker_class(*args, **kwargs)
# Pipes are not thread safe
pipe_lock = threading.Lock()
def send(task_result: TaskResult):
with pipe_lock:
pipe.send(task_result)
def make_ondone(msgid):
def on_done(f: concurrent.futures.Future):
if f.exception():
send(TaskResult(msgid=msgid, success=False, value=f.exception()))
else:
send(TaskResult(msgid=msgid, success=True, value=f.result()))
return on_done
# Run forever. The stub will just kill the runner when done.
while True:
task: Task = pipe.recv()
the_func = getattr(obj, task.func_name)
application = functools.partial(the_func, *task.args, **task.kwargs)
if task.is_urgent:
try:
res = application()
send(TaskResult(msgid=task.msgid, success=True, value=res))
except BaseException as e: # pylint: disable=broad-except
send(TaskResult(msgid=task.msgid, success=False, value=e))
else:
pool.submit(application).add_done_callback(make_ondone(task.msgid))
def _run(*args, **kwargs):
try:
_run_impl(*args, **kwargs)
except BaseException as e:
logging.error(e)
raise e
def _make_stub(cls: 'type[Worker]', *args, **kwargs):
class _Stub:
"""Client stub to a worker hosted by a process."""
def __init__(self):
parent_pipe, child_pipe = multiprocessing.get_context().Pipe()
self._pipe = parent_pipe
self._pipe_lock = threading.Lock()
# this is the process hosting one worker instance.
# we set aside 1 thread to coordinate running jobs, and the main thread
# to handle high priority requests. The expectation is that the user
# achieves concurrency through multiprocessing, not multithreading.
self._process = multiprocessing.get_context().Process(
target=functools.partial(
_run, worker_class=cls, pipe=child_pipe, *args, **kwargs))
# lock for the msgid -> reply future map. The map will be set to None
# when we stop.
self._lock = threading.Lock()
self._map: Dict[int, concurrent.futures.Future] = {}
# thread draining the pipe
self._pump = threading.Thread(target=self._msg_pump)
# Set the state of this worker to "dead" if the process dies naturally.
def observer():
self._process.join()
# Feed the parent pipe a poison pill, this kills msg_pump
child_pipe.send(None)
self._observer = threading.Thread(target=observer)
# atomic control to _msgid
self._msgidlock = threading.Lock()
self._msgid = 0
# start the worker and the message pump
self._process.start()
# the observer must follow the process start, otherwise join() raises.
self._observer.start()
self._pump.start()
def _msg_pump(self):
while True:
task_result: Optional[TaskResult] = self._pipe.recv()
if task_result is None: # Poison pill fed by observer
break
with self._lock:
future = self._map[task_result.msgid]
del self._map[task_result.msgid]
# The following will trigger any callbacks defined on the future, as a
# direct function call. If those callbacks were set by the scheduler,
# it's important that self._lock isn't being held when they are being
# called, otherwise a deadlock could arise from __get_attr__ trying to
# acquire the lock.
if task_result.success:
future.set_result(task_result.value)
else:
future.set_exception(task_result.value)
# clear out pending futures and mark ourselves as "stopped" by null-ing
# the map
with self._lock:
for _, v in self._map.items():
v.set_exception(concurrent.futures.CancelledError())
self._map = None
def _is_stopped(self):
return self._map is None
def __getattr__(self, name) -> Callable[[Any], concurrent.futures.Future]:
result_future = concurrent.futures.Future()
with self._msgidlock:
msgid = self._msgid
self._msgid += 1
def remote_call(*args, **kwargs):
with self._lock:
if self._is_stopped():
result_future.set_exception(concurrent.futures.CancelledError())
else:
with self._pipe_lock:
self._pipe.send(
Task(
msgid=msgid,
func_name=name,
args=args,
kwargs=kwargs,
is_urgent=cls.is_priority_method(name)))
self._map[msgid] = result_future
return result_future
return remote_call
def shutdown(self):
try:
# Killing the process triggers observer exit, which triggers msg_pump
# exit
self._process.kill()
except: # pylint: disable=bare-except
pass
def set_nice(self, val: int):
"""Sets the nice-ness of the process, this modifies how the OS
schedules it. Only works on Unix, since val is presumed to be an int.
"""
psutil.Process(self._process.pid).nice(val)
def set_affinity(self, val: List[int]):
"""Sets the CPU affinity of the process, this modifies which cores the OS
schedules it on.
"""
psutil.Process(self._process.pid).cpu_affinity(val)
def join(self):
self._observer.join()
self._pump.join()
self._process.join()
def __dir__(self):
return [n for n in dir(cls) if not n.startswith('_')]
return _Stub()
class LocalWorkerPoolManager(AbstractContextManager):
"""A pool of workers hosted on the local machines, each in its own process."""
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
**kwargs):
if not count:
count = multiprocessing.get_context().cpu_count()
self._stubs = [
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
]
def __enter__(self) -> FixedWorkerPool:
return FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
def __exit__(self, *args):
# first, trigger killing the worker process and exiting of the msg pump,
# which will also clear out any pending futures.
for s in self._stubs:
s.shutdown()
# now wait for the message pumps to indicate they exit.
for s in self._stubs:
s.join()
def __del__(self):
self.__exit__()
@property
def stubs(self):
# Return a shallow copy, to avoid something messing the internal list up
return list(self._stubs)