Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
int batch_size,
int k,
algo_base::index_type* neighbors,
float* distances,
IdxT* neighbors_idx_t) const;
float* distances) const;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -181,7 +180,8 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
std::shared_ptr<raft::device_matrix<T, int64_t, raft::row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, raft::row_major>> input_dataset_v_;

std::shared_ptr<cuvs::neighbors::dynamic_batching::index<T, IdxT>> dynamic_batcher_;
std::shared_ptr<cuvs::neighbors::dynamic_batching::index<T, algo_base::index_type>>
dynamic_batcher_;
cuvs::neighbors::dynamic_batching::search_params dynamic_batcher_sp_{};
int64_t dynamic_batching_max_batch_size_;
size_t dynamic_batching_n_queues_;
Expand Down Expand Up @@ -292,16 +292,18 @@ void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param,
// dynamic batching
if (sp.dynamic_batching) {
if (!dynamic_batcher_ || needs_dynamic_batcher_update) {
dynamic_batcher_ = std::make_shared<cuvs::neighbors::dynamic_batching::index<T, IdxT>>(
handle_,
cuvs::neighbors::dynamic_batching::index_params{{},
sp.dynamic_batching_k,
sp.dynamic_batching_max_batch_size,
sp.dynamic_batching_n_queues,
sp.dynamic_batching_conservative_dispatch},
*index_,
search_params_,
filter_.get());
dynamic_batcher_ =
std::make_shared<cuvs::neighbors::dynamic_batching::index<T, algo_base::index_type>>(
handle_,
cuvs::neighbors::dynamic_batching::index_params{
{},
sp.dynamic_batching_k,
sp.dynamic_batching_max_batch_size,
sp.dynamic_batching_n_queues,
sp.dynamic_batching_conservative_dispatch},
*index_,
search_params_,
filter_.get());
}
dynamic_batcher_sp_.dispatch_timeout_ms = sp.dynamic_batching_dispatch_timeout_ms;
} else {
Expand Down Expand Up @@ -359,24 +361,16 @@ std::unique_ptr<algo<T>> cuvs_cagra<T, IdxT>::copy()
}

