diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst index 99125ac74bdde..894045b684c6d 100644 --- a/providers/common/ai/docs/operators/llamaindex_embedding.rst +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -25,10 +25,10 @@ LlamaIndex. Designed to feed the output of :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` into vector storage (pgvector, Pinecone, Weaviate, ...). -The operator passes the embedding model **directly** to -``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate -LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the same -worker process don't race on shared model state. +The operator calls the embedding model **directly** (and passes it to +``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does not +mutate LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the +same worker process don't race on shared model state. Basic usage ----------- @@ -117,3 +117,7 @@ Returns a dict with:: ... ], } + +``vector`` is computed over the chunk's metadata-enriched content +(LlamaIndex's ``MetadataMode.EMBED``, the same content ``VectorStoreIndex`` +embeds), while ``text`` is the raw chunk text without metadata. diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index d85e692100202..34cd441ef0565 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -44,10 +44,10 @@ class LlamaIndexEmbeddingOperator(BaseOperator): ``list[dict]`` with ``text`` and ``metadata`` keys; output includes the embedding vectors ready for downstream storage ingest. - The operator passes the embedding model **directly** to - ``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate - LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the - same worker don't race on shared state. + The operator calls the embedding model **directly** (and passes it to + ``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does + not mutate LlamaIndex's global ``Settings`` singleton, so concurrent + tasks in the same worker don't race on shared state. :param documents: List of dicts with ``text`` and ``metadata`` keys, typically from ``DocumentLoaderOperator`` or a ``@task``. Templated, @@ -114,6 +114,7 @@ def execute(self, context: Context) -> dict[str, Any]: try: from llama_index.core import Document, VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter + from llama_index.core.schema import MetadataMode except ImportError as e: raise AirflowOptionalProviderFeatureException(e) @@ -125,19 +126,32 @@ def execute(self, context: Context) -> dict[str, Any]: nodes = splitter.get_nodes_from_documents(llama_docs) self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) - # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a - # side effect of building the index; capture the index so the - # variable isn't discarded. - index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) + # ``VectorStoreIndex(...)`` never sets ``.embedding`` on the nodes it + # is given -- ``_get_node_with_embedding()`` attaches embeddings to + # ``model_copy()`` copies, so reading ``node.embedding`` afterwards + # always returns ``None`` (apache/airflow#68416). Embed the original + # nodes explicitly. + # ``MetadataMode.EMBED`` matches what ``embed_nodes()`` inside the + # index embeds (includes metadata, respects + # ``excluded_embed_metadata_keys``). + texts = [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes] + vectors = embed_model.get_text_embedding_batch(texts, show_progress=False) + for node, vector in zip(nodes, vectors, strict=True): + node.embedding = vector if self.persist_dir: + # The index is only needed for persistence. ``embed_nodes()`` + # inside ``VectorStoreIndex`` skips nodes whose ``.embedding`` is + # already set, so this reuses the vectors above instead of + # re-calling the embedding API. + index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) self._persist(index, self.persist_dir) # ``SentenceSplitter`` always returns ``TextNode`` instances, but the # base ``get_nodes_from_documents`` signature is typed as # ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't - # flag the ``.text`` access; ``node.embedding`` is populated by - # ``VectorStoreIndex`` for every node above. + # flag the ``.text`` access; ``node.embedding`` is populated by the + # pre-embed step above for every node. text_nodes = cast("list[TextNode]", nodes) chunks = [ { @@ -164,9 +178,9 @@ def _resolve_embed_model(self) -> BaseEmbedding: * ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via ``LlamaIndexHook`` (the framework's documented ``default`` behaviour). - * Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as - a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a - ``llama_index`` import here). + * Has ``get_text_embedding_batch`` / ``_get_query_embedding`` -- + treat as a pre-built ``BaseEmbedding`` (duck-typed to avoid + forcing a ``llama_index`` import here). * Anything else -- ``TypeError`` with a clear pointer. """ if self.embed_model is None or isinstance(self.embed_model, str): @@ -179,10 +193,11 @@ def _resolve_embed_model(self) -> BaseEmbedding: ).get_embedding_model() # ``BaseEmbedding`` always exposes these two methods (see - # ``llama_index.core.base.embeddings.base``). Duck-typing avoids - # importing ``llama_index`` here and also catches the case where an + # ``llama_index.core.base.embeddings.base``); ``execute`` calls + # ``get_text_embedding_batch``. Duck-typing avoids importing + # ``llama_index`` here and also catches the case where an # unresolved ``XComArg`` slips through. - if hasattr(self.embed_model, "get_text_embedding") and hasattr( + if hasattr(self.embed_model, "get_text_embedding_batch") and hasattr( self.embed_model, "_get_query_embedding" ): return self.embed_model diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 43b44f87c9ff4..d3e6b2acfca34 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -20,6 +20,11 @@ import pytest +pytest.importorskip("llama_index.core") + +from llama_index.core import MockEmbedding +from llama_index.core.schema import MetadataMode + from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator @@ -38,17 +43,23 @@ def _li(monkeypatch): return {"VectorStoreIndex": VectorStoreIndex, "SentenceSplitter": SentenceSplitter} -def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): +def _node(text: str = "chunk text", metadata: dict | None = None): node = MagicMock() node.text = text node.metadata = metadata or {} - node.embedding = vector + node.embedding = None + node.get_content.return_value = text return node -def _byo_embedding(): - """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks).""" - return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) +def _byo_embedding(vectors: list[list[float]] | None = None): + """Return a duck-typed ``BaseEmbedding`` stand-in (has the methods the operator checks and calls).""" + embedding = MagicMock( + name="MyBaseEmbedding", + spec=["get_text_embedding_batch", "_get_query_embedding"], + ) + embedding.get_text_embedding_batch.return_value = [[0.0]] if vectors is None else vectors + return embedding class TestEmbeddingOperatorInit: @@ -72,8 +83,9 @@ class TestEmbeddingOperatorExecute: def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): # `embed_model` as a string -> hook builds OpenAIEmbedding. _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="chunk a", vector=[0.1, 0.2]), + _node(text="chunk a"), ] + mock_get_embed.return_value = _byo_embedding(vectors=[[0.1, 0.2]]) op = LlamaIndexEmbeddingOperator( task_id="test", @@ -93,6 +105,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + mock_hook_cls.return_value.get_embedding_model.return_value = _byo_embedding() op = LlamaIndexEmbeddingOperator( task_id="test", @@ -110,8 +123,9 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): ) def test_byo_embed_model_bypasses_hook(self, _li): - # `embed_model` is a non-string instance -> hook is bypassed. - byo = _byo_embedding() + # `embed_model` is a non-string instance -> hook is bypassed and the + # user's instance does the embedding. + byo = _byo_embedding(vectors=[[0.5]]) _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( @@ -119,12 +133,10 @@ def test_byo_embed_model_bypasses_hook(self, _li): documents=[{"text": "doc"}], embed_model=byo, ) - op.execute(context=MagicMock()) + result = op.execute(context=MagicMock()) - # VectorStoreIndex called with the user's instance, not anything else. - _li["VectorStoreIndex"].assert_called_once() - kwargs = _li["VectorStoreIndex"].call_args.kwargs - assert kwargs["embed_model"] is byo + byo.get_text_embedding_batch.assert_called_once() + assert result["chunks"][0]["vector"] == [0.5] def test_invalid_embed_model_raises_typeerror(self, _li): # An object that's neither None/str nor duck-types as BaseEmbedding @@ -140,17 +152,16 @@ def test_invalid_embed_model_raises_typeerror(self, _li): with pytest.raises(TypeError, match="embed_model must be"): op.execute(context=MagicMock()) - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): + def test_chunks_carry_text_metadata_vector(self, _li): _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]), - _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), + _node(text="x", metadata={"k": "v"}), + _node(text="y", metadata={"k": "v2"}), ] op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], - embed_model="text-embedding-3-small", + embed_model=_byo_embedding(vectors=[[1.0, 2.0], [3.0, 4.0]]), ) result = op.execute(context=MagicMock()) @@ -159,8 +170,84 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): {"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]}, ] + def test_nodes_embedded_with_embed_metadata_mode(self, _li): + # llama-index's own ``embed_nodes()`` embeds + # ``node.get_content(metadata_mode=MetadataMode.EMBED)`` (includes + # metadata, respects ``excluded_embed_metadata_keys``). The pre-embed + # step must match, or the vectors silently change semantics. + node = _node(text="chunk a") + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + byo = _byo_embedding() + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=byo, + ) + op.execute(context=MagicMock()) + + node.get_content.assert_called_once_with(metadata_mode=MetadataMode.EMBED) + byo.get_text_embedding_batch.assert_called_once() + assert byo.get_text_embedding_batch.call_args.args[0] == ["chunk a"] + + def test_index_only_built_when_persisting(self, _li): + # Without ``persist_dir`` the index would be built and immediately + # discarded; the vectors come from the pre-embed step. + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=_byo_embedding(), + ) + op.execute(context=MagicMock()) + + _li["VectorStoreIndex"].assert_not_called() + + def test_vectors_populated_with_real_llama_index(self): + # Regression test for #68416: ``VectorStoreIndex`` attaches embeddings + # to node *copies* (``model_copy()`` in ``_get_node_with_embedding``), + # so reading ``node.embedding`` after index construction returns + # ``None``. Run the real llama-index code path with its offline + # ``MockEmbedding`` -- no mocks on the operator's internals. + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "hello world", "metadata": {"src": "a"}}], + embed_model=MockEmbedding(embed_dim=8), + ) + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] >= 1 + assert all(chunk["vector"] is not None for chunk in result["chunks"]) + assert all(len(chunk["vector"]) == 8 for chunk in result["chunks"]) + class TestEmbeddingOperatorPersist: + def test_persist_path_embeds_each_chunk_once_with_real_llama_index(self, tmp_path): + # ``embed_nodes()`` inside ``VectorStoreIndex`` must skip the + # pre-embedded nodes -- if it re-embeds, every chunk pays the + # embedding API twice. Runs the real llama-index persist path. + embedded_texts: list[str] = [] + + class CountingMockEmbedding(MockEmbedding): + def _get_text_embedding(self, text: str) -> list[float]: + embedded_texts.append(text) + return super()._get_text_embedding(text) + + persist_dir = tmp_path / "idx" + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "hello world", "metadata": {"src": "a"}}], + embed_model=CountingMockEmbedding(embed_dim=8), + persist_dir=str(persist_dir), + ) + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] == 1 + assert len(embedded_texts) == result["chunk_count"] + assert all(chunk["vector"] is not None for chunk in result["chunks"]) + assert any(persist_dir.iterdir()) + @patch("os.makedirs") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_local_persist_dir_calls_makedirs_and_storage_persist( @@ -168,6 +255,7 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( ): node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + mock_get_embed.return_value = _byo_embedding(vectors=[[0.1]]) index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator( @@ -180,6 +268,11 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( mock_makedirs.assert_called_once_with(str(tmp_path / "idx"), exist_ok=True) index.storage_context.persist.assert_called_once_with(persist_dir=str(tmp_path / "idx")) + # Nodes are already embedded when handed to the index (the + # no-double-embed behavior itself is pinned by + # ``test_persist_path_embeds_each_chunk_once_with_real_llama_index``). + nodes_arg = _li["VectorStoreIndex"].call_args.args[0] + assert nodes_arg[0].embedding == [0.1] @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") @@ -195,6 +288,7 @@ def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mo node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + mock_get_embed.return_value = _byo_embedding() index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator(