Skip to content

Commit a79a4ee

Browse files
committed
Add multi-dimensional support to block_radix_sort routines.
1 parent 29f4d52 commit a79a4ee

File tree

2 files changed

+145
-91
lines changed

2 files changed

+145
-91
lines changed

python/cuda_cooperative/cuda/cooperative/experimental/block/_block_radix_sort.py

Lines changed: 124 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from typing import TYPE_CHECKING, Tuple, Union
6+
57
import numba
68

79
from cuda.cooperative.experimental._common import (
10+
CUB_BLOCK_SCAN_ALGOS,
11+
CudaSharedMemConfig,
12+
dim3,
813
make_binary_tempfile,
14+
normalize_dim_param,
915
normalize_dtype_param,
1016
)
1117
from cuda.cooperative.experimental._types import (
@@ -18,6 +24,112 @@
1824
Value,
1925
)
2026

27+
if TYPE_CHECKING:
28+
import numpy as np
29+
30+
31+
TEMPLATE_PARAMETERS = [
32+
TemplateParameter("KeyT"),
33+
TemplateParameter("BLOCK_DIM_X"),
34+
TemplateParameter("ITEMS_PER_THREAD"),
35+
TemplateParameter("ValueT"),
36+
TemplateParameter("RADIX_BITS"),
37+
TemplateParameter("MEMOIZE_OUTER_SCAN"),
38+
TemplateParameter("INNER_SCAN_ALGORITHM"),
39+
TemplateParameter("SMEM_CONFIG"),
40+
TemplateParameter("BLOCK_DIM_Y"),
41+
TemplateParameter("BLOCK_DIM_Z"),
42+
]
43+
44+
45+
METHOD_PARAMETERS_VARIANTS = [
46+
[
47+
Pointer(numba.uint8),
48+
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
49+
],
50+
[
51+
Pointer(numba.uint8),
52+
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
53+
Value(numba.int32),
54+
Value(numba.int32),
55+
],
56+
]
57+
58+
59+
# N.B. In order to support multi-dimensional block dimensions, we have to
60+
# defaults for all the template parameters preceding the final Y and
61+
# Z dimensions. This will be improved in the future, allowing users
62+
# to provide overrides for the default values.
63+
64+
TEMPLATE_PARAMETER_DEFAULTS = {
65+
"ValueT": "::cub::NullType", # Indicates keys-only sort
66+
"RADIX_BITS": 4,
67+
"MEMOIZE_OUTER_SCAN": "true",
68+
"INNER_SCAN_ALGORITHM": CUB_BLOCK_SCAN_ALGOS["warp_scans"],
69+
"SMEM_CONFIG": str(CudaSharedMemConfig.BankSizeFourByte),
70+
}
71+
72+
73+
def _get_template_parameter_specializations(
74+
dtype: numba.types.Type, dim: dim3, items_per_thread: int
75+
) -> dict:
76+
"""
77+
Returns a dictionary of template parameter specializations for the block
78+
radix sort algorithm.
79+
80+
Args:
81+
dtype: Supplies the Numba data type.
82+
83+
dim: Supplies the block dimensions.
84+
85+
items_per_thread: Supplies the number of items each thread owns.
86+
87+
Returns:
88+
A dictionary of template parameter specializations.
89+
"""
90+
specialization = {
91+
"KeyT": dtype,
92+
"BLOCK_DIM_X": dim[0],
93+
"ITEMS_PER_THREAD": items_per_thread,
94+
"BLOCK_DIM_Y": dim[1],
95+
"BLOCK_DIM_Z": dim[2],
96+
}
97+
98+
specialization.update(TEMPLATE_PARAMETER_DEFAULTS)
99+
100+
return specialization
101+
102+
103+
def _radix_sort(
104+
dtype: Union[str, type, "np.dtype", "numba.types.Type"],
105+
threads_per_block: Union[int, Tuple[int, int], Tuple[int, int, int], dim3],
106+
items_per_thread: int,
107+
descending: bool,
108+
) -> Invocable:
109+
dim = normalize_dim_param(threads_per_block)
110+
dtype = normalize_dtype_param(dtype)
111+
112+
method_name = "SortDescending" if descending else "Sort"
113+
template = Algorithm(
114+
"BlockRadixSort",
115+
method_name,
116+
"block_radix_sort",
117+
["cub/block/block_radix_sort.cuh"],
118+
TEMPLATE_PARAMETERS,
119+
METHOD_PARAMETERS_VARIANTS,
120+
)
121+
specialization = template.specialize(
122+
_get_template_parameter_specializations(dtype, dim, items_per_thread)
123+
)
124+
return Invocable(
125+
temp_files=[
126+
make_binary_tempfile(ltoir, ".ltoir")
127+
for ltoir in specialization.get_lto_ir()
128+
],
129+
temp_storage_bytes=specialization.get_temp_storage_bytes(),
130+
algorithm=specialization,
131+
)
132+
21133

