From c9585c56e28e353880636d969f6d93081b95e3be Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:58:06 -0700 Subject: [PATCH 1/2] Remove select_column option in TaskInstance.get_task_instance Fundamentally what's going on here is we need a TaskInstance object instead of a Row object when sending over the wire in RPC call. But the full story on this one is actually somewhat complicated. It was back in 2.2.0 in #25312 when we converted to query with the column attrs instead of the TI object (#28900 only refactored this logic into a function). The reason was to avoid locking the dag_run table since TI newly had a dag_run relationship attr. Now, this causes a problem with AIP-44 because the RPC api does not know how to serialize a Row object. This PR switches back to querying a TaskInstance object, but avoids locking dag_run by using lazy_load option. Meanwhile, since try_number is a horrible attribute (which gives you a different answer depending on the state), we have to switch it back to look at the underlying private attr instead of the public accesor. --- airflow/models/taskinstance.py | 24 +++++++++++------------- tests/models/test_taskinstance.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c9bd2ce617154..b7ecd52cb3e02 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -60,7 +60,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import reconstructor, relationship +from sqlalchemy.orm import lazyload, reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy.sql.expression import case, select @@ -521,7 +521,6 @@ def _refresh_from_db( task_id=task_instance.task_id, run_id=task_instance.run_id, map_index=task_instance.map_index, - select_columns=True, lock_for_update=lock_for_update, session=session, ) @@ -532,8 +531,7 @@ def _refresh_from_db( task_instance.end_date = ti.end_date task_instance.duration = ti.duration task_instance.state = ti.state - # Since we selected columns, not the object, this is the raw value - task_instance.try_number = ti.try_number + task_instance.try_number = ti._try_number task_instance.max_tries = ti.max_tries task_instance.hostname = ti.hostname task_instance.unixname = ti.unixname @@ -911,7 +909,7 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic): :meta private: """ - if task_instance.state == TaskInstanceState.RUNNING.RUNNING: + if task_instance.state == TaskInstanceState.RUNNING: return task_instance._try_number return task_instance._try_number + 1 @@ -1792,18 +1790,18 @@ def get_task_instance( run_id: str, task_id: str, map_index: int, - select_columns: bool = False, lock_for_update: bool = False, session: Session = NEW_SESSION, ) -> TaskInstance | TaskInstancePydantic | None: query = ( - session.query(*TaskInstance.__table__.columns) if select_columns else session.query(TaskInstance) - ) - query = query.filter_by( - dag_id=dag_id, - run_id=run_id, - task_id=task_id, - map_index=map_index, + session.query(TaskInstance) + .options(lazyload("dag_run")) # lazy load dag run to avoid locking it + .filter_by( + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_index=map_index, + ) ) if lock_for_update: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e6187311429c9..77d9680135877 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -4561,3 +4561,16 @@ def test_taskinstance_with_note(create_task_instance, session): assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None + + +def test__refresh_from_db_should_not_increment_try_number(dag_maker, session): + with dag_maker(): + BashOperator(task_id="hello", bash_command="hi") + dag_maker.create_dagrun(state="success") + ti = session.scalar(select(TaskInstance)) + assert ti.task_id == "hello" # just to confirm... + assert ti.try_number == 1 # starts out as 1 + ti.refresh_from_db() + assert ti.try_number == 1 # stays 1 + ti.refresh_from_db() + assert ti.try_number == 1 # stays 1 From c817d5454f149746c595f06e5baced40771c3aec Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 2 Apr 2024 08:46:21 -0700 Subject: [PATCH 2/2] comment --- airflow/models/taskinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b7ecd52cb3e02..443ed4b5d2127 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -531,7 +531,7 @@ def _refresh_from_db( task_instance.end_date = ti.end_date task_instance.duration = ti.duration task_instance.state = ti.state - task_instance.try_number = ti._try_number + task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor task_instance.max_tries = ti.max_tries task_instance.hostname = ti.hostname task_instance.unixname = ti.unixname