diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 1902c985be771..1afea71c00167 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -58,6 +58,7 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.timeout import timeout from airflow.utils.types import NOTSET +from airflow.utils.warnings import capture_with_reraise if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -67,13 +68,23 @@ class FileLoadStat(NamedTuple): - """Information about single file.""" + """ + Information about single file. + + :param file: Loaded file. + :param duration: Time spent on process file. + :param dag_num: Total number of DAGs loaded in this file. + :param task_num: Total number of Tasks loaded in this file. + :param dags: DAGs names loaded in this file. + :param warning_num: Total number of warnings captured from processing this file. + """ file: str duration: timedelta dag_num: int task_num: int dags: str + warning_num: int class DagBag(LoggingMixin): @@ -139,6 +150,7 @@ def __init__( # the file's last modified timestamp when we last read it self.file_last_changed: dict[str, datetime] = {} self.import_errors: dict[str, str] = {} + self.captured_warnings: dict[str, tuple[str, ...]] = {} self.has_logged = False self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True @@ -314,10 +326,21 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): # Ensure we don't pick up anything else we didn't mean to DagContext.autoregistered_dags.clear() - if filepath.endswith(".py") or not zipfile.is_zipfile(filepath): - mods = self._load_modules_from_file(filepath, safe_mode) - else: - mods = self._load_modules_from_zip(filepath, safe_mode) + self.captured_warnings.pop(filepath, None) + with capture_with_reraise() as captured_warnings: + if filepath.endswith(".py") or not zipfile.is_zipfile(filepath): + mods = self._load_modules_from_file(filepath, safe_mode) + else: + mods = self._load_modules_from_zip(filepath, safe_mode) + + if captured_warnings: + formatted_warnings = [] + for msg in captured_warnings: + category = msg.category.__name__ + if (module := msg.category.__module__) != "builtins": + category = f"{module}.{category}" + formatted_warnings.append(f"{msg.filename}:{msg.lineno}: {category}: {msg.message}") + self.captured_warnings[filepath] = tuple(formatted_warnings) found_dags = self._process_modules(filepath, mods, file_last_changed_on_disk) @@ -566,6 +589,7 @@ def collect_dags( dag_num=len(found_dags), task_num=sum(len(dag.tasks) for dag in found_dags), dags=str([dag.dag_id for dag in found_dags]), + warning_num=len(self.captured_warnings.get(filepath, [])), ) ) except Exception as e: diff --git a/airflow/utils/warnings.py b/airflow/utils/warnings.py new file mode 100644 index 0000000000000..bcff4c06fa286 --- /dev/null +++ b/airflow/utils/warnings.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import warnings +from collections.abc import Generator +from contextlib import contextmanager + + +@contextmanager +def capture_with_reraise() -> Generator[list[warnings.WarningMessage], None, None]: + """Capture warnings in context and re-raise it on exit from the context manager.""" + captured_warnings = [] + try: + with warnings.catch_warnings(record=True) as captured_warnings: + yield captured_warnings + finally: + if captured_warnings: + for cw in captured_warnings: + warnings.warn_explicit( + message=cw.message, + category=cw.category, + filename=cw.filename, + lineno=cw.lineno, + source=cw.source, + ) diff --git a/tests/dags/test_dag_warnings.py b/tests/dags/test_dag_warnings.py new file mode 100644 index 0000000000000..3eb3e4ccf7d8a --- /dev/null +++ b/tests/dags/test_dag_warnings.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import warnings +from datetime import datetime + +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.models.baseoperator import BaseOperator +from airflow.models.dag import DAG + +DAG_ID = "test_dag_warnings" + + +class TestOperator(BaseOperator): + def __init__(self, *, parameter: str | None = None, deprecated_parameter: str | None = None, **kwargs): + super().__init__(**kwargs) + if deprecated_parameter: + warnings.warn("Deprecated Parameter", category=RemovedInAirflow3Warning, stacklevel=2) + parameter = deprecated_parameter + self.parameter = parameter + + def execute(self, context): + return None + + +def some_warning(): + warnings.warn("Some Warning", category=UserWarning, stacklevel=1) + + +with DAG(DAG_ID, start_date=datetime(2024, 1, 1), schedule=None): + TestOperator(task_id="test-task", parameter="foo") + TestOperator(task_id="test-task-deprecated", deprecated_parameter="bar") + +some_warning() diff --git a/tests/dags/test_dag_warnings.zip b/tests/dags/test_dag_warnings.zip new file mode 100644 index 0000000000000..eb35a415ffa08 Binary files /dev/null and b/tests/dags/test_dag_warnings.zip differ diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 95baefff2c11a..3623aa38a554e 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -22,6 +22,7 @@ import pathlib import sys import textwrap +import warnings import zipfile from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -35,7 +36,7 @@ import airflow.example_dags from airflow import settings -from airflow.exceptions import SerializationError +from airflow.exceptions import RemovedInAirflow3Warning, SerializationError from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.serialized_dag import SerializedDagModel @@ -1145,3 +1146,49 @@ def test_dagbag_dag_collection(self): # test that dagbag.dags is not empty if collect_dags is True dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False) assert dagbag.dags + + @pytest.mark.filterwarnings("default::airflow.exceptions.RemovedInAirflow3Warning") + def test_dabgag_captured_warnings(self): + dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_warnings.py") + dagbag = DagBag(dag_folder=dag_file, include_examples=False, collect_dags=False) + assert dag_file not in dagbag.captured_warnings + + dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False) + assert len(dagbag.dag_ids) == 1 + assert dag_file in dagbag.captured_warnings + captured_warnings = dagbag.captured_warnings[dag_file] + assert len(captured_warnings) == 2 + assert dagbag.dagbag_stats[0].warning_num == 2 + + assert captured_warnings[0] == ( + f"{dag_file}:48: airflow.exceptions.RemovedInAirflow3Warning: Deprecated Parameter" + ) + assert captured_warnings[1] == f"{dag_file}:50: UserWarning: Some Warning" + + with warnings.catch_warnings(): + # Disable capture RemovedInAirflow3Warning, and it should be reflected in captured warnings + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False) + assert dag_file in dagbag.captured_warnings + assert len(dagbag.captured_warnings[dag_file]) == 1 + assert dagbag.dagbag_stats[0].warning_num == 1 + + # Disable all warnings, no captured warnings expected + warnings.simplefilter("ignore") + dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False) + assert dag_file not in dagbag.captured_warnings + assert dagbag.dagbag_stats[0].warning_num == 0 + + @pytest.mark.filterwarnings("default::airflow.exceptions.RemovedInAirflow3Warning") + def test_dabgag_captured_warnings_zip(self): + dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_warnings.zip") + in_zip_dag_file = f"{dag_file}/test_dag_warnings.py" + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + assert len(dagbag.dag_ids) == 1 + assert dag_file in dagbag.captured_warnings + captured_warnings = dagbag.captured_warnings[dag_file] + assert len(captured_warnings) == 2 + assert captured_warnings[0] == ( + f"{in_zip_dag_file}:48: airflow.exceptions.RemovedInAirflow3Warning: Deprecated Parameter" + ) + assert captured_warnings[1] == f"{in_zip_dag_file}:50: UserWarning: Some Warning" diff --git a/tests/utils/test_warnings.py b/tests/utils/test_warnings.py new file mode 100644 index 0000000000000..666e565d796e2 --- /dev/null +++ b/tests/utils/test_warnings.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import warnings + +import pytest + +from airflow.utils.warnings import capture_with_reraise + + +class TestCaptureWithReraise: + @staticmethod + def raise_warnings(): + warnings.warn("Foo", UserWarning, stacklevel=2) + warnings.warn("Bar", UserWarning, stacklevel=2) + warnings.warn("Baz", UserWarning, stacklevel=2) + + def test_capture_no_warnings(self): + with warnings.catch_warnings(): + warnings.simplefilter("error") + with capture_with_reraise() as cw: + pass + assert cw == [] + + def test_capture_warnings(self): + with pytest.warns(UserWarning, match="(Foo|Bar|Baz)") as ctx: + with capture_with_reraise() as cw: + self.raise_warnings() + assert len(cw) == 3 + assert len(ctx.list) == 3 + + def test_capture_warnings_with_parent_error_filter(self): + with warnings.catch_warnings(record=True) as records: + warnings.filterwarnings("error", message="Bar") + with capture_with_reraise() as cw: + with pytest.raises(UserWarning, match="Bar"): + self.raise_warnings() + assert len(cw) == 1 + assert len(records) == 1 + + def test_capture_warnings_with_parent_ignore_filter(self): + with warnings.catch_warnings(record=True) as records: + warnings.filterwarnings("ignore", message="Baz") + with capture_with_reraise() as cw: + self.raise_warnings() + assert len(cw) == 2 + assert len(records) == 2 + + def test_capture_warnings_with_filters(self): + with warnings.catch_warnings(record=True) as records: + with capture_with_reraise() as cw: + warnings.filterwarnings("ignore", message="Foo") + self.raise_warnings() + assert len(cw) == 2 + assert len(records) == 2 + + def test_capture_warnings_with_error_filters(self): + with warnings.catch_warnings(record=True) as records: + with capture_with_reraise() as cw: + warnings.filterwarnings("error", message="Bar") + with pytest.raises(UserWarning, match="Bar"): + self.raise_warnings() + assert len(cw) == 1 + assert len(records) == 1