blob: f8bdc811edd7ba25611e0c53b13126c0df8f71b9 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An optimal push-pull-based load balancer which attempts to maintain at least
`buffer` tasks assigned to each worker.
"""
import concurrent.futures
import threading
from typing import List, Callable, TypeVar
from compiler_opt.distributed import worker
T = TypeVar('T')
def schedule(work: List[Callable[[T], worker.WorkerFuture]],
workers: List[T],
buffer=2) -> List[worker.WorkerFuture]:
"""
Assigns work to workers once previous work of the worker are
completed.
Args:
work: Function to call with a worker.
workers: List of workers that are the singular argument to callable.
buffer: Number of work to maintain on each worker.
Returns:
A list of Futures.
"""
# Create futures to be returned first, these futures aren't bound to
# anything now, but they will be later.
results = [concurrent.futures.Future() for _ in range(len(work))]
idx = -1
idx_lock = threading.Lock()
# Simple atomic increment and get.
# Used to iterate over `work` like a thread-safe queue without making a copy.
def fetch_idx():
nonlocal idx
with idx_lock:
idx += 1
return idx
def make_result_handler(wkr: T, result_future: concurrent.futures.Future):
def handler(worker_future: concurrent.futures.Future):
if (e := worker_future.exception()) is not None:
result_future.set_exception(e)
else:
result_future.set_result(worker_future.result())
chain_work(wkr)
return handler
def chain_work(wkr: T):
if (i := fetch_idx()) < len(work):
# This potentially causes a deadlock if chain_work is called via a
# future.set_result() context which holds a resource that is also required
# to complete the call work[i](wkr) call below. For an example, see:
# https://gist.github.com/Northbadge/a57f2d4e0a71e8f3934bdb47e59e343e
# A fix/workaround would be using threading below, but that introduces
# overhead of creating a new thread.
work[i](wkr).add_done_callback(make_result_handler(wkr, results[i]))
# Use min() in case buffer is huge for some reason.
for _ in range(min(buffer, (len(work) // len(workers)) + 1)):
for w in workers:
chain_work(w)
return results