diff --git a/include/cuco/detail/open_addressing/kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh index 266335a50..24fce230c 100644 --- a/include/cuco/detail/open_addressing/kernels.cuh +++ b/include/cuco/detail/open_addressing/kernels.cuh @@ -182,6 +182,46 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void erase(InputIt first, } } +/** + * @brief For each key in the range [first, first + n), applies the function object `callback_op` to + * the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CGSize Number of threads in each CG + * @tparam BlockSize Number of threads in each block + * @tparam InputIt Device accessible input iterator whose `value_type` is + * convertible to the `key_type` of the data structure + * @tparam CallbackOp Type of unary callback function object + * @tparam Ref Type of non-owning device ref allowing access to storage + * + * @param first Beginning of the sequence of input elements + * @param n Number of input elements + * @param callback_op Function to call on every matched slot found in the container + * @param ref Non-owning container device ref used to access the slot storage + */ +template +CUCO_KERNEL __launch_bounds__(BlockSize) void for_each_n(InputIt first, + cuco::detail::index_type n, + CallbackOp callback_op, + Ref ref) +{ + auto const loop_stride = cuco::detail::grid_stride() / CGSize; + auto idx = cuco::detail::global_thread_id() / CGSize; + + while (idx < n) { + typename std::iterator_traits::value_type const& key{*(first + idx)}; + if constexpr (CGSize == 1) { + ref.for_each(key, callback_op); + } else { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + ref.for_each(tile, key, callback_op); + } + idx += loop_stride; + } +} + /** * @brief Indicates whether the keys in the range `[first, first + n)` are contained in the data * structure if `pred` of the corresponding stencil returns true. diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index 9dabff990..f9c35e0ff 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -681,6 +682,67 @@ class open_addressing_impl { return output_begin + h_num_out; } + /** + * @brief Asynchronously applies the given function object `callback_op` to the copy of every + * filled slot in the container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to call on every filled slot in the container + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const + { + auto const is_filled = open_addressing_ns::detail::slot_is_filled{ + this->empty_key_sentinel(), this->erased_key_sentinel()}; + + auto storage_ref = this->storage_ref(); + auto const op = [callback_op, is_filled, storage_ref] __device__(auto const window_slots) { + for (auto const slot : window_slots) { + if (is_filled(slot)) { callback_op(slot); } + } + }; + + CUCO_CUDA_TRY(cub::DeviceFor::ForEachCopyN( + storage_ref.data(), storage_ref.num_windows(), op, stream.get())); + } + + /** + * @brief For each key in the range [first, last), asynchronously applies the function object + * `callback_op` to the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * @tparam Ref Type of non-owning device container ref allowing access to storage + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to call on every match found in the container + * @param container_ref Non-owning device container ref used to access the slot storage + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + Ref container_ref, + cuda::stream_ref stream) const noexcept + { + auto const num_keys = cuco::detail::distance(first, last); + if (num_keys == 0) { return; } + + auto const grid_size = cuco::detail::grid_size(num_keys, cg_size); + + detail::for_each_n + <<>>( + first, num_keys, std::forward(callback_op), container_ref); + } + /** * @brief Gets the number of elements in the container * diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 0be26e482..12a306a71 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -966,17 +966,16 @@ class open_addressing_ref_impl { } /** - * @brief Executes a callback on every element in the container with key equivalent to the probe - * key. + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. * - * @note Passes an un-incrementable input iterator to the element whose key is equivalent to - * `key` to the callback. + * @note The return value of `callback_op`, if any, is ignored. * * @tparam ProbeKey Probe key type - * @tparam CallbackOp Unary callback functor or device lambda + * @tparam CallbackOp Type of unary callback function object * * @param key The key to search for - * @param callback_op Function to call on every element found + * @param callback_op Function to apply to every matched slot */ template __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept @@ -995,7 +994,7 @@ class open_addressing_ref_impl { return; } case detail::equal_result::EQUAL: { - callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + callback_op(window_slots[i]); continue; } default: continue; @@ -1006,24 +1005,23 @@ class open_addressing_ref_impl { } /** - * @brief Executes a callback on every element in the container with key equivalent to the probe - * key. - * - * @note Passes an un-incrementable input iterator to the element whose key is equivalent to - * `key` to the callback. + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. * * @note This function uses cooperative group semantics, meaning that any thread may call the * callback if it finds a matching element. If multiple elements are found within the same group, * each thread with a match will call the callback with its associated element. * + * @note The return value of `callback_op`, if any, is ignored. + * * @note Synchronizing `group` within `callback_op` is undefined behavior. * * @tparam ProbeKey Probe key type - * @tparam CallbackOp Unary callback functor or device lambda + * @tparam CallbackOp Type of unary callback function object * * @param group The Cooperative Group used to perform this operation * @param key The key to search for - * @param callback_op Function to call on every element found + * @param callback_op Function to apply to every matched slot */ template __device__ void for_each(cooperative_groups::thread_block_tile const& group, @@ -1045,7 +1043,7 @@ class open_addressing_ref_impl { continue; } case detail::equal_result::EQUAL: { - callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + callback_op(window_slots[i]); continue; } default: { @@ -1060,12 +1058,9 @@ class open_addressing_ref_impl { } /** - * @brief Executes a callback on every element in the container with key equivalent to the probe - * key and can additionally perform work that requires synchronizing the Cooperative Group - * performing this operation. - * - * @note Passes an un-incrementable input iterator to the element whose key is equivalent to - * `key` to the callback. + * @brief Applies the function object `callback_op` to the copy of every slot in the container + * with key equivalent to the probe key and can additionally perform work that requires + * synchronizing the Cooperative Group performing this operation. * * @note This function uses cooperative group semantics, meaning that any thread may call the * callback if it finds a matching element. If multiple elements are found within the same group, @@ -1073,18 +1068,20 @@ class open_addressing_ref_impl { * * @note Synchronizing `group` within `callback_op` is undefined behavior. * + * @note The return value of `callback_op`, if any, is ignored. + * * @note The `sync_op` function can be used to perform work that requires synchronizing threads in * `group` inbetween probing steps, where the number of probing steps performed between * synchronization points is capped by `window_size * cg_size`. The functor will be called right * after the current probing window has been traversed. * * @tparam ProbeKey Probe key type - * @tparam CallbackOp Unary callback functor or device lambda - * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * @tparam CallbackOp Type of unary callback function object + * @tparam SyncOp Type of function object which accepts the current `group` object * * @param group The Cooperative Group used to perform this operation * @param key The key to search for - * @param callback_op Function to call on every element found + * @param callback_op Function to apply to every matched slot * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows */ template @@ -1108,7 +1105,7 @@ class open_addressing_ref_impl { continue; } case detail::equal_result::EQUAL: { - callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + callback_op(window_slots[i]); continue; } default: { diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index 86b75507d..e575114de 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -499,6 +499,70 @@ void static_mapfind_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_map::for_each( + CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); + stream.wait(); +} + +template +template +void static_map::for_each_async( + CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); +} + +template +template +void static_map::for_each( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); + stream.wait(); +} + +template +template +void static_map::for_each_async( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); +} + template +class operator_impl< + op::for_each_tag, + static_map_ref> { + using base_type = static_map_ref; + using ref_type = static_map_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using iterator = typename base_type::iterator; + using const_iterator = typename base_type::const_iterator; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief For a given key, applies the function object `callback_op` to its match found in the + * container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param key The key to search for + * @param callback_op Function to apply to the copy of the matched key-value pair + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, std::forward(callback_op)); + } + + /** + * @brief For a given key, applies the function object `callback_op` to its match found in the + * container. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching key-value pair. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to the copy of the matched key-value pair + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, std::forward(callback_op)); + } +}; + } // namespace detail } // namespace cuco diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 9c87e45a9..01a39ad5d 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -762,6 +762,74 @@ class static_map { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief Applies the given function object `callback_op` to the copy of every filled slot in the + * container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the matched key-value pair + * @param stream CUDA stream used for this operation + */ + template + void for_each(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief Asynchronously applies the given function object `callback_op` to the copy of every + * filled slot in the container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the matched key-value pair + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), applies the function object `callback_op` to + * the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched key-value pair + * @param stream CUDA stream used for this operation + */ + template + void for_each(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), asynchronously applies the function object + * `callback_op` to the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched key-value pair + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const noexcept; + /** * @brief Retrieves all of the keys and their associated values. * diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dde1317b0..dc610af5b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -77,6 +77,7 @@ ConfigureTest(STATIC_MAP_TEST static_map/custom_type_test.cu static_map/duplicate_keys_test.cu static_map/erase_test.cu + static_map/for_each_test.cu static_map/hash_test.cu static_map/heterogeneous_lookup_test.cu static_map/insert_and_find_test.cu diff --git a/tests/static_map/for_each_test.cu b/tests/static_map/for_each_test.cu new file mode 100644 index 000000000..1c72a2e58 --- /dev/null +++ b/tests/static_map/for_each_test.cu @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +using size_type = std::size_t; + +template +void test_for_each(Map& map, size_type num_keys) +{ + using Key = typename Map::key_type; + using Value = typename Map::mapped_type; + + REQUIRE(num_keys % 2 == 0); + + // Insert pairs + auto pairs_begin = thrust::make_transform_iterator( + thrust::counting_iterator(0), + cuda::proclaim_return_type>([] __device__(auto i) { + // use payload as 1 for even keys and 2 for odd keys + return cuco::pair{i, i % 2 + 1}; + })); + + cuda::stream_ref stream{}; + + map.insert(pairs_begin, pairs_begin + num_keys, stream); + + using Allocator = cuco::cuda_allocator>; + cuco::detail::counter_storage counter_storage( + Allocator{}); + counter_storage.reset(stream); + + // count all the keys which are even and whose payload has value 1 + map.for_each( + [counter = counter_storage.data()] __device__(auto const slot) { + auto const& [key, value] = slot; + if (((key % 2 == 0)) and (value == 1)) { counter->fetch_add(1, cuda::memory_order_relaxed); } + }, + stream); + + auto const res = counter_storage.load_to_host(stream); + REQUIRE(res == num_keys / 2); + + counter_storage.reset(stream); + + map.for_each( + thrust::counting_iterator(0), + thrust::counting_iterator(2 * num_keys), // test for false-positives + [counter = counter_storage.data()] __device__(auto const slot) { + auto const& [key, value] = slot; + if (((key % 2 == 0)) and (value == 1)) { counter->fetch_add(1, cuda::memory_order_relaxed); } + }, + stream); + REQUIRE(res == num_keys / 2); +} + +TEMPLATE_TEST_CASE_SIG( + "static_map for_each tests", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_type num_keys{100}; + using probe = std::conditional_t< + Probe == cuco::test::probe_sequence::linear_probing, + cuco::linear_probing>, + cuco::double_hashing, cuco::murmurhash3_32>>; + + using map_type = cuco::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::storage<2>>; + + auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; + test_for_each(map, num_keys); +} diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index 1872586b7..b987ba660 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -45,8 +45,8 @@ CUCO_KERNEL void for_each_check_scalar(Ref ref, while (idx < n) { auto const& key = *(first + idx); std::size_t matches = 0; - ref.for_each(key, [&] __device__(auto const it) { - if (ref.key_eq()(key, *it)) { matches++; } + ref.for_each(key, [&] __device__(auto const slot) { + if (ref.key_eq()(key, slot)) { matches++; } }); if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); } idx += loop_stride; @@ -73,13 +73,13 @@ CUCO_KERNEL void for_each_check_cooperative(Ref ref, ref.for_each( tile, key, - [&] __device__(auto const it) { - if (ref.key_eq()(key, *it)) { thread_matches++; } + [&] __device__(auto const slot) { + if (ref.key_eq()(key, slot)) { thread_matches++; } }, [] __device__(auto const& group) { group.sync(); }); } else { - ref.for_each(tile, key, [&] __device__(auto const it) { - if (ref.key_eq()(key, *it)) { thread_matches++; } + ref.for_each(tile, key, [&] __device__(auto const slot) { + if (ref.key_eq()(key, slot)) { thread_matches++; } }); } auto const tile_matches =