Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,8 @@ def eager_load_TI_and_TIH_for_validation(
contains_eager(orm_model.dag_version).options(joinedload(DagVersion.bundle)),
)
if orm_model is TaskInstance:
query = query.options(joinedload(orm_model.task_instance_note))
query = query.options(
joinedload(orm_model.task_instance_note),
joinedload(orm_model.rendered_task_instance_fields),
)
return query
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def _get_task_instance_with_hitl_detail(
try_number: int | None = None,
) -> TI | TIH:
def _query(orm_object: Base) -> TI | TIH | None:
options = [joinedload(orm_object.hitl_detail)]
if orm_object is TI:
options.append(joinedload(TI.rendered_task_instance_fields))
query = (
select(orm_object)
.where(
Expand All @@ -92,7 +95,7 @@ def _query(orm_object: Base) -> TI | TIH | None:
orm_object.task_id == task_id,
orm_object.map_index == map_index,
)
.options(joinedload(orm_object.hitl_detail))
.options(*options)
)

if try_number is not None:
Expand Down Expand Up @@ -213,7 +216,7 @@ def get_hitl_detail(
map_index=map_index,
try_number=None,
)
return task_instance.hitl_detail
return HITLDetail.model_validate(task_instance.hitl_detail)


@task_instances_hitl_router.get(
Expand Down Expand Up @@ -304,6 +307,7 @@ def get_hitl_details(
joinedload(TI.dag_run).joinedload(DagRun.dag_model),
joinedload(TI.task_instance_note),
joinedload(TI.dag_version).joinedload(DagVersion.bundle),
joinedload(TI.rendered_task_instance_fields),
),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import Annotated, Literal, cast

import structlog
Expand Down Expand Up @@ -637,10 +638,6 @@ def get_task_instances_batch(
limit=limit,
session=session,
)
task_instance_select = task_instance_select.options(
joinedload(TI.rendered_task_instance_fields),
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
Expand Down Expand Up @@ -802,6 +799,7 @@ def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"]
*((t, m) for t, m in mapped_tasks_tuples if t not in normal_task_ids),
]

task_instances: Sequence[TI]
if dag_run_id is not None and not (past or future):
# Use run_id-based clearing when we have a specific dag_run_id and not using past/future
task_instances = dag.clear(
Expand Down Expand Up @@ -845,6 +843,21 @@ def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"]
user=user,
)

# Eagerly load rendered_task_instance_fields for serialization (lazy='raise' prevents lazy access).
# dag.clear() returns TIs without this relationship loaded; re-query with joinedload.
# populate_existing=True ensures the joinedload updates TIs already in the identity map.
if task_instances:
task_instances = (
session.scalars(
select(TI)
.options(joinedload(TI.rendered_task_instance_fields))
.where(TI.id.in_([ti.id for ti in task_instances]))
.execution_options(populate_existing=True)
)
.unique()
.all()
)

return TaskInstanceCollectionResponse(
task_instances=[TaskInstanceResponse.model_validate(ti) for ti in task_instances],
total_entries=len(task_instances),
Expand Down Expand Up @@ -878,6 +891,7 @@ def patch_task_instance_dry_run(
update_mask: list[str] | None = Query(None),
) -> TaskInstanceCollectionResponse:
"""Update a task instance dry_run mode."""
tis: Sequence[TI]
dag, tis, data = _patch_ti_validate_request(
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
)
Expand All @@ -899,6 +913,21 @@ def patch_task_instance_dry_run(
or []
)

# Eagerly load rendered_task_instance_fields for serialization (lazy='raise' prevents lazy access).
# dag.set_task_instance_state() may return TIs without this relationship loaded.
# populate_existing=True ensures the joinedload updates TIs already in the identity map.
if tis:
tis = (
session.scalars(
select(TI)
.options(joinedload(TI.rendered_task_instance_fields))
.where(TI.id.in_([ti.id for ti in tis]))
.execution_options(populate_existing=True)
)
.unique()
.all()
)

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from fastapi import Depends, status
from sqlalchemy import and_, func, select
from sqlalchemy.orm import defaultload

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.db.common import (
Expand Down Expand Up @@ -219,6 +220,9 @@ def get_dags(
HITLDetail,
)
.join(TaskInstance, HITLDetail.ti_id == TaskInstance.id)
.options(
defaultload(HITLDetail.task_instance).joinedload(TaskInstance.rendered_task_instance_fields)
)
.where(
HITLDetail.responded_at.is_(None),
TaskInstance.state == TaskInstanceState.DEFERRED,
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Log(Base):
viewonly=True,
foreign_keys=[dag_id, task_id, run_id, map_index],
primaryjoin="and_(Log.dag_id == TaskInstance.dag_id, Log.task_id == TaskInstance.task_id, Log.run_id == TaskInstance.run_id, Log.map_index == TaskInstance.map_index)",
lazy="noload",
lazy="raise",
)

__table_args__ = (
Expand Down
6 changes: 4 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,10 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
trigger = relationship("Trigger", uselist=False, back_populates="task_instance")
triggerer_job = association_proxy("trigger", "triggerer_job")
dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
hitl_detail = relationship("HITLDetail", lazy="noload", uselist=False)
rendered_task_instance_fields = relationship(
"RenderedTaskInstanceFields", lazy="raise", uselist=False, passive_deletes=True
)
hitl_detail = relationship("HITLDetail", lazy="raise", uselist=False, passive_deletes=True)

run_after = association_proxy("dag_run", "run_after")
logical_date = association_proxy("dag_run", "logical_date")
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/taskinstancehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class TaskInstanceHistory(Base):
foreign_keys=[run_id, dag_id],
)

hitl_detail = relationship("HITLDetailHistory", lazy="noload", uselist=False)
hitl_detail = relationship("HITLDetailHistory", lazy="raise", uselist=False)

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class XComModel(TaskInstanceDependencies):
task = relationship(
"TaskInstance",
viewonly=True,
lazy="noload",
lazy="raise",
)

@classmethod
Expand Down
22 changes: 22 additions & 0 deletions airflow-core/tests/unit/models/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest
from sqlalchemy import select
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import joinedload

from airflow.models.log import Log
Expand All @@ -29,6 +30,27 @@


class TestLogTaskInstanceReproduction:
def test_log_task_instance_raises_without_joinedload(self, dag_maker, session):
"""Accessing Log.task_instance without joinedload should raise."""
with dag_maker("dag_raise_test", session=session):
EmptyOperator(task_id="task_1")

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("task_1")
session.merge(ti)
session.commit()

log = Log(event="test_event", task_instance=ti)
session.add(log)
session.commit()

session.expire_all()
stmt = select(Log).where(Log.id == log.id)
loaded_log = session.scalar(stmt)

with pytest.raises(InvalidRequestError):
loaded_log.task_instance

def test_log_task_instance_join_correctness(self, dag_maker, session):
# Create dag_1 with a task
with dag_maker("dag_1", session=session):
Expand Down
23 changes: 23 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,6 +2750,29 @@ def test_defer_task_with_trigger_timeout(create_task_instance):
assert abs((ti.trigger_timeout - expected_timeout).total_seconds()) < 5


class TestTaskInstanceRelationships:
@pytest.mark.parametrize(
"attr",
["rendered_task_instance_fields", "hitl_detail"],
)
def test_noload_relationships_raise_without_joinedload(self, dag_maker, session, attr):
"""Accessing lazy='raise' relationships without joinedload should raise."""
from sqlalchemy.exc import InvalidRequestError

with dag_maker("test_dag", session=session):
EmptyOperator(task_id="task_1")

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("task_1")
session.merge(ti)
session.commit()

loaded_ti = session.scalar(select(TaskInstance).where(TaskInstance.id == ti.id))

with pytest.raises(InvalidRequestError):
getattr(loaded_ti, attr)


class TestTaskInstanceRecordTaskMapXComPush:
"""Test TI.xcom_push() correctly records return values for task-mapping."""

Expand Down
20 changes: 20 additions & 0 deletions airflow-core/tests/unit/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,26 @@ def task_instances(session, task_instance):
return task_instance, ti2 # ti2 will be cleaned up automatically with the DAG run.


class TestXComModelRelationships:
def test_xcom_task_raises_without_joinedload(self, task_instance, session):
"""Accessing XComModel.task without joinedload should raise."""
from sqlalchemy.exc import InvalidRequestError

XComModel.set(
key="test_key",
value="test_value",
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)

xcom = session.scalar(select(XComModel).where(XComModel.task_id == task_instance.task_id))

with pytest.raises(InvalidRequestError):
xcom.task


class TestXCom:
@conf_vars({("core", "xcom_backend"): "unit.models.test_xcom.CustomXCom"})
def test_resolve_xcom_class(self):
Expand Down
Loading