Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,25 @@ cuvsError_t cuvsIvfFlatIndexCreate(cuvsIvfFlatIndex_t* index);
* @param[in] index cuvsIvfFlatIndex_t to de-allocate
*/
cuvsError_t cuvsIvfFlatIndexDestroy(cuvsIvfFlatIndex_t index);

/** Get the number of clusters/inverted lists */
uint32_t cuvsIvfFlatIndexGetNLists(cuvsIvfFlatIndex_t index);

/** Get the dimensionality of the data */
uint32_t cuvsIvfFlatIndexGetDim(cuvsIvfFlatIndex_t index);

/**
* @brief Get the cluster centers corresponding to the lists [n_lists, dim]
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] index cuvsIvfFlatIndex_t Built Ivf-Flat Index
* @param[out] centers Preallocated array on host or device memory to store output, [n_lists, dim]
* @return cuvsError_t
*/
cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsResources_t res,
cuvsIvfFlatIndex_t index,
DLManagedTensor* centers);

/**
* @}
*/
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,16 @@ cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index);
/** Get the number of clusters/inverted lists */
uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index);

/** Get the dimensionality of the cluster centers */
uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index);
/** Get the dimensionality */
uint32_t cuvsIvfPqIndexGetDim(cuvsIvfPqIndex_t index);

/**
* @brief Get the cluster centers corresponding to the lists in the original space
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] index cuvsIvfPqIndex_t Built NN-Descent index
* @param[out] centers Preallocated array on host memory to store output, [n_lists, dim_ext]
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
* @param[out] centers Preallocated array on host or device memory to store output,
* dimensions [n_lists, dim]
* @return cuvsError_t
*/
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,10 @@ void extract_centers(raft::resources const& res,
const index<int64_t>& index,
raft::device_matrix_view<float, uint32_t, raft::row_major> cluster_centers);

