From 789b903e1d33eaf22df9662a7c6f3a4bd0b7bd1f Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 13 Jun 2026 01:42:54 +0100 Subject: [PATCH] Fix LlamaIndexEmbeddingOperator returning vector=None for every chunk VectorStoreIndex attaches embeddings to model_copy() copies of the nodes it is given, never the originals, so reading node.embedding after index construction always returned None. Embed the original nodes explicitly with the same content VectorStoreIndex embeds (MetadataMode.EMBED) and only build the index when persist_dir is set; embed_nodes() inside the index skips pre-embedded nodes, so persisting does not re-call the embedding API. --- .../docs/operators/llamaindex_embedding.rst | 12 +- .../ai/operators/llamaindex_embedding.py | 47 ++++--- .../ai/operators/test_llamaindex_embedding.py | 130 +++++++++++++++--- 3 files changed, 151 insertions(+), 38 deletions(-) diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst index 99125ac74bdde..894045b684c6d 100644 --- a/providers/common/ai/docs/operators/llamaindex_embedding.rst +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -25,10 +25,10 @@ LlamaIndex. Designed to feed the output of :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` into vector storage (pgvector, Pinecone, Weaviate, ...). -The operator passes the embedding model **directly** to -``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate -LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the same -worker process don't race on shared model state. +The operator calls the embedding model **directly** (and passes it to +``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does not +mutate LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the +same worker process don't race on shared model state. Basic usage ----------- @@ -117,3 +117,7 @@ Returns a dict with:: ... ], } + +``vector`` is computed over the chunk's metadata-enriched content +(LlamaIndex's ``MetadataMode.EMBED``, the same content ``VectorStoreIndex`` +embeds), while ``text`` is the raw chunk text without metadata. diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index d85e692100202..34cd441ef0565 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -44,10 +44,10 @@ class LlamaIndexEmbeddingOperator(BaseOperator): ``list[dict]`` with ``text`` and ``metadata`` keys; output includes the embedding vectors ready for downstream storage ingest. - The operator passes the embedding model **directly** to - ``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate - LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the - same worker don't race on shared state. + The operator calls the embedding model **directly** (and passes it to + ``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does + not mutate LlamaIndex's global ``Settings`` singleton, so concurrent + tasks in the same worker don't race on shared state. :param documents: List of dicts with ``text`` and ``metadata`` keys, typically from ``DocumentLoaderOperator`` or a ``@task``. Templated, @@ -114,6 +114,7 @@ def execute(self, context: Context) -> dict[str, Any]: try: from llama_index.core import Document, VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter + from llama_index.core.schema import MetadataMode except ImportError as e: raise AirflowOptionalProviderFeatureException(e) @@ -125,19 +126,32 @@ def execute(self, context: Context) -> dict[str, Any]: nodes = splitter.get_nodes_from_documents(llama_docs) self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) - # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a - # side effect of building the index; capture the index so the - # variable isn't discarded. - index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) + # ``VectorStoreIndex(...)`` never sets ``.embedding`` on the nodes it + # is given -- ``_get_node_with_embedding()`` attaches embeddings to + # ``model_copy()`` copies, so reading ``node.embedding`` afterwards + # always returns ``None`` (apache/airflow#68416). Embed the original + # nodes explicitly. + # ``MetadataMode.EMBED`` matches what ``embed_nodes()`` inside the + # index embeds (includes metadata, respects + # ``excluded_embed_metadata_keys``). + texts = [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes] + vectors = embed_model.get_text_embedding_batch(texts, show_progress=False) + for node, vector in zip(nodes, vectors, strict=True): + node.embedding = vector if self.persist_dir: + # The index is only needed for persistence. ``embed_nodes()`` + # inside ``VectorStoreIndex`` skips nodes whose ``.embedding`` is + # already set, so this reuses the vectors above instead of + # re-calling the embedding API. + index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) self._persist(index, self.persist_dir) # ``SentenceSplitter`` always returns ``TextNode`` instances, but the # base ``get_nodes_from_documents`` signature is typed as # ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't - # flag the ``.text`` access; ``node.embedding`` is populated by - # ``VectorStoreIndex`` for every node above. + # flag the ``.text`` access; ``node.embedding`` is populated by the + # pre-embed step above for every node. text_nodes = cast("list[TextNode]", nodes) chunks = [ { @@ -164,9 +178,9 @@ def _resolve_embed_model(self) -> BaseEmbedding: * ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via ``LlamaIndexHook`` (the framework's documented ``default`` behaviour). - * Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as - a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a - ``llama_index`` import here). + * Has ``get_text_embedding_batch`` / ``_get_query_embedding`` -- + treat as a pre-built ``BaseEmbedding`` (duck-typed to avoid + forcing a ``llama_index`` import here). * Anything else -- ``TypeError`` with a clear pointer. """ if self.embed_model is None or isinstance(self.embed_model, str): @@ -179,10 +193,11 @@ def _resolve_embed_model(self) -> BaseEmbedding: ).get_embedding_model() # ``BaseEmbedding`` always exposes these two methods (see - # ``llama_index.core.base.embeddings.base``). Duck-typing avoids - # importing ``llama_index`` here and also catches the case where an + # ``llama_index.core.base.embeddings.base``); ``execute`` calls + # ``get_text_embedding_batch``. Duck-typing avoids importing + # ``llama_index`` here and also catches the case where an # unresolved ``XComArg`` slips through. - if hasattr(self.embed_model, "get_text_embedding") and hasattr( + if hasattr(self.embed_model, "get_text_embedding_batch") and hasattr( self.embed_model, "_get_query_embedding" ): return self.embed_model diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 43b44f87c9ff4..d3e6b2acfca34 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -20,6 +20,11 @@ import pytest +pytest.importorskip("llama_index.core") + +from llama_index.core import MockEmbedding +from llama_index.core.schema import MetadataMode + from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator @@ -38,17 +43,23 @@ def _li(monkeypatch): return {"VectorStoreIndex": VectorStoreIndex, "SentenceSplitter": SentenceSplitter} -def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): +def _node(text: str = "chunk text", metadata: dict | None = None): node = MagicMock() node.text = text node.metadata = metadata or {} - node.embedding = vector + node.embedding = None + node.get_content.return_value = text return node -def _byo_embedding(): - """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks).""" - return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) +def _byo_embedding(vectors: list[list[float]] | None = None): + """Return a duck-typed ``BaseEmbedding`` stand-in (has the methods the operator checks and calls).""" + embedding = MagicMock( + name="MyBaseEmbedding", + spec=["get_text_embedding_batch", "_get_query_embedding"], + ) + embedding.get_text_embedding_batch.return_value = [[0.0]] if vectors is None else vectors + return embedding class TestEmbeddingOperatorInit: @@ -72,8 +83,9 @@ class TestEmbeddingOperatorExecute: def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): # `embed_model` as a string -> hook builds OpenAIEmbedding. _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="chunk a", vector=[0.1, 0.2]), + _node(text="chunk a"), ] + mock_get_embed.return_value = _byo_embedding(vectors=[[0.1, 0.2]]) op = LlamaIndexEmbeddingOperator( task_id="test", @@ -93,6 +105,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + mock_hook_cls.return_value.get_embedding_model.return_value = _byo_embedding() op = LlamaIndexEmbeddingOperator( task_id="test", @@ -110,8 +123,9 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): ) def test_byo_embed_model_bypasses_hook(self, _li): - # `embed_model` is a non-string instance -> hook is bypassed. - byo = _byo_embedding() + # `embed_model` is a non-string instance -> hook is bypassed and the + # user's instance does the embedding. + byo = _byo_embedding(vectors=[[0.5]]) _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( @@ -119,12 +133,10 @@ def test_byo_embed_model_bypasses_hook(self, _li): documents=[{"text": "doc"}], embed_model=byo, ) - op.execute(context=MagicMock()) + result = op.execute(context=MagicMock()) - # VectorStoreIndex called with the user's instance, not anything else. - _li["VectorStoreIndex"].assert_called_once() - kwargs = _li["VectorStoreIndex"].call_args.kwargs - assert kwargs["embed_model"] is byo + byo.get_text_embedding_batch.assert_called_once() + assert result["chunks"][0]["vector"] == [0.5] def test_invalid_embed_model_raises_typeerror(self, _li): # An object that's neither None/str nor duck-types as BaseEmbedding @@ -140,17 +152,16 @@ def test_invalid_embed_model_raises_typeerror(self, _li): with pytest.raises(TypeError, match="embed_model must be"): op.execute(context=MagicMock()) - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): + def test_chunks_carry_text_metadata_vector(self, _li): _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]), - _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), + _node(text="x", metadata={"k": "v"}), + _node(text="y", metadata={"k": "v2"}), ] op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], - embed_model="text-embedding-3-small", + embed_model=_byo_embedding(vectors=[[1.0, 2.0], [3.0, 4.0]]), ) result = op.execute(context=MagicMock()) @@ -159,8 +170,84 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): {"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]}, ] + def test_nodes_embedded_with_embed_metadata_mode(self, _li): + # llama-index's own ``embed_nodes()`` embeds + # ``node.get_content(metadata_mode=MetadataMode.EMBED)`` (includes + # metadata, respects ``excluded_embed_metadata_keys``). The pre-embed + # step must match, or the vectors silently change semantics. + node = _node(text="chunk a") + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + byo = _byo_embedding() + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=byo, + ) + op.execute(context=MagicMock()) + + node.get_content.assert_called_once_with(metadata_mode=MetadataMode.EMBED) + byo.get_text_embedding_batch.assert_called_once() + assert byo.get_text_embedding_batch.call_args.args[0] == ["chunk a"] + + def test_index_only_built_when_persisting(self, _li): + # Without ``persist_dir`` the index would be built and immediately + # discarded; the vectors come from the pre-embed step. + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=_byo_embedding(), + ) + op.execute(context=MagicMock()) + + _li["VectorStoreIndex"].assert_not_called() + + def test_vectors_populated_with_real_llama_index(self): + # Regression test for #68416: ``VectorStoreIndex`` attaches embeddings + # to node *copies* (``model_copy()`` in ``_get_node_with_embedding``), + # so reading ``node.embedding`` after index construction returns + # ``None``. Run the real llama-index code path with its offline + # ``MockEmbedding`` -- no mocks on the operator's internals. + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "hello world", "metadata": {"src": "a"}}], + embed_model=MockEmbedding(embed_dim=8), + ) + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] >= 1 + assert all(chunk["vector"] is not None for chunk in result["chunks"]) + assert all(len(chunk["vector"]) == 8 for chunk in result["chunks"]) + class TestEmbeddingOperatorPersist: + def test_persist_path_embeds_each_chunk_once_with_real_llama_index(self, tmp_path): + # ``embed_nodes()`` inside ``VectorStoreIndex`` must skip the + # pre-embedded nodes -- if it re-embeds, every chunk pays the + # embedding API twice. Runs the real llama-index persist path. + embedded_texts: list[str] = [] + + class CountingMockEmbedding(MockEmbedding): + def _get_text_embedding(self, text: str) -> list[float]: + embedded_texts.append(text) + return super()._get_text_embedding(text) + + persist_dir = tmp_path / "idx" + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "hello world", "metadata": {"src": "a"}}], + embed_model=CountingMockEmbedding(embed_dim=8), + persist_dir=str(persist_dir), + ) + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] == 1 + assert len(embedded_texts) == result["chunk_count"] + assert all(chunk["vector"] is not None for chunk in result["chunks"]) + assert any(persist_dir.iterdir()) + @patch("os.makedirs") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_local_persist_dir_calls_makedirs_and_storage_persist( @@ -168,6 +255,7 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( ): node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + mock_get_embed.return_value = _byo_embedding(vectors=[[0.1]]) index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator( @@ -180,6 +268,11 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( mock_makedirs.assert_called_once_with(str(tmp_path / "idx"), exist_ok=True) index.storage_context.persist.assert_called_once_with(persist_dir=str(tmp_path / "idx")) + # Nodes are already embedded when handed to the index (the + # no-double-embed behavior itself is pinned by + # ``test_persist_path_embeds_each_chunk_once_with_real_llama_index``). + nodes_arg = _li["VectorStoreIndex"].call_args.args[0] + assert nodes_arg[0].embedding == [0.1] @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") @@ -195,6 +288,7 @@ def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mo node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + mock_get_embed.return_value = _byo_embedding() index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator(