Skip to content

Commit a9b09e4

Browse files
authored
Paying down some tech debt on docs, runtime API, and cython (rapidsai#1055)
This PR pays down some tech debt and cleans up some things. On the surface, you'll notice many files have been touched or modified but the modifications are largely confined to a few major categories of changes: 1. Fixes some of the issues found with the doc updates in 22.12. 2. Breaks some of the docs for the c++ namespaces down into multiple sections to make them easier to navigate and consume 3. Renames raft_distance directory into more appropriately named raft_runtime. (This is also in preparation to eventually rename the libraft-distance library into libraft once we can remove the FAISS dependency. 4. Separates out some runtime source files and APIs that were being mistakenly combined with the template specializations API 5. Consolidates multiple mdspan.pxd files into a single file. 6. Consistently uses `cpp` directory for new(er) pxd files, nested into their respective packages. 7. Consistently uses doxygen groups in many of the namespaces. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) - Divye Gala (https://github.com/divyegala) URL: rapidsai#1055
1 parent ff5c9f0 commit a9b09e4

File tree

229 files changed

+3541
-1718
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

229 files changed

+3541
-1718
lines changed

cpp/CMakeLists.txt

Lines changed: 75 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -278,81 +278,83 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)
278278
if(RAFT_COMPILE_DIST_LIBRARY)
279279
add_library(
280280
raft_distance_lib
281-
src/distance/pairwise_distance.cu
282-
src/distance/fused_l2_min_arg.cu
283-
src/distance/update_centroids_float.cu
284-
src/distance/update_centroids_double.cu
285-
src/distance/cluster_cost_float.cu
286-
src/distance/cluster_cost_double.cu
287-
src/distance/kmeans_fit_float.cu
288-
src/distance/kmeans_fit_double.cu
289-
src/distance/specializations/detail/canberra.cu
290-
src/distance/specializations/detail/chebyshev.cu
291-
src/distance/specializations/detail/correlation.cu
292-
src/distance/specializations/detail/cosine.cu
293-
src/distance/specializations/detail/cosine.cu
294-
src/distance/specializations/detail/hamming_unexpanded.cu
295-
src/distance/specializations/detail/hellinger_expanded.cu
296-
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
297-
src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
298-
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
299-
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
300-
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
301-
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
302-
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
281+
src/distance/distance/pairwise_distance.cu
282+
src/distance/distance/fused_l2_min_arg.cu
283+
src/distance/cluster/update_centroids_float.cu
284+
src/distance/cluster/update_centroids_double.cu
285+
src/distance/cluster/cluster_cost_float.cu
286+
src/distance/cluster/cluster_cost_double.cu
287+
src/distance/neighbors/refine.cu
288+
src/distance/neighbors/ivfpq_search.cu
289+
src/distance/cluster/kmeans_fit_float.cu
290+
src/distance/cluster/kmeans_fit_double.cu
291+
src/distance/distance/specializations/detail/canberra.cu
292+
src/distance/distance/specializations/detail/chebyshev.cu
293+
src/distance/distance/specializations/detail/correlation.cu
294+
src/distance/distance/specializations/detail/cosine.cu
295+
src/distance/distance/specializations/detail/cosine.cu
296+
src/distance/distance/specializations/detail/hamming_unexpanded.cu
297+
src/distance/distance/specializations/detail/hellinger_expanded.cu
298+
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
299+
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
300+
src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
301+
src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu
302+
src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu
303+
src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
304+
src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
303305
# These are somehow missing a kernel definition which is causing a compile error.
304306
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
305307
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
306-
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
307-
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
308-
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
309-
src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
310-
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
311-
src/distance/specializations/detail/l1_float_float_float_int.cu
312-
src/distance/specializations/detail/l1_float_float_float_uint32.cu
313-
src/distance/specializations/detail/l1_double_double_double_int.cu
314-
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
315-
src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
316-
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
317-
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
318-
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
319-
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
320-
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
321-
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
322-
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
323-
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
324-
src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
325-
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
326-
src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
327-
src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
328-
src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
329-
src/distance/specializations/detail/russel_rao_double_double_double_int.cu
330-
src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
331-
src/distance/specializations/detail/russel_rao_float_float_float_int.cu
332-
src/distance/specializations/fused_l2_nn_double_int.cu
333-
src/distance/specializations/fused_l2_nn_double_int64.cu
334-
src/distance/specializations/fused_l2_nn_float_int.cu
335-
src/distance/specializations/fused_l2_nn_float_int64.cu
336-
src/nn/specializations/detail/ivfpq_build.cu
337-
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
338-
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
339-
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
340-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
341-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
342-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
343-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
344-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
345-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
346-
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
347-
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
348-
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
349-
src/nn/specializations/detail/ivfpq_search.cu
350-
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
351-
src/nn/specializations/refine.cu
352-
src/random/specializations/rmat_rectangular_generator_int_double.cu
353-
src/random/specializations/rmat_rectangular_generator_int64_double.cu
354-
src/random/specializations/rmat_rectangular_generator_int_float.cu
355-
src/random/specializations/rmat_rectangular_generator_int64_float.cu
308+
src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu
309+
src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu
310+
src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu
311+
src/distance/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
312+
src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu
313+
src/distance/distance/specializations/detail/l1_float_float_float_int.cu
314+
src/distance/distance/specializations/detail/l1_float_float_float_uint32.cu
315+
src/distance/distance/specializations/detail/l1_double_double_double_int.cu
316+
src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu
317+
src/distance/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
318+
src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu
319+
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
320+
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
321+
src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
322+
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
323+
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
324+
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
325+
src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
326+
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
327+
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
328+
src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
329+
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
330+
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
331+
src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu
332+
src/distance/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
333+
src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu
334+
src/distance/distance/specializations/fused_l2_nn_double_int.cu
335+
src/distance/distance/specializations/fused_l2_nn_double_int64.cu
336+
src/distance/distance/specializations/fused_l2_nn_float_int.cu
337+
src/distance/distance/specializations/fused_l2_nn_float_int64.cu
338+
src/distance/neighbors/ivfpq_build.cu
339+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_fast.cu
340+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
341+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
342+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
343+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
344+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
345+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
346+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
347+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
348+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_fast.cu
349+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
350+
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
351+
src/distance/neighbors/specializations/detail/ivfpq_search_float_int64_t.cu
352+
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint64_t.cu
353+
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint32_t.cu
354+
src/distance/random/rmat_rectangular_generator_int_double.cu
355+
src/distance/random/rmat_rectangular_generator_int64_double.cu
356+
src/distance/random/rmat_rectangular_generator_int_float.cu
357+
src/distance/random/rmat_rectangular_generator_int64_float.cu
356358
)
357359
set_target_properties(
358360
raft_distance_lib
@@ -410,23 +412,6 @@ if(RAFT_COMPILE_NN_LIBRARY)
410412
src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
411413
src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
412414
src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu
413-
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
414-
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
415-
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
416-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
417-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
418-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
419-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
420-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
421-
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
422-
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
423-
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
424-
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
425-
src/nn/specializations/detail/ivfpq_build.cu
426-
src/nn/specializations/detail/ivfpq_search.cu
427-
src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
428-
src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
429-
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
430415
src/nn/specializations/fused_l2_knn_long_float_true.cu
431416
src/nn/specializations/fused_l2_knn_long_float_false.cu
432417
src/nn/specializations/fused_l2_knn_int_float_true.cu
@@ -519,7 +504,7 @@ if(TARGET raft_distance_lib)
519504
EXPORT raft-distance-lib-exports
520505
)
521506
install(
522-
DIRECTORY include/raft_distance
507+
DIRECTORY include/raft_runtime
523508
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
524509
COMPONENT distance
525510
)

