| // Copyright (C) 2019 The Android Open Source Project |
| // Copyright (C) 2019 Google Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| #include "aemu/base/threads/AndroidWorkPool.h" |
| |
| #include "aemu/base/threads/AndroidFunctorThread.h" |
| #include "aemu/base/synchronization/AndroidLock.h" |
| #include "aemu/base/synchronization/AndroidConditionVariable.h" |
| #include "aemu/base/synchronization/AndroidMessageChannel.h" |
| |
| #include <atomic> |
| #include <memory> |
| #include <unordered_map> |
| #include <sys/time.h> |
| |
| using android::base::guest::AutoLock; |
| using android::base::guest::ConditionVariable; |
| using android::base::guest::FunctorThread; |
| using android::base::guest::Lock; |
| using android::base::guest::MessageChannel; |
| |
| namespace android { |
| namespace base { |
| namespace guest { |
| |
| class WaitGroup { // intrusive refcounted |
| public: |
| |
| WaitGroup(int numTasksRemaining) : |
| mNumTasksInitial(numTasksRemaining), |
| mNumTasksRemaining(numTasksRemaining) { } |
| |
| ~WaitGroup() = default; |
| |
| android::base::guest::Lock& getLock() { return mLock; } |
| |
| void acquire() { |
| if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) { |
| ALOGE("%s: goofed, refcount0 acquire\n", __func__); |
| abort(); |
| } |
| } |
| |
| bool release() { |
| if (0 == mRefCount) { |
| ALOGE("%s: goofed, refcount0 release\n", __func__); |
| abort(); |
| } |
| if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) { |
| std::atomic_thread_fence(std::memory_order_acquire); |
| delete this; |
| return true; |
| } |
| return false; |
| } |
| |
| // wait on all of or any of the associated tasks to complete. |
| bool waitAllLocked(WorkPool::TimeoutUs timeout) { |
| return conditionalTimeoutLocked( |
| [this] { return mNumTasksRemaining > 0; }, |
| timeout); |
| } |
| |
| bool waitAnyLocked(WorkPool::TimeoutUs timeout) { |
| return conditionalTimeoutLocked( |
| [this] { return mNumTasksRemaining == mNumTasksInitial; }, |
| timeout); |
| } |
| |
| // broadcasts to all waiters that there has been a new job that has completed |
| bool decrementBroadcast() { |
| AutoLock<Lock> lock(mLock); |
| bool done = |
| (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst)); |
| std::atomic_thread_fence(std::memory_order_acquire); |
| mCv.broadcast(); |
| return done; |
| } |
| |
| private: |
| |
| bool doWait(WorkPool::TimeoutUs timeout) { |
| if (timeout == ~0ULL) { |
| ALOGV("%s: uncond wait\n", __func__); |
| mCv.wait(&mLock); |
| return true; |
| } else { |
| return mCv.timedWait(&mLock, getDeadline(timeout)); |
| } |
| } |
| |
| struct timespec getDeadline(WorkPool::TimeoutUs relative) { |
| struct timeval deadlineUs; |
| struct timespec deadlineNs; |
| gettimeofday(&deadlineUs, 0); |
| |
| auto prevDeadlineUs = deadlineUs.tv_usec; |
| |
| deadlineUs.tv_usec += relative; |
| |
| // Wrap around |
| if (prevDeadlineUs > deadlineUs.tv_usec) { |
| ++deadlineUs.tv_sec; |
| } |
| |
| deadlineNs.tv_sec = deadlineUs.tv_sec; |
| deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL; |
| return deadlineNs; |
| } |
| |
| uint64_t currTimeUs() { |
| struct timeval tv; |
| gettimeofday(&tv, 0); |
| return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec); |
| } |
| |
| bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) { |
| uint64_t currTime = currTimeUs(); |
| WorkPool::TimeoutUs currTimeout = timeout; |
| |
| while (conditionFunc()) { |
| doWait(currTimeout); |
| if (!conditionFunc()) { |
| // Decrement timeout for wakeups |
| uint64_t nextTime = currTimeUs(); |
| WorkPool::TimeoutUs waited = |
| nextTime - currTime; |
| currTime = nextTime; |
| |
| if (currTimeout > waited) { |
| currTimeout -= waited; |
| } else { |
| return conditionFunc(); |
| } |
| } |
| } |
| |
| return true; |
| } |
| |
| std::atomic<int> mRefCount = { 1 }; |
| int mNumTasksInitial; |
| std::atomic<int> mNumTasksRemaining; |
| |
| Lock mLock; |
| ConditionVariable mCv; |
| }; |
| |
| class WorkPoolThread { |
| public: |
| // State diagram for each work pool thread |
| // |
| // Unacquired: (Start state) When no one else has claimed the thread. |
| // Acquired: When the thread has been claimed for work, |
| // but work has not been issued to it yet. |
| // Scheduled: When the thread is running tasks from the acquirer. |
| // Exiting: cleanup |
| // |
| // Messages: |
| // |
| // Acquire |
| // Run |
| // Exit |
| // |
| // Transitions: |
| // |
| // Note: While task is being run, messages will come back with a failure value. |
| // |
| // Unacquired: |
| // message Acquire -> Acquired. effect: return success value |
| // message Run -> Unacquired. effect: return failure value |
| // message Exit -> Exiting. effect: return success value |
| // |
| // Acquired: |
| // message Acquire -> Acquired. effect: return failure value |
| // message Run -> Scheduled. effect: run the task, return success |
| // message Exit -> Exiting. effect: return success value |
| // |
| // Scheduled: |
| // implicit effect: after task is run, transition back to Unacquired. |
| // message Acquire -> Scheduled. effect: return failure value |
| // message Run -> Scheduled. effect: return failure value |
| // message Exit -> queue up exit message, then transition to Exiting after that is done. |
| // effect: return success value |
| // |
| enum State { |
| Unacquired = 0, |
| Acquired = 1, |
| Scheduled = 2, |
| Exiting = 3, |
| }; |
| |
| WorkPoolThread() : mThread([this] { threadFunc(); }) { |
| mThread.start(); |
| } |
| |
| ~WorkPoolThread() { |
| exit(); |
| mThread.wait(); |
| } |
| |
| bool acquire() { |
| AutoLock<Lock> lock(mLock); |
| switch (mState) { |
| case State::Unacquired: |
| mState = State::Acquired; |
| return true; |
| case State::Acquired: |
| case State::Scheduled: |
| case State::Exiting: |
| return false; |
| } |
| } |
| |
| bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) { |
| AutoLock<Lock> lock(mLock); |
| switch (mState) { |
| case State::Unacquired: |
| return false; |
| case State::Acquired: { |
| mState = State::Scheduled; |
| mToCleanupWaitGroupHandle = waitGroupHandle; |
| waitGroup->acquire(); |
| mToCleanupWaitGroup = waitGroup; |
| mShouldCleanupWaitGroup = false; |
| TaskInfo msg = { |
| Command::Run, |
| waitGroup, task, |
| }; |
| mRunMessages.send(msg); |
| return true; |
| } |
| case State::Scheduled: |
| case State::Exiting: |
| return false; |
| } |
| } |
| |
| bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) { |
| AutoLock<Lock> lock(mLock); |
| bool res = mShouldCleanupWaitGroup; |
| *waitGroupHandle = mToCleanupWaitGroupHandle; |
| *waitGroup = mToCleanupWaitGroup; |
| mShouldCleanupWaitGroup = false; |
| return res; |
| } |
| |
| private: |
| enum Command { |
| Run = 0, |
| Exit = 1, |
| }; |
| |
| struct TaskInfo { |
| Command cmd; |
| WaitGroup* waitGroup = nullptr; |
| WorkPool::Task task = {}; |
| }; |
| |
| bool exit() { |
| AutoLock<Lock> lock(mLock); |
| TaskInfo msg { Command::Exit, }; |
| mRunMessages.send(msg); |
| return true; |
| } |
| |
| void threadFunc() { |
| TaskInfo taskInfo; |
| bool done = false; |
| |
| while (!done) { |
| mRunMessages.receive(&taskInfo); |
| switch (taskInfo.cmd) { |
| case Command::Run: |
| doRun(taskInfo); |
| break; |
| case Command::Exit: { |
| AutoLock<Lock> lock(mLock); |
| mState = State::Exiting; |
| break; |
| } |
| } |
| AutoLock<Lock> lock(mLock); |
| done = mState == State::Exiting; |
| } |
| } |
| |
| // Assumption: the wait group refcount is >= 1 when entering |
| // this function (before decrement).. |
| // at least it doesn't get to 0 |
| void doRun(TaskInfo& msg) { |
| WaitGroup* waitGroup = msg.waitGroup; |
| |
| if (msg.task) msg.task(); |
| |
| bool lastTask = |
| waitGroup->decrementBroadcast(); |
| |
| AutoLock<Lock> lock(mLock); |
| mState = State::Unacquired; |
| |
| if (lastTask) { |
| mShouldCleanupWaitGroup = true; |
| } |
| |
| waitGroup->release(); |
| } |
| |
| FunctorThread mThread; |
| Lock mLock; |
| State mState = State::Unacquired; |
| MessageChannel<TaskInfo, 4> mRunMessages; |
| WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0; |
| WaitGroup* mToCleanupWaitGroup = nullptr; |
| bool mShouldCleanupWaitGroup = false; |
| }; |
| |
| class WorkPool::Impl { |
| public: |
| Impl(int numInitialThreads) : mThreads(numInitialThreads) { |
| for (size_t i = 0; i < mThreads.size(); ++i) { |
| mThreads[i].reset(new WorkPoolThread); |
| } |
| } |
| |
| ~Impl() = default; |
| |
| WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) { |
| |
| if (tasks.empty()) abort(); |
| |
| AutoLock<Lock> lock(mLock); |
| |
| // Sweep old wait groups |
| for (size_t i = 0; i < mThreads.size(); ++i) { |
| WaitGroupHandle handle; |
| WaitGroup* waitGroup; |
| bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup); |
| if (cleanup) { |
| mWaitGroups.erase(handle); |
| waitGroup->release(); |
| } |
| } |
| |
| WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked(); |
| WaitGroup* waitGroup = |
| new WaitGroup(tasks.size()); |
| |
| mWaitGroups[resHandle] = waitGroup; |
| |
| std::vector<size_t> threadIndices; |
| |
| while (threadIndices.size() < tasks.size()) { |
| for (size_t i = 0; i < mThreads.size(); ++i) { |
| if (!mThreads[i]->acquire()) continue; |
| threadIndices.push_back(i); |
| if (threadIndices.size() == tasks.size()) break; |
| } |
| if (threadIndices.size() < tasks.size()) { |
| mThreads.resize(mThreads.size() + 1); |
| mThreads[mThreads.size() - 1].reset(new WorkPoolThread); |
| } |
| } |
| |
| // every thread here is acquired |
| for (size_t i = 0; i < threadIndices.size(); ++i) { |
| mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]); |
| } |
| |
| return resHandle; |
| } |
| |
| bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) { |
| AutoLock<Lock> lock(mLock); |
| auto it = mWaitGroups.find(waitGroupHandle); |
| if (it == mWaitGroups.end()) return true; |
| |
| auto waitGroup = it->second; |
| waitGroup->acquire(); |
| lock.unlock(); |
| |
| bool waitRes = false; |
| |
| { |
| AutoLock<Lock> waitGroupLock(waitGroup->getLock()); |
| waitRes = waitGroup->waitAnyLocked(timeout); |
| } |
| |
| waitGroup->release(); |
| |
| return waitRes; |
| } |
| |
| bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) { |
| auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle); |
| if (!waitGroup) return true; |
| |
| bool waitRes = false; |
| |
| { |
| AutoLock<Lock> waitGroupLock(waitGroup->getLock()); |
| waitRes = waitGroup->waitAllLocked(timeout); |
| } |
| |
| waitGroup->release(); |
| |
| return waitRes; |
| } |
| |
| private: |
| // Increments wait group refcount by 1. |
| WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) { |
| AutoLock<Lock> lock(mLock); |
| auto it = mWaitGroups.find(waitGroupHandle); |
| if (it == mWaitGroups.end()) return nullptr; |
| |
| auto waitGroup = it->second; |
| waitGroup->acquire(); |
| |
| return waitGroup; |
| } |
| |
| using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>; |
| |
| WorkPool::WaitGroupHandle genWaitGroupHandleLocked() { |
| WorkPool::WaitGroupHandle res = mNextWaitGroupHandle; |
| ++mNextWaitGroupHandle; |
| return res; |
| } |
| |
| Lock mLock; |
| uint64_t mNextWaitGroupHandle = 0; |
| WaitGroupStore mWaitGroups; |
| std::vector<std::unique_ptr<WorkPoolThread>> mThreads; |
| }; |
| |
| WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { } |
| WorkPool::~WorkPool() = default; |
| |
| WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) { |
| return mImpl->schedule(tasks); |
| } |
| |
| bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) { |
| return mImpl->waitAny(waitGroup, timeout); |
| } |
| |
| bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) { |
| return mImpl->waitAll(waitGroup, timeout); |
| } |
| |
| } // namespace guest |
| } // namespace base |
| } // namespace android |