Skip to content

Commit 8f18e7f

Browse files
authored
Adds a linear accessor to RMM cuda stream pool (#696)
Adds `rmm::cuda_stream_pool::get_stream(stream_id)` and `rmm::cuda_stream_pool::get_pool_size()` accessors which allow legacy compatibility in cuML and immediate adoption of `rmm::cuda_stream_pool` in RAFT and cuGraph. This co-exist with the current features in `rmm::cuda_stream_pool`. close #689 Authors: - Alex Fender (@afender) Approvers: - Jake Hemstad (@jrhemstad) - Mark Harris (@harrism) - Rong Ou (@rongou) URL: #696
1 parent 31604e7 commit 8f18e7f

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

include/rmm/cuda_stream_pool.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <rmm/cuda_stream.hpp>
2020
#include <rmm/cuda_stream_view.hpp>
21+
#include <rmm/detail/error.hpp>
2122

2223
#include <atomic>
2324
#include <vector>
@@ -61,6 +62,30 @@ class cuda_stream_pool {
6162
return streams_[(next_stream++) % streams_.size()].view();
6263
}
6364

65+
/**
66+
* @brief Get a `cuda_stream_view` of the stream associated with `stream_id`.
67+
* Equivalent values of `stream_id` return a stream_view to the same underlying stream.
68+
*
69+
* This function is thread safe with respect to other calls to the same function.
70+
*
71+
* @param stream_id Unique identifier for the desired stream
72+
*
73+
* @return rmm::cuda_stream_view
74+
*/
75+
rmm::cuda_stream_view get_stream(std::size_t stream_id) const
76+
{
77+
return streams_[stream_id % streams_.size()].view();
78+
}
79+
80+
/**
81+
* @brief Get the number of streams in the pool.
82+
*
83+
* This function is thread safe with respect to other calls to the same function.
84+
*
85+
* @return the number of streams in the pool
86+
*/
87+
size_t get_pool_size() const noexcept { return streams_.size(); }
88+
6489
private:
6590
std::vector<rmm::cuda_stream> streams_;
6691
mutable std::atomic_size_t next_stream{};

tests/cuda_stream_pool_tests.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,23 @@ TEST_F(CudaStreamPoolTest, ValidStreams)
5454
RMM_CUDA_TRY(cudaMemsetAsync(v.data(), 0xcc, 100, stream_a.value()));
5555
stream_a.synchronize();
5656

57-
auto v2 = rmm::device_uvector<uint8_t>{v, stream_b};
57+
auto v2 = rmm::device_uvector<std::uint8_t>{v, stream_b};
5858
auto x = v2.front_element(stream_b);
5959
EXPECT_EQ(x, 0xcc);
6060
}
61+
62+
TEST_F(CudaStreamPoolTest, PoolSize) { EXPECT_GE(this->pool.get_pool_size(), 1); }
63+
64+
TEST_F(CudaStreamPoolTest, OutOfBoundLinearAccess)
65+
{
66+
auto const stream_a = this->pool.get_stream(0);
67+
auto const stream_b = this->pool.get_stream(this->pool.get_pool_size());
68+
EXPECT_EQ(stream_a, stream_b);
69+
}
70+
71+
TEST_F(CudaStreamPoolTest, ValidLinearAccess)
72+
{
73+
auto const stream_a = this->pool.get_stream(0);
74+
auto const stream_b = this->pool.get_stream(1);
75+
EXPECT_NE(stream_a, stream_b);
76+
}

0 commit comments

Comments
 (0)