Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.timeout import timeout
from airflow.utils.types import NOTSET
from airflow.utils.warnings import capture_with_reraise

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand All @@ -67,13 +68,23 @@


class FileLoadStat(NamedTuple):
"""Information about single file."""
"""
Information about single file.

:param file: Loaded file.
:param duration: Time spent on process file.
:param dag_num: Total number of DAGs loaded in this file.
:param task_num: Total number of Tasks loaded in this file.
:param dags: DAGs names loaded in this file.
:param warning_num: Total number of warnings captured from processing this file.
"""

file: str
duration: timedelta
dag_num: int
task_num: int
dags: str
warning_num: int


class DagBag(LoggingMixin):
Expand Down Expand Up @@ -139,6 +150,7 @@ def __init__(
# the file's last modified timestamp when we last read it
self.file_last_changed: dict[str, datetime] = {}
self.import_errors: dict[str, str] = {}
self.captured_warnings: dict[str, tuple[str, ...]] = {}
self.has_logged = False
self.read_dags_from_db = read_dags_from_db
# Only used by read_dags_from_db=True
Expand Down Expand Up @@ -314,10 +326,21 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
# Ensure we don't pick up anything else we didn't mean to
DagContext.autoregistered_dags.clear()

if filepath.endswith(".py") or not zipfile.is_zipfile(filepath):
mods = self._load_modules_from_file(filepath, safe_mode)
else:
mods = self._load_modules_from_zip(filepath, safe_mode)
self.captured_warnings.pop(filepath, None)
with capture_with_reraise() as captured_warnings:
if filepath.endswith(".py") or not zipfile.is_zipfile(filepath):
mods = self._load_modules_from_file(filepath, safe_mode)
else:
mods = self._load_modules_from_zip(filepath, safe_mode)

if captured_warnings:
formatted_warnings = []
for msg in captured_warnings:
category = msg.category.__name__
if (module := msg.category.__module__) != "builtins":
category = f"{module}.{category}"
formatted_warnings.append(f"{msg.filename}:{msg.lineno}: {category}: {msg.message}")
self.captured_warnings[filepath] = tuple(formatted_warnings)

found_dags = self._process_modules(filepath, mods, file_last_changed_on_disk)

Expand Down Expand Up @@ -566,6 +589,7 @@ def collect_dags(
dag_num=len(found_dags),
task_num=sum(len(dag.tasks) for dag in found_dags),
dags=str([dag.dag_id for dag in found_dags]),
warning_num=len(self.captured_warnings.get(filepath, [])),
)
)
except Exception as e:
Expand Down
40 changes: 40 additions & 0 deletions airflow/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import warnings
from collections.abc import Generator
from contextlib import contextmanager


@contextmanager
def capture_with_reraise() -> Generator[list[warnings.WarningMessage], None, None]:
"""Capture warnings in context and re-raise it on exit from the context manager."""
captured_warnings = []
try:
with warnings.catch_warnings(record=True) as captured_warnings:
yield captured_warnings
finally:
if captured_warnings:
for cw in captured_warnings:
warnings.warn_explicit(
message=cw.message,
category=cw.category,
filename=cw.filename,
lineno=cw.lineno,
source=cw.source,
)
50 changes: 50 additions & 0 deletions tests/dags/test_dag_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import warnings
from datetime import datetime

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG

DAG_ID = "test_dag_warnings"


class TestOperator(BaseOperator):
def __init__(self, *, parameter: str | None = None, deprecated_parameter: str | None = None, **kwargs):
super().__init__(**kwargs)
if deprecated_parameter:
warnings.warn("Deprecated Parameter", category=RemovedInAirflow3Warning, stacklevel=2)
parameter = deprecated_parameter
self.parameter = parameter

def execute(self, context):
return None


def some_warning():
warnings.warn("Some Warning", category=UserWarning, stacklevel=1)


with DAG(DAG_ID, start_date=datetime(2024, 1, 1), schedule=None):
TestOperator(task_id="test-task", parameter="foo")
TestOperator(task_id="test-task-deprecated", deprecated_parameter="bar")

some_warning()
Binary file added tests/dags/test_dag_warnings.zip
Binary file not shown.
49 changes: 48 additions & 1 deletion tests/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pathlib
import sys
import textwrap
import warnings
import zipfile
from copy import deepcopy
from datetime import datetime, timedelta, timezone
Expand All @@ -35,7 +36,7 @@

