From 863f24fa066835765e73d35c6bc0b333918ae536 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 26 Mar 2024 14:26:47 +0800 Subject: [PATCH 1/5] Implement context accessor for DatasetEvent extra --- airflow/datasets/__init__.py | 11 ++++- airflow/models/taskinstance.py | 14 ++++-- airflow/utils/context.py | 34 ++++++++++++++ airflow/utils/context.pyi | 6 +++ .../authoring-and-scheduling/datasets.rst | 31 +++++++++++-- tests/models/test_taskinstance.py | 46 ++++++++++++++++++- tests/operators/test_python.py | 1 + 7 files changed, 132 insertions(+), 11 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 2507c69d01b43..d20d3b578e508 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -42,7 +42,14 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N return ProvidersManager().dataset_uri_handlers.get(scheme) -def _sanitize_uri(uri: str) -> str: +def sanitize_uri(uri: str) -> str: + """Sanitize a dataset URI. + + This checks for URI validity, and normalizes the URI if needed. A fully + normalized URI is returned. + + :meta private: + """ if not uri: raise ValueError("Dataset URI cannot be empty") if uri.isspace(): @@ -110,7 +117,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput): """A representation of data dependencies between workflows.""" uri: str = attr.field( - converter=_sanitize_uri, + converter=sanitize_uri, validator=[attr.validators.min_len(1), attr.validators.max_len(3000)], ) extra: dict[str, Any] | None = None diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 8bb9947327d3a..c9bd2ce617154 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -104,7 +104,13 @@ from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.utils import timezone -from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge +from airflow.utils.context import ( + ConnectionAccessor, + Context, + DatasetEventAccessors, + VariableAccessor, + context_merge, +) from airflow.utils.email import send_email from airflow.utils.helpers import prune_dict, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin @@ -766,6 +772,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti "dag_run": dag_run, "data_interval_end": timezone.coerce_datetime(data_interval.end), "data_interval_start": timezone.coerce_datetime(data_interval.start), + "dataset_events": DatasetEventAccessors(), "ds": ds, "ds_nodash": ds_nodash, "execution_date": logical_date, @@ -2569,7 +2576,7 @@ def _run_raw_task( session.add(Log(self.state, self)) session.merge(self).task = self.task if self.state == TaskInstanceState.SUCCESS: - self._register_dataset_changes(session=session) + self._register_dataset_changes(events=context["dataset_events"], session=session) session.commit() if self.state == TaskInstanceState.SUCCESS: @@ -2579,7 +2586,7 @@ def _run_raw_task( return None - def _register_dataset_changes(self, *, session: Session) -> None: + def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: Session) -> None: if TYPE_CHECKING: assert self.task @@ -2590,6 +2597,7 @@ def _register_dataset_changes(self, *, session: Session) -> None: dataset_manager.register_dataset_change( task_instance=self, dataset=obj, + extra=events[obj].extra, session=session, ) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 3501ca7dbc22a..033b7aa39d3ba 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -36,8 +36,10 @@ ValuesView, ) +import attrs import lazy_object_proxy +from airflow.datasets import Dataset, sanitize_uri from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.types import NOTSET @@ -54,6 +56,7 @@ "dag_run", "data_interval_end", "data_interval_start", + "dataset_events", "ds", "ds_nodash", "execution_date", @@ -146,6 +149,37 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn +@attrs.define() +class DatasetEventAccessor: + """Wrapper to access a DatasetEvent instance in template.""" + + extra: dict[str, Any] + + +class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]): + """Lazy mapping of dataset event accessors.""" + + def __init__(self) -> None: + self._dict: dict[str, DatasetEventAccessor] = {} + + def __iter__(self) -> Iterator[str]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: + if isinstance(key, str): + uri = sanitize_uri(key) + elif isinstance(key, Dataset): + uri = key.uri + else: + return NotImplemented + if uri not in self._dict: + self._dict[uri] = DatasetEventAccessor({}) + return self._dict[uri] + + class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): """Warn for usage of deprecated context variables in a task.""" diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index eb08201248173..c26e7bfa25d01 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -55,6 +55,11 @@ class VariableAccessor: class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... +class DatasetEventAccessor: + extra: dict[str, Any] + +class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]): ... + # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py # * Table in docs/apache-airflow/templates-ref.rst @@ -65,6 +70,7 @@ class Context(TypedDict, total=False): dag_run: DagRun | DagRunPydantic data_interval_end: DateTime data_interval_start: DateTime + dataset_events: DatasetEventAccessors ds: str ds_nodash: str exception: BaseException | str | None diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 1102420dd4656..5324a11bbc7d0 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -99,8 +99,8 @@ The identifier does not have to be absolute; it can be a scheme-less, relative U Non-absolute identifiers are considered plain strings that do not carry any semantic meanings to Airflow. -Extra information ------------------ +Extra information on Dataset +---------------------------- If needed, an extra dictionary can be included in a Dataset: @@ -111,7 +111,7 @@ If needed, an extra dictionary can be included in a Dataset: extra={"team": "trainees"}, ) -This extra information does not affect a dataset's identity. This means a DAG will be triggered by a dataset with an identical URI, even if the extra dict is different: +This can be used to supply custom description to the dataset, such as who has ownership to the target file, or what the file is for. The extra information does not affect a dataset's identity. This means a DAG will be triggered by a dataset with an identical URI, even if the extra dict is different: .. code-block:: python @@ -224,6 +224,29 @@ If one dataset is updated multiple times before all consumed datasets have been } +Attaching extra information to an emitting Dataset Event +-------------------------------------------------------- + +.. versionadded:: 2.10.0 + +A task with a dataset outlet can optionally attach extra information before it emits a dataset event. This is different +from `Extra information on Dataset`_. Extra information on a dataset statically describes the entity pointed to by the dataset URI; extra information on the *dataset event* instead should be used to annotate the triggering data change, such as how many rows in the database are changed by the update, or the date range covered by it. + +The easiest way to attach extra information to the dataset event is by accessing ``dataset_events`` in a task's execution context: + +.. code-block:: python + + example_s3_dataset = Dataset("s3://dataset/example.csv") + + + @task(outlets=[example_s3_dataset]) + def write_to_s3(*, dataset_events): + df = ... # Get a Pandas DataFrame to write. + # Write df to dataset... + dataset_events[example_s3_dataset].extras = {"row_count": len(df)} + +This can also be done in classic operators by either subclassing the operator and overriding ``execute``, or by supplying a pre- or post-execution function. + Fetching information from a Triggering Dataset Event ---------------------------------------------------- @@ -234,7 +257,7 @@ Example: .. code-block:: python - example_snowflake_dataset = Dataset("snowflake://my_db.my_schema.my_table") + example_snowflake_dataset = Dataset("snowflake://my_db/my_schema/my_table") with DAG(dag_id="load_snowflake_data", schedule="@hourly", ...): SQLExecuteQueryOperator( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 120856dbbfd24..7a28cbee5ccfe 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -35,6 +35,7 @@ import pendulum import pytest import time_machine +from sqlalchemy import select from airflow import settings from airflow.decorators import task, task_group @@ -2281,7 +2282,7 @@ def raise_an_exception(placeholder: int): task_instance.run() assert task_instance.current_state() == TaskInstanceState.SUCCESS - def test_outlet_datasets_skipped(self, create_task_instance): + def test_outlet_datasets_skipped(self): """ Verify that when we have an outlet dataset on a task, and the task is skipped, a DatasetDagRunQueue is not logged, and a DatasetEvent is @@ -2311,7 +2312,48 @@ def test_outlet_datasets_skipped(self, create_task_instance): # check that no dataset events were generated assert session.query(DatasetEvent).count() == 0 - def test_changing_of_dataset_when_ddrq_is_already_populated(self, dag_maker, session): + def test_outlet_dataset_extra(self, dag_maker, session): + from airflow.datasets import Dataset + + with dag_maker(schedule=None, session=session): + + @task(outlets=Dataset("test_outlet_dataset_extra")) + def write(*, dataset_events): + dataset_events["test_outlet_dataset_extra"].extra = {"foo": "bar"} + + write() + + dr: DagRun = dag_maker.create_dagrun() + dr.get_task_instance("write").run(session=session) + + event = session.scalars(select(DatasetEvent)).one() + assert event.source_dag_id == dr.dag_id + assert event.source_run_id == dr.run_id + assert event.source_task_id == "write" + assert event.extra == {"foo": "bar"} + + def test_outlet_dataset_extra_ignore_different(self, dag_maker, session): + from airflow.datasets import Dataset + + with dag_maker(schedule=None, session=session): + + @task(outlets=Dataset("test_outlet_dataset_extra")) + def write(*, dataset_events): + dataset_events["test_outlet_dataset_extra"].extra = {"one": 1} + dataset_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped. + + write() + + dr: DagRun = dag_maker.create_dagrun() + dr.get_task_instance("write").run(session=session) + + event = session.scalars(select(DatasetEvent)).one() + assert event.source_dag_id == dr.dag_id + assert event.source_run_id == dr.run_id + assert event.source_task_id == "write" + assert event.extra == {"one": 1} + + def test_changing_of_dataset_when_ddrq_is_already_populated(self, dag_maker): """ Test that when a task that produces dataset has ran, that changing the consumer dag dataset will not cause primary key blank-out diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 578302a83646e..b8876f97ec676 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -834,6 +834,7 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "ti", "var", # Accessor for Variable; var->json and var->value. "conn", # Accessor for Connection. + "dataset_events", # Accessor for DatasetEvent. ] ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) From 3c71429134e9d1e12640753c1f651012510f1b47 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 28 Mar 2024 18:01:59 +0800 Subject: [PATCH 2/5] Ad dataset_events to documentation --- docs/apache-airflow/templates-ref.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index dd05fcc831c70..4d3014268b44e 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -74,6 +74,8 @@ Variable Type Description ``{{ var.value }}`` Airflow variables. See `Airflow Variables in Templates`_ below. ``{{ var.json }}`` Airflow variables. See `Airflow Variables in Templates`_ below. ``{{ conn }}`` Airflow connections. See `Airflow Connections in Templates`_ below. +``{{ dataset_events }}`` dict[str, ...] | Accessors to attach information to dataset events that will be emitted by the current task. + | See :doc:`Datasets `. Added in version 2.10. ``{{ task_instance_key_str }}`` str | A unique, human-readable key to the task instance. The format is | ``{dag_id}__{task_id}__{ds_nodash}``. ``{{ conf }}`` AirflowConfigParser | The full configuration object representing the content of your From 84983d7b36384490c8a2461a46b91d6285be2600 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 28 Mar 2024 18:14:53 +0800 Subject: [PATCH 3/5] Test using post_execute to access dataset_events --- tests/models/test_taskinstance.py | 43 +++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 7a28cbee5ccfe..e6187311429c9 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2315,22 +2315,43 @@ def test_outlet_datasets_skipped(self): def test_outlet_dataset_extra(self, dag_maker, session): from airflow.datasets import Dataset - with dag_maker(schedule=None, session=session): + with dag_maker(schedule=None, session=session) as dag: - @task(outlets=Dataset("test_outlet_dataset_extra")) - def write(*, dataset_events): - dataset_events["test_outlet_dataset_extra"].extra = {"foo": "bar"} + @task(outlets=Dataset("test_outlet_dataset_extra_1")) + def write1(*, dataset_events): + dataset_events["test_outlet_dataset_extra_1"].extra = {"foo": "bar"} - write() + write1() + + def _write2_post_execute(context, _): + context["dataset_events"]["test_outlet_dataset_extra_2"].extra = {"x": 1} + + BashOperator( + task_id="write2", + bash_command=":", + outlets=Dataset("test_outlet_dataset_extra_2"), + post_execute=_write2_post_execute, + ) dr: DagRun = dag_maker.create_dagrun() - dr.get_task_instance("write").run(session=session) + for ti in dr.get_task_instances(session=session): + ti.refresh_from_task(dag.get_task(ti.task_id)) + ti.run(session=session) - event = session.scalars(select(DatasetEvent)).one() - assert event.source_dag_id == dr.dag_id - assert event.source_run_id == dr.run_id - assert event.source_task_id == "write" - assert event.extra == {"foo": "bar"} + events = dict(iter(session.execute(select(DatasetEvent.source_task_id, DatasetEvent)))) + assert set(events) == {"write1", "write2"} + + assert events["write1"].source_dag_id == dr.dag_id + assert events["write1"].source_run_id == dr.run_id + assert events["write1"].source_task_id == "write1" + assert events["write1"].dataset.uri == "test_outlet_dataset_extra_1" + assert events["write1"].extra == {"foo": "bar"} + + assert events["write2"].source_dag_id == dr.dag_id + assert events["write2"].source_run_id == dr.run_id + assert events["write2"].source_task_id == "write2" + assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2" + assert events["write2"].extra == {"x": 1} def test_outlet_dataset_extra_ignore_different(self, dag_maker, session): from airflow.datasets import Dataset From 2996db33e67d2038f3c095cda5d0696ca665d9c2 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 28 Mar 2024 18:21:55 +0800 Subject: [PATCH 4/5] Enable context variable pre-commit check --- .pre-commit-config.yaml | 5 + contributing-docs/08_static_code_checks.rst | 2 + .../doc/images/output_static-checks.svg | 114 +++++++++--------- .../doc/images/output_static-checks.txt | 2 +- .../src/airflow_breeze/pre_commit_ids.py | 1 + .../pre_commit_template_context_key_sync.py | 0 6 files changed, 68 insertions(+), 56 deletions(-) mode change 100644 => 100755 scripts/ci/pre_commit/pre_commit_template_context_key_sync.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1829f9b200492..2f347b1c8851c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -629,6 +629,11 @@ repos: entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py pass_filenames: false files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group\.py$ + - id: check-template-context-variable-in-sync + name: Check all template context variable references are in sync + language: python + entry: ./scripts/ci/pre_commit/pre_commit_template_context_key_sync.py + files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$ - id: check-base-operator-usage language: pygrep name: Check BaseOperator core imports diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 0b331bf3e95b4..c7be51b6a78fe 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -222,6 +222,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-system-tests-tocs | Check that system tests is properly added | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ +| check-template-context-variable-in-sync | Check all template context variable references are in sync | | ++-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-tests-in-the-right-folders | Check if tests are in the right folders | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-tests-unittest-testcase | Check that unit tests do not inherit from unittest.TestCase | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index 679db3dfeec89..a709f8070c9d9 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -1,4 +1,4 @@ - +