Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import java.util.function.LongToIntFunction;

/**
* BruteForceQuery holds the query vectors to be used while invoking search.
Expand All @@ -27,25 +27,25 @@
*/
public class BruteForceQuery {

private List<Integer> mapping;
private float[][] queryVectors;
private BitSet[] prefilters;
private final LongToIntFunction mapping;
private final float[][] queryVectors;
private final BitSet[] prefilters;
private int numDocs = -1;
private int topK;
private final int topK;

/**
* Constructs an instance of {@link BruteForceQuery} using queryVectors,
* mapping, and topK.
*
* @param queryVectors 2D float query vector array
* @param mapping an instance of ID mapping
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @param topK the top k results to return
* @param prefilters the prefilters data to use while searching the BRUTEFORCE
* index
* @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
* Used only when prefilter(s) is/are passed.
*/
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
public BruteForceQuery(float[][] queryVectors, LongToIntFunction mapping, int topK, BitSet[] prefilters, int numDocs) {
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
Expand All @@ -63,11 +63,9 @@ public float[][] getQueryVectors() {
}

/**
* Gets the passed map instance.
*
* @return a map of ID mappings
* Gets the function mapping ordinals (neighbor IDs) to custom user IDs
*/
public List<Integer> getMapping() {
public LongToIntFunction getMapping() {
return mapping;
}

Expand Down Expand Up @@ -112,7 +110,7 @@ public static class Builder {
private float[][] queryVectors;
private BitSet[] prefilters;
private int numDocs;
private List<Integer> mapping;
private LongToIntFunction mapping = SearchResults.IDENTITY_MAPPING;
private int topK = 2;

/**
Expand All @@ -127,12 +125,12 @@ public Builder withQueryVectors(float[][] queryVectors) {
}

/**
* Sets the instance of mapping to be used for ID mapping.
* Sets the function used to map ordinals (neighbor IDs) to custom user IDs
*
* @param mapping the ID mapping instance
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @return an instance of this Builder
*/
public Builder withMapping(List<Integer> mapping) {
public Builder withMapping(LongToIntFunction mapping) {
this.mapping = mapping;
return this;
}
Expand Down
31 changes: 15 additions & 16 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.BitSet;
import java.util.function.LongToIntFunction;

/**
* CagraQuery holds the CagraSearchParams and the query vectors to be used while
Expand All @@ -28,12 +29,12 @@
*/
public class CagraQuery {

private CagraSearchParams cagraSearchParameters;
private List<Integer> mapping;
private float[][] queryVectors;
private int topK;
private BitSet prefilter;
private int numDocs;
private final CagraSearchParams cagraSearchParameters;
private final LongToIntFunction mapping;
private final float[][] queryVectors;
private final int topK;
private final BitSet prefilter;
private final int numDocs;

/**
* Constructs an instance of {@link CagraQuery} using cagraSearchParameters,
Expand All @@ -42,12 +43,12 @@ public class CagraQuery {
* @param cagraSearchParameters an instance of {@link CagraSearchParams} holding
* the search parameters
* @param queryVectors 2D float query vector array
* @param mapping an instance of ID mapping
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @param topK the top k results to return
* @param prefilter A single BitSet to use as filter while searching the CAGRA index
* @param numDocs Total number of dataset vectors; used to align the prefilter correctly
*/
public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, List<Integer> mapping, int topK, BitSet prefilter, int numDocs) {
public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, LongToIntFunction mapping, int topK, BitSet prefilter, int numDocs) {
super();
this.cagraSearchParameters = cagraSearchParameters;
this.queryVectors = queryVectors;
Expand Down Expand Up @@ -76,11 +77,9 @@ public float[][] getQueryVectors() {
}

/**
* Gets the passed map instance.
*
* @return a map of ID mappings
* Gets the function mapping ordinals (neighbor IDs) to custom user IDs
*/
public List<Integer> getMapping() {
public LongToIntFunction getMapping() {
return mapping;
}

Expand Down Expand Up @@ -124,7 +123,7 @@ public static class Builder {

private CagraSearchParams cagraSearchParams;
private float[][] queryVectors;
private List<Integer> mapping;
private LongToIntFunction mapping = SearchResults.IDENTITY_MAPPING;
private int topK = 2;
private BitSet prefilter;
private int numDocs;
Expand Down Expand Up @@ -159,12 +158,12 @@ public Builder withQueryVectors(float[][] queryVectors) {
}

/**
* Sets the instance of mapping to be used for ID mapping.
* Sets the function used to map ordinals (neighbor IDs) to custom user IDs
*
* @param mapping the ID mapping instance
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @return an instance of this Builder
*/
public Builder withMapping(List<Integer> mapping) {
public Builder withMapping(LongToIntFunction mapping) {
this.mapping = mapping;
return this;
}
Expand Down
28 changes: 13 additions & 15 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.cuvs;

import java.util.Arrays;
import java.util.List;
import java.util.function.LongToIntFunction;

/**
* HnswQuery holds the query vectors to be used while invoking search on the
Expand All @@ -27,21 +27,21 @@
*/
public class HnswQuery {

private HnswSearchParams hnswSearchParams;
private List<Integer> mapping;
private float[][] queryVectors;
private int topK;
private final HnswSearchParams hnswSearchParams;
private final LongToIntFunction mapping;
private final float[][] queryVectors;
private final int topK;

/**
* Constructs an instance of {@link HnswQuery} using queryVectors, mapping, and
* topK.
*
* @param hnswSearchParams the search parameters to use
* @param queryVectors 2D float query vector array
* @param mapping an instance of ID mapping
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @param topK the top k results to return
*/
private HnswQuery(HnswSearchParams hnswSearchParams, float[][] queryVectors, List<Integer> mapping, int topK) {
private HnswQuery(HnswSearchParams hnswSearchParams, float[][] queryVectors, LongToIntFunction mapping, int topK) {
this.hnswSearchParams = hnswSearchParams;
this.queryVectors = queryVectors;
this.mapping = mapping;
Expand All @@ -67,11 +67,9 @@ public float[][] getQueryVectors() {
}

/**
* Gets the passed map instance.
*
* @return a map of ID mappings
* Gets the function mapping ordinals (neighbor IDs) to custom user IDs
*/
public List<Integer> getMapping() {
public LongToIntFunction getMapping() {
return mapping;
}

Expand All @@ -96,7 +94,7 @@ public static class Builder {

private HnswSearchParams hnswSearchParams;
private float[][] queryVectors;
private List<Integer> mapping;
private LongToIntFunction mapping = SearchResults.IDENTITY_MAPPING;
private int topK = 2;

/**
Expand All @@ -123,12 +121,12 @@ public Builder withQueryVectors(float[][] queryVectors) {
}

/**
* Sets the instance of mapping to be used for ID mapping.
* Sets the function used to map ordinals (neighbor IDs) to custom user IDs
*
* @param mapping the ID mapping instance
* @param mapping a function mapping ordinals (neighbor IDs) to custom user IDs
* @return an instance of this Builder
*/
public Builder withMapping(List<Integer> mapping) {
public Builder withMapping(LongToIntFunction mapping) {
this.mapping = mapping;
return this;
}
Expand Down
15 changes: 15 additions & 0 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/SearchResults.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,24 @@

import java.util.List;
import java.util.Map;
import java.util.function.LongToIntFunction;

public interface SearchResults {

/**
* The default identity function mapping neighbours IDs to user-defined IDs
*/
LongToIntFunction IDENTITY_MAPPING = l -> (int) l;

/**
* Creates a mapping function from a list lookup of custom user IDs
* @param mappingAsList a positional list of custom user IDs
* @return a function that maps the input ordinal to a custom user IDs, using the input as an index in the list
*/
static LongToIntFunction mappingsFromList(List<Integer> mappingAsList) {
return l -> mappingAsList.get((int) l);
}

/**
* Gets a list results as a map of neighbor IDs to distances.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.LongToIntFunction;

/**
* SearchResult encapsulates the logic for reading and holding search results.
Expand All @@ -32,7 +32,7 @@
public class BruteForceSearchResults extends SearchResultsImpl {

protected BruteForceSearchResults(SequenceLayout neighboursSequenceLayout, SequenceLayout distancesSequenceLayout,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, List<Integer> mapping,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, LongToIntFunction mapping,
long numberOfQueries) {
super(neighboursSequenceLayout, distancesSequenceLayout, neighboursMemorySegment, distancesMemorySegment, topK,
mapping, numberOfQueries);
Expand All @@ -49,7 +49,7 @@ protected void readResultMemorySegments() {
for (long i = 0; i < topK * numberOfQueries; i++) {
long id = (long) neighboursVarHandle.get(neighboursMemorySegment, 0L, i);
float dst = (float) distancesVarHandle.get(distancesMemorySegment, 0L, i);
intermediateResultMap.put(mapping != null ? mapping.get((int) id) : (int) id, dst);
intermediateResultMap.put(mapping.applyAsInt(id), dst);
count += 1;
if (count == topK) {
results.add(intermediateResultMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ private IndexReference build() throws Throwable {
public SearchResults search(CagraQuery query) throws Throwable {
try (var localArena = Arena.ofConfined()) {
checkNotDestroyed();
int topK = query.getMapping() != null ? Math.min(query.getMapping().size(), query.getTopK()) : query.getTopK();
int topK = query.getTopK();
long numQueries = query.getQueryVectors().length;
long numBlocks = topK * numQueries;
int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.LongToIntFunction;

import com.nvidia.cuvs.internal.common.SearchResultsImpl;

Expand All @@ -32,7 +32,7 @@
public class CagraSearchResults extends SearchResultsImpl {

protected CagraSearchResults(SequenceLayout neighboursSequenceLayout, SequenceLayout distancesSequenceLayout,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, List<Integer> mapping,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, LongToIntFunction mapping,
long numberOfQueries) {
super(neighboursSequenceLayout, distancesSequenceLayout, neighboursMemorySegment, distancesMemorySegment, topK,
mapping, numberOfQueries);
Expand All @@ -47,10 +47,10 @@ protected void readResultMemorySegments() {
Map<Integer, Float> intermediateResultMap = new LinkedHashMap<Integer, Float>();
int count = 0;
for (long i = 0; i < topK * numberOfQueries; i++) {
int id = (int) neighboursVarHandle.get(neighboursMemorySegment, 0L, i);
long id = (long) neighboursVarHandle.get(neighboursMemorySegment, 0L, i);
float dst = (float) distancesVarHandle.get(distancesMemorySegment, 0L, i);
if (id != Integer.MAX_VALUE) {
intermediateResultMap.put(mapping != null ? mapping.get(id) : id, dst);
intermediateResultMap.put(mapping.applyAsInt(id), dst);
}
count += 1;
if (count == topK) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.LongToIntFunction;

/**
* SearchResult encapsulates the logic for reading and holding search results.
Expand All @@ -32,7 +32,7 @@
public class HnswSearchResults extends SearchResultsImpl {

protected HnswSearchResults(SequenceLayout neighboursSequenceLayout, SequenceLayout distancesSequenceLayout,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, List<Integer> mapping,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, LongToIntFunction mapping,
long numberOfQueries) {
super(neighboursSequenceLayout, distancesSequenceLayout, neighboursMemorySegment, distancesMemorySegment, topK,
mapping, numberOfQueries);
Expand All @@ -49,7 +49,7 @@ protected void readResultMemorySegments() {
for (long i = 0; i < topK * numberOfQueries; i++) {
long id = (long) neighboursVarHandle.get(neighboursMemorySegment, 0L, i);
float dst = (float) distancesVarHandle.get(distancesMemorySegment, 0L, i);
intermediateResultMap.put(mapping != null ? mapping.get((int) id) : (int) id, dst);
intermediateResultMap.put(mapping.applyAsInt(id), dst);
count += 1;
if (count == topK) {
results.add(intermediateResultMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.LongToIntFunction;

public abstract class SearchResultsImpl implements SearchResults {

protected final List<Map<Integer, Float>> results;
protected final List<Integer> mapping; // TODO: Is this performant in a user application?
protected final LongToIntFunction mapping;

protected final MemorySegment neighboursMemorySegment;
protected final MemorySegment distancesMemorySegment;
Expand All @@ -23,7 +24,7 @@ public abstract class SearchResultsImpl implements SearchResults {
protected final VarHandle distancesVarHandle;

protected SearchResultsImpl(SequenceLayout neighboursSequenceLayout, SequenceLayout distancesSequenceLayout,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, List<Integer> mapping,
MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, LongToIntFunction mapping,
long numberOfQueries) {
this.topK = topK;
this.numberOfQueries = numberOfQueries;
Expand Down
Loading