Skip to content

Commit 81e01a1

Browse files
authored
Minor cython fixes / cleanup (rapidsai#1072)
* Release GIL on C++ cython declarations * Remove 'valid values for metric' mention from the compute_new_centroids docstring (since it doesn't take a metric parameter) * Remove `is_c_cont` in favour of `input_validation.is_c_contiguous` in kmeans.pyx Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1072
1 parent a9b09e4 commit 81e01a1

File tree

5 files changed

+9
-18
lines changed

5 files changed

+9
-18
lines changed

python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ from pylibraft.common.handle cimport handle_t
3131

3232

3333
cdef extern from "raft_runtime/cluster/kmeans.hpp" \
34-
namespace "raft::runtime::cluster::kmeans":
34+
namespace "raft::runtime::cluster::kmeans" nogil:
3535

3636
cdef void update_centroids(
3737
const handle_t& handle,

python/pylibraft/pylibraft/cluster/kmeans.pyx

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ from pylibraft.common.cpp.optional cimport optional
4646
from pylibraft.common.handle cimport handle_t
4747

4848

49-
def is_c_cont(cai, dt):
50-
return "strides" not in cai or \
51-
cai["strides"] is None or \
52-
cai["strides"][1] == dt.itemsize
53-
54-
5549
@auto_sync_handle
5650
def compute_new_centroids(X,
5751
centroids,
@@ -63,9 +57,6 @@ def compute_new_centroids(X,
6357
"""
6458
Compute new centroids given an input matrix and existing centroids
6559
66-
Valid values for metric:
67-
["euclidean", "sqeuclidean"]
68-
6960
Parameters
7061
----------
7162
@@ -167,9 +158,9 @@ def compute_new_centroids(X,
167158
handle = handle if handle is not None else Handle()
168159
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()
169160

170-
x_c_contiguous = is_c_cont(x_cai, x_dt)
171-
centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt)
172-
new_centroids_c_contiguous = is_c_cont(new_centroids_cai, new_centroids_dt)
161+
x_c_contiguous = is_c_contiguous(x_cai)
162+
centroids_c_contiguous = is_c_contiguous(centroids_cai)
163+
new_centroids_c_contiguous = is_c_contiguous(new_centroids_cai)
173164

174165
if not x_c_contiguous or not centroids_c_contiguous \
175166
or not new_centroids_c_contiguous:
@@ -258,8 +249,8 @@ def cluster_cost(X, centroids, handle=None):
258249
handle = handle if handle is not None else Handle()
259250
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()
260251

261-
x_c_contiguous = is_c_cont(x_cai, x_dt)
262-
centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt)
252+
x_c_contiguous = is_c_contiguous(x_cai)
253+
centroids_c_contiguous = is_c_contiguous(centroids_cai)
263254

264255
if not x_c_contiguous or not centroids_c_contiguous:
265256
raise ValueError("Inputs must all be c contiguous")

python/pylibraft/pylibraft/distance/fused_l2_nn.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ from pylibraft.common.handle cimport handle_t
3333

3434

3535
cdef extern from "raft_runtime/distance/fused_l2_nn.hpp" \
36-
namespace "raft::runtime::distance":
36+
namespace "raft::runtime::distance" nogil:
3737

3838
void fused_l2_nn_min_arg(
3939
const handle_t &handle,

python/pylibraft/pylibraft/distance/pairwise_distance.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ from pylibraft.common import cai_wrapper, device_ndarray
3535

3636

3737
cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \
38-
namespace "raft::runtime::distance":
38+
namespace "raft::runtime::distance" nogil:
3939

4040
cdef void pairwise_distance(const handle_t &handle,
4141
float *x,

python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ from pylibraft.random.cpp.rng_state cimport RngState
3333

3434

3535
cdef extern from "raft_runtime/random/rmat_rectangular_generator.hpp" \
36-
namespace "raft::runtime::random":
36+
namespace "raft::runtime::random" nogil:
3737

3838
cdef void rmat_rectangular_gen(const handle_t &handle,
3939
int* out,

0 commit comments

Comments
 (0)