1717
1818#include " cuvs_ann_bench_utils.h"
1919#include " cuvs_cagra_wrapper.h"
20- #include < cuvs/neighbors/mg .hpp>
21- #include < raft/core/resource/nccl_clique .hpp>
20+ #include < cuvs/neighbors/cagra .hpp>
21+ #include < raft/core/device_resources_snmg .hpp>
2222
2323namespace cuvs ::bench {
2424using namespace cuvs ::neighbors;
@@ -33,21 +33,20 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {
3333 using algo<T>::dim_;
3434
3535 struct build_param : public cuvs ::bench::cuvs_cagra<T, IdxT>::build_param {
36- cuvs::neighbors::mg:: distribution_mode mode;
36+ cuvs::neighbors::distribution_mode mode;
3737 };
3838
3939 struct search_param : public cuvs ::bench::cuvs_cagra<T, IdxT>::search_param {
40- cuvs::neighbors::mg:: sharded_merge_mode merge_mode;
40+ cuvs::neighbors::sharded_merge_mode merge_mode;
4141 };
4242
4343 cuvs_mg_cagra (Metric metric, int dim, const build_param& param, int concurrent_searches = 1 )
44- : algo<T>(metric, dim), index_params_(param)
44+ : algo<T>(metric, dim), index_params_(param), clique_()
4545 {
4646 index_params_.cagra_params .metric = parse_metric_type (metric);
4747 index_params_.ivf_pq_build_params ->metric = parse_metric_type (metric);
4848
49- // init nccl clique outside as to not affect benchmark
50- const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique (handle_);
49+ clique_.set_memory_pool (80 );
5150 }
5251
5352 void build (const T* dataset, size_t nrow) final ;
@@ -69,7 +68,7 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {
6968
7069 [[nodiscard]] auto get_sync_stream () const noexcept -> cudaStream_t override
7170 {
72- auto stream = raft::resource::get_cuda_stream (handle_ );
71+ auto stream = raft::resource::get_cuda_stream (clique_ );
7372 return stream;
7473 }
7574
@@ -87,11 +86,11 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {
8786 std::unique_ptr<algo<T>> copy () override ;
8887
8988 private:
90- raft::device_resources handle_ ;
89+ raft::device_resources_snmg clique_ ;
9190 float refine_ratio_;
9291 build_param index_params_;
93- cuvs::neighbors::mg::search_params <cagra::search_params> search_params_;
94- std::shared_ptr<cuvs::neighbors::mg::index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>
92+ cuvs::neighbors::mg_search_params <cagra::search_params> search_params_;
93+ std::shared_ptr<cuvs::neighbors::mg_index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>
9594 index_;
9695};
9796
@@ -100,14 +99,14 @@ void cuvs_mg_cagra<T, IdxT>::build(const T* dataset, size_t nrow)
10099{
101100 auto dataset_extents = raft::make_extents<IdxT>(nrow, dim_);
102101 index_params_.prepare_build_params (dataset_extents);
103- cuvs::neighbors::mg::index_params <cagra::index_params> build_params = index_params_.cagra_params ;
104- build_params.mode = index_params_.mode ;
102+ cuvs::neighbors::mg_index_params <cagra::index_params> build_params = index_params_.cagra_params ;
103+ build_params.mode = index_params_.mode ;
105104
106105 auto dataset_view =
107106 raft::make_host_matrix_view<const T, int64_t , raft::row_major>(dataset, nrow, dim_);
108- auto idx = cuvs::neighbors::mg ::build (handle_ , build_params, dataset_view);
107+ auto idx = cuvs::neighbors::cagra ::build (clique_ , build_params, dataset_view);
109108 index_ =
110- std::make_shared<cuvs::neighbors::mg::index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
109+ std::make_shared<cuvs::neighbors::mg_index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
111110 std::move (idx));
112111}
113112
@@ -118,8 +117,7 @@ void cuvs_mg_cagra<T, IdxT>::set_search_param(const search_param_base& param,
118117 const void * filter_bitset)
119118{
120119 if (filter_bitset != nullptr ) { throw std::runtime_error (" Filtering is not supported yet." ); }
121- auto sp = dynamic_cast <const search_param&>(param);
122- // search_params_ = static_cast<mg::search_params<cagra::search_params>>(sp.p);
120+ auto sp = dynamic_cast <const search_param&>(param);
123121 cagra::search_params* search_params_ptr_ = static_cast <cagra::search_params*>(&search_params_);
124122 *search_params_ptr_ = sp.p ;
125123 search_params_.merge_mode = sp.merge_mode ;
@@ -134,15 +132,15 @@ void cuvs_mg_cagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
134132template <typename T, typename IdxT>
135133void cuvs_mg_cagra<T, IdxT>::save(const std::string& file) const
136134{
137- cuvs::neighbors::mg ::serialize (handle_ , *index_, file);
135+ cuvs::neighbors::cagra ::serialize (clique_ , *index_, file);
138136}
139137
140138template <typename T, typename IdxT>
141139void cuvs_mg_cagra<T, IdxT>::load(const std::string& file)
142140{
143141 index_ =
144- std::make_shared<cuvs::neighbors::mg::index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
145- std::move (cuvs::neighbors::mg::deserialize_cagra <T, IdxT>(handle_ , file)));
142+ std::make_shared<cuvs::neighbors::mg_index <cuvs::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
143+ std::move (cuvs::neighbors::cagra::deserialize <T, IdxT>(clique_ , file)));
146144}
147145
148146template <typename T, typename IdxT>
@@ -165,8 +163,8 @@ void cuvs_mg_cagra<T, IdxT>::search_base(
165163 auto distances_view =
166164 raft::make_host_matrix_view<float , int64_t , raft::row_major>(distances, batch_size, k);
167165
168- cuvs::neighbors::mg ::search (
169- handle_ , *index_, search_params_, queries_view, neighbors_view, distances_view);
166+ cuvs::neighbors::cagra ::search (
167+ clique_ , *index_, search_params_, queries_view, neighbors_view, distances_view);
170168}
171169
172170template <typename T, typename IdxT>
0 commit comments