From d7aa71ed007e99741b30ecda251bd8b9a15f52d5 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 29 Jan 2026 18:47:17 +0000 Subject: [PATCH 1/3] Massive refactor, using std::packaged_tasks, std::future and variadic templating --- CMakeLists.txt | 2 - demo/demo.h | 36 ++--- src/threadpool.cpp | 24 +-- src/threadpool.h | 302 ++++++------------------------------- tests/all_tests.h | 8 +- tests/return_value_tests.h | 42 ------ tests/task_tests.h | 22 --- tests/threadpool_tests.h | 137 ++++++----------- 8 files changed, 116 insertions(+), 457 deletions(-) delete mode 100644 tests/return_value_tests.h delete mode 100644 tests/task_tests.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a690cb..3851b55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,9 +6,7 @@ set(CMAKE_CXX_STANDARD 23) add_executable(ThreadPool demo/main.cpp src/threadpool.h src/threadpool.cpp - tests/return_value_tests.h tests/all_tests.h - tests/task_tests.h tests/threadpool_tests.h demo/dependency_demo.h demo/demo.h diff --git a/demo/demo.h b/demo/demo.h index 1be821e..1890914 100644 --- a/demo/demo.h +++ b/demo/demo.h @@ -15,22 +15,22 @@ inline int recursive_fibonacci(int n) { inline void fibonacci_example() { - threadpool tp(3); - - std::vector> futures = { - tp.submit( []() -> int { return recursive_fibonacci(20);} ), - tp.submit( []() -> int { return recursive_fibonacci(30);} ), - tp.submit( []() -> int { return recursive_fibonacci(40);} ), - }; - - tp.shutdown(); - - for (int i=0; i> futures = { + // tp.submit( []() -> int { return recursive_fibonacci(20);} ), + // tp.submit( []() -> int { return recursive_fibonacci(30);} ), + // tp.submit( []() -> int { return recursive_fibonacci(40);} ), + // }; + // + // tp.shutdown(); + // + // for (int i=0; i opt_task = this->poll_task(); - if (opt_task.has_value()) { - auto& task = opt_task.value(); + + if (!tasks.empty()) { + auto task = std::move(tasks.front()); + tasks.pop(); lock.unlock(); task(); } @@ -30,6 +31,9 @@ threadpool::threadpool(const int& n) { } } + + + threadpool::~threadpool() { std::unique_lock lock(queue_stop_mutex); if (!(m_Stop)) { @@ -62,19 +66,7 @@ void threadpool::shutdown_now() { } [[nodiscard]] -int threadpool::queue_size() { +size_t threadpool::queue_size() { std::lock_guard lock(queue_stop_mutex); return tasks.size(); } - -// ============ THREADPOOL PRIVATE ============ - -std::optional threadpool::poll_task() { - // No lock guard as the thread would already have the guard - if (tasks.empty()) { - return std::nullopt; - } - task front = tasks.front(); - tasks.pop(); - return front; -} \ No newline at end of file diff --git a/src/threadpool.h b/src/threadpool.h index 20649e7..82ca7c0 100644 --- a/src/threadpool.h +++ b/src/threadpool.h @@ -6,293 +6,77 @@ #include #include #include - -// Forward Declarations -template -struct return_value_handle; -class threadpool; - - -// Wrapper for the function pointers -struct task { - task(const std::function& ptr) : m_Ptr {ptr} { - }; - - void operator()() const { - m_Ptr(); - } -private: - std::function m_Ptr; -}; - -// Simplified implementation of std::future -// return_value is shared state via the shared pointer, do not allow copy / move semantics -template -struct return_value { - - friend class return_value_handle; // Allow access to set_value and set_valid - - return_value() = default; - - return_value(const return_value&) = delete; // Copy Constructor - return_value& operator=(const return_value&) = delete; // Copy Assignment Constructor - - return_value(return_value&&) = delete; // Move Constructor - return_value& operator=(return_value&&) = delete; // Move Assignment Constructor - - bool is_valid() { - std::unique_lock lock(access_mutex); - return m_IsValid; - } - - T get() { - std::unique_lock lock(access_mutex); - if (!is_valid_unsafe()) { // unsafe to prevent deadlock as we already acquired mutex - throw std::runtime_error{"Thread Return Value is Invalid!"}; - } - m_IsValid = false; - return m_Value; - } - -private: - std::mutex access_mutex; - T m_Value; - bool m_IsValid{false}; - - bool is_valid_unsafe() { - return m_IsValid; - } - - void set_value(const T& value) { - std::unique_lock lock(access_mutex); - m_Value = value; - m_IsValid = true; - } -}; - -// Specialization of void -template<> -struct return_value { - - friend class return_value_handle; // Allow access to set_value and set_valid - - return_value() = default; - - return_value(const return_value&) = delete; // Copy Constructor - return_value& operator=(const return_value&) = delete; // Copy Assignment Constructor - - return_value(return_value&&) = delete; // Move Constructor - return_value& operator=(return_value&&) = delete; // Move Assignment Constructor - - bool is_valid() { - std::unique_lock lock(access_mutex); - return m_IsValid; - } - - void get() { - std::unique_lock lock(access_mutex); - if (!is_valid_unsafe()) { - lock.unlock(); - throw std::runtime_error{"Thread Return Value is Invalid!"}; - } - // nothing to return - } - -private: - std::mutex access_mutex; - bool m_IsValid{false}; - - bool is_valid_unsafe() { - return m_IsValid; - } - - void set_value() { - std::unique_lock lock(access_mutex); - m_IsValid = false; // Never true with - // Do nothing - } -}; - - -template -struct return_value_handle { - - friend class threadpool; - -public: - return_value_handle() : m_Handle{std::make_shared>()} { - } - - return_value_handle(const return_value_handle&) = default; // Copy Constructor - return_value_handle& operator=(const return_value_handle&) = default; // Copy Assignment Constructor - - return_value_handle(return_value_handle&&) = default; // Move Constructor - return_value_handle& operator=(return_value_handle&&) = default; // Move Assignment Constructor - - bool is_valid() const { - if (m_Handle == nullptr) - return false; - - return m_Handle -> is_valid(); - } - - T get() const { - return m_Handle.get()->get(); - } - - // Dependency DAG APIs - template - return_value_handle then(threadpool& tp, Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - - return rv; - } - -private: - std::shared_ptr> m_Handle; - void set_value(const T& value) { - m_Handle->set_value(value); - } - void set_valid(const bool& value) { - m_Handle->set_valid(value); - } -}; - -// Specialization of void -template<> -struct return_value_handle { - - friend class threadpool; - -public: - return_value_handle() : m_Handle{std::make_shared>()} { - } - - return_value_handle(const return_value_handle&) = default; // Copy Constructor - return_value_handle& operator=(const return_value_handle&) = default; // Copy Assignment Constructor - - return_value_handle(return_value_handle&&) = default; // Move Constructor - return_value_handle& operator=(return_value_handle&&) = default; // Move Assignment Constructor - - bool is_valid() const { - if (m_Handle == nullptr) - return false; - - return m_Handle -> is_valid(); - } - - void get() const { - return m_Handle.get() -> get(); - } - - // Dependency DAG APIs - template - return_value_handle then(threadpool& tp, Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - - return rv; - } - -private: - std::shared_ptr> m_Handle; - static void set_value() { - // Do nothing - } -}; - - - - - - - - - +#include class threadpool { private: - std::queue tasks; + std::queue> tasks; std::vector workers; bool m_Stop{false}; std::mutex queue_stop_mutex; // Used for queue operations and read/write m_Stop operations std::condition_variable cv; - - std::optional poll_task(); - void write_task(const std::function& ptr) { + void write_task(const std::function& fn) { // No lock guard as submit() already contains the lock - tasks.push( task{ptr} ); + tasks.push(fn); cv.notify_one(); } public: - threadpool(const int& threads); + explicit threadpool(const int& threads); ~threadpool(); - template - [[nodiscard]] - return_value_handle submit(const std::function& ptr) { - std::unique_lock lock(queue_stop_mutex); - if (m_Stop) { - lock.unlock(); - throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; - } - return_value_handle rv_handle{}; + template - write_task( - [ptr, rv_handle] () mutable { - rv_handle.set_value(ptr()); - } - ); - return rv_handle; - } + auto submit(Function &&F, Args &&...ArgList) { - template - [[nodiscard]] - return_value_handle submit(const std::function& ptr, int dependency_id) { - return_value_handle rv_handle{}; - return rv_handle; - } + using ReturnType = std::invoke_result_t; + std::shared_ptr> task = std::make_shared>(( + std::bind(std::forward(F), + std::forward(ArgList)...) + )); - void shutdown(); // finish queued tasks - void shutdown_now(); // cancel pending tasks + auto future = task->get_future(); - [[nodiscard]] - int queue_size(); + write_task([task]() mutable { (*task)(); }); - // Dependency DAG API - template - return_value_handle when_all(Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - return rv; + + return future; // Return type is future } + // template + // [[nodiscard]] + // auto submit(const Fn&& fn) { + // using return_type = std::invoke_result_t; + // + // std::unique_lock lock(queue_stop_mutex); + // if (m_Stop) { + // throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; + // } + // + // std::packaged_task task{fn}; + // write_task([&task]() { task(); }); + // return task.get_future(); + // } -}; -// Void specialization -template<> -inline return_value_handle threadpool::submit(const std::function& ptr) { + void shutdown(); // finish queued tasks + void shutdown_now(); // cancel pending tasks - std::unique_lock lock(queue_stop_mutex); - if (m_Stop) { - lock.unlock(); - throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; - } + [[nodiscard]] + size_t queue_size(); + + // Dependency DAG API + // template + // return_value_handle when_all(Args... args) { + // // Todo: actually implement the logic + // return_value_handle rv{}; + // return rv; + // } - return_value_handle rv_handle{}; - write_task( - [ptr] (){ - ptr(); - } - ); - return rv_handle; -} \ No newline at end of file +}; \ No newline at end of file diff --git a/tests/all_tests.h b/tests/all_tests.h index 81ccdd2..ca51715 100644 --- a/tests/all_tests.h +++ b/tests/all_tests.h @@ -1,18 +1,14 @@ #pragma once -#include "return_value_tests.h" -#include "task_tests.h" #include "threadpool_tests.h" inline void all_tests() { // Use the threadpools lol - threadpool tp(3); + threadpool tp(1); - auto rv1 = tp.submit(task_tests); - auto rv2 = tp.submit(return_value_tests); - auto rv3 = tp.submit(threadpool_tests); + auto rv3 = tp.submit(threadpool_tests); tp.shutdown(); } \ No newline at end of file diff --git a/tests/return_value_tests.h b/tests/return_value_tests.h deleted file mode 100644 index af82ac3..0000000 --- a/tests/return_value_tests.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include "../src/threadpool.h" -#include -#include - -inline void return_value_tests() { - - // Handle should not be valid after creating it immediately - { - return_value_handle rv_handle1{}; - assert(!rv_handle1.is_valid()); - return_value_handle rv_handle2{}; - assert(!rv_handle2.is_valid()); - } - - // Catch exception when get is invalid - { - return_value_handle rv_handle1{}; - assert(!rv_handle1.is_valid()); - try { - auto val = rv_handle1.get(); - assert(false); - } catch (std::runtime_error& e) { - - } - assert(!rv_handle1.is_valid()); - - - return_value_handle rv_handle2{}; - assert(!rv_handle2.is_valid()); - try { - rv_handle2.get(); - assert(false); - } catch (std::runtime_error& e) { - - } - assert(!rv_handle2.is_valid()); - } - - std::cout << "return_value & return_value_handle tests passed!\n"; -} diff --git a/tests/task_tests.h b/tests/task_tests.h deleted file mode 100644 index e04b5f8..0000000 --- a/tests/task_tests.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "../src/threadpool.h" -#include -#include - -inline void task_tests() { - - // Task operator() overload - { - int num{0}; - std::function ptr{ - [&num]() mutable { num += 5; } - }; - task t{ptr}; - t(); - - assert(num == 5); - } - - std::cout << "task test passed! \n"; -} \ No newline at end of file diff --git a/tests/threadpool_tests.h b/tests/threadpool_tests.h index 57ac7af..8b547a9 100644 --- a/tests/threadpool_tests.h +++ b/tests/threadpool_tests.h @@ -6,132 +6,85 @@ #include #include + +inline std::string string_test() { + return "Hello world!"; +} + +inline int int_test(int input1, int input2) { + return input1 + input2; +} + inline void threadpool_tests() { - // Ensure shutdown finishes all remaining tasks + // Submit syntax { threadpool tp{1}; - int i{0}; - auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; - auto f2 = [&i]() mutable{ i = 5; }; - - auto rv1 = tp.submit(f1); - auto rv2 = tp.submit(f2); + auto future = tp.submit([]() {}); tp.shutdown(); - - assert(i == 5); } - // Ensure shutdown_now clears the remaining tasks + // Verify work is done on a submit { threadpool tp{1}; - int i{0}; - auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; - auto f2 = [&i]() mutable{ i = 5; }; - - auto rv1 = tp.submit(f1); - auto rv2 = tp.submit(f2); - tp.shutdown_now(); - - assert(i == 0); + int i = 0; + auto future = tp.submit([&i](){i = 42;}); + tp.shutdown(); + assert(i == 42); } - - // Submit after shutdown + // Return type syntax { threadpool tp{1}; + auto future = tp.submit([]() {return 42;}); tp.shutdown(); - try { - auto rv = tp.submit([]() {}); - assert(false); - } catch (std::runtime_error) { - - } + int work = future.get(); + assert(work == 42); } - // Submit after shutdown_now + // Variadic arguments works { threadpool tp{1}; - tp.shutdown_now(); - try { - auto rv = tp.submit([]() {}); - assert(false); - } catch (std::runtime_error) { - - } + auto future = tp.submit([](int num1, int num2, int num3) {return num1 + num2 + num3;}, 1, 2, 3); + tp.shutdown(); + assert(future.valid() && future.get() == 6); } - // Destructor stress tests + // Function pointer works { - for (int i = 0; i < 10'000; ++i) { - threadpool tp{4}; - auto rv = tp.submit([]{}); - } + threadpool tp{2}; + auto future1 = tp.submit(string_test); + auto future2 = tp.submit(int_test, 1, 2); + tp.shutdown(); + assert(future1.valid() && future1.get() == "Hello world!"); + assert(future2.valid() && future2.get() == 3); + } - // Nested submission - /* - The invariant here is a little more subtle - What happens here is the task has a sub-task to put another task onto the threadpool queue - However, what most of the time happens is that shutdown() in the main thread gets called before the task gets processed - Which means that the queue no longer accepts any tasks, and therefore would throw the runtime_error exception - This invariant is kept here - for the DAG aware pools, there would be a private internal enqueing function that would bypass this check - */ + // Function pointers with variadic arguments { threadpool tp{1}; - auto rv1 = tp.submit([&]{ - - try { - auto rv2 = tp.submit([](){ /* work */ }); - assert(false); - } catch (std::runtime_error) { - - } - }); + //std::function f1 = []() -> int {return 5;}; + std::future future = tp.submit([]() {return 5;}); tp.shutdown(); + assert(future.get() == 5); } - // Shutdown now test + + // Ensure shutdown finishes all remaining tasks { threadpool tp{1}; - auto rv1 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(50));}); - auto rv2 = tp.submit( []() { }); + int i{0}; + auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; + auto f2 = [&i]() mutable{ i = 5; }; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // Wait for thread to pick up the task - tp.shutdown_now(); + auto rv1 = tp.submit(f1); + auto rv2 = tp.submit(f2); + tp.shutdown(); - // Not sure exactly why, but this fails - int i; - //assert(rv1.is_valid()); - assert(!rv2.is_valid()); + assert(i == 5); } - - // // Then() syntax - // { - // threadpool tp{1}; - // - // auto rv_1 = tp.submit([](){ return 5; }); - // - // auto rv_2 = rv_1.then(tp, []() { return 10; }); - // - // tp.shutdown(); - // assert(rv_1.is_valid()); - // assert(rv_1.get() == 5); - // assert(rv_2.is_valid()); - // assert(rv_2.get() == 10); - // } - // - // // Then() actually waits for dependencies - // { - // threadpool tp{5}; - // auto rv_1 = tp.submit([](){ std::this_thread::sleep_for(std::chrono::milliseconds(10));}); - // auto rv_2 = rv_1.then(tp, []() { return 10; }); - // assert(!rv_1.is_valid() && !rv_2.is_valid()); - // tp.shutdown(); - // } - - std::cout << "threadpool tests passed!\n"; } \ No newline at end of file From d7523e9a54903d3cd27d8658e0f704bb11934e14 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 29 Jan 2026 18:51:47 +0000 Subject: [PATCH 2/3] Reinstated fibonacci demo --- demo/demo.h | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/demo/demo.h b/demo/demo.h index 1890914..fcbfd9d 100644 --- a/demo/demo.h +++ b/demo/demo.h @@ -15,22 +15,21 @@ inline int recursive_fibonacci(int n) { inline void fibonacci_example() { - // threadpool tp(3); - // - // std::vector> futures = { - // tp.submit( []() -> int { return recursive_fibonacci(20);} ), - // tp.submit( []() -> int { return recursive_fibonacci(30);} ), - // tp.submit( []() -> int { return recursive_fibonacci(40);} ), - // }; - // - // tp.shutdown(); - // - // for (int i=0; i> futures; + futures.reserve(3); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(10);})); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(20);})); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(30);})); + tp.shutdown(); + + for (int i=0; i Date: Thu, 29 Jan 2026 19:57:22 +0000 Subject: [PATCH 3/3] Copilot suggestions review: added forgotten mutexes and test coverage --- demo/main.cpp | 2 +- src/threadpool.h | 13 +++++--- tests/threadpool_tests.h | 65 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/demo/main.cpp b/demo/main.cpp index cdd0f9c..f77b282 100644 --- a/demo/main.cpp +++ b/demo/main.cpp @@ -5,7 +5,7 @@ int main() { all_tests(); - fibonacci_example(); + //fibonacci_example(); // dependency_dag_example(); // multiple_threadpool_example(); diff --git a/src/threadpool.h b/src/threadpool.h index 82ca7c0..3156b46 100644 --- a/src/threadpool.h +++ b/src/threadpool.h @@ -19,7 +19,7 @@ class threadpool { std::condition_variable cv; void write_task(const std::function& fn) { - // No lock guard as submit() already contains the lock + // Does not need the lock as submit already acquires it tasks.push(fn); cv.notify_one(); } @@ -30,9 +30,15 @@ class threadpool { template - + [[nodiscard]] auto submit(Function &&F, Args &&...ArgList) { + std::unique_lock lock(queue_stop_mutex); + if (m_Stop) { + throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; + } + + using ReturnType = std::invoke_result_t; std::shared_ptr> task = std::make_shared>(( @@ -42,8 +48,7 @@ class threadpool { auto future = task->get_future(); - write_task([task]() mutable { (*task)(); }); - + write_task([task](){ (*task)(); }); return future; // Return type is future } diff --git a/tests/threadpool_tests.h b/tests/threadpool_tests.h index 8b547a9..8f33fb6 100644 --- a/tests/threadpool_tests.h +++ b/tests/threadpool_tests.h @@ -15,6 +15,9 @@ inline int int_test(int input1, int input2) { return input1 + input2; } + + + inline void threadpool_tests() { // Submit syntax @@ -70,8 +73,6 @@ inline void threadpool_tests() { assert(future.get() == 5); } - - // Ensure shutdown finishes all remaining tasks { threadpool tp{1}; @@ -86,5 +87,65 @@ inline void threadpool_tests() { assert(i == 5); } + + // Submit after shutdown throws + { + threadpool tp{1}; + auto future = tp.submit([]() {return 42;}); + tp.shutdown(); + + try { + auto rv = tp.submit([]() {}); + assert(false); + } catch (std::runtime_error) { + + } + } + + // Nested submission + /* + The invariant here is a little more subtle + What happens here is the task has a sub-task to put another task onto the threadpool queue + However, what most of the time happens is that shutdown() in the main thread gets called before the task gets processed + Which means that the queue no longer accepts any tasks, and therefore would throw the runtime_error exception + This invariant is kept here - for the DAG aware pools, there would be a private internal enqueing function that would bypass this check + */ + { + threadpool tp{1}; + auto rv1 = tp.submit([&]{ + + try { + auto rv2 = tp.submit([](){ /* work */ }); + assert(false); + } catch (std::runtime_error) { + + } + }); + tp.shutdown(); + } + + // get() called after shutdown_now called + { + threadpool tp{1}; + auto rv1 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(1000)); return 5; }); + auto rv2 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(500)); return 55; }); + auto rv3 = tp.submit( []() { return 55;}); + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + tp.shutdown_now(); + + assert(rv1.valid() && rv1.get() == 5); + try { + auto error = rv3.get(); + assert(false); + } catch (const std::future_error& e) { + + } + + } + + + + std::cout << "threadpool tests passed!\n"; } \ No newline at end of file