-
Notifications
You must be signed in to change notification settings - Fork 160
[WIP][Java] Exposing CAGRA graph #1102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,10 @@ | |
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraDeserialize; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexCreate; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexDestroy; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexFromGraph; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexGetGraph; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexGetGraphDegree; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexGetSize; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndex_t; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraMerge; | ||
| import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraSearch; | ||
|
|
@@ -46,6 +50,7 @@ | |
| import com.nvidia.cuvs.CagraIndex; | ||
| import com.nvidia.cuvs.CagraIndexParams; | ||
| import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo; | ||
| import com.nvidia.cuvs.CagraIndexParams.CuvsDistanceType; | ||
| import com.nvidia.cuvs.CagraMergeParams; | ||
| import com.nvidia.cuvs.CagraQuery; | ||
| import com.nvidia.cuvs.CagraSearchParams; | ||
|
|
@@ -133,6 +138,24 @@ private CagraIndexImpl(IndexReference indexReference, CuVSResourcesImpl resource | |
| this.destroyed = false; | ||
| } | ||
|
|
||
| /** | ||
| * Constructor for creating an index from a graph and dataset | ||
| * | ||
| * @param graph the kNN graph as a 2D int array | ||
| * @param distanceType the distance metric to use | ||
| * @param dataset the dataset corresponding to the graph | ||
| * @param resources an instance of {@link CuVSResources} | ||
| */ | ||
| private CagraIndexImpl( | ||
| int[][] graph, CuvsDistanceType distanceType, Dataset dataset, CuVSResourcesImpl resources) { | ||
| Objects.requireNonNull(graph); | ||
| Objects.requireNonNull(distanceType); | ||
| Objects.requireNonNull(dataset); | ||
| this.resources = resources; | ||
| assert dataset instanceof DatasetImpl; | ||
| this.cagraIndexReference = buildFromGraph(graph, distanceType, (DatasetImpl) dataset); | ||
| } | ||
|
|
||
| private void checkNotDestroyed() { | ||
| if (destroyed) { | ||
| throw new IllegalStateException("destroyed"); | ||
|
|
@@ -151,6 +174,8 @@ public void destroyIndex() throws Throwable { | |
| if (cagraIndexReference.dataset != null) { | ||
| cagraIndexReference.dataset.close(); | ||
| } | ||
| // Free graph memory if needed | ||
| cagraIndexReference.freeGraphMemory(); | ||
| } finally { | ||
| destroyed = true; | ||
| } | ||
|
|
@@ -212,6 +237,79 @@ private IndexReference build(CagraIndexParams indexParameters, DatasetImpl datas | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Invokes the native cuvsCagraIndexFromGraph function via the Panama API to build the | ||
| * {@link CagraIndex} from a graph and dataset | ||
| * | ||
| * @return an instance of {@link IndexReference} that holds the pointer to the | ||
| * index | ||
| */ | ||
| private IndexReference buildFromGraph( | ||
| int[][] graph, CuvsDistanceType distanceType, DatasetImpl dataset) { | ||
| try (var localArena = Arena.ofConfined()) { | ||
| long rows = dataset.size(); | ||
| long cols = dataset.dimensions(); | ||
| long graphDegree = graph.length > 0 ? graph[0].length : 0; | ||
|
|
||
| if (graph.length != rows) { | ||
| throw new IllegalArgumentException("Graph rows must match dataset size"); | ||
| } | ||
|
|
||
| MemorySegment dataSeg = dataset.asMemorySegment(); | ||
| long cuvsRes = resources.getHandle(); | ||
|
|
||
| // Prepare dataset tensor | ||
| long[] datasetShape = {rows, cols}; | ||
| MemorySegment datasetTensor = | ||
| prepareTensor(resources.getArena(), dataSeg, datasetShape, 2, 32, 2, 2, 1); | ||
|
|
||
| // Prepare graph tensor | ||
| Arena arena = resources.getArena(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, the resources arena will go away |
||
| long graphElements = rows * graphDegree; | ||
| SequenceLayout graphSequenceLayout = MemoryLayout.sequenceLayout(graphElements, C_INT); | ||
| MemorySegment graphMemorySegment = arena.allocate(graphSequenceLayout); | ||
|
|
||
| // Copy graph data to memory segment | ||
| for (int i = 0; i < rows; i++) { | ||
| for (int j = 0; j < graphDegree; j++) { | ||
| graphMemorySegment.setAtIndex(C_INT, (long) i * graphDegree + j, graph[i][j]); | ||
| } | ||
| } | ||
|
|
||
| // Allocate device memory for the graph | ||
| MemorySegment graphD = arena.allocate(C_POINTER); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's interesting you are copying things to device memory here, but I don't think this is needed: cuvsCagraIndexFromGraph will work with host memory and do the copy itself (and maybe a tiny bit more efficiently) |
||
| long graphBytes = C_INT_BYTE_SIZE * graphElements; | ||
|
|
||
| int returnValue = cuvsRMMAlloc(cuvsRes, graphD, graphBytes); | ||
| checkCuVSError(returnValue, "cuvsRMMAlloc"); | ||
|
|
||
| MemorySegment graphDP = graphD.get(C_POINTER, 0); | ||
|
|
||
| // Copy graph from host to device | ||
| cudaMemcpy(graphDP, graphMemorySegment, graphBytes, HOST_TO_DEVICE); | ||
|
|
||
| // Prepare graph tensor | ||
| long[] graphShape = {rows, graphDegree}; | ||
| MemorySegment graphTensor = prepareTensor(arena, graphDP, graphShape, 1, 32, 2, 2, 1); | ||
|
|
||
| var index = createCagraIndex(); | ||
|
|
||
| returnValue = cuvsStreamSync(cuvsRes); | ||
| checkCuVSError(returnValue, "cuvsStreamSync"); | ||
|
|
||
| returnValue = | ||
| cuvsCagraIndexFromGraph(cuvsRes, distanceType.value, graphTensor, datasetTensor, index); | ||
| checkCuVSError(returnValue, "cuvsCagraIndexFromGraph"); | ||
|
|
||
| returnValue = cuvsStreamSync(cuvsRes); | ||
| checkCuVSError(returnValue, "cuvsStreamSync"); | ||
|
|
||
| // Keep graph memory alive - freed when index is destroyed | ||
|
|
||
| return new IndexReference(index, dataset, graphDP, graphBytes, resources); | ||
| } | ||
| } | ||
|
|
||
| private static MemorySegment createCagraIndex() { | ||
| try (var localArena = Arena.ofConfined()) { | ||
| MemorySegment indexPtrPtr = localArena.allocate(cuvsCagraIndex_t); | ||
|
|
@@ -489,6 +587,94 @@ public CuVSResourcesImpl getCuVSResources() { | |
| return resources; | ||
| } | ||
|
|
||
| /** | ||
| * Gets the kNN graph from the index. | ||
| */ | ||
| @Override | ||
| public int[][] getGraph() throws Throwable { | ||
| checkNotDestroyed(); | ||
|
|
||
| try (var localArena = Arena.ofConfined()) { | ||
| // Get index size and graph degree | ||
| MemorySegment sizePtr = localArena.allocate(C_INT); | ||
| MemorySegment graphDegreePtr = localArena.allocate(C_INT); | ||
|
|
||
| int returnValue = cuvsCagraIndexGetSize(cagraIndexReference.getMemorySegment(), sizePtr); | ||
| checkCuVSError(returnValue, "cuvsCagraIndexGetSize"); | ||
|
|
||
| returnValue = | ||
| cuvsCagraIndexGetGraphDegree(cagraIndexReference.getMemorySegment(), graphDegreePtr); | ||
| checkCuVSError(returnValue, "cuvsCagraIndexGetGraphDegree"); | ||
|
|
||
| int size = sizePtr.get(C_INT, 0); | ||
| int graphDegree = graphDegreePtr.get(C_INT, 0); | ||
|
|
||
| // Allocate memory for the graph | ||
| long graphElements = (long) size * graphDegree; | ||
| Arena arena = resources.getArena(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, this can be local |
||
| SequenceLayout graphSequenceLayout = MemoryLayout.sequenceLayout(graphElements, C_INT); | ||
| MemorySegment graphMemorySegment = arena.allocate(graphSequenceLayout); | ||
|
|
||
| // Allocate device memory for the graph | ||
| MemorySegment graphD = arena.allocate(C_POINTER); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, I think |
||
| long graphBytes = C_INT_BYTE_SIZE * graphElements; | ||
| long cuvsRes = resources.getHandle(); | ||
|
|
||
| returnValue = cuvsRMMAlloc(cuvsRes, graphD, graphBytes); | ||
| checkCuVSError(returnValue, "cuvsRMMAlloc"); | ||
|
|
||
| MemorySegment graphDP = graphD.get(C_POINTER, 0); | ||
|
|
||
| // Prepare the tensor for the graph | ||
| long[] graphShape = {size, graphDegree}; | ||
| MemorySegment graphTensor = prepareTensor(arena, graphDP, graphShape, 1, 32, 2, 2, 1); | ||
|
|
||
| // Get the graph from the index | ||
| returnValue = | ||
| cuvsCagraIndexGetGraph(cuvsRes, cagraIndexReference.getMemorySegment(), graphTensor); | ||
| checkCuVSError(returnValue, "cuvsCagraIndexGetGraph"); | ||
|
|
||
| returnValue = cuvsStreamSync(cuvsRes); | ||
| checkCuVSError(returnValue, "cuvsStreamSync"); | ||
|
|
||
| // Copy the graph from device to host | ||
| cudaMemcpy(graphMemorySegment, graphDP, graphBytes, INFER_DIRECTION); | ||
|
|
||
| // Free device memory | ||
| returnValue = cuvsRMMFree(cuvsRes, graphDP, graphBytes); | ||
| checkCuVSError(returnValue, "cuvsRMMFree"); | ||
|
|
||
| // Convert to 2D int array | ||
| int[][] graph = new int[size][graphDegree]; | ||
| for (int i = 0; i < size; i++) { | ||
| for (int j = 0; j < graphDegree; j++) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: this can be done "by-row" more efficiently with |
||
| graph[i][j] = graphMemorySegment.getAtIndex(C_INT, (long) i * graphDegree + j); | ||
| } | ||
| } | ||
|
|
||
| return graph; | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Gets the dataset used to build the index. | ||
| */ | ||
| @Override | ||
| public Dataset getDataset() throws Throwable { | ||
| checkNotDestroyed(); | ||
|
|
||
| // If we have the dataset stored locally, return it | ||
| if (cagraIndexReference.dataset != null) { | ||
| return cagraIndexReference.dataset; | ||
| } | ||
|
|
||
| // For indexes created from deserialization or merging, the dataset is not stored locally | ||
| throw new UnsupportedOperationException( | ||
| "Dataset is not available for this index. " | ||
| + "Dataset is only available for indexes created directly from a dataset, " | ||
| + "not for indexes loaded from serialization or created through merging."); | ||
| } | ||
|
|
||
| /** | ||
| * Allocates the configured index parameters in the MemorySegment. | ||
| */ | ||
|
|
@@ -687,6 +873,8 @@ public static class Builder implements CagraIndex.Builder { | |
| private CagraIndexParams cagraIndexParams; | ||
| private final CuVSResourcesImpl cuvsResources; | ||
| private InputStream inputStream; | ||
| private int[][] graph; | ||
| private CuvsDistanceType distanceType; | ||
|
|
||
| public Builder(CuVSResourcesImpl cuvsResources) { | ||
| this.cuvsResources = cuvsResources; | ||
|
|
@@ -716,10 +904,19 @@ public Builder withIndexParams(CagraIndexParams cagraIndexParameters) { | |
| return this; | ||
| } | ||
|
|
||
| @Override | ||
| public Builder from(int[][] graph, CuvsDistanceType distanceType) { | ||
| this.graph = graph; | ||
| this.distanceType = distanceType; | ||
| return this; | ||
| } | ||
|
|
||
| @Override | ||
| public CagraIndexImpl build() throws Throwable { | ||
| if (inputStream != null) { | ||
| return new CagraIndexImpl(inputStream, cuvsResources); | ||
| } else if (graph != null && distanceType != null) { | ||
| return new CagraIndexImpl(graph, distanceType, dataset, cuvsResources); | ||
| } else { | ||
| return new CagraIndexImpl(cagraIndexParams, dataset, cuvsResources); | ||
| } | ||
|
|
@@ -733,6 +930,9 @@ public static class IndexReference { | |
|
|
||
| private final MemorySegment memorySegment; | ||
| private final Dataset dataset; | ||
| private final MemorySegment graphDevicePointer; | ||
| private final long graphBytes; | ||
| private final CuVSResourcesImpl resources; | ||
|
|
||
| /** | ||
| * Constructs CagraIndexReference with an instance of MemorySegment passed as a | ||
|
|
@@ -748,6 +948,25 @@ public static class IndexReference { | |
| private IndexReference(MemorySegment indexMemorySegment, Dataset dataset) { | ||
| this.memorySegment = indexMemorySegment; | ||
| this.dataset = dataset; | ||
| this.graphDevicePointer = null; | ||
| this.graphBytes = 0; | ||
| this.resources = null; | ||
| } | ||
|
|
||
| /** | ||
| * Constructor for graph-based index, keeps graph memory alive. | ||
| */ | ||
| private IndexReference( | ||
| MemorySegment indexMemorySegment, | ||
| Dataset dataset, | ||
| MemorySegment graphDevicePointer, | ||
| long graphBytes, | ||
| CuVSResourcesImpl resources) { | ||
| this.memorySegment = indexMemorySegment; | ||
| this.dataset = dataset; | ||
| this.graphDevicePointer = graphDevicePointer; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be avoided if you don't allocate GPU graph memory yourself, keeping things easier/tidier (but better double check) |
||
| this.graphBytes = graphBytes; | ||
| this.resources = resources; | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -758,5 +977,25 @@ private IndexReference(MemorySegment indexMemorySegment, Dataset dataset) { | |
| protected MemorySegment getMemorySegment() { | ||
| return memorySegment; | ||
| } | ||
|
|
||
| /** | ||
| * Gets the dataset associated with this index. | ||
| * | ||
| * @return the dataset, or null if not available | ||
| */ | ||
| protected Dataset getDataset() { | ||
| return dataset; | ||
| } | ||
|
|
||
| /** | ||
| * Frees graph memory if needed. | ||
| */ | ||
| protected void freeGraphMemory() throws Throwable { | ||
| if (graphDevicePointer != null && resources != null) { | ||
| long cuvsRes = resources.getHandle(); | ||
| int returnValue = cuvsRMMFree(cuvsRes, graphDevicePointer, graphBytes); | ||
| checkCuVSError(returnValue, "cuvsRMMFree graph memory"); | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can/should use the localArena here