import airflow.example_dags
from airflow import settings
from airflow.exceptions import SerializationError
from airflow.exceptions import RemovedInAirflow3Warning, SerializationError
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -1145,3 +1146,49 @@ def test_dagbag_dag_collection(self):
# test that dagbag.dags is not empty if collect_dags is True
dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False)
assert dagbag.dags

@pytest.mark.filterwarnings("default::airflow.exceptions.RemovedInAirflow3Warning")
def test_dabgag_captured_warnings(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_warnings.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False, collect_dags=False)
assert dag_file not in dagbag.captured_warnings

dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert len(dagbag.dag_ids) == 1
assert dag_file in dagbag.captured_warnings
captured_warnings = dagbag.captured_warnings[dag_file]
assert len(captured_warnings) == 2
assert dagbag.dagbag_stats[0].warning_num == 2

assert captured_warnings[0] == (
f"{dag_file}:48: airflow.exceptions.RemovedInAirflow3Warning: Deprecated Parameter"
)
assert captured_warnings[1] == f"{dag_file}:50: UserWarning: Some Warning"

with warnings.catch_warnings():
# Disable capture RemovedInAirflow3Warning, and it should be reflected in captured warnings
warnings.simplefilter("ignore", RemovedInAirflow3Warning)
dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert dag_file in dagbag.captured_warnings
assert len(dagbag.captured_warnings[dag_file]) == 1
assert dagbag.dagbag_stats[0].warning_num == 1

# Disable all warnings, no captured warnings expected
warnings.simplefilter("ignore")
dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert dag_file not in dagbag.captured_warnings
assert dagbag.dagbag_stats[0].warning_num == 0

@pytest.mark.filterwarnings("default::airflow.exceptions.RemovedInAirflow3Warning")
def test_dabgag_captured_warnings_zip(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_warnings.zip")
in_zip_dag_file = f"{dag_file}/test_dag_warnings.py"
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
assert len(dagbag.dag_ids) == 1
assert dag_file in dagbag.captured_warnings
captured_warnings = dagbag.captured_warnings[dag_file]
assert len(captured_warnings) == 2
assert captured_warnings[0] == (
f"{in_zip_dag_file}:48: airflow.exceptions.RemovedInAirflow3Warning: Deprecated Parameter"
)
assert captured_warnings[1] == f"{in_zip_dag_file}:50: UserWarning: Some Warning"
79 changes: 79 additions & 0 deletions tests/utils/test_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import warnings

import pytest

from airflow.utils.warnings import capture_with_reraise


class TestCaptureWithReraise:
@staticmethod
def raise_warnings():
warnings.warn("Foo", UserWarning, stacklevel=2)
warnings.warn("Bar", UserWarning, stacklevel=2)
warnings.warn("Baz", UserWarning, stacklevel=2)

def test_capture_no_warnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with capture_with_reraise() as cw:
pass
assert cw == []

def test_capture_warnings(self):
with pytest.warns(UserWarning, match="(Foo|Bar|Baz)") as ctx:
with capture_with_reraise() as cw:
self.raise_warnings()
assert len(cw) == 3
assert len(ctx.list) == 3

def test_capture_warnings_with_parent_error_filter(self):
with warnings.catch_warnings(record=True) as records:
warnings.filterwarnings("error", message="Bar")
with capture_with_reraise() as cw:
with pytest.raises(UserWarning, match="Bar"):
self.raise_warnings()
assert len(cw) == 1
assert len(records) == 1

def test_capture_warnings_with_parent_ignore_filter(self):
with warnings.catch_warnings(record=True) as records:
warnings.filterwarnings("ignore", message="Baz")
with capture_with_reraise() as cw:
self.raise_warnings()
assert len(cw) == 2
assert len(records) == 2

def test_capture_warnings_with_filters(self):
with warnings.catch_warnings(record=True) as records:
with capture_with_reraise() as cw:
warnings.filterwarnings("ignore", message="Foo")
self.raise_warnings()
assert len(cw) == 2
assert len(records) == 2

def test_capture_warnings_with_error_filters(self):
with warnings.catch_warnings(record=True) as records:
with capture_with_reraise() as cw:
warnings.filterwarnings("error", message="Bar")
with pytest.raises(UserWarning, match="Bar"):
self.raise_warnings()
assert len(cw) == 1
assert len(records) == 1