Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 24 additions & 33 deletions python/fast_plaid/search/fast_plaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def compute_kmeans( # noqa: PLR0913

def search_on_device( # noqa: PLR0913
device: str,
queries_embeddings: torch.Tensor,
queries_embeddings: list[torch.Tensor],
batch_size: int,
n_full_scores: int,
top_k: int,
Expand Down Expand Up @@ -154,6 +154,17 @@ def search_on_device( # noqa: PLR0913
]


def cleanup_embeddings(embeddings: list[torch.Tensor] | torch.Tensor) -> list[torch.Tensor]:
if isinstance(embeddings, torch.Tensor):
embeddings = [
embeddings[i] for i in range(embeddings.shape[0])
]
return [
embedding.squeeze(0) if embedding.dim() == 3 else embedding
for embedding in embeddings
]


class FastPlaid:
"""A class for creating and searching a FastPlaid index.

Expand Down Expand Up @@ -288,15 +299,7 @@ def create( # noqa: PLR0913
Optional list of dictionaries containing metadata for each document.

"""
if isinstance(documents_embeddings, torch.Tensor):
documents_embeddings = [
documents_embeddings[i] for i in range(documents_embeddings.shape[0])
]

documents_embeddings = [
embedding.squeeze(0) if embedding.dim() == 3 else embedding
for embedding in documents_embeddings
]
documents_embeddings = cleanup_embeddings(documents_embeddings)
num_docs = len(documents_embeddings)

self._prepare_index_directory(index_path=self.index)
Expand Down Expand Up @@ -473,17 +476,8 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
corresponding inner list.

"""
if isinstance(queries_embeddings, list):
queries_embeddings = torch.nn.utils.rnn.pad_sequence(
sequences=[
embedding[0] if embedding.dim() == 3 else embedding
for embedding in queries_embeddings
],
batch_first=True,
padding_value=0.0,
)

num_queries = queries_embeddings.shape[0]
queries_embeddings = cleanup_embeddings(queries_embeddings)
num_queries = len(queries_embeddings)

if subset is not None:
if isinstance(subset, int):
Expand Down Expand Up @@ -529,16 +523,14 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
num_cpus = len(self.devices)

# Use torch.chunk to split the tensor into num_cpus
queries_embeddings_splits = torch.chunk(
input=queries_embeddings,
chunks=num_cpus,
dim=0,
)
queries_embeddings_splits = [
queries_embeddings[i:i + num_cpus] for i in range(0, num_queries, num_cpus)
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list comprehension creates chunks of size num_cpus starting at increments of num_cpus, which is incorrect. The step should match the chunk size to avoid overlap. Use: queries_embeddings[i*num_cpus:(i+1)*num_cpus] for i in range((num_queries + num_cpus - 1) // num_cpus)] or similar logic to properly partition the list.

Suggested change
queries_embeddings[i:i + num_cpus] for i in range(0, num_queries, num_cpus)
queries_embeddings[i*num_cpus:(i+1)*num_cpus] for i in range((num_queries + num_cpus - 1) // num_cpus)

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list comprehension seems correct to me. Don't know why Copilot thinks it's wrong

]

# Filter out empty chunks that torch.chunk might create
# if num_queries < num_cpus
non_empty_splits = [
split for split in queries_embeddings_splits if split.shape[0] > 0
split for split in queries_embeddings_splits if len(split) > 0
]
num_splits = len(non_empty_splits)

Expand All @@ -548,7 +540,7 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
if subset is not None:
current_idx = 0
for split in non_empty_splits:
size = split.shape[0]
size = len(split)
subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore
current_idx += size

Expand Down Expand Up @@ -600,16 +592,15 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
subset=subset, # type: ignore
)

queries_embeddings_splits = torch.split(
tensor=queries_embeddings,
split_size_or_sections=len(self.devices),
)
queries_embeddings_splits = [
queries_embeddings[i:i + len(self.devices)] for i in range(0, num_queries, len(self.devices))
]
Comment on lines +597 to +600
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list comprehension creates chunks of size len(self.devices) starting at increments of len(self.devices), which is incorrect. The step should match the chunk size to avoid overlap. Use proper chunking logic to partition the list without overlap.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I'm unsure what Copilot thinks is wrong


num_splits = len(queries_embeddings_splits)
if subset is not None:
current_idx = 0
for split in queries_embeddings_splits:
size = split.shape[0]
size = len(split)
subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore
current_idx += size
else:
Expand Down
9 changes: 7 additions & 2 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ fn load_and_search(
index: String,
torch_path: String,
device: String,
queries_embeddings: PyTensor,
queries_embeddings: Vec<PyTensor>,
search_parameters: &SearchParameters,
show_progress: bool,
preload_index: bool,
Expand All @@ -397,9 +397,14 @@ fn load_and_search(
Ok(Arc::new(loaded_index))
}?;

let queries_embeddings: Vec<_> = queries_embeddings
.into_iter()
.map(|tensor| tensor.to_kind(Kind::Half))
.collect();

// Perform the search
let results = search_many(
&queries_embeddings.to_kind(Kind::Half),
&queries_embeddings,
&index,
search_parameters,
device_tch,
Expand Down
18 changes: 7 additions & 11 deletions rust/search/search.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, Result};
use indicatif::{ProgressBar, ProgressIterator};
use pyo3::prelude::*;
use serde::Serialize;
use tch::{Device, IndexOp, Kind, Tensor};
use tch::{Device, Kind, Tensor};

use crate::search::load::LoadedIndex;
use crate::search::padding::direct_pad_sequences;
Expand Down Expand Up @@ -165,22 +165,18 @@ impl SearchParameters {
/// A `Result` with a `Vec<QueryResult>`. Individual search failures result in an empty
/// `QueryResult` for that specific query, ensuring the operation doesn't halt.
pub fn search_many(
queries: &Tensor,
queries: &Vec<Tensor>,
index: &LoadedIndex,
params: &SearchParameters,
device: Device,
show_progress: bool,
subset: Option<Vec<Vec<i64>>>,
) -> Result<Vec<QueryResult>> {
let [num_queries, _, query_dim] = queries.size()[..] else {
bail!(
"Expected a 3D tensor for queries, but got shape {:?}",
queries.size()
);
};
let num_queries = queries.len();
let query_dim = queries[0].size()[queries[0].dim() - 1];

let search_closure = |query_index| {
let query_embedding = queries.i(query_index).to(device);
let search_closure = |query_index: usize| {
let query_embedding = &queries[query_index].to(device);

// Handle the per-query subset list
let query_subset = subset.as_ref().and_then(|s| s.get(query_index as usize));
Expand Down
Loading