/** @copydoc extract_centers */
void extract_centers(raft::resources const& res,
const index<int64_t>& index,
raft::host_matrix_view<float, uint32_t, raft::row_major> cluster_centers);
/**
* @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been
* modified externally.
Expand Down
95 changes: 95 additions & 0 deletions cpp/src/neighbors/ivf_flat_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>
#include <raft/util/cudart_utils.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
Expand Down Expand Up @@ -139,6 +140,35 @@ void _extend(cuvsResources_t res,

cuvs::neighbors::ivf_flat::extend(*res_ptr, vectors_mds, indices_mds, index_ptr);
}

template <typename output_mdspan_type, typename T, typename IdxT>
void _get_centers(cuvsResources_t res, cuvsIvfFlatIndex index, DLManagedTensor* centers)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_flat::index<T, IdxT>*>(index.addr);
auto dst = cuvs::core::from_dlpack<output_mdspan_type>(centers);
auto src = index_ptr->centers();
Comment on lines +147 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be value in making any of these const?

I don't have familiarity with the code yet, so I haven't grokked the semantics of raft::copy. Apologies, if this is noise.


RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output centers has incorrect number of rows");
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output centers has incorrect number of cols");

raft::copy(dst.data_handle(),
src.data_handle(),
dst.extent(0) * dst.extent(1),
raft::resource::get_cuda_stream(*res_ptr));
}

template <typename T, typename IdxT>
void get_centers(cuvsResources_t res, cuvsIvfFlatIndex index, DLManagedTensor* centers)
{
if (cuvs::core::is_dlpack_device_compatible(centers->dl_tensor)) {
using output_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
_get_centers<output_mdspan_type, T, IdxT>(res, index, centers);
} else {
using output_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
_get_centers<output_mdspan_type, T, IdxT>(res, index, centers);
}
}
} // namespace

extern "C" cuvsError_t cuvsIvfFlatIndexCreate(cuvsIvfFlatIndex_t* index)
Expand Down Expand Up @@ -351,3 +381,68 @@ extern "C" cuvsError_t cuvsIvfFlatExtend(cuvsResources_t res,
}
});
}

extern "C" uint32_t cuvsIvfFlatIndexGetNLists(cuvsIvfFlatIndex_t index)
{
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<float, int64_t>*>(index->addr);
return index_ptr->n_lists();
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<half, int64_t>*>(index->addr);
return index_ptr->n_lists();
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<int8_t, int64_t>*>(index->addr);
return index_ptr->n_lists();
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<uint8_t, int64_t>*>(index->addr);
return index_ptr->n_lists();
Comment on lines +387 to +402
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a recurring pattern in the code. One wonders if libcudf's type-dispatch pattern might be of value here.

We might consider exploring at a later date.

} else {
return 0;
}
}

extern "C" uint32_t cuvsIvfFlatIndexGetDim(cuvsIvfFlatIndex_t index)
{
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<float, int64_t>*>(index->addr);
return index_ptr->dim();
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<half, int64_t>*>(index->addr);
return index_ptr->dim();
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<int8_t, int64_t>*>(index->addr);
return index_ptr->dim();
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::ivf_flat::index<uint8_t, int64_t>*>(index->addr);
return index_ptr->dim();
} else {
return 0;
}
}

extern "C" cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsResources_t res,
cuvsIvfFlatIndex_t index,
DLManagedTensor* centers)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
get_centers<float, int64_t>(res, *index, centers);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
get_centers<half, int64_t>(res, *index, centers);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
get_centers<int8_t, int64_t>(res, *index, centers);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
get_centers<uint8_t, int64_t>(res, *index, centers);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}
22 changes: 22 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1904,4 +1904,26 @@ void extend(
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

template <typename output_mdspan_type>
inline void extract_centers(raft::resources const& res,
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
output_mdspan_type cluster_centers)
{
RAFT_EXPECTS(cluster_centers.extent(0) == index.n_lists(),
"Number of rows in the output buffer for cluster centers must be equal to the "
"number of IVF lists");
RAFT_EXPECTS(
cluster_centers.extent(1) == index.dim(),
"Number of columns in the output buffer for cluster centers and index dim are different");
auto stream = raft::resource::get_cuda_stream(res);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data_handle(),
sizeof(float) * index.dim(),
index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
sizeof(float) * index.dim(),
index.n_lists(),
cudaMemcpyDefault,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL cudaMemcpyDefault. I didn't know/realize that the direction could be inferred.

stream));
}
} // namespace cuvs::neighbors::ivf_pq::detail
23 changes: 8 additions & 15 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,14 @@ void extract_centers(raft::resources const& res,
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
raft::device_matrix_view<float, uint32_t, raft::row_major> cluster_centers)
{
RAFT_EXPECTS(cluster_centers.extent(0) == index.n_lists(),
"Number of rows in the output buffer for cluster centers must be equal to the "
"number of IVF lists");
RAFT_EXPECTS(
cluster_centers.extent(1) == index.dim(),
"Number of columns in the output buffer for cluster centers and index dim are different");
auto stream = raft::resource::get_cuda_stream(res);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data_handle(),
sizeof(float) * index.dim(),
index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
sizeof(float) * index.dim(),
index.n_lists(),
cudaMemcpyDefault,
stream));
detail::extract_centers(res, index, cluster_centers);
}

void extract_centers(raft::resources const& res,
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
raft::host_matrix_view<float, uint32_t, raft::row_major> cluster_centers)
{
detail::extract_centers(res, index, cluster_centers);
}

void recompute_internal_state(const raft::resources& res, index<int64_t>* index)
Expand Down
14 changes: 3 additions & 11 deletions cpp/src/neighbors/ivf_pq_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,8 @@ void _get_centers(cuvsResources_t res, cuvsIvfPqIndex index, DLManagedTensor* ce
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
auto dst = cuvs::core::from_dlpack<output_mdspan_type>(centers);
auto src = index_ptr->centers();

RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output centers has incorrect number of rows");
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output centers has incorrect number of cols");

cudaMemcpyAsync(dst.data_handle(),
src.data_handle(),
dst.extent(0) * dst.extent(1) * sizeof(float),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(*res_ptr));
cuvs::neighbors::ivf_pq::helpers::extract_centers(*res_ptr, *index_ptr, dst);
}
} // namespace

Expand Down Expand Up @@ -337,10 +329,10 @@ extern "C" uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)
return index_ptr->n_lists();
}

extern "C" uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)
extern "C" uint32_t cuvsIvfPqIndexGetDim(cuvsIvfPqIndex_t index)
{
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<int64_t>*>(index->addr);
return index_ptr->dim_ext();
return index_ptr->dim();
}

extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/neighbors/nn_descent_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>
#include <raft/util/cudart_utils.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
Expand Down Expand Up @@ -85,11 +86,10 @@ void _get_graph(cuvsResources_t res, cuvsNNDescentIndex_t index, DLManagedTensor
RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output graph has incorrect number of rows");
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output graph has incorrect number of cols");

cudaMemcpyAsync(dst.data_handle(),
src.data_handle(),
dst.extent(0) * dst.extent(1) * sizeof(uint32_t),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(*res_ptr));
raft::copy(dst.data_handle(),
src.data_handle(),
dst.extent(0) * dst.extent(1),
raft::resource::get_cuda_stream(*res_ptr));
} else {
RAFT_FAIL("Unsupported nn-descent index dtype: %d and bits: %d", dtype.code, dtype.bits);
}
Expand Down
8 changes: 8 additions & 0 deletions python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ cdef extern from "cuvs/neighbors/ivf_flat.h" nogil:

cuvsError_t cuvsIvfFlatIndexDestroy(cuvsIvfFlatIndex_t index)

uint32_t cuvsIvfFlatIndexGetNLists(cuvsIvfFlatIndex_t index)

uint32_t cuvsIvfFlatIndexGetDim(cuvsIvfFlatIndex_t index)

cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsResources_t res,
cuvsIvfFlatIndex_t index,
DLManagedTensor * centers)

cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res,
cuvsIvfFlatIndexParams* params,
DLManagedTensor* dataset,
Expand Down
29 changes: 29 additions & 0 deletions python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,35 @@ cdef class Index:
def __repr__(self):
return "Index(type=IvfFlat)"

@property
def n_lists(self):
""" The number of inverted lists (clusters) """
return cuvsIvfFlatIndexGetNLists(self.index)

@property
def dim(self):
""" dimensionality of the cluster centers """
return cuvsIvfFlatIndexGetDim(self.index)

@property
def centers(self):
""" Get the cluster centers corresponding to the lists in the
original space """
return self._get_centers()

@auto_sync_resources
def _get_centers(self, resources=None):
if not self.trained:
raise ValueError("Index needs to be built before getting centers")

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

output = np.empty((self.n_lists, self.dim), dtype='float32')
ai = wrap_array(output)
cdef cydlpack.DLManagedTensor* output_dlpack = cydlpack.dlpack_c(ai)
check_cuvs(cuvsIvfFlatIndexGetCenters(res, self.index, output_dlpack))
return output


@auto_sync_resources
def build(IndexParams index_params, dataset, resources=None):
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:

uint32_t cuvsIvfPqIndexGetNLists(cuvsIvfPqIndex_t index)

uint32_t cuvsIvfPqIndexGetDimExt(cuvsIvfPqIndex_t index)
uint32_t cuvsIvfPqIndexGetDim(cuvsIvfPqIndex_t index)

cuvsError_t cuvsIvfPqIndexGetCenters(cuvsResources_t res,
cuvsIvfPqIndex_t index,
Expand Down
6 changes: 3 additions & 3 deletions python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ cdef class Index:
return cuvsIvfPqIndexGetNLists(self.index)

@property
def dim_ext(self):
def dim(self):
""" dimensionality of the cluster centers """
return cuvsIvfPqIndexGetDimExt(self.index)
return cuvsIvfPqIndexGetDim(self.index)

@property
def centers(self):
Expand All @@ -261,7 +261,7 @@ cdef class Index:

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

output = np.empty((self.n_lists, self.dim_ext), dtype='float32')
output = np.empty((self.n_lists, self.dim), dtype='float32')
ai = wrap_array(output)
cdef cydlpack.DLManagedTensor* output_dlpack = cydlpack.dlpack_c(ai)
check_cuvs(cuvsIvfPqIndexGetCenters(res, self.index, output_dlpack))
Expand Down
4 changes: 4 additions & 0 deletions python/cuvs/cuvs/tests/test_ivf_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def run_ivf_flat_build_search_test(
recall = calc_recall(out_idx, skl_idx)
assert recall > 0.7

centers = index.centers
assert centers.shape[0] == build_params.n_lists
assert centers.shape[1] == n_cols


@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("dtype", [np.float32])
Expand Down