Skip to content

Commit 5571cd8

Browse files
Backport to 2.8: B200 tunings for histogram (#3728)
* Add b200 tunings for histogram (#3616) Co-authored-by: Giannis Gonidelis <ggonidelis@nvidia.com> * Fix SM100 histogram tunings (#3691) The tuning data member names did not match the one used when selecting tunings, so all SM100 tunings were SFINAE-ed out. Also drop tunings with no benefit. --------- Co-authored-by: Giannis Gonidelis <ggonidelis@nvidia.com>
1 parent c8bda1a commit 5571cd8

File tree

2 files changed

+104
-12
lines changed

2 files changed

+104
-12
lines changed

cub/cub/device/dispatch/dispatch_histogram.cuh

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
#include <cub/config.cuh>
3939

40+
#include <cuda/std/__type_traits/is_void.h>
41+
4042
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
4143
# pragma GCC system_header
4244
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
@@ -554,8 +556,7 @@ template <int NUM_CHANNELS,
554556
typename CounterT,
555557
typename LevelT,
556558
typename OffsetT,
557-
typename PolicyHub =
558-
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
559+
typename PolicyHub = void> // if user passes a custom Policy this should not be void
559560
struct DispatchHistogram
560561
{
561562
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
@@ -920,8 +921,14 @@ public:
920921
cudaStream_t stream,
921922
::cuda::std::false_type /*is_byte_sample*/)
922923
{
923-
using MaxPolicyT = typename PolicyHub::MaxPolicy;
924-
cudaError error = cudaSuccess;
924+
// Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram?
925+
static constexpr bool isEven = 0;
926+
using fallback_policy_hub = detail::histogram::
927+
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;
928+
929+
using MaxPolicyT =
930+
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
931+
cudaError error = cudaSuccess;
925932

926933
do
927934
{
@@ -1124,8 +1131,13 @@ public:
11241131
cudaStream_t stream,
11251132
::cuda::std::true_type /*is_byte_sample*/)
11261133
{
1127-
using MaxPolicyT = typename PolicyHub::MaxPolicy;
1128-
cudaError error = cudaSuccess;
1134+
static constexpr bool isEven = 0;
1135+
using fallback_policy_hub = detail::histogram::
1136+
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;
1137+
1138+
using MaxPolicyT =
1139+
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
1140+
cudaError error = cudaSuccess;
11291141

11301142
do
11311143
{
@@ -1292,8 +1304,13 @@ public:
12921304
cudaStream_t stream,
12931305
::cuda::std::false_type /*is_byte_sample*/)
12941306
{
1295-
using MaxPolicyT = typename PolicyHub::MaxPolicy;
1296-
cudaError error = cudaSuccess;
1307+
static constexpr bool isEven = 1;
1308+
using fallback_policy_hub = detail::histogram::
1309+
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;
1310+
1311+
using MaxPolicyT =
1312+
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
1313+
cudaError error = cudaSuccess;
12971314

12981315
do
12991316
{
@@ -1513,8 +1530,13 @@ public:
15131530
cudaStream_t stream,
15141531
::cuda::std::true_type /*is_byte_sample*/)
15151532
{
1516-
using MaxPolicyT = typename PolicyHub::MaxPolicy;
1517-
cudaError error = cudaSuccess;
1533+
static constexpr bool isEven = 1;
1534+
using fallback_policy_hub = detail::histogram::
1535+
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;
1536+
1537+
using MaxPolicyT =
1538+
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
1539+
cudaError error = cudaSuccess;
15181540

15191541
do
15201542
{

cub/cub/device/dispatch/tuning/tuning_histogram.cuh

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ enum class sample_size
6060
{
6161
_1,
6262
_2,
63+
_4,
64+
_8,
6365
unknown
6466
};
6567

@@ -125,7 +127,52 @@ struct sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sampl
125127
static constexpr bool work_stealing = false;
126128
};
127129

128-
template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels>
130+
template <bool IsEven,
131+
class SampleT,
132+
int NumChannels,
133+
int NumActiveChannels,
134+
counter_size CounterSize,
135+
primitive_sample PrimitiveSample = is_primitive_sample<SampleT>(),
136+
sample_size SampleSize = classify_sample_size<SampleT>()>
137+
struct sm100_tuning;
138+
139+
// even
140+
template <class SampleT>
141+
struct sm100_tuning<true, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
142+
{
143+
// ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876
144+
static constexpr int items = 12;
145+
static constexpr int threads = 928;
146+
static constexpr bool rle_compress = false;
147+
static constexpr bool work_stealing = false;
148+
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
149+
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
150+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
151+
static constexpr int vec_size = 1 << 2;
152+
};
153+
154+
// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks
155+
156+
// range
157+
template <class SampleT>
158+
struct sm100_tuning<false, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
159+
{
160+
// ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637
161+
static constexpr int items = 12;
162+
static constexpr int threads = 448;
163+
static constexpr bool rle_compress = false;
164+
static constexpr bool work_stealing = false;
165+
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
166+
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
167+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
168+
static constexpr int vec_size = 1 << 2;
169+
};
170+
171+
// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks
172+
173+
// multi.even and multi.range: none of the found tunings surpassed the SM90 tuning during verification benchmarks
174+
175+
template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels, bool IsEven>
129176
struct policy_hub
130177
{
131178
// TODO(bgruber): move inside t_scale in C++14
@@ -173,7 +220,30 @@ struct policy_hub
173220
sm90_tuning<SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(0));
174221
};
175222

176-
using MaxPolicy = Policy900;
223+
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
224+
{
225+
// Use values from tuning if a specialization exists, otherwise pick Policy900
226+
template <typename Tuning>
227+
static auto select_agent_policy(int)
228+
-> AgentHistogramPolicy<Tuning::threads,
229+
Tuning::items,
230+
Tuning::load_algorithm,
231+
Tuning::load_modifier,
232+
Tuning::rle_compress,
233+
Tuning::mem_preference,
234+
Tuning::work_stealing,
235+
Tuning::vec_size>;
236+
237+
template <typename Tuning>
238+
static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT;
239+
240+
using AgentHistogramPolicyT =
241+
decltype(select_agent_policy<
242+
sm100_tuning<IsEven, SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(
243+
0));
244+
};
245+
246+
using MaxPolicy = Policy1000;
177247
};
178248
} // namespace histogram
179249
} // namespace detail

0 commit comments

Comments
 (0)