diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 2f0a416f47ff3..33cb0fe6ea5dd 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -46,10 +46,12 @@ from airflow.models.dagrun import DagRun, DagRunType from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator + from airflow.sdk import Context from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + from airflow.sdk.types import DagRunProtocol from airflow.timetables.base import DataInterval from airflow.typing_compat import Self from airflow.utils.state import DagRunState, TaskInstanceState @@ -1996,6 +1998,15 @@ def msg(self) -> ToSupervisor | None: ... @property def error(self) -> BaseException | None: ... + @property + def ti(self) -> RuntimeTaskInstance: ... + + @property + def dagrun(self) -> DagRunProtocol: ... + + @property + def context(self) -> Context: ... + xcom: _XComHelperProtocol def __call__( @@ -2039,6 +2050,7 @@ def execute(self, context): from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + from airflow.timetables.base import TimeRestriction from airflow.utils import timezone def _create_task_instance( @@ -2058,6 +2070,7 @@ def _create_task_instance( max_tries: int | None = None, ) -> RuntimeTaskInstance: from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + from airflow.utils.types import DagRunType if not ti_id: ti_id = uuid7() @@ -2067,13 +2080,22 @@ def _create_task_instance( task.dag = dag # type: ignore[assignment] task = dag.task_dict[task.task_id] + data_interval_start = None + data_interval_end = None + if task.dag.timetable: - data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( - run_after=logical_date # type: ignore - ) - else: - data_interval_start = None - data_interval_end = None + if run_type == DagRunType.MANUAL: + data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval( + run_after=logical_date # type: ignore + ) + else: + drinfo = task.dag.timetable.next_dagrun_info( + last_automated_data_interval=None, + restriction=TimeRestriction(earliest=None, latest=None, catchup=False), + ) + if drinfo: + data_interval = drinfo.data_interval + data_interval_start, data_interval_end = data_interval.start, data_interval.end dag_id = task.dag.dag_id task_retries = task.retries or 0 @@ -2252,6 +2274,9 @@ def __init__(self, create_runtime_ti): self._state = None self._msg = None self._error = None + self._ti = None + self._dagrun = None + self._context = None @property def state(self) -> IntermediateTIState | TerminalTIState: @@ -2268,6 +2293,18 @@ def error(self) -> BaseException | None: """Get the error message if there was any.""" return self._error + @property + def ti(self) -> RuntimeTaskInstance: + return self._ti + + @property + def dagrun(self) -> DagRunProtocol: + return self._dagrun + + @property + def context(self) -> Context: + return self._context + def __call__( self, task: TaskSDKBaseOperator, @@ -2315,6 +2352,9 @@ def __call__( self._state = state self._msg = msg self._error = error + self._ti = ti + self._dagrun = context.get("dag_run") + self._context = context return state, msg, error diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index 930d5947563d6..d7f4c636ebfff 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -25,10 +25,12 @@ import pendulum from airflow.providers.standard.operators.branch import BaseBranchOperator +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.types import DagRunType if TYPE_CHECKING: from airflow.models import DAG, DagRun + from airflow.timetables.base import DagRunInfo try: from airflow.sdk.definitions.context import Context @@ -46,6 +48,10 @@ class LatestOnlyOperator(BaseBranchOperator): Note that downstream tasks are never skipped if the given DAG_Run is marked as externally triggered. + + Note that when used with timetables that produce zero-length or point-in-time data intervals + (e.g., ``DeltaTriggerTimetable``), this operator assumes each run is the latest + and does not skip downstream tasks. """ ui_color = "#e9ffdb" # nyanza @@ -58,8 +64,7 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: self.log.info("Manually triggered DAG_Run: allowing execution to proceed.") return list(context["task"].get_direct_relative_ids(upstream=False)) - dag: DAG = context["dag"] # type: ignore[assignment] - next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) + next_info = self._get_next_run_info(context, dag_run) now = pendulum.now("UTC") if next_info is None: @@ -74,6 +79,15 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: now, ) + if left_window == right_window: + self.log.info( + "Zero-length interval [%s, %s) from timetable (%s); treating current run as latest.", + left_window, + right_window, + self.dag.timetable.__class__, + ) + return list(context["task"].get_direct_relative_ids(upstream=False)) + if not left_window < now <= right_window: self.log.info("Not latest execution, skipping downstream.") # we return an empty list, thus the parent BaseBranchOperator @@ -82,3 +96,21 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: else: self.log.info("Latest, allowing execution to proceed.") return list(context["task"].get_direct_relative_ids(upstream=False)) + + def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None: + dag: DAG = context["dag"] # type: ignore[assignment] + + if AIRFLOW_V_3_0_PLUS: + from airflow.timetables.base import DataInterval, TimeRestriction + + time_restriction = TimeRestriction(earliest=None, latest=None, catchup=True) + current_interval = DataInterval(start=dag_run.data_interval_start, end=dag_run.data_interval_end) + + next_info = dag.timetable.next_dagrun_info( + last_automated_data_interval=current_interval, + restriction=time_restriction, + ) + + else: + next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) + return next_info diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index e5f0e842fe405..b976f41fa9a2e 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -36,6 +36,8 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG + from airflow.timetables.trigger import DeltaTriggerTimetable from airflow.utils.types import DagRunTriggeredByType pytestmark = pytest.mark.db_test @@ -310,3 +312,23 @@ def test_not_skipping_external(self, dag_maker): timezone.datetime(2016, 1, 1, 12): "success", timezone.datetime(2016, 1, 2): "success", } + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Only applicable to Airflow 3.0+") + def test_zero_length_interval_treated_as_latest(self, run_task): + """Test that when the data_interval_start and data_interval_end are the same, the task is treated as latest.""" + with DAG( + "test_dag", + schedule=DeltaTriggerTimetable(datetime.timedelta(hours=1)), + start_date=DEFAULT_DATE, + catchup=False, + ): + latest_task = LatestOnlyOperator(task_id="latest") + downstream_task = EmptyOperator(task_id="downstream") + latest_task >> downstream_task + + run_task(latest_task, run_type=DagRunType.SCHEDULED) + + assert run_task.dagrun.data_interval_start == run_task.dagrun.data_interval_end + + # The task will raise DownstreamTasksSkipped exception if it is not the latest run + assert run_task.state == State.SUCCESS diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index b3ecc7be91a13..a7749ccd2308b 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -21,8 +21,6 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, Union -import attrs - from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet if TYPE_CHECKING: @@ -109,7 +107,7 @@ def get_dr_count( def get_dagrun_state(dag_id: str, run_id: str) -> str: ... -class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance): +class OutletEventAccessorProtocol(Protocol): """Protocol for managing access to a specific outlet event accessor.""" key: BaseAssetUniqueKey