cpp/include/raft/distance/distance.cuh

Lines changed: 85 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
#include <raft/core/device_mdspan.hpp>
2727

28+
namespace raft {
29+
namespace distance {
30+
2831
/**
29-
* @defgroup pairwise_distance pairwise distance prims
32+
* @defgroup pairwise_distance pointer-based pairwise distance prims
3033
* @{
3134
*/
3235

33-
namespace raft {
34-
namespace distance {
35-
3636
/**
3737
* @brief Evaluate pairwise distances with the user epilogue lamba allowed
3838
* @tparam DistanceType which distance to evaluate
@@ -219,58 +219,6 @@ void distance(const InType* x,
219219
x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg);
220220
}
221221

222-
/**
223-
* @brief Evaluate pairwise distances for the simple use case.
224-
*
225-
* Note: Only contiguous row- or column-major layouts supported currently.
226-
*
227-
* @tparam DistanceType which distance to evaluate
228-
* @tparam InType input argument type
229-
* @tparam AccType accumulation type
230-
* @tparam OutType output type
231-
* @tparam Index_ Index type
232-
* @param handle raft handle for managing expensive resources
233-
* @param x first set of points (size n*k)
234-
* @param y second set of points (size m*k)
235-
* @param dist output distance matrix (size n*m)
236-
* @param metric_arg metric argument (used for Minkowski distance)
237-
*/
238-
template <raft::distance::DistanceType distanceType,
239-
typename InType,
240-
typename AccType,
241-
typename OutType,
242-
typename layout = raft::layout_c_contiguous,
243-
typename Index_ = int>
244-
void distance(raft::handle_t const& handle,
245-
raft::device_matrix_view<InType, Index_, layout> const x,
246-
raft::device_matrix_view<InType, Index_, layout> const y,
247-
raft::device_matrix_view<OutType, Index_, layout> dist,
248-
InType metric_arg = 2.0f)
249-
{
250-
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
251-
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
252-
"Number of rows in output must be equal to "
253-
"number of rows in X");
254-
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
255-
"Number of columns in output must be equal to "
256-
"number of rows in Y");
257-
258-
RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
259-
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");
260-
261-
constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;
262-
263-
distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
264-
y.data_handle(),
265-
dist.data_handle(),
266-
x.extent(0),
267-
y.extent(0),
268-
x.extent(1),
269-
handle.get_stream(),
270-
is_rowmajor,
271-
metric_arg);
272-
}
273-
274222
/**
275223
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
276224
* into compile time for the purpose of dispatch
@@ -401,6 +349,85 @@ void pairwise_distance(const raft::handle_t& handle,
401349
handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg);
402350
}
403351

352+
/** @} */
353+
354+
/**
355+
* \defgroup distance_mdspan Pairwise distance functions
356+
* @{
357+
*/
358+
359+
/**
360+
* @brief Evaluate pairwise distances for the simple use case.
361+
*
362+
* Note: Only contiguous row- or column-major layouts supported currently.
363+
*
364+
* Usage example:
365+
* @code{.cpp}
366+
* #include <raft/core/handle.hpp>
367+
* #include <raft/core/device_mdarray.hpp>
368+
* #include <raft/random/make_blobs.cuh>
369+
* #include <raft/distance/distance.cuh>
370+
*
371+
* raft::handle_t handle;
372+
* int n_samples = 5000;
373+
* int n_features = 50;
374+
*
375+
* auto input = raft::make_device_matrix<float>(handle, n_samples, n_features);
376+
* auto labels = raft::make_device_vector<int>(handle, n_samples);
377+
* auto output = raft::make_device_matrix<float>(handle, n_samples, n_samples);
378+
*
379+
* raft::random::make_blobs(handle, input.view(), labels.view());
380+
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
381+
* raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
382+
* @endcode
383+
*
384+
* @tparam DistanceType which distance to evaluate
385+
* @tparam InType input argument type
386+
* @tparam AccType accumulation type
387+
* @tparam OutType output type
388+
* @tparam Index_ Index type
389+
* @param handle raft handle for managing expensive resources
390+
* @param x first set of points (size n*k)
391+
* @param y second set of points (size m*k)
392+
* @param dist output distance matrix (size n*m)
393+
* @param metric_arg metric argument (used for Minkowski distance)
394+
*/
395+
template <raft::distance::DistanceType distanceType,
396+
typename InType,
397+
typename AccType,
398+
typename OutType,
399+
typename layout = raft::layout_c_contiguous,
400+
typename Index_ = int>
401+
void distance(raft::handle_t const& handle,
402+
raft::device_matrix_view<InType, Index_, layout> const x,
403+
raft::device_matrix_view<InType, Index_, layout> const y,
404+
raft::device_matrix_view<OutType, Index_, layout> dist,
405+
InType metric_arg = 2.0f)
406+
{
407+
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
408+
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
409+
"Number of rows in output must be equal to "
410+
"number of rows in X");
411+
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
412+
"Number of columns in output must be equal to "
413+
"number of rows in Y");
414+
415+
RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
416+
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");
417+
418+
constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;
419+
420+
distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
421+
y.data_handle(),
422+
dist.data_handle(),
423+
x.extent(0),
424+
y.extent(0),
425+
x.extent(1),
426+
handle.get_stream(),
427+
is_rowmajor,
428+
metric_arg);
429+
}
430+
404431
/**
405432
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
406433
* into compile time for the purpose of dispatch
@@ -449,9 +476,9 @@ void pairwise_distance(raft::handle_t const& handle,
449476
metric_arg);
450477
}
451478

479+
/** @} */
480+
452481
}; // namespace distance
453482
}; // namespace raft
454483

455-
/** @} */
456-
457484
#endif

0 commit comments

Comments
 (0)