Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
ConnectionAccessor,
Context,
DatasetEventAccessors,
InletEventsAccessors,
VariableAccessor,
context_get_dataset_events,
context_merge,
Expand Down Expand Up @@ -804,6 +805,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti
"execution_date": logical_date,
"expanded_ti_count": expanded_ti_count,
"inlets": task.inlets,
"inlet_events": InletEventsAccessors(task.inlets, session=session),
"logical_date": logical_date,
"macros": macros,
"map_index_template": task.map_index_template,
Expand Down
88 changes: 88 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
KeysView,
Mapping,
MutableMapping,
Sequence,
SupportsIndex,
ValuesView,
overload,
)

import attrs
Expand All @@ -44,7 +46,11 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select

from airflow.models.baseoperator import BaseOperator
from airflow.models.dataset import DatasetEvent

# NOTE: Please keep this in sync with the following:
# * Context in airflow/utils/context.pyi.
Expand All @@ -63,6 +69,7 @@
"expanded_ti_count",
"exception",
"inlets",
"inlet_events",
"logical_date",
"macros",
"map_index_template",
Expand Down Expand Up @@ -174,6 +181,87 @@ def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
return self._dict[uri]


@attrs.define()
class InletEventsAccessor(Sequence["DatasetEvent"]):
"""Lazy sequence to access inlet dataset events.

:meta private:
"""

_uri: str
_session: Session

def _get_select_stmt(self, *, reverse: bool = False) -> Select:
from sqlalchemy import select

from airflow.models.dataset import DatasetEvent, DatasetModel

stmt = select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == self._uri)
if reverse:
return stmt.order_by(DatasetEvent.timestamp.desc())
return stmt.order_by(DatasetEvent.timestamp.asc())

def __reversed__(self) -> Iterator[DatasetEvent]:
return iter(self._session.scalar(self._get_select_stmt(reverse=True)))

def __iter__(self) -> Iterator[DatasetEvent]:
return iter(self._session.scalar(self._get_select_stmt()))

@overload
def __getitem__(self, key: int) -> DatasetEvent: ...

@overload
def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ...

def __getitem__(self, key: int | slice) -> DatasetEvent | Sequence[DatasetEvent]:
if not isinstance(key, int):
raise ValueError("non-index access is not supported")
if key >= 0:
stmt = self._get_select_stmt().offset(key)
else:
stmt = self._get_select_stmt(reverse=True).offset(-1 - key)
if (event := self._session.scalar(stmt.limit(1))) is not None:
return event
raise IndexError(key)

def __len__(self) -> int:
from sqlalchemy import func, select

return self._session.scalar(select(func.count()).select_from(self._get_select_stmt()))


@attrs.define(init=False)
class InletEventsAccessors(Mapping[str, InletEventsAccessor]):
"""Lazy mapping for inlet dataset events accessors.

:meta private:
"""

_inlets: list[Any]
_datasets: dict[str, Dataset]
_session: Session

def __init__(self, inlets: list, *, session: Session) -> None:
self._inlets = inlets
self._datasets = {inlet.uri: inlet for inlet in inlets if isinstance(inlet, Dataset)}
self._session = session

def __iter__(self) -> Iterator[str]:
return iter(self._inlets)

def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | str | Dataset) -> InletEventsAccessor:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
dataset = self._inlets[key]
if not isinstance(dataset, Dataset):
raise IndexError(key)
else:
dataset = self._datasets[coerce_to_uri(key)]
return InletEventsAccessor(dataset.uri, session=self._session)


class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
"""Warn for usage of deprecated context variables in a task."""

Expand Down
17 changes: 16 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
# declare "these are defined, but don't error if others are accessed" someday.
from __future__ import annotations

from typing import Any, Collection, Container, Iterable, Iterator, Mapping, overload
from typing import Any, Collection, Container, Iterable, Iterator, Mapping, Sequence, overload

from pendulum import DateTime
from sqlalchemy.orm import Session

from airflow.configuration import AirflowConfigParser
from airflow.datasets import Dataset
Expand Down Expand Up @@ -65,6 +66,19 @@ class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
def __len__(self) -> int: ...
def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: ...

class InletEventsAccessor(Sequence[DatasetEvent]):
@overload
def __getitem__(self, key: int) -> DatasetEvent: ...
@overload
def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ...
def __len__(self) -> int: ...

class InletEventsAccessors(Mapping[str, InletEventsAccessor]):
def __init__(self, inlets: list, *, session: Session) -> None: ...
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: int | str | Dataset) -> InletEventsAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
Expand All @@ -82,6 +96,7 @@ class Context(TypedDict, total=False):
execution_date: DateTime
expanded_ti_count: int | None
inlets: list
inlet_events: InletEventsAccessors
logical_date: DateTime
macros: Any
map_index_template: str
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Variable Type Description
``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available).
| ``None``
``{{ inlets }}`` list List of inlets declared on the task.
``{{ inlet_events }}`` dict[str, ...] Access past events of inlet datasets. See :doc:`Datasets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ outlets }}`` list List of outlets declared on the task.
``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs <core-concepts/dags>`.
``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators`
Expand Down
53 changes: 53 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,59 @@ def _write2_post_execute(context, result):
assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2"
assert events["write2"].extra == {"x": 1}

def test_inlet_dataset_extra(self, dag_maker, session):
from airflow.datasets import Dataset

read_task_evaluated = False

with dag_maker(schedule=None, session=session):

@task(outlets=Dataset("test_inlet_dataset_extra"))
def write(*, ti, dataset_events):
dataset_events["test_inlet_dataset_extra"].extra = {"from": ti.task_id}

@task(inlets=Dataset("test_inlet_dataset_extra"))
def read(*, inlet_events):
second_event = inlet_events["test_inlet_dataset_extra"][1]
assert second_event.uri == "test_inlet_dataset_extra"
assert second_event.extra == {"from": "write2"}

last_event = inlet_events["test_inlet_dataset_extra"][-1]
assert last_event.uri == "test_inlet_dataset_extra"
assert last_event.extra == {"from": "write3"}

with pytest.raises(KeyError):
inlet_events["does_not_exist"]
with pytest.raises(IndexError):
inlet_events["test_inlet_dataset_extra"][5]

# TODO: Support slices.

nonlocal read_task_evaluated
read_task_evaluated = True

[
write.override(task_id="write1")(),
write.override(task_id="write2")(),
write.override(task_id="write3")(),
] >> read()

dr: DagRun = dag_maker.create_dagrun()

# Run "write1", "write2", and "write3" (in this order).
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")):
ti.run(session=session)

# Run "read".
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.run(session=session)

# Should be done.
assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis
assert read_task_evaluated

def test_changing_of_dataset_when_ddrq_is_already_populated(self, dag_maker):
"""
Test that when a task that produces dataset has ran, that changing the consumer
Expand Down
3 changes: 2 additions & 1 deletion tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,8 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance):
"ti",
"var", # Accessor for Variable; var->json and var->value.
"conn", # Accessor for Connection.
"dataset_events", # Accessor for DatasetEvent.
"dataset_events", # Accessor for outlet DatasetEvent.
"inlet_events", # Accessor for inlet DatasetEvent.
]

ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
Expand Down