Skip to content

Commit 1e548f8

Browse files
authored
Allow brute_force::build to work on host matrix dataset (#562)
Closes #538 Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Micka (https://github.com/lowener) URL: #562
1 parent 2a10353 commit 1e548f8

File tree

4 files changed

+140
-59
lines changed

4 files changed

+140
-59
lines changed

cpp/include/cuvs/neighbors/brute_force.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,20 @@ auto build(raft::resources const& handle,
204204
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
205205
-> cuvs::neighbors::brute_force::index<float, float>;
206206

207+
/**
208+
* @brief Build the index from the dataset for efficient search.
209+
*
210+
* @param[in] handle
211+
* @param[in] index_params parameters such as the distance metric to use
212+
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
213+
*
214+
* @return the constructed brute-force index
215+
*/
216+
auto build(raft::resources const& handle,
217+
const cuvs::neighbors::brute_force::index_params& index_params,
218+
raft::host_matrix_view<const float, int64_t, raft::row_major> dataset)
219+
-> cuvs::neighbors::brute_force::index<float, float>;
220+
207221
[[deprecated]] auto build(
208222
raft::resources const& handle,
209223
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
@@ -231,6 +245,20 @@ auto build(raft::resources const& handle,
231245
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
232246
-> cuvs::neighbors::brute_force::index<half, float>;
233247

248+
/**
249+
* @brief Build the index from the dataset for efficient search.
250+
*
251+
* @param[in] handle
252+
* @param[in] index_params parameters such as the distance metric to use
253+
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
254+
*
255+
* @return the constructed brute-force index
256+
*/
257+
auto build(raft::resources const& handle,
258+
const cuvs::neighbors::brute_force::index_params& index_params,
259+
raft::host_matrix_view<const half, int64_t, raft::row_major> dataset)
260+
-> cuvs::neighbors::brute_force::index<half, float>;
261+
234262
[[deprecated]] auto build(
235263
raft::resources const& handle,
236264
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,

cpp/src/neighbors/brute_force.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ void index<T, DistT>::update_dataset(
168168
{ \
169169
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
170170
} \
171+
auto build(raft::resources const& res, \
172+
const cuvs::neighbors::brute_force::index_params& index_params, \
173+
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
174+
->cuvs::neighbors::brute_force::index<T, DistT> \
175+
{ \
176+
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
177+
} \
171178
auto build(raft::resources const& res, \
172179
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
173180
cuvs::distance::DistanceType metric, \

cpp/src/neighbors/detail/knn_brute_force.cuh

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "./knn_utils.cuh"
2929

3030
#include <raft/core/bitmap.cuh>
31+
#include <raft/core/copy.cuh>
3132
#include <raft/core/device_csr_matrix.hpp>
3233
#include <raft/core/host_mdspan.hpp>
3334
#include <raft/core/resource/cuda_stream.hpp>
@@ -750,10 +751,10 @@ void search(raft::resources const& res,
750751
}
751752
}
752753

753-
template <typename T, typename DistT, typename LayoutT = raft::row_major>
754+
template <typename T, typename DistT, typename AccessorT, typename LayoutT = raft::row_major>
754755
cuvs::neighbors::brute_force::index<T, DistT> build(
755756
raft::resources const& res,
756-
raft::device_matrix_view<const T, int64_t, LayoutT> dataset,
757+
mdspan<const T, matrix_extent<int64_t>, LayoutT, AccessorT> dataset,
757758
cuvs::distance::DistanceType metric,
758759
DistT metric_arg)
759760
{
@@ -764,18 +765,31 @@ cuvs::neighbors::brute_force::index<T, DistT> build(
764765
if (metric == cuvs::distance::DistanceType::L2Expanded ||
765766
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
766767
metric == cuvs::distance::DistanceType::CosineExpanded) {
768+
auto dataset_storage = std::optional<device_matrix<T, int64_t, LayoutT>>{};
769+
auto dataset_view = [&res, &dataset_storage, dataset]() {
770+
if constexpr (std::is_same_v<decltype(dataset),
771+
raft::device_matrix_view<const T, int64_t, row_major>>) {
772+
return dataset;
773+
} else {
774+
dataset_storage =
775+
make_device_matrix<T, int64_t, LayoutT>(res, dataset.extent(0), dataset.extent(1));
776+
raft::copy(res, dataset_storage->view(), dataset);
777+
return raft::make_const_mdspan(dataset_storage->view());
778+
}
779+
}();
780+
767781
norms = raft::make_device_vector<DistT, int64_t>(res, dataset.extent(0));
768782
// cosine needs the l2norm, where as l2 distances needs the squared norm
769783
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
770784
raft::linalg::norm(res,
771-
dataset,
785+
dataset_view,
772786
norms->view(),
773787
raft::linalg::NormType::L2Norm,
774788
raft::linalg::Apply::ALONG_ROWS,
775789
raft::sqrt_op{});
776790
} else {
777791
raft::linalg::norm(res,
778-
dataset,
792+
dataset_view,
779793
norms->view(),
780794
raft::linalg::NormType::L2Norm,
781795
raft::linalg::Apply::ALONG_ROWS);

cpp/test/neighbors/brute_force.cu

Lines changed: 87 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cuvs/selection/select_k.hpp>
2222

2323
#include <cuvs/neighbors/brute_force.hpp>
24+
#include <raft/core/host_mdarray.hpp>
2425
#include <raft/core/resource/cuda_stream.hpp>
2526
#include <raft/linalg/transpose.cuh>
2627
#include <raft/matrix/init.cuh>
@@ -210,14 +211,15 @@ struct RandomKNNInputs {
210211
int k;
211212
cuvs::distance::DistanceType metric;
212213
bool row_major;
214+
bool host_dataset;
213215
};
214216

215217
std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input)
216218
{
217219
return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs
218220
<< " dim:" << input.dim << " k:" << input.k
219221
<< " metric:" << cuvs::neighbors::print_metric{input.metric}
220-
<< " row_major:" << input.row_major;
222+
<< " row_major:" << input.row_major << " host_dataset:" << input.host_dataset;
221223
}
222224

223225
template <typename T, typename DistT = T>
@@ -399,12 +401,15 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
399401

400402
cuvs::neighbors::brute_force::search_params search_params;
401403

402-
if (params_.row_major) {
403-
auto idx =
404-
cuvs::neighbors::brute_force::build(handle_,
405-
index_params,
406-
raft::make_device_matrix_view<const T, int64_t>(
407-
database.data(), params_.num_db_vecs, params_.dim));
404+
if (params_.host_dataset) {
405+
// test building from a dataset in host memory
406+
auto host_database =
407+
raft::make_host_matrix<T, int64_t, raft::row_major>(params_.num_db_vecs, params_.dim);
408+
raft::copy(
409+
host_database.data_handle(), database.data(), params_.num_db_vecs * params_.dim, stream_);
410+
411+
auto idx = cuvs::neighbors::brute_force::build(
412+
handle_, index_params, raft::make_const_mdspan(host_database.view()));
408413

409414
cuvs::neighbors::brute_force::search(
410415
handle_,
@@ -416,21 +421,39 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
416421
distances,
417422
cuvs::neighbors::filtering::none_sample_filter{});
418423
} else {
419-
auto idx = cuvs::neighbors::brute_force::build(
420-
handle_,
421-
index_params,
422-
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
423-
database.data(), params_.num_db_vecs, params_.dim));
424+
if (params_.row_major) {
425+
auto idx =
426+
cuvs::neighbors::brute_force::build(handle_,
427+
index_params,
428+
raft::make_device_matrix_view<const T, int64_t>(
429+
database.data(), params_.num_db_vecs, params_.dim));
424430

425-
cuvs::neighbors::brute_force::search(
426-
handle_,
427-
search_params,
428-
idx,
429-
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
430-
search_queries.data(), params_.num_queries, params_.dim),
431-
indices,
432-
distances,
433-
cuvs::neighbors::filtering::none_sample_filter{});
431+
cuvs::neighbors::brute_force::search(
432+
handle_,
433+
search_params,
434+
idx,
435+
raft::make_device_matrix_view<const T, int64_t>(
436+
search_queries.data(), params_.num_queries, params_.dim),
437+
indices,
438+
distances,
439+
cuvs::neighbors::filtering::none_sample_filter{});
440+
} else {
441+
auto idx = cuvs::neighbors::brute_force::build(
442+
handle_,
443+
index_params,
444+
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
445+
database.data(), params_.num_db_vecs, params_.dim));
446+
447+
cuvs::neighbors::brute_force::search(
448+
handle_,
449+
search_params,
450+
idx,
451+
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
452+
search_queries.data(), params_.num_queries, params_.dim),
453+
indices,
454+
distances,
455+
cuvs::neighbors::filtering::none_sample_filter{});
456+
}
434457
}
435458

436459
ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(ref_indices_.data(),
@@ -480,42 +503,51 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
480503

481504
const std::vector<RandomKNNInputs> random_inputs = {
482505
// test each distance metric on a small-ish input, with row-major inputs
483-
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true},
484-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true},
485-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
486-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true},
487-
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true},
488-
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true},
489-
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true},
490-
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true},
491-
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true},
492-
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true},
493-
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true},
494-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
495-
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true},
506+
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true, false},
507+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true, false},
508+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
509+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, false},
510+
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, false},
511+
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, false},
512+
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, false},
513+
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true, false},
514+
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true, false},
515+
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true, false},
516+
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true, false},
517+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
518+
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true, false},
496519
// test each distance metric with col-major inputs
497-
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false},
498-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false},
499-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
500-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false},
501-
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false},
502-
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false},
503-
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false},
504-
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false},
505-
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false},
506-
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false},
507-
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false},
508-
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
509-
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false},
520+
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false, false},
521+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false, false},
522+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
523+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false, false},
524+
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false, false},
525+
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false, false},
526+
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false, false},
527+
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false, false},
528+
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false, false},
529+
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false, false},
530+
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false, false},
531+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
532+
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false, false},
510533
// larger tests on different sized data / k values
511-
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false},
512-
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true},
513-
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false},
514-
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true},
515-
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false},
516-
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true},
517-
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false},
518-
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false}};
534+
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false, false},
535+
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true, false},
536+
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
537+
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true, false},
538+
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false, false},
539+
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true, false},
540+
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
541+
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false, false},
542+
// test with datasets on host memory
543+
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, true, true},
544+
{256, 512, 32, 16, cuvs::distance::DistanceType::L2Unexpanded, true, true},
545+
{256, 512, 8, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, true},
546+
{256, 128, 32, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, true},
547+
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, true},
548+
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, true},
549+
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, true},
550+
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, true, true}};
519551

520552
typedef RandomBruteForceKNNTest<float, float> RandomBruteForceKNNTestF;
521553
TEST_P(RandomBruteForceKNNTestF, BruteForce) { this->testBruteForce(); }

0 commit comments

Comments
 (0)