2121#include < cuvs/selection/select_k.hpp>
2222
2323#include < cuvs/neighbors/brute_force.hpp>
24+ #include < raft/core/host_mdarray.hpp>
2425#include < raft/core/resource/cuda_stream.hpp>
2526#include < raft/linalg/transpose.cuh>
2627#include < raft/matrix/init.cuh>
@@ -210,14 +211,15 @@ struct RandomKNNInputs {
210211 int k;
211212 cuvs::distance::DistanceType metric;
212213 bool row_major;
214+ bool host_dataset;
213215};
214216
215217std::ostream& operator <<(std::ostream& os, const RandomKNNInputs& input)
216218{
217219 return os << " num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs
218220 << " dim:" << input.dim << " k:" << input.k
219221 << " metric:" << cuvs::neighbors::print_metric{input.metric }
220- << " row_major:" << input.row_major ;
222+ << " row_major:" << input.row_major << " host_dataset: " << input. host_dataset ;
221223}
222224
223225template <typename T, typename DistT = T>
@@ -399,12 +401,15 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
399401
400402 cuvs::neighbors::brute_force::search_params search_params;
401403
402- if (params_.row_major ) {
403- auto idx =
404- cuvs::neighbors::brute_force::build (handle_,
405- index_params,
406- raft::make_device_matrix_view<const T, int64_t >(
407- database.data (), params_.num_db_vecs , params_.dim ));
404+ if (params_.host_dataset ) {
405+ // test building from a dataset in host memory
406+ auto host_database =
407+ raft::make_host_matrix<T, int64_t , raft::row_major>(params_.num_db_vecs , params_.dim );
408+ raft::copy (
409+ host_database.data_handle (), database.data (), params_.num_db_vecs * params_.dim , stream_);
410+
411+ auto idx = cuvs::neighbors::brute_force::build (
412+ handle_, index_params, raft::make_const_mdspan (host_database.view ()));
408413
409414 cuvs::neighbors::brute_force::search (
410415 handle_,
@@ -416,21 +421,39 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
416421 distances,
417422 cuvs::neighbors::filtering::none_sample_filter{});
418423 } else {
419- auto idx = cuvs::neighbors::brute_force::build (
420- handle_,
421- index_params,
422- raft::make_device_matrix_view<const T, int64_t , raft::col_major>(
423- database.data (), params_.num_db_vecs , params_.dim ));
424+ if (params_.row_major ) {
425+ auto idx =
426+ cuvs::neighbors::brute_force::build (handle_,
427+ index_params,
428+ raft::make_device_matrix_view<const T, int64_t >(
429+ database.data (), params_.num_db_vecs , params_.dim ));
424430
425- cuvs::neighbors::brute_force::search (
426- handle_,
427- search_params,
428- idx,
429- raft::make_device_matrix_view<const T, int64_t , raft::col_major>(
430- search_queries.data (), params_.num_queries , params_.dim ),
431- indices,
432- distances,
433- cuvs::neighbors::filtering::none_sample_filter{});
431+ cuvs::neighbors::brute_force::search (
432+ handle_,
433+ search_params,
434+ idx,
435+ raft::make_device_matrix_view<const T, int64_t >(
436+ search_queries.data (), params_.num_queries , params_.dim ),
437+ indices,
438+ distances,
439+ cuvs::neighbors::filtering::none_sample_filter{});
440+ } else {
441+ auto idx = cuvs::neighbors::brute_force::build (
442+ handle_,
443+ index_params,
444+ raft::make_device_matrix_view<const T, int64_t , raft::col_major>(
445+ database.data (), params_.num_db_vecs , params_.dim ));
446+
447+ cuvs::neighbors::brute_force::search (
448+ handle_,
449+ search_params,
450+ idx,
451+ raft::make_device_matrix_view<const T, int64_t , raft::col_major>(
452+ search_queries.data (), params_.num_queries , params_.dim ),
453+ indices,
454+ distances,
455+ cuvs::neighbors::filtering::none_sample_filter{});
456+ }
434457 }
435458
436459 ASSERT_TRUE (cuvs::neighbors::devArrMatchKnnPair (ref_indices_.data (),
@@ -480,42 +503,51 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
480503
481504const std::vector<RandomKNNInputs> random_inputs = {
482505 // test each distance metric on a small-ish input, with row-major inputs
483- {100 , 256 , 2 , 65 , cuvs::distance::DistanceType::L2Expanded, true },
484- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2Unexpanded, true },
485- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, true },
486- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtUnexpanded, true },
487- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L1, true },
488- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Linf, true },
489- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::InnerProduct, true },
490- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CorrelationExpanded, true },
491- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CosineExpanded, true },
492- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::LpUnexpanded, true },
493- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::JensenShannon, true },
494- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, true },
495- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Canberra, true },
506+ {100 , 256 , 2 , 65 , cuvs::distance::DistanceType::L2Expanded, true , false },
507+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2Unexpanded, true , false },
508+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, true , false },
509+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtUnexpanded, true , false },
510+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L1, true , false },
511+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Linf, true , false },
512+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::InnerProduct, true , false },
513+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CorrelationExpanded, true , false },
514+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CosineExpanded, true , false },
515+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::LpUnexpanded, true , false },
516+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::JensenShannon, true , false },
517+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, true , false },
518+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Canberra, true , false },
496519 // test each distance metric with col-major inputs
497- {256 , 512 , 16 , 7 , cuvs::distance::DistanceType::L2Expanded, false },
498- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2Unexpanded, false },
499- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, false },
500- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtUnexpanded, false },
501- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L1, false },
502- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Linf, false },
503- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::InnerProduct, false },
504- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CorrelationExpanded, false },
505- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CosineExpanded, false },
506- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::LpUnexpanded, false },
507- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::JensenShannon, false },
508- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, false },
509- {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Canberra, false },
520+ {256 , 512 , 16 , 7 , cuvs::distance::DistanceType::L2Expanded, false , false },
521+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2Unexpanded, false , false },
522+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, false , false },
523+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtUnexpanded, false , false },
524+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L1, false , false },
525+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Linf, false , false },
526+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::InnerProduct, false , false },
527+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CorrelationExpanded, false , false },
528+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::CosineExpanded, false , false },
529+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::LpUnexpanded, false , false },
530+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::JensenShannon, false , false },
531+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, false , false },
532+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Canberra, false , false },
510533 // larger tests on different sized data / k values
511- {10000 , 40000 , 32 , 30 , cuvs::distance::DistanceType::L2Expanded, false },
512- {345 , 1023 , 16 , 128 , cuvs::distance::DistanceType::CosineExpanded, true },
513- {789 , 20516 , 64 , 256 , cuvs::distance::DistanceType::L2SqrtExpanded, false },
514- {1000 , 200000 , 128 , 128 , cuvs::distance::DistanceType::L2Expanded, true },
515- {1000 , 200000 , 128 , 128 , cuvs::distance::DistanceType::L2Expanded, false },
516- {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::LpUnexpanded, true },
517- {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::L2SqrtExpanded, false },
518- {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::InnerProduct, false }};
534+ {10000 , 40000 , 32 , 30 , cuvs::distance::DistanceType::L2Expanded, false , false },
535+ {345 , 1023 , 16 , 128 , cuvs::distance::DistanceType::CosineExpanded, true , false },
536+ {789 , 20516 , 64 , 256 , cuvs::distance::DistanceType::L2SqrtExpanded, false , false },
537+ {1000 , 200000 , 128 , 128 , cuvs::distance::DistanceType::L2Expanded, true , false },
538+ {1000 , 200000 , 128 , 128 , cuvs::distance::DistanceType::L2Expanded, false , false },
539+ {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::LpUnexpanded, true , false },
540+ {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::L2SqrtExpanded, false , false },
541+ {1000 , 5000 , 128 , 128 , cuvs::distance::DistanceType::InnerProduct, false , false },
542+ // test with datasets on host memory
543+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L2Expanded, true , true },
544+ {256 , 512 , 32 , 16 , cuvs::distance::DistanceType::L2Unexpanded, true , true },
545+ {256 , 512 , 8 , 8 , cuvs::distance::DistanceType::L2SqrtExpanded, true , true },
546+ {256 , 128 , 32 , 8 , cuvs::distance::DistanceType::L2SqrtUnexpanded, true , true },
547+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::L1, true , true },
548+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::Linf, true , true },
549+ {256 , 512 , 16 , 8 , cuvs::distance::DistanceType::InnerProduct, true , true },
550+ {256 , 512 , 16 , 7 , cuvs::distance::DistanceType::L2Expanded, true , true }};
519551
520552typedef RandomBruteForceKNNTest<float , float > RandomBruteForceKNNTestF;
521553TEST_P (RandomBruteForceKNNTestF, BruteForce) { this ->testBruteForce (); }
0 commit comments