Skip to content

Commit b031ecf

Browse files
committed
Support struct-like inputs
1 parent 61afc05 commit b031ecf

File tree

6 files changed

+164
-13
lines changed

6 files changed

+164
-13
lines changed

python/cuda_parallel/cuda/parallel/experimental/_cccl.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def _type_to_enum(numba_type: types.Type) -> TypeEnum:
121121
def _numba_type_to_info(numba_type: types.Type) -> TypeInfo:
122122
context = cuda.descriptor.cuda_target.target_context
123123
value_type = context.get_value_type(numba_type)
124+
if isinstance(numba_type, types.Record):
125+
# then `value_type` is a pointer and we need the
126+
# alignment of the pointee.
127+
value_type = value_type.pointee
124128
size = value_type.get_abi_size(context.target_data)
125129
alignment = value_type.get_abi_alignment(context.target_data)
126130
return TypeInfo(size, alignment, _type_to_enum(numba_type))
@@ -211,4 +215,9 @@ def to_cccl_iter(array_or_iterator) -> Iterator:
211215

212216
def host_array_to_value(array: np.ndarray) -> Value:
213217
info = _numpy_type_to_info(array.dtype)
214-
return Value(info, array.ctypes.data)
218+
if isinstance(array, np.ndarray):
219+
data = array.ctypes.data
220+
else:
221+
# it's a gpudataclass:
222+
data = ctypes.cast(ctypes.pointer(array._data), ctypes.c_void_p)
223+
return Value(info, data)

python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515

1616

1717
def get_dtype(arr: DeviceArrayLike) -> np.dtype:
18-
return np.dtype(arr.__cuda_array_interface__["typestr"])
18+
typestr = arr.__cuda_array_interface__["typestr"]
19+
20+
if typestr.startswith("|V"):
21+
# it's a structured dtype, use the descr field:
22+
return np.dtype(arr.__cuda_array_interface__["descr"])
23+
else:
24+
# a simple dtype, use the typestr field:
25+
return np.dtype(typestr)
1926

2027

2128
def get_strides(arr: DeviceArrayLike) -> Optional[Tuple]:

python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
from .._caching import CachableFunction, cache_with_key
1919
from .._utils import cai
2020
from ..iterators._iterators import IteratorBase
21-
from ..typing import DeviceArrayLike
21+
from ..typing import DeviceArrayLike, GpuStruct
2222

2323

2424
class _Op:
25-
def __init__(self, dtype: np.dtype, op: Callable):
26-
value_type = numba.from_dtype(dtype)
27-
self.ltoir, _ = cuda.compile(
28-
op, sig=value_type(value_type, value_type), output="ltoir"
29-
)
25+
def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable):
26+
if isinstance(h_init, np.ndarray):
27+
value_type = numba.from_dtype(h_init.dtype)
28+
else:
29+
value_type = numba.typeof(h_init)
30+
self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir")
3031
self.name = op.__name__.encode("utf-8")
3132

