diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index ee59d3f5156a3..8c31d68efe40f 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1843,6 +1843,13 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st creating_job_id=self.job.id, session=session, ) + asset_events = session.scalars( + select(AssetEvent).where( + PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id, + PartitionedAssetKeyLog.asset_event_id == AssetEvent.id, + ) + ) + dag_run.consumed_asset_events.extend(asset_events) session.flush() apdr.created_dag_run_id = dag_run.id session.flush() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 8f8a686ddf5ed..f59768cb4afcd 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -8786,46 +8786,15 @@ def test_partitioned_dag_run_with_customized_mapper( runner = SchedulerJobRunner( job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False)) ) - - with dag_maker(dag_id="asset-event-producer", schedule=None, session=session) as dag: - EmptyOperator(task_id="hi", outlets=[asset_1]) - - dr = dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped", session=session) - [ti] = dr.get_task_instances(session=session) - session.commit() - - serialized_outlets = dag.get_task("hi").outlets with custom_partition_mapper_patch(): - TaskInstance.register_asset_changes_in_db( - ti=ti, - task_outlets=[o.asprofile() for o in serialized_outlets], - outlet_events=[], + apdr = _produce_and_register_asset_event( + dag_id="asset-event-producer", + asset=asset_1, + partition_key="this-is-not-key-1-before-mapped", session=session, + dag_maker=dag_maker, + expected_partition_key="key-1", ) - session.commit() - - event = session.scalar( - select(AssetEvent).where( - AssetEvent.source_dag_id == dag.dag_id, - AssetEvent.source_run_id == dr.run_id, - ) - ) - assert event is not None - assert event.partition_key == "this-is-not-key-1-before-mapped" - - apdr = session.scalar( - select(AssetPartitionDagRun) - .join( - PartitionedAssetKeyLog, - PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, - ) - .where(PartitionedAssetKeyLog.asset_event_id == event.id) - ) - assert apdr is not None - assert apdr.created_dag_run_id is None - assert apdr.partition_key == "key-1" - - with custom_partition_mapper_patch(): partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) session.refresh(apdr) # Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created, @@ -8834,6 +8803,13 @@ def test_partitioned_dag_run_with_customized_mapper( assert len(partition_dags) == 1 assert partition_dags == {"asset-event-consumer"} + dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id)) + assert dag_run is not None + asset_event = dag_run.consumed_asset_events[0] + assert asset_event.source_task_id == "hi" + assert asset_event.source_dag_id == "asset-event-producer" + assert asset_event.source_run_id == "test" + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") @@ -8907,6 +8883,13 @@ def test_consumer_dag_listen_to_two_partitioned_asset( assert len(partition_dags) == 1 assert partition_dags == {"asset-event-consumer"} + dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id)) + assert dag_run is not None + for asset_event in dag_run.consumed_asset_events: + assert asset_event.source_task_id == "hi" + assert "asset-event-producer-" in asset_event.source_dag_id + assert asset_event.source_run_id == "test" + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") @@ -8971,3 +8954,10 @@ def test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper( assert apdr.created_dag_run_id is not None assert len(partition_dags) == 1 assert partition_dags == {"asset-event-consumer"} + + dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id)) + assert dag_run is not None + for asset_event in dag_run.consumed_asset_events: + assert asset_event.source_task_id == "hi" + assert "asset-event-producer-" in asset_event.source_dag_id + assert asset_event.source_run_id == "test"