diff --git a/src/pipelines/weaver.py b/src/pipelines/weaver.py index 8138436..55d55cb 100644 --- a/src/pipelines/weaver.py +++ b/src/pipelines/weaver.py @@ -12,6 +12,8 @@ from __future__ import annotations +import asyncio +from functools import partial import logging from typing import Any, Callable, Dict, List, Optional @@ -126,7 +128,6 @@ async def flush_add_batch(): # Prepare data for batch add valid_ops = [] texts = [] - embeddings = [] metadatas = [] for op in add_batch_ops: @@ -139,41 +140,70 @@ async def flush_add_batch(): continue try: - emb = self.embed_fn(op.content) meta = {"user_id": user_id, "domain": domain.value} meta.update(_extract_structured_metadata(op.content)) valid_ops.append(op) texts.append(op.content) - embeddings.append(emb) metadatas.append(meta) except Exception as exc: - logger.error("Embedding generation failed for ADD: %s", exc) + logger.error("Metadata extraction failed for ADD: %s", exc) executed_ops.append(ExecutedOp( type=op.type, status=OpStatus.FAILED, content=op.content, error=str(exc) )) if valid_ops: - try: - ids = self.vector_store.add( - texts=texts, - embeddings=embeddings, - metadata=metadatas, - ) - # Map IDs back to ops - for op, new_id in zip(valid_ops, ids): - executed_ops.append(ExecutedOp( - type=op.type, status=OpStatus.SUCCESS, - content=op.content, new_id=new_id, - )) - except Exception as exc: - logger.error("Vector batch ADD failed: %s", exc) - for op in valid_ops: + loop = asyncio.get_running_loop() + + async def _embed(text: str) -> List[float]: + return await loop.run_in_executor(None, self.embed_fn, text) + + tasks = [_embed(text) for text in texts] + results = await asyncio.gather(*tasks, return_exceptions=True) + + successful_ops = [] + successful_texts = [] + successful_embeddings = [] + successful_metadatas = [] + + for op, text, meta, res in zip(valid_ops, texts, metadatas, results): + if isinstance(res, Exception): + logger.error("Embedding generation failed for ADD: %s", res) executed_ops.append(ExecutedOp( type=op.type, status=OpStatus.FAILED, - content=op.content, error=str(exc) + content=op.content, error=str(res) )) + else: + successful_ops.append(op) + successful_texts.append(text) + successful_embeddings.append(res) + successful_metadatas.append(meta) + + if successful_ops: + try: + ids = await loop.run_in_executor( + None, + partial( + self.vector_store.add, + texts=successful_texts, + embeddings=successful_embeddings, + metadata=successful_metadatas, + ) + ) + # Map IDs back to ops + for op, new_id in zip(successful_ops, ids): + executed_ops.append(ExecutedOp( + type=op.type, status=OpStatus.SUCCESS, + content=op.content, new_id=new_id, + )) + except Exception as exc: + logger.error("Vector batch ADD failed: %s", exc) + for op in successful_ops: + executed_ops.append(ExecutedOp( + type=op.type, status=OpStatus.FAILED, + content=op.content, error=str(exc) + )) add_batch_ops.clear() @@ -197,8 +227,9 @@ async def flush_delete_batch(): ids_to_delete.append(op.embedding_id) if valid_ops: + loop = asyncio.get_running_loop() try: - success = self.vector_store.delete(ids=ids_to_delete) + success = await loop.run_in_executor(None, partial(self.vector_store.delete, ids=ids_to_delete)) status = OpStatus.SUCCESS if success else OpStatus.FAILED for op in valid_ops: executed_ops.append(ExecutedOp(