Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions src/pipelines/weaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from __future__ import annotations

import asyncio
from functools import partial
import logging
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -126,7 +128,6 @@ async def flush_add_batch():
# Prepare data for batch add
valid_ops = []
texts = []
embeddings = []
metadatas = []

for op in add_batch_ops:
Expand All @@ -139,41 +140,70 @@ async def flush_add_batch():
continue

try:
emb = self.embed_fn(op.content)
meta = {"user_id": user_id, "domain": domain.value}
meta.update(_extract_structured_metadata(op.content))

valid_ops.append(op)
texts.append(op.content)
embeddings.append(emb)
metadatas.append(meta)
except Exception as exc:
logger.error("Embedding generation failed for ADD: %s", exc)
logger.error("Metadata extraction failed for ADD: %s", exc)
executed_ops.append(ExecutedOp(
type=op.type, status=OpStatus.FAILED,
content=op.content, error=str(exc)
))

if valid_ops:
try:
ids = self.vector_store.add(
texts=texts,
embeddings=embeddings,
metadata=metadatas,
)
# Map IDs back to ops
for op, new_id in zip(valid_ops, ids):
executed_ops.append(ExecutedOp(
type=op.type, status=OpStatus.SUCCESS,
content=op.content, new_id=new_id,
))
except Exception as exc:
logger.error("Vector batch ADD failed: %s", exc)
for op in valid_ops:
loop = asyncio.get_running_loop()

async def _embed(text: str) -> List[float]:
return await loop.run_in_executor(None, self.embed_fn, text)

tasks = [_embed(text) for text in texts]
results = await asyncio.gather(*tasks, return_exceptions=True)

successful_ops = []
successful_texts = []
successful_embeddings = []
successful_metadatas = []

for op, text, meta, res in zip(valid_ops, texts, metadatas, results):
if isinstance(res, Exception):
logger.error("Embedding generation failed for ADD: %s", res)
executed_ops.append(ExecutedOp(
type=op.type, status=OpStatus.FAILED,
content=op.content, error=str(exc)
content=op.content, error=str(res)
))
else:
successful_ops.append(op)
successful_texts.append(text)
successful_embeddings.append(res)
successful_metadatas.append(meta)

if successful_ops:
try:
ids = await loop.run_in_executor(
None,
partial(
self.vector_store.add,
texts=successful_texts,
embeddings=successful_embeddings,
metadata=successful_metadatas,
)
)
# Map IDs back to ops
for op, new_id in zip(successful_ops, ids):
executed_ops.append(ExecutedOp(
type=op.type, status=OpStatus.SUCCESS,
content=op.content, new_id=new_id,
))
except Exception as exc:
logger.error("Vector batch ADD failed: %s", exc)
for op in successful_ops:
executed_ops.append(ExecutedOp(
type=op.type, status=OpStatus.FAILED,
content=op.content, error=str(exc)
))

add_batch_ops.clear()

Expand All @@ -197,8 +227,9 @@ async def flush_delete_batch():
ids_to_delete.append(op.embedding_id)

if valid_ops:
loop = asyncio.get_running_loop()
try:
success = self.vector_store.delete(ids=ids_to_delete)
success = await loop.run_in_executor(None, partial(self.vector_store.delete, ids=ids_to_delete))
status = OpStatus.SUCCESS if success else OpStatus.FAILED
for op in valid_ops:
executed_ops.append(ExecutedOp(
Expand Down