From 6545cad57f0a5e9a230a60a6c418e063fcf795a4 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Mon, 15 Sep 2025 22:27:00 +0800 Subject: [PATCH 1/2] use blocking sync when idling --- src/turbomind/comm/CMakeLists.txt | 7 +- src/turbomind/comm/barrier.h | 33 +++++++ src/turbomind/comm/host_comm.h | 2 +- src/turbomind/comm/test_comm.cu | 40 -------- src/turbomind/comm/test_host_comm.cc | 63 ++++++++++++ src/turbomind/comm/thread_comm.cc | 121 ++++++++++++----------- src/turbomind/models/llama/LlamaBatch.cc | 8 +- 7 files changed, 173 insertions(+), 101 deletions(-) create mode 100644 src/turbomind/comm/test_host_comm.cc diff --git a/src/turbomind/comm/CMakeLists.txt b/src/turbomind/comm/CMakeLists.txt index a8b6ff75a2..0a3b2b4ea3 100644 --- a/src/turbomind/comm/CMakeLists.txt +++ b/src/turbomind/comm/CMakeLists.txt @@ -2,8 +2,10 @@ cmake_minimum_required(VERSION 3.8) +find_package(Threads) + add_library(host_comm STATIC host_comm.cc thread_comm.cc) -target_link_libraries(host_comm PRIVATE core logger) +target_link_libraries(host_comm PRIVATE core logger Threads::Threads) set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON) add_library(device_comm STATIC device_comm.cc) @@ -24,5 +26,8 @@ if (BUILD_MULTI_GPU) add_executable(test_comm test_comm.cu) target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils) target_compile_options(test_comm PRIVATE -march=native -mtune=native) + + add_executable(test_host_comm test_host_comm.cc) + target_link_libraries(test_host_comm PRIVATE host_comm core Threads::Threads) endif () endif () diff --git a/src/turbomind/comm/barrier.h b/src/turbomind/comm/barrier.h index 91ac3b65c9..194c450554 100644 --- a/src/turbomind/comm/barrier.h +++ b/src/turbomind/comm/barrier.h @@ -2,6 +2,8 @@ #pragma once +#if defined(_MSC_VER) && !defined(__clang__) + #include #include #include @@ -37,3 +39,34 @@ class Barrier { }; } // namespace turbomind::comm + +#else + +#include + +namespace turbomind::comm { + +class Barrier { +public: + explicit Barrier(int count): barrier_{} + { + pthread_barrier_init(&barrier_, {}, count); + } + + ~Barrier() + { + pthread_barrier_destroy(&barrier_); + } + + void arrive_and_wait() + { + pthread_barrier_wait(&barrier_); + } + +private: + pthread_barrier_t barrier_; +}; + +} // namespace turbomind::comm + +#endif diff --git a/src/turbomind/comm/host_comm.h b/src/turbomind/comm/host_comm.h index b036142264..e9e25d2b8f 100644 --- a/src/turbomind/comm/host_comm.h +++ b/src/turbomind/comm/host_comm.h @@ -35,7 +35,7 @@ class HostCommImpl { virtual std::shared_ptr Split(int color, int key) = 0; - virtual void Sync() = 0; + virtual void Sync(bool blocking = false) = 0; virtual void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) = 0; diff --git a/src/turbomind/comm/test_comm.cu b/src/turbomind/comm/test_comm.cu index 2f049a07d3..37415090a6 100644 --- a/src/turbomind/comm/test_comm.cu +++ b/src/turbomind/comm/test_comm.cu @@ -964,46 +964,6 @@ struct TestComm { int main(int argc, char* argv[]) { -#if 0 - const int N = 8; - auto state = std::make_shared(N); - std::vector threads; - for (int r = 0; r < N; ++r) { - threads.emplace_back([&, r] { - HostComm comm(N, r, state); - int group = 0; - // group = comm.Split(r / (N / 2), 0); - group = comm.Split(r % 4, 0); - auto tick = std::chrono::steady_clock::now(); - volatile int a; - volatile int b; - for (int i = 0; i < 1; ++i) { - a = Allreduce(comm, r, group); - auto v = Allgather(comm, r, group); - b = std::accumulate(v.begin(), v.end(), 0); - for (int j = 0; j < N; ++j) { - comm.Sync(); - if (j == r) { - std::cout << a << " " << b << std::endl; - } - } - } - auto tock = std::chrono::steady_clock::now(); - - for (int i = 0; i < N; ++i) { - comm.Sync(); - if (i == r) { - std::cout << std::chrono::duration(tock - tick).count() << std::endl; - } - } - }); - } - std::cout << "main thread waiting.\n"; - for (auto& t : threads) { - t.join(); - } - return 0; -#endif TestComm test; diff --git a/src/turbomind/comm/test_host_comm.cc b/src/turbomind/comm/test_host_comm.cc new file mode 100644 index 0000000000..c0a1925b62 --- /dev/null +++ b/src/turbomind/comm/test_host_comm.cc @@ -0,0 +1,63 @@ + +#include +#include +#include + +#include "src/turbomind/comm/host_comm.h" + +using namespace turbomind; +using namespace turbomind::comm; + +int main(int argc, char* argv[]) +{ + const int N = 32; + std::unique_ptr group_id = CreateHostGroupId({}); + group_id->Initialize(); + std::vector threads; + for (int r = 0; r < N; ++r) { + threads.emplace_back([&, r] { + HostComm world = group_id->CreateCommunicator(N, r); + + HostComm group = world; + group = world->Split(r / (N / 4), 0); + + auto tick = std::chrono::steady_clock::now(); + + // int data = 100; + // for (int i = 0; i < 10000; ++i, ++data) { + // group->Sync(true); + // } + + volatile int a; + volatile int b; + for (int i = 0; i < 1; ++i) { + a = AllReduce(group, r, RedOp::kSum); + auto v = AllGather(group, r); + b = std::accumulate(v.begin(), v.end(), 0); + for (int j = 0; j < N; ++j) { + world->Sync(); + if (j == r) { + std::cout << a << " " << b << std::endl; + } + } + } + + auto tock = std::chrono::steady_clock::now(); + + for (int i = 0; i < N; ++i) { + world->Sync(); + if (i == r) { + std::cout << std::chrono::duration(tock - tick).count() << std::endl; + } + } + }); + } + + std::cout << "main thread waiting.\n"; + + for (auto& t : threads) { + t.join(); + } + + return 0; +} \ No newline at end of file diff --git a/src/turbomind/comm/thread_comm.cc b/src/turbomind/comm/thread_comm.cc index 017d83abb0..509681271e 100644 --- a/src/turbomind/comm/thread_comm.cc +++ b/src/turbomind/comm/thread_comm.cc @@ -2,13 +2,12 @@ #include #include -#include #include #include #include #include -#include +#include "src/turbomind/comm/barrier.h" #include "src/turbomind/comm/host_comm.h" #include "src/turbomind/core/check.h" #include "src/turbomind/core/data_type.h" @@ -18,44 +17,42 @@ struct ThreadCommImpl: public HostCommImpl { class State { public: - explicit State(int n): n_{n}, channels_(n * n) {} + explicit State(int n): n_{n}, channels_(n * n), barrier_{n} {} + std::atomic& channel(int from, int to) { return channels_[from * n_ + to]; } + void sync() + { + barrier_.arrive_and_wait(); + } + private: int n_; std::deque> channels_; + Barrier barrier_; }; std::shared_ptr state_; - int rank_; // global rank - - std::vector l2g_; - std::vector g2l_; - - ThreadCommImpl(int n_ranks, std::shared_ptr state, int rank): state_{std::move(state)}, rank_{rank} - { - l2g_.resize(n_ranks); - std::iota(l2g_.begin(), l2g_.end(), 0); - g2l_ = l2g_; - } + int n_ranks_; + int rank_; - ThreadCommImpl(std::vector l2g, std::vector g2l, std::shared_ptr state, int rank): - state_{std::move(state)}, rank_{rank}, l2g_{std::move(l2g)}, g2l_{std::move(g2l)} + ThreadCommImpl(int n_ranks, std::shared_ptr state, int rank): + state_{std::move(state)}, n_ranks_{n_ranks}, rank_{rank} { } int rank() const override { - return g2l_.at(rank_); + return rank_; } int n_ranks() const override { - return l2g_.size(); + return n_ranks_; } bool is_same_process() const override @@ -71,37 +68,50 @@ struct ThreadCommImpl: public HostCommImpl { std::shared_ptr Split(int color, int key) override { TM_CHECK(color >= 0); - TM_CHECK(g2l_[rank_] >= 0); - - // `g2l_[rank_]` imposes proper ordering when keys are equal - auto vec = comm::AllGather(this, std::make_tuple(color, key, g2l_[rank_])); - - auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) { // - return std::get<0>(x) == color; - }); - vec.erase(last, vec.end()); - std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) { // - return a < b; - }); - - std::vector l2g; - std::vector g2l(g2l_.size(), -1); - - for (size_t i = 0; i < vec.size(); ++i) { - int r = l2g_.at(std::get<2>(vec[i])); - l2g.push_back(r); - g2l[r] = i; + + auto ranks = comm::AllGather(this, std::make_tuple(color, key, rank_)); + + auto same_color = [&](auto x) { return std::get<0>(x) == color; }; + ranks.erase(std::stable_partition(ranks.begin(), ranks.end(), same_color), ranks.end()); + + std::stable_sort(ranks.begin(), ranks.end(), [](auto& a, auto& b) { return a < b; }); + + std::shared_ptr state; + + int rank = -1; + for (int i = 0; i < ranks.size(); ++i) { + if (std::get<2>(ranks[i]) == rank_) { + rank = i; + } + } + + TM_CHECK_GE(rank, 0); + + if (rank == 0) { + state = std::make_shared(ranks.size()); } - return std::make_shared(std::move(l2g), std::move(g2l), state_, rank_); + auto states = comm::AllGather(this, state); + if (rank != 0) { + const int root = std::get<2>(ranks[0]); + state = states[root]; + } + + return std::make_shared(ranks.size(), state, rank); } - void Sync() override + void Sync(bool blocking) override { - if (n_ranks() == 1) { + if (n_ranks_ == 1) { return; } - for (const auto& r : l2g_) { + + if (blocking) { + state_->sync(); + return; + } + + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); void* expected{}; @@ -110,7 +120,7 @@ struct ThreadCommImpl: public HostCommImpl { } } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(r, rank_); void* expected = (void*)1; @@ -124,13 +134,12 @@ struct ThreadCommImpl: public HostCommImpl { void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) override { TM_CHECK(copy); - if (n_ranks() == 1) { + if (n_ranks_ == 1) { return; } // transform root to global rank - root = l2g_.at(root); if (rank_ == root) { - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); void* expected{}; @@ -139,7 +148,7 @@ struct ThreadCommImpl: public HostCommImpl { } } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); while (c.load(std::memory_order_relaxed)) {} @@ -158,10 +167,10 @@ struct ThreadCommImpl: public HostCommImpl { void AllGather(void* data, int count, DataType dtype, copy_fn copy) override { TM_CHECK(copy); - if (n_ranks() == 1) { + if (n_ranks_ == 1) { return; } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); void* expected{}; @@ -170,16 +179,16 @@ struct ThreadCommImpl: public HostCommImpl { } } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(r, rank_); void* incoming{}; while (!(incoming = c.load(std::memory_order_acquire))) {} - copy(incoming, count, data, g2l_[r] * count); + copy(incoming, count, data, r * count); c.store(nullptr, std::memory_order_relaxed); } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); while (c.load(std::memory_order_relaxed)) {} @@ -250,12 +259,12 @@ struct ThreadCommImpl: public HostCommImpl { { const auto reduce = get_reduce(dtype, red_op); const auto elem_size = byte_size(dtype); - if (n_ranks() == 1) { + if (n_ranks_ == 1) { return; } std::unique_ptr tmp((char*)::operator new[](elem_size* count)); std::copy_n((char*)data, elem_size * count, tmp.get()); - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); void* expected{}; @@ -264,7 +273,7 @@ struct ThreadCommImpl: public HostCommImpl { } } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(r, rank_); void* incoming{}; @@ -273,7 +282,7 @@ struct ThreadCommImpl: public HostCommImpl { c.store(nullptr, std::memory_order_relaxed); } } - for (const auto& r : l2g_) { + for (int r = 0; r < n_ranks_; ++r) { if (r != rank_) { auto& c = channel(rank_, r); while (c.load(std::memory_order_relaxed)) {} diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 586c2e692c..f23de6c0bd 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1365,13 +1365,15 @@ void LlamaBatch::InternalThreadEntry() FindCanceledIndices(req->cancel); } + if (state_->size == g.finished_count) { + // Batch is empty, use blocking sync to avoid spinning + comm_.h_tp_group->Sync(true); + } + NvtxScope scope("mainloop"); // 1. Wait while rank-0 is dequeueing // 2. Broadcast `ec` from rank-0 - // shared_state_->barrier->wait(); - // comm_.h_comm->Sync(comm_.h_comm_tp_group); - Broadcast(comm_.h_tp_group, req, 0); if (req->abort) { From 3d05298a71e6e1455f45a4f87d9b66129f194dfb Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Mon, 15 Sep 2025 22:39:38 +0800 Subject: [PATCH 2/2] lint --- src/turbomind/comm/test_host_comm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turbomind/comm/test_host_comm.cc b/src/turbomind/comm/test_host_comm.cc index c0a1925b62..fbacd10a35 100644 --- a/src/turbomind/comm/test_host_comm.cc +++ b/src/turbomind/comm/test_host_comm.cc @@ -60,4 +60,4 @@ int main(int argc, char* argv[]) } return 0; -} \ No newline at end of file +}