Skip to content

Commit 13d8d19

Browse files
Introduce _bindings.call_build utility
This finds compute capability and include paths and appends them to the algorithm-specific arguments. Used the utility in segmented_reduce.
1 parent d6d39fa commit 13d8d19

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

python/cuda_parallel/cuda/parallel/experimental/_bindings.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import sys
88
from functools import lru_cache
99
from pathlib import Path
10-
from typing import List
10+
from typing import Callable, List, Tuple
11+
12+
from numba import cuda
1113

1214
from cuda.cccl import get_include_paths # type: ignore[import-not-found]
1315

@@ -41,3 +43,22 @@ def get_paths() -> List[bytes]:
4143
if path is not None
4244
]
4345
return paths
46+
47+
48+
def call_build(build_impl_fn: Callable, args: Tuple):
49+
"""Calls given build_impl_fn callable while providing compute capability and paths
50+
51+
Returns result of the call.
52+
"""
53+
cc_major, cc_minor = cuda.get_current_device().compute_capability
54+
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
55+
error = build_impl_fn(
56+
*args,
57+
cc_major,
58+
cc_minor,
59+
ctypes.c_char_p(cub_path),
60+
ctypes.c_char_p(thrust_path),
61+
ctypes.c_char_p(libcudacxx_path),
62+
ctypes.c_char_p(cuda_include_path),
63+
)
64+
return error

python/cuda_parallel/cuda/parallel/experimental/algorithms/segmented_reduce.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
import numba
55
import numpy as np
6-
from numba import cuda
76
from numba.cuda.cudadrv import enums
87

98
from .. import _cccl as cccl
10-
from .._bindings import get_bindings, get_paths
9+
from .._bindings import call_build, get_bindings
1110
from .._caching import CachableFunction, cache_with_key
1211
from .._utils import protocols
1312
from ..iterators._iterators import IteratorBase
@@ -36,8 +35,6 @@ def __init__(
3635
self.start_offsets_in_cccl = cccl.to_cccl_iter(start_offsets_in)
3736
self.end_offsets_in_cccl = cccl.to_cccl_iter(end_offsets_in)
3837
self.h_init_cccl = cccl.to_cccl_value(h_init)
39-
cc_major, cc_minor = cuda.get_current_device().compute_capability
40-
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
4138
if isinstance(h_init, np.ndarray):
4239
value_type = numba.from_dtype(h_init.dtype)
4340
else:
@@ -46,20 +43,17 @@ def __init__(
4643
self.op_wrapper = cccl.to_cccl_op(op, sig)
4744
self.build_result = cccl.DeviceSegmentedReduceBuildResult()
4845
self.bindings = get_bindings()
49-
error = self.bindings.cccl_device_segmented_reduce_build(
50-
ctypes.byref(self.build_result),
51-
self.d_in_cccl,
52-
self.d_out_cccl,
53-
self.start_offsets_in_cccl,
54-
self.end_offsets_in_cccl,
55-
self.op_wrapper,
56-
self.h_init_cccl,
57-
cc_major,
58-
cc_minor,
59-
ctypes.c_char_p(cub_path),
60-
ctypes.c_char_p(thrust_path),
61-
ctypes.c_char_p(libcudacxx_path),
62-
ctypes.c_char_p(cuda_include_path),
46+
error = call_build(
47+
self.bindings.cccl_device_segmented_reduce_build,
48+
(
49+
ctypes.byref(self.build_result),
50+
self.d_in_cccl,
51+
self.d_out_cccl,
52+
self.start_offsets_in_cccl,
53+
self.end_offsets_in_cccl,
54+
self.op_wrapper,
55+
self.h_init_cccl,
56+
),
6357
)
6458
if error != enums.CUDA_SUCCESS:
6559
raise ValueError("Error building reduce")

0 commit comments

Comments
 (0)