Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from __future__ import annotations

import contextlib
import json
import zlib
from collections import defaultdict
Expand Down Expand Up @@ -363,6 +364,28 @@ def _sort_serialized_dag_dict(serialized_dag: Any):
return serialized_dag


@contextlib.contextmanager
def _begin_nested_transaction(conn):
"""
Create a nested transaction.

On SQLite, uses ``conn.begin_nested()`` with commit/rollback.
On other backends, opens a new connection via ``conn.engine.begin()``
and yields it so callers use the new connection for writes.
"""
if conn.dialect.name != "sqlite":
with conn.engine.begin() as new_conn:
yield new_conn
return
try:
savepoint = conn.begin_nested()
yield conn
except Exception:
savepoint.rollback()
raise
savepoint.commit()


def migrate_existing_deadline_alert_data_from_serialized_dag() -> None:
"""Extract DeadlineAlert data from serialized Dag data and populate deadline_alert table."""
if context.is_offline_mode():
Expand Down Expand Up @@ -411,8 +434,7 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None:
"""),
{"last_dag_id": last_dag_id, "batch_size": BATCH_SIZE},
)

batch_results = sorted(list(result), key=lambda r: r.dag_id)
batch_results = sorted(result, key=lambda r: r.dag_id)
else:
result = conn.execute(
sa.text("""
Expand All @@ -425,7 +447,6 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None:
"""),
{"last_dag_id": last_dag_id, "batch_size": BATCH_SIZE},
)

batch_results = list(result)
if not batch_results:
break
Expand All @@ -436,130 +457,122 @@ def migrate_existing_deadline_alert_data_from_serialized_dag() -> None:
processed_dags.append(dag_id)
last_dag_id = dag_id

# Create a savepoint for this Dag to allow rollback on error.
savepoint = conn.begin_nested()

# Validation that does not need a DB connection.
try:
dag_data = get_dag_data(data, data_compressed)

if dag_deadline := dag_data[DAG_KEY][DEADLINE_KEY]:
dags_with_deadlines.add(dag_id)
deadline_alerts = dag_deadline if isinstance(dag_deadline, list) else [dag_deadline]

migrated_alert_ids = []

for serialized_alert in deadline_alerts:
if isinstance(serialized_alert, dict):
try:
alert_data = serialized_alert.get(Encoding.VAR, serialized_alert)

if not DEADLINE_ALERT_REQUIRED_FIELDS.issubset(alert_data):
dags_with_errors[dag_id].append(
f"Invalid DeadlineAlert structure: {serialized_alert}"
)
continue

reference_data = json.dumps(alert_data[REFERENCE_KEY], sort_keys=True)
interval_data = float(alert_data.get(INTERVAL_KEY))
callback_data = json.dumps(alert_data[CALLBACK_KEY], sort_keys=True)
deadline_alert_id = str(uuid6.uuid7())

conn.execute(
sa.text("""
INSERT INTO deadline_alert (
id,
created_at,
serialized_dag_id,
reference,
interval,
callback_def,
name,
description)
VALUES (
:id,
:created_at,
:serialized_dag_id,
:reference,
:interval,
:callback_def,
NULL,
NULL)
"""),
{
"id": deadline_alert_id,
"created_at": created_at or timezone.utcnow(),
"serialized_dag_id": serialized_dag_id,
"reference": reference_data,
"interval": interval_data,
"callback_def": callback_data,
},
)

if not validate_written_data(
conn, deadline_alert_id, reference_data, interval_data, callback_data
):
dags_with_errors[dag_id].append(
f"Invalid DeadlineAlert data: {serialized_alert}"
)
continue

migrated_alert_ids.append(deadline_alert_id)
migrated_alerts_count += 1

conn.execute(
sa.text("""
UPDATE deadline
SET deadline_alert_id = :alert_id
WHERE dagrun_id IN (
SELECT dr.id
FROM dag_run dr
JOIN serialized_dag sd ON dr.dag_id = sd.dag_id
WHERE sd.id = :serialized_dag_id)
AND deadline_alert_id IS NULL
dag_deadline = get_dag_data(data, data_compressed)[DAG_KEY].get(DEADLINE_KEY)
except (json.JSONDecodeError, KeyError, TypeError) as e:
dags_with_errors[dag_id].append(f"Could not process serialized Dag: {e}")
continue
if not dag_deadline:
continue

