-
Notifications
You must be signed in to change notification settings - Fork 160
Expose ivf-flat centers to python/c #888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
a8b06e3
6ad385a
6f72740
2cfbef8
1914251
6a864c1
7080e26
ed474d1
ed1670e
71355fe
936b45e
1e26626
cce8a75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include <raft/core/mdspan_types.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <raft/core/serialize.hpp> | ||
| #include <raft/util/cudart_utils.hpp> | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include <cuvs/core/exceptions.hpp> | ||
|
|
@@ -139,6 +140,35 @@ void _extend(cuvsResources_t res, | |
|
|
||
| cuvs::neighbors::ivf_flat::extend(*res_ptr, vectors_mds, indices_mds, index_ptr); | ||
| } | ||
|
|
||
| template <typename output_mdspan_type, typename T, typename IdxT> | ||
| void _get_centers(cuvsResources_t res, cuvsIvfFlatIndex index, DLManagedTensor* centers) | ||
| { | ||
| auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
| auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_flat::index<T, IdxT>*>(index.addr); | ||
| auto dst = cuvs::core::from_dlpack<output_mdspan_type>(centers); | ||
| auto src = index_ptr->centers(); | ||
|
Comment on lines
+147
to
+150
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would there be value in making any of these I don't have familiarity with the code yet, so I haven't grokked the semantics of |
||
|
|
||
| RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output centers has incorrect number of rows"); | ||
| RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output centers has incorrect number of cols"); | ||
|
|
||
| raft::copy(dst.data_handle(), | ||
| src.data_handle(), | ||
| dst.extent(0) * dst.extent(1), | ||
| raft::resource::get_cuda_stream(*res_ptr)); | ||
| } | ||
|
|
||
| template <typename T, typename IdxT> | ||
| void get_centers(cuvsResources_t res, cuvsIvfFlatIndex index, DLManagedTensor* centers) | ||
| { | ||
| if (cuvs::core::is_dlpack_device_compatible(centers->dl_tensor)) { | ||
| using output_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>; | ||
| _get_centers<output_mdspan_type, T, IdxT>(res, index, centers); | ||
| } else { | ||
| using output_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>; | ||
| _get_centers<output_mdspan_type, T, IdxT>(res, index, centers); | ||
| } | ||
| } | ||
| } // namespace | ||
|
|
||
| extern "C" cuvsError_t cuvsIvfFlatIndexCreate(cuvsIvfFlatIndex_t* index) | ||
|
|
@@ -351,3 +381,68 @@ extern "C" cuvsError_t cuvsIvfFlatExtend(cuvsResources_t res, | |
| } | ||
| }); | ||
| } | ||
|
|
||
| extern "C" uint32_t cuvsIvfFlatIndexGetNLists(cuvsIvfFlatIndex_t index) | ||
| { | ||
| if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<float, int64_t>*>(index->addr); | ||
| return index_ptr->n_lists(); | ||
| } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<half, int64_t>*>(index->addr); | ||
| return index_ptr->n_lists(); | ||
| } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<int8_t, int64_t>*>(index->addr); | ||
| return index_ptr->n_lists(); | ||
| } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<uint8_t, int64_t>*>(index->addr); | ||
| return index_ptr->n_lists(); | ||
|
Comment on lines
+387
to
+402
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a recurring pattern in the code. One wonders if We might consider exploring at a later date. |
||
| } else { | ||
| return 0; | ||
| } | ||
| } | ||
|
|
||
| extern "C" uint32_t cuvsIvfFlatIndexGetDim(cuvsIvfFlatIndex_t index) | ||
| { | ||
| if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<float, int64_t>*>(index->addr); | ||
| return index_ptr->dim(); | ||
| } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<half, int64_t>*>(index->addr); | ||
| return index_ptr->dim(); | ||
| } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<int8_t, int64_t>*>(index->addr); | ||
| return index_ptr->dim(); | ||
| } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { | ||
| auto index_ptr = | ||
| reinterpret_cast<cuvs::neighbors::ivf_flat::index<uint8_t, int64_t>*>(index->addr); | ||
| return index_ptr->dim(); | ||
| } else { | ||
| return 0; | ||
| } | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsResources_t res, | ||
| cuvsIvfFlatIndex_t index, | ||
| DLManagedTensor* centers) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { | ||
| get_centers<float, int64_t>(res, *index, centers); | ||
| } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { | ||
| get_centers<half, int64_t>(res, *index, centers); | ||
| } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { | ||
| get_centers<int8_t, int64_t>(res, *index, centers); | ||
| } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { | ||
| get_centers<uint8_t, int64_t>(res, *index, centers); | ||
| } else { | ||
| RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); | ||
| } | ||
| }); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.