22134
def radix_sort_keys(dtype, threads_per_block, items_per_thread):
23135
"""Performs an ascending block-wide radix sort over a :ref:`blocked arrangement <flexible-data-arrangement>` of keys.
@@ -47,54 +159,17 @@ def radix_sort_keys(dtype, threads_per_block, items_per_thread):
47159
``{ [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }``.
48160
49161
Args:
50-
dtype: Numba data type of the keys to be sorted
51-
threads_per_block: The number of threads in a block
162+
dtype: Data type of the keys to be sorted
163+
164+
threads_per_block: The number of threads in a block, either an integer
165+
or a tuple of 2 or 3 integers
166+
52167
items_per_thread: The number of items each thread owns
53168
54169
Returns:
55170
A callable object that can be linked to and invoked from a CUDA kernel
56171
"""
57-
# Normalize the dtype parameter.
58-
dtype = normalize_dtype_param(dtype)
59-
60-
template = Algorithm(
61-
"BlockRadixSort",
62-
"Sort",
63-
"block_radix_sort",
64-
["cub/block/block_radix_sort.cuh"],
65-
[
66-
TemplateParameter("KeyT"),
67-
TemplateParameter("BLOCK_DIM_X"),
68-
TemplateParameter("ITEMS_PER_THREAD"),
69-
],
70-
[
71-
[
72-
Pointer(numba.uint8),
73-
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
74-
],
75-
[
76-
Pointer(numba.uint8),
77-
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
78-
Value(numba.int32),
79-
Value(numba.int32),
80-
],
81-
],
82-
)
83-
specialization = template.specialize(
84-
{
85-
"KeyT": dtype,
86-
"BLOCK_DIM_X": threads_per_block,
87-
"ITEMS_PER_THREAD": items_per_thread,
88-
}
89-
)
90-
return Invocable(
91-
temp_files=[
92-
make_binary_tempfile(ltoir, ".ltoir")
93-
for ltoir in specialization.get_lto_ir()
94-
],
95-
temp_storage_bytes=specialization.get_temp_storage_bytes(),
96-
algorithm=specialization,
97-
)
172+
return _radix_sort(dtype, threads_per_block, items_per_thread, descending=False)
98173

99174

