Skip to content

Commit fd7058c

Browse files
authored
Use doctest for testing python example docstrings (rapidsai#1073)
Similar to rapidsai/cudf#9815, this change uses doctest to test that the pylibraft example docstrings run without issue. This caught several errors in the example docstrings, that are also fixed in this PR: * a missing ‘device_ndarray’ import in kmeans fit when the centroids weren’t explicitly passed in * an error in the fused_l2_nn_argmin docstring where output wasn’t defined * An `AttributeError: module 'pylibraft.neighbors.ivf_pq' has no attribute 'np'` error in ivf_pq Closes rapidsai#981 Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1073
1 parent 81e01a1 commit fd7058c

File tree

12 files changed

+271
-147
lines changed

12 files changed

+271
-147
lines changed

python/pylibraft/pylibraft/cluster/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@
1313
# limitations under the License.
1414
#
1515

16-
from .kmeans import compute_new_centroids
16+
from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit
17+
18+
__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"]

python/pylibraft/pylibraft/cluster/kmeans.pyx

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ from libcpp cimport nullptr
2727
from collections import namedtuple
2828
from enum import IntEnum
2929

30-
from pylibraft.common import Handle, cai_wrapper
30+
from pylibraft.common import Handle, cai_wrapper, device_ndarray
3131
from pylibraft.common.handle import auto_sync_handle
3232

3333
from pylibraft.common.handle cimport handle_t
@@ -81,33 +81,33 @@ def compute_new_centroids(X,
8181
--------
8282
8383
>>> import cupy as cp
84-
>>>
84+
8585
>>> from pylibraft.common import Handle
8686
>>> from pylibraft.cluster.kmeans import compute_new_centroids
87-
>>>
87+
8888
>>> # A single RAFT handle can optionally be reused across
8989
>>> # pylibraft functions.
9090
>>> handle = Handle()
91-
>>>
91+
9292
>>> n_samples = 5000
9393
>>> n_features = 50
9494
>>> n_clusters = 3
95-
>>>
95+
9696
>>> X = cp.random.random_sample((n_samples, n_features),
97-
>>> dtype=cp.float32)
98-
>>>
97+
... dtype=cp.float32)
98+
9999
>>> centroids = cp.random.random_sample((n_clusters, n_features),
100-
>>> dtype=cp.float32)
101-
>>>
100+
... dtype=cp.float32)
101+
...
102102
>>> labels = cp.random.randint(0, high=n_clusters, size=n_samples,
103-
>>> dtype=cp.int32)
104-
>>>
103+
... dtype=cp.int32)
104+
105105
>>> new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32)
106-
>>>
106+
107107
>>> compute_new_centroids(
108-
>>> X, centroids, labels, new_centroids, handle=handle
109-
>>> )
110-
>>>
108+
... X, centroids, labels, new_centroids, handle=handle
109+
... )
110+
111111
>>> # pylibraft functions are often asynchronous so the
112112
>>> # handle needs to be explicitly synchronized
113113
>>> handle.sync()
@@ -211,22 +211,21 @@ def cluster_cost(X, centroids, handle=None):
211211
Examples
212212
--------
213213
214-
.. code-block:: python
215-
import cupy as cp
216-
217-
from pylibraft.cluster.kmeans import cluster_cost
218-
219-
n_samples = 5000
220-
n_features = 50
221-
n_clusters = 3
222-
223-
X = cp.random.random_sample((n_samples, n_features),
224-
dtype=cp.float32)
214+
>>> import cupy as cp
215+
>>>
216+
>>> from pylibraft.cluster.kmeans import cluster_cost
217+
>>>
218+
>>> n_samples = 5000
219+
>>> n_features = 50
220+
>>> n_clusters = 3
221+
>>>
222+
>>> X = cp.random.random_sample((n_samples, n_features),
223+
... dtype=cp.float32)
225224
226-
centroids = cp.random.random_sample((n_clusters, n_features),
227-
dtype=cp.float32)
225+
>>> centroids = cp.random.random_sample((n_clusters, n_features),
226+
... dtype=cp.float32)
228227
229-
inertia = cluster_cost(X, centroids)
228+
>>> inertia = cluster_cost(X, centroids)
230229
"""
231230
x_cai = X.__cuda_array_interface__
232231
centroids_cai = centroids.__cuda_array_interface__
@@ -434,21 +433,19 @@ def fit(
434433
Examples
435434
--------
436435
437-
.. code-block:: python
438-
439-
import cupy as cp
440-
441-
from pylibraft.cluster.kmeans import fit, KMeansParams
442-
443-
n_samples = 5000
444-
n_features = 50
445-
n_clusters = 3
446-
447-
X = cp.random.random_sample((n_samples, n_features),
448-
dtype=cp.float32)
436+
>>> import cupy as cp
437+
>>>
438+
>>> from pylibraft.cluster.kmeans import fit, KMeansParams
439+
>>>
440+
>>> n_samples = 5000
441+
>>> n_features = 50
442+
>>> n_clusters = 3
443+
>>>
444+
>>> X = cp.random.random_sample((n_samples, n_features),
445+
... dtype=cp.float32)
449446
450-
params = KMeansParams(n_clusters=n_clusters)
451-
centroids, inertia, n_iter = fit(params, X)
447+
>>> params = KMeansParams(n_clusters=n_clusters)
448+
>>> centroids, inertia, n_iter = fit(params, X)
452449
"""
453450
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()
454451

