diff --git a/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py b/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py index 7bd5081810507..0233f48c9a9e6 100644 --- a/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py +++ b/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py @@ -30,6 +30,7 @@ from __future__ import annotations +import contextlib import json import zlib from collections import defaultdict @@ -363,6 +364,28 @@ def _sort_serialized_dag_dict(serialized_dag: Any): return serialized_dag +@contextlib.contextmanager +def _begin_nested_transaction(conn): + """ + Create a nested transaction. + + On SQLite, uses ``conn.begin_nested()`` with commit/rollback. + On other backends, opens a new connection via ``conn.engine.begin()`` + and yields it so callers use the new connection for writes. + """ + if conn.dialect.name != "sqlite": + with conn.engine.begin() as new_conn: + yield new_conn + return + try: + savepoint = conn.begin_nested() + yield conn + except Exception: + savepoint.rollback() + raise + savepoint.commit() + + def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: """Extract DeadlineAlert data from serialized Dag data and populate deadline_alert table.""" if context.is_offline_mode(): @@ -411,8 +434,7 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: """), {"last_dag_id": last_dag_id, "batch_size": BATCH_SIZE}, ) - - batch_results = sorted(list(result), key=lambda r: r.dag_id) + batch_results = sorted(result, key=lambda r: r.dag_id) else: result = conn.execute( sa.text(""" @@ -425,7 +447,6 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: """), {"last_dag_id": last_dag_id, "batch_size": BATCH_SIZE}, ) - batch_results = list(result) if not batch_results: break @@ -436,100 +457,98 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: processed_dags.append(dag_id) last_dag_id = dag_id - # Create a savepoint for this Dag to allow rollback on error. - savepoint = conn.begin_nested() - + # Validation that does not need a DB connection. try: - dag_data = get_dag_data(data, data_compressed) - - if dag_deadline := dag_data[DAG_KEY][DEADLINE_KEY]: - dags_with_deadlines.add(dag_id) - deadline_alerts = dag_deadline if isinstance(dag_deadline, list) else [dag_deadline] - - migrated_alert_ids = [] - - for serialized_alert in deadline_alerts: - if isinstance(serialized_alert, dict): - try: - alert_data = serialized_alert.get(Encoding.VAR, serialized_alert) - - if not DEADLINE_ALERT_REQUIRED_FIELDS.issubset(alert_data): - dags_with_errors[dag_id].append( - f"Invalid DeadlineAlert structure: {serialized_alert}" - ) - continue - - reference_data = json.dumps(alert_data[REFERENCE_KEY], sort_keys=True) - interval_data = float(alert_data.get(INTERVAL_KEY)) - callback_data = json.dumps(alert_data[CALLBACK_KEY], sort_keys=True) - deadline_alert_id = str(uuid6.uuid7()) - - conn.execute( - sa.text(""" - INSERT INTO deadline_alert ( - id, - created_at, - serialized_dag_id, - reference, - interval, - callback_def, - name, - description) - VALUES ( - :id, - :created_at, - :serialized_dag_id, - :reference, - :interval, - :callback_def, - NULL, - NULL) - """), - { - "id": deadline_alert_id, - "created_at": created_at or timezone.utcnow(), - "serialized_dag_id": serialized_dag_id, - "reference": reference_data, - "interval": interval_data, - "callback_def": callback_data, - }, - ) - - if not validate_written_data( - conn, deadline_alert_id, reference_data, interval_data, callback_data - ): - dags_with_errors[dag_id].append( - f"Invalid DeadlineAlert data: {serialized_alert}" - ) - continue - - migrated_alert_ids.append(deadline_alert_id) - migrated_alerts_count += 1 - - conn.execute( - sa.text(""" - UPDATE deadline - SET deadline_alert_id = :alert_id - WHERE dagrun_id IN ( - SELECT dr.id - FROM dag_run dr - JOIN serialized_dag sd ON dr.dag_id = sd.dag_id - WHERE sd.id = :serialized_dag_id) - AND deadline_alert_id IS NULL + dag_deadline = get_dag_data(data, data_compressed)[DAG_KEY].get(DEADLINE_KEY) + except (json.JSONDecodeError, KeyError, TypeError) as e: + dags_with_errors[dag_id].append(f"Could not process serialized Dag: {e}") + continue + if not dag_deadline: + continue + + dags_with_deadlines.add(dag_id) + deadline_alerts = dag_deadline if isinstance(dag_deadline, list) else [dag_deadline] + + def _migrate_dag_deadlines(dag_conn: Connection) -> Iterable[str]: + for serialized_alert in deadline_alerts: + if not isinstance(serialized_alert, dict): + continue + try: + alert_data = serialized_alert.get(Encoding.VAR, serialized_alert) + + if not DEADLINE_ALERT_REQUIRED_FIELDS.issubset(alert_data): + dags_with_errors[dag_id].append( + f"Invalid DeadlineAlert structure: {serialized_alert}" + ) + continue + + reference_data = json.dumps(alert_data[REFERENCE_KEY], sort_keys=True) + interval_data = float(alert_data.get(INTERVAL_KEY)) + callback_data = json.dumps(alert_data[CALLBACK_KEY], sort_keys=True) + deadline_alert_id = str(uuid6.uuid7()) + + dag_conn.execute( + sa.text(""" + INSERT INTO deadline_alert ( + id, + created_at, + serialized_dag_id, + reference, + interval, + callback_def, + name, + description) + VALUES ( + :id, + :created_at, + :serialized_dag_id, + :reference, + :interval, + :callback_def, + NULL, + NULL) """), - {"alert_id": deadline_alert_id, "serialized_dag_id": serialized_dag_id}, - ) - except Exception as e: - dags_with_errors[dag_id].append(f"Failed to process {serialized_alert}: {e}") - continue + { + "id": deadline_alert_id, + "created_at": created_at or timezone.utcnow(), + "serialized_dag_id": serialized_dag_id, + "reference": reference_data, + "interval": interval_data, + "callback_def": callback_data, + }, + ) + + if not validate_written_data( + dag_conn, deadline_alert_id, reference_data, interval_data, callback_data + ): + dags_with_errors[dag_id].append(f"Invalid DeadlineAlert data: {serialized_alert}") + continue + + yield deadline_alert_id + dag_conn.execute( + sa.text(""" + UPDATE deadline + SET deadline_alert_id = :alert_id + WHERE dagrun_id IN ( + SELECT dr.id + FROM dag_run dr + JOIN serialized_dag sd ON dr.dag_id = sd.dag_id + WHERE sd.id = :serialized_dag_id) + AND deadline_alert_id IS NULL + """), + {"alert_id": deadline_alert_id, "serialized_dag_id": serialized_dag_id}, + ) + except Exception as e: + dags_with_errors[dag_id].append(f"Failed to process {serialized_alert}: {e}") + continue + try: + with _begin_nested_transaction(conn) as dag_conn: + migrated_alert_ids = list(_migrate_dag_deadlines(dag_conn)) if migrated_alert_ids: uuid_strings = [str(uuid_id) for uuid_id in migrated_alert_ids] - update_dag_deadline_field(conn, serialized_dag_id, uuid_strings, dialect) - - # Recalculate and update the dag_hash after modifying the deadline data to ensure - # it matches what write_dag() will compute later and avoid re-serialization. - updated_result = conn.execute( + update_dag_deadline_field(dag_conn, serialized_dag_id, uuid_strings, dialect) + updated_result = dag_conn.execute( sa.text( "SELECT data, data_compressed " "FROM serialized_dag " @@ -537,15 +556,12 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: ), {"serialized_dag_id": serialized_dag_id}, ).fetchone() - if updated_result: updated_dag_data = get_dag_data( updated_result.data, updated_result.data_compressed ) - # Import here to avoid a circular dependency issue new_hash = hash_dag(updated_dag_data) - - conn.execute( + dag_conn.execute( sa.text( "UPDATE serialized_dag " "SET dag_hash = :new_hash " @@ -553,13 +569,10 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None: ), {"new_hash": new_hash, "serialized_dag_id": serialized_dag_id}, ) - - # Commit the savepoint if everything succeeded for this Dag. - savepoint.commit() - + migrated_alerts_count += len(migrated_alert_ids) except (json.JSONDecodeError, KeyError, TypeError) as e: - dags_with_errors[dag_id].append(f"Could not process serialized Dag: {e}") - savepoint.rollback() + log.exception("Could not migrate deadline for dag %s", dag_id) + dags_with_errors[dag_id].append(f"Could not migrate deadline: {e}") log.info("Batch complete", batch_num=batch_num, total_batches=total_batches) @@ -652,60 +665,64 @@ def migrate_deadline_alert_data_back_to_serialized_dag() -> None: processed_dags.append(dag_id) last_dag_id = dag_id - # Create a savepoint for this Dag to allow rollback on error. - savepoint = conn.begin_nested() - + # Validation that does not need a DB connection. try: dag_data = get_dag_data(data, data_compressed) - deadline_uuids = dag_data[DAG_KEY][DEADLINE_KEY] - - if not isinstance(deadline_uuids, list) or not deadline_uuids: - continue - - if not all(isinstance(uuid_val, str) for uuid_val in deadline_uuids): - log.warning("Dag has non-string deadline values, skipping", dag_id=dag_id) - continue - - dags_with_deadlines.add(dag_id) - restored_deadline_objects = [] - - alert_result = conn.execute( - sa.select( - deadline_alert_table.c.reference, - deadline_alert_table.c.interval, - deadline_alert_table.c.callback_def, - ).where(deadline_alert_table.c.serialized_dag_id == sa.bindparam("serialized_dag_id")), - {"serialized_dag_id": serialized_dag_id}, - ).fetchall() - - if not alert_result: - dags_with_errors[dag_id].append( - f"Could not find deadline_alert for serialized_dag {serialized_dag_id}" - ) - continue - - for alert in alert_result: - deadline_object = { - Encoding.TYPE: ENCODING_TYPE, - Encoding.VAR: { - REFERENCE_KEY: alert.reference, - INTERVAL_KEY: float(alert.interval), - CALLBACK_KEY: alert.callback_def, - }, - } - restored_deadline_objects.append(deadline_object) - restored_alerts_count += 1 - - # Replace the UUID array with the restored objects. - if restored_deadline_objects: - update_dag_deadline_field(conn, serialized_dag_id, restored_deadline_objects, dialect) - - # Commit the savepoint if everything succeeded for this Dag. - savepoint.commit() + except (json.JSONDecodeError, KeyError, TypeError): + continue + deadline_uuids = ( + dag_data.get(DAG_KEY, {}).get(DEADLINE_KEY) + if isinstance(dag_data.get(DAG_KEY), dict) + else None + ) + + if not isinstance(deadline_uuids, list) or not deadline_uuids: + continue + if not all(isinstance(uuid_val, str) for uuid_val in deadline_uuids): + log.warning("Dag has non-string deadline values, skipping", dag_id=dag_id) + continue + + dags_with_deadlines.add(dag_id) + + try: + with _begin_nested_transaction(conn) as dag_conn: + alert_result = dag_conn.execute( + sa.select( + deadline_alert_table.c.reference, + deadline_alert_table.c.interval, + deadline_alert_table.c.callback_def, + ).where( + deadline_alert_table.c.serialized_dag_id == sa.bindparam("serialized_dag_id") + ), + {"serialized_dag_id": serialized_dag_id}, + ).fetchall() + + if not alert_result: + dags_with_errors[dag_id].append( + f"Could not find deadline_alert for serialized_dag {serialized_dag_id}" + ) + continue + + restored_deadline_objects = [] + for alert in alert_result: + deadline_object = { + Encoding.TYPE: ENCODING_TYPE, + Encoding.VAR: { + REFERENCE_KEY: alert.reference, + INTERVAL_KEY: float(alert.interval), + CALLBACK_KEY: alert.callback_def, + }, + } + restored_deadline_objects.append(deadline_object) + restored_alerts_count += 1 + if restored_deadline_objects: + update_dag_deadline_field( + dag_conn, serialized_dag_id, restored_deadline_objects, dialect + ) except Exception as e: + log.exception("Could not restore deadline for dag %s", dag_id) dags_with_errors[dag_id].append(f"Could not restore deadline: {e}") - savepoint.rollback() log.info("Batch complete", batch_num=batch_num, total_batches=total_batches)