Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions cpp/include/cuvs/neighbors/refine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,51 @@ void refine(raft::resources const& handle,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
* Refinement is an operation that follows an approximate NN search. The approximate search has
* already selected n_candidates neighbor candidates for each query. We narrow it down to k
* neighbors. For each query, we calculate the exact distance between the query and its
* n_candidates neighbor candidate, and select the k nearest ones.
*
* The k nearest neighbors and distances are returned.
*
* Example usage
* @code{.cpp}
* using namespace cuvs::neighbors;
* // use default index parameters
* ivf_pq::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = ivf_pq::build(handle, index_params, dataset);
* // use default search parameters
* ivf_pq::search_params search_params;
* // search m = 4 * k nearest neighbours for each of the N queries
* ivf_pq::search(handle, search_params, index, queries, neighbor_candidates,
* out_dists_tmp);
* // refine it to the k nearest one
* refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists,
* index.metric());
* @endcode
*
*
* @param[in] handle the raft handle
* @param[in] dataset device matrix that stores the dataset [n_rows, dims]
* @param[in] queries device matrix of the queries [n_queris, dims]
* @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where
* n_candidates >= k
* @param[out] indices device matrix that stores the refined indices [n_queries, k]
* @param[out] distances device matrix that stores the refined distances [n_queries, k]
* @param[in] metric distance metric to use. Euclidean (L2) is used by default
*/
void refine(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<const uint32_t, int64_t, raft::row_major> neighbor_candidates,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> indices,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ void index<T, IdxT>::check_consistency()
"inconsistent number of lists (clusters)");
}

template struct index<float, uint32_t>; // Used for refine function
template struct index<float, int64_t>;
template struct index<half, int64_t>;
template struct index<int8_t, int64_t>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@
}

instantiate_cuvs_neighbors_refine_d(int64_t, float, float, int64_t);
instantiate_cuvs_neighbors_refine_d(uint32_t, float, float, int64_t);

#undef instantiate_cuvs_neighbors_refine_d
13 changes: 7 additions & 6 deletions cpp/src/neighbors/refine/refine_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ void refine_device(
cuvs::neighbors::ivf_flat::index<data_t, idx_t> refinement_index(
handle, cuvs::distance::DistanceType(metric), n_queries, false, true, dim);

cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
n_queries,
n_candidates);
cuvs::neighbors::ivf_flat::detail::fill_refinement_index<data_t, idx_t>(
handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
(idx_t)n_queries,
(uint32_t)n_candidates);
uint32_t grid_dim_x = 1;

// the neighbor ids will be computed in uint32_t as offset
Expand Down
Loading