-
Notifications
You must be signed in to change notification settings - Fork 18
Enable variable-length queries #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,7 +115,7 @@ def compute_kmeans( # noqa: PLR0913 | |
|
|
||
| def search_on_device( # noqa: PLR0913 | ||
| device: str, | ||
| queries_embeddings: torch.Tensor, | ||
| queries_embeddings: list[torch.Tensor], | ||
| batch_size: int, | ||
| n_full_scores: int, | ||
| top_k: int, | ||
|
|
@@ -154,6 +154,17 @@ def search_on_device( # noqa: PLR0913 | |
| ] | ||
|
|
||
|
|
||
| def cleanup_embeddings(embeddings: list[torch.Tensor] | torch.Tensor) -> list[torch.Tensor]: | ||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = [ | ||
| embeddings[i] for i in range(embeddings.shape[0]) | ||
| ] | ||
| return [ | ||
| embedding.squeeze(0) if embedding.dim() == 3 else embedding | ||
| for embedding in embeddings | ||
| ] | ||
|
|
||
|
|
||
| class FastPlaid: | ||
| """A class for creating and searching a FastPlaid index. | ||
|
|
||
|
|
@@ -288,15 +299,7 @@ def create( # noqa: PLR0913 | |
| Optional list of dictionaries containing metadata for each document. | ||
|
|
||
| """ | ||
| if isinstance(documents_embeddings, torch.Tensor): | ||
| documents_embeddings = [ | ||
| documents_embeddings[i] for i in range(documents_embeddings.shape[0]) | ||
| ] | ||
|
|
||
| documents_embeddings = [ | ||
| embedding.squeeze(0) if embedding.dim() == 3 else embedding | ||
| for embedding in documents_embeddings | ||
| ] | ||
| documents_embeddings = cleanup_embeddings(documents_embeddings) | ||
| num_docs = len(documents_embeddings) | ||
|
|
||
| self._prepare_index_directory(index_path=self.index) | ||
|
|
@@ -473,17 +476,8 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 | |
| corresponding inner list. | ||
|
|
||
| """ | ||
| if isinstance(queries_embeddings, list): | ||
| queries_embeddings = torch.nn.utils.rnn.pad_sequence( | ||
| sequences=[ | ||
| embedding[0] if embedding.dim() == 3 else embedding | ||
| for embedding in queries_embeddings | ||
| ], | ||
| batch_first=True, | ||
| padding_value=0.0, | ||
| ) | ||
|
|
||
| num_queries = queries_embeddings.shape[0] | ||
| queries_embeddings = cleanup_embeddings(queries_embeddings) | ||
| num_queries = len(queries_embeddings) | ||
|
|
||
| if subset is not None: | ||
| if isinstance(subset, int): | ||
|
|
@@ -529,16 +523,14 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 | |
| num_cpus = len(self.devices) | ||
|
|
||
| # Use torch.chunk to split the tensor into num_cpus | ||
| queries_embeddings_splits = torch.chunk( | ||
| input=queries_embeddings, | ||
| chunks=num_cpus, | ||
| dim=0, | ||
| ) | ||
| queries_embeddings_splits = [ | ||
| queries_embeddings[i:i + num_cpus] for i in range(0, num_queries, num_cpus) | ||
| ] | ||
|
|
||
| # Filter out empty chunks that torch.chunk might create | ||
| # if num_queries < num_cpus | ||
| non_empty_splits = [ | ||
| split for split in queries_embeddings_splits if split.shape[0] > 0 | ||
| split for split in queries_embeddings_splits if len(split) > 0 | ||
| ] | ||
| num_splits = len(non_empty_splits) | ||
|
|
||
|
|
@@ -548,7 +540,7 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 | |
| if subset is not None: | ||
| current_idx = 0 | ||
| for split in non_empty_splits: | ||
| size = split.shape[0] | ||
| size = len(split) | ||
| subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore | ||
| current_idx += size | ||
|
|
||
|
|
@@ -600,16 +592,15 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 | |
| subset=subset, # type: ignore | ||
| ) | ||
|
|
||
| queries_embeddings_splits = torch.split( | ||
| tensor=queries_embeddings, | ||
| split_size_or_sections=len(self.devices), | ||
| ) | ||
| queries_embeddings_splits = [ | ||
| queries_embeddings[i:i + len(self.devices)] for i in range(0, num_queries, len(self.devices)) | ||
| ] | ||
|
Comment on lines
+597
to
+600
|
||
|
|
||
| num_splits = len(queries_embeddings_splits) | ||
| if subset is not None: | ||
| current_idx = 0 | ||
| for split in queries_embeddings_splits: | ||
| size = split.shape[0] | ||
| size = len(split) | ||
| subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore | ||
| current_idx += size | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list comprehension creates chunks of size
num_cpusstarting at increments ofnum_cpus, which is incorrect. The step should match the chunk size to avoid overlap. Use:queries_embeddings[i*num_cpus:(i+1)*num_cpus] for i in range((num_queries + num_cpus - 1) // num_cpus)]or similar logic to properly partition the list.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list comprehension seems correct to me. Don't know why Copilot thinks it's wrong