python/pylibraft/pylibraft/distance/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515

1616
from .fused_l2_nn import fused_l2_nn_argmin
1717
from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance
18+
19+
__all__ = ["fused_l2_nn_argmin", "pairwise_distance"]

python/pylibraft/pylibraft/distance/fused_l2_nn.pyx

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
8080
>>> n_clusters = 5
8181
>>> n_features = 50
8282
>>> in1 = cp.random.random_sample((n_samples, n_features),
83-
>>> dtype=cp.float32)
83+
... dtype=cp.float32)
8484
>>> in2 = cp.random.random_sample((n_clusters, n_features),
85-
>>> dtype=cp.float32)
85+
... dtype=cp.float32)
8686
>>> # A single RAFT handle can optionally be reused across
8787
>>> # pylibraft functions.
8888
>>> handle = Handle()
89-
>>> ...
90-
>>> output = fused_l2_nn_argmin(in1, in2, output, handle=handle)
91-
>>> ...
89+
90+
>>> output = fused_l2_nn_argmin(in1, in2, handle=handle)
91+
9292
>>> # pylibraft functions are often asynchronous so the
9393
>>> # handle needs to be explicitly synchronized
9494
>>> handle.sync()
@@ -103,20 +103,20 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
103103
>>> n_clusters = 5
104104
>>> n_features = 50
105105
>>> in1 = cp.random.random_sample((n_samples, n_features),
106-
>>> dtype=cp.float32)
106+
... dtype=cp.float32)
107107
>>> in2 = cp.random.random_sample((n_clusters, n_features),
108-
>>> dtype=cp.float32)
108+
... dtype=cp.float32)
109109
>>> output = cp.empty((n_samples, 1), dtype=cp.int32)
110110
>>> # A single RAFT handle can optionally be reused across
111111
>>> # pylibraft functions.
112112
>>> handle = Handle()
113-
>>> ...
113+
114114
>>> fused_l2_nn_argmin(in1, in2, out=output, handle=handle)
115-
>>> ...
115+
array(...)
116+
116117
>>> # pylibraft functions are often asynchronous so the
117118
>>> # handle needs to be explicitly synchronized
118119
>>> handle.sync()
119-
120120
"""
121121

122122
x_cai = cai_wrapper(X)

python/pylibraft/pylibraft/distance/pairwise_distance.pyx

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
124124
>>> n_samples = 5000
125125
>>> n_features = 50
126126
>>> in1 = cp.random.random_sample((n_samples, n_features),
127-
>>> dtype=cp.float32)
127+
... dtype=cp.float32)
128128
>>> in2 = cp.random.random_sample((n_samples, n_features),
129-
>>> dtype=cp.float32)
129+
... dtype=cp.float32)
130130
131131
A single RAFT handle can optionally be reused across
132132
pylibraft functions.
@@ -147,9 +147,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
147147
>>> n_samples = 5000
148148
>>> n_features = 50
149149
>>> in1 = cp.random.random_sample((n_samples, n_features),
150-
>>> dtype=cp.float32)
150+
... dtype=cp.float32)
151151
>>> in2 = cp.random.random_sample((n_samples, n_features),
152-
>>> dtype=cp.float32)
152+
... dtype=cp.float32)
153153
>>> output = cp.empty((n_samples, n_samples), dtype=cp.float32)
154154
155155
A single RAFT handle can optionally be reused across
@@ -158,7 +158,8 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
158158
>>>
159159
>>> handle = Handle()
160160
>>> pairwise_distance(in1, in2, out=output,
161-
>>> metric="euclidean", handle=handle)
161+
... metric="euclidean", handle=handle)
162+
array(...)
162163
163164
pylibraft functions are often asynchronous so the
164165
handle needs to be explicitly synchronized

python/pylibraft/pylibraft/neighbors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
# limitations under the License.
1414
#
1515
from .refine import refine
16+
17+
__all__ = ["refine"]

python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
#
1515

1616
from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search
17+
18+
__all__ = ["Index", "IndexParams", "SearchParams", "build", "extend", "search"]

0 commit comments

Comments
 (0)