3233
def handle(self) -> cccl.Op:
@@ -53,7 +54,7 @@ def __init__(
5354
d_in: DeviceArrayLike | IteratorBase,
5455
d_out: DeviceArrayLike,
5556
op: Callable,
56-
h_init: np.ndarray,
57+
h_init: np.ndarray | GpuStruct,
5758
):
5859
d_in_cccl = cccl.to_cccl_iter(d_in)
5960
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
@@ -64,11 +65,10 @@ def __init__(
6465
cc_major, cc_minor = cuda.get_current_device().compute_capability
6566
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
6667
bindings = get_bindings()
67-
self.op_wrapper = _Op(h_init.dtype, op)
68+
self.op_wrapper = _Op(h_init, op)
6869
d_out_cccl = cccl.to_cccl_iter(d_out)
6970
self.build_result = cccl.DeviceReduceBuildResult()
7071

71-
# TODO Figure out caching
7272
error = bindings.cccl_device_reduce_build(
7373
ctypes.byref(self.build_result),
7474
d_in_cccl,
@@ -85,7 +85,9 @@ def __init__(
8585
if error != enums.CUDA_SUCCESS:
8686
raise ValueError("Error building reduce")
8787

88-
def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray):
88+
def __call__(
89+
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
90+
):
8991
d_in_cccl = cccl.to_cccl_iter(d_in)
9092
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
9193
assert num_items is not None
@@ -99,7 +101,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray
99101
self._ctor_d_in_cccl_type_enum_name,
100102
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
101103
)
102-
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype)
104+
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
103105
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
104106
bindings = get_bindings()
105107
if temp_storage is None:
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
from dataclasses import fields as dataclass_fields
3+
4+
import numba
5+
import numpy as np
6+
from numba.core import cgutils
7+
from numba.core.extending import (
8+
make_attribute_wrapper,
9+
models,
10+
register_model,
11+
typeof_impl,
12+
)
13+
from numba.core.typing import signature as nb_signature
14+
from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate
15+
from numba.cuda.cudadecl import registry as cuda_registry
16+
from numba.extending import lower_builtin
17+
18+
from .typing import GpuStruct
19+
20+
21+
def gpu_struct(this: type) -> GpuStruct:
22+
anns = getattr(this, "__annotations__", {})
23+
24+
# set the .dtype attribute on the class for numpy compatibility:
25+
setattr(this, "dtype", np.dtype(list(anns.items())))
26+
27+
# define __post_init__ to create a ctypes object from the fields,
28+
# and keep a reference to it in the `._data` attribute.
29+
def __post_init__(self):
30+
ctypes_typ = np.ctypeslib.as_ctypes_type(this.dtype)
31+
self._data = ctypes_typ(*(getattr(self, name) for name in this.dtype.names))
32+
33+
setattr(this, "__post_init__", __post_init__)
34+
35+
# create a dataclass:
36+
this = dataclass(this)
37+
fields = dataclass_fields(this)
38+
39+
# define a numba type corresponding to the dataclass:
40+
class ThisType(numba.types.Type):
41+
def __init__(self):
42+
super().__init__(name=this.__name__)
43+
44+
this_type = ThisType()
45+
46+
@typeof_impl.register(this)
47+
def typeof_this(val, c):
48+
return ThisType()
49+
50+
# Data model corresponding to ThisType:
51+
@register_model(ThisType)
52+
class ThisModel(models.StructModel):
53+
def __init__(self, dmm, fe_type):
54+
members = [(field.name, numba.from_dtype(field.type)) for field in fields]
55+
super().__init__(dmm, fe_type, members)
56+
57+
# Typing for accessing attributes (fields) of the dataclass:
58+
class ThisAttrsTemplate(AttributeTemplate):
59+
pass
60+
61+
for field in fields:
62+
typ = field.type
63+
name = field.name
64+
65+
def resolver(self, this):
66+
return numba.from_dtype(typ)
67+
68+
setattr(ThisAttrsTemplate, f"resolve_{name}", resolver)
69+
70+
@cuda_registry.register_attr
71+
class ThisAttrs(ThisAttrsTemplate):
72+
key = this_type
73+
74+
# Lowering for attribute access:
75+
for field in fields:
76+
make_attribute_wrapper(ThisType, field.name, field.name)
77+
78+
# Register typing for constructor.
79+
@cuda_registry.register
80+
class TypeConstructor(ConcreteTemplate):
81+
key = this
82+
cases = [
83+
nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields])
84+
]
85+
86+
cuda_registry.register_global(this, numba.types.Function(TypeConstructor))
87+
88+
def type_constructor(context, builder, sig, args):
89+
ty = sig.return_type
90+
retval = cgutils.create_struct_proxy(ty)(context, builder)
91+
for field, val in zip(fields, args):
92+
setattr(retval, field.name, val)
93+
return retval._getvalue()
94+
95+
lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])(
96+
type_constructor
97+
)
98+
99+
return this

python/cuda_parallel/cuda/parallel/experimental/typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from typing_extensions import (
24
Protocol,
35
) # TODO: typing_extensions required for Python 3.7 docs env
@@ -10,3 +12,7 @@ class DeviceArrayLike(Protocol):
1012
"""
1113

1214
__cuda_array_interface__: dict
15+
16+
17+
# return type of @gpu_struct
18+
GpuStruct = Any

python/cuda_parallel/tests/test_reduce.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import cuda.parallel.experimental.algorithms as algorithms
1414
import cuda.parallel.experimental.iterators as iterators
15+
from cuda.parallel.experimental.gpu_struct import gpu_struct
1516

1617

1718
def random_int(shape, dtype):
@@ -550,3 +551,30 @@ def binary_op(x, y):
550551
d_in = cp.zeros(size)[::2]
551552
with pytest.raises(ValueError, match="Non-contiguous arrays are not supported."):
552553
_ = algorithms.reduce_into(d_in, d_out, binary_op, h_init)
554+
555+
556+
def test_reduce_struct_type():
557+
@gpu_struct
558+
class Pixel:
559+
r: np.int32
560+
g: np.int32
561+
b: np.int32
562+
563+
def max_g_value(x, y):
564+
return x if x.g > y.g else y
565+
566+
d_rgb = cp.random.randint(0, 256, (10, 3), dtype=np.int32).view(Pixel.dtype)
567+
d_out = cp.zeros(1, Pixel.dtype)
568+
569+
h_init = Pixel(0, 0, 0)
570+
571+
reduce_into = algorithms.reduce_into(d_rgb, d_out, max_g_value, h_init)
572+
temp_storage_bytes = reduce_into(None, d_rgb, d_out, len(d_rgb), h_init)
573+
574+
d_temp_storage = cp.zeros(temp_storage_bytes, dtype=np.uint8)
575+
_ = reduce_into(d_temp_storage, d_rgb, d_out, len(d_rgb), h_init)
576+
577+
h_rgb = d_rgb.get()
578+
expected = h_rgb[h_rgb.view("int32")[:, 1].argmax()]
579+
580+
np.testing.assert_equal(expected["g"], d_out.get()["g"])

0 commit comments

Comments
 (0)