100175
def radix_sort_keys_descending(dtype, threads_per_block, items_per_thread):
@@ -125,49 +200,14 @@ def radix_sort_keys_descending(dtype, threads_per_block, items_per_thread):
125200
``{ [511, 510, 509, 508], [507, 506, 505, 504], ..., [3, 2, 1, 0] }``.
126201
127202
Args:
128-
dtype: Numba data type of the keys to be sorted
129-
threads_per_block: The number of threads in a block
203+
dtype: Data type of the keys to be sorted
204+
205+
threads_per_block: The number of threads in a block, either an integer
206+
or a tuple of 2 or 3 integers
207+
130208
items_per_thread: The number of items each thread owns
131209
132210
Returns:
133211
A callable object that can be linked to and invoked from a CUDA kernel
134212
"""
135-
template = Algorithm(
136-
"BlockRadixSort",
137-
"SortDescending",
138-
"block_radix_sort",
139-
["cub/block/block_radix_sort.cuh"],
140-
[
141-
TemplateParameter("KeyT"),
142-
TemplateParameter("BLOCK_DIM_X"),
143-
TemplateParameter("ITEMS_PER_THREAD"),
144-
],
145-
[
146-
[
147-
Pointer(numba.uint8),
148-
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
149-
],
150-
[
151-
Pointer(numba.uint8),
152-
DependentArray(Dependency("KeyT"), Dependency("ITEMS_PER_THREAD")),
153-
Value(numba.int32),
154-
Value(numba.int32),
155-
],
156-
],
157-
)
158-
specialization = template.specialize(
159-
{
160-
"KeyT": dtype,
161-
"BLOCK_DIM_X": threads_per_block,
162-
"ITEMS_PER_THREAD": items_per_thread,
163-
}
164-
)
165-
166-
return Invocable(
167-
temp_files=[
168-
make_binary_tempfile(ltoir, ".ltoir")
169-
for ltoir in specialization.get_lto_ir()
170-
],
171-
temp_storage_bytes=specialization.get_temp_storage_bytes(),
172-
algorithm=specialization,
173-
)
213+
return _radix_sort(dtype, threads_per_block, items_per_thread, descending=True)

python/cuda_cooperative/tests/test_block_radix_sort.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from functools import reduce
6+
from operator import mul
7+
58
import numba
69
import pytest
7-
from helpers import NUMBA_TYPES_TO_NP, random_int
10+
from helpers import NUMBA_TYPES_TO_NP, random_int, row_major_tid
811
from numba import cuda, types
912
from pynvjitlink import patch
1013

@@ -15,19 +18,26 @@
1518

1619

1720
@pytest.mark.parametrize("T", [types.int8, types.int16, types.uint32, types.uint64])
18-
@pytest.mark.parametrize("threads_per_block", [32, 128, 256, 1024])
21+
@pytest.mark.parametrize("threads_per_block", [32, 128, 256, 1024, (4, 8), (2, 4, 8)])
1922
@pytest.mark.parametrize("items_per_thread", [1, 3])
2023
def test_block_radix_sort_descending(T, threads_per_block, items_per_thread):
2124
begin_bit = numba.int32(0)
2225
end_bit = numba.int32(T.bitwidth)
26+
27+
num_threads_per_block = (
28+
threads_per_block
29+
if type(threads_per_block) is int
30+
else reduce(mul, threads_per_block)
31+
)
32+
2333
block_radix_sort = cudax.block.radix_sort_keys_descending(
2434
dtype=T, threads_per_block=threads_per_block, items_per_thread=items_per_thread
2535
)
2636
temp_storage_bytes = block_radix_sort.temp_storage_bytes
2737

2838
@cuda.jit(link=block_radix_sort.files)
2939
def kernel(input, output):
30-
tid = cuda.threadIdx.x
40+
tid = row_major_tid()
3141
temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
3242
thread_data = cuda.local.array(shape=items_per_thread, dtype=dtype)
3343
for i in range(items_per_thread):
@@ -37,7 +47,7 @@ def kernel(input, output):
3747
output[tid * items_per_thread + i] = thread_data[i]
3848

3949
dtype = NUMBA_TYPES_TO_NP[T]
40-
items_per_tile = threads_per_block * items_per_thread
50+
items_per_tile = num_threads_per_block * items_per_thread
4151
input = random_int(items_per_tile, dtype)
4252
d_input = cuda.to_device(input)
4353
d_output = cuda.device_array(items_per_tile, dtype=dtype)
@@ -57,10 +67,14 @@ def kernel(input, output):
5767

5868

5969
@pytest.mark.parametrize("T", [types.int8, types.int16, types.uint32, types.uint64])
60-
@pytest.mark.parametrize("threads_per_block", [32, 128, 256, 1024])
70+
@pytest.mark.parametrize("threads_per_block", [32, 128, 256, 1024, (4, 8), (2, 4, 8)])
6171
@pytest.mark.parametrize("items_per_thread", [1, 3])
6272
def test_block_radix_sort(T, threads_per_block, items_per_thread):
63-
items_per_tile = threads_per_block * items_per_thread
73+
items_per_tile = (
74+
threads_per_block * items_per_thread
75+
if type(threads_per_block) is int
76+
else reduce(mul, threads_per_block) * items_per_thread
77+
)
6478

6579
block_radix_sort = cudax.block.radix_sort_keys(
6680
dtype=T, threads_per_block=threads_per_block, items_per_thread=items_per_thread
@@ -69,7 +83,7 @@ def test_block_radix_sort(T, threads_per_block, items_per_thread):
6983

7084
@cuda.jit(link=block_radix_sort.files)
7185
def kernel(input, output):
72-
tid = cuda.threadIdx.x
86+
tid = row_major_tid()
7387
temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype="uint8")
7488
thread_data = cuda.local.array(shape=items_per_thread, dtype=dtype)
7589
for i in range(items_per_thread):

0 commit comments

Comments
 (0)