forked from NVIDIA/cuCollections
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprobing_scheme_impl.inl
More file actions
174 lines (157 loc) · 5.51 KB
/
probing_scheme_impl.inl
File metadata and controls
174 lines (157 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuco/detail/utils.cuh>
namespace cuco {
namespace detail {
/**
* @brief Probing iterator class.
*
* @tparam Extent Type of Extent
*/
template <typename Extent>
class probing_iterator {
public:
using extent_type = Extent; ///< Extent type
using size_type = typename extent_type::value_type; ///< Size type
/**
* @brief Constructs an probing iterator
*
* @param start Iteration starting point
* @param step_size Double hashing step size
* @param upper_bound Upper bound of the iteration
*/
__host__ __device__ constexpr probing_iterator(size_type start,
size_type step_size,
extent_type upper_bound) noexcept
: curr_index_{start}, step_size_{step_size}, upper_bound_{upper_bound}
{
// TODO: revise this API when introducing quadratic probing into cuco
}
/**
* @brief Dereference operator
*
* @return Current slot index
*/
__host__ __device__ constexpr auto operator*() const noexcept { return curr_index_; }
/**
* @brief Prefix increment operator
*
* @return Current iterator
*/
__host__ __device__ constexpr auto operator++() noexcept
{
// TODO: step_size_ can be a build time constant (e.g. linear probing)
// Worth passing another extent type?
curr_index_ = (curr_index_ + step_size_) % upper_bound_;
return *this;
}
/**
* @brief Postfix increment operator
*
* @return Old iterator before increment
*/
__host__ __device__ constexpr auto operator++(int32_t) noexcept
{
auto temp = *this;
++(*this);
return temp;
}
private:
size_type curr_index_;
size_type step_size_;
extent_type upper_bound_;
};
} // namespace detail
template <int32_t CGSize, typename Hash>
__host__ __device__ constexpr linear_probing<CGSize, Hash>::linear_probing(Hash const& hash)
: hash_{hash}
{
}
template <int32_t CGSize, typename Hash>
template <typename NewHash>
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::with_hash_function(
NewHash const& hash) const noexcept
{
return linear_probing<cg_size, NewHash>{hash};
}
template <int32_t CGSize, typename Hash>
template <typename ProbeKey, typename Extent>
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
ProbeKey const& probe_key, Extent upper_bound) const noexcept
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(hash_(probe_key)) % upper_bound,
1, // step size is 1
upper_bound};
}
template <int32_t CGSize, typename Hash>
template <typename ProbeKey, typename Extent>
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
cooperative_groups::thread_block_tile<cg_size> const& g,
ProbeKey const& probe_key,
Extent upper_bound) const noexcept
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(g, hash_(probe_key)) % upper_bound,
cg_size,
upper_bound};
}
template <int32_t CGSize, typename Hash1, typename Hash2>
__host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashing(
Hash1 const& hash1, Hash2 const& hash2)
: hash1_{hash1}, hash2_{hash2}
{
}
template <int32_t CGSize, typename Hash1, typename Hash2>
template <typename NewHash1, typename NewHash2>
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::with_hash_function(
NewHash1 const& hash1, NewHash2 const& hash2) const noexcept
{
return double_hashing<cg_size, NewHash1, NewHash2>{hash1, hash2};
}
template <int32_t CGSize, typename Hash1, typename Hash2>
template <typename ProbeKey, typename Extent>
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operator()(
ProbeKey const& probe_key, Extent upper_bound) const noexcept
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key)) % upper_bound,
max(size_type{1},
cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) %
upper_bound), // step size in range [1, prime - 1]
upper_bound};
}
template <int32_t CGSize, typename Hash1, typename Hash2>
template <typename ProbeKey, typename Extent>
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operator()(
cooperative_groups::thread_block_tile<cg_size> const& g,
ProbeKey const& probe_key,
Extent upper_bound) const noexcept
{
using size_type = typename Extent::value_type;
return detail::probing_iterator<Extent>{
cuco::detail::sanitize_hash<size_type>(g, hash1_(probe_key)) % upper_bound,
static_cast<size_type>(
(cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) % (upper_bound / cg_size - 1) +
1) *
cg_size),
upper_bound}; // TODO use fast_int operator
}
} // namespace cuco