From 7920321687ac09bcd1013d952c468548e174d77a Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:58:02 +0200 Subject: [PATCH] SQLA: Replace the deprecated `lazy="noload"` with `lazy="raise"` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SQLAlchemy 2.1 deprecated the `noload` lazy loading strategy (sqlalchemy/sqlalchemy#11045). `noload` silently returns `None`/empty collections — essentially incorrect results — and will be removed in a future release. This PR replaces all 5 occurrences of `lazy="noload"` with `lazy="raise"`, which raises `InvalidRequestError` if the relationship is accessed without an explicit eager load (e.g. `joinedload`). All affected relationships are already properly loaded via `joinedload()` wherever they're accessed, so this is a safe drop-in that also catches missing eager loads at development time instead of silently returning `None`. Two callers needed fixes to work correctly with `lazy="raise"`: - `TaskInstance.rendered_task_instance_fields` and `TaskInstance.hitl_detail` needed `passive_deletes=True` to tell SQLAlchemy to rely on the DB-level `ON DELETE CASCADE` rather than attempting ORM-level cascade processing (which would fail since FK columns are also PK columns on RTIF, and `lazy="raise"` prevents the ORM from loading the collection to clear them). - The HITL API endpoints needed `joinedload(TI.rendered_task_instance_fields)` added to their queries, since `TaskInstanceResponse` accesses `rendered_task_instance_fields` during Pydantic serialization. The `get_hitl_detail` endpoint also needed an explicit `model_validate()` call so serialization happens while the session is still active. **Changed models:** - `Log.task_instance` - `TaskInstance.rendered_task_instance_fields` - `TaskInstance.hitl_detail` - `TaskInstanceHistory.hitl_detail` - `XComModel.task` related: #61229 --- .../api_fastapi/common/db/task_instances.py | 5 ++- .../core_api/routes/public/hitl.py | 8 +++- .../core_api/routes/public/task_instances.py | 37 +++++++++++++++++-- .../api_fastapi/core_api/routes/ui/dags.py | 4 ++ airflow-core/src/airflow/models/log.py | 2 +- .../src/airflow/models/taskinstance.py | 6 ++- .../src/airflow/models/taskinstancehistory.py | 2 +- airflow-core/src/airflow/models/xcom.py | 2 +- airflow-core/tests/unit/models/test_log.py | 22 +++++++++++ .../tests/unit/models/test_taskinstance.py | 23 ++++++++++++ airflow-core/tests/unit/models/test_xcom.py | 20 ++++++++++ 11 files changed, 119 insertions(+), 12 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py index 615e1d260e058..37ff83d406c01 100644 --- a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py @@ -52,5 +52,8 @@ def eager_load_TI_and_TIH_for_validation( contains_eager(orm_model.dag_version).options(joinedload(DagVersion.bundle)), ) if orm_model is TaskInstance: - query = query.options(joinedload(orm_model.task_instance_note)) + query = query.options( + joinedload(orm_model.task_instance_note), + joinedload(orm_model.rendered_task_instance_fields), + ) return query diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py index 55222782806fb..125781ed65486 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py @@ -84,6 +84,9 @@ def _get_task_instance_with_hitl_detail( try_number: int | None = None, ) -> TI | TIH: def _query(orm_object: Base) -> TI | TIH | None: + options = [joinedload(orm_object.hitl_detail)] + if orm_object is TI: + options.append(joinedload(TI.rendered_task_instance_fields)) query = ( select(orm_object) .where( @@ -92,7 +95,7 @@ def _query(orm_object: Base) -> TI | TIH | None: orm_object.task_id == task_id, orm_object.map_index == map_index, ) - .options(joinedload(orm_object.hitl_detail)) + .options(*options) ) if try_number is not None: @@ -213,7 +216,7 @@ def get_hitl_detail( map_index=map_index, try_number=None, ) - return task_instance.hitl_detail + return HITLDetail.model_validate(task_instance.hitl_detail) @task_instances_hitl_router.get( @@ -304,6 +307,7 @@ def get_hitl_details( joinedload(TI.dag_run).joinedload(DagRun.dag_model), joinedload(TI.task_instance_note), joinedload(TI.dag_version).joinedload(DagVersion.bundle), + joinedload(TI.rendered_task_instance_fields), ), ) ) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 9488c0b7591fb..1af56b565f309 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -17,6 +17,7 @@ from __future__ import annotations +from collections.abc import Sequence from typing import Annotated, Literal, cast import structlog @@ -637,10 +638,6 @@ def get_task_instances_batch( limit=limit, session=session, ) - task_instance_select = task_instance_select.options( - joinedload(TI.rendered_task_instance_fields), - ) - task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( @@ -802,6 +799,7 @@ def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"] *((t, m) for t, m in mapped_tasks_tuples if t not in normal_task_ids), ] + task_instances: Sequence[TI] if dag_run_id is not None and not (past or future): # Use run_id-based clearing when we have a specific dag_run_id and not using past/future task_instances = dag.clear( @@ -845,6 +843,21 @@ def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"] user=user, ) + # Eagerly load rendered_task_instance_fields for serialization (lazy='raise' prevents lazy access). + # dag.clear() returns TIs without this relationship loaded; re-query with joinedload. + # populate_existing=True ensures the joinedload updates TIs already in the identity map. + if task_instances: + task_instances = ( + session.scalars( + select(TI) + .options(joinedload(TI.rendered_task_instance_fields)) + .where(TI.id.in_([ti.id for ti in task_instances])) + .execution_options(populate_existing=True) + ) + .unique() + .all() + ) + return TaskInstanceCollectionResponse( task_instances=[TaskInstanceResponse.model_validate(ti) for ti in task_instances], total_entries=len(task_instances), @@ -878,6 +891,7 @@ def patch_task_instance_dry_run( update_mask: list[str] | None = Query(None), ) -> TaskInstanceCollectionResponse: """Update a task instance dry_run mode.""" + tis: Sequence[TI] dag, tis, data = _patch_ti_validate_request( dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask ) @@ -899,6 +913,21 @@ def patch_task_instance_dry_run( or [] ) + # Eagerly load rendered_task_instance_fields for serialization (lazy='raise' prevents lazy access). + # dag.set_task_instance_state() may return TIs without this relationship loaded. + # populate_existing=True ensures the joinedload updates TIs already in the identity map. + if tis: + tis = ( + session.scalars( + select(TI) + .options(joinedload(TI.rendered_task_instance_fields)) + .where(TI.id.in_([ti.id for ti in tis])) + .execution_options(populate_existing=True) + ) + .unique() + .all() + ) + return TaskInstanceCollectionResponse( task_instances=[ TaskInstanceResponse.model_validate( diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/dags.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/dags.py index 41da300e5d8a5..772beb465fd55 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/dags.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/dags.py @@ -22,6 +22,7 @@ from fastapi import Depends, status from sqlalchemy import and_, func, select +from sqlalchemy.orm import defaultload from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import ( @@ -219,6 +220,9 @@ def get_dags( HITLDetail, ) .join(TaskInstance, HITLDetail.ti_id == TaskInstance.id) + .options( + defaultload(HITLDetail.task_instance).joinedload(TaskInstance.rendered_task_instance_fields) + ) .where( HITLDetail.responded_at.is_(None), TaskInstance.state == TaskInstanceState.DEFERRED, diff --git a/airflow-core/src/airflow/models/log.py b/airflow-core/src/airflow/models/log.py index eb2725068918d..8eb1d4bbcd2fb 100644 --- a/airflow-core/src/airflow/models/log.py +++ b/airflow-core/src/airflow/models/log.py @@ -63,7 +63,7 @@ class Log(Base): viewonly=True, foreign_keys=[dag_id, task_id, run_id, map_index], primaryjoin="and_(Log.dag_id == TaskInstance.dag_id, Log.task_id == TaskInstance.task_id, Log.run_id == TaskInstance.run_id, Log.map_index == TaskInstance.map_index)", - lazy="noload", + lazy="raise", ) __table_args__ = ( diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 4c2137a5343cf..6fca3af52b7bd 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -614,8 +614,10 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload): trigger = relationship("Trigger", uselist=False, back_populates="task_instance") triggerer_job = association_proxy("trigger", "triggerer_job") dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True) - rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False) - hitl_detail = relationship("HITLDetail", lazy="noload", uselist=False) + rendered_task_instance_fields = relationship( + "RenderedTaskInstanceFields", lazy="raise", uselist=False, passive_deletes=True + ) + hitl_detail = relationship("HITLDetail", lazy="raise", uselist=False, passive_deletes=True) run_after = association_proxy("dag_run", "run_after") logical_date = association_proxy("dag_run", "logical_date") diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py index df7b0a2876f0e..ee192624123ca 100644 --- a/airflow-core/src/airflow/models/taskinstancehistory.py +++ b/airflow-core/src/airflow/models/taskinstancehistory.py @@ -131,7 +131,7 @@ class TaskInstanceHistory(Base): foreign_keys=[run_id, dag_id], ) - hitl_detail = relationship("HITLDetailHistory", lazy="noload", uselist=False) + hitl_detail = relationship("HITLDetailHistory", lazy="raise", uselist=False) def __init__( self, diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 0f848bc60bf4b..d474d521046c0 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -106,7 +106,7 @@ class XComModel(TaskInstanceDependencies): task = relationship( "TaskInstance", viewonly=True, - lazy="noload", + lazy="raise", ) @classmethod diff --git a/airflow-core/tests/unit/models/test_log.py b/airflow-core/tests/unit/models/test_log.py index d049be9a0d165..24961e3b00c33 100644 --- a/airflow-core/tests/unit/models/test_log.py +++ b/airflow-core/tests/unit/models/test_log.py @@ -19,6 +19,7 @@ import pytest from sqlalchemy import select +from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import joinedload from airflow.models.log import Log @@ -29,6 +30,27 @@ class TestLogTaskInstanceReproduction: + def test_log_task_instance_raises_without_joinedload(self, dag_maker, session): + """Accessing Log.task_instance without joinedload should raise.""" + with dag_maker("dag_raise_test", session=session): + EmptyOperator(task_id="task_1") + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("task_1") + session.merge(ti) + session.commit() + + log = Log(event="test_event", task_instance=ti) + session.add(log) + session.commit() + + session.expire_all() + stmt = select(Log).where(Log.id == log.id) + loaded_log = session.scalar(stmt) + + with pytest.raises(InvalidRequestError): + loaded_log.task_instance + def test_log_task_instance_join_correctness(self, dag_maker, session): # Create dag_1 with a task with dag_maker("dag_1", session=session): diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 3fa09106ade7f..80ca037a4b66c 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2750,6 +2750,29 @@ def test_defer_task_with_trigger_timeout(create_task_instance): assert abs((ti.trigger_timeout - expected_timeout).total_seconds()) < 5 +class TestTaskInstanceRelationships: + @pytest.mark.parametrize( + "attr", + ["rendered_task_instance_fields", "hitl_detail"], + ) + def test_noload_relationships_raise_without_joinedload(self, dag_maker, session, attr): + """Accessing lazy='raise' relationships without joinedload should raise.""" + from sqlalchemy.exc import InvalidRequestError + + with dag_maker("test_dag", session=session): + EmptyOperator(task_id="task_1") + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("task_1") + session.merge(ti) + session.commit() + + loaded_ti = session.scalar(select(TaskInstance).where(TaskInstance.id == ti.id)) + + with pytest.raises(InvalidRequestError): + getattr(loaded_ti, attr) + + class TestTaskInstanceRecordTaskMapXComPush: """Test TI.xcom_push() correctly records return values for task-mapping.""" diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index c703e9312c102..740dc204fa840 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -126,6 +126,26 @@ def task_instances(session, task_instance): return task_instance, ti2 # ti2 will be cleaned up automatically with the DAG run. +class TestXComModelRelationships: + def test_xcom_task_raises_without_joinedload(self, task_instance, session): + """Accessing XComModel.task without joinedload should raise.""" + from sqlalchemy.exc import InvalidRequestError + + XComModel.set( + key="test_key", + value="test_value", + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + run_id=task_instance.run_id, + session=session, + ) + + xcom = session.scalar(select(XComModel).where(XComModel.task_id == task_instance.task_id)) + + with pytest.raises(InvalidRequestError): + xcom.task + + class TestXCom: @conf_vars({("core", "xcom_backend"): "unit.models.test_xcom.CustomXCom"}) def test_resolve_xcom_class(self):