dags_with_deadlines.add(dag_id)
deadline_alerts = dag_deadline if isinstance(dag_deadline, list) else [dag_deadline]

def _migrate_dag_deadlines(dag_conn: Connection) -> Iterable[str]:
for serialized_alert in deadline_alerts:
if not isinstance(serialized_alert, dict):
continue
try:
alert_data = serialized_alert.get(Encoding.VAR, serialized_alert)

if not DEADLINE_ALERT_REQUIRED_FIELDS.issubset(alert_data):
dags_with_errors[dag_id].append(
f"Invalid DeadlineAlert structure: {serialized_alert}"
)
continue

reference_data = json.dumps(alert_data[REFERENCE_KEY], sort_keys=True)
interval_data = float(alert_data.get(INTERVAL_KEY))
callback_data = json.dumps(alert_data[CALLBACK_KEY], sort_keys=True)
deadline_alert_id = str(uuid6.uuid7())

dag_conn.execute(
sa.text("""
INSERT INTO deadline_alert (
id,
created_at,
serialized_dag_id,
reference,
interval,
callback_def,
name,
description)
VALUES (
:id,
:created_at,
:serialized_dag_id,
:reference,
:interval,
:callback_def,
NULL,
NULL)
"""),
{"alert_id": deadline_alert_id, "serialized_dag_id": serialized_dag_id},
)
except Exception as e:
dags_with_errors[dag_id].append(f"Failed to process {serialized_alert}: {e}")
continue
{
"id": deadline_alert_id,
"created_at": created_at or timezone.utcnow(),
"serialized_dag_id": serialized_dag_id,
"reference": reference_data,
"interval": interval_data,
"callback_def": callback_data,
},
)

if not validate_written_data(
dag_conn, deadline_alert_id, reference_data, interval_data, callback_data
):
dags_with_errors[dag_id].append(f"Invalid DeadlineAlert data: {serialized_alert}")
continue

yield deadline_alert_id
dag_conn.execute(
sa.text("""
UPDATE deadline
SET deadline_alert_id = :alert_id
WHERE dagrun_id IN (
SELECT dr.id
FROM dag_run dr
JOIN serialized_dag sd ON dr.dag_id = sd.dag_id
WHERE sd.id = :serialized_dag_id)
AND deadline_alert_id IS NULL
"""),
{"alert_id": deadline_alert_id, "serialized_dag_id": serialized_dag_id},
)
except Exception as e:
dags_with_errors[dag_id].append(f"Failed to process {serialized_alert}: {e}")
continue

try:
with _begin_nested_transaction(conn) as dag_conn:
migrated_alert_ids = list(_migrate_dag_deadlines(dag_conn))
if migrated_alert_ids:
uuid_strings = [str(uuid_id) for uuid_id in migrated_alert_ids]
update_dag_deadline_field(conn, serialized_dag_id, uuid_strings, dialect)

