diff --git a/airflow-core/docs/howto/deadline-alerts.rst b/airflow-core/docs/howto/deadline-alerts.rst index 64f39c0244050..e36908009a0f1 100644 --- a/airflow-core/docs/howto/deadline-alerts.rst +++ b/airflow-core/docs/howto/deadline-alerts.rst @@ -42,7 +42,7 @@ Creating a Deadline Alert Creating a Deadline Alert requires three mandatory parameters: * Reference: When to start counting from -* Interval: How far before or after the reference point to trigger the alert +* Interval: How far before or after the reference point to trigger the alert (either a timedelta or a dynamic interval such as VariableInterval) * Callback: A Callback object which contains a path to a callable and optional kwargs to pass to it if the deadline is exceeded Here is how Deadlines are calculated: diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 82f32c8a2fdd9..32fbc76feb015 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``acc215baed80`` (head) | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. | +| ``8812eb67b63c`` (head) | ``acc215baed80`` | ``3.3.0`` | Change Deadline interval to JSON. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``acc215baed80`` | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``a1b2c3d4e5f6`` | ``a7f3b2c1d4e5`` | ``3.3.0`` | Add version_data to dag_version. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/newsfragments/64751.feature.rst b/airflow-core/newsfragments/64751.feature.rst new file mode 100644 index 0000000000000..41d647f143d2b --- /dev/null +++ b/airflow-core/newsfragments/64751.feature.rst @@ -0,0 +1 @@ +Allow DeadlineAlert intervals to be dynamically resolved at Deadline evaluation using objects such as VariableInterval. diff --git a/airflow-core/src/airflow/migrations/versions/0116_3_3_0_add_team_name_to_trigger_table.py b/airflow-core/src/airflow/migrations/versions/0116_3_3_0_add_team_name_to_trigger_table.py index 503f1b58d7c0e..e217cbeb45ce4 100644 --- a/airflow-core/src/airflow/migrations/versions/0116_3_3_0_add_team_name_to_trigger_table.py +++ b/airflow-core/src/airflow/migrations/versions/0116_3_3_0_add_team_name_to_trigger_table.py @@ -22,6 +22,7 @@ Revision ID: acc215baed80 Revises: a1b2c3d4e5f6 Create Date: 2026-05-21 21:38:00.122692 + """ from __future__ import annotations diff --git a/airflow-core/src/airflow/migrations/versions/0117_3_3_0_change_deadline_interval_to_json.py b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_change_deadline_interval_to_json.py new file mode 100644 index 0000000000000..04e5a35aa3143 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_change_deadline_interval_to_json.py @@ -0,0 +1,305 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Change Deadline interval to JSON. + +Revision ID: 8812eb67b63c +Revises: acc215baed80 +Create Date: 2026-05-28 17:36:56.837243 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import context, op + +# revision identifiers, used by Alembic. +revision = "8812eb67b63c" +down_revision = "acc215baed80" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """Apply change deadline interval to JSON.""" + conn = op.get_bind() + dialect = conn.dialect.name + + if context.is_offline_mode(): + print( + """ + Manual conversion required: + + PostgreSQL: + + Step 1: Convert column type. + ALTER TABLE deadline_alert + ALTER COLUMN interval TYPE JSONB + USING to_json(interval); + + Step 2: Convert values. + UPDATE deadline_alert + SET interval = json_build_object( + '__classname__', 'datetime.timedelta', + '__version__', 2, + '__data__', (interval::text)::float + ) + WHERE jsonb_typeof(interval::jsonb) = 'number'; + + MySQL: + + Step 1: Convert column type. + ALTER TABLE deadline_alert MODIFY COLUMN `interval` JSON; + + Step 2: Convert values + UPDATE deadline_alert + SET `interval` = JSON_OBJECT( + '__classname__', 'datetime.timedelta', + '__version__', 2, + '__data__', `interval` + ); + + SQLite: + + UPDATE deadline_alert + SET interval = + '{"__classname__":"datetime.timedelta","__version__":2,"__data__":' + || CAST(interval AS TEXT) || '}'; + """ + ) + return + + with op.batch_alter_table("deadline_alert") as batch_op: + if dialect == "postgresql": + batch_op.alter_column( + "interval", + existing_type=sa.FLOAT(), + type_=sa.JSON(), + postgresql_using="to_json(interval)", + existing_nullable=False, + ) + else: + batch_op.alter_column( + "interval", + existing_type=sa.FLOAT(), + type_=sa.JSON(), + existing_nullable=False, + ) + + if dialect == "postgresql": + op.execute(""" + UPDATE deadline_alert + SET interval = json_build_object( + '__classname__', 'datetime.timedelta', + '__version__', 2, + '__data__', (interval::text)::float + ) + WHERE jsonb_typeof(interval::jsonb) = 'number' + """) + + elif dialect == "mysql": + op.execute(""" + UPDATE deadline_alert + SET `interval` = JSON_OBJECT( + '__classname__', 'datetime.timedelta', + '__version__', 2, + '__data__', `interval` + ) + """) + + else: + op.execute(""" + UPDATE deadline_alert + SET interval = + '{"__classname__":"datetime.timedelta","__version__":' + || '2' || + ',"__data__":' || CAST(interval AS TEXT) || '}' + """) + + +def downgrade(): + """Revert deadline interval back to float.""" + conn = op.get_bind() + dialect = conn.dialect.name + + if context.is_offline_mode(): + print( + """ + Manual downgrade required: + + PostgreSQL: + + Step 1: Convert values. + UPDATE deadline_alert + SET interval = + CASE + WHEN jsonb_typeof(interval::jsonb) = 'number' + THEN interval + WHEN (interval::jsonb)->>'__classname__' = 'datetime.timedelta' + THEN to_json((interval->>'__data__')::double precision) + ELSE NULL + END; + + Step 2: Convert column type. + ALTER TABLE deadline_alert + ALTER COLUMN interval TYPE DOUBLE PRECISION + USING ( + CASE + WHEN jsonb_typeof(interval::jsonb) = 'number' + THEN interval::text::double precision + WHEN (interval::jsonb)->>'__classname__' = 'datetime.timedelta' + THEN (interval->>'__data__')::double precision + ELSE NULL + END + ); + + MySQL: + + Step 1: Convert values + UPDATE deadline_alert + SET `interval` = + CASE + WHEN JSON_EXTRACT(`interval`, '$.__data__') IS NOT NULL + THEN CAST(JSON_EXTRACT(`interval`, '$.__data__') AS DOUBLE) + WHEN JSON_EXTRACT(`interval`, '$.__classname__') IS NULL + THEN CAST(`interval` AS DOUBLE) + ELSE NULL + END; + + Step 2: Convert column type + ALTER TABLE deadline_alert + MODIFY COLUMN `interval` DOUBLE; + + SQLite: + + Step 1: Convert values + UPDATE deadline_alert + SET interval = + CASE + WHEN json_extract(interval, '$.__data__') IS NOT NULL + THEN CAST(json_extract(interval, '$.__data__') AS REAL) + WHEN json_extract(interval, '$.__classname__') IS NULL + THEN CAST(interval AS REAL) + ELSE NULL + END; + + Step 2: SQLite does not support ALTER COLUMN TYPE. + Recreate the table with interval as REAL and copy data. + """ + ) + return + + if dialect == "postgresql": + op.execute(""" + UPDATE deadline_alert + SET interval = + CASE + WHEN jsonb_typeof(interval::jsonb) = 'number' + THEN interval + WHEN (interval::jsonb)->>'__classname__' = 'datetime.timedelta' + THEN to_json((interval->>'__data__')::double precision) + ELSE NULL + END + """) + + elif dialect == "mysql": + op.execute(""" + UPDATE deadline_alert + SET `interval` = + CASE + WHEN JSON_EXTRACT(`interval`, '$.__data__') IS NOT NULL + THEN CAST(JSON_EXTRACT(`interval`, '$.__data__') AS DOUBLE) + WHEN JSON_EXTRACT(`interval`, '$.__classname__') IS NULL + THEN CAST(`interval` AS DOUBLE) + ELSE NULL + END + """) + + # Serialized VariableInterval objects do not contain a numeric "__data__" field + # and therefore cannot be converted back to a float representation. + # During downgrade, only timedelta-style serialized values are converted. + # Other serialized interval types (e.g. VariableInterval) will cast as null. + else: + # Detect availability of SQLite JSON functions (JSON1 extension). + json_functions_available = False + try: + conn.execute(sa.text("SELECT json_extract('{\"a\":1}', '$.a')")).fetchone() + json_functions_available = True + except Exception: + print("SQLite JSON functions not available, using string parsing as fallback.") + + if json_functions_available: + op.execute(""" + UPDATE deadline_alert + SET interval = + CASE + WHEN json_extract(interval, '$.__data__') IS NOT NULL + THEN CAST(json_extract(interval, '$.__data__') AS REAL) + WHEN json_extract(interval, '$.__classname__') IS NULL + THEN CAST(interval AS REAL) + ELSE NULL + END + """) + else: + # NOTE: This is a best-effort fallback for environments without JSON1. + # It assumes a stable JSON format and may not work for all serialized values. + op.execute(""" + UPDATE deadline_alert + SET interval = + CASE + WHEN instr(interval, '__data__') > 0 + THEN CAST( + substr( + interval, + instr(interval, '__data__') + + instr(substr(interval, instr(interval, '__data__')), ':') + ) AS FLOAT + ) + WHEN instr(interval, '__classname__') = 0 + THEN CAST(interval AS FLOAT) + ELSE NULL + END + """) + + with op.batch_alter_table("deadline_alert") as batch_op: + if dialect == "postgresql": + batch_op.alter_column( + "interval", + existing_type=sa.JSON(), + type_=sa.FLOAT(), + postgresql_using=""" + CASE + WHEN jsonb_typeof(interval::jsonb) = 'number' + THEN interval::text::double precision + WHEN (interval::jsonb)->>'__classname__' = 'datetime.timedelta' + THEN (interval->>'__data__')::double precision + ELSE NULL + END + """, + existing_nullable=False, + ) + else: + batch_op.alter_column( + "interval", + existing_type=sa.JSON(), + type_=sa.FLOAT(), + existing_nullable=False, + ) diff --git a/airflow-core/src/airflow/models/deadline_alert.py b/airflow-core/src/airflow/models/deadline_alert.py index 0b8a8eba9b1b1..d9b6590c0f7f1 100644 --- a/airflow-core/src/airflow/models/deadline_alert.py +++ b/airflow-core/src/airflow/models/deadline_alert.py @@ -21,7 +21,7 @@ from uuid import UUID import uuid6 -from sqlalchemy import JSON, Float, ForeignKey, String, Text, Uuid, select +from sqlalchemy import JSON, ForeignKey, String, Text, Uuid, select from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Mapped, mapped_column @@ -50,13 +50,22 @@ class DeadlineAlert(Base): name: Mapped[str | None] = mapped_column(String(250), nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True) reference: Mapped[dict] = mapped_column(JSON, nullable=False) - interval: Mapped[float] = mapped_column(Float, nullable=False) + interval: Mapped[dict] = mapped_column(JSON, nullable=False) callback_def: Mapped[dict] = mapped_column(JSON, nullable=False) def __repr__(self): - interval_seconds = int(self.interval) - if interval_seconds >= 3600: + interval_seconds = None + + if isinstance(self.interval, (int, float)): + interval_seconds = int(self.interval) + + elif isinstance(self.interval, datetime.timedelta): + interval_seconds = int(self.interval.total_seconds()) + + if interval_seconds is None: + interval_display = "dynamic" + elif interval_seconds >= 3600: interval_display = f"{interval_seconds // 3600}h" elif interval_seconds >= 60: interval_display = f"{interval_seconds // 60}m" diff --git a/airflow-core/src/airflow/serialization/decoders.py b/airflow-core/src/airflow/serialization/decoders.py index 22683ef5d612b..b36b7a8a52499 100644 --- a/airflow-core/src/airflow/serialization/decoders.py +++ b/airflow-core/src/airflow/serialization/decoders.py @@ -156,6 +156,7 @@ def decode_deadline_alert(encoded_data: dict): :meta private: """ + from airflow.sdk.definitions.deadline import VariableInterval from airflow.sdk.serde import deserialize data = encoded_data.get(Encoding.VAR, encoded_data) @@ -163,9 +164,30 @@ def decode_deadline_alert(encoded_data: dict): reference_data = data[DeadlineAlertFields.REFERENCE] reference = decode_deadline_reference(reference_data) + raw_interval = data[DeadlineAlertFields.INTERVAL] + + if raw_interval is None: + raise ValueError( + "DeadlineAlert interval is missing. This can happen after downgrading " + "from a version that supports VariableInterval. Downgrade is not fully reversible." + ) + + interval: datetime.timedelta | VariableInterval + + # Backward compatibility: previously interval was stored as total_seconds() (float/int). + # Handle numeric values by converting to timedelta. + if isinstance(raw_interval, (int, float)): + interval = datetime.timedelta(seconds=raw_interval) + else: + deserialized = deserialize(raw_interval) + if isinstance(deserialized, (datetime.timedelta, VariableInterval)): + interval = deserialized + else: + raise TypeError(f"Invalid interval type: {type(deserialized).__name__}") + return SerializedDeadlineAlert( reference=reference, - interval=datetime.timedelta(seconds=data[DeadlineAlertFields.INTERVAL]), + interval=interval, callback=deserialize(data[DeadlineAlertFields.CALLBACK]), name=data.get(DeadlineAlertFields.NAME), ) diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py b/airflow-core/src/airflow/serialization/definitions/dag.py index d20b9c22c039c..6fb7bf083cf76 100644 --- a/airflow-core/src/airflow/serialization/definitions/dag.py +++ b/airflow-core/src/airflow/serialization/definitions/dag.py @@ -41,6 +41,7 @@ from airflow.models.deadline_alert import DeadlineAlert as DeadlineAlertModel from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.tasklog import LogTemplate +from airflow.sdk.definitions.deadline import VariableInterval from airflow.serialization.decoders import decode_deadline_alert from airflow.serialization.definitions.deadline import DeadlineAlertFields, SerializedReferenceModels from airflow.serialization.definitions.param import SerializedParamsDict @@ -653,10 +654,15 @@ def _process_dagrun_deadline_alerts( } ) + interval = deserialized_deadline_alert.interval + + if isinstance(interval, VariableInterval): + interval = interval.resolve() + if isinstance(deserialized_deadline_alert.reference, SerializedReferenceModels.TYPES.DAGRUN): deadline_time = deserialized_deadline_alert.reference.evaluate_with( session=session, - interval=deserialized_deadline_alert.interval, + interval=interval, # TODO : Pretty sure we can drop these last two; verify after testing is complete dag_id=self.dag_id, run_id=orm_dagrun.run_id, diff --git a/airflow-core/src/airflow/serialization/definitions/deadline.py b/airflow-core/src/airflow/serialization/definitions/deadline.py index 58eaa46e6f721..89e231cba24dc 100644 --- a/airflow-core/src/airflow/serialization/definitions/deadline.py +++ b/airflow-core/src/airflow/serialization/definitions/deadline.py @@ -38,6 +38,8 @@ from sqlalchemy import ColumnElement from sqlalchemy.orm import Session + from airflow.sdk.definitions.deadline import VariableInterval + logger = logging.getLogger(__name__) @@ -366,6 +368,6 @@ class SerializedDeadlineAlert: """Serialized representation of a deadline alert.""" reference: SerializedReferenceModels.SerializedBaseDeadlineReference - interval: timedelta + interval: timedelta | VariableInterval callback: Any name: str | None = None diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index e97dcff26237c..b9caea4cc3722 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -226,7 +226,7 @@ def encode_deadline_alert(d: DeadlineAlert | SerializedDeadlineAlert) -> dict[st return { "name": d.name, "reference": encode_deadline_reference(d.reference), - "interval": d.interval.total_seconds(), + "interval": serialize(d.interval), "callback": serialize(d.callback), } diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 00d512909dc5a..0d3bf5ef1aa32 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", - "3.3.0": "acc215baed80", + "3.3.0": "8812eb67b63c", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index f9751684e35cf..0a15c9ab09df5 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -2229,7 +2229,10 @@ def test_dag_with_multiple_deadlines(self, testing_dag_bundle, session): ).all() assert len(stored_alerts) == expected_num_deadlines - intervals = sorted([alert.interval for alert in stored_alerts]) + intervals = sorted( + alert.interval["__data__"] if isinstance(alert.interval, dict) else alert.interval + for alert in stored_alerts + ) assert intervals == [300.0, 600.0, 3600.0] # Now create a dagrun and verify deadlines are created diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 5f555ba69d030..7ac2514cbf72d 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -61,7 +61,9 @@ from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown from airflow.sdk.definitions.callback import AsyncCallback -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference, VariableInterval +from airflow.sdk.definitions.variable import Variable +from airflow.sdk.exceptions import AirflowRuntimeError from airflow.serialization.definitions.deadline import SerializedReferenceModels from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.settings import get_policy_plugin_manager @@ -1326,17 +1328,28 @@ def test_dag_run_dag_versions_with_null_created_dag_version(self, dag_maker, ses assert isinstance(dag_run.dag_versions, list) assert len(dag_run.dag_versions) == 0 + @pytest.mark.parametrize( + "interval", + [ + datetime.timedelta(hours=1), + VariableInterval("my_key"), + ], + ) + @mock.patch.object(Variable, "get") @mock.patch.object(Deadline, "prune_deadlines") - def test_dagrun_success_deadline(self, _, session, deadline_test_dag): + def test_dagrun_success_deadline(self, _, mock_get, interval, session, deadline_test_dag): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dag" future_date = datetime.datetime.now() + datetime.timedelta(days=365) + # First value used during resolution + mock_get.return_value = "5" + scheduler_dag = deadline_test_dag( deadline=DeadlineAlert( reference=DeadlineReference.FIXED_DATETIME(future_date), - interval=datetime.timedelta(hours=1), + interval=interval, callback=AsyncCallback(empty_callback_for_deadline), ), on_success_callback=on_success_callable, @@ -1441,6 +1454,73 @@ def test_dagrun_success_handles_empty_deadline_list(self, mock_prune, dag_maker, mock_prune.assert_not_called() assert dag_run.state == DagRunState.SUCCESS + @mock.patch.object(Variable, "get") + @mock.patch.object(Deadline, "prune_deadlines") + def test_dagrun_deadline_variable_interval_stable(self, _, mock_get, session, deadline_test_dag): + future_date = datetime.datetime.now() + datetime.timedelta(days=365) + + # First value used during resolution. + mock_get.return_value = "60" + + scheduler_dag = deadline_test_dag( + deadline=DeadlineAlert( + reference=DeadlineReference.FIXED_DATETIME(future_date), + interval=VariableInterval("my_key"), + callback=AsyncCallback(empty_callback_for_deadline), + ), + ) + + dag_run = self.create_dag_run( + dag=scheduler_dag, + task_states={"task_1": TaskInstanceState.SUCCESS, "task_2": TaskInstanceState.SUCCESS}, + session=session, + ) + dag_run.dag = scheduler_dag + + # First update resolve interval to "5". + dag_run.update_state(session=session) + + deadline = session.execute(select(Deadline)).scalars().one_or_none() + first_deadline_time = deadline.deadline_time + + # Change Variable value after resolution. + mock_get.return_value = "120" + + # Run again (This should not change existing deadline). + dag_run.update_state(session=session) + + deadline = session.execute(select(Deadline)).scalars().one_or_none() + assert deadline.deadline_time == first_deadline_time + + @mock.patch.object(Deadline, "prune_deadlines") + def test_dagrun_deadline_variable_interval_missing_variable_fails(self, _, session, deadline_test_dag): + + mock_err = mock.Mock() + mock_err.error.value = "MISSING_DEADLINE" + mock_err.detail = "missing deadline" + + with mock.patch.object( + Variable, + "get", + side_effect=AirflowRuntimeError(mock_err), + ): + future_date = datetime.datetime.now() + datetime.timedelta(days=365) + + scheduler_dag = deadline_test_dag( + deadline=DeadlineAlert( + reference=DeadlineReference.FIXED_DATETIME(future_date), + interval=VariableInterval("missing_key"), + callback=AsyncCallback(empty_callback_for_deadline), + ), + ) + + with pytest.raises(ValueError, match="not found"): + self.create_dag_run( + dag=scheduler_dag, + task_states={"task_1": TaskInstanceState.SUCCESS}, + session=session, + ) + @pytest.mark.parametrize( ("run_type", "expected_tis"), diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 54438f8e82fc9..765b2adf206ae 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -849,7 +849,7 @@ def test_deadline_interval_change_triggers_new_serdag(self, testing_dag_bundle, # There should be a second serdag with a new hash and the new interval. assert new_serdag_count == 2 assert new_serdag.dag_hash != orig_serdag.dag_hash - assert new_alert.interval == 600.0 + assert new_alert.interval["__data__"] == 600.0 def test_deadline_name_change_updates_db_and_returns_true(self, testing_dag_bundle, session): """Name-only deadline change: UUID reused, DB row updated, write_dag returns True.""" diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 1767209ca7350..6d558628e7431 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -513,6 +513,22 @@ def test_serialize_deserialize_deadline_alert(reference): assert deserialized.callback == original.callback +def test_deserialize_deadline_alert_none_interval_raises(): + valid = DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=timedelta(hours=1), + callback=AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS), + ) + + serialized = BaseSerialization.serialize(valid) + + # Inject downgrade corruption. + serialized[Encoding.VAR][DeadlineAlertFields.INTERVAL] = None + + with pytest.raises(ValueError, match="interval"): + BaseSerialization.deserialize(serialized) + + @pytest.mark.parametrize( "conn_uri", [ diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index 2ab20d056b9d5..a9da3dd3d3ea2 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -22,7 +22,11 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any +import attrs + from airflow.sdk.definitions.callback import AsyncCallback, Callback, SyncCallback +from airflow.sdk.definitions.variable import Variable +from airflow.sdk.exceptions import AirflowRuntimeError if TYPE_CHECKING: from collections.abc import Callable @@ -143,7 +147,7 @@ class DeadlineAlert: def __init__( self, reference: DeadlineReferenceType, - interval: timedelta, + interval: timedelta | VariableInterval, callback: Callback, name: str | None = None, ): @@ -342,3 +346,58 @@ def decorator( return reference_class return decorator + + +@attrs.define(frozen=True) +class VariableInterval: + """ + Interval backed by an Airflow Variable. + + This allows DeadlineAlert intervals to be configured dynamically using + Airflow Variables. The variable value is interpreted as seconds and + converted into a ``timedelta`` object. + + ------ + Usage: + ------ + + .. code-block:: python + + from airflow.sdk import DAG, DeadlineAlert, DeadlineReference, AsyncCallback + + DAG( + dag_id="dag_with_variable_interval", + deadline=DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=VariableInterval("deadline_seconds"), + callback=AsyncCallback(my_callback), + ), + ) + + ------ + Notes: + ------ + * Resolution occurs when deadlines are evaluated (during DagRun creation). + * Changes to the Variable affect only newly parsed DAGs and future DagRuns. + * Existing deadlines are not retroactively updated. + """ + + key: str + + def resolve(self) -> timedelta: + try: + value = Variable.get(self.key) + except AirflowRuntimeError as e: + raise ValueError(f"VariableInterval '{self.key}' not found") from e + + try: + seconds = int(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"VariableInterval '{self.key}' must be an integer (seconds), got: {value!r}" + ) from e + + if seconds <= 0: + raise ValueError(f"VariableInterval '{self.key}' must be > 0, got: {seconds}") + + return timedelta(seconds=seconds) diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py index 8e9e816b30705..b104980e4c986 100644 --- a/task-sdk/tests/task_sdk/definitions/test_deadline.py +++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py @@ -17,12 +17,15 @@ from __future__ import annotations from datetime import datetime, timedelta +from unittest import mock import pytest from task_sdk.definitions.test_callback import TEST_CALLBACK_KWARGS, TEST_CALLBACK_PATH, UNIMPORTABLE_DOT_PATH from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference, VariableInterval +from airflow.sdk.definitions.variable import Variable +from airflow.sdk.exceptions import AirflowRuntimeError DAG_ID = "dag_id_1" RUN_ID = 1 @@ -162,3 +165,50 @@ def test_deadline_alert_rejects_invalid_callback(self): interval=timedelta(hours=1), callback="not_a_callback", # type: ignore ) + + +class TestVariableInterval: + @pytest.mark.parametrize( + ("value", "expected"), + [ + ("3", timedelta(seconds=3)), + ("10", timedelta(seconds=10)), + ("05", timedelta(seconds=5)), # leading zero + ], + ) + def test_resolve_valid(self, mocker, value, expected): + mocker.patch.object(Variable, "get", return_value=value) + + interval = VariableInterval(key="test_interval") + + assert interval.resolve() == expected + + @pytest.mark.parametrize( + ("value", "raise_runtime", "match"), + [ + (None, True, "not found"), + ("abc", False, "must be an integer"), + ("", False, "must be an integer"), + ("0", False, "must be > 0"), + ("-5", False, "must be > 0"), + ], + ) + def test_resolve_invalid(self, mocker, value, raise_runtime, match): + + if raise_runtime: + mock_err = mock.Mock() + mock_err.error.value = "MISSING" + mock_err.detail = "missing" + + mocker.patch.object( + Variable, + "get", + side_effect=AirflowRuntimeError(mock_err), + ) + else: + mocker.patch.object(Variable, "get", return_value=value) + + interval = VariableInterval(key="test_interval") + + with pytest.raises(ValueError, match=match): + interval.resolve()