diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 074d83585bd9a..9de1650b334a8 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -88,6 +88,7 @@ class DagFileStat(NamedTuple): last_finish_time: datetime | None last_duration: timedelta | None run_count: int + last_num_of_db_queries: int class DagParsingSignal(enum.Enum): @@ -351,7 +352,12 @@ class DagFileProcessorManager(LoggingMixin): """ DEFAULT_FILE_STAT = DagFileStat( - num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0 + num_dags=0, + import_errors=0, + last_finish_time=None, + last_duration=None, + run_count=0, + last_num_of_db_queries=0, ) def __init__( @@ -850,7 +856,18 @@ def _log_file_processing_stats(self, known_file_paths): # Last Runtime: If the process ran before, how long did it take to # finish in seconds # Last Run: When the file finished processing in the previous run. - headers = ["File Path", "PID", "Runtime", "# DAGs", "# Errors", "Last Runtime", "Last Run"] + # Last # of DB Queries: The number of queries performed to the + # Airflow database during last parsing of the file. + headers = [ + "File Path", + "PID", + "Runtime", + "# DAGs", + "# Errors", + "Last Runtime", + "Last Run", + "Last # of DB Queries", + ] rows = [] now = timezone.utcnow() @@ -866,14 +883,35 @@ def _log_file_processing_stats(self, known_file_paths): if last_run: seconds_ago = (now - last_run).total_seconds() Stats.gauge(f"dag_processing.last_run.seconds_ago.{file_name}", seconds_ago) + last_num_of_db_queries = self.get_last_num_of_db_queries(file_path) - rows.append((file_path, processor_pid, runtime, num_dags, num_errors, last_runtime, last_run)) + rows.append( + ( + file_path, + processor_pid, + runtime, + num_dags, + num_errors, + last_runtime, + last_run, + last_num_of_db_queries, + ) + ) # Sort by longest last runtime. (Can't sort None values in python3) rows.sort(key=lambda x: x[5] or 0.0, reverse=True) formatted_rows = [] - for file_path, pid, runtime, num_dags, num_errors, last_runtime, last_run in rows: + for ( + file_path, + pid, + runtime, + num_dags, + num_errors, + last_runtime, + last_run, + last_num_of_db_queries, + ) in rows: formatted_rows.append( ( file_path, @@ -883,6 +921,7 @@ def _log_file_processing_stats(self, known_file_paths): num_errors, f"{last_runtime:.2f}s" if last_runtime else None, last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None, + last_num_of_db_queries, ) ) log_str = ( @@ -946,6 +985,17 @@ def get_last_error_count(self, file_path) -> int | None: stat = self._file_stats.get(file_path) return stat.import_errors if stat else None + def get_last_num_of_db_queries(self, file_path) -> int | None: + """ + Retrieve the number of queries performed to the Airflow database during last parsing of the file. + + :param file_path: the path to the file that was processed + :return: the number of queries performed to the Airflow database during last parsing of the file, + or None if the file was never processed. + """ + stat = self._file_stats.get(file_path) + return stat.last_num_of_db_queries if stat else None + def get_last_finish_time(self, file_path) -> datetime | None: """ Retrieve the last completion time for processing a specific path. @@ -1031,13 +1081,14 @@ def _collect_results_from_processor(self, processor) -> None: last_finish_time = timezone.utcnow() if processor.result is not None: - num_dags, count_import_errors = processor.result + num_dags, count_import_errors, last_num_of_db_queries = processor.result else: self.log.error( "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code ) count_import_errors = -1 num_dags = 0 + last_num_of_db_queries = 0 last_duration = last_finish_time - processor.start_time stat = DagFileStat( @@ -1046,6 +1097,7 @@ def _collect_results_from_processor(self, processor) -> None: last_finish_time=last_finish_time, last_duration=last_duration, run_count=self.get_run_count(processor.file_path) + 1, + last_num_of_db_queries=last_num_of_db_queries, ) self._file_stats[processor.file_path] = stat file_name = Path(processor.file_path).stem @@ -1242,6 +1294,7 @@ def _kill_timed_out_processors(self): last_finish_time=now, last_duration=duration, run_count=self.get_run_count(file_path) + 1, + last_num_of_db_queries=0, ) self._file_stats[processor.file_path] = stat diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 8df64c9f1eb3e..ceb0476b8a382 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Iterable, Iterator from setproctitle import setproctitle -from sqlalchemy import delete, func, or_, select +from sqlalchemy import delete, event, func, or_, select from airflow import settings from airflow.api_internal.internal_api_call import internal_api_call @@ -99,7 +99,7 @@ def __init__( # The process that was launched to process the given . self._process: multiprocessing.process.BaseProcess | None = None # The result of DagFileProcessor.process_file(file_path). - self._result: tuple[int, int] | None = None + self._result: tuple[int, int, int] | None = None # Whether the process is done running. self._done = False # When the process started. @@ -162,7 +162,7 @@ def _handle_dag_file_processing(): log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) dag_file_processor = DagFileProcessor(dag_ids=dag_ids, dag_directory=dag_directory, log=log) - result: tuple[int, int] = dag_file_processor.process_file( + result: tuple[int, int, int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, callback_requests=callback_requests, @@ -350,7 +350,7 @@ def done(self) -> bool: return False @property - def result(self) -> tuple[int, int] | None: + def result(self) -> tuple[int, int, int] | None: """Result of running ``DagFileProcessor.process_file()``.""" if not self.done: raise AirflowException("Tried to get the result before it's done!") @@ -415,6 +415,7 @@ def __init__(self, dag_ids: list[str] | None, dag_directory: str, log: logging.L self._log = log self._dag_directory = dag_directory self.dag_warnings: set[tuple[str, str]] = set() + self._last_num_of_db_queries = 0 @classmethod @internal_api_call @@ -815,7 +816,7 @@ def process_file( callback_requests: list[CallbackRequest], pickle_dags: bool = False, session: Session = NEW_SESSION, - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: """ Process a Python file containing Airflow DAGs. @@ -833,16 +834,20 @@ def process_file( :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :param session: Sqlalchemy ORM Session - :return: number of dags found, count of import errors + :return: number of dags found, count of import errors, last number of db queries """ self.log.info("Processing file %s for tasks to queue", file_path) + @event.listens_for(session, "do_orm_execute") + def _count_db_queries(orm_execute_state): + self._last_num_of_db_queries += 1 + try: dagbag = DagFileProcessor._get_dagbag(file_path) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": file_path}) - return 0, 0 + return 0, 0, self._last_num_of_db_queries if dagbag.dags: self.log.info("DAG(s) %s retrieved from %s", ", ".join(map(repr, dagbag.dags)), file_path) @@ -859,7 +864,7 @@ def process_file( # parse error we still need to progress the state of TIs, # otherwise they might be stuck in queued/running for ever! self.execute_callbacks_without_dag(callback_requests, session) - return 0, len(dagbag.import_errors) + return 0, len(dagbag.import_errors), self._last_num_of_db_queries self.execute_callbacks(dagbag, callback_requests, session) session.commit() @@ -889,7 +894,7 @@ def process_file( except Exception: self.log.exception("Error logging DAG warnings.") - return len(dagbag.dags), len(dagbag.import_errors) + return len(dagbag.dags), len(dagbag.import_errors), self._last_num_of_db_queries @staticmethod @internal_api_call diff --git a/tests/dag_processing/test_job_runner.py b/tests/dag_processing/test_job_runner.py index 8e2bfbde623fa..b937a29ce2745 100644 --- a/tests/dag_processing/test_job_runner.py +++ b/tests/dag_processing/test_job_runner.py @@ -81,7 +81,7 @@ def __init__(self, file_path, pickle_dags, dag_ids, dag_directory, callbacks): writable.send("abc") writable.close() self._waitable_handle = readable - self._result = 0, 0 + self._result = 0, 0, 0 def start(self): pass @@ -270,7 +270,7 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): mock_processor.terminate.side_effect = None manager.processor._processors["missing_file.txt"] = mock_processor - manager.processor._file_stats["missing_file.txt"] = DagFileStat(0, 0, None, None, 0) + manager.processor._file_stats["missing_file.txt"] = DagFileStat(0, 0, None, None, 0, 0) manager.processor.set_file_paths(["abc.txt"]) assert manager.processor._processors == {} @@ -549,7 +549,7 @@ def test_recently_modified_file_is_parsed_with_mtime_mode( # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) manager.processor._file_stats = { - "file_1.py": DagFileStat(1, 0, last_finish_time, timedelta(seconds=1.0), 1), + "file_1.py": DagFileStat(1, 0, last_finish_time, timedelta(seconds=1.0), 1, 1), } with time_machine.travel(freezed_base_time): manager.processor.set_file_paths(dag_files) @@ -651,6 +651,7 @@ def test_scan_stale_dags(self): last_finish_time=timezone.utcnow() + timedelta(hours=1), last_duration=1, run_count=1, + last_num_of_db_queries=1, ) manager.processor._file_paths = [test_dag_path] manager.processor._file_stats[test_dag_path] = stat @@ -733,6 +734,7 @@ def test_scan_stale_dags_standalone_mode(self): last_finish_time=timezone.utcnow() + timedelta(hours=1), last_duration=1, run_count=1, + last_num_of_db_queries=1, ) manager.processor._file_paths = [test_dag_path] manager.processor._file_stats[test_dag_path] = stat diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index b79095994ab0e..124a13ff11573 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -19,6 +19,7 @@ import datetime import os +import pathlib import sys from unittest import mock from unittest.mock import MagicMock, patch @@ -39,6 +40,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType +from tests.test_utils.asserts import assert_queries_count from tests.test_utils.compat import ParseImportError from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( @@ -67,6 +69,7 @@ # Filename to be used for dags that are created in an ad-hoc manner and can be removed/ # created at runtime TEMP_DAG_FILENAME = "temp_dag.py" +TEST_DAG_FOLDER = pathlib.Path(__file__).parents[1].resolve() / "dags" @pytest.fixture(scope="class") @@ -1008,6 +1011,17 @@ def test_nullbyte_exception_handling_when_preimporting_airflow(self, mock_contex ) processor.start() + def test_counter_for_last_num_of_db_queries(self): + dag_filepath = TEST_DAG_FOLDER / "test_dag_for_db_queries_counter.py" + + with create_session() as session: + with assert_queries_count( + expected_count=94, + margin=10, + session=session, + ): + self._process_file(dag_filepath, TEST_DAG_FOLDER, session) + class TestProcessorAgent: @pytest.fixture(autouse=True) diff --git a/tests/dags/test_dag_for_db_queries_counter.py b/tests/dags/test_dag_for_db_queries_counter.py new file mode 100644 index 0000000000000..36bdaaa05c83b --- /dev/null +++ b/tests/dags/test_dag_for_db_queries_counter.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime + +from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator + +args = {"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)} + + +def create_dag(suffix): + dag = DAG( + dag_id=f"test_for_db_queries_counter__{suffix}", + default_args=args, + schedule="0 0 * * *", + dagrun_timeout=datetime.timedelta(minutes=60), + ) + + with dag: + EmptyOperator(task_id="test_task") + return dag + + +# 26 queries for parsing file with one DAG, 17 queries more for each new dag. +# As a result 94 queries for this file with 5 DAGs. +for i in range(0, 5): + globals()[f"dag_{i}"] = create_dag(f"dag_{i}") diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py index d06bb454de4d6..56bf8cc1fee57 100644 --- a/tests/test_utils/asserts.py +++ b/tests/test_utils/asserts.py @@ -22,13 +22,16 @@ import traceback from collections import Counter from contextlib import contextmanager -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple from sqlalchemy import event # Long import to not create a copy of the reference, but to refer to one place. import airflow.settings +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + log = logging.getLogger(__name__) @@ -91,17 +94,30 @@ class CountQueries: not be included. """ - def __init__(self, *, stacklevel: int = 1, stacklevel_from_module: str | None = None): + def __init__( + self, + *, + stacklevel: int = 1, + stacklevel_from_module: str | None = None, + session: Session | None = None, + ): self.result: Counter[str] = Counter() self.stacklevel = stacklevel self.stacklevel_from_module = stacklevel_from_module + self.session = session def __enter__(self): - event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) + if self.session: + event.listen(self.session, "do_orm_execute", self.after_cursor_execute) + else: + event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) return self.result def __exit__(self, type_, value, tb): - event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) + if self.session: + event.remove(self.session, "do_orm_execute", self.after_cursor_execute) + else: + event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) log.debug("Queries count: %d", sum(self.result.values())) def after_cursor_execute(self, *args, **kwargs): @@ -125,6 +141,7 @@ def assert_queries_count( margin: int = 0, stacklevel: int = 5, stacklevel_from_module: str | None = None, + session: Session | None = None, ): """ Asserts that the number of queries is as expected with the margin applied @@ -136,7 +153,9 @@ def assert_queries_count( :param stacklevel: limits the output stack trace to that numbers of frame :param stacklevel_from_module: Filter stack trace from specific module """ - with count_queries(stacklevel=stacklevel, stacklevel_from_module=stacklevel_from_module) as result: + with count_queries( + stacklevel=stacklevel, stacklevel_from_module=stacklevel_from_module, session=session + ) as result: yield None count = sum(result.values())