-
Notifications
You must be signed in to change notification settings - Fork 160
Set cudaFuncAttributeMaxDynamicSharedMemorySize with thread-safety
#1771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rapids-bot
merged 35 commits into
rapidsai:main
from
mythrocks:cuda-invalid-argument-kernel-error
Feb 22, 2026
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
96af4dd
Set max-dynamic-shared-mem with thread-safety
mythrocks 305cb55
Copyright date. Formatting.
mythrocks 4463666
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks 677f6fe
Moved other call to set cudaFuncAttributeMaxDynamicSharedMemorySize.
mythrocks 5e9e410
Moved call-sites in other files.
mythrocks f55cbd3
Copyright date.
mythrocks b394b19
Resolved merge conflicts. Moved all call-sites to use the new `optio…
mythrocks f3efb37
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks 8792c40
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks e55eb23
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks 2a5a00f
Invoke kernel within critical section.
mythrocks b1b97cc
Removed old function.
mythrocks f50d4d4
Tie the kernel to its launcher.
mythrocks c2b1c43
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks 8d66233
Better error reporting.
mythrocks 1acb864
Merge branch 'main' into cuda-invalid-argument-kernel-error
achirkin 1d41373
Remove the safety fix for persistent kernel (only one kernel must run…
achirkin f6cf7d3
Add a reproducer
achirkin 9b13500
Fix style
achirkin 8d6fb1a
Apply suggestion from @achirkin
mythrocks 1cd8b2c
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks 9ed3400
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks 1955dce
Fixed formatting again.
mythrocks 3ba77d5
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks 5a8771c
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks 45dff11
Apply suggestion from @achirkin
achirkin 8c107e1
Merge branch 'main' into cuda-invalid-argument-kernel-error
achirkin a6f3071
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks d3f6ecd
Fixed format string.
mythrocks 088fcc8
Fixed thrust header.
mythrocks 856e695
Fix compile error for thrust make_counting_iterator
mythrocks eff47c8
Revert "Fix compile error for thrust make_counting_iterator"
mythrocks b59b8c3
Cherrypick fix from #1825.
mythrocks 3581c01
Merge remote-tracking branch 'origin/main' into cuda-invalid-argument…
mythrocks 80825ab
Merge branch 'main' into cuda-invalid-argument-kernel-error
mythrocks File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <raft/core/error.hpp> | ||
|
|
||
| #include <atomic> | ||
| #include <cstdint> | ||
| #include <mutex> | ||
|
|
||
| namespace cuvs::neighbors::detail { | ||
|
|
||
| /** | ||
| * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size. | ||
| * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed | ||
| * atomically. | ||
| * | ||
| * Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus | ||
| * guarantees that the kernel is safe to launch. | ||
| * | ||
| * @tparam KernelT The type of the kernel. | ||
| * @tparam InvocationT The type of the invocation function. | ||
| * @param kernel The kernel function address (for whom the smem-size is specified). | ||
| * @param smem_size The size of the dynamic shared memory to be set. | ||
| * @param launch The kernel launch function/lambda. | ||
| */ | ||
| template <typename KernelT, typename KernelLauncherT> | ||
| void safely_launch_kernel_with_smem_size(KernelT const& kernel, | ||
| uint32_t smem_size, | ||
| KernelLauncherT const& launch) | ||
| { | ||
| // the last smem size is parameterized by the kernel thanks to the template parameter. | ||
| static std::atomic<uint32_t> current_smem_size{0}; | ||
| auto last_smem_size = current_smem_size.load(std::memory_order_relaxed); | ||
| if (smem_size > last_smem_size) { | ||
| // We still need a mutex for the critical section: actualize last_smem_size and set the | ||
| // attribute. | ||
| static auto mutex = std::mutex{}; | ||
| auto guard = std::lock_guard<std::mutex>{mutex}; | ||
| if (!current_smem_size.compare_exchange_strong( | ||
| last_smem_size, smem_size, std::memory_order_relaxed, std::memory_order_relaxed)) { | ||
| // The value has been updated by another thread between the load and the mutex acquisition. | ||
| if (smem_size > last_smem_size) { | ||
| current_smem_size.store(smem_size, std::memory_order_relaxed); | ||
| } | ||
| } | ||
| // Only update if the last seen value is smaller than the new one. | ||
| if (smem_size > last_smem_size) { | ||
| auto launch_status = | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); | ||
| RAFT_EXPECTS(launch_status == cudaSuccess, | ||
| "Failed to set max dynamic shared memory size to %u bytes", | ||
| smem_size); | ||
| } | ||
| } | ||
| // We don't need to guard the kernel launch because the smem_size can only grow. | ||
| return launch(kernel); | ||
| } | ||
|
|
||
| } // namespace cuvs::neighbors::detail | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
cpp/tests/neighbors/ann_cagra/bug_issue_93_reproducer.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Reproducer for https://github.com/rapidsai/cuvs-lucene/issues/93 | ||
| * cuvsCagraSearch returned 0 (Reason=cudaErrorInvalidValue:invalid argument) | ||
| * | ||
| * ROOT CAUSE: | ||
| * `cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)` | ||
| * is not thread-safe. It sets a CUDA-context-wide attribute. When two threads call it | ||
| * concurrently with different smem_size values, the following race occurs: | ||
| * 1. Thread A sets max-dynamic-shared-mem to SIZE_A (larger). | ||
| * 2. Thread B overwrites it with SIZE_B (smaller). | ||
| * 3. Thread A launches its kernel requesting SIZE_A of shared memory, | ||
| * but the CUDA context now only allows SIZE_B → cudaErrorInvalidValue. | ||
| * | ||
| * HOW IT MANIFESTS IN cuvs-lucene: | ||
| * Lucene's TaskExecutor dispatches per-segment CAGRA searches to a thread pool. | ||
| * Each segment has a different number of vectors (e.g. 25, 26, 47), leading to | ||
| * different graph degrees after reduction, and therefore different smem_size values | ||
| * in the single-CTA search kernel. The concurrent cudaFuncSetAttribute calls race. | ||
| * | ||
| * REPRODUCTION STRATEGY: | ||
| * Build CAGRA indices with different dataset sizes (different graph degrees), | ||
| * then search them concurrently from separate threads, each with its own raft::resources. | ||
| * This mirrors the cuvs-lucene setup where each thread gets a ThreadLocal CuVSResources. | ||
| */ | ||
|
|
||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <cuvs/distance/distance.hpp> | ||
| #include <cuvs/neighbors/cagra.hpp> | ||
| #include <raft/core/device_mdarray.hpp> | ||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/device_resources.hpp> | ||
| #include <raft/core/resource/cuda_stream.hpp> | ||
| #include <raft/random/rng.cuh> | ||
|
|
||
| #include <cstdint> | ||
| #include <mutex> | ||
| #include <string> | ||
| #include <thread> | ||
| #include <vector> | ||
|
|
||
| namespace cuvs::neighbors::cagra { | ||
|
|
||
| // NOLINTNEXTLINE(readability-identifier-naming) | ||
| TEST(Issue93Reproducer, ConcurrentSearchDifferentGraphDegrees) | ||
| { | ||
| raft::resources handle; | ||
| raft::random::RngState rng(6181234567890123459ULL); | ||
|
|
||
| // Dataset sizes from REPRODUCER.md warnings (different sizes → different graph degrees). | ||
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
| std::vector<int> dataset_sizes = {25, 26, 47, 25}; | ||
| constexpr int dim = 64; | ||
| constexpr int top_k = 10; | ||
|
|
||
| // Build indices on the main thread. | ||
| std::vector<cagra::index<float, uint32_t>> indices; | ||
| for (int n_rows : dataset_sizes) { | ||
| auto database = raft::make_device_matrix<float, int64_t>(handle, n_rows, dim); | ||
| raft::random::uniform( | ||
| handle, rng, database.data_handle(), n_rows * dim, -1.0F, 1.0F); // NOLINT | ||
|
|
||
| cagra::index_params ip; | ||
| ip.metric = cuvs::distance::DistanceType::L2Expanded; | ||
| ip.intermediate_graph_degree = 128; // NOLINT | ||
| ip.graph_degree = 64; // NOLINT | ||
| ip.graph_build_params = | ||
| graph_build_params::nn_descent_params(ip.intermediate_graph_degree, ip.metric); | ||
|
|
||
| indices.push_back(cagra::build(handle, ip, raft::make_const_mdspan(database.view()))); | ||
| } | ||
| raft::resource::sync_stream(handle); | ||
|
|
||
| // Search concurrently from multiple threads until the first failure. | ||
| const int num_threads = static_cast<int>(indices.size()); | ||
| std::mutex error_mutex; | ||
| std::string first_error; | ||
|
|
||
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
| for (int iter = 0; iter < 50 && first_error.empty(); ++iter) { | ||
| std::vector<std::thread> threads; | ||
| for (int t = 0; t < num_threads; ++t) { | ||
| threads.emplace_back([&, t, iter]() { | ||
| raft::resources thread_handle; | ||
| raft::random::RngState thread_rng(42ULL + static_cast<uint64_t>(t) + | ||
| static_cast<uint64_t>(iter) * 1000ULL); | ||
| try { | ||
| auto query = raft::make_device_matrix<float, int64_t>(thread_handle, 1, dim); | ||
| raft::random::uniform(thread_handle, thread_rng, query.data_handle(), dim, -1.0F, 1.0F); | ||
|
|
||
| // Match cuvs-lucene params: Java's Panama zero-initializes the struct, | ||
| // and SINGLE_CTA = 0 in the enum, so algo is SINGLE_CTA. | ||
| cagra::search_params sp; | ||
| sp.itopk_size = top_k; | ||
| sp.search_width = 1; | ||
| sp.algo = search_algo::SINGLE_CTA; | ||
|
|
||
| auto neighbors = raft::make_device_matrix<uint32_t, int64_t>(thread_handle, 1, top_k); | ||
| auto distances = raft::make_device_matrix<float, int64_t>(thread_handle, 1, top_k); | ||
|
|
||
| cagra::search(thread_handle, | ||
| sp, | ||
| indices[static_cast<size_t>(t)], | ||
| raft::make_const_mdspan(query.view()), | ||
| neighbors.view(), | ||
| distances.view()); | ||
|
|
||
| raft::resource::sync_stream(thread_handle); | ||
| } catch (const std::exception& e) { | ||
| std::lock_guard<std::mutex> lock(error_mutex); | ||
| if (first_error.empty()) { first_error = e.what(); } | ||
| } | ||
| }); | ||
| } | ||
| for (auto& th : threads) { | ||
| th.join(); | ||
| } | ||
| } | ||
|
|
||
| ASSERT_TRUE(first_error.empty()) << first_error; | ||
| } | ||
|
|
||
| } // namespace cuvs::neighbors::cagra |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.