From fee20dc273aba3f342679db0d17059ff320b07e2 Mon Sep 17 00:00:00 2001 From: Vikram Koka Date: Mon, 18 May 2026 16:17:37 +0100 Subject: [PATCH 1/7] Add LlamaIndex operators to common.ai provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Adds LlamaIndexHook to bridge Airflow connections to LlamaIndex's Settings singleton. Reuses the pydanticai connection type, supports separate embedding and LLM connections. - Adds EmbeddingOperator to chunk documents and produce embedding vectors via LlamaIndex's SentenceSplitter. Input is list[dict(text, metadata)] (same shape as DocumentLoaderOperator output), output includes chunks with vectors ready for downstream vector store ingest operators (pgvector, Pinecone, Weaviate). - Adds RetrievalOperator to load a persisted LlamaIndex index and perform similarity search. Output is scored chunks ready for synthesis via LLMOperator. Design notes All LlamaIndex imports are lazy (inside execute() / method bodies), so modules parse without llama-index installed. The hook currently hardcodes OpenAI embedding/LLM providers; a follow-up PR will refactor to use BaseAIHook for provider-agnostic model resolution when it lands. What's included ┌─────────────────────────────────────────┬──────────────────────────────────────────┐ │ File │ Purpose │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ hooks/llamaindex.py │ Hook (~110 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ operators/llamaindex_embedding.py │ EmbeddingOperator (~110 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ operators/llamaindex_retrieval.py │ RetrievalOperator (~90 lines) │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex.py │ 12 hook tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex_embedding.py │ 10 operator tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ tests/.../test_llamaindex_retrieval.py │ 8 operator tests │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/hooks/llamaindex.rst │ Hook docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/llamaindex_embedding.rst │ EmbeddingOperator docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/llamaindex_retrieval.rst │ RetrievalOperator docs │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ provider.yaml │ Integration, hook, operator registration │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/index.rst │ LlamaIndex Hook in Guides toctree │ ├─────────────────────────────────────────┼──────────────────────────────────────────┤ │ docs/operators/index.rst │ Chooser table rows │ └─────────────────────────────────────────┴──────────────────────────────────────────┘ Test plan - uv run --project providers/common/ai pytest providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py -xvs (12 tests) - uv run --project providers/common/ai pytest providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py -xvs (18 tests) - Hook: init defaults, separate embed_conn_id, connection kwargs extraction, embedding model, LLM, Settings configuration - EmbeddingOperator: output shape, chunking, index persistence, vector inclusion/omission, splitter params - RetrievalOperator: output shape, chunk keys, top_k forwarding, multiple results, storage context --- Was generative AI tooling used to co-author this PR? - Yes — Claude Code (Opus 4.6) Generated-by: Claude Code (Opus 4.6) following https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions --- providers/common/ai/docs/hooks/llamaindex.rst | 87 ++++++++ providers/common/ai/docs/operators/index.rst | 6 + .../docs/operators/llamaindex_embedding.rst | 135 ++++++++++++ .../docs/operators/llamaindex_retrieval.rst | 108 ++++++++++ providers/common/ai/provider.yaml | 11 + .../providers/common/ai/hooks/llamaindex.py | 110 ++++++++++ .../ai/operators/llamaindex_embedding.py | 109 ++++++++++ .../ai/operators/llamaindex_retrieval.py | 96 +++++++++ .../unit/common/ai/hooks/test_llamaindex.py | 196 +++++++++++++++++ .../ai/operators/test_llamaindex_embedding.py | 202 ++++++++++++++++++ .../ai/operators/test_llamaindex_retrieval.py | 199 +++++++++++++++++ 11 files changed, 1259 insertions(+) create mode 100644 providers/common/ai/docs/hooks/llamaindex.rst create mode 100644 providers/common/ai/docs/operators/llamaindex_embedding.rst create mode 100644 providers/common/ai/docs/operators/llamaindex_retrieval.rst create mode 100644 providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py create mode 100644 providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py create mode 100644 providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py create mode 100644 providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py create mode 100644 providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py create mode 100644 providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py diff --git a/providers/common/ai/docs/hooks/llamaindex.rst b/providers/common/ai/docs/hooks/llamaindex.rst new file mode 100644 index 0000000000000..ff942a2f65bd4 --- /dev/null +++ b/providers/common/ai/docs/hooks/llamaindex.rst @@ -0,0 +1,87 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/hook:llamaindex: + +``LlamaIndexHook`` +================== + +Use :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` to +bridge Airflow connections to LlamaIndex's ``Settings`` singleton. The hook +reuses the ``pydanticai`` connection type, so users configure a single +connection for both pydantic-ai operators and LlamaIndex operators. + +.. seealso:: + :ref:`Connection configuration ` + +What It Does +------------ + +The hook resolves API keys and base URLs from Airflow connections and uses +them to configure LlamaIndex's embedding models, LLMs, and global settings. +This eliminates manual ``Settings.embed_model = ...`` boilerplate in every +task that uses LlamaIndex. + +Configuration +------------- + +``LlamaIndexHook`` reuses the ``pydanticai`` connection type. Set the API key +in the **Password** field and optionally a custom endpoint in the **Host** +field. + +Separate Embedding and LLM Connections +-------------------------------------- + +RAG pipelines often use different providers for embeddings and chat. The hook +supports an optional ``embed_conn_id`` parameter that defaults to the main +``llm_conn_id``: + +.. code-block:: python + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook( + llm_conn_id="openai_default", + embed_conn_id="embedding_provider", + embed_model="text-embedding-3-large", + llm_model="gpt-4o", + ) + hook.configure_settings() + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the LLM/embedding provider. + * - ``embed_conn_id`` + - Same as ``llm_conn_id`` + - Separate connection for embeddings (optional). + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``llm_model`` + - ``None`` + - LLM model name. Required for ``get_llm()`` and ``configure_settings()`` + LLM setup. diff --git a/providers/common/ai/docs/operators/index.rst b/providers/common/ai/docs/operators/index.rst index dec108990eee2..bc0b36bd9e8fa 100644 --- a/providers/common/ai/docs/operators/index.rst +++ b/providers/common/ai/docs/operators/index.rst @@ -49,6 +49,12 @@ to pick the one that fits your use case: * - Parse files (PDF, DOCX, CSV, etc.) into document dicts for embedding - :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` - *(no decorator)* + * - Chunk documents and produce embedding vectors + - :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` + - *(no decorator)* + * - Retrieve relevant chunks from a vector index + - :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` + - *(no decorator)* **LLMOperator / @task.llm** — stateless, single-turn calls. Use this for classification, summarization, extraction, or any prompt that produces one response. Supports structured output diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst new file mode 100644 index 0000000000000..2a32d056dc24a --- /dev/null +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -0,0 +1,135 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:llamaindex_embedding: + +``EmbeddingOperator`` +===================== + +Use :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +to chunk documents and produce embedding vectors using LlamaIndex. This operator +bridges document loading (Airflow provider hooks returning text) and vector +storage (pgvector, Pinecone, Weaviate ingest operators). + +Basic Usage +----------- + +Provide a list of documents with ``text`` and ``metadata`` keys. The operator +chunks the documents, embeds them, and returns the results: + +.. code-block:: python + + from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=[ + {"text": "Airflow is a workflow orchestration platform.", "metadata": {"source": "docs"}}, + {"text": "LlamaIndex is a data framework for LLM applications.", "metadata": {"source": "docs"}}, + ], + llm_conn_id="openai_default", + ) + +Connection Configuration +------------------------ + +The operator uses :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` +internally. Configure your embedding API credentials via the ``pydanticai`` +connection type. + +.. seealso:: + :ref:`Connection configuration ` + +Chunking Parameters +------------------- + +Control how documents are split into chunks before embedding: + +.. code-block:: python + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=documents, + llm_conn_id="openai_default", + chunk_size=256, + chunk_overlap=25, + ) + +Index Persistence +----------------- + +Set ``persist_dir`` to save the LlamaIndex index for later retrieval via +:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator`: + +.. code-block:: python + + embed = EmbeddingOperator( + task_id="embed_docs", + documents=documents, + llm_conn_id="openai_default", + persist_dir="/opt/airflow/data/my_index", + ) + +Output Shape +------------ + +The operator returns a dict: + +.. code-block:: python + + { + "document_count": 2, + "chunk_count": 5, + "persist_dir": "/opt/airflow/data/my_index", + "chunks": [ + {"text": "chunk text", "metadata": {"source": "docs"}, "vector": [0.1, ...]}, + ... + ], + } + +Each chunk includes ``text``, ``metadata``, and optionally ``vector`` (the +embedding array). The ``chunks`` list is ready for downstream consumption by +vector store ingest operators. + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``documents`` + - (required) + - List of dicts with ``text`` and ``metadata`` keys. + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the embedding API. + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``chunk_size`` + - ``512`` + - Chunk size for the sentence splitter. + * - ``chunk_overlap`` + - ``50`` + - Overlap between chunks. + * - ``persist_dir`` + - ``None`` + - Directory path to persist the index. diff --git a/providers/common/ai/docs/operators/llamaindex_retrieval.rst b/providers/common/ai/docs/operators/llamaindex_retrieval.rst new file mode 100644 index 0000000000000..ff238744cf196 --- /dev/null +++ b/providers/common/ai/docs/operators/llamaindex_retrieval.rst @@ -0,0 +1,108 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:llamaindex_retrieval: + +``RetrievalOperator`` +===================== + +Use :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` +to retrieve relevant document chunks from a persisted LlamaIndex index. The +operator performs similarity search against the provided query and returns +results ready for downstream synthesis via ``LLMOperator``. + +Basic Usage +----------- + +Provide a query string and the path to a previously persisted index: + +.. code-block:: python + + from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator + + retrieve = RetrievalOperator( + task_id="retrieve_context", + query="What are Airflow's key features?", + index_persist_dir="/opt/airflow/data/my_index", + llm_conn_id="openai_default", + ) + +Query Templating +---------------- + +The ``query`` field supports Jinja templating, so it can be set dynamically +from upstream task output or Dag run configuration: + +.. code-block:: python + + retrieve = RetrievalOperator( + task_id="retrieve_context", + query="{{ dag_run.conf['question'] }}", + index_persist_dir="/opt/airflow/data/my_index", + llm_conn_id="openai_default", + top_k=10, + ) + +Output Shape +------------ + +The operator returns a dict: + +.. code-block:: python + + { + "question": "What are Airflow's key features?", + "chunks": [ + { + "text": "Airflow provides ...", + "score": 0.95, + "metadata": {"source": "overview.txt"}, + "source": "node-abc123", + }, + ... + ], + } + +Each chunk includes ``text``, ``score`` (similarity), ``metadata``, and +``source`` (the LlamaIndex node ID). This output pairs naturally with +``LLMOperator`` for RAG synthesis using Jinja templates. + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``query`` + - (required) + - The query string to search for. Supports Jinja templating. + * - ``index_persist_dir`` + - (required) + - Path to the persisted LlamaIndex index directory. + * - ``llm_conn_id`` + - ``pydanticai_default`` + - Airflow connection ID for the embedding API. + * - ``embed_model`` + - ``text-embedding-3-small`` + - Embedding model name. + * - ``top_k`` + - ``5`` + - Number of top results to retrieve. diff --git a/providers/common/ai/provider.yaml b/providers/common/ai/provider.yaml index 92d826ffb0a89..57dbdf86e1122 100644 --- a/providers/common/ai/provider.yaml +++ b/providers/common/ai/provider.yaml @@ -53,6 +53,12 @@ integrations: - integration-name: LangChain external-doc-url: https://python.langchain.com/ tags: [ai] + - integration-name: LlamaIndex + external-doc-url: https://docs.llamaindex.ai/ + how-to-guide: + - /docs/apache-airflow-providers-common-ai/operators/llamaindex_embedding.rst + - /docs/apache-airflow-providers-common-ai/operators/llamaindex_retrieval.rst + tags: [ai] hooks: - integration-name: Pydantic AI @@ -64,6 +70,9 @@ hooks: - integration-name: LangChain python-modules: - airflow.providers.common.ai.hooks.langchain + - integration-name: LlamaIndex + python-modules: + - airflow.providers.common.ai.hooks.llamaindex plugins: - name: hitl_review @@ -365,6 +374,8 @@ operators: - airflow.providers.common.ai.operators.llm_sql - airflow.providers.common.ai.operators.llm_schema_compare - airflow.providers.common.ai.operators.document_loader + - airflow.providers.common.ai.operators.llamaindex_embedding + - airflow.providers.common.ai.operators.llamaindex_retrieval task-decorators: - class-name: airflow.providers.common.ai.decorators.agent.agent_task diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py new file mode 100644 index 0000000000000..7c3272c4cdb68 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for LlamaIndex integration with Airflow connections.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseHook + +if TYPE_CHECKING: + from llama_index.core.base.embeddings.base import BaseEmbedding + from llama_index.core.llms.llm import LLM + + +class LlamaIndexHook(BaseHook): + """ + Bridge Airflow connections to LlamaIndex's Settings singleton. + + Reuses the ``pydanticai`` connection type so users configure a single + connection for both pydantic-ai operators and LlamaIndex operators. + + :param llm_conn_id: Airflow connection ID for the LLM/embedding provider. + :param embed_conn_id: Separate connection for embeddings. Defaults to + ``llm_conn_id`` when not provided. + :param embed_model: Embedding model name (e.g. ``text-embedding-3-small``). + :param llm_model: LLM model name (e.g. ``gpt-4o``). Only needed when + configuring ``Settings.llm``. + """ + + conn_name_attr = "llm_conn_id" + default_conn_name = "pydanticai_default" + conn_type = "pydanticai" + hook_name = "LlamaIndex" + + def __init__( + self, + llm_conn_id: str = "pydanticai_default", + embed_conn_id: str | None = None, + embed_model: str = "text-embedding-3-small", + llm_model: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.llm_conn_id = llm_conn_id + self.embed_conn_id = embed_conn_id or llm_conn_id + self.embed_model = embed_model + self.llm_model = llm_model + + def _resolve_connection_kwargs(self, conn_id: str) -> dict[str, Any]: + """Extract API key and base URL from an Airflow connection.""" + conn = self.get_connection(conn_id) + kwargs: dict[str, Any] = {} + if conn.password: + kwargs["api_key"] = conn.password + if conn.host: + kwargs["api_base"] = conn.host + return kwargs + + def get_embedding_model(self) -> BaseEmbedding: + """ + Return a LlamaIndex embedding model configured from the Airflow connection. + + Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials. + """ + from llama_index.embeddings.openai import OpenAIEmbedding + + conn_kwargs = self._resolve_connection_kwargs(self.embed_conn_id) + return OpenAIEmbedding(model=self.embed_model, **conn_kwargs) + + def get_llm(self) -> LLM: + """ + Return a LlamaIndex LLM configured from the Airflow connection. + + Requires ``llm_model`` to be set on the hook. + """ + if not self.llm_model: + raise ValueError("llm_model must be set to use get_llm()") + + from llama_index.llms.openai import OpenAI + + conn_kwargs = self._resolve_connection_kwargs(self.llm_conn_id) + return OpenAI(model=self.llm_model, **conn_kwargs) + + def configure_settings(self) -> None: + """ + Configure LlamaIndex's global Settings with models from Airflow connections. + + Sets ``Settings.embed_model`` always, and ``Settings.llm`` when + ``llm_model`` is provided. + """ + from llama_index.core import Settings + + Settings.embed_model = self.get_embedding_model() + if self.llm_model: + Settings.llm = self.get_llm() diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py new file mode 100644 index 0000000000000..acbbc46dabbc4 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for document chunking and embedding via LlamaIndex.""" + +from __future__ import annotations + +import os +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseOperator + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class EmbeddingOperator(BaseOperator): + """ + Chunk documents and produce embedding vectors using LlamaIndex. + + Bridges document loading (Airflow provider hooks returning text) and + vector storage (pgvector, Pinecone, Weaviate ingest operators). Input + is ``list[dict]`` with ``text`` and ``metadata`` keys; output includes + the embedding vectors ready for downstream storage. + + :param documents: List of dicts with ``text`` and ``metadata`` keys, + typically from ``DocumentLoaderOperator`` or a ``@task``. + :param llm_conn_id: Airflow connection ID for the embedding API. + :param embed_model: Embedding model name (default: ``text-embedding-3-small``). + :param chunk_size: Chunk size for the sentence splitter (default: 512). + :param chunk_overlap: Overlap between chunks (default: 50). + :param persist_dir: Optional directory path to persist the LlamaIndex + index for later retrieval. + """ + + template_fields: Sequence[str] = ("documents", "llm_conn_id", "persist_dir") + + def __init__( + self, + *, + documents: list[dict[str, Any]], + llm_conn_id: str = "pydanticai_default", + embed_model: str = "text-embedding-3-small", + chunk_size: int = 512, + chunk_overlap: int = 50, + persist_dir: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.documents = documents + self.llm_conn_id = llm_conn_id + self.embed_model = embed_model + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.persist_dir = persist_dir + + def execute(self, context: Context) -> dict[str, Any]: + from llama_index.core import Document, StorageContext, VectorStoreIndex + from llama_index.core.node_parser import SentenceSplitter + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) + hook.configure_settings() + + llama_docs = [Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents] + + splitter = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + nodes = splitter.get_nodes_from_documents(llama_docs) + self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) + + storage_context = StorageContext.from_defaults() + VectorStoreIndex(nodes, storage_context=storage_context, show_progress=False) + + if self.persist_dir: + os.makedirs(self.persist_dir, exist_ok=True) + storage_context.persist(persist_dir=self.persist_dir) + self.log.info("Index persisted to %s", self.persist_dir) + + chunks = [] + for node in nodes: + chunk: dict[str, Any] = { + "text": node.text, + "metadata": node.metadata, + } + if node.embedding: + chunk["vector"] = node.embedding + chunks.append(chunk) + + return { + "document_count": len(llama_docs), + "chunk_count": len(nodes), + "persist_dir": self.persist_dir, + "chunks": chunks, + } diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py new file mode 100644 index 0000000000000..6089f7a4c628d --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for semantic retrieval via a persisted LlamaIndex index.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import BaseOperator + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class RetrievalOperator(BaseOperator): + """ + Retrieve relevant document chunks from a persisted LlamaIndex index. + + Loads a previously persisted vector store index and performs similarity + search against the provided query. The output is a list of chunks with + text, score, metadata, and source information ready for downstream + synthesis via ``LLMOperator``. + + :param query: The query string to search for. Supports Jinja templating. + :param index_persist_dir: Path to the persisted LlamaIndex index directory. + :param llm_conn_id: Airflow connection ID for the embedding API + (needed to embed the query vector). + :param embed_model: Embedding model name (default: ``text-embedding-3-small``). + :param top_k: Number of top results to retrieve (default: 5). + """ + + template_fields: Sequence[str] = ("query", "index_persist_dir", "llm_conn_id") + + def __init__( + self, + *, + query: str, + index_persist_dir: str, + llm_conn_id: str = "pydanticai_default", + embed_model: str = "text-embedding-3-small", + top_k: int = 5, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.query = query + self.index_persist_dir = index_persist_dir + self.llm_conn_id = llm_conn_id + self.embed_model = embed_model + self.top_k = top_k + + def execute(self, context: Context) -> dict[str, Any]: + from llama_index.core import StorageContext, load_index_from_storage + + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) + hook.configure_settings() + + storage_context = StorageContext.from_defaults(persist_dir=self.index_persist_dir) + index = load_index_from_storage(storage_context) + + retriever = index.as_retriever(similarity_top_k=self.top_k) + results = retriever.retrieve(self.query) + self.log.info("Retrieved %d chunks for query: %s", len(results), self.query[:100]) + + chunks = [] + for node_with_score in results: + node = node_with_score.node + chunks.append( + { + "text": node.get_content(), + "score": node_with_score.score, + "metadata": node.metadata, + "source": node.node_id, + } + ) + + return { + "question": self.query, + "chunks": chunks, + } diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py new file mode 100644 index 0000000000000..3b119e3e5b439 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + +class TestLlamaIndexHookInit: + def test_default_params(self): + hook = LlamaIndexHook() + assert hook.llm_conn_id == "pydanticai_default" + assert hook.embed_conn_id == "pydanticai_default" + assert hook.embed_model == "text-embedding-3-small" + assert hook.llm_model is None + + def test_separate_embed_conn_id(self): + hook = LlamaIndexHook(llm_conn_id="llm_conn", embed_conn_id="embed_conn") + assert hook.llm_conn_id == "llm_conn" + assert hook.embed_conn_id == "embed_conn" + + def test_embed_conn_defaults_to_llm_conn(self): + hook = LlamaIndexHook(llm_conn_id="my_conn") + assert hook.embed_conn_id == "my_conn" + + +class TestResolveConnectionKwargs: + @patch.object(LlamaIndexHook, "get_connection") + def test_extracts_password_as_api_key(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test-key" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_key": "sk-test-key"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_extracts_host_as_api_base(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "" + mock_conn.host = "https://custom.api.com" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_base": "https://custom.api.com"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_both_password_and_host(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-key" + mock_conn.host = "https://api.example.com" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {"api_key": "sk-key", "api_base": "https://api.example.com"} + + @patch.object(LlamaIndexHook, "get_connection") + def test_empty_fields_return_empty_dict(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + hook = LlamaIndexHook() + result = hook._resolve_connection_kwargs("test_conn") + + assert result == {} + + +def _make_mock_openai_embedding_module(): + mock_module = MagicMock() + mock_cls = MagicMock() + mock_module.OpenAIEmbedding = mock_cls + return mock_module, mock_cls + + +def _make_mock_openai_llm_module(): + mock_module = MagicMock() + mock_cls = MagicMock() + mock_module.OpenAI = mock_cls + return mock_module, mock_cls + + +class TestGetEmbeddingModel: + @patch.object(LlamaIndexHook, "get_connection") + def test_returns_openai_embedding(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() + + hook = LlamaIndexHook(embed_model="text-embedding-3-large") + with patch.dict("sys.modules", {"llama_index.embeddings.openai": mock_embed_module}): + result = hook.get_embedding_model() + + mock_embed_cls.assert_called_once_with(model="text-embedding-3-large", api_key="sk-test") + assert result == mock_embed_cls.return_value + + +class TestGetLLM: + def test_raises_without_llm_model(self): + hook = LlamaIndexHook() + with pytest.raises(ValueError, match="llm_model must be set"): + hook.get_llm() + + @patch.object(LlamaIndexHook, "get_connection") + def test_returns_openai_llm(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() + + hook = LlamaIndexHook(llm_model="gpt-4o") + with patch.dict("sys.modules", {"llama_index.llms.openai": mock_llm_module}): + result = hook.get_llm() + + mock_llm_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test") + assert result == mock_llm_cls.return_value + + +class TestConfigureSettings: + @patch.object(LlamaIndexHook, "get_connection") + def test_sets_embed_model(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() + mock_settings_module = MagicMock() + + hook = LlamaIndexHook() + with patch.dict( + "sys.modules", + { + "llama_index.embeddings.openai": mock_embed_module, + "llama_index": MagicMock(), + "llama_index.core": mock_settings_module, + }, + ): + hook.configure_settings() + + assert mock_settings_module.Settings.embed_model == mock_embed_cls.return_value + + @patch.object(LlamaIndexHook, "get_connection") + def test_sets_llm_when_model_provided(self, mock_get_conn): + mock_conn = MagicMock() + mock_conn.password = "sk-test" + mock_conn.host = "" + mock_get_conn.return_value = mock_conn + + mock_embed_module, _ = _make_mock_openai_embedding_module() + mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() + mock_settings_module = MagicMock() + + hook = LlamaIndexHook(llm_model="gpt-4o") + with patch.dict( + "sys.modules", + { + "llama_index.embeddings.openai": mock_embed_module, + "llama_index.llms.openai": mock_llm_module, + "llama_index": MagicMock(), + "llama_index.core": mock_settings_module, + }, + ): + hook.configure_settings() + + assert mock_settings_module.Settings.llm == mock_llm_cls.return_value diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py new file mode 100644 index 0000000000000..ee3e8c51a562d --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator + + +def _make_mock_node(text="chunk text", metadata=None, embedding=None): + node = MagicMock() + node.text = text + node.metadata = metadata or {} + node.embedding = embedding + return node + + +def _make_mock_llamaindex_modules(nodes=None): + """Create mock llama_index modules for sys.modules injection.""" + if nodes is None: + nodes = [_make_mock_node()] + + mock_core = MagicMock() + mock_core.Document = MagicMock(side_effect=lambda text, metadata: MagicMock(text=text, metadata=metadata)) + mock_core.StorageContext.from_defaults.return_value = MagicMock() + mock_core.VectorStoreIndex = MagicMock() + + mock_node_parser = MagicMock() + mock_splitter = MagicMock() + mock_splitter.get_nodes_from_documents.return_value = nodes + mock_node_parser.SentenceSplitter.return_value = mock_splitter + + return ( + { + "llama_index": MagicMock(), + "llama_index.core": mock_core, + "llama_index.core.node_parser": mock_node_parser, + "llama_index.embeddings": MagicMock(), + "llama_index.embeddings.openai": MagicMock(), + }, + mock_core, + mock_splitter, + ) + + +class TestEmbeddingOperator: + def test_template_fields(self): + expected = {"documents", "llm_conn_id", "persist_dir"} + assert set(EmbeddingOperator.template_fields) == expected + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_execute_returns_expected_shape(self, mock_hook_cls): + docs = [{"text": "Hello world", "metadata": {"source": "test"}}] + nodes = [_make_mock_node(text="Hello world", metadata={"source": "test"})] + mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "document_count" in result + assert "chunk_count" in result + assert "persist_dir" in result + assert "chunks" in result + assert result["document_count"] == 1 + assert result["chunk_count"] == 1 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunking_node_count(self, mock_hook_cls): + docs = [{"text": "A long document " * 100, "metadata": {}}] + nodes = [_make_mock_node(text=f"chunk {i}") for i in range(5)] + mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["chunk_count"] == 5 + assert len(result["chunks"]) == 5 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_persist_dir_creates_and_persists(self, mock_hook_cls, tmp_path): + docs = [{"text": "test", "metadata": {}}] + persist_dir = str(tmp_path / "index_storage") + mock_modules, mock_core, _ = _make_mock_llamaindex_modules() + mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn", persist_dir=persist_dir) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_storage_ctx.persist.assert_called_once_with(persist_dir=persist_dir) + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_no_persist_when_none(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, mock_core, _ = _make_mock_llamaindex_modules() + mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_storage_ctx.persist.assert_not_called() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_have_text_and_metadata(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {"src": "a"}}] + nodes = [_make_mock_node(text="chunk1", metadata={"src": "a"})] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + chunk = result["chunks"][0] + assert "text" in chunk + assert "metadata" in chunk + assert chunk["text"] == "chunk1" + assert chunk["metadata"] == {"src": "a"} + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_include_vector_when_present(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + nodes = [_make_mock_node(text="chunk1", embedding=[0.1, 0.2, 0.3])] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["chunks"][0]["vector"] == [0.1, 0.2, 0.3] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_omit_vector_when_not_present(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + nodes = [_make_mock_node(text="chunk1", embedding=None)] + mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + + op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "vector" not in result["chunks"][0] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_hook_configured_with_params(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, _, _ = _make_mock_llamaindex_modules() + + op = EmbeddingOperator( + task_id="test", + documents=docs, + llm_conn_id="custom_conn", + embed_model="text-embedding-ada-002", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") + mock_hook_cls.return_value.configure_settings.assert_called_once() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_splitter_params_forwarded(self, mock_hook_cls): + docs = [{"text": "test", "metadata": {}}] + mock_modules, _, _ = _make_mock_llamaindex_modules() + mock_node_parser = mock_modules["llama_index.core.node_parser"] + + op = EmbeddingOperator( + task_id="test", + documents=docs, + llm_conn_id="my_conn", + chunk_size=256, + chunk_overlap=25, + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_node_parser.SentenceSplitter.assert_called_once_with(chunk_size=256, chunk_overlap=25) diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py new file mode 100644 index 0000000000000..0c85e86c214dc --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator + + +def _make_mock_node_with_score(text="chunk text", score=0.9, metadata=None, node_id="node-1"): + node = MagicMock() + node.get_content.return_value = text + node.metadata = metadata or {} + node.node_id = node_id + + node_with_score = MagicMock() + node_with_score.node = node + node_with_score.score = score + return node_with_score + + +def _make_mock_llamaindex_modules(retrieval_results=None): + """Create mock llama_index modules for sys.modules injection.""" + if retrieval_results is None: + retrieval_results = [_make_mock_node_with_score()] + + mock_core = MagicMock() + mock_index = MagicMock() + mock_retriever = MagicMock() + mock_retriever.retrieve.return_value = retrieval_results + mock_index.as_retriever.return_value = mock_retriever + mock_core.load_index_from_storage.return_value = mock_index + + return ( + { + "llama_index": MagicMock(), + "llama_index.core": mock_core, + "llama_index.embeddings": MagicMock(), + "llama_index.embeddings.openai": MagicMock(), + }, + mock_core, + mock_index, + mock_retriever, + ) + + +class TestRetrievalOperator: + def test_template_fields(self): + expected = {"query", "index_persist_dir", "llm_conn_id"} + assert set(RetrievalOperator.template_fields) == expected + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_execute_returns_expected_shape(self, mock_hook_cls): + results = [_make_mock_node_with_score(text="relevant chunk", score=0.95)] + mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="What is Airflow?", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert "question" in result + assert "chunks" in result + assert result["question"] == "What is Airflow?" + assert len(result["chunks"]) == 1 + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_chunks_have_required_keys(self, mock_hook_cls): + results = [ + _make_mock_node_with_score( + text="chunk text", score=0.8, metadata={"file": "doc.txt"}, node_id="abc-123" + ) + ] + mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="test query", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + chunk = result["chunks"][0] + assert chunk["text"] == "chunk text" + assert chunk["score"] == 0.8 + assert chunk["metadata"] == {"file": "doc.txt"} + assert chunk["source"] == "abc-123" + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_top_k_forwarded_to_retriever(self, mock_hook_cls): + mock_modules, _, mock_index, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + top_k=10, + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_index.as_retriever.assert_called_once_with(similarity_top_k=10) + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_query_value_in_output(self, mock_hook_cls): + mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="How does Airflow scheduling work?", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert result["question"] == "How does Airflow scheduling work?" + assert result["chunks"] == [] + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_multiple_results_returned(self, mock_hook_cls): + results = [ + _make_mock_node_with_score(text=f"chunk {i}", score=0.9 - i * 0.1, node_id=f"node-{i}") + for i in range(3) + ] + mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + result = op.execute(context=MagicMock()) + + assert len(result["chunks"]) == 3 + assert result["chunks"][0]["text"] == "chunk 0" + assert result["chunks"][2]["text"] == "chunk 2" + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_hook_configured_with_params(self, mock_hook_cls): + mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/tmp/index", + llm_conn_id="custom_conn", + embed_model="text-embedding-ada-002", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") + mock_hook_cls.return_value.configure_settings.assert_called_once() + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) + def test_persist_dir_passed_to_storage_context(self, mock_hook_cls): + mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules([]) + + op = RetrievalOperator( + task_id="test", + query="test", + index_persist_dir="/data/my_index", + llm_conn_id="my_conn", + ) + + with patch.dict("sys.modules", mock_modules): + op.execute(context=MagicMock()) + + mock_core.StorageContext.from_defaults.assert_called_once_with(persist_dir="/data/my_index") From a6e176f2c594dd93aef8529368701d3341213677 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 02:21:58 +0100 Subject: [PATCH 2/7] Refactor LlamaIndex hook + operators: no Settings mutation, BYO models, cloud URIs Same playbook as #67192 (LangChain) and #67120 (DocumentLoader) plus three LlamaIndex-specific architectural fixes: Critical fixes - Stop mutating LlamaIndex's global ``Settings`` singleton. The previous ``LlamaIndexHook.configure_settings()`` wrote ``Settings.embed_model`` / ``Settings.llm`` process-wide, which leaks across concurrent tasks in the same worker. Replaced with per-call ``embed_model=`` / ``llm=`` parameters on ``VectorStoreIndex(...)`` and ``load_index_from_storage(...)``. - Own ``llamaindex`` connection type instead of squatting on ``pydanticai``. Mirrors the LangChain / CrewAI fix. - Remove ``documents`` from ``EmbeddingOperator.template_fields``. ``list[dict]`` doesn't survive Jinja stringification, and worse, a user document containing literal ``{{ var.value.api_key }}`` would leak secrets into the embedding store. Bind via ``loader.output`` instead. BYO embedding/LLM for non-OpenAI vendors - LlamaIndex doesn't ship an ``init_chat_model`` / ``init_embedding_model`` equivalent (verified in ``llama_index.core.embeddings.utils.resolve_embed_model`` -- only ``"default"`` / ``"local"`` / ``"clip:"`` dispatch). The hook therefore covers OpenAI (matching LlamaIndex's own ``resolve_embed_model("default")`` behaviour) and operators accept a pre-built ``BaseEmbedding`` / ``LLM`` instance to bypass the hook for Cohere / Bedrock / Vertex / HuggingFace / etc. Cloud-URI persistence - ``EmbeddingOperator.persist_dir`` and ``RetrievalOperator.index_persist_dir`` accept storage URIs (``s3://``, ``gs://``, ``azure://``) resolved via ``ObjectStoragePath`` and fsspec, matching the merged ``DocumentLoaderOperator`` pattern. Hook plumbing playbook (mirrors LangChain / CrewAI / DocumentLoader) - ``conn_type = "llamaindex"`` + new ``connection-types`` entry in ``provider.yaml`` with ``embed_model`` / ``llm_model`` conn-fields. - ``default_conn_name`` resolves at runtime via ``llm_conn_id: str | None = None``. - ``_resolve_model`` honours ``conn.extra_dejson`` for parity with the sibling hooks (swallows ``JSONDecodeError``, applies secret masking). - ``get_ui_field_behaviour`` added. - ``[llamaindex]`` extra in ``pyproject.toml`` pinning ``llama-index-core``, ``llama-index-embeddings-openai``, ``llama-index-llms-openai`` (enough to back the hook's default OpenAI return values). Same in the ``dev`` group. Misc operator/test fixes - Wrap lazy ``llama_index`` imports with ``AirflowOptionalProviderFeatureException`` so missing extras surface cleanly. - ``RetrievalOperator`` returns ``{"query": ..., "chunks": [...]}`` (was ``"question"``) and ``chunks[*].node_id`` (was the misleading ``"source"`` key). - ``RetrievalOperator`` raises ``FileNotFoundError`` with a "did you run EmbeddingOperator first?" hint when ``index_persist_dir`` is missing. - All three test files get an autouse fixture stubbing ``llama_index.*`` in ``sys.modules`` so ``@patch`` resolves without ``llama-index-*`` packages installed in CI's non-DB test env (mirrors apache/airflow#67237). - New ``example_llamaindex_hook.py`` with ``[START howto_*]`` markers for the docs to ``exampleinclude``. --- providers/common/ai/docs/hooks/index.rst | 6 + providers/common/ai/docs/hooks/llamaindex.rst | 126 ++++---- .../docs/operators/llamaindex_embedding.rst | 160 +++++----- .../docs/operators/llamaindex_retrieval.rst | 135 +++++---- providers/common/ai/provider.yaml | 34 +++ providers/common/ai/pyproject.toml | 8 + .../example_dags/example_llamaindex_hook.py | 147 ++++++++++ .../providers/common/ai/get_provider_info.py | 40 +++ .../providers/common/ai/hooks/llamaindex.py | 171 ++++++++--- .../ai/operators/llamaindex_embedding.py | 139 ++++++--- .../ai/operators/llamaindex_retrieval.py | 135 ++++++--- .../unit/common/ai/hooks/test_llamaindex.py | 271 +++++++++-------- .../ai/operators/test_llamaindex_embedding.py | 275 ++++++++---------- .../ai/operators/test_llamaindex_retrieval.py | 255 ++++++++-------- 14 files changed, 1159 insertions(+), 743 deletions(-) create mode 100644 providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py diff --git a/providers/common/ai/docs/hooks/index.rst b/providers/common/ai/docs/hooks/index.rst index 3d05ba8edae13..9426280c9eb36 100644 --- a/providers/common/ai/docs/hooks/index.rst +++ b/providers/common/ai/docs/hooks/index.rst @@ -40,6 +40,12 @@ Choosing a hook - Direct LangChain access for tasks that compose ``Runnable``\\s, use the LangChain agent surface, or need LangChain-native chat / embedding model objects. Independent of the pydantic-ai-backed operators. + * - :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` + - Backs the LlamaIndex ``EmbeddingOperator`` and ``RetrievalOperator``. + Returns LlamaIndex-native ``BaseEmbedding`` / ``LLM`` objects (OpenAI + by default). For non-OpenAI vendors, pass a pre-built + ``BaseEmbedding`` / ``LLM`` instance straight to the operator and + bypass the hook. Hook guides ----------- diff --git a/providers/common/ai/docs/hooks/llamaindex.rst b/providers/common/ai/docs/hooks/llamaindex.rst index ff942a2f65bd4..02255b72f6cf6 100644 --- a/providers/common/ai/docs/hooks/llamaindex.rst +++ b/providers/common/ai/docs/hooks/llamaindex.rst @@ -21,67 +21,95 @@ ================== Use :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` to -bridge Airflow connections to LlamaIndex's ``Settings`` singleton. The hook -reuses the ``pydanticai`` connection type, so users configure a single -connection for both pydantic-ai operators and LlamaIndex operators. - -.. seealso:: - :ref:`Connection configuration ` - -What It Does ------------- - -The hook resolves API keys and base URLs from Airflow connections and uses -them to configure LlamaIndex's embedding models, LLMs, and global settings. -This eliminates manual ``Settings.embed_model = ...`` boilerplate in every -task that uses LlamaIndex. - -Configuration -------------- - -``LlamaIndexHook`` reuses the ``pydanticai`` connection type. Set the API key -in the **Password** field and optionally a custom endpoint in the **Host** -field. - -Separate Embedding and LLM Connections --------------------------------------- - -RAG pipelines often use different providers for embeddings and chat. The hook -supports an optional ``embed_conn_id`` parameter that defaults to the main -``llm_conn_id``: - -.. code-block:: python - - from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook - - hook = LlamaIndexHook( - llm_conn_id="openai_default", - embed_conn_id="embedding_provider", - embed_model="text-embedding-3-large", - llm_model="gpt-4o", - ) - hook.configure_settings() +bridge an Airflow connection to `LlamaIndex `__ +chat and embedding models. The hook reads credentials (API key, optional +base URL) from a connection of type ``llamaindex`` and returns native +LlamaIndex objects ready to pass to ``VectorStoreIndex(..., embed_model=...)``, +``load_index_from_storage(..., embed_model=...)``, or +``index.as_retriever(..., llm=...)``. + +The hook deliberately does **not** mutate LlamaIndex's global ``Settings`` +singleton. Operators pass the resolved model directly to LlamaIndex +constructors, so concurrent tasks in the same worker don't race on shared +state. + +OpenAI by default, BYO for other vendors +---------------------------------------- + +LlamaIndex does not ship a universal ``init_chat_model`` / +``init_embedding_model`` equivalent (each vendor is a separate package +under ``llama-index-llms-*`` / ``llama-index-embeddings-*`` with its own +constructor kwargs). The hook therefore covers the OpenAI-compatible +surface that matches LlamaIndex's own ``resolve_embed_model("default")`` +behaviour: + +- ``hook.get_embedding_model()`` returns an ``OpenAIEmbedding`` configured + from the connection. +- ``hook.get_llm()`` returns an ``OpenAI`` LLM configured from the + connection. + +For other vendors (Cohere, Bedrock, Vertex AI, HuggingFace, ...), +instantiate the LlamaIndex class directly in a ``@task`` and pass it to +the operator's ``embed_model=`` / ``llm=`` parameter -- both +:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +and +:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` +accept a pre-built ``BaseEmbedding`` / ``LLM`` instance and bypass the +hook: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py + :language: python + :start-after: [START howto_hook_llamaindex_byo_embed_model] + :end-before: [END howto_hook_llamaindex_byo_embed_model] + +Install the per-vendor LlamaIndex integration package separately: +``pip install llama-index-embeddings-cohere``, ``...-bedrock``, +``...-huggingface``, ``llama-index-llms-anthropic``, etc. + +Connection Configuration +------------------------ + +The hook reads credentials from the Airflow connection of type ``llamaindex``: + +- **password** -- API key (passed as ``api_key`` to ``OpenAIEmbedding`` / + ``OpenAI``). +- **host** -- Optional base URL (passed as ``api_base``; useful for custom + OpenAI-compatible endpoints, Ollama, vLLM). +- **extra** JSON -- + ``{"embed_model": "text-embedding-3-small", "llm_model": "gpt-4o"}`` -- + default model identifiers stored on the connection. Parameters ---------- .. list-table:: :header-rows: 1 - :widths: 25 15 60 + :widths: 25 25 50 * - Parameter - Default - Description * - ``llm_conn_id`` - - ``pydanticai_default`` + - ``llamaindex_default`` - Airflow connection ID for the LLM/embedding provider. * - ``embed_conn_id`` - - Same as ``llm_conn_id`` - - Separate connection for embeddings (optional). + - ``None`` (falls back to ``llm_conn_id``) + - Optional separate Airflow connection ID for the embedding provider. * - ``embed_model`` - - ``text-embedding-3-small`` - - Embedding model name. + - ``None`` (falls back to ``extra["embed_model"]``) + - Embedding model name, e.g. ``text-embedding-3-small``. * - ``llm_model`` - - ``None`` - - LLM model name. Required for ``get_llm()`` and ``configure_settings()`` - LLM setup. + - ``None`` (falls back to ``extra["llm_model"]``) + - LLM model name, e.g. ``gpt-4o``. Required when calling ``get_llm()``. + +Dependencies +------------ + +Install the ``llamaindex`` extra:: + + pip install apache-airflow-providers-common-ai[llamaindex] + +That extra installs ``llama-index-core``, ``llama-index-embeddings-openai``, +and ``llama-index-llms-openai`` -- enough to back the hook's default +OpenAI return values. For other LlamaIndex vendor packages, install +their integration package separately. diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst index 2a32d056dc24a..0a2c47029d0c1 100644 --- a/providers/common/ai/docs/operators/llamaindex_embedding.rst +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -17,119 +17,99 @@ .. _howto/operator:llamaindex_embedding: -``EmbeddingOperator`` -===================== +LlamaIndex ``EmbeddingOperator`` +================================ -Use :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` -to chunk documents and produce embedding vectors using LlamaIndex. This operator -bridges document loading (Airflow provider hooks returning text) and vector -storage (pgvector, Pinecone, Weaviate ingest operators). +Chunk a ``list[dict]`` of documents and produce embedding vectors using +LlamaIndex. Designed to feed the output of +:class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` +into vector storage (pgvector, Pinecone, Weaviate, ...). -Basic Usage ------------ - -Provide a list of documents with ``text`` and ``metadata`` keys. The operator -chunks the documents, embeds them, and returns the results: - -.. code-block:: python - - from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator - - embed = EmbeddingOperator( - task_id="embed_docs", - documents=[ - {"text": "Airflow is a workflow orchestration platform.", "metadata": {"source": "docs"}}, - {"text": "LlamaIndex is a data framework for LLM applications.", "metadata": {"source": "docs"}}, - ], - llm_conn_id="openai_default", - ) - -Connection Configuration ------------------------- +The operator passes the embedding model **directly** to +``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate +LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the same +worker process don't race on shared model state. -The operator uses :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` -internally. Configure your embedding API credentials via the ``pydanticai`` -connection type. - -.. seealso:: - :ref:`Connection configuration ` - -Chunking Parameters -------------------- - -Control how documents are split into chunks before embedding: - -.. code-block:: python - - embed = EmbeddingOperator( - task_id="embed_docs", - documents=documents, - llm_conn_id="openai_default", - chunk_size=256, - chunk_overlap=25, - ) +Basic usage +----------- -Index Persistence ------------------ +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py + :language: python + :start-after: [START howto_hook_llamaindex_embed] + :end-before: [END howto_hook_llamaindex_embed] -Set ``persist_dir`` to save the LlamaIndex index for later retrieval via -:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator`: +The ``documents`` parameter binds to ``loader.output`` (XCom direct), **not** +via Jinja -- ``list[dict]`` doesn't survive Jinja stringification, so the +parameter is intentionally not in ``template_fields``. -.. code-block:: python +Bring-your-own embedding model +------------------------------ - embed = EmbeddingOperator( - task_id="embed_docs", - documents=documents, - llm_conn_id="openai_default", - persist_dir="/opt/airflow/data/my_index", - ) +LlamaIndex doesn't ship a universal embedding-model initializer, so the +operator's ``embed_model`` parameter accepts either: -Output Shape ------------- +* a string model name (e.g. ``"text-embedding-3-small"``) -- the operator + constructs an ``OpenAIEmbedding`` via + :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` + using ``llm_conn_id``, or +* a pre-built ``BaseEmbedding`` instance -- bypass the hook entirely. Use + this for Cohere, Bedrock, Vertex, HuggingFace, etc.: -The operator returns a dict: +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py + :language: python + :start-after: [START howto_hook_llamaindex_byo_embed_model] + :end-before: [END howto_hook_llamaindex_byo_embed_model] -.. code-block:: python +Persisting to cloud storage +--------------------------- - { - "document_count": 2, - "chunk_count": 5, - "persist_dir": "/opt/airflow/data/my_index", - "chunks": [ - {"text": "chunk text", "metadata": {"source": "docs"}, "vector": [0.1, ...]}, - ... - ], - } +``persist_dir`` accepts local paths and storage URIs (``s3://``, ``gs://``, +``azure://``, ``file://``) resolved via +:class:`~airflow.sdk.ObjectStoragePath`. Pass ``persist_conn_id`` to +point at the Airflow connection that holds the cloud credentials: -Each chunk includes ``text``, ``metadata``, and optionally ``vector`` (the -embedding array). The ``chunks`` list is ready for downstream consumption by -vector store ingest operators. +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py + :language: python + :start-after: [START howto_hook_llamaindex_cloud_persist] + :end-before: [END howto_hook_llamaindex_cloud_persist] Parameters ---------- .. list-table:: :header-rows: 1 - :widths: 25 15 60 + :widths: 25 75 * - Parameter - - Default - Description * - ``documents`` - - (required) - - List of dicts with ``text`` and ``metadata`` keys. - * - ``llm_conn_id`` - - ``pydanticai_default`` - - Airflow connection ID for the embedding API. + - ``list[dict]`` with ``text`` / ``metadata`` keys. Bind via + ``loader.output``; **not** templated. * - ``embed_model`` - - ``text-embedding-3-small`` - - Embedding model name. + - String model name OR pre-built ``BaseEmbedding`` instance. + * - ``llm_conn_id`` + - Airflow connection ID used when ``embed_model`` is a string + (default ``llamaindex_default``). * - ``chunk_size`` - - ``512`` - - Chunk size for the sentence splitter. + - Sentence-splitter chunk size (default 512). * - ``chunk_overlap`` - - ``50`` - - Overlap between chunks. + - Overlap between chunks (default 50). * - ``persist_dir`` - - ``None`` - - Directory path to persist the index. + - Local path or storage URI to persist the LlamaIndex index. + * - ``persist_conn_id`` + - Cloud credentials connection ID for ``persist_dir`` URIs. + +Output +------ + +Returns a dict with:: + + { + "document_count": int, + "chunk_count": int, + "persist_dir": str | None, + "chunks": [ + {"text": str, "metadata": dict, "vector": list[float]}, + ... + ], + } diff --git a/providers/common/ai/docs/operators/llamaindex_retrieval.rst b/providers/common/ai/docs/operators/llamaindex_retrieval.rst index ff238744cf196..13ec739d2ab13 100644 --- a/providers/common/ai/docs/operators/llamaindex_retrieval.rst +++ b/providers/common/ai/docs/operators/llamaindex_retrieval.rst @@ -17,92 +17,89 @@ .. _howto/operator:llamaindex_retrieval: -``RetrievalOperator`` -===================== - -Use :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` -to retrieve relevant document chunks from a persisted LlamaIndex index. The -operator performs similarity search against the provided query and returns -results ready for downstream synthesis via ``LLMOperator``. - -Basic Usage +LlamaIndex ``RetrievalOperator`` +================================ + +Load a persisted LlamaIndex index and run similarity search. Designed to +sit between +:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +(which builds the index) and +:class:`~airflow.providers.common.ai.operators.llm.LLMOperator` (which +synthesises an answer from the retrieved chunks). + +Passes the embedding model **directly** to +``load_index_from_storage(..., embed_model=...)`` -- no LlamaIndex +``Settings`` mutation. The embedding model must match the one used when +the index was originally built. + +Basic usage ----------- -Provide a query string and the path to a previously persisted index: - -.. code-block:: python - - from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator - - retrieve = RetrievalOperator( - task_id="retrieve_context", - query="What are Airflow's key features?", - index_persist_dir="/opt/airflow/data/my_index", - llm_conn_id="openai_default", - ) - -Query Templating ----------------- - -The ``query`` field supports Jinja templating, so it can be set dynamically -from upstream task output or Dag run configuration: +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py + :language: python + :start-after: [START howto_hook_llamaindex_retrieve] + :end-before: [END howto_hook_llamaindex_retrieve] -.. code-block:: python +``query`` is templated, so DAG-run params, XCom, and Variables all flow +through cleanly. - retrieve = RetrievalOperator( - task_id="retrieve_context", - query="{{ dag_run.conf['question'] }}", - index_persist_dir="/opt/airflow/data/my_index", - llm_conn_id="openai_default", - top_k=10, - ) +Cloud-persisted indexes +----------------------- -Output Shape ------------- +``index_persist_dir`` accepts the same local-path-or-URI shape as +``EmbeddingOperator.persist_dir``. Pass ``persist_conn_id`` to point at +the Airflow connection that holds cloud credentials. The operator raises +``FileNotFoundError`` with a clear "did you run EmbeddingOperator first?" +message when the path is missing. -The operator returns a dict: - -.. code-block:: python - - { - "question": "What are Airflow's key features?", - "chunks": [ - { - "text": "Airflow provides ...", - "score": 0.95, - "metadata": {"source": "overview.txt"}, - "source": "node-abc123", - }, - ... - ], - } +Bring-your-own embedding model +------------------------------ -Each chunk includes ``text``, ``score`` (similarity), ``metadata``, and -``source`` (the LlamaIndex node ID). This output pairs naturally with -``LLMOperator`` for RAG synthesis using Jinja templates. +Same shape as ``EmbeddingOperator``: ``embed_model`` accepts either a +string model name (OpenAI via the hook) or a pre-built ``BaseEmbedding`` +instance for non-OpenAI vendors. See the BYO example in +:doc:`llamaindex_embedding`. Parameters ---------- .. list-table:: :header-rows: 1 - :widths: 25 15 60 + :widths: 25 75 * - Parameter - - Default - Description * - ``query`` - - (required) - - The query string to search for. Supports Jinja templating. + - The query string. Templated. * - ``index_persist_dir`` - - (required) - - Path to the persisted LlamaIndex index directory. - * - ``llm_conn_id`` - - ``pydanticai_default`` - - Airflow connection ID for the embedding API. + - Local path or storage URI pointing at the persisted index. + Templated. + * - ``persist_conn_id`` + - Cloud credentials connection ID for ``index_persist_dir`` URIs. + Templated. * - ``embed_model`` - - ``text-embedding-3-small`` - - Embedding model name. + - String model name OR pre-built ``BaseEmbedding`` instance. Must + match the model used when the index was built. + * - ``llm_conn_id`` + - Airflow connection ID used when ``embed_model`` is a string + (default ``llamaindex_default``). * - ``top_k`` - - ``5`` - - Number of top results to retrieve. + - Number of top similarity results to return (default 5). + +Output +------ + +Returns a dict with:: + + { + "query": str, + "chunks": [ + { + "text": str, + "score": float, + "metadata": dict, + "node_id": str, + }, + ... + ], + } diff --git a/providers/common/ai/provider.yaml b/providers/common/ai/provider.yaml index 57dbdf86e1122..d95ff608857dd 100644 --- a/providers/common/ai/provider.yaml +++ b/providers/common/ai/provider.yaml @@ -363,6 +363,40 @@ connection-types: type: - string - 'null' + - hook-class-name: airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook + hook-name: "LlamaIndex" + connection-type: llamaindex + ui-field-behaviour: + hidden-fields: + - schema + - port + - login + relabeling: + password: API Key + placeholders: + host: "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)" + extra: '{"embed_model": "text-embedding-3-small", "llm_model": "gpt-4o"}' + conn-fields: + embed_model: + label: Embedding Model + description: > + Default LlamaIndex embedding model name (e.g. + text-embedding-3-small). The OpenAI default; for other vendors + pass a pre-built BaseEmbedding instance to the operator. + schema: + type: + - string + - 'null' + llm_model: + label: LLM Model + description: > + Default LlamaIndex LLM model name (e.g. gpt-4o). The OpenAI + default; for other vendors pass a pre-built LLM instance to + the operator. + schema: + type: + - string + - 'null' operators: - integration-name: Common AI diff --git a/providers/common/ai/pyproject.toml b/providers/common/ai/pyproject.toml index 18833cd64bc1a..c67f16b319bb7 100644 --- a/providers/common/ai/pyproject.toml +++ b/providers/common/ai/pyproject.toml @@ -98,6 +98,11 @@ dependencies = [ "langchain" = [ "langchain>=1.0.0", ] +"llamaindex" = [ + "llama-index-core>=0.13.0", + "llama-index-embeddings-openai>=0.6.0", + "llama-index-llms-openai>=0.6.0", +] "pdf" = ["pypdf>=4.0.0"] "docx" = ["python-docx>=1.0.0"] @@ -114,6 +119,9 @@ dev = [ "pydantic-ai-slim[mcp]", "apache-airflow-providers-common-sql[datafusion]", "langchain>=1.0.0", + "llama-index-core>=0.13.0", + "llama-index-embeddings-openai>=0.6.0", + "llama-index-llms-openai>=0.6.0", ] # To build docs: diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py new file mode 100644 index 0000000000000..617662c431e55 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating LlamaIndexHook + LlamaIndex operator usage. + +Each DAG covers a single pattern. The docs reference these via +``.. exampleinclude::`` so the runnable snippets stay in sync. +""" + +from __future__ import annotations + +from airflow.providers.common.ai.operators.document_loader import DocumentLoaderOperator +from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator +from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator +from airflow.providers.common.compat.sdk import dag, task + + +# [START howto_hook_llamaindex_embed] +@dag(schedule=None) +def example_llamaindex_embed(): + """Chunk + embed a directory of documents and persist the index locally.""" + + load = DocumentLoaderOperator( + task_id="load", + source_path="/opt/airflow/data/library/**/*", + file_extensions=[".pdf", ".md", ".txt"], + ) + + embed = EmbeddingOperator( + task_id="embed", + documents=load.output, # XCom direct -- never via Jinja (list[dict]) + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + chunk_size=512, + chunk_overlap=50, + persist_dir="/opt/airflow/data/library_index", + ) + + load >> embed + + +# [END howto_hook_llamaindex_embed] + +example_llamaindex_embed() + + +# [START howto_hook_llamaindex_retrieve] +@dag(schedule=None) +def example_llamaindex_retrieve(): + """Load a persisted index and run similarity search.""" + + retrieve = RetrievalOperator( + task_id="retrieve", + query="{{ params.query }}", + index_persist_dir="/opt/airflow/data/library_index", + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + top_k=5, + ) + + retrieve + + +# [END howto_hook_llamaindex_retrieve] + +example_llamaindex_retrieve() + + +# [START howto_hook_llamaindex_cloud_persist] +@dag(schedule=None) +def example_llamaindex_cloud_persist(): + """Persist the index directly to S3 -- no separate upload step.""" + + load = DocumentLoaderOperator( + task_id="load", + source_path="s3://my-bucket/library/", + source_conn_id="aws_default", + file_extensions=[".pdf"], + ) + + embed = EmbeddingOperator( + task_id="embed", + documents=load.output, + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + persist_dir="s3://my-bucket/indexes/library/", + persist_conn_id="aws_default", + ) + + load >> embed + + +# [END howto_hook_llamaindex_cloud_persist] + +example_llamaindex_cloud_persist() + + +# [START howto_hook_llamaindex_byo_embed_model] +@dag(schedule=None) +def example_llamaindex_byo_embed_model(): + """Use a non-OpenAI embedding by instantiating the LlamaIndex class directly. + + LlamaIndex doesn't ship a universal init helper, so the operator accepts + a pre-built ``BaseEmbedding`` instance and bypasses the hook entirely. + Install the matching extra: + ``pip install llama-index-embeddings-cohere``. + """ + + @task + def build_cohere_embedder(): + from llama_index.embeddings.cohere import CohereEmbedding + + from airflow.providers.common.compat.sdk import BaseHook + + conn = BaseHook.get_connection("cohere_default") + return CohereEmbedding(model_name="embed-english-v3.0", cohere_api_key=conn.password) + + @task + def empty_doc_list() -> list[dict]: + return [{"text": "Cohere demo content", "metadata": {}}] + + embed = EmbeddingOperator( + task_id="embed", + documents=empty_doc_list(), + embed_model=build_cohere_embedder(), + persist_dir="/opt/airflow/data/cohere_index", + ) + + embed + + +# [END howto_hook_llamaindex_byo_embed_model] + +example_llamaindex_byo_embed_model() diff --git a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py index d87733bb5ffa0..a3642e4895b79 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py +++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py @@ -56,6 +56,15 @@ def get_provider_info(): "external-doc-url": "https://python.langchain.com/", "tags": ["ai"], }, + { + "integration-name": "LlamaIndex", + "external-doc-url": "https://docs.llamaindex.ai/", + "how-to-guide": [ + "/docs/apache-airflow-providers-common-ai/operators/llamaindex_embedding.rst", + "/docs/apache-airflow-providers-common-ai/operators/llamaindex_retrieval.rst", + ], + "tags": ["ai"], + }, ], "hooks": [ { @@ -67,6 +76,10 @@ def get_provider_info(): "integration-name": "LangChain", "python-modules": ["airflow.providers.common.ai.hooks.langchain"], }, + { + "integration-name": "LlamaIndex", + "python-modules": ["airflow.providers.common.ai.hooks.llamaindex"], + }, ], "plugins": [ { @@ -288,6 +301,31 @@ def get_provider_info(): }, }, }, + { + "hook-class-name": "airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", + "hook-name": "LlamaIndex", + "connection-type": "llamaindex", + "ui-field-behaviour": { + "hidden-fields": ["schema", "port", "login"], + "relabeling": {"password": "API Key"}, + "placeholders": { + "host": "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)", + "extra": '{"embed_model": "text-embedding-3-small", "llm_model": "gpt-4o"}', + }, + }, + "conn-fields": { + "embed_model": { + "label": "Embedding Model", + "description": "Default LlamaIndex embedding model name (e.g. text-embedding-3-small). The OpenAI default; for other vendors pass a pre-built BaseEmbedding instance to the operator.\n", + "schema": {"type": ["string", "null"]}, + }, + "llm_model": { + "label": "LLM Model", + "description": "Default LlamaIndex LLM model name (e.g. gpt-4o). The OpenAI default; for other vendors pass a pre-built LLM instance to the operator.\n", + "schema": {"type": ["string", "null"]}, + }, + }, + }, ], "operators": [ { @@ -300,6 +338,8 @@ def get_provider_info(): "airflow.providers.common.ai.operators.llm_sql", "airflow.providers.common.ai.operators.llm_schema_compare", "airflow.providers.common.ai.operators.document_loader", + "airflow.providers.common.ai.operators.llamaindex_embedding", + "airflow.providers.common.ai.operators.llamaindex_retrieval", ], } ], diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py index 7c3272c4cdb68..68b8325cf05b7 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py +++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py @@ -20,7 +20,10 @@ from typing import TYPE_CHECKING, Any -from airflow.providers.common.compat.sdk import BaseHook +from airflow.providers.common.compat.sdk import ( + AirflowOptionalProviderFeatureException, + BaseHook, +) if TYPE_CHECKING: from llama_index.core.base.embeddings.base import BaseEmbedding @@ -29,41 +32,108 @@ class LlamaIndexHook(BaseHook): """ - Bridge Airflow connections to LlamaIndex's Settings singleton. - - Reuses the ``pydanticai`` connection type so users configure a single - connection for both pydantic-ai operators and LlamaIndex operators. - - :param llm_conn_id: Airflow connection ID for the LLM/embedding provider. - :param embed_conn_id: Separate connection for embeddings. Defaults to - ``llm_conn_id`` when not provided. - :param embed_model: Embedding model name (e.g. ``text-embedding-3-small``). - :param llm_model: LLM model name (e.g. ``gpt-4o``). Only needed when - configuring ``Settings.llm``. + Bridge an Airflow connection to LlamaIndex chat and embedding models. + + The hook resolves credentials (API key, optional API base URL) from the + Airflow connection and returns native LlamaIndex objects ready to pass + to ``VectorStoreIndex(..., embed_model=...)``, + ``load_index_from_storage(..., embed_model=...)``, or + ``index.as_retriever(..., llm=...)``. + + LlamaIndex does not ship a universal ``init_chat_model`` / + ``init_embedding_model`` equivalent (each vendor is a separate package + under ``llama-index-llms-*`` / ``llama-index-embeddings-*`` with its own + constructor kwargs). The hook therefore covers the OpenAI-compatible + surface that matches LlamaIndex's own ``resolve_embed_model("default")`` + behaviour. For other vendors (Cohere, Bedrock, Vertex, HuggingFace, ...) + instantiate the LlamaIndex class directly in your ``@task`` and pass it + to the operator's ``embed_model=`` / ``llm=`` parameter -- both + ``EmbeddingOperator`` and ``RetrievalOperator`` accept a pre-built + ``BaseEmbedding`` / ``LLM`` instance and bypass the hook in that case. + + .. note:: + + The hook deliberately does **not** mutate LlamaIndex's global + ``Settings`` singleton. Operators pass the resolved model directly + to LlamaIndex constructors so concurrent tasks in the same worker + don't race on shared state. + + Connection fields: + + * **password**: API key passed as ``api_key=``. + * **host**: Optional base URL passed as ``api_base=`` (custom endpoints, + Ollama, vLLM). + * **extra** JSON: ``{"embed_model": "text-embedding-3-small", + "llm_model": "gpt-4o"}`` -- default model identifiers stored on the + connection. + + :param llm_conn_id: Airflow connection ID for the LLM provider. Falls + back to :attr:`default_conn_name` (``"llamaindex_default"``) when + not provided. + :param embed_conn_id: Optional separate Airflow connection ID for the + embedding provider. Falls back to ``llm_conn_id`` when not set. + :param embed_model: Embedding model name (e.g. + ``"text-embedding-3-small"``). Overrides ``extra["embed_model"]`` + on the connection. + :param llm_model: LLM model name (e.g. ``"gpt-4o"``). Overrides + ``extra["llm_model"]`` on the connection. Required when calling + :meth:`get_llm`. """ conn_name_attr = "llm_conn_id" - default_conn_name = "pydanticai_default" - conn_type = "pydanticai" + default_conn_name = "llamaindex_default" + conn_type = "llamaindex" hook_name = "LlamaIndex" def __init__( self, - llm_conn_id: str = "pydanticai_default", + llm_conn_id: str | None = None, embed_conn_id: str | None = None, - embed_model: str = "text-embedding-3-small", + embed_model: str | None = None, llm_model: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.llm_conn_id = llm_conn_id - self.embed_conn_id = embed_conn_id or llm_conn_id + # Resolve at runtime so a future per-vendor subclass with its own + # ``default_conn_name`` is honoured. + self.llm_conn_id = llm_conn_id if llm_conn_id is not None else self.default_conn_name + self.embed_conn_id = embed_conn_id if embed_conn_id is not None else self.llm_conn_id self.embed_model = embed_model self.llm_model = llm_model - def _resolve_connection_kwargs(self, conn_id: str) -> dict[str, Any]: - """Extract API key and base URL from an Airflow connection.""" - conn = self.get_connection(conn_id) + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Return custom field behaviour for the Airflow connection form.""" + return { + "hidden_fields": ["schema", "port", "login"], + "relabeling": {"password": "API Key"}, + "placeholders": { + "host": "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)", + "extra": '{"embed_model": "text-embedding-3-small", "llm_model": "gpt-4o"}', + }, + } + + @staticmethod + def _resolve_model( + conn_extra: dict[str, Any], + *, + constructor_value: str | None, + extra_key: str, + kind: str, + ) -> str: + """Resolve a model identifier from the constructor arg or connection extra.""" + model_id = constructor_value or conn_extra.get(extra_key) + if not model_id: + raise ValueError( + f"No {kind} model identifier set. Pass {extra_key}= to the hook " + f'constructor or set extra={{"{extra_key}": "model-name"}} on ' + "the connection." + ) + return model_id + + @staticmethod + def _connection_kwargs(conn: Any) -> dict[str, Any]: + """Return shared OpenAI-style kwargs (api_key, api_base) from the connection.""" kwargs: dict[str, Any] = {} if conn.password: kwargs["api_key"] = conn.password @@ -76,35 +146,44 @@ def get_embedding_model(self) -> BaseEmbedding: Return a LlamaIndex embedding model configured from the Airflow connection. Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials. + Returns an ``OpenAIEmbedding`` instance; for other vendors, + instantiate the LlamaIndex class directly and pass it to the + operator's ``embed_model=`` parameter. """ - from llama_index.embeddings.openai import OpenAIEmbedding - - conn_kwargs = self._resolve_connection_kwargs(self.embed_conn_id) - return OpenAIEmbedding(model=self.embed_model, **conn_kwargs) + # Lazy: llama-index is an optional extra; importing at module level + # would break common.ai for users who haven't installed ``[llamaindex]``. + try: + from llama_index.embeddings.openai import OpenAIEmbedding + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + + conn = self.get_connection(self.embed_conn_id) + model_id = self._resolve_model( + conn.extra_dejson, + constructor_value=self.embed_model, + extra_key="embed_model", + kind="embedding", + ) + return OpenAIEmbedding(model=model_id, **self._connection_kwargs(conn)) def get_llm(self) -> LLM: """ Return a LlamaIndex LLM configured from the Airflow connection. - Requires ``llm_model`` to be set on the hook. - """ - if not self.llm_model: - raise ValueError("llm_model must be set to use get_llm()") - - from llama_index.llms.openai import OpenAI - - conn_kwargs = self._resolve_connection_kwargs(self.llm_conn_id) - return OpenAI(model=self.llm_model, **conn_kwargs) - - def configure_settings(self) -> None: + Returns an ``OpenAI`` LLM instance; for other vendors, instantiate + the LlamaIndex class directly and pass it to the operator's ``llm=`` + parameter. """ - Configure LlamaIndex's global Settings with models from Airflow connections. - - Sets ``Settings.embed_model`` always, and ``Settings.llm`` when - ``llm_model`` is provided. - """ - from llama_index.core import Settings - - Settings.embed_model = self.get_embedding_model() - if self.llm_model: - Settings.llm = self.get_llm() + try: + from llama_index.llms.openai import OpenAI + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + + conn = self.get_connection(self.llm_conn_id) + model_id = self._resolve_model( + conn.extra_dejson, + constructor_value=self.llm_model, + extra_key="llm_model", + kind="llm", + ) + return OpenAI(model=model_id, **self._connection_kwargs(conn)) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index acbbc46dabbc4..098fa01fa871b 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -18,88 +18,122 @@ from __future__ import annotations -import os from collections.abc import Sequence from typing import TYPE_CHECKING, Any -from airflow.providers.common.compat.sdk import BaseOperator +from airflow.providers.common.compat.sdk import ( + AirflowOptionalProviderFeatureException, + BaseOperator, +) if TYPE_CHECKING: from airflow.sdk import Context + from llama_index.core.base.embeddings.base import BaseEmbedding class EmbeddingOperator(BaseOperator): """ Chunk documents and produce embedding vectors using LlamaIndex. - Bridges document loading (Airflow provider hooks returning text) and - vector storage (pgvector, Pinecone, Weaviate ingest operators). Input - is ``list[dict]`` with ``text`` and ``metadata`` keys; output includes - the embedding vectors ready for downstream storage. + Bridges document loading (e.g. + :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` + output) and vector storage (pgvector, Pinecone, Weaviate, ...). Input is + ``list[dict]`` with ``text`` and ``metadata`` keys; output includes the + embedding vectors ready for downstream storage ingest. + + The operator passes the embedding model **directly** to + ``VectorStoreIndex(..., embed_model=...)`` -- it does not mutate + LlamaIndex's global ``Settings`` singleton, so concurrent tasks in the + same worker don't race on shared state. :param documents: List of dicts with ``text`` and ``metadata`` keys, - typically from ``DocumentLoaderOperator`` or a ``@task``. - :param llm_conn_id: Airflow connection ID for the embedding API. - :param embed_model: Embedding model name (default: ``text-embedding-3-small``). - :param chunk_size: Chunk size for the sentence splitter (default: 512). - :param chunk_overlap: Overlap between chunks (default: 50). - :param persist_dir: Optional directory path to persist the LlamaIndex - index for later retrieval. + typically from ``DocumentLoaderOperator`` or a ``@task``. Bind via + ``my_loader.output`` (XCom direct), **not** via Jinja -- ``list[dict]`` + does not survive Jinja stringification. + :param embed_model: Either: + + * a string model name (e.g. ``"text-embedding-3-small"``) -- the + operator constructs an :class:`~.LlamaIndexHook`-backed + ``OpenAIEmbedding`` from ``llm_conn_id``, or + * a pre-built ``BaseEmbedding`` instance -- bypass the hook + entirely for non-OpenAI vendors (e.g. + ``CohereEmbedding(...)``, ``BedrockEmbedding(...)``). + + :param llm_conn_id: Airflow connection ID for the embedding API. Used + only when ``embed_model`` is a string (or omitted entirely, falling + back to ``extra["embed_model"]`` on the connection). + :param chunk_size: Chunk size for the sentence splitter. + :param chunk_overlap: Overlap between chunks. + :param persist_dir: Optional path to persist the index. Accepts local + paths and storage URIs (``s3://``, ``gs://``, ...) resolved via + :class:`~airflow.sdk.ObjectStoragePath`. + :param persist_conn_id: Airflow connection ID for cloud-storage + credentials when ``persist_dir`` is a URI. """ - template_fields: Sequence[str] = ("documents", "llm_conn_id", "persist_dir") + template_fields: Sequence[str] = ( + "llm_conn_id", + "persist_dir", + "persist_conn_id", + ) def __init__( self, *, documents: list[dict[str, Any]], - llm_conn_id: str = "pydanticai_default", - embed_model: str = "text-embedding-3-small", + embed_model: str | BaseEmbedding | None = None, + llm_conn_id: str = "llamaindex_default", chunk_size: int = 512, chunk_overlap: int = 50, persist_dir: str | None = None, + persist_conn_id: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.documents = documents - self.llm_conn_id = llm_conn_id self.embed_model = embed_model + self.llm_conn_id = llm_conn_id self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.persist_dir = persist_dir + self.persist_conn_id = persist_conn_id def execute(self, context: Context) -> dict[str, Any]: - from llama_index.core import Document, StorageContext, VectorStoreIndex - from llama_index.core.node_parser import SentenceSplitter + try: + from llama_index.core import Document, StorageContext, VectorStoreIndex + from llama_index.core.node_parser import SentenceSplitter + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) - from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + embed_model = self._resolve_embed_model() - hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) - hook.configure_settings() - - llama_docs = [Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents] + llama_docs = [ + Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents + ] splitter = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) nodes = splitter.get_nodes_from_documents(llama_docs) self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) - storage_context = StorageContext.from_defaults() - VectorStoreIndex(nodes, storage_context=storage_context, show_progress=False) + # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a + # side effect of building the index; capture the index so the + # variable isn't discarded (also lets future enhancements query it + # before persistence). + index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) if self.persist_dir: - os.makedirs(self.persist_dir, exist_ok=True) - storage_context.persist(persist_dir=self.persist_dir) - self.log.info("Index persisted to %s", self.persist_dir) + self._persist(index) - chunks = [] - for node in nodes: - chunk: dict[str, Any] = { + chunks = [ + { "text": node.text, "metadata": node.metadata, + # ``node.embedding`` is populated by ``VectorStoreIndex`` for + # every node since we forced an in-memory build above. + "vector": node.embedding, } - if node.embedding: - chunk["vector"] = node.embedding - chunks.append(chunk) + for node in nodes + ] return { "document_count": len(llama_docs), @@ -107,3 +141,38 @@ def execute(self, context: Context) -> dict[str, Any]: "persist_dir": self.persist_dir, "chunks": chunks, } + + def _resolve_embed_model(self) -> BaseEmbedding: + """ + Return a ready-to-use ``BaseEmbedding``. + + If ``embed_model`` is a string or ``None``, build one via + ``LlamaIndexHook`` (OpenAI from the configured Airflow connection). + Anything else is treated as a pre-built ``BaseEmbedding`` instance + (user brought their own) and returned as-is. Avoids + ``isinstance(.., BaseEmbedding)`` so the check doesn't trigger an + otherwise-unnecessary ``llama_index`` import. + """ + if self.embed_model is None or isinstance(self.embed_model, str): + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + return LlamaIndexHook( + llm_conn_id=self.llm_conn_id, + embed_model=self.embed_model, + ).get_embedding_model() + return self.embed_model + + def _persist(self, index: Any) -> None: + """Persist the index to ``persist_dir``; cloud URIs go through ObjectStoragePath.""" + if "://" in self.persist_dir: # type: ignore[operator] + from airflow.sdk import ObjectStoragePath + + target = ObjectStoragePath(self.persist_dir, conn_id=self.persist_conn_id) + target.mkdir(parents=True, exist_ok=True) + index.storage_context.persist(persist_dir=str(target), fs=target.fs) + else: + import os + + os.makedirs(self.persist_dir, exist_ok=True) # type: ignore[arg-type] + index.storage_context.persist(persist_dir=self.persist_dir) + self.log.info("Index persisted to %s", self.persist_dir) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py index 6089f7a4c628d..86050b4e7e91d 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -21,76 +21,139 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any -from airflow.providers.common.compat.sdk import BaseOperator +from airflow.providers.common.compat.sdk import ( + AirflowOptionalProviderFeatureException, + BaseOperator, +) if TYPE_CHECKING: from airflow.sdk import Context + from llama_index.core.base.embeddings.base import BaseEmbedding class RetrievalOperator(BaseOperator): """ Retrieve relevant document chunks from a persisted LlamaIndex index. - Loads a previously persisted vector store index and performs similarity - search against the provided query. The output is a list of chunks with - text, score, metadata, and source information ready for downstream - synthesis via ``LLMOperator``. - - :param query: The query string to search for. Supports Jinja templating. - :param index_persist_dir: Path to the persisted LlamaIndex index directory. - :param llm_conn_id: Airflow connection ID for the embedding API - (needed to embed the query vector). - :param embed_model: Embedding model name (default: ``text-embedding-3-small``). - :param top_k: Number of top results to retrieve (default: 5). + Loads a previously persisted vector store index (from + ``EmbeddingOperator(persist_dir=...)``) and performs similarity search + against the provided query. Output is a list of chunks with text, + score, metadata, and node id, ready for downstream synthesis via + :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`. + + Passes the embedding model **directly** to + ``load_index_from_storage(..., embed_model=...)`` -- no LlamaIndex + ``Settings`` mutation, so concurrent tasks in the same worker don't + race on shared state. + + :param query: The query string. Supports Jinja templating. + :param index_persist_dir: Local path or storage URI (``s3://``, + ``gs://``, ...) pointing at the persisted LlamaIndex index. + Resolved via :class:`~airflow.sdk.ObjectStoragePath` when a URI + scheme is present. + :param persist_conn_id: Airflow connection ID for cloud-storage + credentials when ``index_persist_dir`` is a URI. + :param embed_model: Either: + + * a string model name (e.g. ``"text-embedding-3-small"``) -- the + operator constructs an :class:`~.LlamaIndexHook`-backed + ``OpenAIEmbedding`` from ``llm_conn_id``, or + * a pre-built ``BaseEmbedding`` instance -- bypass the hook for + non-OpenAI vendors. Must match the embedding model used when + the index was originally built. + + :param llm_conn_id: Airflow connection ID for the embedding API. Used + only when ``embed_model`` is a string (or omitted entirely). + :param top_k: Number of top results to retrieve. """ - template_fields: Sequence[str] = ("query", "index_persist_dir", "llm_conn_id") + template_fields: Sequence[str] = ( + "query", + "index_persist_dir", + "persist_conn_id", + "llm_conn_id", + ) def __init__( self, *, query: str, index_persist_dir: str, - llm_conn_id: str = "pydanticai_default", - embed_model: str = "text-embedding-3-small", + persist_conn_id: str | None = None, + embed_model: str | BaseEmbedding | None = None, + llm_conn_id: str = "llamaindex_default", top_k: int = 5, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.query = query self.index_persist_dir = index_persist_dir - self.llm_conn_id = llm_conn_id + self.persist_conn_id = persist_conn_id self.embed_model = embed_model + self.llm_conn_id = llm_conn_id self.top_k = top_k def execute(self, context: Context) -> dict[str, Any]: - from llama_index.core import StorageContext, load_index_from_storage - - from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + try: + from llama_index.core import StorageContext, load_index_from_storage + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) - hook = LlamaIndexHook(llm_conn_id=self.llm_conn_id, embed_model=self.embed_model) - hook.configure_settings() - - storage_context = StorageContext.from_defaults(persist_dir=self.index_persist_dir) - index = load_index_from_storage(storage_context) + embed_model = self._resolve_embed_model() + storage_context = self._open_storage_context(StorageContext) + index = load_index_from_storage(storage_context, embed_model=embed_model) retriever = index.as_retriever(similarity_top_k=self.top_k) results = retriever.retrieve(self.query) self.log.info("Retrieved %d chunks for query: %s", len(results), self.query[:100]) - chunks = [] - for node_with_score in results: - node = node_with_score.node - chunks.append( - { - "text": node.get_content(), - "score": node_with_score.score, - "metadata": node.metadata, - "source": node.node_id, - } - ) + chunks = [ + { + "text": node_with_score.node.get_content(), + "score": node_with_score.score, + "metadata": node_with_score.node.metadata, + "node_id": node_with_score.node.node_id, + } + for node_with_score in results + ] return { - "question": self.query, + "query": self.query, "chunks": chunks, } + + def _resolve_embed_model(self) -> BaseEmbedding: + """String / ``None`` -> hook; anything else -> pre-built instance.""" + if self.embed_model is None or isinstance(self.embed_model, str): + from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook + + return LlamaIndexHook( + llm_conn_id=self.llm_conn_id, + embed_model=self.embed_model, + ).get_embedding_model() + return self.embed_model + + def _open_storage_context(self, storage_context_cls: Any) -> Any: + """Open a ``StorageContext`` from a local path or storage URI.""" + if "://" in self.index_persist_dir: + from airflow.sdk import ObjectStoragePath + + source = ObjectStoragePath(self.index_persist_dir, conn_id=self.persist_conn_id) + if not source.is_dir(): + raise FileNotFoundError( + f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " + "Did you run EmbeddingOperator with the same persist_dir first?" + ) + return storage_context_cls.from_defaults( + persist_dir=str(source), + fs=source.fs, + ) + + from pathlib import Path + + if not Path(self.index_persist_dir).is_dir(): + raise FileNotFoundError( + f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " + "Did you run EmbeddingOperator with the same persist_dir first?" + ) + return storage_context_cls.from_defaults(persist_dir=self.index_persist_dir) diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py index 3b119e3e5b439..0a4ae14d4271a 100644 --- a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py +++ b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import sys from unittest.mock import MagicMock, patch import pytest @@ -23,174 +24,172 @@ from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook +@pytest.fixture(autouse=True) +def _stub_llama_index_modules(): + """Stub ``llama_index.*`` in sys.modules so @patch can resolve targets + without the real package installed (mirrors the langchain pattern from + apache/airflow#67237). + """ + li = MagicMock() + mocks = { + "llama_index": li, + "llama_index.core": li.core, + "llama_index.core.base": li.core.base, + "llama_index.core.base.embeddings": li.core.base.embeddings, + "llama_index.core.base.embeddings.base": li.core.base.embeddings.base, + "llama_index.core.llms": li.core.llms, + "llama_index.core.llms.llm": li.core.llms.llm, + "llama_index.embeddings": li.embeddings, + "llama_index.embeddings.openai": li.embeddings.openai, + "llama_index.llms": li.llms, + "llama_index.llms.openai": li.llms.openai, + } + with patch.dict(sys.modules, mocks): + yield + + +def _conn(password: str = "", host: str = "", extra: dict | None = None) -> MagicMock: + mock_conn = MagicMock() + mock_conn.password = password + mock_conn.host = host + mock_conn.extra_dejson = extra or {} + return mock_conn + + class TestLlamaIndexHookInit: def test_default_params(self): hook = LlamaIndexHook() - assert hook.llm_conn_id == "pydanticai_default" - assert hook.embed_conn_id == "pydanticai_default" - assert hook.embed_model == "text-embedding-3-small" + assert hook.llm_conn_id == "llamaindex_default" + assert hook.embed_conn_id == "llamaindex_default" + assert hook.embed_model is None assert hook.llm_model is None - def test_separate_embed_conn_id(self): - hook = LlamaIndexHook(llm_conn_id="llm_conn", embed_conn_id="embed_conn") - assert hook.llm_conn_id == "llm_conn" - assert hook.embed_conn_id == "embed_conn" - - def test_embed_conn_defaults_to_llm_conn(self): + def test_embed_conn_falls_back_to_llm_conn(self): hook = LlamaIndexHook(llm_conn_id="my_conn") assert hook.embed_conn_id == "my_conn" + def test_explicit_separate_conns_and_models(self): + hook = LlamaIndexHook( + llm_conn_id="chat_conn", + embed_conn_id="embed_conn", + embed_model="text-embedding-3-large", + llm_model="gpt-4o", + ) + assert hook.llm_conn_id == "chat_conn" + assert hook.embed_conn_id == "embed_conn" + assert hook.embed_model == "text-embedding-3-large" + assert hook.llm_model == "gpt-4o" + + def test_conn_type_is_llamaindex(self): + assert LlamaIndexHook.conn_type == "llamaindex" + assert LlamaIndexHook.default_conn_name == "llamaindex_default" + assert LlamaIndexHook.conn_name_attr == "llm_conn_id" + assert LlamaIndexHook.hook_name == "LlamaIndex" + + +class TestGetUiFieldBehaviour: + def test_shape(self): + behaviour = LlamaIndexHook.get_ui_field_behaviour() + assert behaviour["hidden_fields"] == ["schema", "port", "login"] + assert behaviour["relabeling"] == {"password": "API Key"} + assert "host" in behaviour["placeholders"] + assert "embed_model" in behaviour["placeholders"]["extra"] + assert "llm_model" in behaviour["placeholders"]["extra"] + + +class TestResolveModel: + def test_constructor_wins_over_extra(self): + result = LlamaIndexHook._resolve_model( + {"embed_model": "old"}, + constructor_value="new", + extra_key="embed_model", + kind="embedding", + ) + assert result == "new" + + def test_falls_back_to_extra(self): + result = LlamaIndexHook._resolve_model( + {"embed_model": "from-extra"}, + constructor_value=None, + extra_key="embed_model", + kind="embedding", + ) + assert result == "from-extra" + + def test_raises_when_neither_set(self): + with pytest.raises(ValueError, match="No embedding model identifier set"): + LlamaIndexHook._resolve_model( + {}, + constructor_value=None, + extra_key="embed_model", + kind="embedding", + ) -class TestResolveConnectionKwargs: - @patch.object(LlamaIndexHook, "get_connection") - def test_extracts_password_as_api_key(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-test-key" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn - - hook = LlamaIndexHook() - result = hook._resolve_connection_kwargs("test_conn") - - assert result == {"api_key": "sk-test-key"} +class TestGetEmbeddingModel: + @patch("llama_index.embeddings.openai.OpenAIEmbedding") @patch.object(LlamaIndexHook, "get_connection") - def test_extracts_host_as_api_base(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "" - mock_conn.host = "https://custom.api.com" - mock_get_conn.return_value = mock_conn + def test_dispatches_with_api_key(self, mock_get_conn, mock_cls): + mock_get_conn.return_value = _conn(password="sk-test") + hook = LlamaIndexHook(embed_model="text-embedding-3-small") - hook = LlamaIndexHook() - result = hook._resolve_connection_kwargs("test_conn") + result = hook.get_embedding_model() - assert result == {"api_base": "https://custom.api.com"} + mock_get_conn.assert_called_once_with("llamaindex_default") + mock_cls.assert_called_once_with(model="text-embedding-3-small", api_key="sk-test") + assert result is mock_cls.return_value + @patch("llama_index.embeddings.openai.OpenAIEmbedding") @patch.object(LlamaIndexHook, "get_connection") - def test_both_password_and_host(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-key" - mock_conn.host = "https://api.example.com" - mock_get_conn.return_value = mock_conn + def test_dispatches_with_api_base(self, mock_get_conn, mock_cls): + mock_get_conn.return_value = _conn(password="sk-test", host="http://localhost:11434/v1") + hook = LlamaIndexHook(embed_model="text-embedding-3-small") - hook = LlamaIndexHook() - result = hook._resolve_connection_kwargs("test_conn") + hook.get_embedding_model() - assert result == {"api_key": "sk-key", "api_base": "https://api.example.com"} + mock_cls.assert_called_once_with( + model="text-embedding-3-small", + api_key="sk-test", + api_base="http://localhost:11434/v1", + ) + @patch("llama_index.embeddings.openai.OpenAIEmbedding") @patch.object(LlamaIndexHook, "get_connection") - def test_empty_fields_return_empty_dict(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn - + def test_resolves_model_from_extra(self, mock_get_conn, mock_cls): + mock_get_conn.return_value = _conn( + password="sk-test", extra={"embed_model": "text-embedding-3-large"} + ) hook = LlamaIndexHook() - result = hook._resolve_connection_kwargs("test_conn") - - assert result == {} - - -def _make_mock_openai_embedding_module(): - mock_module = MagicMock() - mock_cls = MagicMock() - mock_module.OpenAIEmbedding = mock_cls - return mock_module, mock_cls - -def _make_mock_openai_llm_module(): - mock_module = MagicMock() - mock_cls = MagicMock() - mock_module.OpenAI = mock_cls - return mock_module, mock_cls + hook.get_embedding_model() + mock_cls.assert_called_once_with(model="text-embedding-3-large", api_key="sk-test") -class TestGetEmbeddingModel: @patch.object(LlamaIndexHook, "get_connection") - def test_returns_openai_embedding(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-test" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn - - mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() - - hook = LlamaIndexHook(embed_model="text-embedding-3-large") - with patch.dict("sys.modules", {"llama_index.embeddings.openai": mock_embed_module}): - result = hook.get_embedding_model() - - mock_embed_cls.assert_called_once_with(model="text-embedding-3-large", api_key="sk-test") - assert result == mock_embed_cls.return_value - - -class TestGetLLM: - def test_raises_without_llm_model(self): + def test_raises_when_no_model_anywhere(self, mock_get_conn): + mock_get_conn.return_value = _conn(password="sk-test") hook = LlamaIndexHook() - with pytest.raises(ValueError, match="llm_model must be set"): - hook.get_llm() - @patch.object(LlamaIndexHook, "get_connection") - def test_returns_openai_llm(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-test" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn + with pytest.raises(ValueError, match="No embedding model identifier set"): + hook.get_embedding_model() - mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() +class TestGetLlm: + @patch("llama_index.llms.openai.OpenAI") + @patch.object(LlamaIndexHook, "get_connection") + def test_dispatches_with_api_key(self, mock_get_conn, mock_cls): + mock_get_conn.return_value = _conn(password="sk-test") hook = LlamaIndexHook(llm_model="gpt-4o") - with patch.dict("sys.modules", {"llama_index.llms.openai": mock_llm_module}): - result = hook.get_llm() - mock_llm_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test") - assert result == mock_llm_cls.return_value + result = hook.get_llm() + mock_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test") + assert result is mock_cls.return_value -class TestConfigureSettings: @patch.object(LlamaIndexHook, "get_connection") - def test_sets_embed_model(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-test" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn - - mock_embed_module, mock_embed_cls = _make_mock_openai_embedding_module() - mock_settings_module = MagicMock() - + def test_raises_when_no_llm_model(self, mock_get_conn): + mock_get_conn.return_value = _conn(password="sk-test") hook = LlamaIndexHook() - with patch.dict( - "sys.modules", - { - "llama_index.embeddings.openai": mock_embed_module, - "llama_index": MagicMock(), - "llama_index.core": mock_settings_module, - }, - ): - hook.configure_settings() - - assert mock_settings_module.Settings.embed_model == mock_embed_cls.return_value - - @patch.object(LlamaIndexHook, "get_connection") - def test_sets_llm_when_model_provided(self, mock_get_conn): - mock_conn = MagicMock() - mock_conn.password = "sk-test" - mock_conn.host = "" - mock_get_conn.return_value = mock_conn - - mock_embed_module, _ = _make_mock_openai_embedding_module() - mock_llm_module, mock_llm_cls = _make_mock_openai_llm_module() - mock_settings_module = MagicMock() - hook = LlamaIndexHook(llm_model="gpt-4o") - with patch.dict( - "sys.modules", - { - "llama_index.embeddings.openai": mock_embed_module, - "llama_index.llms.openai": mock_llm_module, - "llama_index": MagicMock(), - "llama_index.core": mock_settings_module, - }, - ): - hook.configure_settings() - - assert mock_settings_module.Settings.llm == mock_llm_cls.return_value + with pytest.raises(ValueError, match="No llm model identifier set"): + hook.get_llm() diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index ee3e8c51a562d..0077b075a252d 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -16,187 +16,170 @@ # under the License. from __future__ import annotations +import sys from unittest.mock import MagicMock, patch -from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator +import pytest +from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator -def _make_mock_node(text="chunk text", metadata=None, embedding=None): - node = MagicMock() - node.text = text - node.metadata = metadata or {} - node.embedding = embedding - return node +@pytest.fixture +def _stub_li(monkeypatch): + """Patch the few ``llama_index`` symbols the operator imports.""" + Document = MagicMock(name="Document") + StorageContext = MagicMock(name="StorageContext") + VectorStoreIndex = MagicMock(name="VectorStoreIndex") + SentenceSplitter = MagicMock(name="SentenceSplitter") -def _make_mock_llamaindex_modules(nodes=None): - """Create mock llama_index modules for sys.modules injection.""" - if nodes is None: - nodes = [_make_mock_node()] - - mock_core = MagicMock() - mock_core.Document = MagicMock(side_effect=lambda text, metadata: MagicMock(text=text, metadata=metadata)) - mock_core.StorageContext.from_defaults.return_value = MagicMock() - mock_core.VectorStoreIndex = MagicMock() - - mock_node_parser = MagicMock() - mock_splitter = MagicMock() - mock_splitter.get_nodes_from_documents.return_value = nodes - mock_node_parser.SentenceSplitter.return_value = mock_splitter - - return ( - { - "llama_index": MagicMock(), - "llama_index.core": mock_core, - "llama_index.core.node_parser": mock_node_parser, - "llama_index.embeddings": MagicMock(), - "llama_index.embeddings.openai": MagicMock(), - }, - mock_core, - mock_splitter, + li_core = MagicMock( + Document=Document, + StorageContext=StorageContext, + VectorStoreIndex=VectorStoreIndex, ) + li_core_np = MagicMock(SentenceSplitter=SentenceSplitter) + monkeypatch.setitem(sys.modules, "llama_index", MagicMock()) + monkeypatch.setitem(sys.modules, "llama_index.core", li_core) + monkeypatch.setitem(sys.modules, "llama_index.core.node_parser", li_core_np) -class TestEmbeddingOperator: - def test_template_fields(self): - expected = {"documents", "llm_conn_id", "persist_dir"} - assert set(EmbeddingOperator.template_fields) == expected - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_execute_returns_expected_shape(self, mock_hook_cls): - docs = [{"text": "Hello world", "metadata": {"source": "test"}}] - nodes = [_make_mock_node(text="Hello world", metadata={"source": "test"})] - mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) - - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") - - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - assert "document_count" in result - assert "chunk_count" in result - assert "persist_dir" in result - assert "chunks" in result - assert result["document_count"] == 1 - assert result["chunk_count"] == 1 - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_chunking_node_count(self, mock_hook_cls): - docs = [{"text": "A long document " * 100, "metadata": {}}] - nodes = [_make_mock_node(text=f"chunk {i}") for i in range(5)] - mock_modules, mock_core, mock_splitter = _make_mock_llamaindex_modules(nodes) - - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") - - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - assert result["chunk_count"] == 5 - assert len(result["chunks"]) == 5 - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_persist_dir_creates_and_persists(self, mock_hook_cls, tmp_path): - docs = [{"text": "test", "metadata": {}}] - persist_dir = str(tmp_path / "index_storage") - mock_modules, mock_core, _ = _make_mock_llamaindex_modules() - mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value - - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn", persist_dir=persist_dir) - - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) - - mock_storage_ctx.persist.assert_called_once_with(persist_dir=persist_dir) + return { + "Document": Document, + "StorageContext": StorageContext, + "VectorStoreIndex": VectorStoreIndex, + "SentenceSplitter": SentenceSplitter, + } - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_no_persist_when_none(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {}}] - mock_modules, mock_core, _ = _make_mock_llamaindex_modules() - mock_storage_ctx = mock_core.StorageContext.from_defaults.return_value - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") +def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): + node = MagicMock() + node.text = text + node.metadata = metadata or {} + node.embedding = vector + return node - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) - mock_storage_ctx.persist.assert_not_called() +class TestEmbeddingOperatorInit: + def test_documents_not_templated(self): + # ``documents`` is ``list[dict]`` -- Jinja stringification would + # break it. Explicitly out of template_fields. + assert "documents" not in EmbeddingOperator.template_fields - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_chunks_have_text_and_metadata(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {"src": "a"}}] - nodes = [_make_mock_node(text="chunk1", metadata={"src": "a"})] - mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + def test_templated_fields(self): + assert set(EmbeddingOperator.template_fields) == { + "llm_conn_id", + "persist_dir", + "persist_conn_id", + } - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) +class TestEmbeddingOperatorExecute: + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_string_embed_model_goes_through_hook(self, mock_get_embed, _stub_li): + # `embed_model` as a string -> hook builds OpenAIEmbedding. + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + _node(text="chunk a", vector=[0.1, 0.2]), + ] - chunk = result["chunks"][0] - assert "text" in chunk - assert "metadata" in chunk - assert chunk["text"] == "chunk1" - assert chunk["metadata"] == {"src": "a"} + op = EmbeddingOperator( + task_id="test", + documents=[{"text": "doc", "metadata": {"src": "x"}}], + embed_model="text-embedding-3-small", + llm_conn_id="my_conn", + ) + result = op.execute(context=MagicMock()) - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_chunks_include_vector_when_present(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {}}] - nodes = [_make_mock_node(text="chunk1", embedding=[0.1, 0.2, 0.3])] - mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + mock_get_embed.assert_called_once() + assert result["document_count"] == 1 + assert result["chunk_count"] == 1 + assert result["chunks"][0]["text"] == "chunk a" + assert result["chunks"][0]["vector"] == [0.1, 0.2] - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + def test_byo_embed_model_bypasses_hook(self, _stub_li): + # `embed_model` is a non-string instance -> hook is bypassed. + byo = MagicMock(name="MyBaseEmbedding") + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + _node() + ] - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) + op = EmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=byo, + ) + op.execute(context=MagicMock()) - assert result["chunks"][0]["vector"] == [0.1, 0.2, 0.3] + # VectorStoreIndex called with the user's instance, not anything else. + _stub_li["VectorStoreIndex"].assert_called_once() + kwargs = _stub_li["VectorStoreIndex"].call_args.kwargs + assert kwargs["embed_model"] is byo - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_chunks_omit_vector_when_not_present(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {}}] - nodes = [_make_mock_node(text="chunk1", embedding=None)] - mock_modules, _, _ = _make_mock_llamaindex_modules(nodes) + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _stub_li): + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]), + _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), + ] - op = EmbeddingOperator(task_id="test", documents=docs, llm_conn_id="my_conn") + op = EmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model="text-embedding-3-small", + ) + result = op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) + assert result["chunks"] == [ + {"text": "x", "metadata": {"k": "v"}, "vector": [1.0, 2.0]}, + {"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]}, + ] - assert "vector" not in result["chunks"][0] - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_hook_configured_with_params(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {}}] - mock_modules, _, _ = _make_mock_llamaindex_modules() +class TestEmbeddingOperatorPersist: + @patch("os.makedirs") + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_local_persist_dir_calls_makedirs_and_storage_persist( + self, mock_get_embed, mock_makedirs, _stub_li, tmp_path + ): + node = _node() + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + index = _stub_li["VectorStoreIndex"].return_value op = EmbeddingOperator( task_id="test", - documents=docs, - llm_conn_id="custom_conn", - embed_model="text-embedding-ada-002", + documents=[{"text": "doc"}], + embed_model="text-embedding-3-small", + persist_dir=str(tmp_path / "idx"), ) + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) + mock_makedirs.assert_called_once_with(str(tmp_path / "idx"), exist_ok=True) + index.storage_context.persist.assert_called_once_with(persist_dir=str(tmp_path / "idx")) - mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") - mock_hook_cls.return_value.configure_settings.assert_called_once() + @patch("airflow.sdk.ObjectStoragePath") + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_cloud_uri_persist_dir_uses_object_storage_path( + self, mock_get_embed, mock_osp_cls, _stub_li + ): + target = MagicMock() + target.__str__.return_value = "s3://bucket/idx/" + target.fs = MagicMock(name="s3fs") + mock_osp_cls.return_value = target - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_splitter_params_forwarded(self, mock_hook_cls): - docs = [{"text": "test", "metadata": {}}] - mock_modules, _, _ = _make_mock_llamaindex_modules() - mock_node_parser = mock_modules["llama_index.core.node_parser"] + node = _node() + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + index = _stub_li["VectorStoreIndex"].return_value op = EmbeddingOperator( task_id="test", - documents=docs, - llm_conn_id="my_conn", - chunk_size=256, - chunk_overlap=25, + documents=[{"text": "doc"}], + embed_model="text-embedding-3-small", + persist_dir="s3://bucket/idx/", + persist_conn_id="aws_default", ) + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) - - mock_node_parser.SentenceSplitter.assert_called_once_with(chunk_size=256, chunk_overlap=25) + mock_osp_cls.assert_called_once_with("s3://bucket/idx/", conn_id="aws_default") + target.mkdir.assert_called_once_with(parents=True, exist_ok=True) + index.storage_context.persist.assert_called_once_with( + persist_dir="s3://bucket/idx/", fs=target.fs + ) diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py index 0c85e86c214dc..6291aab189566 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -16,184 +16,167 @@ # under the License. from __future__ import annotations +import sys from unittest.mock import MagicMock, patch +import pytest + from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator -def _make_mock_node_with_score(text="chunk text", score=0.9, metadata=None, node_id="node-1"): - node = MagicMock() - node.get_content.return_value = text - node.metadata = metadata or {} - node.node_id = node_id +@pytest.fixture +def _stub_li(monkeypatch): + """Patch the few ``llama_index`` symbols the retrieval operator imports.""" + StorageContext = MagicMock(name="StorageContext") + load_index_from_storage = MagicMock(name="load_index_from_storage") - node_with_score = MagicMock() - node_with_score.node = node - node_with_score.score = score - return node_with_score - - -def _make_mock_llamaindex_modules(retrieval_results=None): - """Create mock llama_index modules for sys.modules injection.""" - if retrieval_results is None: - retrieval_results = [_make_mock_node_with_score()] - - mock_core = MagicMock() - mock_index = MagicMock() - mock_retriever = MagicMock() - mock_retriever.retrieve.return_value = retrieval_results - mock_index.as_retriever.return_value = mock_retriever - mock_core.load_index_from_storage.return_value = mock_index - - return ( - { - "llama_index": MagicMock(), - "llama_index.core": mock_core, - "llama_index.embeddings": MagicMock(), - "llama_index.embeddings.openai": MagicMock(), - }, - mock_core, - mock_index, - mock_retriever, + li_core = MagicMock( + StorageContext=StorageContext, + load_index_from_storage=load_index_from_storage, ) + monkeypatch.setitem(sys.modules, "llama_index", MagicMock()) + monkeypatch.setitem(sys.modules, "llama_index.core", li_core) -class TestRetrievalOperator: - def test_template_fields(self): - expected = {"query", "index_persist_dir", "llm_conn_id"} - assert set(RetrievalOperator.template_fields) == expected + return { + "StorageContext": StorageContext, + "load_index_from_storage": load_index_from_storage, + } - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_execute_returns_expected_shape(self, mock_hook_cls): - results = [_make_mock_node_with_score(text="relevant chunk", score=0.95)] - mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules(results) - op = RetrievalOperator( - task_id="test", - query="What is Airflow?", - index_persist_dir="/tmp/index", - llm_conn_id="my_conn", - ) - - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - assert "question" in result - assert "chunks" in result - assert result["question"] == "What is Airflow?" - assert len(result["chunks"]) == 1 - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_chunks_have_required_keys(self, mock_hook_cls): - results = [ - _make_mock_node_with_score( - text="chunk text", score=0.8, metadata={"file": "doc.txt"}, node_id="abc-123" - ) +def _scored_node(text: str, score: float, metadata: dict | None = None, node_id: str = "n"): + node = MagicMock() + node.get_content.return_value = text + node.metadata = metadata or {} + node.node_id = node_id + wrapped = MagicMock() + wrapped.node = node + wrapped.score = score + return wrapped + + +class TestRetrievalOperatorOutput: + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_chunk_shape(self, mock_get_embed, _stub_li, tmp_path): + # Make the persist_dir existence check pass. + (tmp_path / "idx").mkdir() + + index = _stub_li["load_index_from_storage"].return_value + retriever = index.as_retriever.return_value + retriever.retrieve.return_value = [ + _scored_node("chunk a", 0.91, {"src": "x"}, "node-a"), + _scored_node("chunk b", 0.85, {"src": "y"}, "node-b"), ] - mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) op = RetrievalOperator( task_id="test", - query="test query", - index_persist_dir="/tmp/index", - llm_conn_id="my_conn", + query="what is airflow", + index_persist_dir=str(tmp_path / "idx"), + embed_model="text-embedding-3-small", ) - - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - chunk = result["chunks"][0] - assert chunk["text"] == "chunk text" - assert chunk["score"] == 0.8 - assert chunk["metadata"] == {"file": "doc.txt"} - assert chunk["source"] == "abc-123" - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_top_k_forwarded_to_retriever(self, mock_hook_cls): - mock_modules, _, mock_index, _ = _make_mock_llamaindex_modules([]) + result = op.execute(context=MagicMock()) + + assert result == { + "query": "what is airflow", + "chunks": [ + {"text": "chunk a", "score": 0.91, "metadata": {"src": "x"}, "node_id": "node-a"}, + {"text": "chunk b", "score": 0.85, "metadata": {"src": "y"}, "node_id": "node-b"}, + ], + } + # The retrieval-time embedding model is passed directly (no Settings mutation). + _stub_li["load_index_from_storage"].assert_called_once() + kwargs = _stub_li["load_index_from_storage"].call_args.kwargs + assert "embed_model" in kwargs + index.as_retriever.assert_called_once_with(similarity_top_k=5) + + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_top_k_forwarded(self, mock_get_embed, _stub_li, tmp_path): + (tmp_path / "idx").mkdir() + index = _stub_li["load_index_from_storage"].return_value + index.as_retriever.return_value.retrieve.return_value = [] op = RetrievalOperator( task_id="test", - query="test", - index_persist_dir="/tmp/index", - llm_conn_id="my_conn", - top_k=10, + query="q", + index_persist_dir=str(tmp_path / "idx"), + embed_model="text-embedding-3-small", + top_k=12, ) + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) - - mock_index.as_retriever.assert_called_once_with(similarity_top_k=10) + index.as_retriever.assert_called_once_with(similarity_top_k=12) - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_query_value_in_output(self, mock_hook_cls): - mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): + (tmp_path / "idx").mkdir() + byo = MagicMock(name="MyBaseEmbedding") + index = _stub_li["load_index_from_storage"].return_value + index.as_retriever.return_value.retrieve.return_value = [] op = RetrievalOperator( task_id="test", - query="How does Airflow scheduling work?", - index_persist_dir="/tmp/index", - llm_conn_id="my_conn", + query="q", + index_persist_dir=str(tmp_path / "idx"), + embed_model=byo, ) + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - assert result["question"] == "How does Airflow scheduling work?" - assert result["chunks"] == [] + kwargs = _stub_li["load_index_from_storage"].call_args.kwargs + assert kwargs["embed_model"] is byo - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_multiple_results_returned(self, mock_hook_cls): - results = [ - _make_mock_node_with_score(text=f"chunk {i}", score=0.9 - i * 0.1, node_id=f"node-{i}") - for i in range(3) - ] - mock_modules, _, _, _ = _make_mock_llamaindex_modules(results) +class TestRetrievalOperatorMissingIndex: + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_local_missing_dir_raises_with_hint(self, mock_get_embed, _stub_li, tmp_path): op = RetrievalOperator( task_id="test", - query="test", - index_persist_dir="/tmp/index", - llm_conn_id="my_conn", + query="q", + index_persist_dir=str(tmp_path / "no_such_dir"), + embed_model="text-embedding-3-small", ) + with pytest.raises(FileNotFoundError, match="EmbeddingOperator"): + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - result = op.execute(context=MagicMock()) - - assert len(result["chunks"]) == 3 - assert result["chunks"][0]["text"] == "chunk 0" - assert result["chunks"][2]["text"] == "chunk 2" - - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_hook_configured_with_params(self, mock_hook_cls): - mock_modules, _, _, _ = _make_mock_llamaindex_modules([]) + @patch("airflow.sdk.ObjectStoragePath") + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_cloud_missing_uri_raises_with_hint(self, mock_get_embed, mock_osp_cls, _stub_li): + missing = MagicMock() + missing.is_dir.return_value = False + mock_osp_cls.return_value = missing op = RetrievalOperator( task_id="test", - query="test", - index_persist_dir="/tmp/index", - llm_conn_id="custom_conn", - embed_model="text-embedding-ada-002", + query="q", + index_persist_dir="s3://bucket/missing/", + embed_model="text-embedding-3-small", ) - - with patch.dict("sys.modules", mock_modules): + with pytest.raises(FileNotFoundError, match="EmbeddingOperator"): op.execute(context=MagicMock()) - mock_hook_cls.assert_called_once_with(llm_conn_id="custom_conn", embed_model="text-embedding-ada-002") - mock_hook_cls.return_value.configure_settings.assert_called_once() - @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook", autospec=True) - def test_persist_dir_passed_to_storage_context(self, mock_hook_cls): - mock_modules, mock_core, _, _ = _make_mock_llamaindex_modules([]) +class TestRetrievalOperatorCloudURI: + @patch("airflow.sdk.ObjectStoragePath") + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _stub_li): + target = MagicMock() + target.is_dir.return_value = True + target.__str__.return_value = "s3://bucket/idx/" + target.fs = MagicMock(name="s3fs") + mock_osp_cls.return_value = target + + index = _stub_li["load_index_from_storage"].return_value + index.as_retriever.return_value.retrieve.return_value = [] op = RetrievalOperator( task_id="test", - query="test", - index_persist_dir="/data/my_index", - llm_conn_id="my_conn", + query="q", + index_persist_dir="s3://bucket/idx/", + persist_conn_id="aws_default", + embed_model="text-embedding-3-small", ) + op.execute(context=MagicMock()) - with patch.dict("sys.modules", mock_modules): - op.execute(context=MagicMock()) - - mock_core.StorageContext.from_defaults.assert_called_once_with(persist_dir="/data/my_index") + mock_osp_cls.assert_called_once_with("s3://bucket/idx/", conn_id="aws_default") + _stub_li["StorageContext"].from_defaults.assert_called_once_with( + persist_dir="s3://bucket/idx/", + fs=target.fs, + ) From 82fa97ba0ddc06496eeabd16bace099b05f99297 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 02:31:10 +0100 Subject: [PATCH 3/7] Rename LlamaIndex operators with framework prefix; fold in #67189 RAG examples Per Kaxil's review r3267387604: ``RetrievalOperator`` / ``EmbeddingOperator`` are too generic in the common.ai namespace -- they risk colliding when other frameworks add their own embedding/retrieval operators. Renamed both with the LlamaIndex prefix: - ``EmbeddingOperator`` -> ``LlamaIndexEmbeddingOperator`` - ``RetrievalOperator`` -> ``LlamaIndexRetrievalOperator`` Renames applied across the two operator modules, three docs RSTs, the two test files, both example DAGs, and the cross-refs in ``docs/operators/index.rst``, ``docs/hooks/llamaindex.rst``, ``docs/operators/document_loader.rst``, and ``docs/hooks/index.rst``. Folds in #67189 (``example_llamaindex_rag.py``) which would otherwise sit blocked waiting for this PR to merge. Rewritten for the new API: - Uses the renamed classes - Drops ``documents="{{ ti.xcom_pull(...) }}"`` Jinja templating (template_fields removed; bind via ``loader.output`` direct) - Switches LlamaIndex operators to ``llamaindex_default`` conn (was ``pydanticai_default``); the synthesis-step ``LLMOperator`` keeps ``pydanticai_default`` because it's pydantic-ai-backed (different framework, intentional split documented in the module docstring) - Adds explicit ``embed_model="text-embedding-3-small"`` to every embedding/retrieval call (new operator validation requires it) - Fixes the string-reference task chains (``load >> "build_index"`` -> ``load >> build_index``) which weren't valid task dependencies Closes #67189. --- providers/common/ai/docs/hooks/index.rst | 3 +- providers/common/ai/docs/hooks/llamaindex.rst | 4 +- .../ai/docs/operators/document_loader.rst | 6 +- providers/common/ai/docs/operators/index.rst | 4 +- .../docs/operators/llamaindex_embedding.rst | 2 +- .../docs/operators/llamaindex_retrieval.rst | 10 +- .../example_dags/example_llamaindex_hook.py | 12 +- .../ai/example_dags/example_llamaindex_rag.py | 236 ++++++++++++++++++ .../providers/common/ai/hooks/llamaindex.py | 2 +- .../ai/operators/llamaindex_embedding.py | 2 +- .../ai/operators/llamaindex_retrieval.py | 8 +- .../ai/operators/test_llamaindex_embedding.py | 16 +- .../ai/operators/test_llamaindex_retrieval.py | 18 +- 13 files changed, 280 insertions(+), 43 deletions(-) create mode 100644 providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py diff --git a/providers/common/ai/docs/hooks/index.rst b/providers/common/ai/docs/hooks/index.rst index 9426280c9eb36..2786cd1b0ced5 100644 --- a/providers/common/ai/docs/hooks/index.rst +++ b/providers/common/ai/docs/hooks/index.rst @@ -41,7 +41,8 @@ Choosing a hook LangChain agent surface, or need LangChain-native chat / embedding model objects. Independent of the pydantic-ai-backed operators. * - :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` - - Backs the LlamaIndex ``EmbeddingOperator`` and ``RetrievalOperator``. + - Backs the LlamaIndex ``LlamaIndexEmbeddingOperator`` and + ``LlamaIndexRetrievalOperator``. Returns LlamaIndex-native ``BaseEmbedding`` / ``LLM`` objects (OpenAI by default). For non-OpenAI vendors, pass a pre-built ``BaseEmbedding`` / ``LLM`` instance straight to the operator and diff --git a/providers/common/ai/docs/hooks/llamaindex.rst b/providers/common/ai/docs/hooks/llamaindex.rst index 02255b72f6cf6..2bbd779ed56d1 100644 --- a/providers/common/ai/docs/hooks/llamaindex.rst +++ b/providers/common/ai/docs/hooks/llamaindex.rst @@ -51,9 +51,9 @@ behaviour: For other vendors (Cohere, Bedrock, Vertex AI, HuggingFace, ...), instantiate the LlamaIndex class directly in a ``@task`` and pass it to the operator's ``embed_model=`` / ``llm=`` parameter -- both -:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.LlamaIndexEmbeddingOperator` and -:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` +:class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.LlamaIndexRetrievalOperator` accept a pre-built ``BaseEmbedding`` / ``LLM`` instance and bypass the hook: diff --git a/providers/common/ai/docs/operators/document_loader.rst b/providers/common/ai/docs/operators/document_loader.rst index 8a836c37d120e..2aa32e6594dd2 100644 --- a/providers/common/ai/docs/operators/document_loader.rst +++ b/providers/common/ai/docs/operators/document_loader.rst @@ -146,7 +146,7 @@ No chunking The operator parses files into documents; it does **not** split them into fixed-size chunks. The right chunking strategy depends on the embedding model and is intentionally left to a downstream text-splitter or embedding -operator (LlamaIndex's ``EmbeddingOperator``, LangChain's text splitters, +operator (LlamaIndex's ``LlamaIndexEmbeddingOperator``, LangChain's text splitters, ...). Format coverage roadmap @@ -172,7 +172,7 @@ Composing with downstream embedding operators --------------------------------------------- The output format (``list[dict(text, metadata)]``) is designed to feed -directly into embedding operators. With LlamaIndex's ``EmbeddingOperator``: +directly into embedding operators. With LlamaIndex's ``LlamaIndexEmbeddingOperator``: .. code-block:: python @@ -181,7 +181,7 @@ directly into embedding operators. With LlamaIndex's ``EmbeddingOperator``: source_path="/data/docs/*.pdf", ) - embed = EmbeddingOperator( + embed = LlamaIndexEmbeddingOperator( task_id="embed", documents="{{ ti.xcom_pull(task_ids='load') }}", llm_conn_id="openai_default", diff --git a/providers/common/ai/docs/operators/index.rst b/providers/common/ai/docs/operators/index.rst index bc0b36bd9e8fa..7eaf414fe3a63 100644 --- a/providers/common/ai/docs/operators/index.rst +++ b/providers/common/ai/docs/operators/index.rst @@ -50,10 +50,10 @@ to pick the one that fits your use case: - :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` - *(no decorator)* * - Chunk documents and produce embedding vectors - - :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` + - :class:`~airflow.providers.common.ai.operators.llamaindex_embedding.LlamaIndexEmbeddingOperator` - *(no decorator)* * - Retrieve relevant chunks from a vector index - - :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.RetrievalOperator` + - :class:`~airflow.providers.common.ai.operators.llamaindex_retrieval.LlamaIndexRetrievalOperator` - *(no decorator)* **LLMOperator / @task.llm** — stateless, single-turn calls. Use this for classification, diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst index 0a2c47029d0c1..0f9440eaec2c3 100644 --- a/providers/common/ai/docs/operators/llamaindex_embedding.rst +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -17,7 +17,7 @@ .. _howto/operator:llamaindex_embedding: -LlamaIndex ``EmbeddingOperator`` +LlamaIndex ``LlamaIndexEmbeddingOperator`` ================================ Chunk a ``list[dict]`` of documents and produce embedding vectors using diff --git a/providers/common/ai/docs/operators/llamaindex_retrieval.rst b/providers/common/ai/docs/operators/llamaindex_retrieval.rst index 13ec739d2ab13..6ee92553c094c 100644 --- a/providers/common/ai/docs/operators/llamaindex_retrieval.rst +++ b/providers/common/ai/docs/operators/llamaindex_retrieval.rst @@ -17,12 +17,12 @@ .. _howto/operator:llamaindex_retrieval: -LlamaIndex ``RetrievalOperator`` +LlamaIndex ``LlamaIndexRetrievalOperator`` ================================ Load a persisted LlamaIndex index and run similarity search. Designed to sit between -:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.EmbeddingOperator` +:class:`~airflow.providers.common.ai.operators.llamaindex_embedding.LlamaIndexEmbeddingOperator` (which builds the index) and :class:`~airflow.providers.common.ai.operators.llm.LLMOperator` (which synthesises an answer from the retrieved chunks). @@ -47,15 +47,15 @@ Cloud-persisted indexes ----------------------- ``index_persist_dir`` accepts the same local-path-or-URI shape as -``EmbeddingOperator.persist_dir``. Pass ``persist_conn_id`` to point at +``LlamaIndexEmbeddingOperator.persist_dir``. Pass ``persist_conn_id`` to point at the Airflow connection that holds cloud credentials. The operator raises -``FileNotFoundError`` with a clear "did you run EmbeddingOperator first?" +``FileNotFoundError`` with a clear "did you run LlamaIndexEmbeddingOperator first?" message when the path is missing. Bring-your-own embedding model ------------------------------ -Same shape as ``EmbeddingOperator``: ``embed_model`` accepts either a +Same shape as ``LlamaIndexEmbeddingOperator``: ``embed_model`` accepts either a string model name (OpenAI via the hook) or a pre-built ``BaseEmbedding`` instance for non-OpenAI vendors. See the BYO example in :doc:`llamaindex_embedding`. diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py index 617662c431e55..f089ee02b4d08 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_hook.py @@ -23,8 +23,8 @@ from __future__ import annotations from airflow.providers.common.ai.operators.document_loader import DocumentLoaderOperator -from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator -from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator +from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator +from airflow.providers.common.ai.operators.llamaindex_retrieval import LlamaIndexRetrievalOperator from airflow.providers.common.compat.sdk import dag, task @@ -39,7 +39,7 @@ def example_llamaindex_embed(): file_extensions=[".pdf", ".md", ".txt"], ) - embed = EmbeddingOperator( + embed = LlamaIndexEmbeddingOperator( task_id="embed", documents=load.output, # XCom direct -- never via Jinja (list[dict]) embed_model="text-embedding-3-small", @@ -62,7 +62,7 @@ def example_llamaindex_embed(): def example_llamaindex_retrieve(): """Load a persisted index and run similarity search.""" - retrieve = RetrievalOperator( + retrieve = LlamaIndexRetrievalOperator( task_id="retrieve", query="{{ params.query }}", index_persist_dir="/opt/airflow/data/library_index", @@ -91,7 +91,7 @@ def example_llamaindex_cloud_persist(): file_extensions=[".pdf"], ) - embed = EmbeddingOperator( + embed = LlamaIndexEmbeddingOperator( task_id="embed", documents=load.output, embed_model="text-embedding-3-small", @@ -132,7 +132,7 @@ def build_cohere_embedder(): def empty_doc_list() -> list[dict]: return [{"text": "Cohere demo content", "metadata": {}}] - embed = EmbeddingOperator( + embed = LlamaIndexEmbeddingOperator( task_id="embed", documents=empty_doc_list(), embed_model=build_cohere_embedder(), diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py new file mode 100644 index 0000000000000..501d69594cc5e --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating RAG pipelines with LlamaIndex operators. + +Three patterns: + +1. Full RAG pipeline -- load -> embed -> retrieve -> answer in one DAG. +2. Separate index/query DAGs -- production-shaped split (scheduled + indexing job + on-demand query DAG). +3. Multi-source RAG -- combine multiple loaders with source metadata. + +The ``LLMOperator`` synthesis step uses a ``pydanticai_default`` connection +because :class:`~airflow.providers.common.ai.operators.llm.LLMOperator` is +pydantic-ai-backed; the LlamaIndex operators use ``llamaindex_default``. +The two connection types are intentional -- they back different frameworks. +""" + +from __future__ import annotations + +from airflow.providers.common.ai.operators.document_loader import DocumentLoaderOperator +from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator +from airflow.providers.common.ai.operators.llamaindex_retrieval import LlamaIndexRetrievalOperator +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import dag, task + +# --------------------------------------------------------------------------- +# 1. Full RAG pipeline: load -> embed -> retrieve -> answer +# --------------------------------------------------------------------------- + + +# [START howto_llamaindex_rag_pipeline] +@dag(schedule=None) +def example_llamaindex_rag_pipeline(): + """End-to-end RAG pipeline in a single DAG. + + 1. Parse local text files into document dicts. + 2. Chunk and embed the documents, persisting the index to disk. + 3. Retrieve relevant chunks for a user question. + 4. Synthesize an answer using the retrieved context. + """ + load = DocumentLoaderOperator( + task_id="load_docs", + source_path="/opt/airflow/data/knowledge_base/", + file_extensions=[".txt", ".md", ".pdf"], + ) + + embed = LlamaIndexEmbeddingOperator( + task_id="embed_docs", + documents=load.output, # XCom direct -- not Jinja (list[dict] doesn't survive stringification) + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + chunk_size=512, + chunk_overlap=50, + persist_dir="/opt/airflow/data/indexes/kb_index", + ) + + retrieve = LlamaIndexRetrievalOperator( + task_id="retrieve", + query="What are the main components of Apache Airflow?", + index_persist_dir="/opt/airflow/data/indexes/kb_index", + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + top_k=5, + ) + + @task + def format_context(retrieval_result: dict) -> str: + chunks = retrieval_result["chunks"] + return "\n\n---\n\n".join(chunk["text"] for chunk in chunks) + + context = format_context(retrieve.output) + + answer = LLMOperator( + task_id="answer", + prompt=( + "Using the context below, answer the question: " + "What are the main components of Apache Airflow?\n\n" + "Context:\n{{ ti.xcom_pull(task_ids='format_context') }}" + ), + llm_conn_id="pydanticai_default", + system_prompt="Answer based only on the provided context. Cite sources when possible.", + ) + + embed >> retrieve >> context >> answer + + +# [END howto_llamaindex_rag_pipeline] + +example_llamaindex_rag_pipeline() + + +# --------------------------------------------------------------------------- +# 2. Production-shaped split: scheduled indexing + on-demand query +# --------------------------------------------------------------------------- + + +# [START howto_llamaindex_index_dag] +@dag(schedule="@weekly") +def example_llamaindex_index_pdf(): + """Weekly indexing DAG -- keep the vector index fresh as PDFs arrive. + + The companion query DAG (below) reads the persisted index on demand. + """ + load = DocumentLoaderOperator( + task_id="load_pdfs", + source_path="/opt/airflow/data/reports/*.pdf", + ) + + build_index = LlamaIndexEmbeddingOperator( + task_id="build_index", + documents=load.output, + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + chunk_size=1024, + chunk_overlap=100, + persist_dir="/opt/airflow/data/indexes/reports_index", + ) + + load >> build_index + + +# [END howto_llamaindex_index_dag] + +example_llamaindex_index_pdf() + + +# [START howto_llamaindex_query_dag] +@dag( + schedule=None, + params={"question": "Summarize the key findings from the latest quarterly report."}, +) +def example_llamaindex_query(): + """On-demand query DAG -- retrieve from a pre-built index and synthesize. + + Trigger manually or via API with a ``question`` parameter. + """ + retrieve = LlamaIndexRetrievalOperator( + task_id="retrieve", + query="{{ params.question }}", + index_persist_dir="/opt/airflow/data/indexes/reports_index", + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + top_k=5, + ) + + @task + def format_context(retrieval_result: dict) -> str: + chunks = retrieval_result["chunks"] + numbered = [f"[{i + 1}] {chunk['text']}" for i, chunk in enumerate(chunks)] + return "\n\n".join(numbered) + + context = format_context(retrieve.output) + + synthesize = LLMOperator( + task_id="synthesize", + prompt=( + "Question: {{ params.question }}\n\n" + "Relevant excerpts:\n{{ ti.xcom_pull(task_ids='format_context') }}\n\n" + "Provide a detailed answer with references to the excerpt numbers." + ), + llm_conn_id="pydanticai_default", + system_prompt=( + "You are a research assistant. Answer the question using only the " + "provided excerpts. Reference excerpt numbers in square brackets." + ), + ) + + context >> synthesize + + +# [END howto_llamaindex_query_dag] + +example_llamaindex_query() + + +# --------------------------------------------------------------------------- +# 3. Multi-source RAG: combine CSV product data with text documentation +# --------------------------------------------------------------------------- + + +# [START howto_llamaindex_multi_source] +@dag(schedule=None) +def example_llamaindex_multi_source(): + """Combine multiple loaders with source-tagging metadata. + + Shows how ``DocumentLoaderOperator`` handles different file formats and + how ``metadata_fields`` tags documents by source for filtered retrieval + downstream. + """ + load_products = DocumentLoaderOperator( + task_id="load_products", + source_path="/opt/airflow/data/products.csv", + metadata_fields={"source": "product_catalog", "department": "engineering"}, + ) + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + source_path="/opt/airflow/data/documentation/", + file_extensions=[".md", ".txt"], + metadata_fields={"source": "documentation"}, + ) + + @task + def merge_documents(products: list[dict], docs: list[dict]) -> list[dict]: + return products + docs + + merged = merge_documents(load_products.output, load_docs.output) + + embed_all = LlamaIndexEmbeddingOperator( + task_id="embed_all", + documents=merged, + embed_model="text-embedding-3-small", + llm_conn_id="llamaindex_default", + persist_dir="/opt/airflow/data/indexes/multi_source_index", + ) + + embed_all + + +# [END howto_llamaindex_multi_source] + +example_llamaindex_multi_source() diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py index 68b8325cf05b7..05e002d86425e 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py +++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/llamaindex.py @@ -48,7 +48,7 @@ class LlamaIndexHook(BaseHook): behaviour. For other vendors (Cohere, Bedrock, Vertex, HuggingFace, ...) instantiate the LlamaIndex class directly in your ``@task`` and pass it to the operator's ``embed_model=`` / ``llm=`` parameter -- both - ``EmbeddingOperator`` and ``RetrievalOperator`` accept a pre-built + ``LlamaIndexEmbeddingOperator`` and ``LlamaIndexRetrievalOperator`` accept a pre-built ``BaseEmbedding`` / ``LLM`` instance and bypass the hook in that case. .. note:: diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index 098fa01fa871b..98b3612cbe0c0 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -31,7 +31,7 @@ from llama_index.core.base.embeddings.base import BaseEmbedding -class EmbeddingOperator(BaseOperator): +class LlamaIndexEmbeddingOperator(BaseOperator): """ Chunk documents and produce embedding vectors using LlamaIndex. diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py index 86050b4e7e91d..40e0ddaefc01e 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -31,12 +31,12 @@ from llama_index.core.base.embeddings.base import BaseEmbedding -class RetrievalOperator(BaseOperator): +class LlamaIndexRetrievalOperator(BaseOperator): """ Retrieve relevant document chunks from a persisted LlamaIndex index. Loads a previously persisted vector store index (from - ``EmbeddingOperator(persist_dir=...)``) and performs similarity search + ``LlamaIndexEmbeddingOperator(persist_dir=...)``) and performs similarity search against the provided query. Output is a list of chunks with text, score, metadata, and node id, ready for downstream synthesis via :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`. @@ -142,7 +142,7 @@ def _open_storage_context(self, storage_context_cls: Any) -> Any: if not source.is_dir(): raise FileNotFoundError( f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " - "Did you run EmbeddingOperator with the same persist_dir first?" + "Did you run LlamaIndexEmbeddingOperator with the same persist_dir first?" ) return storage_context_cls.from_defaults( persist_dir=str(source), @@ -154,6 +154,6 @@ def _open_storage_context(self, storage_context_cls: Any) -> Any: if not Path(self.index_persist_dir).is_dir(): raise FileNotFoundError( f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " - "Did you run EmbeddingOperator with the same persist_dir first?" + "Did you run LlamaIndexEmbeddingOperator with the same persist_dir first?" ) return storage_context_cls.from_defaults(persist_dir=self.index_persist_dir) diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 0077b075a252d..8e4af71ced00b 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.common.ai.operators.llamaindex_embedding import EmbeddingOperator +from airflow.providers.common.ai.operators.llamaindex_embedding import LlamaIndexEmbeddingOperator @pytest.fixture @@ -63,10 +63,10 @@ class TestEmbeddingOperatorInit: def test_documents_not_templated(self): # ``documents`` is ``list[dict]`` -- Jinja stringification would # break it. Explicitly out of template_fields. - assert "documents" not in EmbeddingOperator.template_fields + assert "documents" not in LlamaIndexEmbeddingOperator.template_fields def test_templated_fields(self): - assert set(EmbeddingOperator.template_fields) == { + assert set(LlamaIndexEmbeddingOperator.template_fields) == { "llm_conn_id", "persist_dir", "persist_conn_id", @@ -81,7 +81,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _stub_li): _node(text="chunk a", vector=[0.1, 0.2]), ] - op = EmbeddingOperator( + op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc", "metadata": {"src": "x"}}], embed_model="text-embedding-3-small", @@ -102,7 +102,7 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li): _node() ] - op = EmbeddingOperator( + op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], embed_model=byo, @@ -121,7 +121,7 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _stub_li): _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), ] - op = EmbeddingOperator( + op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], embed_model="text-embedding-3-small", @@ -144,7 +144,7 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] index = _stub_li["VectorStoreIndex"].return_value - op = EmbeddingOperator( + op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], embed_model="text-embedding-3-small", @@ -169,7 +169,7 @@ def test_cloud_uri_persist_dir_uses_object_storage_path( _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] index = _stub_li["VectorStoreIndex"].return_value - op = EmbeddingOperator( + op = LlamaIndexEmbeddingOperator( task_id="test", documents=[{"text": "doc"}], embed_model="text-embedding-3-small", diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py index 6291aab189566..6cd1c8387ad54 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.common.ai.operators.llamaindex_retrieval import RetrievalOperator +from airflow.providers.common.ai.operators.llamaindex_retrieval import LlamaIndexRetrievalOperator @pytest.fixture @@ -68,7 +68,7 @@ def test_chunk_shape(self, mock_get_embed, _stub_li, tmp_path): _scored_node("chunk b", 0.85, {"src": "y"}, "node-b"), ] - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="what is airflow", index_persist_dir=str(tmp_path / "idx"), @@ -95,7 +95,7 @@ def test_top_k_forwarded(self, mock_get_embed, _stub_li, tmp_path): index = _stub_li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="q", index_persist_dir=str(tmp_path / "idx"), @@ -112,7 +112,7 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): index = _stub_li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="q", index_persist_dir=str(tmp_path / "idx"), @@ -127,13 +127,13 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): class TestRetrievalOperatorMissingIndex: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_local_missing_dir_raises_with_hint(self, mock_get_embed, _stub_li, tmp_path): - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="q", index_persist_dir=str(tmp_path / "no_such_dir"), embed_model="text-embedding-3-small", ) - with pytest.raises(FileNotFoundError, match="EmbeddingOperator"): + with pytest.raises(FileNotFoundError, match="LlamaIndexEmbeddingOperator"): op.execute(context=MagicMock()) @patch("airflow.sdk.ObjectStoragePath") @@ -143,13 +143,13 @@ def test_cloud_missing_uri_raises_with_hint(self, mock_get_embed, mock_osp_cls, missing.is_dir.return_value = False mock_osp_cls.return_value = missing - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="q", index_persist_dir="s3://bucket/missing/", embed_model="text-embedding-3-small", ) - with pytest.raises(FileNotFoundError, match="EmbeddingOperator"): + with pytest.raises(FileNotFoundError, match="LlamaIndexEmbeddingOperator"): op.execute(context=MagicMock()) @@ -166,7 +166,7 @@ def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _st index = _stub_li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] - op = RetrievalOperator( + op = LlamaIndexRetrievalOperator( task_id="test", query="q", index_persist_dir="s3://bucket/idx/", From 0a377097df9e0c45eafd98e6f93194a270d8d2b3 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 02:59:23 +0100 Subject: [PATCH 4/7] Address code-review findings on LlamaIndex operators - Fix ObjectStoragePath conn_id mangling: pass raw URI to LlamaIndex persist_dir= and supply target.fs separately. str(target) returns s3://@/..., which fsspec misinterprets. - Add documents / embed_model / embed_conn_id to template_fields so XComArg resolution fires. The previous "list[dict] doesn't survive stringification" rationale was wrong; Templater unwraps resolvables before Jinja. - Default llm_conn_id to None on both operators; LlamaIndexHook resolves to default_conn_name at runtime. Hard-coding "llamaindex_default" undid the hook's careful runtime resolution. - Add embed_conn_id pass-through for separate embedding credentials. - Replace isinstance(str) duck-typing with hasattr-based BaseEmbedding check; raise TypeError with a clear pointer instead of letting an unresolved XComArg or random object explode later. - Hoist 'import os' and 'from pathlib import Path' to module top. - Pad RST title underlines and refresh docs/tests to match the new surface. --- .../docs/operators/llamaindex_embedding.rst | 22 +++--- .../docs/operators/llamaindex_retrieval.rst | 12 ++-- .../ai/example_dags/example_llamaindex_rag.py | 2 +- .../ai/operators/llamaindex_embedding.py | 70 +++++++++++++------ .../ai/operators/llamaindex_retrieval.py | 57 ++++++++++++--- .../ai/operators/test_llamaindex_embedding.py | 62 +++++++++++++--- .../ai/operators/test_llamaindex_retrieval.py | 62 +++++++++++++++- 7 files changed, 233 insertions(+), 54 deletions(-) diff --git a/providers/common/ai/docs/operators/llamaindex_embedding.rst b/providers/common/ai/docs/operators/llamaindex_embedding.rst index 0f9440eaec2c3..99125ac74bdde 100644 --- a/providers/common/ai/docs/operators/llamaindex_embedding.rst +++ b/providers/common/ai/docs/operators/llamaindex_embedding.rst @@ -18,7 +18,7 @@ .. _howto/operator:llamaindex_embedding: LlamaIndex ``LlamaIndexEmbeddingOperator`` -================================ +========================================== Chunk a ``list[dict]`` of documents and produce embedding vectors using LlamaIndex. Designed to feed the output of @@ -38,9 +38,8 @@ Basic usage :start-after: [START howto_hook_llamaindex_embed] :end-before: [END howto_hook_llamaindex_embed] -The ``documents`` parameter binds to ``loader.output`` (XCom direct), **not** -via Jinja -- ``list[dict]`` doesn't survive Jinja stringification, so the -parameter is intentionally not in ``template_fields``. +``documents`` is templated, so ``loader.output`` (XCom direct) is resolved +to a native ``list[dict]`` before ``execute`` runs. Bring-your-own embedding model ------------------------------ @@ -51,7 +50,7 @@ operator's ``embed_model`` parameter accepts either: * a string model name (e.g. ``"text-embedding-3-small"``) -- the operator constructs an ``OpenAIEmbedding`` via :class:`~airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook` - using ``llm_conn_id``, or + using ``llm_conn_id`` / ``embed_conn_id``, or * a pre-built ``BaseEmbedding`` instance -- bypass the hook entirely. Use this for Cohere, Bedrock, Vertex, HuggingFace, etc.: @@ -83,13 +82,18 @@ Parameters * - Parameter - Description * - ``documents`` - - ``list[dict]`` with ``text`` / ``metadata`` keys. Bind via - ``loader.output``; **not** templated. + - ``list[dict]`` with ``text`` / ``metadata`` keys. Templated, so + binding ``loader.output`` resolves to the native list before + execute. * - ``embed_model`` - String model name OR pre-built ``BaseEmbedding`` instance. * - ``llm_conn_id`` - - Airflow connection ID used when ``embed_model`` is a string - (default ``llamaindex_default``). + - Airflow connection ID used when ``embed_model`` is a string. Falls + back to ``LlamaIndexHook.default_conn_name`` (``llamaindex_default``) + when ``None``. + * - ``embed_conn_id`` + - Optional separate connection ID for the embedding provider. Falls + back to ``llm_conn_id`` when ``None``. * - ``chunk_size`` - Sentence-splitter chunk size (default 512). * - ``chunk_overlap`` diff --git a/providers/common/ai/docs/operators/llamaindex_retrieval.rst b/providers/common/ai/docs/operators/llamaindex_retrieval.rst index 6ee92553c094c..6e0793604abab 100644 --- a/providers/common/ai/docs/operators/llamaindex_retrieval.rst +++ b/providers/common/ai/docs/operators/llamaindex_retrieval.rst @@ -18,7 +18,7 @@ .. _howto/operator:llamaindex_retrieval: LlamaIndex ``LlamaIndexRetrievalOperator`` -================================ +========================================== Load a persisted LlamaIndex index and run similarity search. Designed to sit between @@ -79,10 +79,14 @@ Parameters Templated. * - ``embed_model`` - String model name OR pre-built ``BaseEmbedding`` instance. Must - match the model used when the index was built. + match the model used when the index was built. Templated. * - ``llm_conn_id`` - - Airflow connection ID used when ``embed_model`` is a string - (default ``llamaindex_default``). + - Airflow connection ID used when ``embed_model`` is a string. Falls + back to ``LlamaIndexHook.default_conn_name`` (``llamaindex_default``) + when ``None``. + * - ``embed_conn_id`` + - Optional separate connection ID for the embedding provider. Falls + back to ``llm_conn_id`` when ``None``. * - ``top_k`` - Number of top similarity results to return (default 5). diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py index 501d69594cc5e..6c044b965f4d7 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llamaindex_rag.py @@ -60,7 +60,7 @@ def example_llamaindex_rag_pipeline(): embed = LlamaIndexEmbeddingOperator( task_id="embed_docs", - documents=load.output, # XCom direct -- not Jinja (list[dict] doesn't survive stringification) + documents=load.output, embed_model="text-embedding-3-small", llm_conn_id="llamaindex_default", chunk_size=512, diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index 98b3612cbe0c0..ab86207b596f4 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -18,6 +18,7 @@ from __future__ import annotations +import os from collections.abc import Sequence from typing import TYPE_CHECKING, Any @@ -47,21 +48,25 @@ class LlamaIndexEmbeddingOperator(BaseOperator): same worker don't race on shared state. :param documents: List of dicts with ``text`` and ``metadata`` keys, - typically from ``DocumentLoaderOperator`` or a ``@task``. Bind via - ``my_loader.output`` (XCom direct), **not** via Jinja -- ``list[dict]`` - does not survive Jinja stringification. + typically from ``DocumentLoaderOperator`` or a ``@task``. Templated, + so binding via ``my_loader.output`` (XCom direct) resolves to the + native ``list[dict]`` before ``execute`` runs. :param embed_model: Either: * a string model name (e.g. ``"text-embedding-3-small"``) -- the operator constructs an :class:`~.LlamaIndexHook`-backed - ``OpenAIEmbedding`` from ``llm_conn_id``, or + ``OpenAIEmbedding`` from ``llm_conn_id`` / ``embed_conn_id``, or * a pre-built ``BaseEmbedding`` instance -- bypass the hook entirely for non-OpenAI vendors (e.g. ``CohereEmbedding(...)``, ``BedrockEmbedding(...)``). - :param llm_conn_id: Airflow connection ID for the embedding API. Used - only when ``embed_model`` is a string (or omitted entirely, falling - back to ``extra["embed_model"]`` on the connection). + Templated, so it works with both literal strings and ``@task`` + output that builds a custom embedder. + + :param llm_conn_id: Airflow connection ID for the embedding API. Falls + back to :attr:`LlamaIndexHook.default_conn_name` when ``None``. + :param embed_conn_id: Optional separate Airflow connection ID for the + embedding provider. Falls back to ``llm_conn_id`` when ``None``. :param chunk_size: Chunk size for the sentence splitter. :param chunk_overlap: Overlap between chunks. :param persist_dir: Optional path to persist the index. Accepts local @@ -72,7 +77,10 @@ class LlamaIndexEmbeddingOperator(BaseOperator): """ template_fields: Sequence[str] = ( + "documents", + "embed_model", "llm_conn_id", + "embed_conn_id", "persist_dir", "persist_conn_id", ) @@ -82,7 +90,8 @@ def __init__( *, documents: list[dict[str, Any]], embed_model: str | BaseEmbedding | None = None, - llm_conn_id: str = "llamaindex_default", + llm_conn_id: str | None = None, + embed_conn_id: str | None = None, chunk_size: int = 512, chunk_overlap: int = 50, persist_dir: str | None = None, @@ -93,6 +102,7 @@ def __init__( self.documents = documents self.embed_model = embed_model self.llm_conn_id = llm_conn_id + self.embed_conn_id = embed_conn_id self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.persist_dir = persist_dir @@ -100,7 +110,7 @@ def __init__( def execute(self, context: Context) -> dict[str, Any]: try: - from llama_index.core import Document, StorageContext, VectorStoreIndex + from llama_index.core import Document, VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter except ImportError as e: raise AirflowOptionalProviderFeatureException(e) @@ -117,8 +127,7 @@ def execute(self, context: Context) -> dict[str, Any]: # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a # side effect of building the index; capture the index so the - # variable isn't discarded (also lets future enhancements query it - # before persistence). + # variable isn't discarded. index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) if self.persist_dir: @@ -146,21 +155,38 @@ def _resolve_embed_model(self) -> BaseEmbedding: """ Return a ready-to-use ``BaseEmbedding``. - If ``embed_model`` is a string or ``None``, build one via - ``LlamaIndexHook`` (OpenAI from the configured Airflow connection). - Anything else is treated as a pre-built ``BaseEmbedding`` instance - (user brought their own) and returned as-is. Avoids - ``isinstance(.., BaseEmbedding)`` so the check doesn't trigger an - otherwise-unnecessary ``llama_index`` import. + Three cases: + + * ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via + ``LlamaIndexHook`` (the framework's documented ``default`` + behaviour). + * Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as + a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a + ``llama_index`` import here). + * Anything else -- ``TypeError`` with a clear pointer. """ if self.embed_model is None or isinstance(self.embed_model, str): from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook return LlamaIndexHook( llm_conn_id=self.llm_conn_id, + embed_conn_id=self.embed_conn_id, embed_model=self.embed_model, ).get_embedding_model() - return self.embed_model + + # ``BaseEmbedding`` always exposes these two methods (see + # ``llama_index.core.base.embeddings.base``). Duck-typing avoids + # importing ``llama_index`` here and also catches the case where an + # unresolved ``XComArg`` slips through. + if hasattr(self.embed_model, "get_text_embedding") and hasattr( + self.embed_model, "_get_query_embedding" + ): + return self.embed_model + + raise TypeError( + "embed_model must be a string model name, a LlamaIndex " + f"``BaseEmbedding`` instance, or None. Got {type(self.embed_model).__name__!r}." + ) def _persist(self, index: Any) -> None: """Persist the index to ``persist_dir``; cloud URIs go through ObjectStoragePath.""" @@ -169,10 +195,12 @@ def _persist(self, index: Any) -> None: target = ObjectStoragePath(self.persist_dir, conn_id=self.persist_conn_id) target.mkdir(parents=True, exist_ok=True) - index.storage_context.persist(persist_dir=str(target), fs=target.fs) + # ``str(target)`` returns ``s3://@/...`` when + # ``conn_id`` is set (see ``task-sdk/.../io/path.py``), which + # fsspec misinterprets. Pass the raw user URI as the path string + # and the authenticated filesystem separately. + index.storage_context.persist(persist_dir=self.persist_dir, fs=target.fs) else: - import os - os.makedirs(self.persist_dir, exist_ok=True) # type: ignore[arg-type] index.storage_context.persist(persist_dir=self.persist_dir) self.log.info("Index persisted to %s", self.persist_dir) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py index 40e0ddaefc01e..881d20c3d3d09 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -19,6 +19,7 @@ from __future__ import annotations from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import ( @@ -57,13 +58,19 @@ class LlamaIndexRetrievalOperator(BaseOperator): * a string model name (e.g. ``"text-embedding-3-small"``) -- the operator constructs an :class:`~.LlamaIndexHook`-backed - ``OpenAIEmbedding`` from ``llm_conn_id``, or + ``OpenAIEmbedding`` from ``llm_conn_id`` / ``embed_conn_id``, or * a pre-built ``BaseEmbedding`` instance -- bypass the hook for non-OpenAI vendors. Must match the embedding model used when the index was originally built. - :param llm_conn_id: Airflow connection ID for the embedding API. Used - only when ``embed_model`` is a string (or omitted entirely). + Templated, so it works with both literal strings and ``@task`` + output that builds a custom embedder. + + :param llm_conn_id: Airflow connection ID for the embedding API. Falls + back to :attr:`LlamaIndexHook.default_conn_name` when ``None``. + Used only when ``embed_model`` is a string (or omitted entirely). + :param embed_conn_id: Optional separate Airflow connection ID for the + embedding provider. Falls back to ``llm_conn_id`` when ``None``. :param top_k: Number of top results to retrieve. """ @@ -71,7 +78,9 @@ class LlamaIndexRetrievalOperator(BaseOperator): "query", "index_persist_dir", "persist_conn_id", + "embed_model", "llm_conn_id", + "embed_conn_id", ) def __init__( @@ -81,7 +90,8 @@ def __init__( index_persist_dir: str, persist_conn_id: str | None = None, embed_model: str | BaseEmbedding | None = None, - llm_conn_id: str = "llamaindex_default", + llm_conn_id: str | None = None, + embed_conn_id: str | None = None, top_k: int = 5, **kwargs: Any, ) -> None: @@ -91,6 +101,7 @@ def __init__( self.persist_conn_id = persist_conn_id self.embed_model = embed_model self.llm_conn_id = llm_conn_id + self.embed_conn_id = embed_conn_id self.top_k = top_k def execute(self, context: Context) -> dict[str, Any]: @@ -123,15 +134,41 @@ def execute(self, context: Context) -> dict[str, Any]: } def _resolve_embed_model(self) -> BaseEmbedding: - """String / ``None`` -> hook; anything else -> pre-built instance.""" + """ + Return a ready-to-use ``BaseEmbedding``. + + Three cases: + + * ``None`` or ``str`` -- build an ``OpenAIEmbedding`` via + ``LlamaIndexHook`` (the framework's documented ``default`` + behaviour). + * Has ``get_text_embedding`` / ``_get_query_embedding`` -- treat as + a pre-built ``BaseEmbedding`` (duck-typed to avoid forcing a + ``llama_index`` import here). + * Anything else -- ``TypeError`` with a clear pointer. + """ if self.embed_model is None or isinstance(self.embed_model, str): from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook return LlamaIndexHook( llm_conn_id=self.llm_conn_id, + embed_conn_id=self.embed_conn_id, embed_model=self.embed_model, ).get_embedding_model() - return self.embed_model + + # ``BaseEmbedding`` always exposes these two methods (see + # ``llama_index.core.base.embeddings.base``). Duck-typing avoids + # importing ``llama_index`` here and also catches the case where an + # unresolved ``XComArg`` slips through. + if hasattr(self.embed_model, "get_text_embedding") and hasattr( + self.embed_model, "_get_query_embedding" + ): + return self.embed_model + + raise TypeError( + "embed_model must be a string model name, a LlamaIndex " + f"``BaseEmbedding`` instance, or None. Got {type(self.embed_model).__name__!r}." + ) def _open_storage_context(self, storage_context_cls: Any) -> Any: """Open a ``StorageContext`` from a local path or storage URI.""" @@ -144,13 +181,15 @@ def _open_storage_context(self, storage_context_cls: Any) -> Any: f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " "Did you run LlamaIndexEmbeddingOperator with the same persist_dir first?" ) + # ``str(source)`` returns ``s3://@/...`` when + # ``conn_id`` is set (see ``task-sdk/.../io/path.py``), which + # fsspec misinterprets. Pass the raw user URI as the path string + # and the authenticated filesystem separately. return storage_context_cls.from_defaults( - persist_dir=str(source), + persist_dir=self.index_persist_dir, fs=source.fs, ) - from pathlib import Path - if not Path(self.index_persist_dir).is_dir(): raise FileNotFoundError( f"Persisted LlamaIndex index not found at '{self.index_persist_dir}'. " diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 8e4af71ced00b..2c84e5f3da47a 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -59,15 +59,23 @@ def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): return node -class TestEmbeddingOperatorInit: - def test_documents_not_templated(self): - # ``documents`` is ``list[dict]`` -- Jinja stringification would - # break it. Explicitly out of template_fields. - assert "documents" not in LlamaIndexEmbeddingOperator.template_fields +def _byo_embedding(): + """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks).""" + instance = MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) + return instance + - def test_templated_fields(self): +class TestEmbeddingOperatorInit: + def test_template_fields(self): + # ``documents`` must be templated so ``loader.output`` (XComArg) is + # resolved before execute. The earlier rationale that "list[dict] + # doesn't survive Jinja stringification" was wrong -- Templater + # unwraps resolvables before Jinja runs. assert set(LlamaIndexEmbeddingOperator.template_fields) == { + "documents", + "embed_model", "llm_conn_id", + "embed_conn_id", "persist_dir", "persist_conn_id", } @@ -95,9 +103,29 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _stub_li): assert result["chunks"][0]["text"] == "chunk a" assert result["chunks"][0]["vector"] == [0.1, 0.2] + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook") + def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li): + # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model="text-embedding-3-small", + llm_conn_id="my_llm_conn", + embed_conn_id="my_embed_conn", + ) + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with( + llm_conn_id="my_llm_conn", + embed_conn_id="my_embed_conn", + embed_model="text-embedding-3-small", + ) + def test_byo_embed_model_bypasses_hook(self, _stub_li): # `embed_model` is a non-string instance -> hook is bypassed. - byo = MagicMock(name="MyBaseEmbedding") + byo = _byo_embedding() _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ _node() ] @@ -114,6 +142,20 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li): kwargs = _stub_li["VectorStoreIndex"].call_args.kwargs assert kwargs["embed_model"] is byo + def test_invalid_embed_model_raises_typeerror(self, _stub_li): + # An object that's neither None/str nor duck-types as BaseEmbedding + # (e.g. an unresolved XComArg or random user input) raises TypeError + # with a clear pointer rather than a cryptic downstream error. + _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model=12345, # type: ignore[arg-type] + ) + with pytest.raises(TypeError, match="embed_model must be"): + op.execute(context=MagicMock()) + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _stub_li): _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ @@ -160,8 +202,12 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( def test_cloud_uri_persist_dir_uses_object_storage_path( self, mock_get_embed, mock_osp_cls, _stub_li ): + # ``ObjectStoragePath.__str__`` returns ``://@/...`` + # when ``conn_id`` is set, which fsspec misinterprets. The operator must + # pass the **raw** user URI to ``persist_dir=`` and supply + # ``fs=target.fs`` for credentials. Asserting against the raw URI here + # catches a regression where ``str(target)`` is used instead. target = MagicMock() - target.__str__.return_value = "s3://bucket/idx/" target.fs = MagicMock(name="s3fs") mock_osp_cls.return_value = target diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py index 6cd1c8387ad54..5b126799a06a9 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -55,6 +55,23 @@ def _scored_node(text: str, score: float, metadata: dict | None = None, node_id: return wrapped +def _byo_embedding(): + """Return a duck-typed ``BaseEmbedding`` stand-in.""" + return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) + + +class TestRetrievalOperatorInit: + def test_template_fields(self): + assert set(LlamaIndexRetrievalOperator.template_fields) == { + "query", + "index_persist_dir", + "persist_conn_id", + "embed_model", + "llm_conn_id", + "embed_conn_id", + } + + class TestRetrievalOperatorOutput: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_chunk_shape(self, mock_get_embed, _stub_li, tmp_path): @@ -106,9 +123,32 @@ def test_top_k_forwarded(self, mock_get_embed, _stub_li, tmp_path): index.as_retriever.assert_called_once_with(similarity_top_k=12) + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook") + def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li, tmp_path): + # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. + (tmp_path / "idx").mkdir() + index = _stub_li["load_index_from_storage"].return_value + index.as_retriever.return_value.retrieve.return_value = [] + + op = LlamaIndexRetrievalOperator( + task_id="test", + query="q", + index_persist_dir=str(tmp_path / "idx"), + embed_model="text-embedding-3-small", + llm_conn_id="my_llm_conn", + embed_conn_id="my_embed_conn", + ) + op.execute(context=MagicMock()) + + mock_hook_cls.assert_called_once_with( + llm_conn_id="my_llm_conn", + embed_conn_id="my_embed_conn", + embed_model="text-embedding-3-small", + ) + def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): (tmp_path / "idx").mkdir() - byo = MagicMock(name="MyBaseEmbedding") + byo = _byo_embedding() index = _stub_li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] @@ -123,6 +163,20 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): kwargs = _stub_li["load_index_from_storage"].call_args.kwargs assert kwargs["embed_model"] is byo + def test_invalid_embed_model_raises_typeerror(self, _stub_li, tmp_path): + # An object that's neither None/str nor duck-types as BaseEmbedding + # raises TypeError with a clear pointer. + (tmp_path / "idx").mkdir() + + op = LlamaIndexRetrievalOperator( + task_id="test", + query="q", + index_persist_dir=str(tmp_path / "idx"), + embed_model=12345, # type: ignore[arg-type] + ) + with pytest.raises(TypeError, match="embed_model must be"): + op.execute(context=MagicMock()) + class TestRetrievalOperatorMissingIndex: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") @@ -157,9 +211,13 @@ class TestRetrievalOperatorCloudURI: @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _stub_li): + # ``ObjectStoragePath.__str__`` returns ``://@/...`` + # when ``conn_id`` is set, which fsspec misinterprets. The operator must + # pass the **raw** user URI to ``persist_dir=`` and supply + # ``fs=target.fs`` for credentials. Asserting against the raw URI here + # catches a regression where ``str(target)`` is used instead. target = MagicMock() target.is_dir.return_value = True - target.__str__.return_value = "s3://bucket/idx/" target.fs = MagicMock(name="s3fs") mock_osp_cls.return_value = target From d3654a9c3c88f2a541a635320e1bcc59236f51a4 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 03:03:23 +0100 Subject: [PATCH 5/7] Fix mypy on LlamaIndex embedding operator - Pass persist_dir as a typed str arg to _persist so the existing None-narrowing # type: ignore comments can go away. - Cast SentenceSplitter nodes to list[TextNode] for the .text access: the splitter only ever returns TextNode, but the base get_nodes_from_documents signature is typed as list[BaseNode]. --- .../ai/operators/llamaindex_embedding.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index ab86207b596f4..54ab47ecf461f 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -20,7 +20,7 @@ import os from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from airflow.providers.common.compat.sdk import ( AirflowOptionalProviderFeatureException, @@ -30,6 +30,7 @@ if TYPE_CHECKING: from airflow.sdk import Context from llama_index.core.base.embeddings.base import BaseEmbedding + from llama_index.core.schema import TextNode class LlamaIndexEmbeddingOperator(BaseOperator): @@ -131,17 +132,21 @@ def execute(self, context: Context) -> dict[str, Any]: index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) if self.persist_dir: - self._persist(index) - + self._persist(index, self.persist_dir) + + # ``SentenceSplitter`` always returns ``TextNode`` instances, but the + # base ``get_nodes_from_documents`` signature is typed as + # ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't + # flag the ``.text`` access; ``node.embedding`` is populated by + # ``VectorStoreIndex`` for every node above. + text_nodes = cast("list[TextNode]", nodes) chunks = [ { "text": node.text, "metadata": node.metadata, - # ``node.embedding`` is populated by ``VectorStoreIndex`` for - # every node since we forced an in-memory build above. "vector": node.embedding, } - for node in nodes + for node in text_nodes ] return { @@ -188,19 +193,19 @@ def _resolve_embed_model(self) -> BaseEmbedding: f"``BaseEmbedding`` instance, or None. Got {type(self.embed_model).__name__!r}." ) - def _persist(self, index: Any) -> None: + def _persist(self, index: Any, persist_dir: str) -> None: """Persist the index to ``persist_dir``; cloud URIs go through ObjectStoragePath.""" - if "://" in self.persist_dir: # type: ignore[operator] + if "://" in persist_dir: from airflow.sdk import ObjectStoragePath - target = ObjectStoragePath(self.persist_dir, conn_id=self.persist_conn_id) + target = ObjectStoragePath(persist_dir, conn_id=self.persist_conn_id) target.mkdir(parents=True, exist_ok=True) # ``str(target)`` returns ``s3://@/...`` when # ``conn_id`` is set (see ``task-sdk/.../io/path.py``), which # fsspec misinterprets. Pass the raw user URI as the path string # and the authenticated filesystem separately. - index.storage_context.persist(persist_dir=self.persist_dir, fs=target.fs) + index.storage_context.persist(persist_dir=persist_dir, fs=target.fs) else: - os.makedirs(self.persist_dir, exist_ok=True) # type: ignore[arg-type] - index.storage_context.persist(persist_dir=self.persist_dir) - self.log.info("Index persisted to %s", self.persist_dir) + os.makedirs(persist_dir, exist_ok=True) + index.storage_context.persist(persist_dir=persist_dir) + self.log.info("Index persisted to %s", persist_dir) From 29abed04feb7af079404576a0ea9a31dc52c9a09 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 03:19:01 +0100 Subject: [PATCH 6/7] Install llama-index in tests instead of stubbing sys.modules llama-index-core / -embeddings-openai / -llms-openai were declared in the common.ai provider's dev dependency group but missing from uv.lock, so CI never actually installed them. The tests papered over that by faking out llama_index.* in sys.modules with MagicMocks. Refresh uv.lock so the packages get installed, then drop the sys.modules manipulation: - test_llamaindex.py: remove the autouse _stub_llama_index_modules fixture entirely; @patch resolves against the real modules. - test_llamaindex_embedding.py / test_llamaindex_retrieval.py: replace the _stub_li fixture (sys.modules setitem) with a smaller _li fixture that uses monkeypatch.setattr against real llama_index.core symbols. --- .../unit/common/ai/hooks/test_llamaindex.py | 25 -- .../ai/operators/test_llamaindex_embedding.py | 74 ++--- .../ai/operators/test_llamaindex_retrieval.py | 56 ++-- uv.lock | 297 +++++++++++++++++- 4 files changed, 350 insertions(+), 102 deletions(-) diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py index 0a4ae14d4271a..9d6e71790b3e0 100644 --- a/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py +++ b/providers/common/ai/tests/unit/common/ai/hooks/test_llamaindex.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import sys from unittest.mock import MagicMock, patch import pytest @@ -24,30 +23,6 @@ from airflow.providers.common.ai.hooks.llamaindex import LlamaIndexHook -@pytest.fixture(autouse=True) -def _stub_llama_index_modules(): - """Stub ``llama_index.*`` in sys.modules so @patch can resolve targets - without the real package installed (mirrors the langchain pattern from - apache/airflow#67237). - """ - li = MagicMock() - mocks = { - "llama_index": li, - "llama_index.core": li.core, - "llama_index.core.base": li.core.base, - "llama_index.core.base.embeddings": li.core.base.embeddings, - "llama_index.core.base.embeddings.base": li.core.base.embeddings.base, - "llama_index.core.llms": li.core.llms, - "llama_index.core.llms.llm": li.core.llms.llm, - "llama_index.embeddings": li.embeddings, - "llama_index.embeddings.openai": li.embeddings.openai, - "llama_index.llms": li.llms, - "llama_index.llms.openai": li.llms.openai, - } - with patch.dict(sys.modules, mocks): - yield - - def _conn(password: str = "", host: str = "", extra: dict | None = None) -> MagicMock: mock_conn = MagicMock() mock_conn.password = password diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 2c84e5f3da47a..b5b06da6b9fda 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import sys from unittest.mock import MagicMock, patch import pytest @@ -25,30 +24,18 @@ @pytest.fixture -def _stub_li(monkeypatch): - """Patch the few ``llama_index`` symbols the operator imports.""" - Document = MagicMock(name="Document") - StorageContext = MagicMock(name="StorageContext") +def _li(monkeypatch): + """Patch the two LlamaIndex constructors the operator uses inside execute(). + + ``llama_index`` (core + openai embeddings) is a real test dependency + declared in ``providers/common/ai/pyproject.toml``'s dev group, so + ``@patch("llama_index.core.X")`` resolves against the real module. + """ VectorStoreIndex = MagicMock(name="VectorStoreIndex") SentenceSplitter = MagicMock(name="SentenceSplitter") - - li_core = MagicMock( - Document=Document, - StorageContext=StorageContext, - VectorStoreIndex=VectorStoreIndex, - ) - li_core_np = MagicMock(SentenceSplitter=SentenceSplitter) - - monkeypatch.setitem(sys.modules, "llama_index", MagicMock()) - monkeypatch.setitem(sys.modules, "llama_index.core", li_core) - monkeypatch.setitem(sys.modules, "llama_index.core.node_parser", li_core_np) - - return { - "Document": Document, - "StorageContext": StorageContext, - "VectorStoreIndex": VectorStoreIndex, - "SentenceSplitter": SentenceSplitter, - } + monkeypatch.setattr("llama_index.core.VectorStoreIndex", VectorStoreIndex) + monkeypatch.setattr("llama_index.core.node_parser.SentenceSplitter", SentenceSplitter) + return {"VectorStoreIndex": VectorStoreIndex, "SentenceSplitter": SentenceSplitter} def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): @@ -61,8 +48,7 @@ def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): def _byo_embedding(): """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks).""" - instance = MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) - return instance + return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) class TestEmbeddingOperatorInit: @@ -83,9 +69,9 @@ def test_template_fields(self): class TestEmbeddingOperatorExecute: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_string_embed_model_goes_through_hook(self, mock_get_embed, _stub_li): + def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): # `embed_model` as a string -> hook builds OpenAIEmbedding. - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ _node(text="chunk a", vector=[0.1, 0.2]), ] @@ -104,9 +90,9 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _stub_li): assert result["chunks"][0]["vector"] == [0.1, 0.2] @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook") - def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li): + def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( task_id="test", @@ -123,12 +109,10 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li embed_model="text-embedding-3-small", ) - def test_byo_embed_model_bypasses_hook(self, _stub_li): + def test_byo_embed_model_bypasses_hook(self, _li): # `embed_model` is a non-string instance -> hook is bypassed. byo = _byo_embedding() - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node() - ] + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( task_id="test", @@ -138,15 +122,15 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li): op.execute(context=MagicMock()) # VectorStoreIndex called with the user's instance, not anything else. - _stub_li["VectorStoreIndex"].assert_called_once() - kwargs = _stub_li["VectorStoreIndex"].call_args.kwargs + _li["VectorStoreIndex"].assert_called_once() + kwargs = _li["VectorStoreIndex"].call_args.kwargs assert kwargs["embed_model"] is byo - def test_invalid_embed_model_raises_typeerror(self, _stub_li): + def test_invalid_embed_model_raises_typeerror(self, _li): # An object that's neither None/str nor duck-types as BaseEmbedding # (e.g. an unresolved XComArg or random user input) raises TypeError # with a clear pointer rather than a cryptic downstream error. - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( task_id="test", @@ -157,8 +141,8 @@ def test_invalid_embed_model_raises_typeerror(self, _stub_li): op.execute(context=MagicMock()) @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _stub_li): - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]), _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), ] @@ -180,11 +164,11 @@ class TestEmbeddingOperatorPersist: @patch("os.makedirs") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_local_persist_dir_calls_makedirs_and_storage_persist( - self, mock_get_embed, mock_makedirs, _stub_li, tmp_path + self, mock_get_embed, mock_makedirs, _li, tmp_path ): node = _node() - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] - index = _stub_li["VectorStoreIndex"].return_value + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator( task_id="test", @@ -200,7 +184,7 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_cloud_uri_persist_dir_uses_object_storage_path( - self, mock_get_embed, mock_osp_cls, _stub_li + self, mock_get_embed, mock_osp_cls, _li ): # ``ObjectStoragePath.__str__`` returns ``://@/...`` # when ``conn_id`` is set, which fsspec misinterprets. The operator must @@ -212,8 +196,8 @@ def test_cloud_uri_persist_dir_uses_object_storage_path( mock_osp_cls.return_value = target node = _node() - _stub_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] - index = _stub_li["VectorStoreIndex"].return_value + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] + index = _li["VectorStoreIndex"].return_value op = LlamaIndexEmbeddingOperator( task_id="test", diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py index 5b126799a06a9..58e2bb751c734 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_retrieval.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import sys from unittest.mock import MagicMock, patch import pytest @@ -25,19 +24,18 @@ @pytest.fixture -def _stub_li(monkeypatch): - """Patch the few ``llama_index`` symbols the retrieval operator imports.""" +def _li(monkeypatch): + """Patch the two LlamaIndex symbols the retrieval operator uses inside execute(). + + ``llama_index`` (core + openai embeddings) is a real test dependency + declared in ``providers/common/ai/pyproject.toml``'s dev group, so + ``monkeypatch.setattr("llama_index.core.X", ...)`` resolves against the + real module. + """ StorageContext = MagicMock(name="StorageContext") load_index_from_storage = MagicMock(name="load_index_from_storage") - - li_core = MagicMock( - StorageContext=StorageContext, - load_index_from_storage=load_index_from_storage, - ) - - monkeypatch.setitem(sys.modules, "llama_index", MagicMock()) - monkeypatch.setitem(sys.modules, "llama_index.core", li_core) - + monkeypatch.setattr("llama_index.core.StorageContext", StorageContext) + monkeypatch.setattr("llama_index.core.load_index_from_storage", load_index_from_storage) return { "StorageContext": StorageContext, "load_index_from_storage": load_index_from_storage, @@ -74,11 +72,11 @@ def test_template_fields(self): class TestRetrievalOperatorOutput: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_chunk_shape(self, mock_get_embed, _stub_li, tmp_path): + def test_chunk_shape(self, mock_get_embed, _li, tmp_path): # Make the persist_dir existence check pass. (tmp_path / "idx").mkdir() - index = _stub_li["load_index_from_storage"].return_value + index = _li["load_index_from_storage"].return_value retriever = index.as_retriever.return_value retriever.retrieve.return_value = [ _scored_node("chunk a", 0.91, {"src": "x"}, "node-a"), @@ -101,15 +99,15 @@ def test_chunk_shape(self, mock_get_embed, _stub_li, tmp_path): ], } # The retrieval-time embedding model is passed directly (no Settings mutation). - _stub_li["load_index_from_storage"].assert_called_once() - kwargs = _stub_li["load_index_from_storage"].call_args.kwargs + _li["load_index_from_storage"].assert_called_once() + kwargs = _li["load_index_from_storage"].call_args.kwargs assert "embed_model" in kwargs index.as_retriever.assert_called_once_with(similarity_top_k=5) @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_top_k_forwarded(self, mock_get_embed, _stub_li, tmp_path): + def test_top_k_forwarded(self, mock_get_embed, _li, tmp_path): (tmp_path / "idx").mkdir() - index = _stub_li["load_index_from_storage"].return_value + index = _li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] op = LlamaIndexRetrievalOperator( @@ -124,10 +122,10 @@ def test_top_k_forwarded(self, mock_get_embed, _stub_li, tmp_path): index.as_retriever.assert_called_once_with(similarity_top_k=12) @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook") - def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li, tmp_path): + def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li, tmp_path): # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. (tmp_path / "idx").mkdir() - index = _stub_li["load_index_from_storage"].return_value + index = _li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] op = LlamaIndexRetrievalOperator( @@ -146,10 +144,10 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _stub_li embed_model="text-embedding-3-small", ) - def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): + def test_byo_embed_model_bypasses_hook(self, _li, tmp_path): (tmp_path / "idx").mkdir() byo = _byo_embedding() - index = _stub_li["load_index_from_storage"].return_value + index = _li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] op = LlamaIndexRetrievalOperator( @@ -160,10 +158,10 @@ def test_byo_embed_model_bypasses_hook(self, _stub_li, tmp_path): ) op.execute(context=MagicMock()) - kwargs = _stub_li["load_index_from_storage"].call_args.kwargs + kwargs = _li["load_index_from_storage"].call_args.kwargs assert kwargs["embed_model"] is byo - def test_invalid_embed_model_raises_typeerror(self, _stub_li, tmp_path): + def test_invalid_embed_model_raises_typeerror(self, _li, tmp_path): # An object that's neither None/str nor duck-types as BaseEmbedding # raises TypeError with a clear pointer. (tmp_path / "idx").mkdir() @@ -180,7 +178,7 @@ def test_invalid_embed_model_raises_typeerror(self, _stub_li, tmp_path): class TestRetrievalOperatorMissingIndex: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_local_missing_dir_raises_with_hint(self, mock_get_embed, _stub_li, tmp_path): + def test_local_missing_dir_raises_with_hint(self, mock_get_embed, _li, tmp_path): op = LlamaIndexRetrievalOperator( task_id="test", query="q", @@ -192,7 +190,7 @@ def test_local_missing_dir_raises_with_hint(self, mock_get_embed, _stub_li, tmp_ @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_cloud_missing_uri_raises_with_hint(self, mock_get_embed, mock_osp_cls, _stub_li): + def test_cloud_missing_uri_raises_with_hint(self, mock_get_embed, mock_osp_cls, _li): missing = MagicMock() missing.is_dir.return_value = False mock_osp_cls.return_value = missing @@ -210,7 +208,7 @@ def test_cloud_missing_uri_raises_with_hint(self, mock_get_embed, mock_osp_cls, class TestRetrievalOperatorCloudURI: @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _stub_li): + def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _li): # ``ObjectStoragePath.__str__`` returns ``://@/...`` # when ``conn_id`` is set, which fsspec misinterprets. The operator must # pass the **raw** user URI to ``persist_dir=`` and supply @@ -221,7 +219,7 @@ def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _st target.fs = MagicMock(name="s3fs") mock_osp_cls.return_value = target - index = _stub_li["load_index_from_storage"].return_value + index = _li["load_index_from_storage"].return_value index.as_retriever.return_value.retrieve.return_value = [] op = LlamaIndexRetrievalOperator( @@ -234,7 +232,7 @@ def test_cloud_uri_opens_storage_with_fs(self, mock_get_embed, mock_osp_cls, _st op.execute(context=MagicMock()) mock_osp_cls.assert_called_once_with("s3://bucket/idx/", conn_id="aws_default") - _stub_li["StorageContext"].from_defaults.assert_called_once_with( + _li["StorageContext"].from_defaults.assert_called_once_with( persist_dir="s3://bucket/idx/", fs=target.fs, ) diff --git a/uv.lock b/uv.lock index ab9ae7beb1078..7766a2d3cf466 100644 --- a/uv.lock +++ b/uv.lock @@ -4258,6 +4258,11 @@ google = [ langchain = [ { name = "langchain" }, ] +llamaindex = [ + { name = "llama-index-core" }, + { name = "llama-index-embeddings-openai" }, + { name = "llama-index-llms-openai" }, +] mcp = [ { name = "pydantic-ai-slim", extra = ["mcp"] }, ] @@ -4284,6 +4289,9 @@ dev = [ { name = "apache-airflow-providers-standard" }, { name = "apache-airflow-task-sdk" }, { name = "langchain" }, + { name = "llama-index-core" }, + { name = "llama-index-embeddings-openai" }, + { name = "llama-index-llms-openai" }, { name = "pydantic-ai-slim", extra = ["mcp"] }, { name = "sqlglot" }, ] @@ -4301,6 +4309,9 @@ requires-dist = [ { name = "fastavro", marker = "python_full_version >= '3.14' and extra == 'avro'", specifier = ">=1.12.1" }, { name = "fastavro", marker = "python_full_version < '3.14' and extra == 'avro'", specifier = ">=1.10.0" }, { name = "langchain", marker = "extra == 'langchain'", specifier = ">=1.0.0" }, + { name = "llama-index-core", marker = "extra == 'llamaindex'", specifier = ">=0.13.0" }, + { name = "llama-index-embeddings-openai", marker = "extra == 'llamaindex'", specifier = ">=0.6.0" }, + { name = "llama-index-llms-openai", marker = "extra == 'llamaindex'", specifier = ">=0.6.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14' and extra == 'parquet'", specifier = ">=22.0.0" }, { name = "pyarrow", marker = "python_full_version < '3.14' and extra == 'parquet'", specifier = ">=18.0.0" }, { name = "pydantic-ai-slim", specifier = ">=1.34.0" }, @@ -4313,7 +4324,7 @@ requires-dist = [ { name = "python-docx", marker = "extra == 'docx'", specifier = ">=1.0.0" }, { name = "sqlglot", marker = "extra == 'sql'", specifier = ">=30.0.0" }, ] -provides-extras = ["anthropic", "bedrock", "google", "openai", "mcp", "avro", "parquet", "sql", "common-sql", "langchain", "pdf", "docx"] +provides-extras = ["anthropic", "bedrock", "google", "openai", "mcp", "avro", "parquet", "sql", "common-sql", "langchain", "llamaindex", "pdf", "docx"] [package.metadata.requires-dev] dev = [ @@ -4325,6 +4336,9 @@ dev = [ { name = "apache-airflow-providers-standard", editable = "providers/standard" }, { name = "apache-airflow-task-sdk", editable = "task-sdk" }, { name = "langchain", specifier = ">=1.0.0" }, + { name = "llama-index-core", specifier = ">=0.13.0" }, + { name = "llama-index-embeddings-openai", specifier = ">=0.6.0" }, + { name = "llama-index-llms-openai", specifier = ">=0.6.0" }, { name = "pydantic-ai-slim", extras = ["mcp"] }, { name = "sqlglot", specifier = ">=30.0.0" }, ] @@ -9636,6 +9650,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/31/759d077aa680555e17c9d2bb09edf4c3428d895fe5d35a8df67684401b84/backports_zstd-1.5.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6172dcdd664ef243e55a35e6b45f1c866767c61043f0ddcd908abd14df662065", size = 300853, upload-time = "2026-05-11T19:54:23.1Z" }, ] +[[package]] +name = "banks" +version = "2.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated" }, + { name = "filetype" }, + { name = "griffe" }, + { name = "jinja2" }, + { name = "platformdirs" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/51/08fb68d23f4b0f6256fe85dc86e9576941550f890b079352fba719e07b39/banks-2.4.2.tar.gz", hash = "sha256:cda6013bd377ea7b701933578bfb9370fc21ad70bc13cedfc3f5cb2c034ca3dc", size = 188633, upload-time = "2026-04-27T12:15:22.021Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b6/8dc5477681b782e2f99de703e7a99828883364b9e03a60d3e2c47053d56a/banks-2.4.2-py3-none-any.whl", hash = "sha256:5fe407cc48c101f3e13d1cf732b83b8246003337612f13c0705d2e81f6faffb7", size = 35050, upload-time = "2026-04-27T12:15:20.785Z" }, +] + [[package]] name = "bcrypt" version = "5.0.0" @@ -11029,6 +11060,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/77/dc8c558f7593132cf8fefec57c4f60c83b16941c574ac5f619abb3ae7933/dill-0.4.1-py3-none-any.whl", hash = "sha256:1e1ce33e978ae97fcfcff5638477032b801c46c7c65cf717f95fbc2248f79a9d", size = 120019, upload-time = "2026-01-19T02:36:55.663Z" }, ] +[[package]] +name = "dirtyjson" +version = "1.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/04/d24f6e645ad82ba0ef092fa17d9ef7a21953781663648a01c9371d9e8e98/dirtyjson-1.0.8.tar.gz", hash = "sha256:90ca4a18f3ff30ce849d100dcf4a003953c79d3a2348ef056f1d9c22231a25fd", size = 30782, upload-time = "2022-11-28T23:32:33.319Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197, upload-time = "2022-11-28T23:32:31.219Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -11551,6 +11591,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" }, ] +[[package]] +name = "filetype" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/29/745f7d30d47fe0f251d3ad3dc2978a23141917661998763bebb6da007eb1/filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb", size = 998020, upload-time = "2022-11-02T17:34:04.141Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970, upload-time = "2022-11-02T17:34:01.425Z" }, +] + [[package]] name = "flask" version = "3.1.3" @@ -13134,6 +13183,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/55/f6adf83dd74563aca7721d456b1d33d7656448e29cc79a6aede3bb6ffa5b/gremlinpython-3.8.1-py3-none-any.whl", hash = "sha256:2e8136f9ea8cd771f9cc6f86f4ce73130595aed414a363534e1a4e18bfa81427", size = 75457, upload-time = "2026-04-07T00:22:18.776Z" }, ] +[[package]] +name = "griffe" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "griffecli" }, + { name = "griffelib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4a/49/eb6d2935e27883af92c930ed40cc4c69bcd32c402be43b8ca4ab20510f67/griffe-2.0.2.tar.gz", hash = "sha256:c5d56326d159f274492e9bf93a9895cec101155d944caa66d0fc4e0c13751b92", size = 293757, upload-time = "2026-03-27T11:34:52.205Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/c0/2bb018eecf9a83c68db9cd9fffd9dab25f102ad30ed869451046e46d1187/griffe-2.0.2-py3-none-any.whl", hash = "sha256:2b31816460aee1996af26050a1fc6927a2e5936486856707f55508e4c9b5960b", size = 5141, upload-time = "2026-03-27T11:34:47.721Z" }, +] + +[[package]] +name = "griffecli" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "griffelib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/e0/6a7d661d71bb043656a109b91d84a42b5342752542074ec83b16a6eb97f0/griffecli-2.0.2.tar.gz", hash = "sha256:40a1ad4181fc39685d025e119ae2c5b669acdc1f19b705fb9bf971f4e6f6dffb", size = 56281, upload-time = "2026-03-27T11:34:50.087Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/e8/90d93356c88ac34c20cb5edffca68138df55ca9bbd1a06eccfbcec8fdbe5/griffecli-2.0.2-py3-none-any.whl", hash = "sha256:0d44d39e59afa81e288a3e1c3bf352cc4fa537483326ac06b8bb6a51fd8303a0", size = 9500, upload-time = "2026-03-27T11:34:48.81Z" }, +] + [[package]] name = "griffelib" version = "2.0.2" @@ -15035,6 +15110,100 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, ] +[[package]] +name = "llama-index-core" +version = "0.14.22" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "aiosqlite" }, + { name = "banks" }, + { name = "dataclasses-json" }, + { name = "deprecated" }, + { name = "dirtyjson" }, + { name = "filetype" }, + { name = "fsspec" }, + { name = "httpx" }, + { name = "llama-index-workflows" }, + { name = "nest-asyncio" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nltk" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pillow" }, + { name = "platformdirs" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "sqlalchemy", extra = ["asyncio"] }, + { name = "tenacity" }, + { name = "tiktoken" }, + { name = "tinytag" }, + { name = "tqdm" }, + { name = "typing-extensions" }, + { name = "typing-inspect" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/7f/94a4b940ef0d069840df0fd6d361a2aa832a2dd73b4cecdf86e8f8c353c8/llama_index_core-0.14.22.tar.gz", hash = "sha256:1384410f89bdbd32349aab444ef4f5c828c338787bc65bd1ffd8e86dfb44ac41", size = 11584786, upload-time = "2026-05-14T20:21:37.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/15/e1a26d8d56aa55fa07587a3e9c7e85294d2df5af6c2229193019bc549ef6/llama_index_core-0.14.22-py3-none-any.whl", hash = "sha256:9cfffde46fd5b7937101e1c0c9bb5c21bd7ff8c8a56937810b87ba3542f31225", size = 11920774, upload-time = "2026-05-14T20:21:40.409Z" }, +] + +[[package]] +name = "llama-index-embeddings-openai" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llama-index-core" }, + { name = "openai" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/52/eb56a4887501651fb17400f7f571c1878109ff698efbe0bbac9165a5603d/llama_index_embeddings_openai-0.6.0.tar.gz", hash = "sha256:eb3e6606be81cb89125073e23c97c0a6119dabb4827adbd14697c2029ad73f29", size = 7629, upload-time = "2026-03-12T20:21:27.234Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/d1/4bb0b80f4057903110060f617ef519197194b3ff5dd6153d850c8f5676fa/llama_index_embeddings_openai-0.6.0-py3-none-any.whl", hash = "sha256:039bb1007ad4267e25ddb89a206dfdab862bfb87d58da4271a3919e4f9df4d61", size = 7666, upload-time = "2026-03-12T20:21:28.079Z" }, +] + +[[package]] +name = "llama-index-instrumentation" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/d0/671b23ccff255c9bce132a84ffd5a6f4541ceefdeab9c1786b08c9722f2e/llama_index_instrumentation-0.5.0.tar.gz", hash = "sha256:eeb724648b25d149de882a5ac9e21c5acb1ce780da214bda2b075341af29ad8e", size = 43831, upload-time = "2026-03-12T20:17:06.742Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/45/6dcaccef44e541ffa138e4b45e33e0d40ab2a7d845338483954fcf77bc75/llama_index_instrumentation-0.5.0-py3-none-any.whl", hash = "sha256:aaab83cddd9dd434278891012d8995f47a3bc7ed1736a371db90965348c56a21", size = 16444, upload-time = "2026-03-12T20:17:05.957Z" }, +] + +[[package]] +name = "llama-index-llms-openai" +version = "0.7.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llama-index-core" }, + { name = "openai" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/d5/2de9c05f1f1d21eb678a6044c59e943063e70099ac39b8b6f835e6e39511/llama_index_llms_openai-0.7.8.tar.gz", hash = "sha256:3352aed617ee5b7aefeb12719609ff84b4b590a1f49aa1e2e9c383d67ea88b0e", size = 27539, upload-time = "2026-05-08T20:02:09.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/49/4250108a76f4f7622109ecb9c57861829f508aba0ffdc502b27134378505/llama_index_llms_openai-0.7.8-py3-none-any.whl", hash = "sha256:967aac1f4ceff99185b2cc425c2757d4fefaf3fac0a35ace247f87a212a29359", size = 28617, upload-time = "2026-05-08T20:02:10.583Z" }, +] + +[[package]] +name = "llama-index-workflows" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llama-index-instrumentation" }, + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/ec/05f3db99a2e6e252e3939e7751cad2fb1322dc6d32f4cf5c795cf7ddcad3/llama_index_workflows-2.20.0.tar.gz", hash = "sha256:df2760fea9e100c97a4e919d255461e344413acac4382d17d8217337806e4772", size = 97410, upload-time = "2026-04-24T14:54:41.524Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/5f/385231406d777cb4b608fd8ebe3577dbd90962770717181e6b91b44fb1b8/llama_index_workflows-2.20.0-py3-none-any.whl", hash = "sha256:36f6b6ace77f837d9907078aea7e830251afe96a58daecff5ed090c88c55095d", size = 121238, upload-time = "2026-04-24T14:54:40.455Z" }, +] + [[package]] name = "lockfile" version = "0.12.2" @@ -16438,6 +16607,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/7c/19cd0671d1ba2762fb388fc149697d20d0568ccfeef833b11280a619e526/nh3-0.3.5-cp38-abi3-win_arm64.whl", hash = "sha256:8f85285700a18e9f3fc5bff41fe573fa84f81542ef13b48a89f9fecca0474d3b", size = 611069, upload-time = "2026-04-25T10:44:14.934Z" }, ] +[[package]] +name = "nltk" +version = "3.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0" @@ -17522,6 +17706,104 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/26/6cee8a1ce8c43625ec561aff19df07f9776b7525d9002c86bceb3e0ac970/pgvector-0.4.2-py3-none-any.whl", hash = "sha256:549d45f7a18593783d5eec609ea1684a724ba8405c4cb182a0b2b08aeff04e08", size = 27441, upload-time = "2025-12-05T01:07:16.536Z" }, ] +[[package]] +name = "pillow" +version = "12.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/aa/d0b28e1c811cd4d5f5c2bfe2e022292bd255ae5744a3b9ac7d6c8f72dd75/pillow-12.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:a4e8f36e677d3336f35089648c8955c51c6d386a13cf6ee9c189c5f5bd713a9f", size = 5354355, upload-time = "2026-04-01T14:42:15.402Z" }, + { url = "https://files.pythonhosted.org/packages/27/8e/1d5b39b8ae2bd7650d0c7b6abb9602d16043ead9ebbfef4bc4047454da2a/pillow-12.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e589959f10d9824d39b350472b92f0ce3b443c0a3442ebf41c40cb8361c5b97", size = 4695871, upload-time = "2026-04-01T14:42:18.234Z" }, + { url = "https://files.pythonhosted.org/packages/f0/c5/dcb7a6ca6b7d3be41a76958e90018d56c8462166b3ef223150360850c8da/pillow-12.2.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a52edc8bfff4429aaabdf4d9ee0daadbbf8562364f940937b941f87a4290f5ff", size = 6269734, upload-time = "2026-04-01T14:42:20.608Z" }, + { url = "https://files.pythonhosted.org/packages/ea/f1/aa1bb13b2f4eba914e9637893c73f2af8e48d7d4023b9d3750d4c5eb2d0c/pillow-12.2.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:975385f4776fafde056abb318f612ef6285b10a1f12b8570f3647ad0d74b48ec", size = 8076080, upload-time = "2026-04-01T14:42:23.095Z" }, + { url = "https://files.pythonhosted.org/packages/a1/2a/8c79d6a53169937784604a8ae8d77e45888c41537f7f6f65ed1f407fe66d/pillow-12.2.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bd9c0c7a0c681a347b3194c500cb1e6ca9cab053ea4d82a5cf45b6b754560136", size = 6382236, upload-time = "2026-04-01T14:42:25.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/42/bbcb6051030e1e421d103ce7a8ecadf837aa2f39b8f82ef1a8d37c3d4ebc/pillow-12.2.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:88d387ff40b3ff7c274947ed3125dedf5262ec6919d83946753b5f3d7c67ea4c", size = 7070220, upload-time = "2026-04-01T14:42:28.68Z" }, + { url = "https://files.pythonhosted.org/packages/3f/e1/c2a7d6dd8cfa6b231227da096fd2d58754bab3603b9d73bf609d3c18b64f/pillow-12.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:51c4167c34b0d8ba05b547a3bb23578d0ba17b80a5593f93bd8ecb123dd336a3", size = 6493124, upload-time = "2026-04-01T14:42:31.579Z" }, + { url = "https://files.pythonhosted.org/packages/5f/41/7c8617da5d32e1d2f026e509484fdb6f3ad7efaef1749a0c1928adbb099e/pillow-12.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:34c0d99ecccea270c04882cb3b86e7b57296079c9a4aff88cb3b33563d95afaa", size = 7194324, upload-time = "2026-04-01T14:42:34.615Z" }, + { url = "https://files.pythonhosted.org/packages/2d/de/a777627e19fd6d62f84070ee1521adde5eeda4855b5cf60fe0b149118bca/pillow-12.2.0-cp310-cp310-win32.whl", hash = "sha256:b85f66ae9eb53e860a873b858b789217ba505e5e405a24b85c0464822fe88032", size = 6376363, upload-time = "2026-04-01T14:42:37.19Z" }, + { url = "https://files.pythonhosted.org/packages/e7/34/fc4cb5204896465842767b96d250c08410f01f2f28afc43b257de842eed5/pillow-12.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:673aa32138f3e7531ccdbca7b3901dba9b70940a19ccecc6a37c77d5fdeb05b5", size = 7083523, upload-time = "2026-04-01T14:42:39.62Z" }, + { url = "https://files.pythonhosted.org/packages/2d/a0/32852d36bc7709f14dc3f64f929a275e958ad8c19a6deba9610d458e28b3/pillow-12.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:3e080565d8d7c671db5802eedfb438e5565ffa40115216eabb8cd52d0ecce024", size = 2463318, upload-time = "2026-04-01T14:42:42.063Z" }, + { url = "https://files.pythonhosted.org/packages/68/e1/748f5663efe6edcfc4e74b2b93edfb9b8b99b67f21a854c3ae416500a2d9/pillow-12.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:8be29e59487a79f173507c30ddf57e733a357f67881430449bb32614075a40ab", size = 5354347, upload-time = "2026-04-01T14:42:44.255Z" }, + { url = "https://files.pythonhosted.org/packages/47/a1/d5ff69e747374c33a3b53b9f98cca7889fce1fd03d79cdc4e1bccc6c5a87/pillow-12.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:71cde9a1e1551df7d34a25462fc60325e8a11a82cc2e2f54578e5e9a1e153d65", size = 4695873, upload-time = "2026-04-01T14:42:46.452Z" }, + { url = "https://files.pythonhosted.org/packages/df/21/e3fbdf54408a973c7f7f89a23b2cb97a7ef30c61ab4142af31eee6aebc88/pillow-12.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f490f9368b6fc026f021db16d7ec2fbf7d89e2edb42e8ec09d2c60505f5729c7", size = 6280168, upload-time = "2026-04-01T14:42:49.228Z" }, + { url = "https://files.pythonhosted.org/packages/d3/f1/00b7278c7dd52b17ad4329153748f87b6756ec195ff786c2bdf12518337d/pillow-12.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8bd7903a5f2a4545f6fd5935c90058b89d30045568985a71c79f5fd6edf9b91e", size = 8088188, upload-time = "2026-04-01T14:42:51.735Z" }, + { url = "https://files.pythonhosted.org/packages/ad/cf/220a5994ef1b10e70e85748b75649d77d506499352be135a4989c957b701/pillow-12.2.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3997232e10d2920a68d25191392e3a4487d8183039e1c74c2297f00ed1c50705", size = 6394401, upload-time = "2026-04-01T14:42:54.343Z" }, + { url = "https://files.pythonhosted.org/packages/e9/bd/e51a61b1054f09437acfbc2ff9106c30d1eb76bc1453d428399946781253/pillow-12.2.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e74473c875d78b8e9d5da2a70f7099549f9eb37ded4e2f6a463e60125bccd176", size = 7079655, upload-time = "2026-04-01T14:42:56.954Z" }, + { url = "https://files.pythonhosted.org/packages/6b/3d/45132c57d5fb4b5744567c3817026480ac7fc3ce5d4c47902bc0e7f6f853/pillow-12.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:56a3f9c60a13133a98ecff6197af34d7824de9b7b38c3654861a725c970c197b", size = 6503105, upload-time = "2026-04-01T14:42:59.847Z" }, + { url = "https://files.pythonhosted.org/packages/7d/2e/9df2fc1e82097b1df3dce58dc43286aa01068e918c07574711fcc53e6fb4/pillow-12.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:90e6f81de50ad6b534cab6e5aef77ff6e37722b2f5d908686f4a5c9eba17a909", size = 7203402, upload-time = "2026-04-01T14:43:02.664Z" }, + { url = "https://files.pythonhosted.org/packages/bd/2e/2941e42858ebb67e50ae741473de81c2984e6eff7b397017623c676e2e8d/pillow-12.2.0-cp311-cp311-win32.whl", hash = "sha256:8c984051042858021a54926eb597d6ee3012393ce9c181814115df4c60b9a808", size = 6378149, upload-time = "2026-04-01T14:43:05.274Z" }, + { url = "https://files.pythonhosted.org/packages/69/42/836b6f3cd7f3e5fa10a1f1a5420447c17966044c8fbf589cc0452d5502db/pillow-12.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e6b2a0c538fc200b38ff9eb6628228b77908c319a005815f2dde585a0664b60", size = 7082626, upload-time = "2026-04-01T14:43:08.557Z" }, + { url = "https://files.pythonhosted.org/packages/c2/88/549194b5d6f1f494b485e493edc6693c0a16f4ada488e5bd974ed1f42fad/pillow-12.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:9a8a34cc89c67a65ea7437ce257cea81a9dad65b29805f3ecee8c8fe8ff25ffe", size = 2463531, upload-time = "2026-04-01T14:43:10.743Z" }, + { url = "https://files.pythonhosted.org/packages/58/be/7482c8a5ebebbc6470b3eb791812fff7d5e0216c2be3827b30b8bb6603ed/pillow-12.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2d192a155bbcec180f8564f693e6fd9bccff5a7af9b32e2e4bf8c9c69dbad6b5", size = 5308279, upload-time = "2026-04-01T14:43:13.246Z" }, + { url = "https://files.pythonhosted.org/packages/d8/95/0a351b9289c2b5cbde0bacd4a83ebc44023e835490a727b2a3bd60ddc0f4/pillow-12.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3f40b3c5a968281fd507d519e444c35f0ff171237f4fdde090dd60699458421", size = 4695490, upload-time = "2026-04-01T14:43:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987", size = 6284462, upload-time = "2026-04-01T14:43:18.268Z" }, + { url = "https://files.pythonhosted.org/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76", size = 8094744, upload-time = "2026-04-01T14:43:20.716Z" }, + { url = "https://files.pythonhosted.org/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005", size = 6398371, upload-time = "2026-04-01T14:43:23.443Z" }, + { url = "https://files.pythonhosted.org/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780", size = 7087215, upload-time = "2026-04-01T14:43:26.758Z" }, + { url = "https://files.pythonhosted.org/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5", size = 6509783, upload-time = "2026-04-01T14:43:29.56Z" }, + { url = "https://files.pythonhosted.org/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5", size = 7212112, upload-time = "2026-04-01T14:43:32.091Z" }, + { url = "https://files.pythonhosted.org/packages/be/42/025cfe05d1be22dbfdb4f264fe9de1ccda83f66e4fc3aac94748e784af04/pillow-12.2.0-cp312-cp312-win32.whl", hash = "sha256:58f62cc0f00fd29e64b29f4fd923ffdb3859c9f9e6105bfc37ba1d08994e8940", size = 6378489, upload-time = "2026-04-01T14:43:34.601Z" }, + { url = "https://files.pythonhosted.org/packages/5d/7b/25a221d2c761c6a8ae21bfa3874988ff2583e19cf8a27bf2fee358df7942/pillow-12.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f84204dee22a783350679a0333981df803dac21a0190d706a50475e361c93f5", size = 7084129, upload-time = "2026-04-01T14:43:37.213Z" }, + { url = "https://files.pythonhosted.org/packages/10/e1/542a474affab20fd4a0f1836cb234e8493519da6b76899e30bcc5d990b8b/pillow-12.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:af73337013e0b3b46f175e79492d96845b16126ddf79c438d7ea7ff27783a414", size = 2463612, upload-time = "2026-04-01T14:43:39.421Z" }, + { url = "https://files.pythonhosted.org/packages/4a/01/53d10cf0dbad820a8db274d259a37ba50b88b24768ddccec07355382d5ad/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8297651f5b5679c19968abefd6bb84d95fe30ef712eb1b2d9b2d31ca61267f4c", size = 4100837, upload-time = "2026-04-01T14:43:41.506Z" }, + { url = "https://files.pythonhosted.org/packages/0f/98/f3a6657ecb698c937f6c76ee564882945f29b79bad496abcba0e84659ec5/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:50d8520da2a6ce0af445fa6d648c4273c3eeefbc32d7ce049f22e8b5c3daecc2", size = 4176528, upload-time = "2026-04-01T14:43:43.773Z" }, + { url = "https://files.pythonhosted.org/packages/69/bc/8986948f05e3ea490b8442ea1c1d4d990b24a7e43d8a51b2c7d8b1dced36/pillow-12.2.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:766cef22385fa1091258ad7e6216792b156dc16d8d3fa607e7545b2b72061f1c", size = 3640401, upload-time = "2026-04-01T14:43:45.87Z" }, + { url = "https://files.pythonhosted.org/packages/34/46/6c717baadcd62bc8ed51d238d521ab651eaa74838291bda1f86fe1f864c9/pillow-12.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5d2fd0fa6b5d9d1de415060363433f28da8b1526c1c129020435e186794b3795", size = 5308094, upload-time = "2026-04-01T14:43:48.438Z" }, + { url = "https://files.pythonhosted.org/packages/71/43/905a14a8b17fdb1ccb58d282454490662d2cb89a6bfec26af6d3520da5ec/pillow-12.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:56b25336f502b6ed02e889f4ece894a72612fe885889a6e8c4c80239ff6e5f5f", size = 4695402, upload-time = "2026-04-01T14:43:51.292Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/42107efcb777b16fa0393317eac58f5b5cf30e8392e266e76e51cff28c3d/pillow-12.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f1c943e96e85df3d3478f7b691f229887e143f81fedab9b20205349ab04d73ed", size = 6280005, upload-time = "2026-04-01T14:43:54.242Z" }, + { url = "https://files.pythonhosted.org/packages/a8/68/b93e09e5e8549019e61acf49f65b1a8530765a7f812c77a7461bca7e4494/pillow-12.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:03f6fab9219220f041c74aeaa2939ff0062bd5c364ba9ce037197f4c6d498cd9", size = 8090669, upload-time = "2026-04-01T14:43:57.335Z" }, + { url = "https://files.pythonhosted.org/packages/4b/6e/3ccb54ce8ec4ddd1accd2d89004308b7b0b21c4ac3d20fa70af4760a4330/pillow-12.2.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5cdfebd752ec52bf5bb4e35d9c64b40826bc5b40a13df7c3cda20a2c03a0f5ed", size = 6395194, upload-time = "2026-04-01T14:43:59.864Z" }, + { url = "https://files.pythonhosted.org/packages/67/ee/21d4e8536afd1a328f01b359b4d3997b291ffd35a237c877b331c1c3b71c/pillow-12.2.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eedf4b74eda2b5a4b2b2fb4c006d6295df3bf29e459e198c90ea48e130dc75c3", size = 7082423, upload-time = "2026-04-01T14:44:02.74Z" }, + { url = "https://files.pythonhosted.org/packages/78/5f/e9f86ab0146464e8c133fe85df987ed9e77e08b29d8d35f9f9f4d6f917ba/pillow-12.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:00a2865911330191c0b818c59103b58a5e697cae67042366970a6b6f1b20b7f9", size = 6505667, upload-time = "2026-04-01T14:44:05.381Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1e/409007f56a2fdce61584fd3acbc2bbc259857d555196cedcadc68c015c82/pillow-12.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1e1757442ed87f4912397c6d35a0db6a7b52592156014706f17658ff58bbf795", size = 7208580, upload-time = "2026-04-01T14:44:08.39Z" }, + { url = "https://files.pythonhosted.org/packages/23/c4/7349421080b12fb35414607b8871e9534546c128a11965fd4a7002ccfbee/pillow-12.2.0-cp313-cp313-win32.whl", hash = "sha256:144748b3af2d1b358d41286056d0003f47cb339b8c43a9ea42f5fea4d8c66b6e", size = 6375896, upload-time = "2026-04-01T14:44:11.197Z" }, + { url = "https://files.pythonhosted.org/packages/3f/82/8a3739a5e470b3c6cbb1d21d315800d8e16bff503d1f16b03a4ec3212786/pillow-12.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:390ede346628ccc626e5730107cde16c42d3836b89662a115a921f28440e6a3b", size = 7081266, upload-time = "2026-04-01T14:44:13.947Z" }, + { url = "https://files.pythonhosted.org/packages/c3/25/f968f618a062574294592f668218f8af564830ccebdd1fa6200f598e65c5/pillow-12.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:8023abc91fba39036dbce14a7d6535632f99c0b857807cbbbf21ecc9f4717f06", size = 2463508, upload-time = "2026-04-01T14:44:16.312Z" }, + { url = "https://files.pythonhosted.org/packages/4d/a4/b342930964e3cb4dce5038ae34b0eab4653334995336cd486c5a8c25a00c/pillow-12.2.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:042db20a421b9bafecc4b84a8b6e444686bd9d836c7fd24542db3e7df7baad9b", size = 5309927, upload-time = "2026-04-01T14:44:18.89Z" }, + { url = "https://files.pythonhosted.org/packages/9f/de/23198e0a65a9cf06123f5435a5d95cea62a635697f8f03d134d3f3a96151/pillow-12.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd025009355c926a84a612fecf58bb315a3f6814b17ead51a8e48d3823d9087f", size = 4698624, upload-time = "2026-04-01T14:44:21.115Z" }, + { url = "https://files.pythonhosted.org/packages/01/a6/1265e977f17d93ea37aa28aa81bad4fa597933879fac2520d24e021c8da3/pillow-12.2.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88ddbc66737e277852913bd1e07c150cc7bb124539f94c4e2df5344494e0a612", size = 6321252, upload-time = "2026-04-01T14:44:23.663Z" }, + { url = "https://files.pythonhosted.org/packages/3c/83/5982eb4a285967baa70340320be9f88e57665a387e3a53a7f0db8231a0cd/pillow-12.2.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d362d1878f00c142b7e1a16e6e5e780f02be8195123f164edf7eddd911eefe7c", size = 8126550, upload-time = "2026-04-01T14:44:26.772Z" }, + { url = "https://files.pythonhosted.org/packages/4e/48/6ffc514adce69f6050d0753b1a18fd920fce8cac87620d5a31231b04bfc5/pillow-12.2.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c727a6d53cb0018aadd8018c2b938376af27914a68a492f59dfcaca650d5eea", size = 6433114, upload-time = "2026-04-01T14:44:29.615Z" }, + { url = "https://files.pythonhosted.org/packages/36/a3/f9a77144231fb8d40ee27107b4463e205fa4677e2ca2548e14da5cf18dce/pillow-12.2.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efd8c21c98c5cc60653bcb311bef2ce0401642b7ce9d09e03a7da87c878289d4", size = 7115667, upload-time = "2026-04-01T14:44:32.773Z" }, + { url = "https://files.pythonhosted.org/packages/c1/fc/ac4ee3041e7d5a565e1c4fd72a113f03b6394cc72ab7089d27608f8aaccb/pillow-12.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f08483a632889536b8139663db60f6724bfcb443c96f1b18855860d7d5c0fd4", size = 6538966, upload-time = "2026-04-01T14:44:35.252Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a8/27fb307055087f3668f6d0a8ccb636e7431d56ed0750e07a60547b1e083e/pillow-12.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dac8d77255a37e81a2efcbd1fc05f1c15ee82200e6c240d7e127e25e365c39ea", size = 7238241, upload-time = "2026-04-01T14:44:37.875Z" }, + { url = "https://files.pythonhosted.org/packages/ad/4b/926ab182c07fccae9fcb120043464e1ff1564775ec8864f21a0ebce6ac25/pillow-12.2.0-cp313-cp313t-win32.whl", hash = "sha256:ee3120ae9dff32f121610bb08e4313be87e03efeadfc6c0d18f89127e24d0c24", size = 6379592, upload-time = "2026-04-01T14:44:40.336Z" }, + { url = "https://files.pythonhosted.org/packages/c2/c4/f9e476451a098181b30050cc4c9a3556b64c02cf6497ea421ac047e89e4b/pillow-12.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:325ca0528c6788d2a6c3d40e3568639398137346c3d6e66bb61db96b96511c98", size = 7085542, upload-time = "2026-04-01T14:44:43.251Z" }, + { url = "https://files.pythonhosted.org/packages/00/a4/285f12aeacbe2d6dc36c407dfbbe9e96d4a80b0fb710a337f6d2ad978c75/pillow-12.2.0-cp313-cp313t-win_arm64.whl", hash = "sha256:2e5a76d03a6c6dcef67edabda7a52494afa4035021a79c8558e14af25313d453", size = 2465765, upload-time = "2026-04-01T14:44:45.996Z" }, + { url = "https://files.pythonhosted.org/packages/bf/98/4595daa2365416a86cb0d495248a393dfc84e96d62ad080c8546256cb9c0/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:3adc9215e8be0448ed6e814966ecf3d9952f0ea40eb14e89a102b87f450660d8", size = 4100848, upload-time = "2026-04-01T14:44:48.48Z" }, + { url = "https://files.pythonhosted.org/packages/0b/79/40184d464cf89f6663e18dfcf7ca21aae2491fff1a16127681bf1fa9b8cf/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:6a9adfc6d24b10f89588096364cc726174118c62130c817c2837c60cf08a392b", size = 4176515, upload-time = "2026-04-01T14:44:51.353Z" }, + { url = "https://files.pythonhosted.org/packages/b0/63/703f86fd4c422a9cf722833670f4f71418fb116b2853ff7da722ea43f184/pillow-12.2.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:6a6e67ea2e6feda684ed370f9a1c52e7a243631c025ba42149a2cc5934dec295", size = 3640159, upload-time = "2026-04-01T14:44:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/71/e0/fb22f797187d0be2270f83500aab851536101b254bfa1eae10795709d283/pillow-12.2.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2bb4a8d594eacdfc59d9e5ad972aa8afdd48d584ffd5f13a937a664c3e7db0ed", size = 5312185, upload-time = "2026-04-01T14:44:56.039Z" }, + { url = "https://files.pythonhosted.org/packages/ba/8c/1a9e46228571de18f8e28f16fabdfc20212a5d019f3e3303452b3f0a580d/pillow-12.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:80b2da48193b2f33ed0c32c38140f9d3186583ce7d516526d462645fd98660ae", size = 4695386, upload-time = "2026-04-01T14:44:58.663Z" }, + { url = "https://files.pythonhosted.org/packages/70/62/98f6b7f0c88b9addd0e87c217ded307b36be024d4ff8869a812b241d1345/pillow-12.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22db17c68434de69d8ecfc2fe821569195c0c373b25cccb9cbdacf2c6e53c601", size = 6280384, upload-time = "2026-04-01T14:45:01.5Z" }, + { url = "https://files.pythonhosted.org/packages/5e/03/688747d2e91cfbe0e64f316cd2e8005698f76ada3130d0194664174fa5de/pillow-12.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7b14cc0106cd9aecda615dd6903840a058b4700fcb817687d0ee4fc8b6e389be", size = 8091599, upload-time = "2026-04-01T14:45:04.5Z" }, + { url = "https://files.pythonhosted.org/packages/f6/35/577e22b936fcdd66537329b33af0b4ccfefaeabd8aec04b266528cddb33c/pillow-12.2.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cbeb542b2ebc6fcdacabf8aca8c1a97c9b3ad3927d46b8723f9d4f033288a0f", size = 6396021, upload-time = "2026-04-01T14:45:07.117Z" }, + { url = "https://files.pythonhosted.org/packages/11/8d/d2532ad2a603ca2b93ad9f5135732124e57811d0168155852f37fbce2458/pillow-12.2.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bfd07bc812fbd20395212969e41931001fd59eb55a60658b0e5710872e95286", size = 7083360, upload-time = "2026-04-01T14:45:09.763Z" }, + { url = "https://files.pythonhosted.org/packages/5e/26/d325f9f56c7e039034897e7380e9cc202b1e368bfd04d4cbe6a441f02885/pillow-12.2.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9aba9a17b623ef750a4d11b742cbafffeb48a869821252b30ee21b5e91392c50", size = 6507628, upload-time = "2026-04-01T14:45:12.378Z" }, + { url = "https://files.pythonhosted.org/packages/5f/f7/769d5632ffb0988f1c5e7660b3e731e30f7f8ec4318e94d0a5d674eb65a4/pillow-12.2.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:deede7c263feb25dba4e82ea23058a235dcc2fe1f6021025dc71f2b618e26104", size = 7209321, upload-time = "2026-04-01T14:45:15.122Z" }, + { url = "https://files.pythonhosted.org/packages/6a/7a/c253e3c645cd47f1aceea6a8bacdba9991bf45bb7dfe927f7c893e89c93c/pillow-12.2.0-cp314-cp314-win32.whl", hash = "sha256:632ff19b2778e43162304d50da0181ce24ac5bb8180122cbe1bf4673428328c7", size = 6479723, upload-time = "2026-04-01T14:45:17.797Z" }, + { url = "https://files.pythonhosted.org/packages/cd/8b/601e6566b957ca50e28725cb6c355c59c2c8609751efbecd980db44e0349/pillow-12.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:4e6c62e9d237e9b65fac06857d511e90d8461a32adcc1b9065ea0c0fa3a28150", size = 7217400, upload-time = "2026-04-01T14:45:20.529Z" }, + { url = "https://files.pythonhosted.org/packages/d6/94/220e46c73065c3e2951bb91c11a1fb636c8c9ad427ac3ce7d7f3359b9b2f/pillow-12.2.0-cp314-cp314-win_arm64.whl", hash = "sha256:b1c1fbd8a5a1af3412a0810d060a78b5136ec0836c8a4ef9aa11807f2a22f4e1", size = 2554835, upload-time = "2026-04-01T14:45:23.162Z" }, + { url = "https://files.pythonhosted.org/packages/b6/ab/1b426a3974cb0e7da5c29ccff4807871d48110933a57207b5a676cccc155/pillow-12.2.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:57850958fe9c751670e49b2cecf6294acc99e562531f4bd317fa5ddee2068463", size = 5314225, upload-time = "2026-04-01T14:45:25.637Z" }, + { url = "https://files.pythonhosted.org/packages/19/1e/dce46f371be2438eecfee2a1960ee2a243bbe5e961890146d2dee1ff0f12/pillow-12.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:d5d38f1411c0ed9f97bcb49b7bd59b6b7c314e0e27420e34d99d844b9ce3b6f3", size = 4698541, upload-time = "2026-04-01T14:45:28.355Z" }, + { url = "https://files.pythonhosted.org/packages/55/c3/7fbecf70adb3a0c33b77a300dc52e424dc22ad8cdc06557a2e49523b703d/pillow-12.2.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c0a9f29ca8e79f09de89293f82fc9b0270bb4af1d58bc98f540cc4aedf03166", size = 6322251, upload-time = "2026-04-01T14:45:30.924Z" }, + { url = "https://files.pythonhosted.org/packages/1c/3c/7fbc17cfb7e4fe0ef1642e0abc17fc6c94c9f7a16be41498e12e2ba60408/pillow-12.2.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1610dd6c61621ae1cf811bef44d77e149ce3f7b95afe66a4512f8c59f25d9ebe", size = 8127807, upload-time = "2026-04-01T14:45:33.908Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c3/a8ae14d6defd2e448493ff512fae903b1e9bd40b72efb6ec55ce0048c8ce/pillow-12.2.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a34329707af4f73cf1782a36cd2289c0368880654a2c11f027bcee9052d35dd", size = 6433935, upload-time = "2026-04-01T14:45:36.623Z" }, + { url = "https://files.pythonhosted.org/packages/6e/32/2880fb3a074847ac159d8f902cb43278a61e85f681661e7419e6596803ed/pillow-12.2.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e9c4f5b3c546fa3458a29ab22646c1c6c787ea8f5ef51300e5a60300736905e", size = 7116720, upload-time = "2026-04-01T14:45:39.258Z" }, + { url = "https://files.pythonhosted.org/packages/46/87/495cc9c30e0129501643f24d320076f4cc54f718341df18cc70ec94c44e1/pillow-12.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fb043ee2f06b41473269765c2feae53fc2e2fbf96e5e22ca94fb5ad677856f06", size = 6540498, upload-time = "2026-04-01T14:45:41.879Z" }, + { url = "https://files.pythonhosted.org/packages/18/53/773f5edca692009d883a72211b60fdaf8871cbef075eaa9d577f0a2f989e/pillow-12.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f278f034eb75b4e8a13a54a876cc4a5ab39173d2cdd93a638e1b467fc545ac43", size = 7239413, upload-time = "2026-04-01T14:45:44.705Z" }, + { url = "https://files.pythonhosted.org/packages/c9/e4/4b64a97d71b2a83158134abbb2f5bd3f8a2ea691361282f010998f339ec7/pillow-12.2.0-cp314-cp314t-win32.whl", hash = "sha256:6bb77b2dcb06b20f9f4b4a8454caa581cd4dd0643a08bacf821216a16d9c8354", size = 6482084, upload-time = "2026-04-01T14:45:47.568Z" }, + { url = "https://files.pythonhosted.org/packages/ba/13/306d275efd3a3453f72114b7431c877d10b1154014c1ebbedd067770d629/pillow-12.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6562ace0d3fb5f20ed7290f1f929cae41b25ae29528f2af1722966a0a02e2aa1", size = 7225152, upload-time = "2026-04-01T14:45:50.032Z" }, + { url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" }, + { url = "https://files.pythonhosted.org/packages/4e/b7/2437044fb910f499610356d1352e3423753c98e34f915252aafecc64889f/pillow-12.2.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0538bd5e05efec03ae613fd89c4ce0368ecd2ba239cc25b9f9be7ed426b0af1f", size = 5273969, upload-time = "2026-04-01T14:45:55.538Z" }, + { url = "https://files.pythonhosted.org/packages/f6/f4/8316e31de11b780f4ac08ef3654a75555e624a98db1056ecb2122d008d5a/pillow-12.2.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:394167b21da716608eac917c60aa9b969421b5dcbbe02ae7f013e7b85811c69d", size = 4659674, upload-time = "2026-04-01T14:45:58.093Z" }, + { url = "https://files.pythonhosted.org/packages/d4/37/664fca7201f8bb2aa1d20e2c3d5564a62e6ae5111741966c8319ca802361/pillow-12.2.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d04bfa02cc2d23b497d1e90a0f927070043f6cbf303e738300532379a4b4e0f", size = 5288479, upload-time = "2026-04-01T14:46:01.141Z" }, + { url = "https://files.pythonhosted.org/packages/49/62/5b0ed78fce87346be7a5cfcfaaad91f6a1f98c26f86bdbafa2066c647ef6/pillow-12.2.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0c838a5125cee37e68edec915651521191cef1e6aa336b855f495766e77a366e", size = 7032230, upload-time = "2026-04-01T14:46:03.874Z" }, + { url = "https://files.pythonhosted.org/packages/c3/28/ec0fc38107fc32536908034e990c47914c57cd7c5a3ece4d8d8f7ffd7e27/pillow-12.2.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a6c9fa44005fa37a91ebfc95d081e8079757d2e904b27103f4f5fa6f0bf78c0", size = 5355404, upload-time = "2026-04-01T14:46:06.33Z" }, + { url = "https://files.pythonhosted.org/packages/5e/8b/51b0eddcfa2180d60e41f06bd6d0a62202b20b59c68f5a132e615b75aecf/pillow-12.2.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25373b66e0dd5905ed63fa3cae13c82fbddf3079f2c8bf15c6fb6a35586324c1", size = 6002215, upload-time = "2026-04-01T14:46:08.83Z" }, + { url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" }, +] + [[package]] name = "pinecone" version = "9.0.0" @@ -20746,8 +21028,8 @@ name = "secretstorage" version = "3.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cryptography", marker = "python_full_version >= '3.14' or platform_machine != 'arm64' or sys_platform != 'darwin'" }, - { name = "jeepney", marker = "python_full_version >= '3.14' or platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "cryptography", marker = "(python_full_version >= '3.14' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'emscripten') or (python_full_version < '3.15' and sys_platform == 'win32') or (platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "jeepney", marker = "(python_full_version >= '3.14' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'emscripten') or (python_full_version < '3.15' and sys_platform == 'win32') or (platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/03/e834bcd866f2f8a49a85eaff47340affa3bfa391ee9912a952a1faa68c7b/secretstorage-3.5.0.tar.gz", hash = "sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be", size = 19884, upload-time = "2025-11-23T19:02:53.191Z" } wheels = [ @@ -22082,6 +22364,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, ] +[[package]] +name = "tinytag" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/59/8a8cb2331e2602b53e4dc06960f57d1387a2b18e7efd24e5f9cb60ea4925/tinytag-2.2.1.tar.gz", hash = "sha256:e6d06610ebe7cd66fd07be2d3b9495914ab32654a5e47657bb8cd44c2484523c", size = 38214, upload-time = "2026-03-15T18:48:01.11Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/34/d50e338631baaf65ec5396e70085e5de0b52b24b28db1ffbc1c6e82190dc/tinytag-2.2.1-py3-none-any.whl", hash = "sha256:ed8b1e6d25367937e3321e054f4974f9abfde1a3e0a538824c87da377130c2b6", size = 32927, upload-time = "2026-03-15T18:47:59.613Z" }, +] + [[package]] name = "tokenizers" version = "0.23.1" From 3b31d957d1276e82addebeabdae8e822851f8257 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 May 2026 04:06:21 +0100 Subject: [PATCH 7/7] Apply ruff lint/format fixes --- .../providers/common/ai/operators/llamaindex_embedding.py | 7 +++---- .../providers/common/ai/operators/llamaindex_retrieval.py | 3 ++- .../unit/common/ai/operators/test_llamaindex_embedding.py | 8 ++------ 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index 54ab47ecf461f..d85e692100202 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -28,10 +28,11 @@ ) if TYPE_CHECKING: - from airflow.sdk import Context from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.schema import TextNode + from airflow.sdk import Context + class LlamaIndexEmbeddingOperator(BaseOperator): """ @@ -118,9 +119,7 @@ def execute(self, context: Context) -> dict[str, Any]: embed_model = self._resolve_embed_model() - llama_docs = [ - Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents - ] + llama_docs = [Document(text=doc["text"], metadata=doc.get("metadata", {})) for doc in self.documents] splitter = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) nodes = splitter.get_nodes_from_documents(llama_docs) diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py index 881d20c3d3d09..9725cdace48d1 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_retrieval.py @@ -28,9 +28,10 @@ ) if TYPE_CHECKING: - from airflow.sdk import Context from llama_index.core.base.embeddings.base import BaseEmbedding + from airflow.sdk import Context + class LlamaIndexRetrievalOperator(BaseOperator): """ diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index b5b06da6b9fda..43b44f87c9ff4 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -183,9 +183,7 @@ def test_local_persist_dir_calls_makedirs_and_storage_persist( @patch("airflow.sdk.ObjectStoragePath") @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") - def test_cloud_uri_persist_dir_uses_object_storage_path( - self, mock_get_embed, mock_osp_cls, _li - ): + def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mock_osp_cls, _li): # ``ObjectStoragePath.__str__`` returns ``://@/...`` # when ``conn_id`` is set, which fsspec misinterprets. The operator must # pass the **raw** user URI to ``persist_dir=`` and supply @@ -210,6 +208,4 @@ def test_cloud_uri_persist_dir_uses_object_storage_path( mock_osp_cls.assert_called_once_with("s3://bucket/idx/", conn_id="aws_default") target.mkdir.assert_called_once_with(parents=True, exist_ok=True) - index.storage_context.persist.assert_called_once_with( - persist_dir="s3://bucket/idx/", fs=target.fs - ) + index.storage_context.persist.assert_called_once_with(persist_dir="s3://bucket/idx/", fs=target.fs)