diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index e2db6935fc204..17b957aaca396 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -440,10 +440,10 @@ OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES = { "spark.openlineage.parentJobName": "dag_id.task_id", "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.parentRunId": "11111111-1111-1111-1111-111111111111", "spark.openlineage.rootParentJobName": "dag_id", "spark.openlineage.rootParentJobNamespace": "default", - "spark.openlineage.rootParentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.rootParentRunId": "22222222-2222-2222-2222-222222222222", } @@ -1430,13 +1430,15 @@ def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job op.execute(context=self.mock_context) assert not mock_defer.called - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection( - self, mock_hook, mock_ol_accessible, mock_static_uuid + self, mock_hook, mock_ol_accessible, task_ol_run_id, dag_ol_run_id ): - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" job_config = { "placement": {"cluster_name": CLUSTER_NAME}, "pyspark_job": { @@ -1456,10 +1458,10 @@ def test_execute_openlineage_parent_job_info_injection( "spark.openlineage.transport.type": "console", "spark.openlineage.parentJobName": "dag_id.task_id", "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.parentRunId": "11111111-1111-1111-1111-111111111111", "spark.openlineage.rootParentJobName": "dag_id", "spark.openlineage.rootParentJobNamespace": "default", - "spark.openlineage.rootParentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.rootParentRunId": "22222222-2222-2222-2222-222222222222", }, }, } @@ -1499,15 +1501,17 @@ def test_execute_openlineage_parent_job_info_injection( metadata=METADATA, ) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_http_transport_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, mock_ol_accessible, mock_ol_listener, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( @@ -1556,15 +1560,17 @@ def test_execute_openlineage_http_transport_info_injection( metadata=METADATA, ) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_all_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, mock_ol_accessible, mock_ol_listener, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( @@ -2591,14 +2597,16 @@ def test_wait_for_operation_on_execute(self, mock_hook): ) mock_op.return_value.result.assert_not_called() - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection( - self, mock_hook, mock_ol_accessible, mock_static_uuid + self, mock_hook, mock_ol_accessible, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" template = { **WORKFLOW_TEMPLATE, "jobs": [ @@ -2643,10 +2651,10 @@ def test_execute_openlineage_parent_job_info_injection( "spark.sql.shuffle.partitions": "1", "spark.openlineage.parentJobName": "dag_id.task_id", "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.parentRunId": "11111111-1111-1111-1111-111111111111", "spark.openlineage.rootParentJobName": "dag_id", "spark.openlineage.rootParentJobNamespace": "default", - "spark.openlineage.rootParentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + "spark.openlineage.rootParentRunId": "22222222-2222-2222-2222-222222222222", }, }, }, @@ -2784,15 +2792,17 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces metadata=METADATA, ) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_transport_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, mock_ol_accessible, mock_ol_listener, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( @@ -2892,15 +2902,17 @@ def test_execute_openlineage_transport_info_injection( metadata=METADATA, ) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_all_info_injection( - self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, mock_ol_accessible, mock_ol_listener, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( @@ -3419,7 +3431,8 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook, mock_log): mock_log.info.assert_any_call("Batch with given id already exists.") @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -3428,11 +3441,13 @@ def test_execute_openlineage_parent_job_info_injection( mock_hook, to_dict_mock, mock_ol_accessible, - mock_static_uuid, + task_ol_run_id, + dag_ol_run_id, mock_log, ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" expected_batch = { **BATCH, "labels": EXPECTED_LABELS, @@ -3474,16 +3489,25 @@ def test_execute_openlineage_parent_job_info_injection( mock_log.info.assert_any_call("Batch job %s completed.\nDriver logs: %s", BATCH_ID, logs_link) @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_transport_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid, mock_log + self, + mock_hook, + to_dict_mock, + mock_ol_accessible, + mock_ol_listener, + task_ol_run_id, + dag_ol_run_id, + mock_log, ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( @@ -3533,16 +3557,18 @@ def test_execute_openlineage_transport_info_injection( logs_link, ) - @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") + @mock.patch("airflow.providers.openlineage.plugins.adapter.build_task_instance_ol_run_id") @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_all_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, task_ol_run_id, dag_ol_run_id ): mock_ol_accessible.return_value = True - mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" + task_ol_run_id.return_value = "11111111-1111-1111-1111-111111111111" + dag_ol_run_id.return_value = "22222222-2222-2222-2222-222222222222" fake_listener = mock.MagicMock() mock_ol_listener.return_value = fake_listener fake_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py index 1ed20b70cc284..eaca60bd80488 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py @@ -34,14 +34,16 @@ ownership_job, tags_job, ) -from openlineage.client.uuid import generate_static_uuid from airflow.providers.common.compat.sdk import Stats, conf as airflow_conf from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION, conf from airflow.providers.openlineage.utils.utils import ( OpenLineageRedactor, + build_dag_run_ol_run_id, + build_task_instance_ol_run_id, get_airflow_debug_facet, get_airflow_state_run_facet, + get_dag_job_dependency_facet, get_processing_engine_facet, ) from airflow.utils.log.logging_mixin import LoggingMixin @@ -123,12 +125,7 @@ def _read_yaml_config(path: str) -> dict | None: @staticmethod def build_dag_run_id(dag_id: str, logical_date: datetime, clear_number: int) -> str: - return str( - generate_static_uuid( - instant=logical_date, - data=f"{conf.namespace()}.{dag_id}.{clear_number}".encode(), - ) - ) + return build_dag_run_ol_run_id(dag_id=dag_id, logical_date=logical_date, clear_number=clear_number) @staticmethod def build_task_instance_run_id( @@ -138,11 +135,12 @@ def build_task_instance_run_id( logical_date: datetime, map_index: int, ): - return str( - generate_static_uuid( - instant=logical_date, - data=f"{conf.namespace()}.{dag_id}.{task_id}.{try_number}.{map_index}".encode(), - ) + return build_task_instance_ol_run_id( + dag_id=dag_id, + task_id=task_id, + try_number=try_number, + logical_date=logical_date, + map_index=map_index, ) def emit(self, event: RunEvent): @@ -365,6 +363,7 @@ def fail_task( def dag_started( self, dag_id: str, + run_id: str, logical_date: datetime, start_date: datetime, nominal_start_time: str | None, @@ -374,10 +373,14 @@ def dag_started( run_facets: dict[str, RunFacet], clear_number: int, job_description: str | None, + is_asset_triggered: bool, job_description_type: str | None = None, job_facets: dict[str, JobFacet] | None = None, # Custom job facets ): try: + job_dependency_facet = {} + if is_asset_triggered: + job_dependency_facet = get_dag_job_dependency_facet(dag_id=dag_id, dag_run_id=run_id) event = RunEvent( eventType=RunState.START, eventTime=start_date.isoformat(), @@ -396,7 +399,7 @@ def dag_started( ), nominal_start_time=nominal_start_time, nominal_end_time=nominal_end_time, - run_facets={**run_facets, **get_airflow_debug_facet()}, + run_facets={**run_facets, **get_airflow_debug_facet(), **job_dependency_facet}, ), inputs=[], outputs=[], @@ -424,9 +427,13 @@ def dag_success( owners: list[str] | None, run_facets: dict[str, RunFacet], job_description: str | None, + is_asset_triggered: bool, job_description_type: str | None = None, ): try: + job_dependency_facet = {} + if is_asset_triggered: + job_dependency_facet = get_dag_job_dependency_facet(dag_id=dag_id, dag_run_id=run_id) event = RunEvent( eventType=RunState.COMPLETE, eventTime=end_date.isoformat(), @@ -446,6 +453,7 @@ def dag_success( nominal_end_time=nominal_end_time, run_facets={ **get_airflow_state_run_facet(dag_id, run_id, task_ids, dag_run_state), + **job_dependency_facet, **get_airflow_debug_facet(), **run_facets, }, @@ -477,9 +485,13 @@ def dag_failed( msg: str, run_facets: dict[str, RunFacet], job_description: str | None, + is_asset_triggered: bool, job_description_type: str | None = None, ): try: + job_dependency_facet = {} + if is_asset_triggered: + job_dependency_facet = get_dag_job_dependency_facet(dag_id=dag_id, dag_run_id=run_id) event = RunEvent( eventType=RunState.FAIL, eventTime=end_date.isoformat(), @@ -502,6 +514,7 @@ def dag_failed( message=msg, programmingLanguage="python" ), **get_airflow_state_run_facet(dag_id, run_id, task_ids, dag_run_state), + **job_dependency_facet, **get_airflow_debug_facet(), **run_facets, }, diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index cb90f80d19879..a7bb56621d85d 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -47,6 +47,7 @@ get_task_documentation, get_task_parent_run_facet, get_user_provided_run_facets, + is_dag_run_asset_triggered, is_operator_disabled, is_selective_lineage_enabled, print_warning, @@ -669,6 +670,7 @@ def on_dag_run_running(self, dag_run: DagRun, msg: str) -> None: self.submit_callable( self.adapter.dag_started, dag_id=dag_run.dag_id, + run_id=dag_run.run_id, logical_date=date, start_date=dag_run.start_date, nominal_start_time=data_interval_start, @@ -685,6 +687,7 @@ def on_dag_run_running(self, dag_run: DagRun, msg: str) -> None: **get_airflow_dag_run_facet(dag_run), **get_dag_parent_run_facet(getattr(dag_run, "conf", {})), }, + is_asset_triggered=is_dag_run_asset_triggered(dag_run), ) except BaseException as e: self.log.warning("OpenLineage received exception in method on_dag_run_running", exc_info=e) @@ -736,6 +739,7 @@ def on_dag_run_success(self, dag_run: DagRun, msg: str) -> None: **get_airflow_dag_run_facet(dag_run), **get_dag_parent_run_facet(getattr(dag_run, "conf", {})), }, + is_asset_triggered=is_dag_run_asset_triggered(dag_run), ) except BaseException as e: self.log.warning("OpenLineage received exception in method on_dag_run_success", exc_info=e) @@ -788,6 +792,7 @@ def on_dag_run_failed(self, dag_run: DagRun, msg: str) -> None: **get_airflow_dag_run_facet(dag_run), **get_dag_parent_run_facet(getattr(dag_run, "conf", {})), }, + is_asset_triggered=is_dag_run_asset_triggered(dag_run), ) except BaseException as e: self.log.warning("OpenLineage received exception in method on_dag_run_failed", exc_info=e) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 91ccfcd13b1d7..ba45eac13dea7 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -27,13 +27,12 @@ from typing import TYPE_CHECKING, Any import attrs -from openlineage.client.facet_v2 import parent_run +from openlineage.client.facet_v2 import job_dependencies_run, parent_run from openlineage.client.utils import RedactMixin +from openlineage.client.uuid import generate_static_uuid from airflow import __version__ as AIRFLOW_VERSION from airflow.exceptions import AirflowOptionalProviderFeatureException - -# TODO: move this maybe to Airflow's logic? from airflow.models import DagRun, TaskInstance, TaskReschedule from airflow.providers.common.compat.assets import Asset from airflow.providers.common.compat.module_loading import import_string @@ -75,6 +74,7 @@ from openlineage.client.event_v2 import Dataset as OpenLineageDataset from openlineage.client.facet_v2 import RunFacet, processing_engine_run + from airflow.models.asset import AssetEvent from airflow.sdk.execution_time.secrets_masker import ( Redactable, Redacted, @@ -726,18 +726,21 @@ class DagRunInfo(InfoJsonEncodable): """Defines encoding DagRun object to JSON.""" includes = [ + "clear_number", "conf", "dag_id", - "data_interval_start", "data_interval_end", - "external_trigger", # Removed in Airflow 3, use run_type instead + "data_interval_start", + "end_date", "execution_date", # Airflow 2 + "external_trigger", # Removed in Airflow 3, use run_type instead "logical_date", # Airflow 3 "run_after", # Airflow 3 "run_id", "run_type", "start_date", - "end_date", + "triggered_by", + "triggering_user_name", # Airflow 3 ] casts = { @@ -1000,6 +1003,358 @@ def get_task_duration(ti): } +def is_dag_run_asset_triggered( + dag_run: DagRun, +): + """Return whether the given DAG run was triggered by an asset.""" + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.types import DagRunTriggeredByType + + return dag_run.triggered_by == DagRunTriggeredByType.ASSET + + # AF 2 Path + from airflow.models.dagrun import DagRunType + + return dag_run.run_type == DagRunType.DATASET_TRIGGERED # type: ignore[attr-defined] # This attr is available on AF2, but mypy can't see it + + +def build_task_instance_ol_run_id( + dag_id: str, + task_id: str, + try_number: int, + logical_date: datetime.datetime, + map_index: int, +): + """ + Generate a deterministic OpenLineage run ID for a task instance. + + Args: + dag_id: The DAG identifier. + task_id: The task identifier. + try_number: The task try number. + logical_date: The logical execution date from dagrun. + map_index: The task map index. + + Returns: + A deterministic OpenLineage run ID for the task instance. + """ + return str( + generate_static_uuid( + instant=logical_date, + data=f"{conf.namespace()}.{dag_id}.{task_id}.{try_number}.{map_index}".encode(), + ) + ) + + +def is_valid_uuid(uuid_string: str | None) -> bool: + """Validate that a string is a valid UUID format.""" + if uuid_string is None: + return False + try: + from uuid import UUID + + UUID(uuid_string) + return True + except (ValueError, TypeError): + return False + + +def build_dag_run_ol_run_id(dag_id: str, logical_date: datetime.datetime, clear_number: int) -> str: + """ + Generate a deterministic OpenLineage run ID for a DAG run. + + Args: + dag_id: The DAG identifier. + logical_date: The logical execution date. + clear_number: The DAG run clear number. + + Returns: + A deterministic OpenLineage run ID for the DAG run. + """ + return str( + generate_static_uuid( + instant=logical_date, + data=f"{conf.namespace()}.{dag_id}.{clear_number}".encode(), + ) + ) + + +def _get_eagerly_loaded_dagrun_consumed_asset_events(dag_id: str, dag_run_id: str) -> list[AssetEvent]: + """ + Retrieve consumed asset events for a DagRun with relationships eagerly loaded. + + Downstream code accesses source_task_instance, source_dag_run, and asset on each AssetEvent. + These relationships are lazy-loaded by default, which could cause N+1 query problem + (2 + 3*N queries for N events). Using `joinedload` fetches everything in a single query. + The returned AssetEvent objects have all needed relationships pre-populated in memory, + so they can be safely used after the session is closed. + + Returns: + AssetEvent objects with populated relationships, or empty list if DagRun not found. + """ + # This should only be used on scheduler, so DB access is allowed + from sqlalchemy import select + from sqlalchemy.orm import joinedload + + from airflow.utils.session import create_session + + if AIRFLOW_V_3_0_PLUS: + from airflow.models.asset import AssetEvent + + options = ( + joinedload(DagRun.consumed_asset_events).joinedload(AssetEvent.source_dag_run), + joinedload(DagRun.consumed_asset_events).joinedload(AssetEvent.source_task_instance), + joinedload(DagRun.consumed_asset_events).joinedload(AssetEvent.asset), + ) + + else: # AF2 path + from airflow.models.dataset import DatasetEvent + + options = ( + joinedload(DagRun.consumed_dataset_events).joinedload(DatasetEvent.source_dag_run), + joinedload(DagRun.consumed_dataset_events).joinedload(DatasetEvent.source_task_instance), + joinedload(DagRun.consumed_dataset_events).joinedload(DatasetEvent.dataset), + ) + + with create_session() as session: + dag_run_with_events = session.scalar( + select(DagRun).where(DagRun.dag_id == dag_id).where(DagRun.run_id == dag_run_id).options(*options) + ) + + if not dag_run_with_events: + return [] + + if AIRFLOW_V_3_0_PLUS: + events = dag_run_with_events.consumed_asset_events + else: # AF2 path + events = dag_run_with_events.consumed_dataset_events + + return events + + +def _extract_ol_info_from_asset_event(asset_event: AssetEvent) -> dict[str, str] | None: + """ + Extract OpenLineage job information from an AssetEvent. + + Information is gathered from multiple potential sources, checked in priority + order: + 1. TaskInstance (primary): Provides the most complete and reliable context. + 2. AssetEvent source fields (fallback): Offers basic `dag_id.task_id` metadata. + 3. `asset_event.extra["openlineage"]` (last resort): May include user provided OpenLineage details. + + Args: + asset_event: The AssetEvent from which to extract job information. + + Returns: + A dictionary containing `job_name`, `job_namespace`, and optionally + `run_id`, or `None` if insufficient information is available. + """ + # First check for TaskInstance + if ti := asset_event.source_task_instance: + result = { + "job_name": get_job_name(ti), + "job_namespace": conf.namespace(), + } + source_dr = asset_event.source_dag_run + if source_dr: + logical_date = source_dr.logical_date # Get logical date from DagRun for OL run_id generation + if AIRFLOW_V_3_0_PLUS and logical_date is None: + logical_date = source_dr.run_after + if logical_date is not None: + result["run_id"] = build_task_instance_ol_run_id( + dag_id=ti.dag_id, + task_id=ti.task_id, + try_number=ti.try_number, + logical_date=logical_date, + map_index=ti.map_index, + ) + return result + + # Then, check AssetEvent source_* fields + if asset_event.source_dag_id and asset_event.source_task_id: + return { + "job_name": f"{asset_event.source_dag_id}.{asset_event.source_task_id}", + "job_namespace": conf.namespace(), + # run_id cannot be constructed from these fields alone + } + + # Lastly, check asset_event.extra["openlineage"] + if asset_event.extra: + ol_info_from_extra = asset_event.extra.get("openlineage") + if isinstance(ol_info_from_extra, dict): + job_name = ol_info_from_extra.get("parentJobName") + job_namespace = ol_info_from_extra.get("parentJobNamespace") + run_id = ol_info_from_extra.get("parentRunId") + + if job_name and job_namespace: + result = { + "job_name": str(job_name), + "job_namespace": str(job_namespace), + } + if run_id: + if not is_valid_uuid(str(run_id)): + log.warning( + "Invalid runId in AssetEvent.extra; ignoring value. event_id=%s, run_id=%s", + asset_event.id, + run_id, + ) + else: + result["run_id"] = str(run_id) + return result + return None + + +def _get_ol_job_dependencies_from_asset_events(events: list[AssetEvent]) -> list[dict[str, Any]]: + """ + Extract and deduplicate OpenLineage job dependencies from asset events. + + This function processes a list of asset events, extracts OpenLineage dependency information + from all relevant sources, and deduplicates the results based on the tuple (job_namespace, job_name, run_id) + to prevent emitting duplicate dependencies. Multiple asset events from the same job but different + source runs/assets are aggregated into a single dependency entry with all source information preserved. + + Args: + events: List of AssetEvent objects to process. + + Returns: + A list of deduplicated dictionaries containing OpenLineage job dependency information. + Each dictionary includes job_name, job_namespace, optional run_id, and an asset_events + list containing source information from all aggregated events. + """ + # Use a dictionary keyed by (namespace, job_name, run_id) to deduplicate + # Multiple asset events from the same task instance should only create one dependency + deduplicated_jobs: dict[tuple[str, str, str | None], dict[str, Any]] = {} + + for asset_event in events: + # Extract OpenLineage information + ol_info = _extract_ol_info_from_asset_event(asset_event) + + # Skip if we don't have minimum required info (job_name and namespace) + if not ol_info: + log.debug( + "Insufficient OpenLineage information, skipping asset event: %s", + str(asset_event), + ) + continue + + # Create deduplication key: (namespace, job_name, run_id) + # We deduplicate on job identity (namespace + name + run_id), not on source dag_run_id + # Multiple asset events from the same job but different source runs/assets are aggregated + dedup_key = ( + ol_info["job_namespace"], + ol_info["job_name"], + ol_info.get("run_id"), + ) + + # Collect source information for this asset event + source_info = { + "dag_run_id": asset_event.source_run_id, + "asset_event_id": asset_event.id, + "asset_event_extra": asset_event.extra or None, + "asset_id": asset_event.asset_id if AIRFLOW_V_3_0_PLUS else asset_event.dataset_id, + "asset_uri": asset_event.uri, + "partition_key": getattr(asset_event, "partition_key", None), + } + + if dedup_key not in deduplicated_jobs: + # First occurrence: create the job entry with initial source info + deduplicated_jobs[dedup_key] = {**ol_info, "asset_events": [source_info]} + else: + # Already seen: append source info to existing entry + deduplicated_jobs[dedup_key]["asset_events"].append(source_info) + + result = list(deduplicated_jobs.values()) + return result + + +def _build_job_dependency_facet( + dag_id: str, dag_run_id: str +) -> dict[str, job_dependencies_run.JobDependenciesRunFacet]: + """ + Build the JobDependenciesRunFacet for a DagRun. + + Args: + dag_id: The DAG identifier. + dag_run_id: The DagRun identifier. + + Returns: + A dictionary containing the JobDependenciesRunFacet, or an empty dictionary. + """ + log.info( + "Building OpenLineage JobDependenciesRunFacet for DagRun(dag_id=%s, run_id=%s).", + dag_id, + dag_run_id, + ) + events = _get_eagerly_loaded_dagrun_consumed_asset_events(dag_id, dag_run_id) + + if not events: + log.info("DagRun %s/%s has no consumed asset events", dag_id, dag_run_id) + return {} + + ol_dependencies = _get_ol_job_dependencies_from_asset_events(events=events) + + if not ol_dependencies: + log.info( + "No OpenLineage job dependencies generated from asset events consumed by DagRun %s/%s.", + dag_id, + dag_run_id, + ) + return {} + + upstream_dependencies = [] + for job in ol_dependencies: + job_identifier = job_dependencies_run.JobIdentifier( + namespace=job["job_namespace"], + name=job["job_name"], + ) + + run_identifier = None + if job.get("run_id"): + run_identifier = job_dependencies_run.RunIdentifier(runId=job["run_id"]) + + job_dependency = job_dependencies_run.JobDependency( + job=job_identifier, + run=run_identifier, + dependency_type="IMPLICIT_ASSET_DEPENDENCY", + ).with_additional_properties(airflow={"asset_events": job.get("asset_events")}) # type: ignore[arg-type] # Fixed in OL client 1.42, waiting for release + + upstream_dependencies.append(job_dependency) + + return { + "jobDependencies": job_dependencies_run.JobDependenciesRunFacet( + upstream=upstream_dependencies, + ) + } + + +def get_dag_job_dependency_facet( + dag_id: str, dag_run_id: str +) -> dict[str, job_dependencies_run.JobDependenciesRunFacet]: + """ + Safely retrieve the asset-triggered job dependency facet for a DagRun. + + This function collects information about the asset events that triggered the specified DagRun, + including details about the originating DAG runs and task instances. If the DagRun was not triggered + by assets, or if any error occurs during lookup or processing, the function logs the error and returns + an empty dictionary. This guarantees that facet generation never raises exceptions and does not + interfere with event emission processes. + + Args: + dag_id: The DAG identifier. + dag_run_id: The DagRun identifier. + + Returns: + A dictionary with JobDependenciesRunFacet, or an empty dictionary + if the DagRun was not asset-triggered or if an error occurs. + """ + try: + return _build_job_dependency_facet(dag_id=dag_id, dag_run_id=dag_run_id) + except Exception as e: + log.warning("Failed to build JobDependenciesRunFacet for DagRun %s/%s: %s.", dag_id, dag_run_id, e) + log.debug("Exception details:", exc_info=True) + return {} + + def _get_tasks_details(dag: DAG | SerializedDAG) -> dict: tasks = { single_task.task_id: { diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py index ee22424bc7826..22bfb73c20c71 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py @@ -662,10 +662,10 @@ def test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) -@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") +@mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.timer") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.incr") -def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_static_uuid, mock_debug_mode): +def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, build_ol_id, mock_debug_mode): random_uuid = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" client = MagicMock() adapter = OpenLineageAdapter(client) @@ -703,7 +703,7 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat data_interval=(event_time, event_time), ) dag_run.dag = dag - generate_static_uuid.return_value = random_uuid + build_ol_id.return_value = random_uuid job_facets = {**get_airflow_job_facet(dag_run)} @@ -736,6 +736,7 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat ) adapter.dag_started( dag_id=dag_id, + run_id=run_id, start_date=event_time, logical_date=event_time, clear_number=0, @@ -745,6 +746,7 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat job_description=dag.description, job_description_type="text/plain", tags=["tag1", "tag2"], + is_asset_triggered=False, run_facets={ "parent": parent_run.ParentRunFacet( run=parent_run.Run(runId=random_uuid), @@ -834,11 +836,11 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch.object(DagRun, "fetch_task_instances") -@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") +@mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.timer") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.incr") def test_emit_dag_complete_event( - mock_stats_incr, mock_stats_timer, generate_static_uuid, mocked_fetch_tis, mock_debug_mode + mock_stats_incr, mock_stats_timer, build_ol_id, mocked_fetch_tis, mock_debug_mode ): random_uuid = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" client = MagicMock() @@ -896,7 +898,7 @@ def test_emit_dag_complete_event( ti2.end_date = datetime.datetime(2022, 1, 1, 0, 14, 0) mocked_fetch_tis.return_value = [ti0, ti1, ti2] - generate_static_uuid.return_value = random_uuid + build_ol_id.return_value = random_uuid adapter.dag_success( dag_id=dag_id, @@ -912,6 +914,7 @@ def test_emit_dag_complete_event( job_description_type="text/plain", nominal_start_time=datetime.datetime(2022, 1, 1).isoformat(), nominal_end_time=datetime.datetime(2022, 1, 1).isoformat(), + is_asset_triggered=False, run_facets={ "parent": parent_run.ParentRunFacet( run=parent_run.Run(runId=random_uuid), @@ -1000,11 +1003,11 @@ def test_emit_dag_complete_event( @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch.object(DagRun, "fetch_task_instances") -@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") +@mock.patch("airflow.providers.openlineage.plugins.adapter.build_dag_run_ol_run_id") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.timer") @mock.patch("airflow.providers.openlineage.plugins.adapter.Stats.incr") def test_emit_dag_failed_event( - mock_stats_incr, mock_stats_timer, generate_static_uuid, mocked_fetch_tis, mock_debug_mode + mock_stats_incr, mock_stats_timer, build_ol_id, mocked_fetch_tis, mock_debug_mode ): random_uuid = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" client = MagicMock() @@ -1062,7 +1065,7 @@ def test_emit_dag_failed_event( mocked_fetch_tis.return_value = [ti0, ti1, ti2] - generate_static_uuid.return_value = random_uuid + build_ol_id.return_value = random_uuid adapter.dag_failed( dag_id=dag_id, @@ -1090,6 +1093,7 @@ def test_emit_dag_failed_event( ), "airflowDagRun": AirflowDagRunFacet(dag={"description": "dag desc"}, dagRun=dag_run), }, + is_asset_triggered=False, ) client.emit.assert_called_once_with( diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 6bbffd2cac28a..cb3eaee82184b 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -42,14 +42,19 @@ TaskInfo, TaskInfoComplete, TaskInstanceInfo, + _extract_ol_info_from_asset_event, + _get_ol_job_dependencies_from_asset_events, _get_openlineage_data_from_dagrun_conf, _get_task_groups_details, _get_tasks_details, _truncate_string_to_byte_size, + build_dag_run_ol_run_id, + build_task_instance_ol_run_id, get_airflow_dag_run_facet, get_airflow_job_facet, get_airflow_state_run_facet, get_dag_documentation, + get_dag_job_dependency_facet, get_dag_parent_run_facet, get_fully_qualified_class_name, get_job_name, @@ -60,6 +65,8 @@ get_task_documentation, get_task_parent_run_facet, get_user_provided_run_facets, + is_dag_run_asset_triggered, + is_valid_uuid, ) from airflow.providers.standard.operators.empty import EmptyOperator from airflow.timetables.events import EventsTimetable @@ -168,6 +175,8 @@ def test_get_airflow_dag_run_facet(): dagrun_mock.run_after = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc) dagrun_mock.start_date = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc) dagrun_mock.end_date = datetime.datetime(2024, 6, 1, 1, 2, 14, 34172, tzinfo=datetime.timezone.utc) + dagrun_mock.triggering_user_name = "user1" + dagrun_mock.triggered_by = "something" dagrun_mock.dag_versions = [ MagicMock( bundle_name="bundle_name", @@ -197,6 +206,7 @@ def test_get_airflow_dag_run_facet(): dag=expected_dag_info, dagRun={ "conf": {}, + "clear_number": 0, "dag_id": "dag", "data_interval_start": "2024-06-01T01:02:03+00:00", "data_interval_end": "2024-06-01T02:03:04+00:00", @@ -213,6 +223,8 @@ def test_get_airflow_dag_run_facet(): "dag_bundle_version": "bundle_version", "dag_version_id": "version_id", "dag_version_number": "version_number", + "triggering_user_name": "user1", + "triggered_by": "something", }, ) } @@ -1908,7 +1920,7 @@ def test_dag_info_schedule_dataset_or_time_schedule(self): } -@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow < 3.0 tests") +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests") class TestDagInfoAirflow210: def test_dag_info_schedule_single_dataset_directly(self): dag = DAG( @@ -2294,10 +2306,12 @@ def test_dagrun_info_af3(mocked_dag_versions): ) assert dagrun.dag_versions == [dv1, dv2] dagrun.end_date = date + datetime.timedelta(seconds=74, microseconds=546) + dagrun.triggering_user_name = "my_user" result = DagRunInfo(dagrun) assert dict(result) == { "conf": {"a": 1}, + "clear_number": 0, "dag_id": "dag_id", "data_interval_end": "2024-06-01T00:00:00+00:00", "data_interval_start": "2024-06-01T00:00:00+00:00", @@ -2312,6 +2326,8 @@ def test_dagrun_info_af3(mocked_dag_versions): "dag_bundle_version": "bundle_version", "dag_version_id": "version_id", "dag_version_number": "version_number", + "triggered_by": DagRunTriggeredByType.UI, + "triggering_user_name": "my_user", } @@ -2338,6 +2354,7 @@ def test_dagrun_info_af2(): result = DagRunInfo(dagrun) assert dict(result) == { "conf": {"a": 1}, + "clear_number": 0, "dag_id": "dag_id", "data_interval_end": "2024-06-01T00:00:00+00:00", "data_interval_start": "2024-06-01T00:00:00+00:00", @@ -2824,3 +2841,708 @@ def test_task_with_none_timestamps_fallback_to_zero(self, dag_maker): ) assert result["airflowState"].tasksDuration["terminated_task"] == 0.0 + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 specific test") +def test_is_dag_run_asset_triggered_af3(): + """Test is_dag_run_asset_triggered for Airflow 3.""" + from airflow.models.dagrun import DagRunTriggeredByType + + dag_run = MagicMock(triggered_by=DagRunTriggeredByType.ASSET) + + assert is_dag_run_asset_triggered(dag_run) is True + + dag_run.triggered_by = DagRunTriggeredByType.TIMETABLE + assert is_dag_run_asset_triggered(dag_run) is False + + +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 specific test") +def test_is_dag_run_asset_triggered_af2(): + """Test is_dag_run_asset_triggered for Airflow 2.""" + from airflow.models.dagrun import DagRunType + + dag_run = MagicMock(run_type=DagRunType.DATASET_TRIGGERED) + + assert is_dag_run_asset_triggered(dag_run) is True + + dag_run.run_type = DagRunType.MANUAL + assert is_dag_run_asset_triggered(dag_run) is False + + +def test_build_task_instance_ol_run_id(): + """Test deterministic UUID generation for task instance.""" + logical_date = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + run_id = build_task_instance_ol_run_id( + dag_id="test_dag", + task_id="test_task", + try_number=1, + logical_date=logical_date, + map_index=0, + ) + + assert run_id == "018cc4e5-2200-7b27-b511-a7a14aa0662a" + + # Should be deterministic - same inputs produce same output + run_id2 = build_task_instance_ol_run_id( + dag_id="test_dag", + task_id="test_task", + try_number=1, + logical_date=logical_date, + map_index=0, + ) + assert run_id == run_id2 + + # Different inputs should produce different outputs + run_id3 = build_task_instance_ol_run_id( + dag_id="test_dag", + task_id="test_task", + try_number=2, # Different try_number + logical_date=logical_date, + map_index=0, + ) + assert run_id != run_id3 + + +def test_build_dag_run_ol_run_id(): + """Test deterministic UUID generation for DAG run.""" + logical_date = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + run_id = build_dag_run_ol_run_id( + dag_id="test_dag", + logical_date=logical_date, + clear_number=0, + ) + assert run_id == "018cc4e5-2200-725f-8091-596ad71712b2" + + # Should be deterministic - same inputs produce same output + run_id2 = build_dag_run_ol_run_id( + dag_id="test_dag", + logical_date=logical_date, + clear_number=0, + ) + assert run_id == run_id2 + + # Different inputs should produce different outputs + run_id3 = build_dag_run_ol_run_id( + dag_id="test_dag", + logical_date=logical_date, + clear_number=1, # Different clear_number + ) + assert run_id != run_id3 + + +def test_validate_uuid_valid(): + """Test validation of valid UUID strings.""" + valid_uuids = [ + "550e8400-e29b-41d4-a716-446655440000", + "6ba7b810-9dad-11d1-80b4-00c04fd430c8", + "00000000-0000-0000-0000-000000000000", + ] + for uuid_str in valid_uuids: + assert is_valid_uuid(uuid_str) is True + + +def test_validate_uuid_invalid(): + """Test validation of invalid UUID strings.""" + invalid_uuids = [ + "not-a-uuid", + "550e8400-e29b-41d4-a716", # Too short + "550e8400-e29b-41d4-a716-446655440000-extra", # Too long + "550e8400-e29b-41d4-a716-44665544000g", # Invalid character + "", + "123", + None, + ] + for uuid_str in invalid_uuids: + assert is_valid_uuid(uuid_str) is False + + +class TestExtractOlInfoFromAssetEvent: + """Tests for _extract_ol_info_from_asset_event function.""" + + def test_extract_ol_info_from_task_instance(self): + """Test extraction from TaskInstance (priority 1).""" + logical_date = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + + # Mock TaskInstance - using MagicMock without spec to avoid SQLAlchemy mapper inspection + ti = MagicMock() + ti.dag_id = "source_dag" + ti.task_id = "source_task" + ti.try_number = 1 + ti.map_index = 0 + + # Mock DagRun + source_dr = MagicMock() + source_dr.logical_date = logical_date + source_dr.run_after = None + + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = ti + asset_event.source_dag_run = source_dr + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = {} + + result = _extract_ol_info_from_asset_event(asset_event) + + expected_run_id = build_task_instance_ol_run_id( + dag_id="source_dag", + task_id="source_task", + try_number=1, + logical_date=logical_date, + map_index=0, + ) + assert result == { + "job_name": "source_dag.source_task", + "job_namespace": namespace(), + "run_id": expected_run_id, + } + + def test_extract_ol_info_from_task_instance_no_logical_date(self): + """Test extraction from TaskInstance without logical_date.""" + # Mock TaskInstance + ti = MagicMock() + ti.dag_id = "source_dag" + ti.task_id = "source_task" + ti.try_number = 1 + ti.map_index = 0 + + # Mock DagRun with None logical_date + source_dr = MagicMock() + source_dr.logical_date = None + source_dr.run_after = None + + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = ti + asset_event.source_dag_run = source_dr + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = {} + + result = _extract_ol_info_from_asset_event(asset_event) + + # run_id should not be included if logical_date is None + assert result == { + "job_name": "source_dag.source_task", + "job_namespace": namespace(), + } + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 specific test") + def test_extract_ol_info_from_task_instance_run_after_fallback(self): + """Test extraction from TaskInstance with run_after fallback (AF3).""" + run_after = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + + # Mock TaskInstance + ti = MagicMock() + ti.dag_id = "source_dag" + ti.task_id = "source_task" + ti.try_number = 1 + ti.map_index = 0 + + # Mock DagRun with None logical_date but run_after set + source_dr = MagicMock() + source_dr.logical_date = None + source_dr.run_after = run_after + + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = ti + asset_event.source_dag_run = source_dr + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = {} + + result = _extract_ol_info_from_asset_event(asset_event) + + # Should use run_after as fallback for logical_date + expected_run_id = build_task_instance_ol_run_id( + dag_id="source_dag", + task_id="source_task", + try_number=1, + logical_date=run_after, + map_index=0, + ) + assert result == { + "job_name": "source_dag.source_task", + "job_namespace": namespace(), + "run_id": expected_run_id, + } + + def test_extract_ol_info_from_source_fields(self): + """Test extraction from AssetEvent source fields (priority 2).""" + # Mock AssetEvent without TaskInstance + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = "source_dag" + asset_event.source_task_id = "source_task" + asset_event.extra = {} + + result = _extract_ol_info_from_asset_event(asset_event) + + # run_id cannot be constructed from source fields alone + assert result == { + "job_name": "source_dag.source_task", + "job_namespace": namespace(), + } + + def test_extract_ol_info_from_extra(self): + """Test extraction from asset_event.extra (priority 3).""" + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = { + "openlineage": { + "parentJobName": "extra_job", + "parentJobNamespace": "extra_namespace", + "parentRunId": "550e8400-e29b-41d4-a716-446655440000", + } + } + + result = _extract_ol_info_from_asset_event(asset_event) + + assert result == { + "job_name": "extra_job", + "job_namespace": "extra_namespace", + "run_id": "550e8400-e29b-41d4-a716-446655440000", + } + + def test_extract_ol_info_from_extra_no_run_id(self): + """Test extraction from asset_event.extra without run_id.""" + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = { + "openlineage": { + "parentJobName": "extra_job", + "parentJobNamespace": "extra_namespace", + } + } + + result = _extract_ol_info_from_asset_event(asset_event) + + assert result == { + "job_name": "extra_job", + "job_namespace": "extra_namespace", + } + + def test_extract_ol_info_from_extra_no_job_name(self): + """Test extraction from asset_event.extra without job_name.""" + # Mock AssetEvent + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = { + "openlineage": { + "parentRunId": "550e8400-e29b-41d4-a716-446655440000", + "parentJobNamespace": "extra_namespace", + } + } + + result = _extract_ol_info_from_asset_event(asset_event) + + assert result is None + + def test_extract_ol_info_insufficient_info(self): + """Test extraction when no information is available.""" + # Mock AssetEvent with no information + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = {} + + result = _extract_ol_info_from_asset_event(asset_event) + + assert result is None + + +class TestGetOlJobDependenciesFromAssetEvents: + """Tests for _get_ol_job_dependencies_from_asset_events function.""" + + def test_get_ol_job_dependencies_no_events(self): + """Test when no events are provided.""" + result = _get_ol_job_dependencies_from_asset_events([]) + + assert result == [] + + @patch("airflow.providers.openlineage.utils.utils._extract_ol_info_from_asset_event") + def test_get_ol_job_dependencies_with_events(self, mock_extract): + """Test extraction and deduplication of asset events.""" + # Mock asset events + asset_event1 = MagicMock() + asset_event1.id = 1 + asset_event1.source_run_id = "run1" + asset_event1.asset_id = 101 + asset_event1.dataset_id = 101 + asset_event1.uri = "s3://bucket/file1" + asset_event1.extra = {} + asset_event1.partition_key = None + + asset_event2 = MagicMock() + asset_event2.id = 2 + asset_event2.source_run_id = "run2" + asset_event2.asset_id = 102 + asset_event2.dataset_id = 102 + asset_event2.uri = "s3://bucket/file2" + asset_event2.extra = {} + asset_event2.partition_key = None + + # Mock extraction results + mock_extract.side_effect = [ + { + "job_name": "dag1.task1", + "job_namespace": "namespace", + "run_id": "550e8400-e29b-41d4-a716-446655440000", + }, + { + "job_name": "dag2.task2", + "job_namespace": "namespace", + "run_id": "550e8400-e29b-41d4-a716-446655440001", + }, + ] + + result = _get_ol_job_dependencies_from_asset_events([asset_event1, asset_event2]) + + assert result == [ + { + "job_name": "dag1.task1", + "job_namespace": "namespace", + "run_id": "550e8400-e29b-41d4-a716-446655440000", + "asset_events": [ + { + "dag_run_id": "run1", + "asset_event_id": 1, + "asset_event_extra": None, + "asset_id": 101, + "asset_uri": "s3://bucket/file1", + "partition_key": None, + } + ], + }, + { + "job_name": "dag2.task2", + "job_namespace": "namespace", + "run_id": "550e8400-e29b-41d4-a716-446655440001", + "asset_events": [ + { + "dag_run_id": "run2", + "asset_event_id": 2, + "asset_event_extra": None, + "asset_id": 102, + "asset_uri": "s3://bucket/file2", + "partition_key": None, + } + ], + }, + ] + + @patch("airflow.providers.openlineage.utils.utils._extract_ol_info_from_asset_event") + def test_get_ol_job_dependencies_deduplication(self, mock_extract): + """Test deduplication of duplicate asset events.""" + # Mock asset events + asset_event1 = MagicMock() + asset_event1.id = 1 + asset_event1.source_run_id = "run1" + asset_event1.asset_id = 101 + asset_event1.dataset_id = 101 + asset_event1.uri = "s3://bucket/file1" + asset_event1.extra = {} + asset_event1.partition_key = None + + asset_event2 = MagicMock() + asset_event2.id = 2 + asset_event2.source_run_id = "run2" + asset_event2.asset_id = 102 + asset_event2.dataset_id = 102 + asset_event2.uri = "s3://bucket/file2" + asset_event2.extra = {} + asset_event2.partition_key = None + + # Mock extraction results - same job/run (should be deduplicated) + same_info = { + "job_name": "dag1.task1", + "job_namespace": "namespace", + } + mock_extract.side_effect = [same_info, same_info] + + result = _get_ol_job_dependencies_from_asset_events([asset_event1, asset_event2]) + + # Should be deduplicated to one entry with both events aggregated + assert result == [ + { + "job_name": "dag1.task1", + "job_namespace": "namespace", + "asset_events": [ + { + "dag_run_id": "run1", + "asset_event_id": 1, + "asset_event_extra": None, + "asset_id": 101, + "asset_uri": "s3://bucket/file1", + "partition_key": None, + }, + { + "dag_run_id": "run2", + "asset_event_id": 2, + "asset_event_extra": None, + "asset_id": 102, + "asset_uri": "s3://bucket/file2", + "partition_key": None, + }, + ], + } + ] + + @patch("airflow.providers.openlineage.utils.utils._extract_ol_info_from_asset_event") + def test_get_ol_job_dependencies_insufficient_info(self, mock_extract): + """Test handling when extraction returns None.""" + # Mock asset event + asset_event = MagicMock() + asset_event.id = 1 + + # Mock extraction returning None + mock_extract.return_value = None + + result = _get_ol_job_dependencies_from_asset_events([asset_event]) + + assert result == [] + + +class TestGetDagJobDependencyFacet: + """Tests for get_dag_job_dependency_facet function. + + These tests mock only the DB-accessing function (_get_eagerly_loaded_dagrun_consumed_asset_events) + to test the full flow of facet generation including event processing and facet building. + """ + + @patch("airflow.providers.openlineage.utils.utils._get_eagerly_loaded_dagrun_consumed_asset_events") + def test_get_dag_job_dependency_facet_no_events(self, mock_get_events): + """Test when no asset events are found.""" + mock_get_events.return_value = [] + + result = get_dag_job_dependency_facet("test_dag", "test_run_id") + + assert result == {} + mock_get_events.assert_called_once_with("test_dag", "test_run_id") + + @patch("airflow.providers.openlineage.utils.utils._get_eagerly_loaded_dagrun_consumed_asset_events") + def test_get_dag_job_dependency_facet_exception_handling(self, mock_get_events): + """Test exception handling in get_dag_job_dependency_facet.""" + mock_get_events.side_effect = Exception("Database error") + + result = get_dag_job_dependency_facet("test_dag", "test_run_id") + + assert result == {} + + @patch("airflow.providers.openlineage.utils.utils._get_eagerly_loaded_dagrun_consumed_asset_events") + def test_get_dag_job_dependency_facet_insufficient_info_skipped(self, mock_get_events): + """Test that events with insufficient info are skipped.""" + # Create an event with no usable information + asset_event = MagicMock() + asset_event.source_task_instance = None + asset_event.source_dag_run = None + asset_event.source_dag_id = None + asset_event.source_task_id = None + asset_event.extra = {} + asset_event.id = 1 + asset_event.source_run_id = None + asset_event.asset_id = 101 + asset_event.dataset_id = 101 + asset_event.uri = "s3://bucket/file" + asset_event.partition_key = None + + mock_get_events.return_value = [asset_event] + + result = get_dag_job_dependency_facet("test_dag", "test_run_id") + + assert result == {} + + @patch("airflow.providers.openlineage.utils.utils._get_eagerly_loaded_dagrun_consumed_asset_events") + def test_get_dag_job_dependency_facet_with_events(self, mock_get_events): + """Test facet generation with asset events - tests full flow.""" + logical_date = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + + # Create mock asset events with source TaskInstance (priority 1 source) + ti1 = MagicMock() + ti1.dag_id = "source_dag1" + ti1.task_id = "source_task1" + ti1.try_number = 1 + ti1.map_index = 0 + + source_dr1 = MagicMock() + source_dr1.logical_date = logical_date + source_dr1.run_after = None + + asset_event1 = MagicMock() + asset_event1.source_task_instance = ti1 + asset_event1.source_dag_run = source_dr1 + asset_event1.source_dag_id = None + asset_event1.source_task_id = None + asset_event1.extra = {} + asset_event1.id = 1 + asset_event1.source_run_id = "run1" + asset_event1.asset_id = 101 + asset_event1.dataset_id = 101 + asset_event1.uri = "s3://bucket/file1" + asset_event1.partition_key = None + + # Second event with source fields (priority 2 source, no run_id) + asset_event2 = MagicMock() + asset_event2.source_task_instance = None + asset_event2.source_dag_run = None + asset_event2.source_dag_id = "source_dag2" + asset_event2.source_task_id = "source_task2" + asset_event2.extra = {} + asset_event2.id = 2 + asset_event2.source_run_id = "run2" + asset_event2.asset_id = 102 + asset_event2.dataset_id = 102 + asset_event2.uri = "s3://bucket/file2" + asset_event2.partition_key = None + + mock_get_events.return_value = [asset_event1, asset_event2] + + result = get_dag_job_dependency_facet("test_dag", "test_run_id") + + # Verify result structure + assert len(result) == 1 + facet = result["jobDependencies"] + assert len(facet.upstream) == 2 + assert len(facet.downstream) == 0 + + # Verify first dependency (from TaskInstance source, has run_id) + dep1 = facet.upstream[0] + assert dep1.job.namespace == namespace() + assert dep1.job.name == "source_dag1.source_task1" + expected_run_id = build_task_instance_ol_run_id( + dag_id="source_dag1", + task_id="source_task1", + try_number=1, + logical_date=logical_date, + map_index=0, + ) + assert dep1.run.runId == expected_run_id + assert dep1.dependency_type == "IMPLICIT_ASSET_DEPENDENCY" + assert dep1.airflow["asset_events"] == [ + { + "dag_run_id": "run1", + "asset_event_id": 1, + "asset_event_extra": None, + "asset_id": 101, + "asset_uri": "s3://bucket/file1", + "partition_key": None, + } + ] + + # Verify second dependency (from source fields, no run_id) + dep2 = facet.upstream[1] + assert dep2.job.namespace == namespace() + assert dep2.job.name == "source_dag2.source_task2" + assert dep2.run is None + assert dep2.dependency_type == "IMPLICIT_ASSET_DEPENDENCY" + assert dep2.airflow["asset_events"] == [ + { + "dag_run_id": "run2", + "asset_event_id": 2, + "asset_event_extra": None, + "asset_id": 102, + "asset_uri": "s3://bucket/file2", + "partition_key": None, + } + ] + + @patch("airflow.providers.openlineage.utils.utils._get_eagerly_loaded_dagrun_consumed_asset_events") + def test_get_dag_job_dependency_facet_deduplication(self, mock_get_events): + """Test that duplicate asset events from same job/run are deduplicated.""" + logical_date = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + + # Create two events from the same source TI (should be deduplicated) + ti = MagicMock() + ti.dag_id = "source_dag" + ti.task_id = "source_task" + ti.try_number = 1 + ti.map_index = 0 + + source_dr = MagicMock() + source_dr.logical_date = logical_date + source_dr.run_after = None + + asset_event1 = MagicMock() + asset_event1.source_task_instance = ti + asset_event1.source_dag_run = source_dr + asset_event1.source_dag_id = None + asset_event1.source_task_id = None + asset_event1.extra = {} + asset_event1.id = 1 + asset_event1.source_run_id = "run1" + asset_event1.asset_id = 101 + asset_event1.dataset_id = 101 + asset_event1.uri = "s3://bucket/file1" + asset_event1.partition_key = None + + asset_event2 = MagicMock() + asset_event2.source_task_instance = ti # Same TI + asset_event2.source_dag_run = source_dr # Same DR + asset_event2.source_dag_id = None + asset_event2.source_task_id = None + asset_event2.extra = {} + asset_event2.id = 2 + asset_event2.source_run_id = "run1" + asset_event2.asset_id = 102 # Different asset + asset_event2.dataset_id = 102 # Different asset + asset_event2.uri = "s3://bucket/file2" + asset_event2.partition_key = None + + mock_get_events.return_value = [asset_event1, asset_event2] + + result = get_dag_job_dependency_facet("test_dag", "test_run_id") + + assert len(result) == 1 + facet = result["jobDependencies"] + assert len(facet.upstream) == 1 + assert len(facet.downstream) == 0 + + # Verify the single deduplicated dependency + dep = facet.upstream[0] + assert dep.job.namespace == namespace() + assert dep.job.name == "source_dag.source_task" + expected_run_id = build_task_instance_ol_run_id( + dag_id="source_dag", + task_id="source_task", + try_number=1, + logical_date=logical_date, + map_index=0, + ) + assert dep.run.runId == expected_run_id + assert dep.dependency_type == "IMPLICIT_ASSET_DEPENDENCY" + + # Both asset events should be aggregated into single dependency + assert dep.airflow["asset_events"] == [ + { + "dag_run_id": "run1", + "asset_event_id": 1, + "asset_event_extra": None, + "asset_id": 101, + "asset_uri": "s3://bucket/file1", + "partition_key": None, + }, + { + "dag_run_id": "run1", + "asset_event_id": 2, + "asset_event_extra": None, + "asset_id": 102, + "asset_uri": "s3://bucket/file2", + "partition_key": None, + }, + ]