Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void build_knn_graph(
* // build KNN graph not using `cagra::build_knn_graph`
* // build(knn_graph, dataset, ...);
* // sort graph index
* sort_knn_graph(res, dataset.view(), knn_graph.view());
* sort_knn_graph(res, build_params.metric, dataset.view(), knn_graph.view());
* // optimize graph
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
Expand All @@ -184,6 +184,7 @@ void build_knn_graph(
* @tparam IdxT type of the dataset vector indices
*
* @param[in] res raft resources
* @param[in] metric metric
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows,
* knn_graph_degree]
Expand All @@ -197,6 +198,7 @@ template <
raft::host_device_accessor<std::experimental::default_accessor<IdxT>, raft::memory_type::host>>
void sort_knn_graph(
raft::resources const& res,
cuvs::distance::DistanceType metric,
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor> dataset,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph)
{
Expand All @@ -215,7 +217,7 @@ void sort_knn_graph(
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));

cagra::detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal);
cagra::detail::graph::sort_knn_graph(res, metric, dataset_internal, knn_graph_internal);
}

/**
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ void build_knn_graph(
nn_descent_idx.graph().extent(0),
nn_descent_idx.graph().extent(1));

cuvs::neighbors::cagra::detail::graph::sort_knn_graph(res, dataset, knn_graph_internal);
cuvs::neighbors::cagra::detail::graph::sort_knn_graph(
res, build_params.metric, dataset, knn_graph_internal);
}

template <
Expand Down
62 changes: 51 additions & 11 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
const uint32_t dataset_dim,
IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
const uint32_t graph_size,
const uint32_t graph_degree)
const uint32_t graph_degree,
const cuvs::distance::DistanceType metric)
{
const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize;
if (srcNode >= graph_size) { return; }
Expand All @@ -91,19 +92,46 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
// Compute distance from a src node to its neighbors
for (int k = 0; k < graph_degree; k++) {
const IdxT dstNode = knn_graph[k + static_cast<uint64_t>(graph_degree) * srcNode];
float dist = 0.0;
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
float diff = cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist += diff * diff;
float dist = 0;
float norm2_dst = 0;
if (metric == cuvs::distance::DistanceType::InnerProduct ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
auto elem_b = cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist -= cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) *
elem_b;

if (metric == cuvs::distance::DistanceType::CosineExpanded) {
norm2_dst += elem_b * elem_b;
}
}
} else {
// L2Expanded
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
float diff = cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
cuvs::spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist += diff * diff;
}
}
dist += __shfl_xor_sync(0xffffffff, dist, 1);
dist += __shfl_xor_sync(0xffffffff, dist, 2);
dist += __shfl_xor_sync(0xffffffff, dist, 4);
dist += __shfl_xor_sync(0xffffffff, dist, 8);
dist += __shfl_xor_sync(0xffffffff, dist, 16);

if (metric == cuvs::distance::DistanceType::CosineExpanded) {
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 1);
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 2);
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 4);
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 8);
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 16);
if (lane_id == (k % raft::WarpSize)) { dist /= sqrt(norm2_dst); }
}

if (lane_id == (k % raft::WarpSize)) {
my_keys[k / raft::WarpSize] = dist;
my_vals[k / raft::WarpSize] = dstNode;
Expand Down Expand Up @@ -471,11 +499,17 @@ template <
raft::host_device_accessor<std::experimental::default_accessor<IdxT>, raft::memory_type::host>>
void sort_knn_graph(
raft::resources const& res,
const cuvs::distance::DistanceType metric,
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor> dataset,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph)
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"dataset size is expected to have the same number of graph index size");
RAFT_EXPECTS(
metric == cuvs::distance::DistanceType::InnerProduct ||
metric == cuvs::distance::DistanceType::CosineExpanded ||
metric == cuvs::distance::DistanceType::L2Expanded,
"Unsupported metric. Only InnerProduct, CosineExpanded, and L2Expanded are supported");
const uint64_t dataset_size = dataset.extent(0);
const uint64_t dataset_dim = dataset.extent(1);
const DataT* dataset_ptr = dataset.data_handle();
Expand Down Expand Up @@ -507,8 +541,13 @@ void sort_knn_graph(
graph_size * input_graph_degree,
raft::resource::get_cuda_stream(res));

void (*kernel_sort)(
const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t);
void (*kernel_sort)(const DataT* const,
const IdxT,
const uint32_t,
IdxT* const,
const uint32_t,
const uint32_t,
const cuvs::distance::DistanceType);
if (input_graph_degree <= 32) {
constexpr int numElementsPerThread = 1;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
Expand Down Expand Up @@ -545,7 +584,8 @@ void sort_knn_graph(
dataset_dim,
d_input_graph.data_handle(),
graph_size,
input_graph_degree);
input_graph_degree,
metric);
raft::resource::sync_stream(res);
RAFT_LOG_DEBUG(".");
raft::copy(input_graph_ptr,
Expand Down