# Recalculate and update the dag_hash after modifying the deadline data to ensure
# it matches what write_dag() will compute later and avoid re-serialization.
updated_result = conn.execute(
update_dag_deadline_field(dag_conn, serialized_dag_id, uuid_strings, dialect)
updated_result = dag_conn.execute(
sa.text(
"SELECT data, data_compressed "
"FROM serialized_dag "
"WHERE id = :serialized_dag_id"
),
{"serialized_dag_id": serialized_dag_id},
).fetchone()

if updated_result:
updated_dag_data = get_dag_data(
updated_result.data, updated_result.data_compressed
)
# Import here to avoid a circular dependency issue
new_hash = hash_dag(updated_dag_data)

conn.execute(
dag_conn.execute(
sa.text(
"UPDATE serialized_dag "
"SET dag_hash = :new_hash "
"WHERE id = :serialized_dag_id"
),
{"new_hash": new_hash, "serialized_dag_id": serialized_dag_id},
)

# Commit the savepoint if everything succeeded for this Dag.
savepoint.commit()

migrated_alerts_count += len(migrated_alert_ids)
except (json.JSONDecodeError, KeyError, TypeError) as e:
dags_with_errors[dag_id].append(f"Could not process serialized Dag: {e}")
savepoint.rollback()
log.exception("Could not migrate deadline for dag %s", dag_id)
dags_with_errors[dag_id].append(f"Could not migrate deadline: {e}")

log.info("Batch complete", batch_num=batch_num, total_batches=total_batches)

Expand Down Expand Up @@ -652,60 +665,64 @@ def migrate_deadline_alert_data_back_to_serialized_dag() -> None:
processed_dags.append(dag_id)
last_dag_id = dag_id

# Create a savepoint for this Dag to allow rollback on error.
savepoint = conn.begin_nested()

# Validation that does not need a DB connection.
try:
dag_data = get_dag_data(data, data_compressed)
deadline_uuids = dag_data[DAG_KEY][DEADLINE_KEY]

if not isinstance(deadline_uuids, list) or not deadline_uuids:
continue

if not all(isinstance(uuid_val, str) for uuid_val in deadline_uuids):
log.warning("Dag has non-string deadline values, skipping", dag_id=dag_id)
continue

dags_with_deadlines.add(dag_id)
restored_deadline_objects = []

alert_result = conn.execute(
sa.select(
deadline_alert_table.c.reference,
deadline_alert_table.c.interval,
deadline_alert_table.c.callback_def,
).where(deadline_alert_table.c.serialized_dag_id == sa.bindparam("serialized_dag_id")),
{"serialized_dag_id": serialized_dag_id},
).fetchall()

if not alert_result:
dags_with_errors[dag_id].append(
f"Could not find deadline_alert for serialized_dag {serialized_dag_id}"
)
continue

for alert in alert_result:
deadline_object = {
Encoding.TYPE: ENCODING_TYPE,
Encoding.VAR: {
REFERENCE_KEY: alert.reference,
INTERVAL_KEY: float(alert.interval),
CALLBACK_KEY: alert.callback_def,
},
}
restored_deadline_objects.append(deadline_object)
restored_alerts_count += 1

# Replace the UUID array with the restored objects.
if restored_deadline_objects:
update_dag_deadline_field(conn, serialized_dag_id, restored_deadline_objects, dialect)

# Commit the savepoint if everything succeeded for this Dag.
savepoint.commit()
except (json.JSONDecodeError, KeyError, TypeError):
continue
deadline_uuids = (
dag_data.get(DAG_KEY, {}).get(DEADLINE_KEY)
if isinstance(dag_data.get(DAG_KEY), dict)
else None
)

if not isinstance(deadline_uuids, list) or not deadline_uuids:
continue

if not all(isinstance(uuid_val, str) for uuid_val in deadline_uuids):
log.warning("Dag has non-string deadline values, skipping", dag_id=dag_id)
continue

dags_with_deadlines.add(dag_id)

try:
with _begin_nested_transaction(conn) as dag_conn:
alert_result = dag_conn.execute(
sa.select(
deadline_alert_table.c.reference,
deadline_alert_table.c.interval,
deadline_alert_table.c.callback_def,
).where(
deadline_alert_table.c.serialized_dag_id == sa.bindparam("serialized_dag_id")
),
{"serialized_dag_id": serialized_dag_id},
).fetchall()

if not alert_result:
dags_with_errors[dag_id].append(
f"Could not find deadline_alert for serialized_dag {serialized_dag_id}"
)
continue

restored_deadline_objects = []
for alert in alert_result:
deadline_object = {
Encoding.TYPE: ENCODING_TYPE,
Encoding.VAR: {
REFERENCE_KEY: alert.reference,
INTERVAL_KEY: float(alert.interval),
CALLBACK_KEY: alert.callback_def,
},
}
restored_deadline_objects.append(deadline_object)
restored_alerts_count += 1
if restored_deadline_objects:
update_dag_deadline_field(
dag_conn, serialized_dag_id, restored_deadline_objects, dialect
)
except Exception as e:
log.exception("Could not restore deadline for dag %s", dag_id)
dags_with_errors[dag_id].append(f"Could not restore deadline: {e}")
savepoint.rollback()

log.info("Batch complete", batch_num=batch_num, total_batches=total_batches)

Expand Down
Loading