From c0fce0ddba2e4c466cf0ee7eebbcb2dfd34e03a2 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 24 Jan 2024 15:22:43 -0800 Subject: [PATCH 01/23] Enable conditional logic for datasets --- airflow/models/dag.py | 94 ++++++++----- airflow/models/dataset.py | 49 ++++++- airflow/serialization/enums.py | 2 + airflow/serialization/schema.json | 36 ++++- airflow/serialization/serialized_objects.py | 17 ++- tests/datasets/test_dataset.py | 133 ++++++++++++++++++ tests/serialization/test_dag_serialization.py | 1 - 7 files changed, 292 insertions(+), 40 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index dd4356865724c..8f4b1f4be5188 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -18,7 +18,6 @@ from __future__ import annotations import asyncio -import collections import copy import functools import itertools @@ -31,7 +30,7 @@ import traceback import warnings import weakref -from collections import deque +from collections import abc, defaultdict, deque from contextlib import ExitStack from datetime import datetime, timedelta from inspect import signature @@ -99,6 +98,13 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun +from airflow.models.dataset import ( + DatasetAll, + DatasetAny, + DatasetBooleanCondition, + DatasetDagRunQueue, + DatasetModel, +) from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -462,7 +468,7 @@ def __init__( on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: collections.abc.MutableMapping | None = None, + params: abc.MutableMapping | None = None, access_control: dict | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, @@ -580,14 +586,15 @@ def __init__( self.timetable: Timetable self.schedule_interval: ScheduleInterval - self.dataset_triggers: Collection[Dataset] = [] - + self.dataset_triggers: DatasetBooleanCondition | None = None + if isinstance(schedule, (DatasetAll, DatasetAny)): + self.dataset_triggers = schedule if isinstance(schedule, Collection) and not isinstance(schedule, str): from airflow.datasets import Dataset if not all(isinstance(x, Dataset) for x in schedule): raise ValueError("All elements in 'schedule' should be datasets") - self.dataset_triggers = list(schedule) + self.dataset_triggers = DatasetAll(*schedule) elif isinstance(schedule, Timetable): timetable = schedule elif schedule is not NOTSET: @@ -3156,8 +3163,8 @@ def bulk_write_to_db( TaskOutletDatasetReference, ) - dag_references = collections.defaultdict(set) - outlet_references = collections.defaultdict(set) + dag_references = defaultdict(set) + outlet_references = defaultdict(set) # We can't use a set here as we want to preserve order outlet_datasets: dict[DatasetModel, None] = {} input_datasets: dict[DatasetModel, None] = {} @@ -3168,12 +3175,13 @@ def bulk_write_to_db( # later we'll persist them to the database. for dag in dags: curr_orm_dag = existing_dags.get(dag.dag_id) - if not dag.dataset_triggers: + if dag.dataset_triggers is None: if curr_orm_dag and curr_orm_dag.schedule_dataset_references: curr_orm_dag.schedule_dataset_references = [] - for dataset in dag.dataset_triggers: - dag_references[dag.dag_id].add(dataset.uri) - input_datasets[DatasetModel.from_public(dataset)] = None + else: + for dataset in dag.dataset_triggers.all_datasets().values(): + dag_references[dag.dag_id].add(dataset.uri) + input_datasets[DatasetModel.from_public(dataset)] = None curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] @@ -3229,7 +3237,7 @@ def bulk_write_to_db( for obj in dag_refs_stored - dag_refs_needed: session.delete(obj) - existing_task_outlet_refs_dict = collections.defaultdict(set) + existing_task_outlet_refs_dict = defaultdict(set) for dag_id, orm_dag in existing_dags.items(): for todr in orm_dag.task_outlet_dataset_references: existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) @@ -3512,7 +3520,7 @@ def __repr__(self): @classmethod def get_all(cls, session) -> dict[str, dict[str, str]]: - dag_links: dict = collections.defaultdict(dict) + dag_links: dict = defaultdict(dict) for obj in session.scalars(select(cls)): dag_links[obj.dag_id].update({obj.owner: obj.link}) return dag_links @@ -3781,23 +3789,43 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. """ - from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ - - # these dag ids are triggered by datasets, and they are ready to go. - dataset_triggered_dag_info = { - x.dag_id: (x.first_queued_time, x.last_queued_time) - for x in session.execute( - select( - DagScheduleDatasetReference.dag_id, - func.max(DDRQ.created_at).label("last_queued_time"), - func.min(DDRQ.created_at).label("first_queued_time"), - ) - .join(DagScheduleDatasetReference.queue_records, isouter=True) - .group_by(DagScheduleDatasetReference.dag_id) - .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) - ) - } - dataset_triggered_dag_ids = set(dataset_triggered_dag_info) + from airflow.models.serialized_dag import SerializedDagModel + + def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None: + # if dag was serialized before 2.9 and we *just* upgraded, + # we may be dealing with old version. In that case, + # just wait for the dag to be reserialized. + try: + return cond.evaluate(statuses) + except AttributeError: + log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id) + return None + + # this loads all the DDRQ records.... may need to limit num dags + all_records = session.scalars(select(DatasetDagRunQueue)).all() + by_dag = defaultdict(list) + for r in all_records: + by_dag[r.target_dag_id].append(r) + del all_records + dag_statuses = {} + for dag_id, records in by_dag.items(): + dag_statuses[dag_id] = {x.dataset.uri: True for x in records} + ser_dags = session.scalars( + select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) + ).all() + for ser_dag in ser_dags: + dag_id = ser_dag.dag_id + statuses = dag_statuses[dag_id] + if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses): + del by_dag[dag_id] + del dag_statuses[dag_id] + del dag_statuses + dataset_triggered_dag_info = {} + for dag_id, records in by_dag.items(): + times = sorted(x.created_at for x in records) + dataset_triggered_dag_info[dag_id] = (times[0], times[-1]) + del by_dag + dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) if dataset_triggered_dag_ids: exclusion_list = set( session.scalars( @@ -3908,7 +3936,7 @@ def dag( on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: collections.abc.MutableMapping | None = None, + params: abc.MutableMapping | None = None, access_control: dict | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, @@ -4030,7 +4058,7 @@ class DagContext: """ - _context_managed_dags: collections.deque[DAG] = deque() + _context_managed_dags: deque[DAG] = deque() autoregistered_dags: set[tuple[DAG, ModuleType]] = set() current_autoregister_module_name: str | None = None diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index d9dd8e4bb52e4..0082e3ca4c456 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from typing import Callable from urllib.parse import urlsplit import sqlalchemy_jsonfield @@ -208,7 +209,7 @@ class DatasetDagRunQueue(Base): dataset_id = Column(Integer, primary_key=True, nullable=False) target_dag_id = Column(StringID(), primary_key=True, nullable=False) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - + dataset = relationship("DatasetModel", viewonly=True) __tablename__ = "dataset_dag_run_queue" __table_args__ = ( PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"), @@ -336,3 +337,49 @@ def __repr__(self) -> str: ]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})" + + +class DatasetBooleanCondition: + """ + Base class for boolean logic for dataset triggers. + + :meta private: + """ + + agg_func: Callable + + def __init__(self, *objects): + self.objects = objects + + def evaluate(self, statuses: dict[str, bool]): + return self.agg_func(self.eval_one(x, statuses) for x in self.objects) + + def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses): + if isinstance(obj, Dataset): + return statuses.get(obj.uri, False) + return obj.evaluate(statuses=statuses) + + def all_datasets(self) -> dict[str, Dataset]: + uris = {} + for x in self.objects: + if isinstance(x, Dataset): + if x.uri not in uris: + uris[x.uri] = x + else: + # keep the first instance + for k, v in x.all_datasets().items(): + if k not in uris: + uris[k] = v + return uris + + +class DatasetAny(DatasetBooleanCondition): + """Use to combine datasets schedule references in an "and" relationship.""" + + agg_func = any + + +class DatasetAll(DatasetBooleanCondition): + """Use to combine datasets schedule references in an "or" relationship.""" + + agg_func = all diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 4f95c849c8fb2..963dec580ef5c 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -50,6 +50,8 @@ class DagAttributeTypes(str, Enum): PARAM = "param" XCOM_REF = "xcomref" DATASET = "dataset" + DATASET_ANY = "dataset_any" + DATASET_ALL = "dataset_all" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index ae7121fd14442..71ee0c80063c0 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -81,6 +81,36 @@ ], "additionalProperties": false }, + "typed_dataset_cond": { + "type": "object", + "properties": { + "__type": { + "anyOf": [{ + "type": "string", + "constant": "dataset_or" + }, + { + "type": "string", + "constant": "dataset_and" + } + ] + }, + "__var": { + "type": "array", + "items": { + "anyOf": [ + {"$ref": "#/definitions/typed_dataset"}, + { "$ref": "#/definitions/typed_dataset_cond"} + ] + } + } + }, + "required": [ + "__type", + "__var" + ], + "additionalProperties": false + }, "dict": { "description": "A python dictionary containing values of any type", "type": "object" @@ -119,9 +149,9 @@ ] }, "dataset_triggers": { - "type": "array", - "items": { "$ref": "#/definitions/typed_dataset" } - }, + "$ref": "#/definitions/typed_dataset_cond" + +}, "owner_links": { "type": "object" }, "timetable": { "type": "object", diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 7adddbab10b00..34f79bf8fd76a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -42,6 +42,7 @@ from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun +from airflow.models.dataset import DatasetAll, DatasetAny from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict @@ -404,6 +405,8 @@ def serialize_to_json( serialized_object[key] = cls.serialize(value) elif key == "timetable" and value is not None: serialized_object[key] = encode_timetable(value) + elif key == "dataset_triggers": + serialized_object[key] = cls.serialize(value) else: value = cls.serialize(value) if isinstance(value, dict) and Encoding.TYPE in value: @@ -497,6 +500,10 @@ def serialize( return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF) elif isinstance(var, Dataset): return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET) + elif isinstance(var, DatasetAll): + return cls._encode([cls.serialize(x) for x in var.objects], type_=DAT.DATASET_ALL) + elif isinstance(var, DatasetAny): + return cls._encode([cls.serialize(x) for x in var.objects], type_=DAT.DATASET_ANY) elif isinstance(var, SimpleTaskInstance): return cls._encode( cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), @@ -587,6 +594,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. elif type_ == DAT.DATASET: return Dataset(**var) + elif type_ == DAT.DATASET_ANY: + return DatasetAny(*(cls.deserialize(x) for x in var)) + elif type_ == DAT.DATASET_ALL: + return DatasetAll(*(cls.deserialize(x) for x in var)) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) elif type_ == DAT.CONNECTION: @@ -763,12 +774,14 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: """Detect dependencies set directly on the DAG object.""" if not dag: return - for x in dag.dataset_triggers: + if not dag.dataset_triggers: + return + for uri in dag.dataset_triggers.all_datasets().keys(): yield DagDependency( source="dataset", target=dag.dag_id, dependency_type="dataset", - dependency_id=x.uri, + dependency_id=uri, ) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 9e9ca9951315d..47d3829cea1f6 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -18,11 +18,17 @@ from __future__ import annotations import os +from collections import defaultdict import pytest +from sqlalchemy.sql import select from airflow.datasets import Dataset +from airflow.models import DagModel +from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG @pytest.mark.parametrize( @@ -54,3 +60,130 @@ def test_fspath(): uri = "s3://example_dataset" dataset = Dataset(uri=uri) assert os.fspath(dataset) == uri + + +@pytest.mark.parametrize( + "input", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, False, True), + (False, False, False), + (False, True, True), + (False, True, False), + ], +) +@pytest.mark.parametrize("scenario", ["any", "all"]) +def test_dataset_cond(input, scenario): + if scenario == "any": + expected = any(input) + class_ = DatasetAny + else: + expected = all(input) + class_ = DatasetAll + d1 = Dataset(uri="s3://abc/123") + d2 = Dataset(uri="s3://abc/124") + d3 = Dataset(uri="s3://abc/125") + d_cond = class_(d1, d2, d3) + ser_d_cond = BaseSerialization.serialize(d_cond) + deser_d_cond = BaseSerialization.deserialize(ser_d_cond) + statuses = {d1.uri: input[0], d2.uri: input[1], d3.uri: input[2]} + assert d_cond.evaluate(statuses) == expected + assert deser_d_cond.evaluate(statuses) == expected + + +@pytest.mark.parametrize( + "input, expected", + [ + ((False, True, True), False), + ((True, True, True), True), + ((True, False, True), True), + ((True, False, False), False), + ], +) +def test_dataset_cond_nested(input, expected): + d1 = Dataset(uri="s3://abc/123") + d2 = Dataset(uri="s3://abc/124") + d3 = Dataset(uri="s3://abc/125") + sub_cond = DatasetAny(d2, d3) + d_cond = DatasetAll(d1, sub_cond) + statuses = {d1.uri: input[0], d2.uri: input[1], d3.uri: input[2]} + ser_d_cond = BaseSerialization.serialize(d_cond) + deser_d_cond = BaseSerialization.deserialize(ser_d_cond) + assert d_cond.evaluate(statuses) == expected + assert deser_d_cond.evaluate(statuses) == expected + + +def test_this(session, dag_maker): + d1 = Dataset(uri="hello1") + d1.uri + dm1 = DatasetModel(uri=d1.uri) + d2 = Dataset(uri="hello2") + dm2 = DatasetModel(uri=d2.uri) + session.add(dm1) + session.add(dm2) + session.commit() + session.query(DagModel).all() + d1.uri + with dag_maker(schedule=DatasetAny(d1, d2)) as dag: + op = EmptyOperator(task_id="hello") + dag.dataset_triggers + dm1.id + ddrq = DatasetDagRunQueue(dataset_id=dm1.id, target_dag_id=dag.dag_id) + session.add(ddrq) + assert isinstance(dag.dataset_triggers, DatasetAny) + SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields) + SerializedDAG.serialize(dag.dataset_triggers).values() + dtr = SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] + assert isinstance(dtr, dict) + deser_dtr = SerializedDAG.deserialize(dtr) + assert isinstance(deser_dtr, DatasetAny) + assert deser_dtr.objects == dag.dataset_triggers.objects + SerializedDagModel.write_dag(dag) + session.commit() + with dag_maker(dag_id="dag2") as dag2: + op = EmptyOperator(task_id="hello2") + + # here we start the scheduling logic + records = session.scalars(select(DatasetDagRunQueue)).all() + dag_statuses = defaultdict(dict) + ddrq_times = defaultdict(list) + for ddrq in records: + dag_statuses[ddrq.target_dag_id][ddrq.dataset.uri] = True + ddrq_times[ddrq.target_dag_id].append(ddrq.created_at) + dataset_triggered_dag_info = {dag_id: (min(times), max(times)) for dag_id, times in ddrq_times.items()} + ser_dags = session.execute( + select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) + ).all() + for (ser_dag,) in ser_dags: + print(ser_dag) + statuses = dag_statuses[ser_dag.dag_id] + ser_dag.dag.dataset_triggers.evaluate(statuses) + + +@pytest.fixture(autouse=True) +def clear_datasets(): + from tests.test_utils.db import clear_db_datasets + + clear_db_datasets() + + +def test_this2(session, dag_maker): + d1 = Dataset(uri="hello1") + d1.uri + dm1 = DatasetModel(uri=d1.uri) + d2 = Dataset(uri="hello2") + dm2 = DatasetModel(uri=d2.uri) + session.add(dm1) + session.add(dm2) + session.commit() + session.query(DagModel).all() + d1.uri + with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: + op = EmptyOperator(task_id="hello") + dag.dataset_triggers + SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields) + SerializedDAG.serialize(dag.dataset_triggers).values() + dtr = SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 8a122592fd5e0..6b3f3bcef8151 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -212,7 +212,6 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i }, ], "schedule_interval": {"__type": "timedelta", "__var": 86400.0}, - "dataset_triggers": [], "timezone": "UTC", "_access_control": { "__type": "dict", From 63aa9a2dc5b388c464bd2b76290df8393d153410 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:39:24 -0800 Subject: [PATCH 02/23] fix static checks --- tests/datasets/test_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 47d3829cea1f6..6da8134ef6296 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -128,7 +128,7 @@ def test_this(session, dag_maker): session.query(DagModel).all() d1.uri with dag_maker(schedule=DatasetAny(d1, d2)) as dag: - op = EmptyOperator(task_id="hello") + EmptyOperator(task_id="hello") dag.dataset_triggers dm1.id ddrq = DatasetDagRunQueue(dataset_id=dm1.id, target_dag_id=dag.dag_id) @@ -143,8 +143,8 @@ def test_this(session, dag_maker): assert deser_dtr.objects == dag.dataset_triggers.objects SerializedDagModel.write_dag(dag) session.commit() - with dag_maker(dag_id="dag2") as dag2: - op = EmptyOperator(task_id="hello2") + with dag_maker(dag_id="dag2"): + EmptyOperator(task_id="hello2") # here we start the scheduling logic records = session.scalars(select(DatasetDagRunQueue)).all() @@ -153,7 +153,7 @@ def test_this(session, dag_maker): for ddrq in records: dag_statuses[ddrq.target_dag_id][ddrq.dataset.uri] = True ddrq_times[ddrq.target_dag_id].append(ddrq.created_at) - dataset_triggered_dag_info = {dag_id: (min(times), max(times)) for dag_id, times in ddrq_times.items()} + # dataset_triggered_dag_info = {dag_id: (min(times), max(times)) for dag_id, times in ddrq_times.items()} ser_dags = session.execute( select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) ).all() @@ -182,8 +182,8 @@ def test_this2(session, dag_maker): session.query(DagModel).all() d1.uri with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: - op = EmptyOperator(task_id="hello") + EmptyOperator(task_id="hello") dag.dataset_triggers SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields) SerializedDAG.serialize(dag.dataset_triggers).values() - dtr = SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] + SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] From ca878575307b2d6bfc629b4e2e6b8b627db09807 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 1 Feb 2024 19:27:03 -0800 Subject: [PATCH 03/23] update after rebase onto dataset-or-timetable change --- airflow/models/dag.py | 8 +++++--- airflow/timetables/datasets.py | 29 ++++++++++++++--------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 8f4b1f4be5188..237759010ac22 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -600,12 +600,14 @@ def __init__( elif schedule is not NOTSET: schedule_interval = schedule - if self.dataset_triggers: + if isinstance(schedule, DatasetOrTimeSchedule): + self.timetable = schedule + self.dataset_triggers = self.timetable.datasets + self.schedule_interval = self.timetable.summary + elif self.dataset_triggers: self.timetable = DatasetTriggeredTimetable() self.schedule_interval = self.timetable.summary elif timetable: - if isinstance(timetable, DatasetOrTimeSchedule): - self.dataset_triggers = timetable.datasets self.timetable = timetable self.schedule_interval = self.timetable.summary else: diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index 4904c64e9cd16..935a852e93abd 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -17,28 +17,31 @@ from __future__ import annotations -import collections.abc import typing -import attrs - -from airflow.datasets import Dataset from airflow.exceptions import AirflowTimetableInvalid +from airflow.models.dataset import DatasetAll, DatasetBooleanCondition from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule from airflow.utils.types import DagRunType if typing.TYPE_CHECKING: + from collections.abc import Collection + import pendulum + from airflow.datasets import Dataset from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable class DatasetOrTimeSchedule(DatasetTriggeredSchedule): """Combine time-based scheduling with event-based scheduling.""" - def __init__(self, timetable: Timetable, datasets: collections.abc.Collection[Dataset]) -> None: + def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | DatasetBooleanCondition) -> None: self.timetable = timetable - self.datasets = datasets + if isinstance(datasets, DatasetBooleanCondition): + self.datasets = datasets + else: + self.datasets = DatasetAll(*datasets) self.description = f"Triggered by datasets or {timetable.description}" self.periodic = timetable.periodic @@ -52,24 +55,20 @@ def deserialize(cls, data: dict[str, typing.Any]) -> Timetable: from airflow.serialization.serialized_objects import decode_timetable return cls( - timetable=decode_timetable(data["timetable"]), datasets=[Dataset(**d) for d in data["datasets"]] + timetable=decode_timetable(data["timetable"]), + datasets=[], # don't need the datasets after deserialization ) def serialize(self) -> dict[str, typing.Any]: from airflow.serialization.serialized_objects import encode_timetable - return { - "timetable": encode_timetable(self.timetable), - "datasets": [attrs.asdict(e) for e in self.datasets], - } + return {"timetable": encode_timetable(self.timetable)} def validate(self) -> None: if isinstance(self.timetable, DatasetTriggeredSchedule): raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.datasets, collections.abc.Collection) or not all( - isinstance(d, Dataset) for d in self.datasets - ): - raise AirflowTimetableInvalid("all elements in 'event' must be datasets") + if not isinstance(self.datasets, DatasetBooleanCondition): + raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") @property def summary(self) -> str: From c51caa25316fafb19ce630a3b73b6ee92c5010b3 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 6 Feb 2024 02:01:52 +0530 Subject: [PATCH 04/23] Mock get_next_workday in AfterWorkdayTimetable --- tests/cli/commands/test_dag_command.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 0df2c36f7d449..ca47309721b6a 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -392,16 +392,24 @@ def test_cli_backfill_depends_on_past_backwards(self, mock_run): disable_retry=False, ) - @mock.patch("workday.AfterWorkdayTimetable") + @mock.patch("workday.AfterWorkdayTimetable.get_next_workday") @mock.patch("airflow.models.taskinstance.TaskInstance.dry_run") @mock.patch("airflow.cli.commands.dag_command.DagRun") - def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, mock_AfterWorkdayTimetable): + def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, mock_get_next_workday): """ when calling `dags backfill` on dag with custom timetable, the DagRun object should be created with data_intervals. """ + start_date = DEFAULT_DATE + timedelta(days=1) end_date = start_date + timedelta(days=1) + workdays = [ + start_date, + start_date + timedelta(days=1), + start_date + timedelta(days=2), + ] + mock_get_next_workday.side_effect = workdays + cli_args = self.parser.parse_args( [ "dags", From a511556ff7611e2bc540bcc461323df23329dfde Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 6 Feb 2024 20:08:48 +0530 Subject: [PATCH 05/23] Fix the tests for dataset models --- tests/datasets/test_dataset.py | 256 +++++++++++++++++++++------------ 1 file changed, 166 insertions(+), 90 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 6da8134ef6296..bcb3f57094b1c 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -24,7 +24,6 @@ from sqlalchemy.sql import select from airflow.datasets import Dataset -from airflow.models import DagModel from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator @@ -62,105 +61,152 @@ def test_fspath(): assert os.fspath(dataset) == uri +@pytest.mark.db_test @pytest.mark.parametrize( - "input", + "inputs, scenario, expected", [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, False, True), - (False, False, False), - (False, True, True), - (False, True, False), + # Scenarios for DatasetAny + ((True, True, True), "any", True), + ((True, True, False), "any", True), + ((True, False, True), "any", True), + ((True, False, False), "any", True), + ((False, False, True), "any", True), + ((False, True, False), "any", True), + ((False, True, True), "any", True), + ((False, False, False), "any", False), + # Scenarios for DatasetAll + ((True, True, True), "all", True), + ((True, True, False), "all", False), + ((True, False, True), "all", False), + ((True, False, False), "all", False), + ((False, False, True), "all", False), + ((False, True, False), "all", False), + ((False, True, True), "all", False), + ((False, False, False), "all", False), ], ) -@pytest.mark.parametrize("scenario", ["any", "all"]) -def test_dataset_cond(input, scenario): - if scenario == "any": - expected = any(input) - class_ = DatasetAny - else: - expected = all(input) - class_ = DatasetAll - d1 = Dataset(uri="s3://abc/123") - d2 = Dataset(uri="s3://abc/124") - d3 = Dataset(uri="s3://abc/125") - d_cond = class_(d1, d2, d3) - ser_d_cond = BaseSerialization.serialize(d_cond) - deser_d_cond = BaseSerialization.deserialize(ser_d_cond) - statuses = {d1.uri: input[0], d2.uri: input[1], d3.uri: input[2]} - assert d_cond.evaluate(statuses) == expected - assert deser_d_cond.evaluate(statuses) == expected +def test_dataset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): + class_ = DatasetAny if scenario == "any" else DatasetAll + datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)] + condition = class_(*datasets) + statuses = {dataset.uri: status for dataset, status in zip(datasets, inputs)} + assert ( + condition.evaluate(statuses) == expected + ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" + # Serialize and deserialize the condition to test persistence + serialized = BaseSerialization.serialize(condition) + deserialized = BaseSerialization.deserialize(serialized) + assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" + + +@pytest.mark.db_test @pytest.mark.parametrize( - "input, expected", + "status_values, expected_evaluation", [ - ((False, True, True), False), - ((True, True, True), True), - ((True, False, True), True), - ((True, False, False), False), + ((False, True, True), False), # DatasetAll requires all conditions to be True, but d1 is False + ((True, True, True), True), # All conditions are True + ((True, False, True), True), # d1 is True, and DatasetAny condition (d2 or d3 being True) is met + ((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the DatasetAny condition ], ) -def test_dataset_cond_nested(input, expected): +def test_nested_dataset_conditions_with_serialization(status_values, expected_evaluation): + # Define datasets d1 = Dataset(uri="s3://abc/123") d2 = Dataset(uri="s3://abc/124") d3 = Dataset(uri="s3://abc/125") - sub_cond = DatasetAny(d2, d3) - d_cond = DatasetAll(d1, sub_cond) - statuses = {d1.uri: input[0], d2.uri: input[1], d3.uri: input[2]} - ser_d_cond = BaseSerialization.serialize(d_cond) - deser_d_cond = BaseSerialization.deserialize(ser_d_cond) - assert d_cond.evaluate(statuses) == expected - assert deser_d_cond.evaluate(statuses) == expected + # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and d3 + nested_condition = DatasetAll(d1, DatasetAny(d2, d3)) -def test_this(session, dag_maker): - d1 = Dataset(uri="hello1") - d1.uri - dm1 = DatasetModel(uri=d1.uri) - d2 = Dataset(uri="hello2") - dm2 = DatasetModel(uri=d2.uri) - session.add(dm1) - session.add(dm2) + statuses = { + d1.uri: status_values[0], + d2.uri: status_values[1], + d3.uri: status_values[2], + } + + assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" + + serialized_condition = BaseSerialization.serialize(nested_condition) + deserialized_condition = BaseSerialization.deserialize(serialized_condition) + + assert ( + deserialized_condition.evaluate(statuses) == expected_evaluation + ), "Post-serialization evaluation mismatch" + + +@pytest.fixture +def create_test_datasets(session): + """Fixture to create test datasets and corresponding models.""" + datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)] + for dataset in datasets: + session.add(DatasetModel(uri=dataset.uri)) session.commit() - session.query(DagModel).all() - d1.uri - with dag_maker(schedule=DatasetAny(d1, d2)) as dag: + return datasets + + +@pytest.mark.db_test +def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets): + datasets = create_test_datasets + + # Create DAG with dataset triggers + with dag_maker(schedule=DatasetAny(*datasets)) as dag: EmptyOperator(task_id="hello") - dag.dataset_triggers - dm1.id - ddrq = DatasetDagRunQueue(dataset_id=dm1.id, target_dag_id=dag.dag_id) - session.add(ddrq) - assert isinstance(dag.dataset_triggers, DatasetAny) - SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields) - SerializedDAG.serialize(dag.dataset_triggers).values() - dtr = SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] - assert isinstance(dtr, dict) - deser_dtr = SerializedDAG.deserialize(dtr) - assert isinstance(deser_dtr, DatasetAny) - assert deser_dtr.objects == dag.dataset_triggers.objects - SerializedDagModel.write_dag(dag) + + # Verify dataset triggers are set up correctly + assert isinstance( + dag.dataset_triggers, DatasetAny + ), "DAG dataset triggers should be an instance of DatasetAny" + + # Serialize and deserialize DAG dataset triggers + serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers) + deserialized_trigger = SerializedDAG.deserialize(serialized_trigger) + + # Verify serialization and deserialization integrity + assert isinstance( + deserialized_trigger, DatasetAny + ), "Deserialized trigger should maintain type DatasetAny" + assert ( + deserialized_trigger.objects == dag.dataset_triggers.objects + ), "Deserialized trigger objects should match original" + + +@pytest.mark.db_test +def test_dataset_dag_run_queue_processing(session, dag_maker, create_test_datasets): + datasets = create_test_datasets + dataset_models = session.query(DatasetModel).all() + + with dag_maker(schedule=DatasetAny(*datasets)) as dag: + EmptyOperator(task_id="hello") + + # Add DatasetDagRunQueue entries to simulate dataset event processing + for dm in dataset_models: + session.add(DatasetDagRunQueue(dataset_id=dm.id, target_dag_id=dag.dag_id)) session.commit() - with dag_maker(dag_id="dag2"): - EmptyOperator(task_id="hello2") - # here we start the scheduling logic + # Fetch and evaluate dataset triggers for all DAGs affected by dataset events records = session.scalars(select(DatasetDagRunQueue)).all() - dag_statuses = defaultdict(dict) - ddrq_times = defaultdict(list) - for ddrq in records: - dag_statuses[ddrq.target_dag_id][ddrq.dataset.uri] = True - ddrq_times[ddrq.target_dag_id].append(ddrq.created_at) - # dataset_triggered_dag_info = {dag_id: (min(times), max(times)) for dag_id, times in ddrq_times.items()} - ser_dags = session.execute( + dag_statuses = defaultdict(lambda: defaultdict(bool)) + for record in records: + dag_statuses[record.target_dag_id][record.dataset.uri] = True + + serialized_dags = session.execute( select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) - ).all() - for (ser_dag,) in ser_dags: - print(ser_dag) - statuses = dag_statuses[ser_dag.dag_id] - ser_dag.dag.dataset_triggers.evaluate(statuses) + ).fetchall() + + for (serialized_dag,) in serialized_dags: + dag = SerializedDAG.deserialize(serialized_dag.data) + for dataset_uri, status in dag_statuses[dag.dag_id].items(): + assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG trigger evaluation failed" + + +@pytest.mark.db_test +@pytest.mark.usefixtures("create_test_datasets") +def test_additional_dag_with_no_triggers(dag_maker): + # Create an additional DAG to ensure it's not affected by dataset triggers + with dag_maker(dag_id="dag2"): + EmptyOperator(task_id="hello2") @pytest.fixture(autouse=True) @@ -170,20 +216,50 @@ def clear_datasets(): clear_db_datasets() -def test_this2(session, dag_maker): +@pytest.fixture +def setup_datasets_and_models(session): + """Fixture to create datasets and corresponding models.""" + # Create Dataset instances d1 = Dataset(uri="hello1") - d1.uri - dm1 = DatasetModel(uri=d1.uri) d2 = Dataset(uri="hello2") + + # Create and add DatasetModel instances to the session + dm1 = DatasetModel(uri=d1.uri) dm2 = DatasetModel(uri=d2.uri) - session.add(dm1) - session.add(dm2) + session.add_all([dm1, dm2]) session.commit() - session.query(DagModel).all() - d1.uri + + return d1, d2 + + +@pytest.mark.db_test +def test_dag_with_complex_dataset_triggers(session, dag_maker, setup_datasets_and_models): + d1, d2 = setup_datasets_and_models + + # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll) with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: EmptyOperator(task_id="hello") - dag.dataset_triggers - SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields) - SerializedDAG.serialize(dag.dataset_triggers).values() - SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"] + + assert isinstance( + dag.dataset_triggers, DatasetAny + ), "DAG's dataset trigger should be an instance of DatasetAny" + assert any( + isinstance(trigger, DatasetAll) for trigger in dag.dataset_triggers.objects + ), "DAG's dataset trigger should include DatasetAll" + + serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers) + + deserialized_triggers = SerializedDAG.deserialize(serialized_triggers) + + assert isinstance( + deserialized_triggers, DatasetAny + ), "Deserialized triggers should be an instance of DatasetAny" + assert any( + isinstance(trigger, DatasetAll) for trigger in deserialized_triggers.objects + ), "Deserialized triggers should include DatasetAll" + + serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"] + assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should contain 'dataset_triggers'" + assert isinstance( + serialized_dag_dict["dataset_triggers"], dict + ), "Serialized 'dataset_triggers' should be a dict" From d86cf1ab5911c56d689bf84ded52f2eb99080eea Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:42:54 +0530 Subject: [PATCH 06/23] Remove clear_dataset fixture --- tests/datasets/test_dataset.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index bcb3f57094b1c..ca0f312b3e0f2 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -209,13 +209,6 @@ def test_additional_dag_with_no_triggers(dag_maker): EmptyOperator(task_id="hello2") -@pytest.fixture(autouse=True) -def clear_datasets(): - from tests.test_utils.db import clear_db_datasets - - clear_db_datasets() - - @pytest.fixture def setup_datasets_and_models(session): """Fixture to create datasets and corresponding models.""" From 8d44b4cad0a81af48865b292fc8bf869c4bf05ee Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:56:38 +0530 Subject: [PATCH 07/23] fix the test for dataset serialization --- tests/timetables/test_datasets_timetable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/timetables/test_datasets_timetable.py b/tests/timetables/test_datasets_timetable.py index 8e293888cabe5..ce58c42a6bc11 100644 --- a/tests/timetables/test_datasets_timetable.py +++ b/tests/timetables/test_datasets_timetable.py @@ -127,7 +127,6 @@ def test_serialization(dataset_timetable: DatasetOrTimeSchedule, monkeypatch: An serialized = dataset_timetable.serialize() assert serialized == { "timetable": "mock_serialized_timetable", - "datasets": [{"uri": "test_dataset", "extra": None}], } From 45d696947b70bb3d6d191423f264a25e342f8cf5 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:26:55 +0530 Subject: [PATCH 08/23] Add clear dataset fixture --- tests/datasets/test_dataset.py | 11 +++++++++++ tests/serialization/test_dag_serialization.py | 1 + 2 files changed, 12 insertions(+) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index ca0f312b3e0f2..f19f95d0c7007 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -30,6 +30,7 @@ from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +@pytest.mark.db_test @pytest.mark.parametrize( ["uri"], [ @@ -45,16 +46,19 @@ def test_invalid_uris(uri): Dataset(uri=uri) +@pytest.mark.db_test def test_uri_with_scheme(): dataset = Dataset(uri="s3://example_dataset") EmptyOperator(task_id="task1", outlets=[dataset]) +@pytest.mark.db_test def test_uri_without_scheme(): dataset = Dataset(uri="example_dataset") EmptyOperator(task_id="task1", outlets=[dataset]) +@pytest.mark.db_test def test_fspath(): uri = "s3://example_dataset" dataset = Dataset(uri=uri) @@ -256,3 +260,10 @@ def test_dag_with_complex_dataset_triggers(session, dag_maker, setup_datasets_an assert isinstance( serialized_dag_dict["dataset_triggers"], dict ), "Serialized 'dataset_triggers' should be a dict" + + +@pytest.fixture(autouse=True) +def clear_datasets(): + from tests.test_utils.db import clear_db_datasets + + clear_db_datasets() diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 6b3f3bcef8151..cc906d6f3b9cb 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -549,6 +549,7 @@ def validate_deserialized_dag(self, serialized_dag, dag): "_task_group", "params", "_processor_dags_folder", + "dataset_triggers", } fields_to_check = dag.get_serialized_fields() - exclusion_list for field in fields_to_check: From c5bee763ae994c31206f2bf4d760531181492b46 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 8 Feb 2024 01:26:25 +0530 Subject: [PATCH 09/23] Add example DAGs --- airflow/example_dags/example_datasets.py | 33 ++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py index 9dfaaf0c34de0..2d613ccd71515 100644 --- a/airflow/example_dags/example_datasets.py +++ b/airflow/example_dags/example_datasets.py @@ -42,6 +42,7 @@ from airflow.datasets import Dataset from airflow.models.dag import DAG +from airflow.models.dataset import DatasetAll, DatasetAny from airflow.operators.bash import BashOperator from airflow.timetables.datasets import DatasetOrTimeSchedule from airflow.timetables.trigger import CronTriggerTimetable @@ -50,6 +51,7 @@ dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"}) # [END dataset_def] dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"}) +dag3_dataset = Dataset("s3://dag3/output_3.txt", extra={"hi": "bye"}) with DAG( dag_id="dataset_produces_1", @@ -145,3 +147,34 @@ task_id="consuming_dataset_time_based", bash_command="sleep 5", ) + +with DAG( + dag_id="consume_1_and_2_with_dataset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=DatasetAll(dag1_dataset, dag2_dataset), +) as dag5: + BashOperator( + outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], + task_id="consume_1_and_2_with_dataset_expressions", + bash_command="sleep 5", + ) +with DAG( + dag_id="consume_1_or_2_with_dataset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=DatasetAny(dag1_dataset, dag2_dataset), +) as dag6: + BashOperator( + outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], + task_id="consume_1_or_2_with_dataset_expressions", + bash_command="sleep 5", + ) +with DAG( + dag_id="consume_1_or_-2_and_3_with_dataset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=DatasetAny(dag1_dataset, DatasetAll(dag2_dataset, dag3_dataset)), +) as dag7: + BashOperator( + outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], + task_id="consume_1_or_-2_and_3_with_dataset_expressions", + bash_command="sleep 5", + ) From 341a91af8c6800fb9807ae0397a5ec8b60a0c4b9 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:14:17 +0530 Subject: [PATCH 10/23] Add strict in class serialization and remove example DAGs --- airflow/example_dags/example_datasets.py | 32 --------------------- airflow/serialization/serialized_objects.py | 16 +++++++++-- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py index 2d613ccd71515..c09d0961c29d9 100644 --- a/airflow/example_dags/example_datasets.py +++ b/airflow/example_dags/example_datasets.py @@ -42,7 +42,6 @@ from airflow.datasets import Dataset from airflow.models.dag import DAG -from airflow.models.dataset import DatasetAll, DatasetAny from airflow.operators.bash import BashOperator from airflow.timetables.datasets import DatasetOrTimeSchedule from airflow.timetables.trigger import CronTriggerTimetable @@ -147,34 +146,3 @@ task_id="consuming_dataset_time_based", bash_command="sleep 5", ) - -with DAG( - dag_id="consume_1_and_2_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=DatasetAll(dag1_dataset, dag2_dataset), -) as dag5: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_and_2_with_dataset_expressions", - bash_command="sleep 5", - ) -with DAG( - dag_id="consume_1_or_2_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=DatasetAny(dag1_dataset, dag2_dataset), -) as dag6: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_or_2_with_dataset_expressions", - bash_command="sleep 5", - ) -with DAG( - dag_id="consume_1_or_-2_and_3_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=DatasetAny(dag1_dataset, DatasetAll(dag2_dataset, dag3_dataset)), -) as dag7: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_or_-2_and_3_with_dataset_expressions", - bash_command="sleep 5", - ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 34f79bf8fd76a..5e6073233e273 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -501,9 +501,21 @@ def serialize( elif isinstance(var, Dataset): return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET) elif isinstance(var, DatasetAll): - return cls._encode([cls.serialize(x) for x in var.objects], type_=DAT.DATASET_ALL) + return cls._encode( + [ + cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models) + for x in var.objects + ], + type_=DAT.DATASET_ALL, + ) elif isinstance(var, DatasetAny): - return cls._encode([cls.serialize(x) for x in var.objects], type_=DAT.DATASET_ANY) + return cls._encode( + [ + cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models) + for x in var.objects + ], + type_=DAT.DATASET_ANY, + ) elif isinstance(var, SimpleTaskInstance): return cls._encode( cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), From fdb6bc6132c28aff0e9b98b352761a14d4a03d77 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:27:45 -0800 Subject: [PATCH 11/23] Query "existing dag runs" one at a time --- airflow/jobs/scheduler_job_runner.py | 114 ++++++++++++++------------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 32cc9f5a634ab..5dd2a65e27be8 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -31,7 +31,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator -from sqlalchemy import and_, delete, func, not_, or_, select, text, update +from sqlalchemy import and_, delete, exists, func, literal_column, not_, or_, select, text, update from sqlalchemy.exc import OperationalError from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -1215,13 +1215,6 @@ def _create_dag_runs_dataset_triggered( dag_id: timezone.coerce_datetime(last_time) for dag_id, (_, last_time) in dataset_triggered_dag_info.items() } - existing_dagruns: set[tuple[str, timezone.DateTime]] = set( - session.execute( - select(DagRun.dag_id, DagRun.execution_date).where( - tuple_in_condition((DagRun.dag_id, DagRun.execution_date), exec_dates.items()) - ) - ) - ) for dag_model in dag_models: dag = self.dagbag.get_dag(dag_model.dag_id, session=session) @@ -1247,59 +1240,68 @@ def _create_dag_runs_dataset_triggered( # create a new one. This is so that in the next Scheduling loop we try to create new runs # instead of falling in a loop of Integrity Error. exec_date = exec_dates[dag.dag_id] - if (dag.dag_id, exec_date) not in existing_dagruns: - previous_dag_run = session.scalar( - select(DagRun) - .where( - DagRun.dag_id == dag.dag_id, - DagRun.execution_date < exec_date, - DagRun.run_type == DagRunType.DATASET_TRIGGERED, - ) - .order_by(DagRun.execution_date.desc()) + curr_date_query = select( + exists( + select(literal_column("1")) + .where(DagRun.dag_id == dag.dag_id, DagRun.execution_date == exec_date) .limit(1) ) - dataset_event_filters = [ - DagScheduleDatasetReference.dag_id == dag.dag_id, - DatasetEvent.timestamp <= exec_date, - ] - if previous_dag_run: - dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date) - - dataset_events = session.scalars( - select(DatasetEvent) - .join( - DagScheduleDatasetReference, - DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, - ) - .join(DatasetEvent.source_dag_run) - .where(*dataset_event_filters) - ).all() - - data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) - run_id = dag.timetable.generate_run_id( - run_type=DagRunType.DATASET_TRIGGERED, - logical_date=exec_date, - data_interval=data_interval, - session=session, - events=dataset_events, - ) + ) + if session.scalar(curr_date_query): # dag already exists + continue - dag_run = dag.create_dagrun( - run_id=run_id, - run_type=DagRunType.DATASET_TRIGGERED, - execution_date=exec_date, - data_interval=data_interval, - state=DagRunState.QUEUED, - external_trigger=False, - session=session, - dag_hash=dag_hash, - creating_job_id=self.job.id, + prev_exec_date = session.scalar( + select(DagRun.execution_date) + .where( + DagRun.dag_id == dag.dag_id, + DagRun.execution_date < exec_date, + DagRun.run_type == DagRunType.DATASET_TRIGGERED, ) - Stats.incr("dataset.triggered_dagruns") - dag_run.consumed_dataset_events.extend(dataset_events) - session.execute( - delete(DatasetDagRunQueue).where(DatasetDagRunQueue.target_dag_id == dag_run.dag_id) + .order_by(DagRun.execution_date.desc()) + .limit(1) + ) + dataset_event_filters = [ + DagScheduleDatasetReference.dag_id == dag.dag_id, + DatasetEvent.timestamp <= exec_date, + ] + if prev_exec_date: + dataset_event_filters.append(DatasetEvent.timestamp > prev_exec_date) + + dataset_events = session.scalars( + select(DatasetEvent) + .join( + DagScheduleDatasetReference, + DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, ) + .join(DatasetEvent.source_dag_run) + .where(*dataset_event_filters) + ).all() + + data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) + run_id = dag.timetable.generate_run_id( + run_type=DagRunType.DATASET_TRIGGERED, + logical_date=exec_date, + data_interval=data_interval, + session=session, + events=dataset_events, + ) + + dag_run = dag.create_dagrun( + run_id=run_id, + run_type=DagRunType.DATASET_TRIGGERED, + execution_date=exec_date, + data_interval=data_interval, + state=DagRunState.QUEUED, + external_trigger=False, + session=session, + dag_hash=dag_hash, + creating_job_id=self.job.id, + ) + Stats.incr("dataset.triggered_dagruns") + dag_run.consumed_dataset_events.extend(dataset_events) + session.execute( + delete(DatasetDagRunQueue).where(DatasetDagRunQueue.target_dag_id == dag_run.dag_id) + ) def _should_update_dag_next_dagruns( self, From 6ab9d8f627f74b7737329eac02b08b5a68e3744d Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 20 Feb 2024 15:25:58 -0800 Subject: [PATCH 12/23] Revert "Query "existing dag runs" one at a time" This reverts commit 2a2a195244b4e40882cd0f758b356d21167a5299. --- airflow/jobs/scheduler_job_runner.py | 114 +++++++++++++-------------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 5dd2a65e27be8..32cc9f5a634ab 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -31,7 +31,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator -from sqlalchemy import and_, delete, exists, func, literal_column, not_, or_, select, text, update +from sqlalchemy import and_, delete, func, not_, or_, select, text, update from sqlalchemy.exc import OperationalError from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -1215,6 +1215,13 @@ def _create_dag_runs_dataset_triggered( dag_id: timezone.coerce_datetime(last_time) for dag_id, (_, last_time) in dataset_triggered_dag_info.items() } + existing_dagruns: set[tuple[str, timezone.DateTime]] = set( + session.execute( + select(DagRun.dag_id, DagRun.execution_date).where( + tuple_in_condition((DagRun.dag_id, DagRun.execution_date), exec_dates.items()) + ) + ) + ) for dag_model in dag_models: dag = self.dagbag.get_dag(dag_model.dag_id, session=session) @@ -1240,68 +1247,59 @@ def _create_dag_runs_dataset_triggered( # create a new one. This is so that in the next Scheduling loop we try to create new runs # instead of falling in a loop of Integrity Error. exec_date = exec_dates[dag.dag_id] - curr_date_query = select( - exists( - select(literal_column("1")) - .where(DagRun.dag_id == dag.dag_id, DagRun.execution_date == exec_date) + if (dag.dag_id, exec_date) not in existing_dagruns: + previous_dag_run = session.scalar( + select(DagRun) + .where( + DagRun.dag_id == dag.dag_id, + DagRun.execution_date < exec_date, + DagRun.run_type == DagRunType.DATASET_TRIGGERED, + ) + .order_by(DagRun.execution_date.desc()) .limit(1) ) - ) - if session.scalar(curr_date_query): # dag already exists - continue + dataset_event_filters = [ + DagScheduleDatasetReference.dag_id == dag.dag_id, + DatasetEvent.timestamp <= exec_date, + ] + if previous_dag_run: + dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date) + + dataset_events = session.scalars( + select(DatasetEvent) + .join( + DagScheduleDatasetReference, + DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, + ) + .join(DatasetEvent.source_dag_run) + .where(*dataset_event_filters) + ).all() + + data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) + run_id = dag.timetable.generate_run_id( + run_type=DagRunType.DATASET_TRIGGERED, + logical_date=exec_date, + data_interval=data_interval, + session=session, + events=dataset_events, + ) - prev_exec_date = session.scalar( - select(DagRun.execution_date) - .where( - DagRun.dag_id == dag.dag_id, - DagRun.execution_date < exec_date, - DagRun.run_type == DagRunType.DATASET_TRIGGERED, + dag_run = dag.create_dagrun( + run_id=run_id, + run_type=DagRunType.DATASET_TRIGGERED, + execution_date=exec_date, + data_interval=data_interval, + state=DagRunState.QUEUED, + external_trigger=False, + session=session, + dag_hash=dag_hash, + creating_job_id=self.job.id, ) - .order_by(DagRun.execution_date.desc()) - .limit(1) - ) - dataset_event_filters = [ - DagScheduleDatasetReference.dag_id == dag.dag_id, - DatasetEvent.timestamp <= exec_date, - ] - if prev_exec_date: - dataset_event_filters.append(DatasetEvent.timestamp > prev_exec_date) - - dataset_events = session.scalars( - select(DatasetEvent) - .join( - DagScheduleDatasetReference, - DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, + Stats.incr("dataset.triggered_dagruns") + dag_run.consumed_dataset_events.extend(dataset_events) + session.execute( + delete(DatasetDagRunQueue).where(DatasetDagRunQueue.target_dag_id == dag_run.dag_id) ) - .join(DatasetEvent.source_dag_run) - .where(*dataset_event_filters) - ).all() - - data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) - run_id = dag.timetable.generate_run_id( - run_type=DagRunType.DATASET_TRIGGERED, - logical_date=exec_date, - data_interval=data_interval, - session=session, - events=dataset_events, - ) - - dag_run = dag.create_dagrun( - run_id=run_id, - run_type=DagRunType.DATASET_TRIGGERED, - execution_date=exec_date, - data_interval=data_interval, - state=DagRunState.QUEUED, - external_trigger=False, - session=session, - dag_hash=dag_hash, - creating_job_id=self.job.id, - ) - Stats.incr("dataset.triggered_dagruns") - dag_run.consumed_dataset_events.extend(dataset_events) - session.execute( - delete(DatasetDagRunQueue).where(DatasetDagRunQueue.target_dag_id == dag_run.dag_id) - ) def _should_update_dag_next_dagruns( self, From df77b9328da45ffe9c7d64cd6b15c91c8e58c2ec Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:25:57 -0800 Subject: [PATCH 13/23] enable serialization check of dataset_triggers attr --- tests/serialization/test_dag_serialization.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index cc906d6f3b9cb..2adc956b6f4d0 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -60,6 +60,7 @@ from airflow.serialization.enums import Encoding from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import ( + BaseSerialization, DagDependency, DependencyDetector, SerializedBaseOperator, @@ -549,13 +550,18 @@ def validate_deserialized_dag(self, serialized_dag, dag): "_task_group", "params", "_processor_dags_folder", + } + compare_serialization_list = { "dataset_triggers", } fields_to_check = dag.get_serialized_fields() - exclusion_list for field in fields_to_check: - assert getattr(serialized_dag, field) == getattr( - dag, field - ), f"{dag.dag_id}.{field} does not match" + actual = getattr(serialized_dag, field) + expected = getattr(dag, field) + if field in compare_serialization_list: + actual = BaseSerialization.serialize(actual) + expected = BaseSerialization.serialize(expected) + assert actual == expected, f"{dag.dag_id}.{field} does not match" # _processor_dags_folder is only populated at serialization time # it's only used when relying on serialized dag to determine a dag's relative path assert dag._processor_dags_folder is None From d068f578c1def263db4bc22e3e63cfc60f28f6bb Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:27:30 -0800 Subject: [PATCH 14/23] Update airflow/models/dataset.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 0082e3ca4c456..220c1dcc4fd7b 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -351,7 +351,7 @@ class DatasetBooleanCondition: def __init__(self, *objects): self.objects = objects - def evaluate(self, statuses: dict[str, bool]): + def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(self.eval_one(x, statuses) for x in self.objects) def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses): From 9d204ad1b6bca149044252d8d19367a2d3737bc1 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:27:36 -0800 Subject: [PATCH 15/23] Update airflow/models/dataset.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 220c1dcc4fd7b..9a7d719d96052 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -354,7 +354,7 @@ def __init__(self, *objects): def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(self.eval_one(x, statuses) for x in self.objects) - def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses): + def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool: if isinstance(obj, Dataset): return statuses.get(obj.uri, False) return obj.evaluate(statuses=statuses) From 9b5bac7d6d9707384ae6c5e402f7ebaa68cf7553 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:27:48 -0800 Subject: [PATCH 16/23] Update airflow/models/dataset.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 9a7d719d96052..e0bb2c4553d5f 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -348,7 +348,7 @@ class DatasetBooleanCondition: agg_func: Callable - def __init__(self, *objects): + def __init__(self, *objects) -> None: self.objects = objects def evaluate(self, statuses: dict[str, bool]) -> bool: From 71f6ebad5bb552d55f430fdcffc533ee635671ce Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:46:59 -0800 Subject: [PATCH 17/23] remove unnecessary db test mark --- tests/datasets/test_dataset.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index f19f95d0c7007..047104a7d6692 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -30,7 +30,6 @@ from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG -@pytest.mark.db_test @pytest.mark.parametrize( ["uri"], [ @@ -46,26 +45,22 @@ def test_invalid_uris(uri): Dataset(uri=uri) -@pytest.mark.db_test def test_uri_with_scheme(): dataset = Dataset(uri="s3://example_dataset") EmptyOperator(task_id="task1", outlets=[dataset]) -@pytest.mark.db_test def test_uri_without_scheme(): dataset = Dataset(uri="example_dataset") EmptyOperator(task_id="task1", outlets=[dataset]) -@pytest.mark.db_test def test_fspath(): uri = "s3://example_dataset" dataset = Dataset(uri=uri) assert os.fspath(dataset) == uri -@pytest.mark.db_test @pytest.mark.parametrize( "inputs, scenario, expected", [ @@ -105,7 +100,6 @@ def test_dataset_logical_conditions_evaluation_and_serialization(inputs, scenari assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" -@pytest.mark.db_test @pytest.mark.parametrize( "status_values, expected_evaluation", [ @@ -149,8 +143,8 @@ def create_test_datasets(session): session.commit() return datasets - @pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets): datasets = create_test_datasets @@ -177,7 +171,8 @@ def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test @pytest.mark.db_test -def test_dataset_dag_run_queue_processing(session, dag_maker, create_test_datasets): +@pytest.mark.usefixtures("clear_datasets") +def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, create_test_datasets): datasets = create_test_datasets dataset_models = session.query(DatasetModel).all() @@ -205,14 +200,6 @@ def test_dataset_dag_run_queue_processing(session, dag_maker, create_test_datase assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG trigger evaluation failed" -@pytest.mark.db_test -@pytest.mark.usefixtures("create_test_datasets") -def test_additional_dag_with_no_triggers(dag_maker): - # Create an additional DAG to ensure it's not affected by dataset triggers - with dag_maker(dag_id="dag2"): - EmptyOperator(task_id="hello2") - - @pytest.fixture def setup_datasets_and_models(session): """Fixture to create datasets and corresponding models.""" @@ -230,6 +217,7 @@ def setup_datasets_and_models(session): @pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") def test_dag_with_complex_dataset_triggers(session, dag_maker, setup_datasets_and_models): d1, d2 = setup_datasets_and_models @@ -262,7 +250,7 @@ def test_dag_with_complex_dataset_triggers(session, dag_maker, setup_datasets_an ), "Serialized 'dataset_triggers' should be a dict" -@pytest.fixture(autouse=True) +@pytest.fixture def clear_datasets(): from tests.test_utils.db import clear_db_datasets From 4721b6af017da01d63337d9b951ba49c1396bae7 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:49:43 -0800 Subject: [PATCH 18/23] remove unnec add --- airflow/example_dags/example_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py index c09d0961c29d9..9dfaaf0c34de0 100644 --- a/airflow/example_dags/example_datasets.py +++ b/airflow/example_dags/example_datasets.py @@ -50,7 +50,6 @@ dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"}) # [END dataset_def] dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"}) -dag3_dataset = Dataset("s3://dag3/output_3.txt", extra={"hi": "bye"}) with DAG( dag_id="dataset_produces_1", From 4df31d4f4c9c887cb80ab60b7f69fa5641cebe9d Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:59:36 -0800 Subject: [PATCH 19/23] Apply suggestions from code review Co-authored-by: Wei Lee Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/dataset.py | 2 +- tests/datasets/test_dataset.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index e0bb2c4553d5f..b24915fe61554 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -346,7 +346,7 @@ class DatasetBooleanCondition: :meta private: """ - agg_func: Callable + agg_func: Callable[[Iterable], bool] def __init__(self, *objects) -> None: self.objects = objects diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 047104a7d6692..fe37f110a3980 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -255,3 +255,5 @@ def clear_datasets(): from tests.test_utils.db import clear_db_datasets clear_db_datasets() + yield + clear_db_datasets() From 10909dc30511b5629ea1c942543e5edc52af5eb3 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:02:14 -0800 Subject: [PATCH 20/23] better comment --- airflow/timetables/datasets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index 935a852e93abd..c755df964ee4d 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -56,7 +56,10 @@ def deserialize(cls, data: dict[str, typing.Any]) -> Timetable: return cls( timetable=decode_timetable(data["timetable"]), - datasets=[], # don't need the datasets after deserialization + # don't need the datasets after deserialization + # they are already stored on dataset_triggers attr on DAG + # and this is what scheduler looks at + datasets=[], ) def serialize(self) -> dict[str, typing.Any]: From 28db3057333eda17399a461ff5cdd768df6d365f Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:03:58 -0800 Subject: [PATCH 21/23] remove unnec fixture --- tests/datasets/test_dataset.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index fe37f110a3980..217c0c3857d03 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -143,6 +143,7 @@ def create_test_datasets(session): session.commit() return datasets + @pytest.mark.db_test @pytest.mark.usefixtures("clear_datasets") def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets): @@ -200,9 +201,9 @@ def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, cr assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG trigger evaluation failed" -@pytest.fixture -def setup_datasets_and_models(session): - """Fixture to create datasets and corresponding models.""" +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") +def test_dag_with_complex_dataset_triggers(session, dag_maker): # Create Dataset instances d1 = Dataset(uri="hello1") d2 = Dataset(uri="hello2") @@ -213,14 +214,6 @@ def setup_datasets_and_models(session): session.add_all([dm1, dm2]) session.commit() - return d1, d2 - - -@pytest.mark.db_test -@pytest.mark.usefixtures("clear_datasets") -def test_dag_with_complex_dataset_triggers(session, dag_maker, setup_datasets_and_models): - d1, d2 = setup_datasets_and_models - # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll) with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: EmptyOperator(task_id="hello") From 01ccbc558b589a70a9708fba2717f767307675f7 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:05:05 -0800 Subject: [PATCH 22/23] relocate fixture --- tests/datasets/test_dataset.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 217c0c3857d03..dfc8b82ba1596 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -30,6 +30,15 @@ from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +@pytest.fixture +def clear_datasets(): + from tests.test_utils.db import clear_db_datasets + + clear_db_datasets() + yield + clear_db_datasets() + + @pytest.mark.parametrize( ["uri"], [ @@ -241,12 +250,3 @@ def test_dag_with_complex_dataset_triggers(session, dag_maker): assert isinstance( serialized_dag_dict["dataset_triggers"], dict ), "Serialized 'dataset_triggers' should be a dict" - - -@pytest.fixture -def clear_datasets(): - from tests.test_utils.db import clear_db_datasets - - clear_db_datasets() - yield - clear_db_datasets() From ffe97e8f1ccbfeede903496828a56a58838c6d1b Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:46:16 -0800 Subject: [PATCH 23/23] fix missing import --- airflow/models/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index b24915fe61554..bf28777358786 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Callable +from typing import Callable, Iterable from urllib.parse import urlsplit import sqlalchemy_jsonfield