blob: d4f5f4745576db7809d830a64a29782cc281768e [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dask - based middleware implementation."""
from absl import logging
import gin
import dask.config
import dask.utils
import multiprocessing
import tempfile
from compiler_opt.core.abstract_worker import AbstractWorker
from concurrent.futures import ThreadPoolExecutor
from dask.distributed import Client, Worker, LocalCluster
from typing import Callable, Optional, Tuple
class MTWorker(Worker):
"""Multi-threaded worker."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.executors['actor']:
self.executors['actor'].shutdown()
self.executors['actor'] = ThreadPoolExecutor(
max_workers=None, thread_name_prefix='Dask-Actor-MT')
class LocalManager:
"""Local, dask-based worker manager."""
def __init__(self):
self._tmpdir = tempfile.TemporaryDirectory()
dask.config.set({
'temporary-directory': self._tmpdir.name,
'distributed.worker.daemon': False,
'work-stealing': False
})
self._client = Client(
dashboard_address=None,
processes=True,
n_workers=1,
worker_class=MTWorker)
print(self._client)
def shutdown(self):
self._client.close()
self._tmpdir.cleanup()
def get_client(self):
return self._client
def get_local_compilation_jobs(ctor: Callable[[], AbstractWorker],
count: Optional[int]) -> Tuple[Callable, list]:
class DaskStubWrapper(AbstractWorker):
def __init__(self, stub: AbstractWorker):
self._stub = stub
def cancel_all_work(self):
return self._stub.cancel_all_work(separate_thread=False)
def __getattr__(self, name):
return self._stub.__getattr__(name)
if not count:
count = multiprocessing.cpu_count()
instance = LocalManager()
workers = [
instance.get_client().submit(ctor, actor=True) for _ in range(count)
]
return instance.shutdown, [
DaskStubWrapper(worker.result()) for worker in workers
]