diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index f678f056db2c2..98cf2333a46de 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -2692,7 +2692,7 @@ def test_upgrade_to_newer_dependencies( { "docs-list-as-string": "amazon apache.drill apache.druid apache.hive apache.iceberg " "apache.impala apache.pinot clickhousedb common.ai common.compat common.sql databricks elasticsearch " - "exasol google jdbc microsoft.mssql mysql odbc openlineage " + "exasol google informatica jdbc microsoft.mssql mysql odbc openlineage " "oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica ydb", }, id="Common SQL provider package python files changed", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 67a7e1fa01b2c..a108634d229db 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1386,6 +1386,7 @@ renderer renderers renewer reparse +reparsed replicaSet repo repos diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 969a4e55f590d..bcbd5925437c2 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1136,9 +1136,11 @@ "deps": [ "apache-airflow-providers-common-compat>=1.12.0", "apache-airflow-providers-http>=4.13.2", - "apache-airflow>=3.0.0" + "apache-airflow>=3.0.0", + "sqlglot>=30.0.0" ], "devel-deps": [ + "sqlglot>=30.0.0", "uuid6>=2024.7.10" ], "plugins": [ @@ -1149,6 +1151,7 @@ ], "cross-providers-deps": [ "common.compat", + "common.sql", "http" ], "excluded-python-versions": [], diff --git a/providers/common/ai/src/airflow/providers/common/ai/utils/dataquality/__init__.py b/providers/common/ai/src/airflow/providers/common/ai/utils/dataquality/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/utils/dataquality/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/informatica/README.rst b/providers/informatica/README.rst index 0f78b5043342c..8c5bf20a04438 100644 --- a/providers/informatica/README.rst +++ b/providers/informatica/README.rst @@ -56,6 +56,7 @@ PIP package Version required ``apache-airflow`` ``>=3.0.0`` ``apache-airflow-providers-common-compat`` ``>=1.12.0`` ``apache-airflow-providers-http`` ``>=4.13.2`` +``sqlglot`` ``>=30.0.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/informatica/dev/docker-compose.yml b/providers/informatica/dev/docker-compose.yml new file mode 100644 index 0000000000000..fb8c68c09fb25 --- /dev/null +++ b/providers/informatica/dev/docker-compose.yml @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- + +# Usage: +# docker-compose up -d +# +# Services: +# postgres : Source/target database for lineage test DAGs (port 55433) +# informatica_sim : Informatica EDC simulator (port 8082) +# +# Airflow connection for the simulator (add via UI or env var): +# Conn ID : informatica_edc_default +# Conn Type : HTTP +# Host : informatica_sim (inside Breeze) or localhost (outside) +# Port : 8082 +# Schema : http +# +# Airflow connection for Postgres (add via UI or env var): +# Conn ID : postgres_lineage +# Conn Type : Postgres +# Host : informatica-lineage-postgres (inside Breeze) or localhost (outside) +# Port : 5433 (host-side mapping) +# Schema : lineage_demo +# Login : airflow +# Password : airflow + +services: + postgres: + image: postgres:16 + container_name: informatica-lineage-postgres + restart: unless-stopped + environment: + POSTGRES_DB: lineage_demo + POSTGRES_USER: airflow + POSTGRES_PASSWORD: airflow + TZ: UTC + ports: + - "55433:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U airflow -d lineage_demo"] + interval: 5s + timeout: 5s + retries: 10 + volumes: + - ./init:/docker-entrypoint-initdb.d:ro + - informatica-lineage-postgres-data:/var/lib/postgresql/data + networks: + - default + - breeze + + informatica_sim: + build: + context: ./informatica_simulator + dockerfile: Dockerfile + container_name: informatica-edc-sim + restart: unless-stopped + ports: + - "8082:8082" + healthcheck: + test: ["CMD-SHELL", "curl -sf http://localhost:8082/access || exit 1"] + interval: 10s + timeout: 5s + retries: 6 + networks: + - default + - breeze + +volumes: + informatica-lineage-postgres-data: + +networks: + breeze: + external: true + name: ${BREEZE_DOCKER_NETWORK:-breeze_default} diff --git a/providers/informatica/dev/informatica_simulator/Dockerfile b/providers/informatica/dev/informatica_simulator/Dockerfile new file mode 100644 index 0000000000000..958ab06f099e0 --- /dev/null +++ b/providers/informatica/dev/informatica_simulator/Dockerfile @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Get the image +FROM python:3.10.10 + +# Set the working directory +WORKDIR /code + +# Copy the requirements package +COPY ./requirements.txt /code/requirements.txt + +# Download the python package requirements in build process +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# Copy the application +COPY ./app /code/app + +# Expose port 8082 +EXPOSE 8082 + +# Start the server when the container is launched +CMD ["uvicorn", "app.main:app","--proxy-headers", "--host", "0.0.0.0", "--port", "8082"] diff --git a/providers/informatica/dev/informatica_simulator/app/main.py b/providers/informatica/dev/informatica_simulator/app/main.py new file mode 100644 index 0000000000000..8e8687c83f006 --- /dev/null +++ b/providers/informatica/dev/informatica_simulator/app/main.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Informatica EDC Simulator. + +A lightweight FastAPI application that mimics the Informatica Enterprise Data +Catalog (EDC) REST API. It is intended for local development and Breeze-based +integration testing of the Airflow Informatica provider. + +Implemented endpoints +--------------------- +GET /access + Health-check — returns {"status": "ok"}. + +GET /access/2/catalog/data/search + Simulates the EDC catalog search used by InformaticaEDCHook to resolve + URIs. Accepts the same query-string parameters as the real API and + returns a hit whose ``id`` matches the ``fq`` filter when supplied. + +GET /access/2/catalog/data/objects/{object_id} + Returns a minimal catalog object. The object is created on first access + and stored in the in-memory registry. + +PATCH /access/1/catalog/data/objects + Accepts the lineage-link payload sent by + ``InformaticaEDCHook.create_lineage_link`` and records the relationship + in memory. + +GET /lineage + Development-only — returns all recorded lineage links. + +DELETE /lineage + Clears all stored lineage links and catalog objects. +""" + +from __future__ import annotations + +import re +import urllib.parse +import uuid +from typing import Any + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +app = FastAPI(title="Informatica EDC Simulator", version="1.0.0") + +# --------------------------------------------------------------------------- +# In-memory stores +# --------------------------------------------------------------------------- + +_catalog: dict[str, dict[str, Any]] = { + "TEST_PSTGRS://mydb/public/orders": { + "id": "TEST_PSTGRS://mydb/public/orders", + "name": "orders", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "orders", + "classType": "com.infa.ldm.relational.Table", + }, + }, + "TEST_PSTGRS://mydb/public/customers": { + "id": "TEST_PSTGRS://mydb/public/customers", + "name": "customers", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "customers", + "classType": "com.infa.ldm.relational.Table", + }, + }, + "TEST_PSTGRS://mydb/public/products": { + "id": "TEST_PSTGRS://mydb/public/products", + "name": "products", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "products", + "classType": "com.infa.ldm.relational.Table", + }, + }, + "TEST_PSTGRS://mydb/public/order_summary": { + "id": "TEST_PSTGRS://mydb/public/order_summary", + "name": "order_summary", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "order_summary", + "classType": "com.infa.ldm.relational.Table", + }, + }, + "TEST_PSTGRS://mydb/public/customer_ltv": { + "id": "TEST_PSTGRS://mydb/public/customer_ltv", + "name": "customer_ltv", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "customer_ltv", + "classType": "com.infa.ldm.relational.Table", + }, + }, + "TEST_PSTGRS://mydb/public/customer_segment_snapshot": { + "id": "TEST_PSTGRS://mydb/public/customer_segment_snapshot", + "name": "customer_segment_snapshot", + "type": "com.infa.ldm.relational.Table", + "core": { + "name": "customer_segment_snapshot", + "classType": "com.infa.ldm.relational.Table", + }, + }, +} +_lineage_links: list[dict[str, str]] = [] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _decode_id(encoded_id: str) -> str: + """Reverse the tilde-encoding used by InformaticaEDCHook._encode_id.""" + + def replace_tilde(m: re.Match) -> str: + return bytes.fromhex(m.group(1)).decode("utf-8") + + decoded = re.sub(r"~([0-9a-f]+)~", replace_tilde, encoded_id) + return urllib.parse.unquote(decoded) + + +def _get_or_create_object(object_id: str) -> dict[str, Any]: + if object_id not in _catalog: + _catalog[object_id] = { + "id": object_id, + "name": object_id.split("/")[-1], + "type": "com.infa.ldm.relational.Table", + "core": { + "name": object_id.split("/")[-1], + "classType": "com.infa.ldm.relational.Table", + }, + } + return _catalog[object_id] + + +def _parse_class_types(filter_expr: str) -> set[str]: + """Extract class types from expressions like 'core.classType:A OR core.classType:B'.""" + class_types: set[str] = set() + for part in filter_expr.split(" OR "): + candidate = part.strip() + prefix = "core.classType:" + if candidate.startswith(prefix): + class_type = candidate[len(prefix) :].strip().strip('"') + if class_type: + class_types.add(class_type) + return class_types + + +def _extract_filter_value(filter_expr: str, key: str) -> str | None: + """Return value from a simple filter expression like 'core.name:orders'.""" + prefix = f"{key}:" + if not filter_expr.startswith(prefix): + return None + return filter_expr[len(prefix) :].strip().strip('"') + + +# --------------------------------------------------------------------------- +# Health-check +# --------------------------------------------------------------------------- + + +@app.get("/access") +def health_check(): + """EDC connectivity check used by InformaticaEDCHook.""" + return {"status": "ok"} + + +# --------------------------------------------------------------------------- +# Catalog search (GET /access/2/catalog/data/search) +# --------------------------------------------------------------------------- + + +@app.get("/access/2/catalog/data/search") +def edc_catalog_search( + request: Request, + q: str | None = None, + fq: list[str] | None = None, + fl: str = "core.name", + defaultFacets: bool = True, + disableSemanticSearch: bool = False, + enableLegacySearch: bool = False, + facet: bool = False, + highlight: bool = False, + includeRefObjects: bool = False, +): + """ + Simulate EDC catalog search. + + Supports the fq filters used by InformaticaEDCHook: + - core.externalId: + - core.classType: OR core.classType: + - core.name: + + Multiple fq filters are combined with AND semantics. + """ + class_type_filters: set[str] = set() + core_name_filter: str | None = None + external_id_filter: str | None = None + + raw_fq_list = request.query_params.getlist("fq") + fq_filters = raw_fq_list if raw_fq_list else (fq or []) + + for filter_expr in fq_filters: + external_id = _extract_filter_value(filter_expr, "core.externalId") + if external_id is not None: + external_id_filter = external_id + continue + + core_name = _extract_filter_value(filter_expr, "core.name") + if core_name is not None: + core_name_filter = core_name + continue + + class_type_filters.update(_parse_class_types(filter_expr)) + + if external_id_filter: + obj = _get_or_create_object(external_id_filter) + hits = [obj] + else: + hits = list(_catalog.values()) + + if class_type_filters: + hits = [obj for obj in hits if obj.get("core", {}).get("classType") in class_type_filters] + + # Treat empty core.name filter as wildcard to mimic permissive EDC behavior. + if core_name_filter: + hits = [obj for obj in hits if obj.get("core", {}).get("name") == core_name_filter] + + return { + "hits": hits, + "facets": [], + "count": len(hits), + "queryTime": 1, + } + + +# --------------------------------------------------------------------------- +# Fetch object by ID (GET /access/2/catalog/data/objects/{object_id}) +# --------------------------------------------------------------------------- + + +@app.get("/access/2/catalog/data/objects/{object_id:path}") +def get_catalog_object(object_id: str): + """ + Return a catalog object by its (possibly tilde-encoded) ID. + + InformaticaEDCHook calls this endpoint after searching; the response must + contain at least ``{"id": "", ...}`` for lineage to proceed. + """ + decoded_id = _decode_id(object_id) + return _get_or_create_object(decoded_id) + + +# --------------------------------------------------------------------------- +# Create/update lineage (PATCH /access/1/catalog/data/objects) +# --------------------------------------------------------------------------- + + +@app.patch("/access/1/catalog/data/objects") +async def update_catalog_objects(request: Request): + """ + Accept a lineage-link payload from InformaticaEDCHook.create_lineage_link. + + Stores the relationship in memory and returns a minimal success envelope. + """ + try: + body: dict[str, Any] = await request.json() + except Exception: + raise HTTPException(status_code=400, detail="Invalid JSON body") + + updates = body.get("updates", []) + for update in updates: + target_id = update.get("id", "") + for link in update.get("newSourceLinks", []): + source_id = link.get("objectId", "") + association = link.get("associationId", "core.DataSetDataFlow") + _lineage_links.append( + { + "source": source_id, + "target": target_id, + "association": association, + } + ) + + return JSONResponse( + status_code=200, + content={"status": "success", "linksCreated": len(updates)}, + ) + + +# --------------------------------------------------------------------------- +# Inspection endpoints (dev-only) +# --------------------------------------------------------------------------- + + +@app.get("/lineage") +def list_lineage(): + """Return all recorded lineage links (dev helper).""" + return {"links": _lineage_links, "total": len(_lineage_links)} + + +@app.delete("/lineage") +def clear_lineage(): + """Clear all in-memory state (useful between test runs).""" + _lineage_links.clear() + _catalog.clear() + return {"status": "cleared"} + + +@app.get("/catalog") +def list_catalog(): + """Return all known catalog objects (dev helper).""" + return {"objects": list(_catalog.values()), "total": len(_catalog)} + + +# --------------------------------------------------------------------------- +# Legacy / compatibility endpoints kept from original simulator +# --------------------------------------------------------------------------- + + +class _Table(BaseModel): + name: str + columns: list[str] + + +class _Resource(BaseModel): + name: str + type: str + tables: list[_Table] = [] + + +_resources_db: dict[str, dict] = {} + + +@app.get("/resources") +def list_resources(): + return list(_resources_db.values()) + + +@app.post("/resources") +def create_resource(resource: _Resource): + resource_id = str(uuid.uuid4()) + _resources_db[resource_id] = resource.model_dump() + return {"resource_id": resource_id} + + +@app.post("/resources/{resource_id}/tables") +def add_table(resource_id: str, table: _Table): + if resource_id not in _resources_db: + raise HTTPException(status_code=404, detail="Resource not found") + _resources_db[resource_id]["tables"].append(table.model_dump()) + return table diff --git a/providers/informatica/dev/informatica_simulator/requirements.txt b/providers/informatica/dev/informatica_simulator/requirements.txt new file mode 100644 index 0000000000000..11eb36395c071 --- /dev/null +++ b/providers/informatica/dev/informatica_simulator/requirements.txt @@ -0,0 +1,8 @@ +aiohttp==3.12.14 +fastapi==0.95.1 +requests==2.32.4 +python-multipart +uvicorn==0.21.1 +pyarrow==14.0.1 +python-dateutil==2.8.2 +paramiko==3.4.0 diff --git a/providers/informatica/dev/init/001_schema_and_seed.sql b/providers/informatica/dev/init/001_schema_and_seed.sql new file mode 100644 index 0000000000000..9bda9ef374de5 --- /dev/null +++ b/providers/informatica/dev/init/001_schema_and_seed.sql @@ -0,0 +1,152 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +-- ============================================================ +-- Lineage demo schema +-- Source tables: customers, products, orders +-- Target tables: order_summary, customer_ltv +-- ============================================================ + +SET TIME ZONE 'UTC'; + +-- ------------------------------------------------------------ +-- Source tables +-- ------------------------------------------------------------ + +CREATE TABLE customers ( + customer_id BIGSERIAL PRIMARY KEY, + first_name TEXT NOT NULL, + last_name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + country TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE products ( + product_id BIGSERIAL PRIMARY KEY, + sku TEXT NOT NULL UNIQUE, + product_name TEXT NOT NULL, + category TEXT NOT NULL, + unit_price NUMERIC(12,2) NOT NULL +); + +CREATE TABLE orders ( + order_id BIGSERIAL PRIMARY KEY, + customer_id BIGINT NOT NULL REFERENCES customers(customer_id), + product_id BIGINT NOT NULL REFERENCES products(product_id), + quantity INT NOT NULL DEFAULT 1, + unit_price NUMERIC(12,2) NOT NULL, + order_date DATE NOT NULL, + status TEXT NOT NULL DEFAULT 'completed' +); + +CREATE TABLE order_items ( + item_id BIGSERIAL PRIMARY KEY, + order_id BIGINT NOT NULL REFERENCES orders(order_id), + product_id BIGINT NOT NULL REFERENCES products(product_id), + quantity INT NOT NULL, + line_total NUMERIC(12,2) NOT NULL +); + +-- ------------------------------------------------------------ +-- Target / aggregation tables (populated by DAG tasks) +-- ------------------------------------------------------------ + +CREATE TABLE order_summary ( + order_id BIGINT PRIMARY KEY, + customer_email TEXT NOT NULL, + product_name TEXT NOT NULL, + category TEXT NOT NULL, + quantity INT NOT NULL, + total_amount NUMERIC(12,2) NOT NULL, + order_date DATE NOT NULL, + country TEXT NOT NULL, + processed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE customer_ltv ( + customer_id BIGINT PRIMARY KEY, + email TEXT NOT NULL, + full_name TEXT NOT NULL, + country TEXT NOT NULL, + total_orders BIGINT NOT NULL DEFAULT 0, + total_spent NUMERIC(14,2) NOT NULL DEFAULT 0, + avg_order_value NUMERIC(14,2), + first_order_date DATE, + last_order_date DATE, + calculated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- ------------------------------------------------------------ +-- Seed data +-- ------------------------------------------------------------ + +INSERT INTO customers (first_name, last_name, email, country) VALUES + ('Alice', 'Johnson', 'alice@example.com', 'US'), + ('Bob', 'Smith', 'bob@example.com', 'UK'), + ('Carlos', 'Garcia', 'carlos@example.com', 'ES'), + ('Diana', 'Lee', 'diana@example.com', 'DE'), + ('Eve', 'Martinez', 'eve@example.com', 'FR'), + ('Frank', 'Brown', 'frank@example.com', 'US'), + ('Grace', 'Wilson', 'grace@example.com', 'CA'), + ('Hiro', 'Tanaka', 'hiro@example.com', 'JP'); + +INSERT INTO products (sku, product_name, category, unit_price) VALUES + ('LAPTOP-001', 'Pro Laptop 15"', 'Electronics', 1299.99), + ('PHONE-001', 'Smart Phone X', 'Electronics', 799.99), + ('DESK-001', 'Standing Desk', 'Furniture', 499.99), + ('CHAIR-001', 'Ergonomic Chair', 'Furniture', 349.99), + ('MONITOR-001', 'UltraWide 34"', 'Electronics', 699.99), + ('HEADSET-001', 'Noise-Cancel HS', 'Electronics', 249.99), + ('KEYBOARD-001','Mech Keyboard', 'Electronics', 149.99), + ('MOUSE-001', 'Wireless Mouse', 'Electronics', 69.99); + +INSERT INTO orders (customer_id, product_id, quantity, unit_price, order_date) VALUES + (1, 1, 1, 1299.99, '2024-01-05'), + (1, 6, 1, 249.99, '2024-01-10'), + (2, 2, 1, 799.99, '2024-01-12'), + (2, 7, 2, 149.99, '2024-02-01'), + (3, 3, 1, 499.99, '2024-02-14'), + (3, 4, 1, 349.99, '2024-02-14'), + (4, 5, 2, 699.99, '2024-03-03'), + (5, 1, 1, 1299.99, '2024-03-15'), + (5, 8, 1, 69.99, '2024-03-15'), + (6, 2, 1, 799.99, '2024-04-01'), + (6, 6, 1, 249.99, '2024-04-05'), + (7, 3, 2, 499.99, '2024-04-20'), + (8, 5, 1, 699.99, '2024-05-01'), + (8, 7, 1, 149.99, '2024-05-10'), + (1, 8, 3, 69.99, '2024-05-15'); + +INSERT INTO order_items (order_id, product_id, quantity, line_total) VALUES + (1, 1, 1, 1299.99), + (2, 6, 1, 249.99), + (3, 2, 1, 799.99), + (4, 7, 2, 299.98), + (5, 3, 1, 499.99), + (6, 4, 1, 349.99), + (7, 5, 2, 1399.98), + (8, 1, 1, 1299.99), + (9, 8, 1, 69.99), + (10, 2, 1, 799.99), + (11, 6, 1, 249.99), + (12, 3, 2, 999.98), + (13, 5, 1, 699.99), + (14, 7, 1, 149.99), + (15, 8, 3, 209.97); diff --git a/providers/informatica/docs/changelog.rst b/providers/informatica/docs/changelog.rst index 0afd84f5bfd62..40da64b0ef876 100644 --- a/providers/informatica/docs/changelog.rst +++ b/providers/informatica/docs/changelog.rst @@ -28,6 +28,19 @@ Changelog ========= +0.2.0 +..... + +Features +~~~~~~~~ + +* ``Add automatic SQL lineage detection via sqlglot for SQL operators`` +* ``Add per-task and per-DAG selective lineage control with disable/enable helpers`` +* ``Refactor lineage validation into shared logic: best-effort listener resolution plus optional strict pre_execute validation`` +* ``Add disabled_for_operators, auto_lineage_enabled, and request_timeout config options`` +* ``Add example DAG demonstrating auto, manual, disabled, and operator-excluded lineage modes`` + + 0.1.4 ..... diff --git a/providers/informatica/docs/guides/api.rst b/providers/informatica/docs/guides/api.rst index 917b7c5395783..86f022b0fa69a 100644 --- a/providers/informatica/docs/guides/api.rst +++ b/providers/informatica/docs/guides/api.rst @@ -101,7 +101,7 @@ Plugins **InformaticaProviderPlugin** ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The ``InformaticaProviderPlugin`` registers event listeners that monitor Airflow task lifecycle events (start, success, failure) and trigger lineage extraction and EDC API calls. This plugin is loaded automatically when the provider is installed and enabled. +The ``InformaticaProviderPlugin`` registers event listeners that monitor Airflow task lifecycle events (running, success, failure) and trigger lineage extraction and EDC API calls. Listener resolution is best-effort: resolution exceptions are logged, and task execution is not failed by listener hooks. Use ``pre_execute=validate_informatica_lineage`` on an operator when strict pre-execution failure is required. No manual instantiation is required. The plugin works transparently with any task that defines inlets and outlets. diff --git a/providers/informatica/docs/guides/configuration.rst b/providers/informatica/docs/guides/configuration.rst index 007772f57da7a..b4a3ab7febe70 100644 --- a/providers/informatica/docs/guides/configuration.rst +++ b/providers/informatica/docs/guides/configuration.rst @@ -50,14 +50,85 @@ Add to your ``airflow.cfg``: listener_disabled = False # The connection ID to use when no connection ID is provided default_conn_id = informatica_edc_default + # Enable automatic SQL lineage detection (parses the sql attribute of operators) + auto_lineage_enabled = True + # Semicolon-separated fully-qualified class names of operators to exclude from lineage + disabled_for_operators = + # HTTP request timeout in seconds for EDC API calls + request_timeout = 30 -Provider Configuration ----------------------- +``auto_lineage_enabled`` +~~~~~~~~~~~~~~~~~~~~~~~~ -The provider configuration is defined in ``get_provider_info.py`` and includes: +When ``True`` (default), the provider inspects each task's ``sql`` attribute before +execution, parses it with `sqlglot `_, resolves the discovered +tables against the Informatica catalog, and creates lineage links on task success. -- ``listener_disabled``: Boolean flag to disable sending events without uninstalling the provider -- ``default_conn_id``: Default connection ID for Informatica EDC +Set to ``False`` to rely exclusively on manually declared ``inlets`` and ``outlets``. + +``disabled_for_operators`` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A semicolon-separated list of fully-qualified Python class names. Operators whose +class matches an entry in this list are excluded entirely from lineage processing — +both automatic and manual inlets/outlets are ignored. + +Example: + +.. code-block:: ini + + [informatica] + disabled_for_operators = airflow.providers.standard.operators.bash.BashOperator;airflow.providers.standard.operators.python.PythonOperator + +``request_timeout`` +~~~~~~~~~~~~~~~~~~~ + +Timeout in seconds applied to every HTTP request made to the EDC REST API. +Increase this value for slow or high-latency networks. + +Strict Pre-execute Validation +----------------------------- + +Listener hooks are best-effort by default. If lineage objects cannot be resolved, +the listener logs a warning and task execution continues. + +To fail a task before ``execute()`` when lineage resolution fails, set +``pre_execute=validate_informatica_lineage`` on the operator: + +.. code-block:: python + + from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + from airflow.providers.informatica.lineage.validation import validate_informatica_lineage + + task = SQLExecuteQueryOperator( + task_id="transform", + conn_id="postgres_default", + sql="INSERT INTO dst SELECT * FROM src", + pre_execute=validate_informatica_lineage, + ) + +Per-task Selective Lineage +-------------------------- + +You can disable or re-enable automatic lineage on individual tasks (or entire DAGs) at +DAG definition time using the helper functions in ``airflow.providers.informatica.lineage``: + +.. code-block:: python + + from airflow.providers.informatica.lineage import ( + disable_informatica_lineage, + enable_informatica_lineage, + ) + + with DAG("my_dag", ...) as dag: + task_a = SomeSQLOperator(task_id="task_a", sql="SELECT * FROM orders", ...) + task_b = SomeSQLOperator(task_id="task_b", sql="SELECT * FROM customers", ...) + + # Disable auto-lineage for task_a only + disable_informatica_lineage(task_a) + + # Disable auto-lineage for all tasks in a DAG + disable_informatica_lineage(dag) SSL and Security ---------------- diff --git a/providers/informatica/docs/guides/usage.rst b/providers/informatica/docs/guides/usage.rst index 956bfaf5916ac..273e49fd8dc36 100644 --- a/providers/informatica/docs/guides/usage.rst +++ b/providers/informatica/docs/guides/usage.rst @@ -26,13 +26,18 @@ How It Works The Informatica plugin automatically detects tasks with lineage support and sends inlet/outlet information to Informatica EDC when tasks succeed. No additional configuration is required beyond defining inlets and outlets in your tasks. +Lineage resolution in listener hooks is best-effort by default: resolution errors are logged as warnings and task execution continues. For strict behavior that fails a task before ``execute()`` when lineage cannot be resolved, set ``pre_execute=validate_informatica_lineage`` on that operator. + Key Features ------------ -- **Automatic Lineage Detection**: Plugin automatically detects tasks with lineage support -- **EDC Integration**: Native REST API integration with Informatica Enterprise Data Catalog -- **Transparent Operation**: No code changes required beyond inlet/outlet definitions -- **Error Handling**: Robust error handling for API failures and invalid objects +- **Manual Lineage**: Explicitly declare inlets and outlets using EDC object URIs. By default, resolution is attempted in listeners and warnings are logged if objects cannot be resolved. +- **Automatic SQL Lineage**: When ``auto_lineage_enabled = True`` (the default), the provider parses the ``sql`` attribute of SQL operators, resolves detected tables in the Informatica catalog, and creates lineage links automatically. Supported SQL dialects include PostgreSQL, MySQL, Snowflake, BigQuery, Databricks, Redshift, SQLite, Oracle, Trino, Presto, Hive, Spark, and MSSQL. +- **Lineage Priority**: Manual inlets/outlets always take precedence over automatic SQL lineage. If a task has any inlets or outlets defined, SQL parsing is skipped entirely. +- **Per-task Control**: Disable or re-enable automatic lineage per task or per DAG using :func:`~airflow.providers.informatica.lineage.disable_informatica_lineage` and :func:`~airflow.providers.informatica.lineage.enable_informatica_lineage`. +- **Operator Exclusion**: Exclude entire operator classes via ``disabled_for_operators`` in ``airflow.cfg``. +- **Optional Fail-fast Validation**: For tasks that must enforce lineage integrity, use ``pre_execute=validate_informatica_lineage`` so unresolvable URIs or tables fail the task *before* execution begins. +- **EDC Integration**: Native REST API integration with Informatica Enterprise Data Catalog. - **Configurable**: Extensive configuration options for different environments Architecture @@ -83,82 +88,126 @@ Quick Start The provider automatically handles lineage extraction when tasks succeed. -Example DAG ------------ +Automatic SQL Lineage +--------------------- + +When ``auto_lineage_enabled = True`` (the default), the provider automatically detects SQL +operators and creates lineage without any explicit ``inlets``/``outlets`` declarations. + +.. code-block:: python + + from airflow import DAG + from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + from datetime import datetime + + with DAG("my_sql_dag", start_date=datetime(2024, 1, 1), schedule=None) as dag: + transform = SQLExecuteQueryOperator( + task_id="transform", + conn_id="postgres_default", + sql="INSERT INTO summary SELECT region, SUM(amount) FROM sales GROUP BY region", + ) + +The provider parses the SQL, finds ``sales`` as the source and ``summary`` as the target, +resolves both against the Informatica catalog, and creates the lineage link on task success. + +The SQL dialect is inferred automatically from the connection ID string (e.g., a connection +ID containing ``postgres`` maps to the PostgreSQL dialect, ``snowflake`` to Snowflake, etc.). + +.. note:: + + **SQL parsing is powered by** `sqlglot `__ **and is + subject to its parsing capabilities.** + + sqlglot covers a wide range of standard SQL constructs across the supported dialects, but + certain complex or dialect-specific patterns may not be parsed correctly. Examples of + queries that can produce incomplete or incorrect table extraction include: + + - Deeply nested or recursive CTEs (``WITH RECURSIVE``) + - Vendor-specific procedural extensions (e.g., ``PL/pgSQL`` ``EXECUTE``, T-SQL dynamic SQL) + - Queries built via dynamic string concatenation or stored procedures + - Non-standard or proprietary syntax not yet supported by sqlglot + + When the parser cannot reliably identify source or target tables, no automatic lineage is + created for that statement and a debug-level log entry is emitted. For such cases, fall + back to **manual lineage** by explicitly declaring ``inlets`` and ``outlets`` on the task, + which bypasses SQL parsing entirely and gives you full control over the lineage graph. + +Manual Lineage +-------------- + +Define inlets and outlets explicitly using EDC object URIs. These always take priority over +automatic SQL lineage. .. code-block:: python from airflow import DAG from airflow.providers.standard.operators.python import PythonOperator + from airflow.sdk import Asset from datetime import datetime - def my_python_task(**kwargs): - print("Hello Informatica Lineage!") + def my_python_task(**kwargs): ... - with DAG( - dag_id="example_informatica_lineage_dag", - start_date=datetime(2024, 1, 1), - schedule=None, - catchup=False, - ) as dag: - python_task = PythonOperator( - task_id="my_python_task", + with DAG("my_dag", start_date=datetime(2024, 1, 1), schedule=None) as dag: + task = PythonOperator( + task_id="transform", python_callable=my_python_task, - inlets=[{"dataset_uri": "edc://object/source_table_abc123"}], - outlets=[{"dataset_uri": "edc://object/target_table_xyz789"}], + inlets=[Asset("edc://object/source_table_abc123")], + outlets=[Asset("edc://object/target_table_xyz789")], ) -When this task succeeds, the provider automatically creates a lineage link between the source and target objects in EDC. +When this task succeeds, the provider creates a lineage link between the source and target +objects in EDC. -Hooks ------ +By default, unresolvable URIs are logged as warnings by listener hooks and do not block task +execution. To fail the task before ``execute()`` when lineage resolution fails, set +``pre_execute=validate_informatica_lineage``: -InformaticaEDCHook -^^^^^^^^^^^^^^^^^^ +.. code-block:: python -The hook provides low-level access to Informatica EDC API. + from airflow.providers.informatica.lineage.validation import validate_informatica_lineage -.. code-block:: python + task = PythonOperator( + task_id="transform", + python_callable=my_python_task, + inlets=[Asset("edc://object/source_table_abc123")], + outlets=[Asset("edc://object/target_table_xyz789")], + pre_execute=validate_informatica_lineage, + ) - from airflow.providers.informatica.hooks.edc import InformaticaEDCHook +Selective Lineage Control +------------------------- - hook = InformaticaEDCHook(informatica_edc_conn_id="my_connection") - object_data = hook.get_object("edc://object/table_123") - result = hook.create_lineage_link("source_id", "target_id") +Use the helpers in :mod:`airflow.providers.informatica.lineage` to disable or re-enable +automatic lineage on individual tasks or entire DAGs: -Plugins and Listeners ---------------------- +.. code-block:: python -The ``InformaticaProviderPlugin`` automatically registers listeners that: + from airflow.providers.informatica.lineage import ( + disable_informatica_lineage, + enable_informatica_lineage, + ) -- Monitor task success events -- Extract inlet/outlet information from tasks -- Resolve object IDs using EDC API -- Create lineage links between resolved objects + with DAG("my_dag", ...) as dag: + task_a = SomeSQLOperator(task_id="task_a", sql="SELECT * FROM orders", ...) + task_b = SomeSQLOperator(task_id="task_b", sql="SELECT * FROM customers", ...) -No manual intervention is required. The plugin works transparently with any task that defines inlets and outlets. + # Disable auto-lineage for task_a only + disable_informatica_lineage(task_a) + + # Or disable for all tasks in the DAG + disable_informatica_lineage(dag) + +These helpers have no effect on manually declared inlets and outlets. Supported Inlet/Outlet Formats ------------------------------- +------------------------------- Inlets and outlets can be defined as: +- ``Asset`` objects: ``Asset("edc://object/table_name")`` (recommended — DAG-serialization safe) - String URIs: ``"edc://object/table_name"`` - Dictionary with dataset_uri: ``{"dataset_uri": "edc://object/table_name"}`` -The plugin automatically handles both formats and resolves them to EDC object IDs. - - -Support -------- - -- **Documentation**: See the guides section for detailed usage and configuration -- **Issues**: Report bugs on the Apache Airflow GitHub repository -- **Community**: Join the Airflow community for discussions and support - -License -------- - -Licensed under the Apache License, Version 2.0. See LICENSE file for details. +All formats are resolved via the EDC ``GET /access/2/catalog/data/objects/{id}`` endpoint. diff --git a/providers/informatica/docs/index.rst b/providers/informatica/docs/index.rst index d860002fb8006..456ae17ddd4d4 100644 --- a/providers/informatica/docs/index.rst +++ b/providers/informatica/docs/index.rst @@ -140,6 +140,7 @@ PIP package Version required ``apache-airflow`` ``>=3.0.0`` ``apache-airflow-providers-common-compat`` ``>=1.12.0`` ``apache-airflow-providers-http`` ``>=4.13.2`` +``sqlglot`` ``>=30.0.0`` ========================================== ================== Cross provider package dependencies @@ -159,6 +160,7 @@ You can install such cross-provider dependencies when installing from PyPI. For Dependent package Extra ================================================================================================================== ================= `apache-airflow-providers-common-compat `_ ``common.compat`` +`apache-airflow-providers-common-sql `_ ``common.sql`` `apache-airflow-providers-http `_ ``http`` ================================================================================================================== ================= diff --git a/providers/informatica/provider.yaml b/providers/informatica/provider.yaml index 28202169f166e..735bfa6792299 100644 --- a/providers/informatica/provider.yaml +++ b/providers/informatica/provider.yaml @@ -76,3 +76,32 @@ config: example: "informatica_edc_default" default: "informatica_edc_default" version_added: ~ + auto_lineage_enabled: + description: | + Enable automatic SQL lineage detection for operators with a ``sql`` attribute. + When set to ``true``, the provider parses each task's SQL at pre-execution time, + resolves the referenced tables against the Informatica catalog, and creates lineage + links on task success. Set to ``false`` to rely solely on manually declared inlets + and outlets. + type: boolean + example: ~ + default: "True" + version_added: ~ + disabled_for_operators: + description: | + Exclude specific operator classes from Informatica lineage tracking by providing + a semicolon-separated list of fully-qualified class names. Operators listed here + will have neither manual nor automatic lineage processed. + type: string + example: >- + airflow.providers.standard.operators.bash.BashOperator; + airflow.providers.standard.operators.python.PythonOperator + default: "" + version_added: ~ + request_timeout: + description: | + Timeout in seconds for HTTP requests made to the Informatica EDC REST API. + type: integer + example: "30" + default: "30" + version_added: ~ diff --git a/providers/informatica/pyproject.toml b/providers/informatica/pyproject.toml index 36c7d0fb406a7..ddc370d214a0f 100644 --- a/providers/informatica/pyproject.toml +++ b/providers/informatica/pyproject.toml @@ -61,7 +61,8 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=3.0.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-http>=4.13.2" + "apache-airflow-providers-http>=4.13.2", + "sqlglot>=30.0.0" ] # The optional dependencies should be modified in place in the generated file @@ -70,6 +71,9 @@ dependencies = [ "common.compat" = [ "apache-airflow-providers-common-compat" ] +"common.sql" = [ + "apache-airflow-providers-common-sql" +] [dependency-groups] dev = [ @@ -77,8 +81,10 @@ dev = [ "apache-airflow-task-sdk", "apache-airflow-devel-common", "apache-airflow-providers-common-compat", + "apache-airflow-providers-common-sql", "apache-airflow-providers-http", # Additional devel dependencies (do not remove this line and add extra development dependencies) + "sqlglot>=30.0.0", "uuid6>=2024.7.10" ] diff --git a/providers/informatica/src/airflow/providers/informatica/conf.py b/providers/informatica/src/airflow/providers/informatica/conf.py new file mode 100644 index 0000000000000..456e252afe957 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/conf.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +# Disable caching inside tests so config can be freely mocked per test. +if os.getenv("PYTEST_VERSION"): + + def _no_cache(func): + return func + + cache = _no_cache +else: + from functools import lru_cache + + cache = lru_cache() + +from airflow.providers.common.compat.sdk import conf + +_CONFIG_SECTION = "informatica" + + +@cache +def disabled_operators() -> set[str]: + """Return FQCNs listed in ``[informatica] disabled_for_operators``.""" + option = conf.get(_CONFIG_SECTION, "disabled_for_operators", fallback="") + return set(op.strip() for op in option.split(";") if op.strip()) + + +@cache +def auto_lineage_enabled() -> bool: + """Return True when ``[informatica] auto_lineage_enabled`` is set.""" + return conf.getboolean(_CONFIG_SECTION, "auto_lineage_enabled", fallback=True) + + +@cache +def listener_disabled() -> bool: + """Return True when ``[informatica] listener_disabled`` is set.""" + return conf.getboolean(_CONFIG_SECTION, "listener_disabled", fallback=False) + + +def is_operator_disabled(operator: object) -> bool: + """Return True when the operator's fully-qualified class name is disabled.""" + op_class = type(operator) + fqcn = f"{op_class.__module__}.{op_class.__name__}" + return fqcn in disabled_operators() diff --git a/providers/informatica/src/airflow/providers/informatica/example_dags/__init__.py b/providers/informatica/src/airflow/providers/informatica/example_dags/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/informatica/src/airflow/providers/informatica/example_dags/example_informatica_lineage.py b/providers/informatica/src/airflow/providers/informatica/example_dags/example_informatica_lineage.py new file mode 100644 index 0000000000000..83388a824e2d2 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/example_dags/example_informatica_lineage.py @@ -0,0 +1,353 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG for Informatica EDC lineage. + +This DAG demonstrates automatic and manual lineage reporting to Informatica +EDC using the ``apache-airflow-providers-informatica`` provider. + +Prerequisites +------------- +1. Start the local dev stack:: + + cd providers/informatica/dev + docker-compose up -d + +2. Create the Airflow connections (UI → Admin → Connections): + + **informatica_edc_default** + - Connection Type : HTTP + - Host : localhost (or ``informatica_sim`` inside Breeze) + - Port : 8082 + - Schema : http + + **postgres_lineage** + - Connection Type : Postgres + - Host : localhost (or ``informatica-lineage-postgres`` inside Breeze) + - Port : 55433 (host mapping, 5432 inside Breeze network) + - Schema : lineage_demo + - Login : airflow + - Password : airflow + +3. Enable auto-lineage in ``airflow.cfg`` (or via env var):: + + [informatica] + auto_lineage_enabled = True + +What this DAG demonstrates +--------------------------- +* **Auto-lineage (task: build_order_summary)** + A multi-join SELECT → INSERT. The provider parses the SQL with ``sqlglot``, + extracts ``orders``, ``customers``, ``products`` as inlets and + ``order_summary`` as the outlet, then reports the relationship to EDC. + +* **Manual lineage (task: compute_customer_ltv)** + Inlets / outlets are set explicitly on the task via ``task.inlets`` and + ``task.outlets``. Auto-detection is skipped when manual lineage is present. + +* **Lineage disabled per-task (task: truncate_staging)** + ``disable_informatica_lineage(truncate_staging)`` is called after task + creation, so the provider skips EDC reporting entirely for this task. + +* **Lineage disabled via operator class (task: log_run_metadata)** + The operator's fully-qualified class name is added to + ``[informatica] disabled_for_operators`` in the example config block + (commented out here; uncomment to test). + +* **Strict pre-execution validation (task: compute_customer_ltv)** + ``pre_execute=validate_informatica_lineage`` is set on the operator so + that the task fails **before** ``execute()`` if any inlet/outlet URI + cannot be resolved in the EDC catalog. Without ``pre_execute``, the + listener logs a warning but does not block execution. + +* **Generic transfer with explicit target table (task: build_customer_segment_snapshot_generic)** + A complex ``CTE``-based ``SELECT`` is executed by ``GenericTransfer`` and loaded + into ``destination_table='customer_segment_snapshot'``. + +Inspecting results +------------------ +After a successful DAG run, check the simulator's in-memory lineage store:: + + curl http://localhost:8082/lineage +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.common.sql.operators.generic_transfer import GenericTransfer +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.informatica.lineage import disable_informatica_lineage +from airflow.providers.informatica.lineage.validation import validate_informatica_lineage +from airflow.sdk import Asset + +task_logger = logging.getLogger("airflow.task") +task_logger.setLevel(logging.DEBUG) + +# --------------------------------------------------------------------------- +# Airflow connection IDs used by all tasks +# --------------------------------------------------------------------------- + +_PG_CONN = "postgres_lineage" + +# --------------------------------------------------------------------------- +# DAG definition +# --------------------------------------------------------------------------- + +with DAG( + dag_id="example_informatica_lineage", + schedule=None, + start_date=datetime(2024, 1, 1), + catchup=False, + tags=["example", "informatica", "lineage"], + doc_md=__doc__, +) as dag: + # ------------------------------------------------------------------ + # Task 1 — AUTO-LINEAGE + # sqlglot detects: + # inlets : orders, customers, products + # outlet : order_summary + # ------------------------------------------------------------------ + build_order_summary = SQLExecuteQueryOperator( + task_id="build_order_summary", + conn_id=_PG_CONN, + sql=""" + INSERT INTO order_summary ( + order_id, + customer_email, + product_name, + category, + quantity, + total_amount, + order_date, + country + ) + SELECT + o.order_id, + c.email AS customer_email, + p.product_name, + p.category, + o.quantity, + o.quantity * o.unit_price AS total_amount, + o.order_date, + c.country + FROM orders o + JOIN customers c ON c.customer_id = o.customer_id + JOIN products p ON p.product_id = o.product_id + WHERE o.status = 'completed' + ON CONFLICT (order_id) DO UPDATE + SET processed_at = NOW() + """, + ) + + # ------------------------------------------------------------------ + # Task 2 — MANUAL LINEAGE + STRICT PRE-EXECUTE VALIDATION + # Inlets/outlets are declared explicitly; auto-detection is skipped. + # pre_execute=validate_informatica_lineage fails the task *before* + # execute() if any URI cannot be resolved in the EDC catalog. + # ------------------------------------------------------------------ + compute_customer_ltv = SQLExecuteQueryOperator( + task_id="compute_customer_ltv", + conn_id=_PG_CONN, + pre_execute=validate_informatica_lineage, + sql=""" + INSERT INTO customer_ltv ( + customer_id, + email, + full_name, + country, + total_orders, + total_spent, + avg_order_value, + first_order_date, + last_order_date + ) + SELECT + c.customer_id, + c.email, + c.first_name || ' ' || c.last_name AS full_name, + c.country, + COUNT(o.order_id) AS total_orders, + SUM(o.quantity * o.unit_price) AS total_spent, + AVG(o.quantity * o.unit_price) AS avg_order_value, + MIN(o.order_date) AS first_order_date, + MAX(o.order_date) AS last_order_date + FROM customers c + LEFT JOIN orders o ON o.customer_id = c.customer_id + GROUP BY c.customer_id, c.email, c.first_name, c.last_name, c.country + ON CONFLICT (customer_id) DO UPDATE + SET total_orders = EXCLUDED.total_orders, + total_spent = EXCLUDED.total_spent, + avg_order_value = EXCLUDED.avg_order_value, + first_order_date = EXCLUDED.first_order_date, + last_order_date = EXCLUDED.last_order_date, + calculated_at = NOW() + """, + # Declare inlets/outlets manually — EDC URI format: edc://object/ + inlets=[ + Asset("TEST_PSTGRS://mydb/public/customers"), + Asset("TEST_PSTGRS://mydb/public/orders"), + ], + outlets=[ + Asset("TEST_PSTGRS://mydb/public/customer_ltv"), + ], + ) + + # ------------------------------------------------------------------ + # Task 3 — GENERIC TRANSFER (COMPLEX CTE + EXPLICIT TARGET TABLE) + # Auto-lineage reads source tables from sql and uses destination_table + # as explicit outlet target for GenericTransfer. + # ------------------------------------------------------------------ + build_customer_segment_snapshot_generic = GenericTransfer( + task_id="build_customer_segment_snapshot_generic", + source_conn_id=_PG_CONN, + destination_conn_id=_PG_CONN, + destination_table="customer_segment_snapshot", + preoperator=""" + CREATE TABLE IF NOT EXISTS customer_segment_snapshot ( + customer_id BIGINT, + email TEXT, + country TEXT, + total_orders BIGINT, + total_revenue NUMERIC(14, 2), + avg_order_value NUMERIC(14, 2), + last_order_date DATE, + top_category TEXT, + segment TEXT + ); + TRUNCATE TABLE customer_segment_snapshot; + """, + sql=""" + WITH order_base AS ( + SELECT + o.customer_id, + o.order_id, + o.order_date, + o.quantity, + o.unit_price, + (o.quantity * o.unit_price) AS gross_amount, + p.category, + p.product_name + FROM orders o + JOIN products p ON p.product_id = o.product_id + WHERE o.status = 'completed' + ), + customer_rollup AS ( + SELECT + ob.customer_id, + COUNT(DISTINCT ob.order_id) AS total_orders, + SUM(ob.gross_amount) AS total_revenue, + AVG(ob.gross_amount) AS avg_order_value, + MAX(ob.order_date) AS last_order_date + FROM order_base ob + GROUP BY ob.customer_id + ), + category_mix AS ( + SELECT + ob.customer_id, + ob.category, + SUM(ob.gross_amount) AS category_revenue, + ROW_NUMBER() OVER ( + PARTITION BY ob.customer_id + ORDER BY SUM(ob.gross_amount) DESC, ob.category + ) AS category_rank + FROM order_base ob + GROUP BY ob.customer_id, ob.category + ), + top_category AS ( + SELECT + cm.customer_id, + cm.category AS top_category + FROM category_mix cm + WHERE cm.category_rank = 1 + ), + segmented AS ( + SELECT + c.customer_id, + c.email, + c.country, + cr.total_orders, + cr.total_revenue, + cr.avg_order_value, + cr.last_order_date, + tc.top_category, + NTILE(4) OVER (ORDER BY cr.total_revenue DESC NULLS LAST) AS revenue_quartile + FROM customer_rollup cr + JOIN customers c ON c.customer_id = cr.customer_id + LEFT JOIN top_category tc ON tc.customer_id = cr.customer_id + ) + SELECT + customer_id, + email, + country, + total_orders, + total_revenue, + avg_order_value, + last_order_date, + top_category, + CASE + WHEN revenue_quartile = 1 THEN 'platinum' + WHEN revenue_quartile = 2 THEN 'gold' + WHEN revenue_quartile = 3 THEN 'silver' + ELSE 'bronze' + END AS segment + FROM segmented + WHERE total_orders >= 1 + """, + ) + + # ------------------------------------------------------------------ + # Task 4 — LINEAGE DISABLED PER-TASK + # This task touches staging data; we don't want it reported to EDC. + # disable_informatica_lineage() sets a special Param that the listener + # checks at runtime. + # ------------------------------------------------------------------ + truncate_staging = SQLExecuteQueryOperator( + task_id="truncate_staging", + conn_id=_PG_CONN, + sql="TRUNCATE TABLE order_summary", + ) + # Disable EDC lineage for this task only + disable_informatica_lineage(truncate_staging) + + # ------------------------------------------------------------------ + # Task 5 — SIMPLE READ-ONLY QUERY (auto-lineage: inlets only) + # No INSERT/CREATE/MERGE → no outlet will be detected automatically. + # ------------------------------------------------------------------ + verify_counts = SQLExecuteQueryOperator( + task_id="verify_counts", + conn_id=_PG_CONN, + sql=""" + SELECT + (SELECT COUNT(*) FROM order_summary) AS summary_rows, + (SELECT COUNT(*) FROM customer_ltv) AS ltv_rows, + (SELECT COUNT(*) FROM customer_segment_snapshot) AS segment_rows + """, + ) + + # ------------------------------------------------------------------ + # Pipeline order + # ------------------------------------------------------------------ + ( + truncate_staging + >> build_order_summary + >> compute_customer_ltv + >> build_customer_segment_snapshot_generic + >> verify_counts + ) diff --git a/providers/informatica/src/airflow/providers/informatica/extractors/informatica.py b/providers/informatica/src/airflow/providers/informatica/extractors/informatica.py index f2debbba70d19..4bb2a42dc7815 100644 --- a/providers/informatica/src/airflow/providers/informatica/extractors/informatica.py +++ b/providers/informatica/src/airflow/providers/informatica/extractors/informatica.py @@ -29,36 +29,74 @@ class InformaticaLineageExtractor(LoggingMixin): """Extracts lineage information from Informatica EDC and converts to Airflow Assets.""" def __init__(self, edc_hook: InformaticaEDCHook) -> None: - """ - Initialize InformaticaLineageExtractor. - - Args: - edc_hook (InformaticaEDCHook): Hook for Informatica EDC API connection. - """ super().__init__() self.edc_hook = edc_hook def get_object(self, object_id: str) -> dict[str, Any]: - """ - Return Informatica catalog object by id via EDC hook. - - Args: - object_id (str): Informatica object id. - - Returns: - dict[str, Any]: Informatica catalog object. - """ + """Return Informatica catalog object by id via EDC hook.""" return self.edc_hook.get_object(object_id) def create_lineage_link(self, source_object_id: str, target_object_id: str) -> dict[str, Any]: + """Create a lineage link between source and target objects via EDC hook.""" + return self.edc_hook.create_lineage_link(source_object_id, target_object_id) + + def find_object_id( + self, + catalog_name: str, + database_name: str, + table_name: str, + ) -> str | None: """ - Create a lineage link between source and target objects via EDC hook. + Find Informatica catalog object id by catalog, database, and table name. - Args: - source_object_id (str): Source Informatica object id. - target_object_id (str): Target Informatica object id. + Resolves the EDC object identifier by searching for the table by name, + then validating it belongs to the expected schema/database hierarchy. - Returns: - dict[str, Any]: Result of lineage link creation. + :param catalog_name: Name of the catalog (used to disambiguate schemas). + :param database_name: Name of the database / schema. + :param table_name: Name of the table or view. + :return: Informatica catalog object id if found, else None. """ - return self.edc_hook.create_lineage_link(source_object_id, target_object_id) + table_obj = None + + tables = self.edc_hook.search_table(table_name) + if not tables.get("hits"): + self.log.debug("No EDC objects found for table %r", table_name) + return None + + for table in tables["hits"]: + if len(tables["hits"]) == 1: + table_obj = self.edc_hook.get_object(table["id"]) + break + + schemas = self.edc_hook.search_schema(database_name) + for schema in schemas.get("hits", []): + if not str(table["id"]).startswith(schema["id"]): + continue + + if len(schemas["hits"]) == 1: + table_obj = self.edc_hook.get_object(table["id"]) + break + + databases = self.edc_hook.search_database(catalog_name) + for database in databases.get("hits", []): + if str(schema["id"]).startswith(database["id"]): + table_obj = self.edc_hook.get_object(table["id"]) + break + + if table_obj: + break + + if table_obj: + break + + if table_obj: + return table_obj["id"] + + self.log.debug( + "Could not resolve EDC object for catalog=%r, database=%r, table=%r", + catalog_name, + database_name, + table_name, + ) + return None diff --git a/providers/informatica/src/airflow/providers/informatica/get_provider_info.py b/providers/informatica/src/airflow/providers/informatica/get_provider_info.py index 661f03814185d..d3c2926b641b3 100644 --- a/providers/informatica/src/airflow/providers/informatica/get_provider_info.py +++ b/providers/informatica/src/airflow/providers/informatica/get_provider_info.py @@ -68,6 +68,27 @@ def get_provider_info(): "default": "informatica_edc_default", "version_added": None, }, + "auto_lineage_enabled": { + "description": "Enable automatic SQL lineage detection for operators with a ``sql`` attribute.\nWhen set to ``true``, the provider parses each task's SQL at pre-execution time,\nresolves the referenced tables against the Informatica catalog, and creates lineage\nlinks on task success. Set to ``false`` to rely solely on manually declared inlets\nand outlets.\n", + "type": "boolean", + "example": None, + "default": "True", + "version_added": None, + }, + "disabled_for_operators": { + "description": "Exclude specific operator classes from Informatica lineage tracking by providing\na semicolon-separated list of fully-qualified class names. Operators listed here\nwill have neither manual nor automatic lineage processed.\n", + "type": "string", + "example": "airflow.providers.standard.operators.bash.BashOperator; airflow.providers.standard.operators.python.PythonOperator", + "default": "", + "version_added": None, + }, + "request_timeout": { + "description": "Timeout in seconds for HTTP requests made to the Informatica EDC REST API.\n", + "type": "integer", + "example": "30", + "default": "30", + "version_added": None, + }, }, } }, diff --git a/providers/informatica/src/airflow/providers/informatica/hooks/edc.py b/providers/informatica/src/airflow/providers/informatica/hooks/edc.py index c6cac88f7591d..44f50031f76f9 100644 --- a/providers/informatica/src/airflow/providers/informatica/hooks/edc.py +++ b/providers/informatica/src/airflow/providers/informatica/hooks/edc.py @@ -22,6 +22,7 @@ from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from urllib.parse import urlencode from requests.exceptions import RequestException @@ -171,17 +172,8 @@ def _request( f"Informatica EDC request to {endpoint} returned {response.status_code}: {message}" ) - def _encode_id(self, object_id, tilde=False): - """ - Encode an ID to be safe. Return String. - - Parameters - ---------- - object_id : String - ID of object - tilde : Boolean, optional (default=False) - Whether to encode with a tilde or percent sign. - """ + def _encode_id(self, object_id: str, tilde: bool = False) -> str: + """Encode an object ID for safe use in EDC URLs using tilde-prefix or percent encoding.""" if ":___" in object_id: object_id = object_id.replace(":___", "://") @@ -190,12 +182,14 @@ def _encode_id(self, object_id, tilde=False): id_lst = list(object_id) idx = 0 - while regex.search(object_id, idx) is not None: - idx = regex.search(object_id, idx).span()[1] + match = regex.search(object_id, idx) + while match is not None: + idx = match.span()[1] if tilde: id_lst[idx - 1] = "~" + str(bytes(id_lst[idx - 1], "utf-8").hex()) + "~" else: id_lst[idx - 1] = "%" + str(bytes(id_lst[idx - 1], "utf-8").hex()) + match = regex.search(object_id, idx) return "".join(id_lst) @@ -209,6 +203,55 @@ def get_object(self, object_id: str, include_ref_objects: bool = False) -> dict[ response = self._request("GET", url) return response.json() + def _search(self, **fq_parts: str) -> dict: + """Execute a catalog data search with the given ``fq`` filter parts.""" + params: list[tuple[str, str]] = [ + ("defaultFacets", "true"), + ("disableSemanticSearch", "false"), + ("enableLegacySearch", "false"), + ("facet", "false"), + ("fl", "core.name"), + ("highlight", "false"), + ("includeRefObjects", "false"), + ] + for key, value in fq_parts.items(): + params.append(("fq", f"{key}:{value}")) + + query_string = urlencode(params) + response = self._request("GET", f"/access/2/catalog/data/search?{query_string}") + return response.json() + + def search_database(self, database_name: str) -> dict: + """Search for a relational Database or DatabaseServer object by name.""" + result = self._search( + **{"core.classType": "com.infa.ldm.relational.Database", "core.name": database_name} + ) + if result.get("hits"): + return result + return self._search( + **{"core.classType": "com.infa.ldm.relational.DatabaseServer", "core.name": database_name} + ) + + def search_schema(self, schema_name: str) -> dict: + """Search for a relational Schema or DatabaseSchema object by name.""" + result = self._search( + **{"core.classType": "com.infa.ldm.relational.Schema", "core.name": schema_name} + ) + if result.get("hits"): + return result + return self._search( + **{"core.classType": "com.infa.ldm.relational.DatabaseSchema", "core.name": schema_name} + ) + + def search_table(self, table_name: str) -> dict: + """Search for a relational Table or View object by name.""" + return self._search( + **{ + "core.classType": "com.infa.ldm.relational.Table OR core.classType:com.infa.ldm.relational.View", + "core.name": table_name, + } + ) + def create_lineage_link(self, source_object_id: str, target_object_id: str) -> dict[str, Any]: """Create a lineage relationship between source and target objects.""" if source_object_id == target_object_id: diff --git a/providers/informatica/src/airflow/providers/informatica/lineage/__init__.py b/providers/informatica/src/airflow/providers/informatica/lineage/__init__.py new file mode 100644 index 0000000000000..cb1880573c6db --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/lineage/__init__.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.informatica.lineage.selective import ( + disable_informatica_lineage, + enable_informatica_lineage, +) + +__all__ = [ + "disable_informatica_lineage", + "enable_informatica_lineage", +] diff --git a/providers/informatica/src/airflow/providers/informatica/lineage/resolver.py b/providers/informatica/src/airflow/providers/informatica/lineage/resolver.py new file mode 100644 index 0000000000000..36e6cb1ff47e6 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/lineage/resolver.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +from airflow.providers.informatica.lineage.sql_parser import TableRef, parse_sql_tables + +log = logging.getLogger(__name__) + +try: + from airflow.providers.common.sql.operators.sql import BaseSQLOperator as _BaseSQLOperator + + _HAS_BASE_SQL_OPERATOR = True +except ImportError: + _BaseSQLOperator = None # type: ignore[assignment, misc] + _HAS_BASE_SQL_OPERATOR = False + +# Operator attribute names scanned in order to locate a connection ID. +# conn_id_field (BaseSQLOperator) is tried first; this list is the fallback. +_CONN_ID_ATTRS: tuple[str, ...] = ( + "conn_id", + "source_conn_id", + "mysql_conn_id", + "postgres_conn_id", + "mssql_conn_id", + "oracle_conn_id", + "sqlite_conn_id", + "snowflake_conn_id", + "databricks_conn_id", + "exasol_conn_id", + "hiveserver2_conn_id", +) + +# Keyword fragments found in a conn_id string mapped to sqlglot dialect names. +_CONN_TYPE_TO_DIALECT: dict[str, str] = { + "postgres": "postgres", + "redshift": "redshift", + "mysql": "mysql", + "mssql": "tsql", + "snowflake": "snowflake", + "bigquery": "bigquery", + "databricks": "databricks", + "sqlite": "sqlite", + "oracle": "oracle", + "trino": "trino", + "presto": "presto", + "hive": "hive", + "spark": "spark", +} + +# Operator attribute names checked as explicit write-target table when SQL +# parsing yields no targets (e.g. GenericTransfer, HiveToMySqlOperator). +_TARGET_TABLE_ATTRS: tuple[str, ...] = ( + "destination_table", + "mysql_table", + "hive_table", + "target_table", +) + + +class BaseLineageResolver(ABC): + """Base class for operator lineage resolvers.""" + + @abstractmethod + def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: + """Return ``(source_refs, target_refs)`` or ``None`` if the resolver does not apply.""" + + +class SQLLineageResolver(BaseLineageResolver): + """ + Resolves lineage for any operator that exposes a ``sql`` attribute. + + Detection is tiered: + + - Tier 1: operators inheriting from ``BaseSQLOperator`` — ``conn_id_field`` + points to the right connection attribute. + - Tier 2: operators with a ``sql`` attribute but no ``BaseSQLOperator`` + base (e.g. ``GenericTransfer``, ``BaseSQLToGCSOperator``) — dialect is + inferred from the first recognizable connection ID string found. + + Returns ``None`` when there is no SQL, when Jinja templates are detected, + or when parsing produces no table references. + """ + + def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: + sql = getattr(task, "sql", None) + if not sql: + return None + dialect = _infer_dialect(task) + default_database: str | None = getattr(task, "database", None) + sources, targets = parse_sql_tables(sql, dialect=dialect) + if not targets: + for attr in _TARGET_TABLE_ATTRS: + table_name = getattr(task, attr, None) + if table_name and isinstance(table_name, str): + targets.append(TableRef(table=table_name)) + break + + if not sources and not targets: + return None + + # Fill in default_database for refs that have none set + if default_database: + sources = [TableRef(t.table, t.schema, t.database or default_database) for t in sources] + targets = [TableRef(t.table, t.schema, t.database or default_database) for t in targets] + + return sources, targets + + +def _infer_dialect(task: Any) -> str | None: + conn_id_field = getattr(task, "conn_id_field", None) + if conn_id_field: + conn_id = getattr(task, conn_id_field, None) + if conn_id and isinstance(conn_id, str): + result = _dialect_from_conn_id_str(conn_id) + if result: + return result + + for attr in _CONN_ID_ATTRS: + conn_id = getattr(task, attr, None) + if conn_id and isinstance(conn_id, str): + result = _dialect_from_conn_id_str(conn_id) + if result: + return result + + return None + + +def _dialect_from_conn_id_str(conn_id: str) -> str | None: + conn_id_lower = conn_id.lower() + for keyword, dialect in _CONN_TYPE_TO_DIALECT.items(): + if keyword in conn_id_lower: + return dialect + return None + + +_SQL_RESOLVER = SQLLineageResolver() + + +def get_resolver(task: Any) -> BaseLineageResolver | None: + """Return a resolver for *task*, or ``None`` when no resolver applies.""" + if getattr(task, "sql", None): + return _SQL_RESOLVER + return None diff --git a/providers/informatica/src/airflow/providers/informatica/lineage/selective.py b/providers/informatica/src/airflow/providers/informatica/lineage/selective.py new file mode 100644 index 0000000000000..06ae87d4abcc9 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/lineage/selective.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.common.compat.sdk import DAG, Param, XComArg + +_DISABLE_PARAM_NAME = "_informatica_auto_lineage" +_ENABLE_PARAM = Param(True, const=True) +_DISABLE_PARAM = Param(False, const=False) + + +def disable_informatica_lineage(obj): + """ + Mark a task (or all tasks in a DAG) to skip automatic lineage detection. + + Has no effect on manually declared inlets and outlets. + """ + if isinstance(obj, XComArg): + disable_informatica_lineage(obj.operator) + return obj + if isinstance(obj, DAG): + for task in obj.task_dict.values(): + disable_informatica_lineage(task) + else: + obj.params[_DISABLE_PARAM_NAME] = _DISABLE_PARAM + return obj + + +def enable_informatica_lineage(obj): + """ + Re-enable automatic lineage detection on a task (or all tasks in a DAG). + + Only needed to reverse a prior ``disable_informatica_lineage`` call. + """ + if isinstance(obj, XComArg): + enable_informatica_lineage(obj.operator) + return obj + if isinstance(obj, DAG): + for task in obj.task_dict.values(): + enable_informatica_lineage(task) + else: + obj.params[_DISABLE_PARAM_NAME] = _ENABLE_PARAM + return obj + + +def is_task_auto_lineage_disabled(task) -> bool: + """Return True when auto lineage has been explicitly disabled on this task.""" + params = getattr(task, "params", None) + if params is None: + return False + value = params.get(_DISABLE_PARAM_NAME) + if isinstance(value, Param): + return value.resolve(suppress_exception=True) is False + return value is False diff --git a/providers/informatica/src/airflow/providers/informatica/lineage/sql_parser.py b/providers/informatica/src/airflow/providers/informatica/lineage/sql_parser.py new file mode 100644 index 0000000000000..f0e55e091bc85 --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/lineage/sql_parser.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import sqlglot +import sqlglot.expressions as exp + +log = logging.getLogger(__name__) + +_JINJA_MARKERS = ("{{", "{%") + + +@dataclass +class TableRef: + """Represents a parsed table reference with optional schema and database qualifiers.""" + + table: str + schema: str | None = None + database: str | None = None + + +def parse_sql_tables( + sql: str | list[str], + dialect: str | None = None, +) -> tuple[list[TableRef], list[TableRef]]: + """ + Parse SQL and return ``(source_tables, target_tables)``. + + Source tables are those read by FROM/JOIN clauses. + Target tables are those written by INSERT INTO, CREATE TABLE AS, or MERGE INTO. + Returns empty lists when SQL cannot be parsed instead of raising. + """ + statements = [sql] if isinstance(sql, str) else sql + sources: list[TableRef] = [] + targets: list[TableRef] = [] + + for idx, stmt in enumerate(statements): + if not isinstance(stmt, str) or not stmt.strip(): + continue + if any(marker in stmt for marker in _JINJA_MARKERS): + log.debug( + "SQL statement %d contains unrendered Jinja templates; skipping lineage extraction.", idx + ) + continue + try: + for parsed in sqlglot.parse(stmt, dialect=dialect, error_level=sqlglot.ErrorLevel.WARN): + if parsed is None: + continue + if not isinstance(parsed, exp.Expression): + continue + stmt_sources, stmt_targets = _extract_tables(parsed) + sources.extend(stmt_sources) + targets.extend(stmt_targets) + except Exception: + log.debug("Failed to parse SQL statement %d", idx, exc_info=True) + + return _dedup(sources), _dedup(targets) + + +def _extract_tables(parsed: exp.Expression) -> tuple[list[TableRef], list[TableRef]]: + cte_names: set[str] = set() + with_node = parsed.find(exp.With) + if with_node: + for cte in with_node.find_all(exp.CTE): + if cte.alias: + cte_names.add(cte.alias.lower()) + + target_node_id: int | None = None + if isinstance(parsed, (exp.Insert, exp.Create, exp.Merge)): + write_target = _get_write_target(parsed) + if write_target is not None: + target_node_id = id(write_target) + + sources: list[TableRef] = [] + targets: list[TableRef] = [] + + for table in parsed.find_all(exp.Table): + if table.name.lower() in cte_names: + continue + ref = TableRef( + table=table.name, + schema=table.db or None, + database=table.catalog or None, + ) + if id(table) == target_node_id: + targets.append(ref) + else: + sources.append(ref) + + return sources, targets + + +def _get_write_target(node: exp.Expression) -> exp.Table | None: + if isinstance(node, (exp.Insert, exp.Merge)): + candidate = node.this + elif isinstance(node, exp.Create): + candidate = node.this + else: + return None + + # INSERT INTO target(col1, col2) ... is represented as Schema(Table(...), ...) + if isinstance(candidate, exp.Schema): + candidate = candidate.this + + return candidate if isinstance(candidate, exp.Table) else None + + +def _dedup(refs: list[TableRef]) -> list[TableRef]: + seen: set[tuple[str | None, str | None, str]] = set() + result: list[TableRef] = [] + for ref in refs: + key = (ref.database, ref.schema, ref.table) + if key not in seen: + seen.add(key) + result.append(ref) + return result diff --git a/providers/informatica/src/airflow/providers/informatica/lineage/validation.py b/providers/informatica/src/airflow/providers/informatica/lineage/validation.py new file mode 100644 index 0000000000000..3750d1ed7725c --- /dev/null +++ b/providers/informatica/src/airflow/providers/informatica/lineage/validation.py @@ -0,0 +1,298 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared lineage validation logic for the Informatica provider. + +This module provides functions that resolve inlet/outlet URIs and table +references against Informatica EDC. The resolution logic is intentionally +separated from the listener so it can also be used as an operator +``pre_execute`` hook — where a raised exception *does* fail the task. + +Usage as ``pre_execute``:: + + from airflow.providers.informatica.lineage.validation import validate_informatica_lineage + + SQLExecuteQueryOperator( + task_id="my_task", + conn_id="my_conn", + sql="INSERT INTO dst SELECT * FROM src", + pre_execute=validate_informatica_lineage, + ) + +When passed as ``pre_execute``, any +:class:`InformaticaLineageResolutionError` propagates through the task +runner and fails the task *before* ``execute()`` is called. + +The listener calls the same functions but wraps them in a ``try/except`` +so resolution errors are logged as warnings instead of failing the task +(listener exceptions are swallowed by the Airflow task runner). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from airflow.providers.informatica.conf import auto_lineage_enabled, is_operator_disabled +from airflow.providers.informatica.extractors import InformaticaLineageExtractor +from airflow.providers.informatica.hooks.edc import InformaticaEDCError +from airflow.providers.informatica.lineage.resolver import get_resolver +from airflow.providers.informatica.lineage.selective import is_task_auto_lineage_disabled + +log = logging.getLogger(__name__) + +# Maximum number of entries kept in the pre-execute result cache. This is a +# safety bound — under normal operation each entry is consumed (popped) by the +# listener's ``on_task_instance_running`` hook shortly after it is stored. +# Orphaned entries (e.g. task killed between pre_execute and listener hook) are +# evicted in FIFO order once the limit is reached. +_PRE_EXECUTE_CACHE_MAX = 1024 + +# Cache shared between ``validate_informatica_lineage`` (pre_execute) and the +# listener. Keyed by (dag_id, run_id, task_id, map_index, try_number). +# Access only via ``pop_pre_execute_result`` / ``_store_pre_execute_result``. +_pre_execute_cache: dict[tuple, tuple[list[tuple[str, str]], list[tuple[str, str]]]] = {} + + +def pop_pre_execute_result( + key: tuple, +) -> tuple[list[tuple[str, str]], list[tuple[str, str]]] | None: + """Remove and return a cached pre-execute result, or ``None`` if absent.""" + return _pre_execute_cache.pop(key, None) + + +def _store_pre_execute_result( + key: tuple, + value: tuple[list[tuple[str, str]], list[tuple[str, str]]], +) -> None: + """Store a pre-execute result, evicting oldest entries when the cache is full.""" + if len(_pre_execute_cache) >= _PRE_EXECUTE_CACHE_MAX: + oldest_key = next(iter(_pre_execute_cache)) + _pre_execute_cache.pop(oldest_key, None) + log.debug("Pre-execute cache full, evicted key %s", oldest_key) + _pre_execute_cache[key] = value + + +# Lazy singleton for the default hook used by ``validate_informatica_lineage``. +_default_hook: InformaticaLineageExtractor | None = None + + +def _get_default_hook() -> InformaticaLineageExtractor: + """Return (and lazily create) a shared hook instance for pre-execute validation.""" + global _default_hook + if _default_hook is None: + from airflow.providers.informatica.hooks.edc import InformaticaEDCHook + + _default_hook = InformaticaLineageExtractor(edc_hook=InformaticaEDCHook()) + return _default_hook + + +class InformaticaLineageResolutionError(RuntimeError): + """Raised when an EDC object cannot be resolved for a lineage URI.""" + + +def resolve_uri_to_object_id(hook: InformaticaLineageExtractor, uri: str) -> str: + """ + Resolve an EDC lineage URI to an Informatica catalog object ID. + + Manual lineage entries are treated as concrete object identifiers/uris. + They are validated directly via ``get_object`` instead of being reparsed + and looked up again with ``find_object_id``. + + :raises InformaticaLineageResolutionError: When the URI cannot be resolved. + """ + try: + obj = hook.get_object(uri) + except InformaticaEDCError as exc: + raise InformaticaLineageResolutionError( + f"Failed to resolve EDC object for URI {uri!r}: {exc}" + ) from exc + + object_id = obj.get("id") if isinstance(obj, dict) else None + if not object_id: + raise InformaticaLineageResolutionError( + f"Could not resolve EDC object for URI {uri!r}. " + "Ensure the object exists in the Informatica catalog." + ) + log.debug("Resolved URI %r to EDC object_id=%s", uri, object_id) + return object_id + + +def resolve_uris( + hook: InformaticaLineageExtractor, + items: list, + role: str, + task_id: str, +) -> list[tuple[str, str]]: + """ + Resolve URI items to ``(uri, edc_object_id)`` tuples. + + :raises InformaticaLineageResolutionError: On the first URI that cannot + be resolved. + """ + result: list[tuple[str, str]] = [] + for item in items: + if isinstance(item, dict) and "dataset_uri" in item: + uri = item["dataset_uri"] + elif isinstance(item, str): + uri = item + elif hasattr(item, "uri") and isinstance(item.uri, str): + uri = item.uri + else: + raise InformaticaLineageResolutionError( + f"Invalid {role} entry for task {task_id!r}: expected a URI string, " + f"dict with 'dataset_uri', or an Asset object, got {type(item).__name__!r}." + ) + object_id = resolve_uri_to_object_id(hook, uri) + result.append((uri, object_id)) + return result + + +def resolve_table_refs( + hook: InformaticaLineageExtractor, + refs: list, + task_id: str, +) -> list[tuple[str, str]]: + """ + Resolve TableRef objects to ``(table_label, edc_object_id)`` tuples. + + Calls ``find_object_id`` which searches EDC by table name and narrows by + schema/database when multiple results are returned. + + :raises InformaticaLineageResolutionError: On the first unresolvable table. + """ + result: list[tuple[str, str]] = [] + for ref in refs: + catalog = ref.database or "" + schema = ref.schema or "" + table = ref.table + try: + object_id = hook.find_object_id(catalog, schema, table) + except InformaticaEDCError as e: + raise InformaticaLineageResolutionError( + f"EDC error while resolving table {table!r} " + f"(catalog={catalog!r}, schema={schema!r}) for task {task_id!r}: {e}" + ) from e + if not object_id: + raise InformaticaLineageResolutionError( + f"Could not resolve EDC object for table {table!r} " + f"(catalog={catalog!r}, schema={schema!r}) in task {task_id!r}. " + "Ensure the table is registered in the Informatica catalog." + ) + result.append((f"{catalog}/{schema}/{table}", object_id)) + return result + + +def resolve_informatica_lineage( + task: Any, + task_id: str, + hook: InformaticaLineageExtractor | None = None, +) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]: + """ + Resolve all inlet/outlet URIs or auto-detected tables for *task*. + + :returns: ``(valid_inlets, valid_outlets)`` — each a list of + ``(uri_or_label, edc_object_id)`` tuples. + :raises InformaticaLineageResolutionError: When any URI or table cannot + be resolved in the Informatica catalog. + """ + if hook is None: + hook = _get_default_hook() + + if is_operator_disabled(task): + log.debug("Lineage disabled for operator %s - skipping", type(task).__name__) + return [], [] + + inlets = getattr(task, "inlets", []) + outlets = getattr(task, "outlets", []) + + if inlets or outlets: + valid_inlets = resolve_uris(hook, inlets, "inlet", task_id) + valid_outlets = resolve_uris(hook, outlets, "outlet", task_id) + elif auto_lineage_enabled() and not is_task_auto_lineage_disabled(task): + resolver = get_resolver(task) + if resolver is not None: + result = resolver.resolve(task) + if result is not None: + source_refs, target_refs = result + log.info( + "Auto-lineage detected %d source(s) and %d target(s) for task %s", + len(source_refs), + len(target_refs), + task_id, + ) + valid_inlets = resolve_table_refs(hook, source_refs, task_id) + valid_outlets = resolve_table_refs(hook, target_refs, task_id) + else: + valid_inlets, valid_outlets = [], [] + else: + valid_inlets, valid_outlets = [], [] + else: + valid_inlets, valid_outlets = [], [] + + return valid_inlets, valid_outlets + + +def validate_informatica_lineage(context: Any) -> None: + """ + Pre-execute hook that validates Informatica lineage before task execution. + + Pass this function as ``pre_execute`` on any operator to fail the task + when inlet/outlet URIs cannot be resolved in the Informatica catalog:: + + SQLExecuteQueryOperator( + task_id="my_task", + conn_id="my_conn", + sql="INSERT INTO dst SELECT * FROM src", + pre_execute=validate_informatica_lineage, + ) + + Resolved pairs are cached so the listener's ``on_task_instance_success`` + can create lineage links without making a second round of EDC calls. + + :raises InformaticaLineageResolutionError: When any URI or table cannot + be resolved. + """ + ti = context.get("task_instance") or context.get("ti") + if ti is None: + return + + task = getattr(ti, "task", None) + if task is None: + return + + task_id = getattr(ti, "task_id", "unknown") + + valid_inlets, valid_outlets = resolve_informatica_lineage(task, task_id) + + # Store in cache so the listener can pick it up. + dag_id = getattr(ti, "dag_id", getattr(task, "dag_id", None)) + cache_key = ( + dag_id, + getattr(ti, "run_id", None), + task_id, + getattr(ti, "map_index", -1), + getattr(ti, "try_number", None), + ) + _store_pre_execute_result(cache_key, (valid_inlets, valid_outlets)) + + log.info( + "Pre-execute validation complete for task %s: %d inlet(s), %d outlet(s) resolved", + task_id, + len(valid_inlets), + len(valid_outlets), + ) diff --git a/providers/informatica/src/airflow/providers/informatica/plugins/informatica.py b/providers/informatica/src/airflow/providers/informatica/plugins/informatica.py index 863fa0a800cf3..e02826870995a 100644 --- a/providers/informatica/src/airflow/providers/informatica/plugins/informatica.py +++ b/providers/informatica/src/airflow/providers/informatica/plugins/informatica.py @@ -17,9 +17,9 @@ from __future__ import annotations from airflow.plugins_manager import AirflowPlugin -from airflow.providers.common.compat.sdk import conf +from airflow.providers.informatica.conf import listener_disabled -is_disabled = conf.getboolean("informatica", "listener_disabled", fallback=False) +is_disabled = listener_disabled() # Conditional imports - only load expensive dependencies when plugin is enabled if not is_disabled: from airflow.providers.common.compat.sdk import HookLineageReader diff --git a/providers/informatica/src/airflow/providers/informatica/plugins/listener.py b/providers/informatica/src/airflow/providers/informatica/plugins/listener.py index 6e3dd1a5c2fc3..783226573b55c 100644 --- a/providers/informatica/src/airflow/providers/informatica/plugins/listener.py +++ b/providers/informatica/src/airflow/providers/informatica/plugins/listener.py @@ -21,7 +21,13 @@ from airflow.listeners import hookimpl from airflow.providers.informatica.extractors import InformaticaLineageExtractor -from airflow.providers.informatica.hooks.edc import InformaticaEDCHook +from airflow.providers.informatica.hooks.edc import InformaticaEDCError +from airflow.providers.informatica.lineage.validation import ( + InformaticaLineageResolutionError, + pop_pre_execute_result, + resolve_informatica_lineage, + resolve_uri_to_object_id, +) if TYPE_CHECKING: from airflow.models import TaskInstance @@ -29,108 +35,138 @@ _informatica_listener: InformaticaListener | None = None +# Re-export for backward compatibility. +__all__ = [ + "InformaticaLineageResolutionError", + "InformaticaListener", + "get_informatica_listener", +] + +# Backward-compatible alias. +_resolve_uri_to_object_id = resolve_uri_to_object_id + class InformaticaListener: """Informatica listener sends events on task instance state changes to Informatica EDC for lineage tracking.""" def __init__(self): - self._executor = None self.log = logging.getLogger(__name__) + from airflow.providers.informatica.hooks.edc import InformaticaEDCHook + self.hook = InformaticaLineageExtractor(edc_hook=InformaticaEDCHook()) - # self.extractor_manager = ExtractorManager() + # Cache: _cache_key(ti) -> (valid_inlets, valid_outlets) + # Populated by on_task_instance_running (pre-validation), consumed by + # on_task_instance_success and cleared by on_task_instance_failed. + self._resolved_cache: dict[tuple, tuple[list[tuple[str, str]], list[tuple[str, str]]]] = {} + + @staticmethod + def _cache_key(task_instance: TaskInstance) -> tuple: + dag_id = getattr(task_instance, "dag_id", None) + if dag_id is None: + task = getattr(task_instance, "task", None) + dag_id = getattr(task, "dag_id", None) + return ( + dag_id, + getattr(task_instance, "run_id", None), + task_instance.task_id, + getattr(task_instance, "map_index", -1), + getattr(task_instance, "try_number", None), + ) @hookimpl def on_task_instance_success( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="success") + key = self._cache_key(task_instance) + cached = self._resolved_cache.pop(key, None) + if cached is None: + # Running hook was skipped (e.g. operator disabled) - nothing to do. + return + valid_inlets, valid_outlets = cached + self._create_lineage_links(valid_inlets, valid_outlets, task_instance.task_id) @hookimpl def on_task_instance_failed( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="failed") + # Clean up cache entry so stale entries do not accumulate. + self._resolved_cache.pop(self._cache_key(task_instance), None) @hookimpl def on_task_instance_running( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="running") - - def _handle_lineage(self, task_instance: TaskInstance, state: str): """ - Handle lineage resolution for inlets and outlets. + Best-effort pre-resolution of inlet/outlet URIs before task execution. - For each inlet and outlet, resolve Informatica EDC object IDs using getObject. - If valid, collect and create lineage links between all valid inlets and outlets. + Resolved pairs are cached so ``on_task_instance_success`` can create + lineage links without making a second round of EDC calls. + + .. note:: + + Exceptions raised inside listener hooks are caught and logged by + the Airflow task runner — they do **not** fail the task. To fail + the task when lineage URIs cannot be resolved, use + :func:`~airflow.providers.informatica.lineage.validation.validate_informatica_lineage` + as a ``pre_execute`` hook on the operator instead. """ task = getattr(task_instance, "task", None) if not task: - self.log.debug("No task found for TaskInstance %s", task_instance) return - inlets = getattr(task, "inlets", getattr(task_instance, "inlets", [])) - outlets = getattr(task, "outlets", getattr(task_instance, "outlets", [])) - valid_inlets = [] # List of tuples: (uri, object_id) - valid_outlets = [] - - self.log.info("[InformaticaLineageListener] Task: %s State: %s", task_instance.task_id, state) + task_id = task_instance.task_id + + # If validate_informatica_lineage was already called via pre_execute, + # reuse its cached result instead of calling EDC again. + key = self._cache_key(task_instance) + pre_exec_result = pop_pre_execute_result(key) + if pre_exec_result is not None: + self._resolved_cache[key] = pre_exec_result + self.log.debug( + "Reusing pre_execute cache for task %s: %d inlet(s), %d outlet(s)", + task_id, + len(pre_exec_result[0]), + len(pre_exec_result[1]), + ) + return - if state != "success": - self.log.info("[InformaticaLineageListener] Skipping lineage handling for state: %s", state) + self.log.debug("Pre-resolving lineage for task %s", task_id) + + try: + valid_inlets, valid_outlets = resolve_informatica_lineage(task, task_id, hook=self.hook) + except (InformaticaLineageResolutionError, InformaticaEDCError): + self.log.warning( + "Could not pre-resolve lineage for task %s - " + "lineage links will not be created on success. " + "To fail the task on resolution errors, use " + "pre_execute=validate_informatica_lineage on the operator.", + task_id, + exc_info=True, + ) return - for inlet in inlets: - inlet_uri = None - if isinstance(inlet, dict) and "dataset_uri" in inlet: - inlet_uri = inlet["dataset_uri"] - elif isinstance(inlet, str): - inlet_uri = inlet - else: - self.log.error("Inlet is not a string or dict with 'dataset_uri': %s", inlet) - continue - self.log.info("[InformaticaLineageListener] Inlet: %s and type: %s", inlet_uri, type(inlet)) - try: - obj = self.hook.get_object(inlet_uri) - if obj and "id" in obj and obj["id"]: - valid_inlets.append((inlet_uri, obj["id"])) - except Exception as e: - self.log.exception("Failed to resolve inlet %s: %s", inlet_uri, e) - - for outlet in outlets: - outlet_uri = None - if isinstance(outlet, dict) and "dataset_uri" in outlet: - outlet_uri = outlet["dataset_uri"] - elif isinstance(outlet, str): - outlet_uri = outlet - else: - self.log.error("Outlet is not a string or dict with 'dataset_uri': %s", outlet) - continue - self.log.info("[InformaticaLineageListener] Outlet: %s", outlet_uri) - try: - obj = self.hook.get_object(outlet_uri) - if obj and "id" in obj and obj["id"]: - valid_outlets.append((outlet_uri, obj["id"])) - except Exception as e: - self.log.warning("Failed to resolve outlet %s: %s", outlet_uri, e) - - # Create lineage links between all valid inlet and outlet object IDs - for inlet_uri, inlet_id in valid_inlets: - for outlet_uri, outlet_id in valid_outlets: + self._resolved_cache[key] = (valid_inlets, valid_outlets) + self.log.info( + "Pre-resolution complete for task %s: %d inlet(s), %d outlet(s) resolved", + task_id, + len(valid_inlets), + len(valid_outlets), + ) + + def _create_lineage_links( + self, + valid_inlets: list[tuple[str, str]], + valid_outlets: list[tuple[str, str]], + task_id: str, + ) -> None: + """Create EDC lineage links between all resolved inlet and outlet object IDs.""" + for _inlet_uri, inlet_id in valid_inlets: + for _outlet_uri, outlet_id in valid_outlets: try: - self.log.info( - "[InformaticaLineageListener] Creating lineage link: %s (%s) -> %s (%s)", - inlet_uri, - inlet_id, - outlet_uri, - outlet_id, - ) - result = self.hook.create_lineage_link(inlet_id, outlet_id) - self.log.info("Lineage link created: %s -> %s | Result: %s", inlet_id, outlet_id, result) - except Exception as e: - self.log.exception( - "Failed to create lineage link from %s to %s: %s", inlet_id, outlet_id, e - ) + self.hook.create_lineage_link(inlet_id, outlet_id) + self.log.info("Lineage link created: %s -> %s", inlet_id, outlet_id) + except InformaticaEDCError: + self.log.exception("Failed to create lineage link from %s to %s", inlet_id, outlet_id) def get_informatica_listener() -> InformaticaListener: diff --git a/providers/informatica/tests/unit/informatica/hooks/test_edc.py b/providers/informatica/tests/unit/informatica/hooks/test_edc.py index 9fd0da906c4ab..c5f128c797c59 100644 --- a/providers/informatica/tests/unit/informatica/hooks/test_edc.py +++ b/providers/informatica/tests/unit/informatica/hooks/test_edc.py @@ -20,7 +20,10 @@ from unittest.mock import MagicMock, patch import pytest +from requests import Response, Session +from requests.exceptions import SSLError +from airflow.models import Connection from airflow.providers.informatica.hooks.edc import InformaticaEDCError, InformaticaEDCHook @@ -32,7 +35,7 @@ def hook(): @patch("airflow.providers.informatica.hooks.edc.HttpHook.get_connection") def test_config_property_and_build_connection_config(mock_get_connection, hook): """Test config property and _build_connection_config method.""" - mock_conn = MagicMock() + mock_conn = MagicMock(spec=Connection) mock_conn.host = "testhost" mock_conn.schema = "https" mock_conn.port = 443 @@ -61,7 +64,7 @@ def test_config_property_and_build_connection_config(mock_get_connection, hook): @patch("airflow.providers.informatica.hooks.edc.HttpHook.get_conn") def test_get_conn_headers_and_verify(mock_get_conn, mock_get_connection, hook): """Test get_conn sets headers and verify.""" - mock_conn = MagicMock() + mock_conn = MagicMock(spec=Connection) mock_conn.host = "testhost" mock_conn.schema = "https" mock_conn.port = 443 @@ -69,7 +72,7 @@ def test_get_conn_headers_and_verify(mock_get_conn, mock_get_connection, hook): mock_conn.password = "pass" mock_conn.extra_dejson = {"verify_ssl": True} mock_get_connection.return_value = mock_conn - mock_session = MagicMock() + mock_session = MagicMock(spec=Session) mock_session.headers = {} mock_get_conn.return_value = mock_session session = hook.get_conn() @@ -91,8 +94,8 @@ def test_build_url(hook): @patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook.get_conn") def test_request_success_and_error(mock_get_conn, hook): """Test _request method for success and error cases.""" - mock_session = MagicMock() - mock_response = MagicMock() + mock_session = MagicMock(spec=Session) + mock_response = MagicMock(spec=Response) mock_response.ok = True mock_response.status_code = 200 mock_response.text = "" @@ -114,6 +117,20 @@ def test_request_success_and_error(mock_get_conn, hook): hook._request("GET", "endpoint") +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook.get_conn") +def test_request_ssl_error_raises_without_http_fallback(mock_get_conn, hook): + """Test _request raises on SSL errors and does not retry over HTTP.""" + mock_session = MagicMock(spec=Session) + mock_session.request.side_effect = SSLError("[SSL: WRONG_VERSION_NUMBER] wrong version number") + mock_get_conn.return_value = mock_session + hook._config = MagicMock(base_url="https://informatica_sim:8082", request_timeout=10) + + with pytest.raises(InformaticaEDCError, match="Failed to call Informatica EDC endpoint /access"): + hook._request("GET", "/access") + + assert mock_session.request.call_count == 1 + + def test_encode_id(hook): """Test _encode_id method for tilde and percent encoding.""" # ID with unsafe chars @@ -154,3 +171,67 @@ def test_create_lineage_link(mock_request, hook): def test_close_session(hook): """Test close_session does nothing (no-op).""" assert hook.close_session() is None + + +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook._search") +def test_search_database_prefers_database_hits(mock_search, hook): + """Database hit should be returned without fallback search.""" + mock_search.return_value = {"hits": [{"id": "db_1"}]} + + result = hook.search_database("my_db") + + assert result == {"hits": [{"id": "db_1"}]} + mock_search.assert_called_once_with( + **{"core.classType": "com.infa.ldm.relational.Database", "core.name": "my_db"} + ) + + +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook._search") +def test_search_database_falls_back_to_database_server(mock_search, hook): + """DatabaseServer fallback should be used when Database search has no hits.""" + mock_search.side_effect = [{"hits": []}, {"hits": [{"id": "db_server_1"}]}] + + result = hook.search_database("my_db") + + assert result == {"hits": [{"id": "db_server_1"}]} + assert mock_search.call_count == 2 + + +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook._search") +def test_search_schema_prefers_schema_hits(mock_search, hook): + """Schema hit should be returned without fallback search.""" + mock_search.return_value = {"hits": [{"id": "schema_1"}]} + + result = hook.search_schema("public") + + assert result == {"hits": [{"id": "schema_1"}]} + mock_search.assert_called_once_with( + **{"core.classType": "com.infa.ldm.relational.Schema", "core.name": "public"} + ) + + +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook._search") +def test_search_schema_falls_back_to_database_schema(mock_search, hook): + """DatabaseSchema fallback should be used when Schema search has no hits.""" + mock_search.side_effect = [{"hits": []}, {"hits": [{"id": "db_schema_1"}]}] + + result = hook.search_schema("public") + + assert result == {"hits": [{"id": "db_schema_1"}]} + assert mock_search.call_count == 2 + + +@patch("airflow.providers.informatica.hooks.edc.InformaticaEDCHook._search") +def test_search_table_queries_table_and_view(mock_search, hook): + """Table search should query both Table and View class types.""" + mock_search.return_value = {"hits": [{"id": "table_or_view_1"}]} + + result = hook.search_table("orders") + + assert result == {"hits": [{"id": "table_or_view_1"}]} + mock_search.assert_called_once_with( + **{ + "core.classType": "com.infa.ldm.relational.Table OR core.classType:com.infa.ldm.relational.View", + "core.name": "orders", + } + ) diff --git a/providers/informatica/tests/unit/informatica/lineage/__init__.py b/providers/informatica/tests/unit/informatica/lineage/__init__.py new file mode 100644 index 0000000000000..5966d6b1d5261 --- /dev/null +++ b/providers/informatica/tests/unit/informatica/lineage/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/providers/informatica/tests/unit/informatica/lineage/test_resolver.py b/providers/informatica/tests/unit/informatica/lineage/test_resolver.py new file mode 100644 index 0000000000000..f30ed70ad6d2a --- /dev/null +++ b/providers/informatica/tests/unit/informatica/lineage/test_resolver.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.informatica.lineage.resolver import ( + SQLLineageResolver, + _dialect_from_conn_id_str, + _infer_dialect, + get_resolver, +) + + +class _FakeTask: + """Minimal task stand-in with configurable attributes.""" + + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +class TestGetResolver: + def test_returns_sql_resolver_when_sql_present(self): + task = _FakeTask(sql="SELECT 1") + assert get_resolver(task) is not None + + def test_returns_none_when_no_sql(self): + task = _FakeTask() + assert get_resolver(task) is None + + def test_returns_none_when_sql_is_empty_string(self): + task = _FakeTask(sql="") + assert get_resolver(task) is None + + +class TestDialectInference: + @pytest.mark.parametrize( + ("conn_id", "expected"), + [ + ("postgres_default", "postgres"), + ("my_snowflake_conn", "snowflake"), + ("mysql_prod", "mysql"), + ("redshift_warehouse", "redshift"), + ("databricks_cluster", "databricks"), + ("hive_metastore", "hive"), + ("unknown_connection", None), + ], + ) + def test_dialect_from_conn_id_str(self, conn_id, expected): + assert _dialect_from_conn_id_str(conn_id) == expected + + def test_infer_uses_conn_id_field_first(self): + task = _FakeTask(conn_id_field="snowflake_conn_id", snowflake_conn_id="my_snowflake") + assert _infer_dialect(task) == "snowflake" + + def test_infer_falls_back_to_conn_id(self): + task = _FakeTask(conn_id="postgres_default") + assert _infer_dialect(task) == "postgres" + + def test_infer_falls_back_to_source_conn_id(self): + task = _FakeTask(source_conn_id="mysql_source") + assert _infer_dialect(task) == "mysql" + + def test_infer_returns_none_when_no_conn_id(self): + task = _FakeTask(sql="SELECT 1") + assert _infer_dialect(task) is None + + +class TestSQLLineageResolver: + def test_simple_select_produces_inlets_only(self): + task = _FakeTask(sql="SELECT * FROM orders", conn_id="postgres_default") + resolver = SQLLineageResolver() + result = resolver.resolve(task) + assert result is not None + inlets, outlets = result + assert len(inlets) == 1 + assert inlets[0].table == "orders" + assert inlets[0].schema is None + assert outlets == [] + + def test_insert_into_select_produces_inlets_and_outlets(self): + task = _FakeTask( + sql="INSERT INTO summary SELECT region, amount FROM sales", + conn_id="postgres_default", + ) + resolver = SQLLineageResolver() + result = resolver.resolve(task) + assert result is not None + inlets, outlets = result + assert any(t.table == "sales" for t in inlets) + assert any(t.table == "summary" for t in outlets) + + def test_destination_table_attr_supplements_missing_targets(self): + task = _FakeTask( + sql="SELECT * FROM source_table", + conn_id="mysql_default", + destination_table="dest_table", + ) + resolver = SQLLineageResolver() + result = resolver.resolve(task) + assert result is not None + _, outlets = result + assert any(t.table == "dest_table" for t in outlets) + + def test_no_sql_returns_none(self): + task = _FakeTask() + resolver = SQLLineageResolver() + assert resolver.resolve(task) is None + + def test_jinja_sql_returns_none(self): + task = _FakeTask(sql="SELECT * FROM {{ params.table }}", conn_id="postgres_default") + resolver = SQLLineageResolver() + assert resolver.resolve(task) is None + + def test_schema_qualified_table_in_uri(self): + task = _FakeTask(sql="SELECT * FROM public.users", conn_id="postgres_default") + resolver = SQLLineageResolver() + result = resolver.resolve(task) + assert result is not None + inlets, _ = result + ref = next(t for t in inlets if t.table == "users") + assert ref.schema == "public" + + def test_database_applied_as_default(self): + task = _FakeTask( + sql="SELECT * FROM orders", + conn_id="postgres_default", + database="mydb", + ) + resolver = SQLLineageResolver() + result = resolver.resolve(task) + assert result is not None + inlets, _ = result + ref = next(t for t in inlets if t.table == "orders") + assert ref.database == "mydb" + + def test_no_tables_in_sql_returns_none(self): + task = _FakeTask(sql="SELECT 1 + 1", conn_id="postgres_default") + resolver = SQLLineageResolver() + assert resolver.resolve(task) is None diff --git a/providers/informatica/tests/unit/informatica/lineage/test_selective.py b/providers/informatica/tests/unit/informatica/lineage/test_selective.py new file mode 100644 index 0000000000000..fe40e69a9a04f --- /dev/null +++ b/providers/informatica/tests/unit/informatica/lineage/test_selective.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.common.compat.sdk import DAG +from airflow.providers.informatica.lineage.selective import ( + disable_informatica_lineage, + enable_informatica_lineage, + is_task_auto_lineage_disabled, +) +from airflow.providers.standard.operators.empty import EmptyOperator + + +class _TaskWithPlainParams: + def __init__(self): + self.params = {} + + +def test_disable_and_enable_for_plain_dict_params(): + task = _TaskWithPlainParams() + + disable_informatica_lineage(task) + assert is_task_auto_lineage_disabled(task) is True + + enable_informatica_lineage(task) + assert is_task_auto_lineage_disabled(task) is False + + +def test_is_task_auto_lineage_disabled_returns_false_when_params_missing_or_none(): + class _NoParamsTask: + pass + + class _NoneParamsTask: + params = None + + assert is_task_auto_lineage_disabled(_NoParamsTask()) is False + assert is_task_auto_lineage_disabled(_NoneParamsTask()) is False + + +def test_disable_and_enable_work_for_dag_tasks(): + dag = DAG(dag_id="test_selective_dag", start_date=datetime(2024, 1, 1), schedule=None) + task_a = EmptyOperator(task_id="task_a", dag=dag) + task_b = EmptyOperator(task_id="task_b", dag=dag) + + disable_informatica_lineage(dag) + assert is_task_auto_lineage_disabled(task_a) is True + assert is_task_auto_lineage_disabled(task_b) is True + + enable_informatica_lineage(dag) + assert is_task_auto_lineage_disabled(task_a) is False + assert is_task_auto_lineage_disabled(task_b) is False + + +def test_disable_with_xcomarg_returns_same_xcomarg_and_marks_underlying_task(): + dag = DAG(dag_id="test_selective_xcomarg", start_date=datetime(2024, 1, 1), schedule=None) + task = EmptyOperator(task_id="task", dag=dag) + + xcomarg = task.output + result = disable_informatica_lineage(xcomarg) + + assert result is xcomarg + assert is_task_auto_lineage_disabled(task) is True diff --git a/providers/informatica/tests/unit/informatica/lineage/test_sql_parser.py b/providers/informatica/tests/unit/informatica/lineage/test_sql_parser.py new file mode 100644 index 0000000000000..fe4ff0853a268 --- /dev/null +++ b/providers/informatica/tests/unit/informatica/lineage/test_sql_parser.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.informatica.lineage.sql_parser import parse_sql_tables + + +class TestParseSqlTables: + def test_simple_select(self): + sources, targets = parse_sql_tables("SELECT * FROM orders") + assert len(sources) == 1 + assert sources[0].table == "orders" + assert targets == [] + + def test_select_with_join(self): + sql = "SELECT o.id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id" + sources, targets = parse_sql_tables(sql) + table_names = {r.table for r in sources} + assert "orders" in table_names + assert "customers" in table_names + assert targets == [] + + def test_insert_into_select(self): + sql = "INSERT INTO summary SELECT region, SUM(amount) FROM sales GROUP BY region" + sources, targets = parse_sql_tables(sql) + assert len(targets) == 1 + assert targets[0].table == "summary" + assert any(r.table == "sales" for r in sources) + + def test_insert_into_select_multi_source(self): + sql = "INSERT INTO dest SELECT * FROM src1 JOIN src2 ON src1.id = src2.id" + sources, targets = parse_sql_tables(sql) + assert targets[0].table == "dest" + source_names = {r.table for r in sources} + assert "src1" in source_names + assert "src2" in source_names + + def test_insert_with_column_list_and_on_conflict(self): + sql = """ + INSERT INTO order_summary (order_id, total_amount) + SELECT o.order_id, o.quantity * o.unit_price + FROM orders o + ON CONFLICT (order_id) DO UPDATE + SET total_amount = EXCLUDED.total_amount + """ + sources, targets = parse_sql_tables(sql, dialect="postgres") + target_names = {r.table for r in targets} + source_names = {r.table for r in sources} + assert "order_summary" in target_names + assert "orders" in source_names + assert "order_summary" not in source_names + + def test_create_table_as_select(self): + sql = "CREATE TABLE report AS SELECT * FROM raw_data" + sources, targets = parse_sql_tables(sql) + assert targets[0].table == "report" + assert sources[0].table == "raw_data" + + def test_schema_qualified_table(self): + sources, _ = parse_sql_tables("SELECT * FROM public.users") + assert sources[0].table == "users" + assert sources[0].schema == "public" + + def test_database_qualified_table(self): + sources, _ = parse_sql_tables("SELECT * FROM mydb.public.users", dialect="postgres") + assert sources[0].table == "users" + assert sources[0].schema == "public" + assert sources[0].database == "mydb" + + def test_cte_name_excluded_from_sources(self): + sql = """ + WITH recent AS (SELECT * FROM events WHERE ts > '2024-01-01') + SELECT * FROM recent + """ + sources, _ = parse_sql_tables(sql) + table_names = {r.table for r in sources} + assert "recent" not in table_names + assert "events" in table_names + + def test_list_of_statements(self): + statements = [ + "INSERT INTO a SELECT * FROM b", + "INSERT INTO c SELECT * FROM d", + ] + sources, targets = parse_sql_tables(statements) + target_names = {r.table for r in targets} + source_names = {r.table for r in sources} + assert "a" in target_names + assert "c" in target_names + assert "b" in source_names + assert "d" in source_names + + def test_dedup_repeated_table(self): + sql = "SELECT * FROM t JOIN t ON t.a = t.b" + sources, _ = parse_sql_tables(sql) + assert len(sources) == 1 + + def test_empty_sql_returns_empty(self): + sources, targets = parse_sql_tables("") + assert sources == [] + assert targets == [] + + def test_whitespace_only_sql_returns_empty(self): + sources, targets = parse_sql_tables(" ") + assert sources == [] + assert targets == [] + + def test_jinja_template_skipped(self): + sources, targets = parse_sql_tables("SELECT * FROM {{ params.table }}") + assert sources == [] + assert targets == [] + + def test_invalid_sql_returns_empty_not_raises(self): + sources, targets = parse_sql_tables("THIS IS NOT SQL @@@@") + assert isinstance(sources, list) + assert isinstance(targets, list) + + @pytest.mark.parametrize("dialect", ["postgres", "mysql", "snowflake"]) + def test_dialect_variants_parse_without_error(self, dialect): + sql = "SELECT id, name FROM customers WHERE active = 1" + sources, _ = parse_sql_tables(sql, dialect=dialect) + assert sources[0].table == "customers" + + def test_merge_into(self): + sql = """ + MERGE INTO target t + USING source s ON t.id = s.id + WHEN MATCHED THEN UPDATE SET t.val = s.val + """ + sources, targets = parse_sql_tables(sql, dialect="tsql") + assert any(r.table == "target" for r in targets) + assert any(r.table == "source" for r in sources) diff --git a/providers/informatica/tests/unit/informatica/lineage/test_validation.py b/providers/informatica/tests/unit/informatica/lineage/test_validation.py new file mode 100644 index 0000000000000..d27bf46da11f6 --- /dev/null +++ b/providers/informatica/tests/unit/informatica/lineage/test_validation.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.informatica.extractors.informatica import InformaticaLineageExtractor +from airflow.providers.informatica.hooks.edc import InformaticaEDCError +from airflow.providers.informatica.lineage.sql_parser import TableRef +from airflow.providers.informatica.lineage.validation import ( + _PRE_EXECUTE_CACHE_MAX, + InformaticaLineageResolutionError, + _pre_execute_cache, + _store_pre_execute_result, + pop_pre_execute_result, + resolve_informatica_lineage, + resolve_table_refs, + resolve_uri_to_object_id, + resolve_uris, + validate_informatica_lineage, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _DummyTask: + def __init__(self, inlets=None, outlets=None, dag_id="dag"): + self.inlets = inlets or [] + self.outlets = outlets or [] + self.params = {} + self.dag_id = dag_id + + +class _DummyTaskInstance: + def __init__(self, task, task_id="dummy", run_id="run1", map_index=-1, dag_id=None, try_number=None): + self.task = task + self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.dag_id = dag_id if dag_id is not None else getattr(task, "dag_id", None) + self.try_number = try_number + + +class _SQLTask: + def __init__(self, sql, conn_id="postgres_default"): + self.sql = sql + self.conn_id = conn_id + self.inlets = [] + self.outlets = [] + self.params = {} + + +@pytest.fixture +def hook(): + return MagicMock(spec=InformaticaLineageExtractor) + + +# --------------------------------------------------------------------------- +# resolve_uri_to_object_id +# --------------------------------------------------------------------------- + + +def test_resolve_uri_to_object_id_success(hook): + hook.get_object.return_value = {"id": "EDC://db/schema/tbl"} + result = resolve_uri_to_object_id(hook, "EDC://db/schema/tbl") + assert result == "EDC://db/schema/tbl" + hook.get_object.assert_called_once_with("EDC://db/schema/tbl") + + +def test_resolve_uri_to_object_id_raises_on_edc_error(hook): + hook.get_object.side_effect = InformaticaEDCError("timeout") + with pytest.raises(InformaticaLineageResolutionError, match="Failed to resolve"): + resolve_uri_to_object_id(hook, "bad://uri") + + +def test_resolve_uri_to_object_id_raises_when_no_id(hook): + hook.get_object.return_value = {"name": "tbl"} + with pytest.raises(InformaticaLineageResolutionError, match="Could not resolve"): + resolve_uri_to_object_id(hook, "EDC://db/schema/tbl") + + +# --------------------------------------------------------------------------- +# resolve_uris +# --------------------------------------------------------------------------- + + +def test_resolve_uris_string_items(hook): + hook.get_object.side_effect = lambda uri: {"id": uri} + result = resolve_uris(hook, ["uri1", "uri2"], "inlet", "t1") + assert result == [("uri1", "uri1"), ("uri2", "uri2")] + + +def test_resolve_uris_dict_items(hook): + hook.get_object.side_effect = lambda uri: {"id": uri} + result = resolve_uris(hook, [{"dataset_uri": "u1"}], "outlet", "t1") + assert result == [("u1", "u1")] + + +def test_resolve_uris_asset_object(hook): + hook.get_object.side_effect = lambda uri: {"id": uri} + + class FakeAsset: + uri = "asset://uri" + + result = resolve_uris(hook, [FakeAsset()], "inlet", "t1") + assert result == [("asset://uri", "asset://uri")] + + +def test_resolve_uris_invalid_type_raises(hook): + with pytest.raises(InformaticaLineageResolutionError, match="Invalid inlet entry"): + resolve_uris(hook, [123], "inlet", "t1") + + +def test_resolve_uris_propagates_resolution_error(hook): + hook.get_object.side_effect = InformaticaEDCError("fail") + with pytest.raises(InformaticaLineageResolutionError): + resolve_uris(hook, ["bad"], "inlet", "t1") + + +# --------------------------------------------------------------------------- +# resolve_table_refs +# --------------------------------------------------------------------------- + + +def test_resolve_table_refs_success(hook): + hook.find_object_id.return_value = "EDC://obj" + refs = [TableRef(table="tbl", schema="sch", database="cat")] + result = resolve_table_refs(hook, refs, "t1") + assert result == [("cat/sch/tbl", "EDC://obj")] + + +def test_resolve_table_refs_raises_when_not_found(hook): + hook.find_object_id.return_value = None + refs = [TableRef(table="missing")] + with pytest.raises(InformaticaLineageResolutionError, match="Could not resolve"): + resolve_table_refs(hook, refs, "t1") + + +def test_resolve_table_refs_wraps_edc_error(hook): + hook.find_object_id.side_effect = InformaticaEDCError("conn refused") + refs = [TableRef(table="tbl")] + with pytest.raises(InformaticaLineageResolutionError, match="EDC error"): + resolve_table_refs(hook, refs, "t1") + + +# --------------------------------------------------------------------------- +# resolve_informatica_lineage +# --------------------------------------------------------------------------- + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=True) +def test_resolve_lineage_disabled_operator(mock_disabled): + task = _DummyTask(inlets=["in1"]) + result = resolve_informatica_lineage(task, "t1", hook=MagicMock()) + assert result == ([], []) + + +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +def test_resolve_lineage_manual_inlets(mock_disabled, mock_resolve): + mock_resolve.side_effect = lambda hook, uri: uri + hook = MagicMock(spec=InformaticaLineageExtractor) + task = _DummyTask(inlets=["in1"], outlets=["out1"]) + + valid_inlets, valid_outlets = resolve_informatica_lineage(task, "t1", hook=hook) + assert valid_inlets == [("in1", "in1")] + assert valid_outlets == [("out1", "out1")] + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_resolve_lineage_auto(mock_get_resolver, mock_task_dis, mock_auto, mock_op_dis): + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = ([TableRef(table="src")], [TableRef(table="dst")]) + mock_get_resolver.return_value = mock_resolver + hook = MagicMock(spec=InformaticaLineageExtractor) + hook.find_object_id.side_effect = lambda c, s, t: f"EDC://{t}" + + task = _SQLTask(sql="INSERT INTO dst SELECT * FROM src") + valid_inlets, valid_outlets = resolve_informatica_lineage(task, "t1", hook=hook) + + assert valid_inlets == [("//src", "EDC://src")] + assert valid_outlets == [("//dst", "EDC://dst")] + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=False) +def test_resolve_lineage_no_inlets_no_auto(mock_auto, mock_op_dis): + task = _DummyTask() + result = resolve_informatica_lineage(task, "t1", hook=MagicMock()) + assert result == ([], []) + + +# --------------------------------------------------------------------------- +# validate_informatica_lineage (pre_execute callable) +# --------------------------------------------------------------------------- + + +@patch("airflow.providers.informatica.lineage.validation.resolve_informatica_lineage") +def test_validate_pre_execute_success(mock_resolve): + mock_resolve.return_value = ([("in1", "id_in")], [("out1", "id_out")]) + task = _DummyTask(inlets=["in1"], outlets=["out1"]) + ti = _DummyTaskInstance(task, task_id="t1", run_id="r1", try_number=1) + context = {"task_instance": ti} + + validate_informatica_lineage(context) + + mock_resolve.assert_called_once_with(task, "t1") + key = ("dag", "r1", "t1", -1, 1) + result = pop_pre_execute_result(key) + assert result == ([("in1", "id_in")], [("out1", "id_out")]) + + +@patch("airflow.providers.informatica.lineage.validation.resolve_informatica_lineage") +def test_validate_pre_execute_raises_on_error(mock_resolve): + mock_resolve.side_effect = InformaticaLineageResolutionError("not found") + task = _DummyTask(inlets=["bad"]) + ti = _DummyTaskInstance(task) + context = {"task_instance": ti} + + with pytest.raises(InformaticaLineageResolutionError, match="not found"): + validate_informatica_lineage(context) + + +def test_validate_pre_execute_noop_when_no_ti(): + validate_informatica_lineage({}) + + +def test_validate_pre_execute_noop_when_no_task(): + ti = _DummyTaskInstance(task=None, task_id="t1") + ti.task = None + validate_informatica_lineage({"task_instance": ti}) + + +# --------------------------------------------------------------------------- +# Cache management helpers +# --------------------------------------------------------------------------- + + +def test_pop_pre_execute_result_returns_none_for_missing_key(): + result = pop_pre_execute_result(("no", "such", "key")) + assert result is None + + +@patch("airflow.providers.informatica.lineage.validation.resolve_informatica_lineage") +def test_validate_pre_execute_overwrites_duplicate_key(mock_resolve): + """A second validate call for the same TI overwrites the cached result.""" + task = _DummyTask(inlets=["in1"], outlets=["out1"]) + ti = _DummyTaskInstance(task, task_id="t1", run_id="r1", try_number=1) + context = {"task_instance": ti} + key = ("dag", "r1", "t1", -1, 1) + + mock_resolve.return_value = ([("in1", "id1")], [("out1", "id2")]) + validate_informatica_lineage(context) + + mock_resolve.return_value = ([("in_new", "id3")], [("out_new", "id4")]) + validate_informatica_lineage(context) + + result = pop_pre_execute_result(key) + assert result == ([("in_new", "id3")], [("out_new", "id4")]) + + +def test_store_pre_execute_result_evicts_oldest_when_full(): + """Cache evicts the oldest entry when _PRE_EXECUTE_CACHE_MAX is reached.""" + _pre_execute_cache.clear() + try: + for i in range(_PRE_EXECUTE_CACHE_MAX): + _store_pre_execute_result(("dag", "run", f"task_{i}", -1, 1), ([], [])) + + assert len(_pre_execute_cache) == _PRE_EXECUTE_CACHE_MAX + + # One more should evict task_0 + _store_pre_execute_result(("dag", "run", "task_overflow", -1, 1), ([], [])) + assert len(_pre_execute_cache) == _PRE_EXECUTE_CACHE_MAX + assert ("dag", "run", "task_0", -1, 1) not in _pre_execute_cache + assert ("dag", "run", "task_overflow", -1, 1) in _pre_execute_cache + finally: + _pre_execute_cache.clear() + + +@patch("airflow.providers.informatica.lineage.validation.resolve_informatica_lineage") +def test_validate_pre_execute_uses_ti_key(mock_resolve): + """Context with 'ti' key (alternative to 'task_instance').""" + mock_resolve.return_value = ([], []) + task = _DummyTask() + ti = _DummyTaskInstance(task, task_id="alt", run_id="r2", try_number=2) + context = {"ti": ti} + + validate_informatica_lineage(context) + + mock_resolve.assert_called_once_with(task, "alt") + key = ("dag", "r2", "alt", -1, 2) + assert key in _pre_execute_cache + _pre_execute_cache.pop(key) diff --git a/providers/informatica/tests/unit/informatica/plugins/test_listener.py b/providers/informatica/tests/unit/informatica/plugins/test_listener.py index 46bc02ae20627..3c42e7f090913 100644 --- a/providers/informatica/tests/unit/informatica/plugins/test_listener.py +++ b/providers/informatica/tests/unit/informatica/plugins/test_listener.py @@ -16,95 +16,457 @@ # under the License. from __future__ import annotations -from unittest.mock import MagicMock +import logging +from unittest.mock import MagicMock, patch import pytest +from airflow.providers.informatica.extractors.informatica import InformaticaLineageExtractor from airflow.providers.informatica.hooks.edc import InformaticaEDCError -from airflow.providers.informatica.plugins.listener import InformaticaListener +from airflow.providers.informatica.lineage.sql_parser import TableRef +from airflow.providers.informatica.lineage.validation import ( + InformaticaLineageResolutionError, + _store_pre_execute_result, + pop_pre_execute_result, + resolve_uri_to_object_id, +) +from airflow.providers.informatica.plugins.listener import ( + InformaticaListener, +) class DummyTask: - def __init__(self, inlets=None, outlets=None): + def __init__(self, inlets=None, outlets=None, dag_id="dag"): self.inlets = inlets or [] self.outlets = outlets or [] + self.params = {} + self.dag_id = dag_id class DummyTaskInstance: - def __init__(self, task, task_id="dummy"): + def __init__(self, task, task_id="dummy", run_id="run1", map_index=-1, dag_id=None, try_number=None): self.task = task self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.dag_id = dag_id if dag_id is not None else getattr(task, "dag_id", None) + self.try_number = try_number @pytest.fixture def listener(): - informatica_listener = InformaticaListener() - informatica_listener.hook = MagicMock() - informatica_listener.log = MagicMock() - return informatica_listener + lsnr = InformaticaListener() + lsnr.hook = MagicMock(spec=InformaticaLineageExtractor) + lsnr.log = MagicMock(spec=logging.Logger) + return lsnr -def test_handle_lineage_success_str(listener): - listener.hook.get_object.side_effect = lambda x: {"id": x} +# --------------------------------------------------------------------------- +# Manual lineage — running hook pre-validates, success hook creates links +# --------------------------------------------------------------------------- + + +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_running_then_success_creates_link_str(mock_resolve, listener): + """String URIs: running resolves, success creates links.""" + mock_resolve.side_effect = lambda hook, uri: uri listener.hook.create_lineage_link.return_value = {"metadata": {}} task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") - listener.hook.get_object.assert_any_call("in1") - listener.hook.get_object.assert_any_call("out1") + + listener.on_task_instance_running(previous_state=None, task_instance=ti) + mock_resolve.assert_any_call(listener.hook, "in1") + mock_resolve.assert_any_call(listener.hook, "out1") + listener.hook.create_lineage_link.assert_not_called() # not yet + + listener.on_task_instance_success(previous_state=None, task_instance=ti) listener.hook.create_lineage_link.assert_called_once_with("in1", "out1") -def test_handle_lineage_success_dict(listener): - listener.hook.get_object.side_effect = lambda x: {"id": x} +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_running_then_success_creates_link_dict(mock_resolve, listener): + """dict(dataset_uri=…) inlets/outlets: running resolves, success creates links.""" + mock_resolve.side_effect = lambda hook, uri: uri listener.hook.create_lineage_link.return_value = {"metadata": {}} task = DummyTask(inlets=[{"dataset_uri": "in1"}], outlets=[{"dataset_uri": "out1"}]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") - listener.hook.get_object.assert_any_call("in1") - listener.hook.get_object.assert_any_call("out1") + + listener.on_task_instance_running(previous_state=None, task_instance=ti) + mock_resolve.assert_any_call(listener.hook, "in1") + mock_resolve.assert_any_call(listener.hook, "out1") + + listener.on_task_instance_success(previous_state=None, task_instance=ti) listener.hook.create_lineage_link.assert_called_once_with("in1", "out1") -def test_handle_lineage_skips_missing_objectid(listener): - listener.hook.get_object.return_value = {} +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_running_logs_warning_when_uri_not_found(mock_resolve, listener): + """If any URI cannot be resolved, running hook logs a warning and does NOT raise.""" + mock_resolve.side_effect = InformaticaLineageResolutionError("not found") task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") + + # Must not raise — listener is best-effort. + listener.on_task_instance_running(previous_state=None, task_instance=ti) + + listener.log.warning.assert_called() + assert "Could not pre-resolve lineage" in str(listener.log.warning.call_args) + # No cache entry → success hook is a no-op. + assert listener._cache_key(ti) not in listener._resolved_cache listener.hook.create_lineage_link.assert_not_called() -def test_handle_lineage_edc_error_on_inlet(listener): - listener.hook.get_object.side_effect = [InformaticaEDCError("fail"), {"id": "out1"}] +def test_running_logs_warning_on_invalid_inlet_type(listener): + """Non-string, non-dict inlet logs a warning; listener does NOT raise.""" + task = DummyTask(inlets=[123], outlets=["out1"]) + ti = DummyTaskInstance(task) + + # Must not raise — listener is best-effort. + listener.on_task_instance_running(previous_state=None, task_instance=ti) + + listener.log.warning.assert_called() + assert listener._cache_key(ti) not in listener._resolved_cache + listener.hook.create_lineage_link.assert_not_called() + + +def test_success_without_prior_running_is_noop(listener): + """If running hook was never called (no cache entry), success is a no-op.""" task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") + + listener.on_task_instance_success(previous_state=None, task_instance=ti) listener.hook.create_lineage_link.assert_not_called() -def test_handle_lineage_non_success_state(listener): +def test_failed_clears_cache(listener): + """on_task_instance_failed removes the cache entry to avoid stale state.""" task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="failed") - listener.hook.get_object.assert_not_called() + key = listener._cache_key(ti) + listener._resolved_cache[key] = ([("in1", "id_in")], [("out1", "id_out")]) + + listener.on_task_instance_failed(previous_state=None, task_instance=ti) + + assert key not in listener._resolved_cache listener.hook.create_lineage_link.assert_not_called() -def test_handle_lineage_link_creation_error_logs(listener): - listener.hook.get_object.side_effect = lambda x: {"id": x} - listener.hook.create_lineage_link.side_effect = Exception("fail link") - listener.log = MagicMock() +def test_manual_uri_resolution_uses_get_object_directly(listener): + """Manual lineage URIs should be validated via get_object, not find_object_id.""" + listener.hook.get_object.return_value = {"id": "TEST_PSTGRS://mydb/public/customers"} + + object_id = resolve_uri_to_object_id(listener.hook, "TEST_PSTGRS://mydb/public/customers") + + assert object_id == "TEST_PSTGRS://mydb/public/customers" + listener.hook.get_object.assert_called_once_with("TEST_PSTGRS://mydb/public/customers") + + +def test_manual_uri_resolution_raises_when_get_object_returns_no_id(listener): + """Manual lineage URI resolution fails when get_object returns no object id.""" + listener.hook.get_object.return_value = {"name": "customers"} + + with pytest.raises(InformaticaLineageResolutionError, match="Could not resolve EDC object for URI"): + resolve_uri_to_object_id(listener.hook, "TEST_PSTGRS://mydb/public/customers") + + +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_link_creation_error_is_logged_not_raised(mock_resolve, listener): + """Errors during link creation are logged but do not propagate.""" + mock_resolve.side_effect = lambda hook, uri: uri + listener.hook.create_lineage_link.side_effect = InformaticaEDCError("network fail") task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") + + listener.on_task_instance_running(previous_state=None, task_instance=ti) + listener.on_task_instance_success(previous_state=None, task_instance=ti) # must not raise calls = listener.log.exception.call_args_list assert any("Failed to create lineage link from" in str(call) for call, *_ in calls) -def test_handle_lineage_inlet_outlet_type_error(listener): - task = DummyTask(inlets=[123], outlets=[None]) +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_cache_consumed_exactly_once(mock_resolve, listener): + """Cache entry is popped by success; a second success call is a no-op.""" + mock_resolve.side_effect = lambda hook, uri: uri + listener.hook.create_lineage_link.return_value = {} + task = DummyTask(inlets=["in1"], outlets=["out1"]) ti = DummyTaskInstance(task) - listener._handle_lineage(ti, state="success") - listener.hook.get_object.assert_not_called() + + listener.on_task_instance_running(previous_state=None, task_instance=ti) + listener.on_task_instance_success(previous_state=None, task_instance=ti) + listener.hook.create_lineage_link.reset_mock() + + listener.on_task_instance_success(previous_state=None, task_instance=ti) # no cache listener.hook.create_lineage_link.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id") +def test_cache_key_includes_dag_id_and_try_number(mock_resolve, listener): + """Cache key should keep entries distinct across DAGs and retries.""" + mock_resolve.side_effect = lambda hook, uri: uri + listener.hook.create_lineage_link.return_value = {} + + dag_a_task = DummyTask(inlets=["in_a"], outlets=["out_a"], dag_id="dag_a") + dag_b_task = DummyTask(inlets=["in_b"], outlets=["out_b"], dag_id="dag_b") + + ti_dag_a_try_1 = DummyTaskInstance( + dag_a_task, + task_id="shared_task", + run_id="shared_run", + map_index=0, + try_number=1, + ) + ti_dag_b_try_1 = DummyTaskInstance( + dag_b_task, + task_id="shared_task", + run_id="shared_run", + map_index=0, + try_number=1, + ) + ti_dag_a_try_2 = DummyTaskInstance( + dag_a_task, + task_id="shared_task", + run_id="shared_run", + map_index=0, + try_number=2, + ) + + listener.on_task_instance_running(previous_state=None, task_instance=ti_dag_a_try_1) + listener.on_task_instance_running(previous_state=None, task_instance=ti_dag_b_try_1) + listener.on_task_instance_running(previous_state=None, task_instance=ti_dag_a_try_2) + + assert len(listener._resolved_cache) == 3 + + listener.on_task_instance_success(previous_state=None, task_instance=ti_dag_a_try_1) + assert len(listener._resolved_cache) == 2 + + listener.on_task_instance_success(previous_state=None, task_instance=ti_dag_b_try_1) + assert len(listener._resolved_cache) == 1 + + listener.on_task_instance_success(previous_state=None, task_instance=ti_dag_a_try_2) + assert len(listener._resolved_cache) == 0 + assert listener.hook.create_lineage_link.call_count == 3 + + +# --------------------------------------------------------------------------- +# Auto-lineage tests +# --------------------------------------------------------------------------- + + +class _SQLTask: + """Task with a sql attribute and no manual inlets/outlets.""" + + def __init__(self, sql, conn_id="postgres_default"): + self.sql = sql + self.conn_id = conn_id + self.inlets = [] + self.outlets = [] + self.params = {} + + +class _SQLTaskInstance: + def __init__(self, task, task_id="sql_task", run_id="run1", map_index=-1, dag_id=None, try_number=None): + self.task = task + self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.dag_id = dag_id + self.try_number = try_number + + +@pytest.fixture +def auto_listener(): + lsnr = InformaticaListener() + lsnr.hook = MagicMock(spec_set=["find_object_id", "create_lineage_link"]) + lsnr.log = MagicMock(spec=logging.Logger) + return lsnr + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_auto_lineage_fires_when_no_manual_inlets( + mock_get_resolver, mock_task_disabled, mock_auto_enabled, mock_op_disabled, auto_listener +): + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = ( + [TableRef(table="src")], + [TableRef(table="dst")], + ) + mock_get_resolver.return_value = mock_resolver + auto_listener.hook.find_object_id.side_effect = lambda cat, schema, tbl: f"DB://schema/{tbl}" + auto_listener.hook.create_lineage_link.return_value = {} + + task = _SQLTask(sql="INSERT INTO dst SELECT * FROM src") + ti = _SQLTaskInstance(task) + + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + auto_listener.hook.create_lineage_link.assert_not_called() + + auto_listener.on_task_instance_success(previous_state=None, task_instance=ti) + auto_listener.hook.create_lineage_link.assert_called_once_with("DB://schema/src", "DB://schema/dst") + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_manual_inlets_take_priority_over_auto_lineage( + mock_get_resolver, mock_auto_enabled, mock_op_disabled, auto_listener +): + auto_listener.hook.create_lineage_link.return_value = {} + + with patch( + "airflow.providers.informatica.lineage.validation.resolve_uri_to_object_id", + side_effect=lambda hook, uri: uri, + ): + task = DummyTask(inlets=["edc://manual_in"], outlets=["edc://manual_out"]) + ti = DummyTaskInstance(task) + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + auto_listener.on_task_instance_success(previous_state=None, task_instance=ti) + + mock_get_resolver.assert_not_called() + auto_listener.hook.create_lineage_link.assert_called_once_with("edc://manual_in", "edc://manual_out") + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=True) +def test_disabled_operator_skips_all_lineage(mock_op_disabled, auto_listener): + task = DummyTask(inlets=["edc://in"], outlets=["edc://out"]) + ti = DummyTaskInstance(task) + + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + auto_listener.on_task_instance_success(previous_state=None, task_instance=ti) + + auto_listener.hook.find_object_id.assert_not_called() + auto_listener.hook.create_lineage_link.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_auto_lineage_disabled_globally_skips_resolver( + mock_get_resolver, mock_auto_enabled, mock_op_disabled, auto_listener +): + task = _SQLTask(sql="SELECT * FROM t") + ti = _SQLTaskInstance(task) + + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + + mock_get_resolver.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_per_task_disable_skips_resolver( + mock_get_resolver, mock_task_disabled, mock_auto_enabled, mock_op_disabled, auto_listener +): + task = _SQLTask(sql="SELECT * FROM t") + ti = _SQLTaskInstance(task) + + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + + mock_get_resolver.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_auto_lineage_no_lineage_when_resolver_returns_none( + mock_get_resolver, mock_task_disabled, mock_auto_enabled, mock_op_disabled, auto_listener +): + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = None + mock_get_resolver.return_value = mock_resolver + + task = _SQLTask(sql="SELECT 1") + ti = _SQLTaskInstance(task) + + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + auto_listener.on_task_instance_success(previous_state=None, task_instance=ti) + + auto_listener.hook.create_lineage_link.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_auto_lineage_logs_warning_when_table_not_found( + mock_get_resolver, mock_task_disabled, mock_auto_enabled, mock_op_disabled, auto_listener +): + """If auto-resolved table is not in EDC, running hook logs a warning; does NOT raise.""" + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = ([TableRef(table="missing")], [TableRef(table="dst")]) + mock_get_resolver.return_value = mock_resolver + auto_listener.hook.find_object_id.return_value = None # not found + + task = _SQLTask(sql="INSERT INTO dst SELECT * FROM missing") + ti = _SQLTaskInstance(task) + + # Must not raise — listener is best-effort. + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + + auto_listener.log.warning.assert_called() + auto_listener.hook.create_lineage_link.assert_not_called() + + +@patch("airflow.providers.informatica.lineage.validation.is_operator_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.auto_lineage_enabled", return_value=True) +@patch("airflow.providers.informatica.lineage.validation.is_task_auto_lineage_disabled", return_value=False) +@patch("airflow.providers.informatica.lineage.validation.get_resolver") +def test_auto_lineage_logs_warning_when_edc_error( + mock_get_resolver, mock_task_disabled, mock_auto_enabled, mock_op_disabled, auto_listener +): + """EDC errors during auto-lineage resolution are logged as warnings; listener does NOT raise.""" + mock_resolver = MagicMock() + mock_resolver.resolve.return_value = ([TableRef(table="src")], [TableRef(table="dst")]) + mock_get_resolver.return_value = mock_resolver + auto_listener.hook.find_object_id.side_effect = InformaticaEDCError("connection refused") + + task = _SQLTask(sql="INSERT INTO dst SELECT * FROM src") + ti = _SQLTaskInstance(task) + + # Must not raise — listener is best-effort. + auto_listener.on_task_instance_running(previous_state=None, task_instance=ti) + + auto_listener.log.warning.assert_called() + auto_listener.hook.create_lineage_link.assert_not_called() + + +# --------------------------------------------------------------------------- +# pre_execute cache integration +# --------------------------------------------------------------------------- + + +def test_running_reuses_pre_execute_cache(listener): + """If validate_informatica_lineage cached a result, running hook reuses it.""" + task = DummyTask(inlets=["in1"], outlets=["out1"]) + ti = DummyTaskInstance(task, task_id="t1", run_id="r1", try_number=1) + key = listener._cache_key(ti) + + _store_pre_execute_result(key, ([("/in1", "id_in")], [("/out1", "id_out")])) + + listener.on_task_instance_running(previous_state=None, task_instance=ti) + + assert key in listener._resolved_cache + assert listener._resolved_cache[key] == ([("/in1", "id_in")], [("/out1", "id_out")]) + assert pop_pre_execute_result(key) is None # consumed by running hook + + +def test_failed_clears_orphaned_pre_execute_cache(listener): + """on_task_instance_failed does not leave stale pre_execute cache entries.""" + task = DummyTask(inlets=["in1"], outlets=["out1"]) + ti = DummyTaskInstance(task, task_id="t1", run_id="r1", try_number=1) + key = listener._cache_key(ti) + + # Simulate pre_execute caching a result that running hook consumed, + # then the task failing. The resolved_cache entry should be removed. + _store_pre_execute_result(key, ([("/in1", "id_in")], [("/out1", "id_out")])) + listener.on_task_instance_running(previous_state=None, task_instance=ti) + assert key in listener._resolved_cache + + listener.on_task_instance_failed(previous_state=None, task_instance=ti) + assert key not in listener._resolved_cache diff --git a/providers/informatica/tests/unit/informatica/test_conf.py b/providers/informatica/tests/unit/informatica/test_conf.py new file mode 100644 index 0000000000000..9bbdb3e7bf2e4 --- /dev/null +++ b/providers/informatica/tests/unit/informatica/test_conf.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +from airflow.providers.informatica.conf import auto_lineage_enabled, disabled_operators, is_operator_disabled + + +class SomeOperator: + pass + + +class TestDisabledOperators: + def test_empty_config_returns_empty_set(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = "" + assert disabled_operators() == set() + + def test_single_fqcn(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = "airflow.operators.bash.BashOperator" + assert disabled_operators() == {"airflow.operators.bash.BashOperator"} + + def test_multiple_fqcns_semicolon_separated(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = ( + "airflow.operators.bash.BashOperator;airflow.operators.python.PythonOperator" + ) + result = disabled_operators() + assert "airflow.operators.bash.BashOperator" in result + assert "airflow.operators.python.PythonOperator" in result + + def test_whitespace_around_entries_is_stripped(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = " foo.Bar ; baz.Qux " + assert disabled_operators() == {"foo.Bar", "baz.Qux"} + + def test_trailing_semicolon_ignored(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = "foo.Bar;" + assert disabled_operators() == {"foo.Bar"} + + +class TestAutoLineageEnabled: + def test_defaults_to_false(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.getboolean.return_value = False + assert auto_lineage_enabled() is False + + def test_returns_true_when_configured(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.getboolean.return_value = True + assert auto_lineage_enabled() is True + + +class TestIsOperatorDisabled: + def test_matching_fqcn_returns_true(self): + fqcn = f"{SomeOperator.__module__}.{SomeOperator.__name__}" + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = fqcn + assert is_operator_disabled(SomeOperator()) is True + + def test_non_matching_fqcn_returns_false(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = "other.Operator" + assert is_operator_disabled(SomeOperator()) is False + + def test_empty_disabled_list_returns_false(self): + with patch("airflow.providers.informatica.conf.conf") as mock_conf: + mock_conf.get.return_value = "" + assert is_operator_disabled(SomeOperator()) is False