template <typename T, typename IdxT>
void cuvs_cagra<T, IdxT>::search_base(const T* queries,
int batch_size,
int k,
algo_base::index_type* neighbors,
float* distances,
IdxT* neighbors_idx_t) const
void cuvs_cagra<T, IdxT>::search_base(
const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const
{
static_assert(std::is_integral_v<algo_base::index_type>);
static_assert(std::is_integral_v<IdxT>);

if constexpr (sizeof(IdxT) == sizeof(algo_base::index_type)) {
neighbors_idx_t = reinterpret_cast<IdxT*>(neighbors);
}

auto queries_view =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, int64_t>(neighbors_idx_t, batch_size, k);
raft::make_device_matrix_view<algo_base::index_type, int64_t>(neighbors, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

if (dynamic_batcher_) {
Expand All @@ -390,26 +384,6 @@ void cuvs_cagra<T, IdxT>::search_base(const T* queries,
cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, *filter_);
}

if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) {
if (raft::get_device_for_address(neighbors) < 0 &&
raft::get_device_for_address(neighbors_idx_t) < 0) {
// Both pointers on the host, let's use host-side mapping
if (uses_stream()) {
// Need to wait for GPU to finish filling source
raft::resource::sync_stream(handle_);
}
for (int i = 0; i < batch_size * k; i++) {
neighbors[i] = algo_base::index_type(neighbors_idx_t[i]);
}
} else {
raft::linalg::unaryOp(neighbors,
neighbors_idx_t,
batch_size * k,
raft::cast_op<algo_base::index_type>(),
raft::resource::get_cuda_stream(handle_));
}
}
}

template <typename T, typename IdxT>
Expand All @@ -418,7 +392,6 @@ void cuvs_cagra<T, IdxT>::search(
{
static_assert(std::is_integral_v<algo_base::index_type>);
static_assert(std::is_integral_v<IdxT>);
constexpr bool kNeedsIoMapping = sizeof(IdxT) != sizeof(algo_base::index_type);

auto k0 = static_cast<size_t>(refine_ratio_ * k);
const bool disable_refinement = k0 <= static_cast<size_t>(k);
Expand All @@ -434,23 +407,20 @@ void cuvs_cagra<T, IdxT>::search(
dynamic_batching_max_batch_size_, batch_size) *
dynamic_batching_n_queues_
: 1;
auto tmp_buf_size = ((disable_refinement ? 0 : (sizeof(float) + sizeof(algo_base::index_type))) +
(kNeedsIoMapping ? sizeof(IdxT) : 0)) *
batch_size * k0;
auto tmp_buf_size =
((disable_refinement ? 0 : (sizeof(float) + sizeof(algo_base::index_type)))) * batch_size * k0;
auto& tmp_buf = get_tmp_buffer_from_global_pool(tmp_buf_size * max_dyn_grouping);
thread_local static int64_t group_id = 0;
auto* candidates_ptr = reinterpret_cast<algo_base::index_type*>(
reinterpret_cast<uint8_t*>(tmp_buf.data(mem_type)) + tmp_buf_size * group_id);
group_id = (group_id + 1) % max_dyn_grouping;
auto* candidate_dists_ptr =
reinterpret_cast<float*>(candidates_ptr + (disable_refinement ? 0 : batch_size * k0));
auto* neighbors_idx_t =
reinterpret_cast<IdxT*>(candidate_dists_ptr + (disable_refinement ? 0 : batch_size * k0));

if (disable_refinement) {
search_base(queries, batch_size, k, neighbors, distances, neighbors_idx_t);
search_base(queries, batch_size, k, neighbors, distances);
} else {
search_base(queries, batch_size, k0, candidates_ptr, candidate_dists_ptr, neighbors_idx_t);
search_base(queries, batch_size, k0, candidates_ptr, candidate_dists_ptr);

if (mem_type == MemoryType::kHostPinned && uses_stream()) {
// If the algorithm uses a stream to synchronize (non-persistent kernel), but the data is in
Expand Down
1 change: 1 addition & 0 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ cuvsError_t cuvsCagraExtend(cuvsResources_t res,
* c. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* d. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
* or `kDLDataType.code == kDLInt` and `kDLDataType.bits = 64`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
*
* @code {.c}
Expand Down
106 changes: 104 additions & 2 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ static_assert(std::is_aggregate_v<search_params>);
* The index stores the dataset and a kNN graph in device memory.
*
* @tparam T data element type
* @tparam IdxT type of the vector indices (represent dataset.extent(0))
* @tparam IdxT the data type used to store the neighbor indices in the search graph.
* It must be large enough to represent values up to dataset.extent(0).
*
*/
template <typename T, typename IdxT>
Expand Down Expand Up @@ -1116,7 +1117,7 @@ void extend(
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] index cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
Expand Down Expand Up @@ -1210,6 +1211,107 @@ void search(raft::resources const& res,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/

void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<half, uint32_t>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @}
*/
Expand Down
29 changes: 10 additions & 19 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,10 @@ index<T, IdxT> build(
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
* @tparam IdxT type of the indices in the CAGRA graph
* @tparam CagraSampleFilterT Device filter function, with the signature
* `(uint32_t query ix, uint32_t sample_ix) -> bool`
* @tparam OutputIdxT type of the returned indices
*
* @param[in] res raft resources
* @param[in] params configure the search
Expand All @@ -297,12 +298,12 @@ index<T, IdxT> build(
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*/
template <typename T, typename IdxT, typename CagraSampleFilterT>
template <typename T, typename IdxT, typename CagraSampleFilterT, typename OutputIdxT = IdxT>
void search_with_filtering(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<OutputIdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
Expand All @@ -315,26 +316,16 @@ void search_with_filtering(raft::resources const& res,
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, int64_t, raft::row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, int64_t, raft::row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, int64_t, raft::row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
return cagra::detail::search_main<T, OutputIdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries, neighbors, distances, sample_filter);
}

template <typename T, typename IdxT>
template <typename T, typename IdxT, typename OutputIdxT = IdxT>
void search(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<OutputIdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
Expand All @@ -344,7 +335,7 @@ void search(raft::resources const& res,
search_params params_copy = params;
if (params.filtering_rate < 0.0) { params_copy.filtering_rate = 0.0; }
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
return search_with_filtering<T, IdxT, none_filter_type, OutputIdxT>(
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
return;
} catch (const std::bad_cast&) {
Expand All @@ -364,7 +355,7 @@ void search(raft::resources const& res,
std::min(std::max(filtering_rate, min_filtering_rate), max_filtering_rate);
}
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, decltype(sample_filter_copy)>(
return search_with_filtering<T, IdxT, decltype(sample_filter_copy), OutputIdxT>(
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
Expand Down
Loading