Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions providers/common/ai/docs/operators/llamaindex_embedding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ LlamaIndex. Designed to feed the output of
:class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator`
into vector storage (pgvector, Pinecone, Weaviate, ...).

The operator passes the embedding model **directly** to
``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate
LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the same
worker process don't race on shared model state.
The operator calls the embedding model **directly** (and passes it to
``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does not
mutate LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the
same worker process don't race on shared model state.

Basic usage
-----------
Expand Down Expand Up @@ -117,3 +117,7 @@ Returns a dict with::
...
],
}

``vector`` is computed over the chunk's metadata-enriched content
(LlamaIndex's ``MetadataMode.EMBED``, the same content ``VectorStoreIndex``
embeds), while ``text`` is the raw chunk text without metadata.
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ class LlamaIndexEmbeddingOperator(BaseOperator):
``list[dict]`` with ``text`` and ``metadata`` keys; output includes the
embedding vectors ready for downstream storage ingest.

The operator passes the embedding model **directly** to
``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate
LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the
same worker don't race on shared state.
The operator calls the embedding model **directly** (and passes it to
``VectorStoreIndex(..., embed_model=...)`` when persisting) -- it does
not mutate LlamaIndex's global ``Settings`` singleton, so concurrent
tasks in the same worker don't race on shared state.

:param documents: List of dicts with ``text`` and ``metadata`` keys,
typically from ``DocumentLoaderOperator`` or a ``@task``. Templated,
Expand Down Expand Up @@ -114,6 +114,7 @@ def execute(self, context: Context) -> dict[str, Any]:
try:
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)

Expand All @@ -125,19 +126,32 @@ def execute(self, context: Context) -> dict[str, Any]:
nodes = splitter.get_nodes_from_documents(llama_docs)
self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes))

# ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a
# side effect of building the index; capture the index so the
# variable isn't discarded.
index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False)
# ``VectorStoreIndex(...)`` never sets ``.embedding`` on the nodes it
# is given -- ``_get_node_with_embedding()`` attaches embeddings to
# ``model_copy()`` copies, so reading ``node.embedding`` afterwards
# always returns ``None`` (apache/airflow#68416). Embed the original
# nodes explicitly.
# ``MetadataMode.EMBED`` matches what ``embed_nodes()`` inside the
# index embeds (includes metadata, respects
# ``excluded_embed_metadata_keys``).
texts = [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes]
vectors = embed_model.get_text_embedding_batch(texts, show_progress=False)
for node, vector in zip(nodes, vectors, strict=True):
node.embedding = vector

if self.persist_dir:
# The index is only needed for persistence. ``embed_nodes()``
# inside ``VectorStoreIndex`` skips nodes whose ``.embedding`` is
# already set, so this reuses the vectors above instead of
# re-calling the embedding API.
index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False)
self._persist(index, self.persist_dir)

# ``SentenceSplitter`` always returns ``TextNode`` instances, but the
# base ``get_nodes_from_documents`` signature is typed as
# ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't
# flag the ``.text`` access; ``node.embedding`` is populated by
# ``VectorStoreIndex`` for every node above.
# flag the ``.text`` access; ``node.embedding`` is populated by the
# pre-embed step above for every node.
text_nodes = cast("list[TextNode]", nodes)
chunks = [
{
Expand All @@ -164,9 +178,9 @@ def _resolve_embed_model(self) -> BaseEmbedding:
* ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via
``LlamaIndexHook`` (the framework's documented ``default``
behaviour).
* Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as
a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a
``llama_index`` import here).
* Has ``get_text_embedding_batch`` / ``_get_query_embedding`` --
treat as a pre-built ``BaseEmbedding`` (duck-typed to avoid
forcing a ``llama_index`` import here).
* Anything else -- ``TypeError`` with a clear pointer.
"""
if self.embed_model is None or isinstance(self.embed_model, str):
Expand All @@ -179,10 +193,11 @@ def _resolve_embed_model(self) -> BaseEmbedding:
).get_embedding_model()

# ``BaseEmbedding`` always exposes these two methods (see
# ``llama_index.core.base.embeddings.base``). Duck-typing avoids
# importing ``llama_index`` here and also catches the case where an
# ``llama_index.core.base.embeddings.base``); ``execute`` calls
# ``get_text_embedding_batch``. Duck-typing avoids importing
# ``llama_index`` here and also catches the case where an
# unresolved ``XComArg`` slips through.
if hasattr(self.embed_model, "get_text_embedding") and hasattr(
if hasattr(self.embed_model, "get_text_embedding_batch") and hasattr(
self.embed_model, "_get_query_embedding"
):
return self.embed_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

import pytest

pytest.importorskip("llama_index.core")

from llama_index.core import MockEmbedding
from llama_index.core.schema import MetadataMode

from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator


Expand All @@ -38,17 +43,23 @@ def _li(monkeypatch):
return {"VectorStoreIndex": VectorStoreIndex, "SentenceSplitter": SentenceSplitter}


def _node(text: str = "chunk text", metadata: dict | None = None, vector=None):
def _node(text: str = "chunk text", metadata: dict | None = None):
node = MagicMock()
node.text = text
node.metadata = metadata or {}
node.embedding = vector
node.embedding = None
node.get_content.return_value = text
return node


def _byo_embedding():
"""Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks)."""
return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"])
def _byo_embedding(vectors: list[list[float]] | None = None):
"""Return a duck-typed ``BaseEmbedding`` stand-in (has the methods the operator checks and calls)."""
embedding = MagicMock(
name="MyBaseEmbedding",
spec=["get_text_embedding_batch", "_get_query_embedding"],
)
embedding.get_text_embedding_batch.return_value = [[0.0]] if vectors is None else vectors
return embedding


class TestEmbeddingOperatorInit:
Expand All @@ -72,8 +83,9 @@ class TestEmbeddingOperatorExecute:
def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
# `embed_model` as a string -> hook builds OpenAIEmbedding.
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
_node(text="chunk a", vector=[0.1, 0.2]),
_node(text="chunk a"),
]
mock_get_embed.return_value = _byo_embedding(vectors=[[0.1, 0.2]])

op = LlamaIndexEmbeddingOperator(
task_id="test",
Expand All @@ -93,6 +105,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li):
# ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API.
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()]
mock_hook_cls.return_value.get_embedding_model.return_value = _byo_embedding()

op = LlamaIndexEmbeddingOperator(
task_id="test",
Expand All @@ -110,21 +123,20 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li):
)

def test_byo_embed_model_bypasses_hook(self, _li):
# `embed_model` is a non-string instance -> hook is bypassed.
byo = _byo_embedding()
# `embed_model` is a non-string instance -> hook is bypassed and the
# user's instance does the embedding.
byo = _byo_embedding(vectors=[[0.5]])
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()]

op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
embed_model=byo,
)
op.execute(context=MagicMock())
result = op.execute(context=MagicMock())

# VectorStoreIndex called with the user's instance, not anything else.
_li["VectorStoreIndex"].assert_called_once()
kwargs = _li["VectorStoreIndex"].call_args.kwargs
assert kwargs["embed_model"] is byo
byo.get_text_embedding_batch.assert_called_once()
assert result["chunks"][0]["vector"] == [0.5]

def test_invalid_embed_model_raises_typeerror(self, _li):
# An object that's neither None/str nor duck-types as BaseEmbedding
Expand All @@ -140,17 +152,16 @@ def test_invalid_embed_model_raises_typeerror(self, _li):
with pytest.raises(TypeError, match="embed_model must be"):
op.execute(context=MagicMock())

@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li):
def test_chunks_carry_text_metadata_vector(self, _li):
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
_node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]),
_node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]),
_node(text="x", metadata={"k": "v"}),
_node(text="y", metadata={"k": "v2"}),
]

op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
embed_model="text-embedding-3-small",
embed_model=_byo_embedding(vectors=[[1.0, 2.0], [3.0, 4.0]]),
)
result = op.execute(context=MagicMock())

Expand All @@ -159,15 +170,92 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li):
{"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]},
]

def test_nodes_embedded_with_embed_metadata_mode(self, _li):
# llama-index's own ``embed_nodes()`` embeds
# ``node.get_content(metadata_mode=MetadataMode.EMBED)`` (includes
# metadata, respects ``excluded_embed_metadata_keys``). The pre-embed
# step must match, or the vectors silently change semantics.
node = _node(text="chunk a")
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node]
byo = _byo_embedding()

op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
embed_model=byo,
)
op.execute(context=MagicMock())

node.get_content.assert_called_once_with(metadata_mode=MetadataMode.EMBED)
byo.get_text_embedding_batch.assert_called_once()
assert byo.get_text_embedding_batch.call_args.args[0] == ["chunk a"]

def test_index_only_built_when_persisting(self, _li):
# Without ``persist_dir`` the index would be built and immediately
# discarded; the vectors come from the pre-embed step.
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()]

op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
embed_model=_byo_embedding(),
)
op.execute(context=MagicMock())

_li["VectorStoreIndex"].assert_not_called()

def test_vectors_populated_with_real_llama_index(self):
# Regression test for #68416: ``VectorStoreIndex`` attaches embeddings
# to node *copies* (``model_copy()`` in ``_get_node_with_embedding``),
# so reading ``node.embedding`` after index construction returns
# ``None``. Run the real llama-index code path with its offline
# ``MockEmbedding`` -- no mocks on the operator's internals.
op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "hello world", "metadata": {"src": "a"}}],
embed_model=MockEmbedding(embed_dim=8),
)
result = op.execute(context=MagicMock())

assert result["chunk_count"] >= 1
assert all(chunk["vector"] is not None for chunk in result["chunks"])
assert all(len(chunk["vector"]) == 8 for chunk in result["chunks"])


class TestEmbeddingOperatorPersist:
def test_persist_path_embeds_each_chunk_once_with_real_llama_index(self, tmp_path):
# ``embed_nodes()`` inside ``VectorStoreIndex`` must skip the
# pre-embedded nodes -- if it re-embeds, every chunk pays the
# embedding API twice. Runs the real llama-index persist path.
embedded_texts: list[str] = []

class CountingMockEmbedding(MockEmbedding):
def _get_text_embedding(self, text: str) -> list[float]:
embedded_texts.append(text)
return super()._get_text_embedding(text)

persist_dir = tmp_path / "idx"
op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "hello world", "metadata": {"src": "a"}}],
embed_model=CountingMockEmbedding(embed_dim=8),
persist_dir=str(persist_dir),
)
result = op.execute(context=MagicMock())

assert result["chunk_count"] == 1
assert len(embedded_texts) == result["chunk_count"]
assert all(chunk["vector"] is not None for chunk in result["chunks"])
assert any(persist_dir.iterdir())

@patch("os.makedirs")
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_local_persist_dir_calls_makedirs_and_storage_persist(
self, mock_get_embed, mock_makedirs, _li, tmp_path
):
node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node]
mock_get_embed.return_value = _byo_embedding(vectors=[[0.1]])
index = _li["VectorStoreIndex"].return_value

op = LlamaIndexEmbeddingOperator(
Expand All @@ -180,6 +268,11 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist(

mock_makedirs.assert_called_once_with(str(tmp_path / "idx"), exist_ok=True)
index.storage_context.persist.assert_called_once_with(persist_dir=str(tmp_path / "idx"))
# Nodes are already embedded when handed to the index (the
# no-double-embed behavior itself is pinned by
# ``test_persist_path_embeds_each_chunk_once_with_real_llama_index``).
nodes_arg = _li["VectorStoreIndex"].call_args.args[0]
assert nodes_arg[0].embedding == [0.1]

@patch("airflow.sdk.ObjectStoragePath")
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
Expand All @@ -195,6 +288,7 @@ def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mo

node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node]
mock_get_embed.return_value = _byo_embedding()
index = _li["VectorStoreIndex"].return_value

op = LlamaIndexEmbeddingOperator(
Expand Down
Loading