174 lines
4.2 KiB
C++
174 lines
4.2 KiB
C++
#pragma once
|
|
|
|
#include "SafePriorityQueue.h"
|
|
#include <atomic>
|
|
#include <cassert>
|
|
#include <condition_variable>
|
|
#include <functional>
|
|
#include <future>
|
|
#include <iostream>
|
|
#include <thread>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace MyThreadPool {
|
|
|
|
class ThreadPool {
|
|
|
|
public:
|
|
ThreadPool()
|
|
: ThreadPool(0)
|
|
{
|
|
}
|
|
|
|
explicit ThreadPool(size_t workers)
|
|
: threadCount(workers)
|
|
{
|
|
createWorkers();
|
|
}
|
|
|
|
template <typename F, typename... Args, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>>
|
|
std::future<R> enqueue(int priority, F&& f, Args&&... args)
|
|
{
|
|
using ReturnType = decltype(f(std::forward<Args>(args)...));
|
|
|
|
// encapsulate it for copying into the queue
|
|
auto funcShared = std::make_shared<std::packaged_task<ReturnType()>>(
|
|
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
|
|
|
|
auto fut = funcShared->get_future();
|
|
|
|
// wrap the packaged task into a lambda
|
|
auto wrapperFunc = [task = std::move(funcShared)]() {
|
|
(*task)();
|
|
};
|
|
|
|
enqueueAndNotify(priority, std::move(wrapperFunc));
|
|
|
|
return fut;
|
|
}
|
|
|
|
size_t getTasksLeft() const
|
|
{
|
|
return queue.length();
|
|
}
|
|
|
|
size_t getTasksRunning() const
|
|
{
|
|
return threadsWorking;
|
|
}
|
|
|
|
size_t getTotalWorkers() const
|
|
{
|
|
return threadCount;
|
|
}
|
|
|
|
bool isIdle() const
|
|
{
|
|
return queue.empty() && threadsWorking == 0;
|
|
}
|
|
|
|
bool isClosed() const
|
|
{
|
|
return done;
|
|
}
|
|
|
|
/**
|
|
* Wait for all tasks to finish. Other threads can still new insert tasks into the pool.
|
|
* Will return upon emptying the queue.
|
|
*/
|
|
void waitForTasks(bool untilDone = true)
|
|
{
|
|
// Queue will empty when the last task is being processed, not when all tasks finish.
|
|
while (!queue.empty() || (untilDone && threadsWorking > 0)) {
|
|
std::unique_lock<std::mutex> poolLock(poolMtx);
|
|
poolCond.wait(poolLock);
|
|
}
|
|
}
|
|
|
|
~ThreadPool()
|
|
{
|
|
/*
|
|
* No need to wait until all tasks are truly finished. If, for whatever reason, one of the thread crashes or
|
|
* doesn't finish execution of the task, we can still join (wait for all threads to finish their tasks) and
|
|
* gracefully (though with possible corruption?) destroy the pool.
|
|
*/
|
|
waitForTasks(false);
|
|
destroyThreads();
|
|
}
|
|
|
|
private:
|
|
void enqueueAndNotify(int priority, std::function<void()>&& func)
|
|
{
|
|
queue.push(priority, std::move(func));
|
|
taskCond.notify_one();
|
|
}
|
|
|
|
/**
|
|
* Worker function for each thread in the pool.
|
|
*/
|
|
void worker()
|
|
{
|
|
while (!done) {
|
|
{
|
|
std::unique_lock<std::mutex> taskLock(taskMtx);
|
|
taskCond.wait(taskLock, [&]() { return done || !queue.empty(); });
|
|
}
|
|
|
|
// Don't even attempt to do another task if we're destroying the threads. Just bail.
|
|
if (done) {
|
|
return;
|
|
}
|
|
|
|
auto callable = std::move(queue.pop());
|
|
if (callable.has_value()) {
|
|
threadsWorking++;
|
|
callable.value()();
|
|
threadsWorking--;
|
|
poolCond.notify_all();
|
|
}
|
|
}
|
|
}
|
|
|
|
void createWorkers()
|
|
{
|
|
assert(threads.empty());
|
|
|
|
if (threadCount == 0) {
|
|
threadCount = std::thread::hardware_concurrency();
|
|
}
|
|
|
|
assert(threadCount > 0);
|
|
|
|
threads.reserve(threadCount);
|
|
for (size_t i = 0; i < threadCount; i++) {
|
|
threads.emplace_back(std::thread(&ThreadPool::worker, this));
|
|
}
|
|
}
|
|
|
|
void destroyThreads()
|
|
{
|
|
done = true;
|
|
taskCond.notify_all();
|
|
for (auto& thread : threads) {
|
|
if (thread.joinable()) {
|
|
thread.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
SafePriorityQueue<std::function<void()>> queue;
|
|
std::vector<std::thread> threads;
|
|
size_t threadCount;
|
|
std::atomic<size_t> threadsWorking { 0 };
|
|
std::atomic<bool> done { false };
|
|
|
|
std::condition_variable taskCond;
|
|
mutable std::mutex taskMtx;
|
|
|
|
std::condition_variable poolCond;
|
|
mutable std::mutex poolMtx;
|
|
};
|
|
|
|
}
|