From 6f3c044cd5e7199ea87c4694b72cf91af88885c7 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 2 Oct 2025 12:12:08 -0700 Subject: [PATCH 1/8] Partially implement multi-asset dependencies Only works if the partitions align, and we don't support rollup. We reuse the boolean evaluator logic. Basically, we can figure out the statuses of each asset dep by evaluating the key log records. This gives us essenitally our statuses dictionary. But should we allow this or ban it? --- .../src/airflow/jobs/scheduler_job_runner.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index eac612d8c1e18..1ce147369093c 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -42,6 +42,7 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun as DRDataModel, TIRunContext +from airflow.assets.evaluation import AssetEvaluator from airflow.callbacks.callback_requests import ( DagCallbackRequest, DagRunContext, @@ -65,6 +66,7 @@ AssetWatcherModel, DagScheduleAssetAliasReference, DagScheduleAssetReference, + PartitionedAssetKeyLog, TaskInletAssetReference, TaskOutletAssetReference, ) @@ -82,6 +84,7 @@ from airflow.models.team import Team from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason from airflow.observability.trace import DebugTrace, Trace, add_debug_span +from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset from airflow.serialization.definitions.notset import NOTSET from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.timetables.simple import AssetTriggeredTimetable @@ -1686,6 +1689,18 @@ def _do_scheduling(self, session: Session) -> int: def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: partition_dag_ids: set[str] = set() + evaluator = AssetEvaluator(session) + + def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool | None: + try: + return evaluator.run(cond, statuses) + except AttributeError: + # if dag was serialized before 2.9 and we *just* upgraded, + # we may be dealing with old version. In that case, + # just wait for the dag to be reserialized. + self.log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) + return None + apdrs: Iterable[AssetPartitionDagRun] = session.scalars( select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) ) @@ -1698,6 +1713,23 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue + key_logs = session.scalars( + select(PartitionedAssetKeyLog).where( + PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id + ) + ) + assets = session.scalars( + select(AssetModel).where(AssetModel.id.in_(x.asset_id for x in key_logs)) + ) + statuses = {AssetUniqueKey.from_asset(a): True for a in assets} + # todo: AIP-76 so, this basically works when we only require one partition from each asset to be there + # but, we ultimately need rollup ability + # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure + # that all the required keys are there + # one way to do this would be just to figure out what the count should be + if not dag_ready(dag.dag_id, cond=dag.timetable.asset_condition, statuses=statuses): + continue + run_after = timezone.utcnow() dag_run = dag.create_dagrun( run_id=DagRun.generate_run_id( From 4c4df217664b41962d63d252d9e099d81c5e9499 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 15 Jan 2026 17:57:57 +0800 Subject: [PATCH 2/8] fix: Use Serialized Asset in core airflow --- .../src/airflow/jobs/scheduler_job_runner.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 1ce147369093c..3d877ff8bde01 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -84,7 +84,7 @@ from airflow.models.team import Team from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason from airflow.observability.trace import DebugTrace, Trace, add_debug_span -from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset +from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.definitions.notset import NOTSET from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.timetables.simple import AssetTriggeredTimetable @@ -115,6 +115,7 @@ from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstanceKey + from airflow.serialization.definitions.assets import SerializedAssetBase from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -1691,14 +1692,16 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st partition_dag_ids: set[str] = set() evaluator = AssetEvaluator(session) - def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool | None: + def dag_ready( + dag_id: str, cond: SerializedAssetBase, statuses: dict[SerializedAssetUniqueKey, bool] + ) -> bool | None: try: return evaluator.run(cond, statuses) except AttributeError: - # if dag was serialized before 2.9 and we *just* upgraded, - # we may be dealing with old version. In that case, - # just wait for the dag to be reserialized. - self.log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) + # if Dag was serialized before 2.9 and we *just* upgraded, + # we may be dealing with old version. In that case, + # just wait for the Dag to be reserialized. + self.log.warning("Dag '%s' has old serialization; skipping Dag run creation.", dag_id) return None apdrs: Iterable[AssetPartitionDagRun] = session.scalars( @@ -1707,9 +1710,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] for apdr in apdrs: if TYPE_CHECKING: assert apdr.target_dag_id - partition_dag_ids.add(apdr.target_dag_id) - dag = _get_current_dag(dag_id=apdr.target_dag_id, session=session) - if not dag: + if not (dag := _get_current_dag(dag_id=apdr.target_dag_id, session=session)): self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue @@ -1718,10 +1719,12 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id ) ) - assets = session.scalars( + asset_models = session.scalars( select(AssetModel).where(AssetModel.id.in_(x.asset_id for x in key_logs)) ) - statuses = {AssetUniqueKey.from_asset(a): True for a in assets} + statuses: dict[SerializedAssetUniqueKey, bool] = { + SerializedAssetUniqueKey.from_asset(a): True for a in asset_models + } # todo: AIP-76 so, this basically works when we only require one partition from each asset to be there # but, we ultimately need rollup ability # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure @@ -1730,6 +1733,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] if not dag_ready(dag.dag_id, cond=dag.timetable.asset_condition, statuses=statuses): continue + partition_dag_ids.add(apdr.target_dag_id) run_after = timezone.utcnow() dag_run = dag.create_dagrun( run_id=DagRun.generate_run_id( From 4752aa5624b4f4a07db7dd7fee11e85650a98c10 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 20 Jan 2026 21:29:02 +0800 Subject: [PATCH 3/8] refactor: remove unnecessary backward compatible code for asset evaluator since partition is a new feature, we won't have a partition aware Dag that's serialized before 3.2 --- .../src/airflow/jobs/scheduler_job_runner.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3d877ff8bde01..a94698817aae6 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -115,7 +115,6 @@ from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstanceKey - from airflow.serialization.definitions.assets import SerializedAssetBase from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -1690,20 +1689,8 @@ def _do_scheduling(self, session: Session) -> int: def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: partition_dag_ids: set[str] = set() - evaluator = AssetEvaluator(session) - - def dag_ready( - dag_id: str, cond: SerializedAssetBase, statuses: dict[SerializedAssetUniqueKey, bool] - ) -> bool | None: - try: - return evaluator.run(cond, statuses) - except AttributeError: - # if Dag was serialized before 2.9 and we *just* upgraded, - # we may be dealing with old version. In that case, - # just wait for the Dag to be reserialized. - self.log.warning("Dag '%s' has old serialization; skipping Dag run creation.", dag_id) - return None + evaluator = AssetEvaluator(session) apdrs: Iterable[AssetPartitionDagRun] = session.scalars( select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) ) @@ -1730,7 +1717,7 @@ def dag_ready( # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure # that all the required keys are there # one way to do this would be just to figure out what the count should be - if not dag_ready(dag.dag_id, cond=dag.timetable.asset_condition, statuses=statuses): + if not evaluator.run(dag.dag_id, cond=dag.timetable.asset_condition, statuses=statuses): continue partition_dag_ids.add(apdr.target_dag_id) From 6cfe7dc518a725045a6fd4e5cac7635f9973cc3a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 20 Jan 2026 21:43:05 +0800 Subject: [PATCH 4/8] perf: make an existing query a subquery --- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index a94698817aae6..4732778ac489b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1701,13 +1701,11 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue - key_logs = session.scalars( - select(PartitionedAssetKeyLog).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id - ) + pakl_subquery = select(PartitionedAssetKeyLog.asset_id).where( + PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id ) asset_models = session.scalars( - select(AssetModel).where(AssetModel.id.in_(x.asset_id for x in key_logs)) + select(AssetModel).where(AssetModel.id.in_(pakl_subquery)), ) statuses: dict[SerializedAssetUniqueKey, bool] = { SerializedAssetUniqueKey.from_asset(a): True for a in asset_models @@ -1717,7 +1715,7 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure # that all the required keys are there # one way to do this would be just to figure out what the count should be - if not evaluator.run(dag.dag_id, cond=dag.timetable.asset_condition, statuses=statuses): + if not evaluator.run(dag.timetable.asset_condition, statuses=statuses): continue partition_dag_ids.add(apdr.target_dag_id) From f5b9d9d8d42b4489809f5f053d54ea0e2136ea4a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 20 Jan 2026 22:23:19 +0800 Subject: [PATCH 5/8] test: add test case test_when_dag_run_has_partition_and_downstreams_listening_then_tables_populated_multiple --- .../tests/unit/jobs/test_scheduler_job.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 28ea1d418518d..6999b5a3b4179 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -98,6 +98,7 @@ from tests_common.test_utils.config import conf_vars, env_vars from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db, sync_dags_to_db from tests_common.test_utils.db import ( + clear_db_apdr, clear_db_assets, clear_db_backfills, clear_db_callbacks, @@ -106,6 +107,7 @@ clear_db_deadline, clear_db_import_errors, clear_db_jobs, + clear_db_pakl, clear_db_pools, clear_db_runs, clear_db_teams, @@ -8318,3 +8320,116 @@ def _find_registered_custom_partition_mapper(s): assert apdr.created_dag_run_id is not None assert len(partition_dags) == 1 assert partition_dags == {"asset-event-consumer"} + + +@pytest.mark.need_serialized_dag +def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, session: Session): + from airlfow.sdk import IdentityMapper + + clear_db_apdr() + clear_db_pakl() + + asset_1 = Asset(name="asset-1") + asset_2 = Asset(name="asset-2") + + # Consumer Dag "asset-event-consumer" + with dag_maker( + dag_id="asset-event-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1 & asset_2, + partition_mapper=IdentityMapper(), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) + ) + + def produce_and_register_asset_event( + *, + dag_id: str, + asset: Asset, + partition_key: str, + ) -> AssetPartitionDagRun: + with dag_maker(dag_id=dag_id, schedule=None, session=session) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + + dr = dag_maker.create_dagrun(partition_key=partition_key, session=session) + [ti] = dr.get_task_instances(session=session) + session.commit() + + serialized_outlets = dag.get_task("hi").outlets + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[o.asprofile() for o in serialized_outlets], + outlet_events=[], + session=session, + ) + session.commit() + + event = session.scalar( + select(AssetEvent).where( + AssetEvent.source_dag_id == dag.dag_id, + AssetEvent.source_run_id == dr.run_id, + ) + ) + assert event is not None + assert event.partition_key == partition_key + + apdr = session.scalar( + select(AssetPartitionDagRun) + .join( + PartitionedAssetKeyLog, + PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, + ) + .where(PartitionedAssetKeyLog.asset_event_id == event.id) + ) + assert apdr is not None + assert apdr.created_dag_run_id is None + assert apdr.partition_key == partition_key + + return apdr + + # Check whether we are ready to create Dag run for "asset-event-consumer" + apdr = produce_and_register_asset_event( + dag_id="asset-event-producer-1", + asset=asset_1, + partition_key="key-1", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # no Dag run will be created + assert apdr.created_dag_run_id is None + assert len(partition_dags) == 0 + assert partition_dags == set() + + apdr = produce_and_register_asset_event( + dag_id="asset-event-producer-2", + asset=asset_2, + partition_key="key-2", + ) + # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # (the one created was Asset(name="asset-2") with key "key-2") + # no Dag run will be created + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert len(partition_dags) == 0 + assert partition_dags == set() + + apdr = produce_and_register_asset_event( + dag_id="asset-event-producer-3", + asset=asset_2, + partition_key="key-1", + ) + # Now the asset event for Asset(name="asset-2") with key "key-1" is created, + # the Dag run should be created + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + assert len(partition_dags) == 1 + assert partition_dags == {"asset-event-consumer"} From aa609709ff755654d05aa4f52131ec641cf210c1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 21 Jan 2026 17:34:21 +0800 Subject: [PATCH 6/8] test: add test case test_consumer_dag_listen_to_two_partitioned_asset --- .../src/airflow/jobs/scheduler_job_runner.py | 18 +++++++++++------- .../tests/unit/jobs/test_scheduler_job.py | 7 ++++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4732778ac489b..da1fd1a96f959 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1691,21 +1691,25 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st partition_dag_ids: set[str] = set() evaluator = AssetEvaluator(session) - apdrs: Iterable[AssetPartitionDagRun] = session.scalars( + for apdr in session.scalars( select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) - ) - for apdr in apdrs: + ): if TYPE_CHECKING: assert apdr.target_dag_id + if not (dag := _get_current_dag(dag_id=apdr.target_dag_id, session=session)): self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue - pakl_subquery = select(PartitionedAssetKeyLog.asset_id).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id - ) asset_models = session.scalars( - select(AssetModel).where(AssetModel.id.in_(pakl_subquery)), + select(AssetModel).where( + exists( + select(1).where( + PartitionedAssetKeyLog.asset_id == AssetModel.id, + PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id, + ) + ) + ) ) statuses: dict[SerializedAssetUniqueKey, bool] = { SerializedAssetUniqueKey.from_asset(a): True for a in asset_models diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 6999b5a3b4179..6a816668aac60 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -81,7 +81,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger -from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, task +from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, IdentityMapper, task from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable from airflow.serialization.definitions.dag import SerializedDAG @@ -8324,8 +8324,6 @@ def _find_registered_custom_partition_mapper(s): @pytest.mark.need_serialized_dag def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, session: Session): - from airlfow.sdk import IdentityMapper - clear_db_apdr() clear_db_pakl() @@ -8344,6 +8342,7 @@ def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, sessi EmptyOperator(task_id="hi") session.commit() + # Check whehter we are ready to create Dag run for "asset-event-consumer" runner = SchedulerJobRunner( job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) ) @@ -8417,6 +8416,8 @@ def produce_and_register_asset_event( # no Dag run will be created partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) session.refresh(apdr) + # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # no Dag run will be created assert apdr.created_dag_run_id is None assert len(partition_dags) == 0 assert partition_dags == set() From e1af1fa85d47ce4e00300444d8914e5aa58806b6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 21 Jan 2026 17:46:47 +0800 Subject: [PATCH 7/8] refactor: simplify test_consumer_dag_listen_to_two_partitioned_asset --- airflow-core/tests/unit/jobs/test_scheduler_job.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 6a816668aac60..85ba450769ae2 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -8342,11 +8342,6 @@ def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, sessi EmptyOperator(task_id="hi") session.commit() - # Check whehter we are ready to create Dag run for "asset-event-consumer" - runner = SchedulerJobRunner( - job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) - ) - def produce_and_register_asset_event( *, dag_id: str, @@ -8392,6 +8387,10 @@ def produce_and_register_asset_event( return apdr + # Check whether we are ready to create Dag run for "asset-event-consumer" + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) + ) # Check whether we are ready to create Dag run for "asset-event-consumer" apdr = produce_and_register_asset_event( dag_id="asset-event-producer-1", From be72914b447f2f2df96e0db350fb7bc1c82152f1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 27 Jan 2026 15:36:32 +0800 Subject: [PATCH 8/8] test: improve test cases --- .../tests/unit/jobs/test_scheduler_job.py | 333 ++++++++++++------ 1 file changed, 221 insertions(+), 112 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 85ba450769ae2..3d9b01710e0c0 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -22,7 +22,8 @@ import logging import os from collections import Counter, deque -from collections.abc import Generator, Iterable +from collections.abc import Callable, Generator, Iterable, Iterator +from contextlib import ExitStack from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -78,6 +79,7 @@ from airflow.models.team import Team from airflow.models.trigger import Trigger from airflow.observability.trace import Trace +from airflow.partition_mapper.base import PartitionMapper as CorePartitionMapper from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger @@ -8230,35 +8232,115 @@ def test_mark_backfills_completed(dag_maker, session): assert b.completed_at.timestamp() > 0 -@pytest.mark.need_serialized_dag -def test_partitioned_dag_run_with_customized_mapper(dag_maker: DagMaker, session: Session): - from airflow.partition_mapper.base import PartitionMapper as CorePartitionMapper +class Key1Mapper(CorePartitionMapper): + """Partition Mapper that returns only key-1 as downstream key""" - class Key1Mapper(CorePartitionMapper): - def to_downstream(self, key: str) -> str: - return "key-1" + def to_downstream(self, key: str) -> str: + return "key-1" - def to_upstream(self, key: str) -> Iterable[str]: - yield key + def to_upstream(self, key: str) -> Iterable[str]: + yield key - def _find_registered_custom_partition_mapper(s): - if s == qualname(Key1Mapper): - return Key1Mapper - raise ValueError(f"unexpected class {s!r}") - asset_1 = Asset(name="asset-1") +def _find_registered_custom_partition_mapper(import_string: str) -> type[CorePartitionMapper]: + if import_string == qualname(Key1Mapper): + return Key1Mapper + raise ValueError(f"unexpected class {import_string!r}") - # Consumer Dag "asset-event-consumer" - with ( - mock.patch( + +@pytest.fixture +def custom_partition_mapper_patch() -> Callable[[], ExitStack]: + def _patch() -> ExitStack: + stack = ExitStack() + for mock_target in [ "airflow.serialization.encoders.find_registered_custom_partition_mapper", - _find_registered_custom_partition_mapper, - ), - mock.patch( "airflow.serialization.decoders.find_registered_custom_partition_mapper", - _find_registered_custom_partition_mapper, - ), - ): + ]: + stack.enter_context( + mock.patch( + mock_target, + _find_registered_custom_partition_mapper, + ) + ) + return stack + + return _patch + + +@pytest.fixture +def clear_asset_partition_rows() -> Iterator: + clear_db_apdr() + clear_db_pakl() + + yield + + clear_db_apdr() + clear_db_pakl() + + +def _produce_and_register_asset_event( + *, + dag_id: str, + asset: Asset, + partition_key: str, + session: Session, + dag_maker: DagMaker, + expected_partition_key: str | None = None, +) -> AssetPartitionDagRun: + if expected_partition_key is None: + expected_partition_key = partition_key + + with dag_maker(dag_id=dag_id, schedule=None, session=session) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + + dr = dag_maker.create_dagrun(partition_key=partition_key, session=session) + [ti] = dr.get_task_instances(session=session) + session.commit() + + serialized_outlets = dag.get_task("hi").outlets + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[o.asprofile() for o in serialized_outlets], + outlet_events=[], + session=session, + ) + session.commit() + + event = session.scalar( + select(AssetEvent).where( + AssetEvent.source_dag_id == dag.dag_id, + AssetEvent.source_run_id == dr.run_id, + ) + ) + assert event is not None + assert event.partition_key == partition_key + + apdr = session.scalar( + select(AssetPartitionDagRun) + .join( + PartitionedAssetKeyLog, + PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, + ) + .where(PartitionedAssetKeyLog.asset_event_id == event.id) + ) + assert apdr is not None + assert apdr.created_dag_run_id is None + assert apdr.partition_key == expected_partition_key + + return apdr + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_with_customized_mapper( + dag_maker: DagMaker, + session: Session, + custom_partition_mapper_patch: Callable[[], ExitStack], +): + asset_1 = Asset(name="asset-1") + + # Consumer Dag "asset-event-consumer" + with custom_partition_mapper_patch(): with dag_maker( dag_id="asset-event-consumer", schedule=PartitionedAssetTimetable( @@ -8272,61 +8354,64 @@ def _find_registered_custom_partition_mapper(s): EmptyOperator(task_id="hi") session.commit() - runner = SchedulerJobRunner( - job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) - ) + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) + ) - with dag_maker(dag_id="asset-event-producer", schedule=None, session=session) as dag: - EmptyOperator(task_id="hi", outlets=[asset_1]) + with dag_maker(dag_id="asset-event-producer", schedule=None, session=session) as dag: + EmptyOperator(task_id="hi", outlets=[asset_1]) - dr = dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped", session=session) - [ti] = dr.get_task_instances(session=session) - session.commit() + dr = dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped", session=session) + [ti] = dr.get_task_instances(session=session) + session.commit() - serialized_outlets = dag.get_task("hi").outlets + serialized_outlets = dag.get_task("hi").outlets + with custom_partition_mapper_patch(): TaskInstance.register_asset_changes_in_db( ti=ti, task_outlets=[o.asprofile() for o in serialized_outlets], outlet_events=[], session=session, ) - session.commit() + session.commit() - event = session.scalar( - select(AssetEvent).where( - AssetEvent.source_dag_id == dag.dag_id, - AssetEvent.source_run_id == dr.run_id, - ) + event = session.scalar( + select(AssetEvent).where( + AssetEvent.source_dag_id == dag.dag_id, + AssetEvent.source_run_id == dr.run_id, ) - assert event is not None - assert event.partition_key == "this-is-not-key-1-before-mapped" + ) + assert event is not None + assert event.partition_key == "this-is-not-key-1-before-mapped" - apdr = session.scalar( - select(AssetPartitionDagRun) - .join( - PartitionedAssetKeyLog, - PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, - ) - .where(PartitionedAssetKeyLog.asset_event_id == event.id) + apdr = session.scalar( + select(AssetPartitionDagRun) + .join( + PartitionedAssetKeyLog, + PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, ) - assert apdr is not None - assert apdr.created_dag_run_id is None - assert apdr.partition_key == "key-1" + .where(PartitionedAssetKeyLog.asset_event_id == event.id) + ) + assert apdr is not None + assert apdr.created_dag_run_id is None + assert apdr.partition_key == "key-1" + with custom_partition_mapper_patch(): partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) - session.refresh(apdr) - # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, - # no Dag run will be created - assert apdr.created_dag_run_id is not None - assert len(partition_dags) == 1 - assert partition_dags == {"asset-event-consumer"} + session.refresh(apdr) + # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # no Dag run will be created + assert apdr.created_dag_run_id is not None + assert len(partition_dags) == 1 + assert partition_dags == {"asset-event-consumer"} @pytest.mark.need_serialized_dag -def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, session: Session): - clear_db_apdr() - clear_db_pakl() - +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_consumer_dag_listen_to_two_partitioned_asset( + dag_maker: DagMaker, + session: Session, +): asset_1 = Asset(name="asset-1") asset_2 = Asset(name="asset-2") @@ -8342,60 +8427,17 @@ def test_consumer_dag_listen_to_two_partitioned_asset(dag_maker: DagMaker, sessi EmptyOperator(task_id="hi") session.commit() - def produce_and_register_asset_event( - *, - dag_id: str, - asset: Asset, - partition_key: str, - ) -> AssetPartitionDagRun: - with dag_maker(dag_id=dag_id, schedule=None, session=session) as dag: - EmptyOperator(task_id="hi", outlets=[asset]) - - dr = dag_maker.create_dagrun(partition_key=partition_key, session=session) - [ti] = dr.get_task_instances(session=session) - session.commit() - - serialized_outlets = dag.get_task("hi").outlets - TaskInstance.register_asset_changes_in_db( - ti=ti, - task_outlets=[o.asprofile() for o in serialized_outlets], - outlet_events=[], - session=session, - ) - session.commit() - - event = session.scalar( - select(AssetEvent).where( - AssetEvent.source_dag_id == dag.dag_id, - AssetEvent.source_run_id == dr.run_id, - ) - ) - assert event is not None - assert event.partition_key == partition_key - - apdr = session.scalar( - select(AssetPartitionDagRun) - .join( - PartitionedAssetKeyLog, - PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, - ) - .where(PartitionedAssetKeyLog.asset_event_id == event.id) - ) - assert apdr is not None - assert apdr.created_dag_run_id is None - assert apdr.partition_key == partition_key - - return apdr - # Check whether we are ready to create Dag run for "asset-event-consumer" runner = SchedulerJobRunner( job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) ) - # Check whether we are ready to create Dag run for "asset-event-consumer" - apdr = produce_and_register_asset_event( + + apdr = _produce_and_register_asset_event( dag_id="asset-event-producer-1", asset=asset_1, partition_key="key-1", + session=session, + dag_maker=dag_maker, ) partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) session.refresh(apdr) @@ -8405,26 +8447,28 @@ def produce_and_register_asset_event( assert len(partition_dags) == 0 assert partition_dags == set() - apdr = produce_and_register_asset_event( + apdr = _produce_and_register_asset_event( dag_id="asset-event-producer-2", asset=asset_2, partition_key="key-2", + session=session, + dag_maker=dag_maker, ) - # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, - # (the one created was Asset(name="asset-2") with key "key-2") - # no Dag run will be created partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) session.refresh(apdr) # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # (the one created was Asset(name="asset-2") with key "key-2") # no Dag run will be created assert apdr.created_dag_run_id is None assert len(partition_dags) == 0 assert partition_dags == set() - apdr = produce_and_register_asset_event( + apdr = _produce_and_register_asset_event( dag_id="asset-event-producer-3", asset=asset_2, partition_key="key-1", + session=session, + dag_maker=dag_maker, ) # Now the asset event for Asset(name="asset-2") with key "key-1" is created, # the Dag run should be created @@ -8433,3 +8477,68 @@ def produce_and_register_asset_event( assert apdr.created_dag_run_id is not None assert len(partition_dags) == 1 assert partition_dags == {"asset-event-consumer"} + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper( + dag_maker: DagMaker, + session: Session, + custom_partition_mapper_patch: Callable[[], ExitStack], +): + asset_1 = Asset(name="asset-1") + asset_2 = Asset(name="asset-2") + + # Consumer Dag "asset-event-consumer" + with custom_partition_mapper_patch(): + with dag_maker( + dag_id="asset-event-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1 & asset_2, + # TODO: (GH-57694) this partition mapper interface will be moved into asset as per-asset mapper + # and the type mismatch will be handled there + partition_mapper=Key1Mapper(), # type: ignore[arg-type] + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + # Check whether we are ready to create Dag run for "asset-event-consumer" + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) + ) + + with custom_partition_mapper_patch(): + apdr = _produce_and_register_asset_event( + dag_id="asset-event-producer-1", + asset=asset_1, + partition_key="key-2", + session=session, + dag_maker=dag_maker, + expected_partition_key="key-1", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, + # no Dag run will be created + assert apdr.created_dag_run_id is None + assert len(partition_dags) == 0 + assert partition_dags == set() + + with custom_partition_mapper_patch(): + apdr = _produce_and_register_asset_event( + dag_id="asset-event-producer-2", + asset=asset_2, + partition_key="key-3", + session=session, + dag_maker=dag_maker, + expected_partition_key="key-1", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + # Even though the original key passed was key-3, the consumer Dag uses Key1Mapper to transform anything + # into key-1. Thus, the criteria is met and a Dag run should be created + assert apdr.created_dag_run_id is not None + assert len(partition_dags) == 1 + assert partition_dags == {"asset-event-consumer"}