Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import Context
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState
from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.sdk.types import DagRunProtocol
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
Expand Down Expand Up @@ -1996,6 +1998,15 @@ def msg(self) -> ToSupervisor | None: ...
@property
def error(self) -> BaseException | None: ...

@property
def ti(self) -> RuntimeTaskInstance: ...

@property
def dagrun(self) -> DagRunProtocol: ...

@property
def context(self) -> Context: ...

xcom: _XComHelperProtocol

def __call__(
Expand Down Expand Up @@ -2039,6 +2050,7 @@ def execute(self, context):
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
from airflow.timetables.base import TimeRestriction
from airflow.utils import timezone

def _create_task_instance(
Expand All @@ -2058,6 +2070,7 @@ def _create_task_instance(
max_tries: int | None = None,
) -> RuntimeTaskInstance:
from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
from airflow.utils.types import DagRunType

if not ti_id:
ti_id = uuid7()
Expand All @@ -2067,13 +2080,22 @@ def _create_task_instance(
task.dag = dag # type: ignore[assignment]
task = dag.task_dict[task.task_id]

data_interval_start = None
data_interval_end = None

if task.dag.timetable:
data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval(
run_after=logical_date # type: ignore
)
else:
data_interval_start = None
data_interval_end = None
if run_type == DagRunType.MANUAL:
data_interval_start, data_interval_end = task.dag.timetable.infer_manual_data_interval(
run_after=logical_date # type: ignore
)
else:
drinfo = task.dag.timetable.next_dagrun_info(
last_automated_data_interval=None,
restriction=TimeRestriction(earliest=None, latest=None, catchup=False),
)
if drinfo:
data_interval = drinfo.data_interval
data_interval_start, data_interval_end = data_interval.start, data_interval.end

dag_id = task.dag.dag_id
task_retries = task.retries or 0
Expand Down Expand Up @@ -2252,6 +2274,9 @@ def __init__(self, create_runtime_ti):
self._state = None
self._msg = None
self._error = None
self._ti = None
self._dagrun = None
self._context = None

@property
def state(self) -> IntermediateTIState | TerminalTIState:
Expand All @@ -2268,6 +2293,18 @@ def error(self) -> BaseException | None:
"""Get the error message if there was any."""
return self._error

@property
def ti(self) -> RuntimeTaskInstance:
return self._ti

@property
def dagrun(self) -> DagRunProtocol:
return self._dagrun

@property
def context(self) -> Context:
return self._context

def __call__(
self,
task: TaskSDKBaseOperator,
Expand Down Expand Up @@ -2315,6 +2352,9 @@ def __call__(
self._state = state
self._msg = msg
self._error = error
self._ti = ti
self._dagrun = context.get("dag_run")
self._context = context

return state, msg, error

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import pendulum

from airflow.providers.standard.operators.branch import BaseBranchOperator
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from airflow.models import DAG, DagRun
from airflow.timetables.base import DagRunInfo

try:
from airflow.sdk.definitions.context import Context
Expand All @@ -46,6 +48,10 @@ class LatestOnlyOperator(BaseBranchOperator):

Note that downstream tasks are never skipped if the given DAG_Run is
marked as externally triggered.

Note that when used with timetables that produce zero-length or point-in-time data intervals
(e.g., ``DeltaTriggerTimetable``), this operator assumes each run is the latest
and does not skip downstream tasks.
"""

ui_color = "#e9ffdb" # nyanza
Expand All @@ -58,8 +64,7 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
self.log.info("Manually triggered DAG_Run: allowing execution to proceed.")
return list(context["task"].get_direct_relative_ids(upstream=False))

dag: DAG = context["dag"] # type: ignore[assignment]
next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
next_info = self._get_next_run_info(context, dag_run)
now = pendulum.now("UTC")

if next_info is None:
Expand All @@ -74,6 +79,15 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
now,
)

if left_window == right_window:
self.log.info(
"Zero-length interval [%s, %s) from timetable (%s); treating current run as latest.",
left_window,
right_window,
self.dag.timetable.__class__,
)
return list(context["task"].get_direct_relative_ids(upstream=False))
Comment thread
kaxil marked this conversation as resolved.

if not left_window < now <= right_window:
self.log.info("Not latest execution, skipping downstream.")
# we return an empty list, thus the parent BaseBranchOperator
Expand All @@ -82,3 +96,21 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
else:
self.log.info("Latest, allowing execution to proceed.")
return list(context["task"].get_direct_relative_ids(upstream=False))

def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None:
dag: DAG = context["dag"] # type: ignore[assignment]

if AIRFLOW_V_3_0_PLUS:
from airflow.timetables.base import DataInterval, TimeRestriction

time_restriction = TimeRestriction(earliest=None, latest=None, catchup=True)
current_interval = DataInterval(start=dag_run.data_interval_start, end=dag_run.data_interval_end)

next_info = dag.timetable.next_dagrun_info(
last_automated_data_interval=current_interval,
restriction=time_restriction,
)

else:
next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
return next_info
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import DAG
from airflow.timetables.trigger import DeltaTriggerTimetable
from airflow.utils.types import DagRunTriggeredByType

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -310,3 +312,23 @@ def test_not_skipping_external(self, dag_maker):
timezone.datetime(2016, 1, 1, 12): "success",
timezone.datetime(2016, 1, 2): "success",
}

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Only applicable to Airflow 3.0+")
def test_zero_length_interval_treated_as_latest(self, run_task):
"""Test that when the data_interval_start and data_interval_end are the same, the task is treated as latest."""
with DAG(
"test_dag",
schedule=DeltaTriggerTimetable(datetime.timedelta(hours=1)),
start_date=DEFAULT_DATE,
catchup=False,
):
latest_task = LatestOnlyOperator(task_id="latest")
downstream_task = EmptyOperator(task_id="downstream")
latest_task >> downstream_task

run_task(latest_task, run_type=DagRunType.SCHEDULED)

assert run_task.dagrun.data_interval_start == run_task.dagrun.data_interval_end

# The task will raise DownstreamTasksSkipped exception if it is not the latest run
assert run_task.state == State.SUCCESS
4 changes: 1 addition & 3 deletions task-sdk/src/airflow/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Protocol, Union

import attrs

from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,7 +107,7 @@ def get_dr_count(
def get_dagrun_state(dag_id: str, run_id: str) -> str: ...


class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
class OutletEventAccessorProtocol(Protocol):
"""Protocol for managing access to a specific outlet event accessor."""

key: BaseAssetUniqueKey
Expand Down