blob: 46fd911d2e51e83fc805a4d22a4d6903ebc62a77 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dask - based middleware implementation."""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import concurrent.futures
import functools
import multiprocessing
import multiprocessing.connection
import queue
import threading
from compiler_opt.core.abstract_worker import AbstractWorker
from typing import Callable, Optional, Tuple
def worker(ctor, in_q:queue.Queue, out_q:queue.Queue):
pool = ThreadPoolExecutor()
obj = ctor()
def make_ondone(msgid):
def on_done(f: concurrent.futures.Future):
if f.exception():
out_q.put((msgid, False, f.exception()))
else:
out_q.put((msgid, True, f.result()))
return on_done
while True:
msgid, fname, args, kwargs, urgent = in_q.get()
the_func = getattr(obj, fname)
if urgent:
try:
res = the_func(*args, **kwargs)
out_q.put((msgid, True, res))
except Exception as e:
out_q.put((msgid, False, e))
else:
pool.submit(the_func, *args,
**kwargs).add_done_callback(make_ondone(msgid))
class Stub:
def __init__(self, ctor):
self._send = multiprocessing.get_context().Queue()
self._receive = multiprocessing.get_context().Queue()
self._process = multiprocessing.Process(
target=functools.partial(worker, ctor, self._send, self._receive))
self._lock = threading.Lock()
self._map:dict[object, concurrent.futures.Future] = {}
self._pump = threading.Thread(target=self._msg_pump)
self._done = threading.Event()
self._msgidlock = threading.Lock()
self._msgid = 0
self._process.start()
self._pump.start()
def _msg_pump(self):
while not self._done.is_set():
msgid, succeeded, value = self._receive.get()
with self._lock:
future = self._map[msgid]
del self._map[msgid]
if succeeded:
future.set_result(value)
else:
future.set_exception(value)
def __getattr__(self, name):
with self._msgidlock:
msgid = self._msgid
self._msgid += 1
def remote_call(*args, **kwargs):
self._send.put(
(msgid, name, args, kwargs, name == 'cancel_all_work'))
future = concurrent.futures.Future()
with self._lock:
self._map[msgid] = future
return future
return remote_call
def kill(self):
try:
self._process.kill()
except:
pass
self._done.set()
def get_compilation_jobs(ctor: Callable[[], AbstractWorker],
count: Optional[int]) -> Tuple[Callable, list]:
if not count:
count = multiprocessing.cpu_count()
stubs = [Stub(ctor) for _ in range(count)]
def shutdown():
for s in stubs:
s.kill()
return shutdown, [Stub(ctor) for _ in range(count)]