Skip to content
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,13 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st
creating_job_id=self.job.id,
session=session,
)
asset_events = session.scalars(
select(AssetEvent).where(
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
PartitionedAssetKeyLog.asset_event_id == AssetEvent.id,
)
)
dag_run.consumed_asset_events.extend(asset_events)
session.flush()
apdr.created_dag_run_id = dag_run.id
session.flush()
Expand Down
64 changes: 27 additions & 37 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8786,46 +8786,15 @@ def test_partitioned_dag_run_with_customized_mapper(
runner = SchedulerJobRunner(
job=Job(job_type=SchedulerJobRunner.job_type, executor=MockExecutor(do_update=False))
)

with dag_maker(dag_id="asset-event-producer", schedule=None, session=session) as dag:
EmptyOperator(task_id="hi", outlets=[asset_1])

dr = dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped", session=session)
[ti] = dr.get_task_instances(session=session)
session.commit()

serialized_outlets = dag.get_task("hi").outlets
with custom_partition_mapper_patch():
TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[o.asprofile() for o in serialized_outlets],
outlet_events=[],
apdr = _produce_and_register_asset_event(
dag_id="asset-event-producer",
asset=asset_1,
partition_key="this-is-not-key-1-before-mapped",
session=session,
dag_maker=dag_maker,
expected_partition_key="key-1",
)
session.commit()

event = session.scalar(
select(AssetEvent).where(
AssetEvent.source_dag_id == dag.dag_id,
AssetEvent.source_run_id == dr.run_id,
)
)
assert event is not None
assert event.partition_key == "this-is-not-key-1-before-mapped"

apdr = session.scalar(
select(AssetPartitionDagRun)
.join(
PartitionedAssetKeyLog,
PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id,
)
.where(PartitionedAssetKeyLog.asset_event_id == event.id)
)
assert apdr is not None
assert apdr.created_dag_run_id is None
assert apdr.partition_key == "key-1"

with custom_partition_mapper_patch():
partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session)
session.refresh(apdr)
# Since asset event for Asset(name="asset-2") with key "key-1" has not yet been created,
Expand All @@ -8834,6 +8803,13 @@ def test_partitioned_dag_run_with_customized_mapper(
assert len(partition_dags) == 1
assert partition_dags == {"asset-event-consumer"}

dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id))
assert dag_run is not None
asset_event = dag_run.consumed_asset_events[0]
assert asset_event.source_task_id == "hi"
assert asset_event.source_dag_id == "asset-event-producer"
assert asset_event.source_run_id == "test"


@pytest.mark.need_serialized_dag
@pytest.mark.usefixtures("clear_asset_partition_rows")
Expand Down Expand Up @@ -8907,6 +8883,13 @@ def test_consumer_dag_listen_to_two_partitioned_asset(
assert len(partition_dags) == 1
assert partition_dags == {"asset-event-consumer"}

dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id))
assert dag_run is not None
for asset_event in dag_run.consumed_asset_events:
assert asset_event.source_task_id == "hi"
assert "asset-event-producer-" in asset_event.source_dag_id
assert asset_event.source_run_id == "test"


@pytest.mark.need_serialized_dag
@pytest.mark.usefixtures("clear_asset_partition_rows")
Expand Down Expand Up @@ -8971,3 +8954,10 @@ def test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper(
assert apdr.created_dag_run_id is not None
assert len(partition_dags) == 1
assert partition_dags == {"asset-event-consumer"}

dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id))
assert dag_run is not None
for asset_event in dag_run.consumed_asset_events:
assert asset_event.source_task_id == "hi"
assert "asset-event-producer-" in asset_event.source_dag_id
assert asset_event.source_run_id == "test"