Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/turbomind/comm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

cmake_minimum_required(VERSION 3.8)

find_package(Threads)

add_library(host_comm STATIC host_comm.cc thread_comm.cc)
target_link_libraries(host_comm PRIVATE core logger)
target_link_libraries(host_comm PRIVATE core logger Threads::Threads)
set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON)

add_library(device_comm STATIC device_comm.cc)
Expand All @@ -24,5 +26,8 @@ if (BUILD_MULTI_GPU)
add_executable(test_comm test_comm.cu)
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
target_compile_options(test_comm PRIVATE -march=native -mtune=native)

add_executable(test_host_comm test_host_comm.cc)
target_link_libraries(test_host_comm PRIVATE host_comm core Threads::Threads)
endif ()
endif ()
33 changes: 33 additions & 0 deletions src/turbomind/comm/barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#pragma once

#if defined(_MSC_VER) && !defined(__clang__)

#include <condition_variable>
#include <cstdint>
#include <mutex>
Expand Down Expand Up @@ -37,3 +39,34 @@ class Barrier {
};

} // namespace turbomind::comm

#else

#include <pthread.h>

namespace turbomind::comm {

class Barrier {
public:
explicit Barrier(int count): barrier_{}
{
pthread_barrier_init(&barrier_, {}, count);
}

~Barrier()
{
pthread_barrier_destroy(&barrier_);
}

void arrive_and_wait()
{
pthread_barrier_wait(&barrier_);
}

private:
pthread_barrier_t barrier_;
};

} // namespace turbomind::comm

#endif
2 changes: 1 addition & 1 deletion src/turbomind/comm/host_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HostCommImpl {

virtual std::shared_ptr<HostCommImpl> Split(int color, int key) = 0;

virtual void Sync() = 0;
virtual void Sync(bool blocking = false) = 0;

virtual void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) = 0;

Expand Down
40 changes: 0 additions & 40 deletions src/turbomind/comm/test_comm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -964,46 +964,6 @@ struct TestComm {

int main(int argc, char* argv[])
{
#if 0
const int N = 8;
auto state = std::make_shared<HostComm::State>(N);
std::vector<std::thread> threads;
for (int r = 0; r < N; ++r) {
threads.emplace_back([&, r] {
HostComm comm(N, r, state);
int group = 0;
// group = comm.Split(r / (N / 2), 0);
group = comm.Split(r % 4, 0);
auto tick = std::chrono::steady_clock::now();
volatile int a;
volatile int b;
for (int i = 0; i < 1; ++i) {
a = Allreduce<RedOp::kSum>(comm, r, group);
auto v = Allgather(comm, r, group);
b = std::accumulate(v.begin(), v.end(), 0);
for (int j = 0; j < N; ++j) {
comm.Sync();
if (j == r) {
std::cout << a << " " << b << std::endl;
}
}
}
auto tock = std::chrono::steady_clock::now();

for (int i = 0; i < N; ++i) {
comm.Sync();
if (i == r) {
std::cout << std::chrono::duration<float, std::milli>(tock - tick).count() << std::endl;
}
}
});
}
std::cout << "main thread waiting.\n";
for (auto& t : threads) {
t.join();
}
return 0;
#endif

TestComm test;

Expand Down
63 changes: 63 additions & 0 deletions src/turbomind/comm/test_host_comm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

#include <iostream>
#include <numeric>
#include <thread>

#include "src/turbomind/comm/host_comm.h"

using namespace turbomind;
using namespace turbomind::comm;

int main(int argc, char* argv[])
{
const int N = 32;
std::unique_ptr<HostGroupId> group_id = CreateHostGroupId({});
group_id->Initialize();
std::vector<std::thread> threads;
for (int r = 0; r < N; ++r) {
threads.emplace_back([&, r] {
HostComm world = group_id->CreateCommunicator(N, r);

HostComm group = world;
group = world->Split(r / (N / 4), 0);

auto tick = std::chrono::steady_clock::now();

// int data = 100;
// for (int i = 0; i < 10000; ++i, ++data) {
// group->Sync(true);
// }

volatile int a;
volatile int b;
for (int i = 0; i < 1; ++i) {
a = AllReduce(group, r, RedOp::kSum);
auto v = AllGather(group, r);
b = std::accumulate(v.begin(), v.end(), 0);
for (int j = 0; j < N; ++j) {
world->Sync();
if (j == r) {
std::cout << a << " " << b << std::endl;
}
}
}

auto tock = std::chrono::steady_clock::now();

for (int i = 0; i < N; ++i) {
world->Sync();
if (i == r) {
std::cout << std::chrono::duration<float, std::milli>(tock - tick).count() << std::endl;
}
}
});
}

std::cout << "main thread waiting.\n";

for (auto& t : threads) {
t.join();
}

return 0;
}
Loading
Loading