From 494832ca35db116959050e9c01921939ebbd5b8e Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Wed, 24 Sep 2025 20:53:16 +0900 Subject: [PATCH 1/4] fix gc issue --- airflow-core/src/airflow/settings.py | 46 +++++++++++++++------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index 5a87384487641..9ddb1f8e65456 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -370,35 +370,24 @@ def _configure_async_session() -> None: ) -def configure_orm(disable_connection_pool=False, pool_class=None): - """Configure ORM using SQLAlchemy.""" +def _configure_session(disable_connection_pool: bool, pool_class): + """(Re)create engine, NonScopedSession, Session using SQLAlchemy.""" from airflow._shared.secrets_masker import mask_secret - if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): - from airflow.exceptions import AirflowConfigException - - raise AirflowConfigException( - f"Cannot use relative path: `{SQL_ALCHEMY_CONN}` to connect to sqlite. " - "Please use absolute path such as `sqlite:////tmp/airflow.db`." - ) + global NonScopedSession, Session, engine - global NonScopedSession - global Session - global engine + log.debug("Setting up DB connection pool (PID %s)", os.getpid()) if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": # Skip DB initialization in unit tests, if DB tests are skipped Session = SkipDBTestsSession engine = None return - log.debug("Setting up DB connection pool (PID %s)", os.getpid()) + engine_args = prepare_engine_args(disable_connection_pool, pool_class) connect_args = _get_connect_args("sync") if SQL_ALCHEMY_CONN.startswith("sqlite"): - # FastAPI runs sync endpoints in a separate thread. SQLite does not allow - # to use objects created in another threads by default. Allowing that in test - # to so the `test` thread and the tested endpoints can use common objects. connect_args["check_same_thread"] = False engine = create_engine( @@ -407,7 +396,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): **engine_args, future=True, ) - _configure_async_session() + mask_secret(engine.url.password) setup_event_handlers(engine) @@ -420,17 +409,32 @@ def configure_orm(disable_connection_pool=False, pool_class=None): autoflush=False, expire_on_commit=False, ) + NonScopedSession = _session_maker(engine) Session = scoped_session(NonScopedSession) + +def configure_orm(disable_connection_pool=False, pool_class=None): + """Configure ORM using SQLAlchemy.""" + if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): + from airflow.exceptions import AirflowConfigException + + raise AirflowConfigException( + f"Cannot use relative path: `{SQL_ALCHEMY_CONN}` to connect to sqlite. " + "Please use absolute path such as `sqlite:////tmp/airflow.db`." + ) + + _configure_session(disable_connection_pool, pool_class) + _configure_async_session() + if register_at_fork := getattr(os, "register_at_fork", None): # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork def clean_in_fork(): _globals = globals() - if engine := _globals.get("engine"): - engine.dispose(close=False) - if async_engine := _globals.get("async_engine"): - async_engine.sync_engine.dispose(close=False) + if _globals.get("engine"): + _configure_session(disable_connection_pool, pool_class) + if _globals.get("async_engine"): + _configure_async_session() # Won't work on Windows register_at_fork(after_in_child=clean_in_fork) From 0688a2a8d94088da526ee8d1cfbd449405080948 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Sun, 19 Oct 2025 22:03:07 +0900 Subject: [PATCH 2/4] add test --- airflow-core/src/airflow/settings.py | 1 + .../tests/unit/core/test_db_pool_safety.py | 388 ++++++++++++++++++ 2 files changed, 389 insertions(+) create mode 100644 airflow-core/tests/unit/core/test_db_pool_safety.py diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index 9ddb1f8e65456..605e2186f57be 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -416,6 +416,7 @@ def _configure_session(disable_connection_pool: bool, pool_class): def configure_orm(disable_connection_pool=False, pool_class=None): """Configure ORM using SQLAlchemy.""" + print(SQL_ALCHEMY_CONN) if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): from airflow.exceptions import AirflowConfigException diff --git a/airflow-core/tests/unit/core/test_db_pool_safety.py b/airflow-core/tests/unit/core/test_db_pool_safety.py new file mode 100644 index 0000000000000..6c423705ce8ef --- /dev/null +++ b/airflow-core/tests/unit/core/test_db_pool_safety.py @@ -0,0 +1,388 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import weakref +import gc as gc_module + +import pytest +from sqlalchemy.exc import OperationalError + +from airflow import settings +from airflow.utils.session import create_session + + +@pytest.mark.backend("mysql") +class TestLocalTaskJobForkSafety: + """ + Test fork safety for LocalTaskJobRunner with MySQL backend + """ + def test_old_dispose_causes_parent_connection_loss(self): + """ + BEFORE FIX: Demonstrates the problem + Using dispose(close=False) in child causes parent connection to die + """ + # WeakRef로 Pool GC 추적 + gc_callback_called = [] + + # Airflow의 실제 engine 사용 + engine = settings.engine + pool = engine.pool + weak_pool = weakref.ref(pool, lambda ref: gc_callback_called.append(True)) + + if register_at_fork := getattr(os, "register_at_fork", None): + # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork + def clean_in_fork(): + print("engine disposed") + engine.dispose(close=False) + + # Won't work on Windows + register_at_fork(after_in_child=clean_in_fork) + + with engine.connect() as conn: + thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() + + # Fork + pid = os.fork() + if pid == 0: # Child + try: + gc_module.collect() + finally: + os._exit(0) + + # Parent + os.waitpid(pid, 0) + + # Verify GC happened + assert len(gc_callback_called) > 0, "Pool was garbage collected in child" + + # Verify connection is dead + with pytest.raises( + OperationalError, + match="MySQL server has gone away|2006|2013" + ): + conn.execute("SELECT 1") + +# def test_new_engine_creation_preserves_parent_connection(self): +# """ +# AFTER FIX: Demonstrates the solution +# Creating new engine in child preserves parent connection +# """ +# gc_callback_called = [] +# +# engine = settings.engine +# pool = engine.pool +# weak_pool = weakref.ref(pool, lambda ref: gc_callback_called.append(True)) +# +# with engine.connect() as conn: +# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() +# +# pid = os.fork() +# if pid == 0: # Child +# try: +# # NEW approach - create new engine +# from airflow.settings import configure_orm +# configure_orm() +# # Parent engine/pool은 건드리지 않음! +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # Verify NO GC happened to parent pool +# assert len(gc_callback_called) == 0, "Pool was NOT garbage collected" +# +# # Verify connection is alive +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 +# +# # Verify same MySQL thread_id +# current_id = conn.execute("SELECT CONNECTION_ID()").scalar() +# assert current_id == thread_id +# +# @pytest.mark.parametrize("approach,expect_failure", [ +# ("old_dispose", True), +# ("new_engine", False), +# ]) +# def test_fork_approaches_comparison(self, approach, expect_failure): +# """ +# Parameterized test comparing old vs new approach +# Shows both behaviors in one test +# """ +# engine = settings.engine +# +# with engine.connect() as conn: +# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() +# +# pid = os.fork() +# if pid == 0: # Child +# try: +# if approach == "old_dispose": +# # OLD: causes parent connection to die +# engine.dispose(close=False) +# gc_module.collect() +# else: # new_engine +# # NEW: preserves parent connection +# from airflow.settings import configure_orm +# configure_orm() +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # Verify expected behavior +# if expect_failure: +# with pytest.raises(OperationalError): +# conn.execute("SELECT 1") +# else: +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 +# current_id = conn.execute("SELECT CONNECTION_ID()").scalar() +# assert current_id == thread_id +# +# +# @pytest.mark.backend("mysql") +# class TestMySQLProcessListVerification: +# """ +# Verify connection persistence using MySQL's SHOW PROCESSLIST +# """ +# +# def test_connection_survives_in_processlist_with_new_approach(self): +# """ +# Verify parent connection persists in MySQL SHOW PROCESSLIST +# after child fork with new approach +# """ +# engine = settings.engine +# +# with create_session() as session: +# # Get connection ID +# thread_id = session.execute("SELECT CONNECTION_ID()").scalar() +# +# # Verify in processlist before fork +# result = session.execute( +# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", +# {"tid": thread_id} +# ).scalar() +# assert result == 1, "Connection not found in PROCESSLIST before fork" +# +# # Fork with NEW approach +# pid = os.fork() +# if pid == 0: +# try: +# from airflow.settings import configure_orm +# configure_orm() +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # Verify connection still in processlist after fork +# result = session.execute( +# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", +# {"tid": thread_id} +# ).scalar() +# assert result == 1, "Connection disappeared from SHOW PROCESSLIST" +# +# def test_connection_disappears_from_processlist_with_old_approach(self): +# """ +# Demonstrate that old approach causes connection to disappear +# from SHOW PROCESSLIST +# """ +# engine = settings.engine +# +# with create_session() as session: +# thread_id = session.execute("SELECT CONNECTION_ID()").scalar() +# +# # Verify in processlist before fork +# result = session.execute( +# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", +# {"tid": thread_id} +# ).scalar() +# assert result == 1 +# +# # Fork with OLD approach +# pid = os.fork() +# if pid == 0: +# try: +# engine.dispose(close=False) +# gc_module.collect() +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # Connection should be gone (old bug) +# result = session.execute( +# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", +# {"tid": thread_id} +# ).scalar() +# assert result == 0, "Connection should have disappeared (old bug)" +# +# +# @pytest.mark.backend("mysql") +# class TestFileDescriptorState: +# """ +# Test file descriptor state after fork +# """ +# +# def test_fd_remains_valid_with_new_approach(self): +# """ +# Verify parent's file descriptor remains valid +# """ +# import fcntl +# +# engine = settings.engine +# +# with engine.connect() as conn: +# # Get the underlying socket file descriptor +# fd = conn.connection.connection.fileno() +# +# # Verify fd is valid +# flags = fcntl.fcntl(fd, fcntl.F_GETFD) +# assert flags >= 0, "FD should be valid before fork" +# +# # Fork with NEW approach +# pid = os.fork() +# if pid == 0: +# try: +# from airflow.settings import configure_orm +# configure_orm() +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # Verify fd is still valid in parent +# flags = fcntl.fcntl(fd, fcntl.F_GETFD) +# assert flags >= 0, "FD should still be valid after fork" +# +# # Verify can still use connection +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 +# +# +# @pytest.mark.backend("postgres") +# class TestPostgreSQLBaseline: +# """ +# PostgreSQL doesn't have this issue (baseline comparison) +# """ +# +# def test_postgres_safe_with_old_approach(self): +# """ +# Demonstrate that PostgreSQL is safe even with old approach +# """ +# engine = settings.engine +# +# with engine.connect() as conn: +# # Query works before fork +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 +# +# pid = os.fork() +# if pid == 0: +# try: +# # Even with old approach, PostgreSQL is safe +# engine.dispose(close=False) +# gc_module.collect() +# finally: +# os._exit(0) +# +# os.waitpid(pid, 0) +# +# # PostgreSQL connection should still work +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 +# +# +# # Fixtures +# +# @pytest.fixture +# def clean_engine(): +# """ +# Ensure clean engine state for each test +# """ +# from airflow.settings import configure_orm +# +# # Setup +# configure_orm() +# engine = settings.engine +# +# yield engine +# +# # Teardown +# engine.dispose() +# +# +# @pytest.fixture +# def isolated_fork_test(): +# """ +# Fixture for isolated fork testing +# Returns helper functions for fork testing +# """ +# children = [] +# +# def fork_and_wait(child_func): +# """Fork, execute child_func, and wait""" +# pid = os.fork() +# if pid == 0: # Child +# try: +# child_func() +# finally: +# os._exit(0) +# else: # Parent +# children.append(pid) +# os.waitpid(pid, 0) +# +# yield fork_and_wait +# +# # Cleanup: wait for any remaining children +# for pid in children: +# try: +# os.waitpid(pid, os.WNOHANG) +# except: +# pass +# +# +# # Example usage with fixtures +# +# @pytest.mark.backend("mysql") +# def test_with_fixture(isolated_fork_test): +# """ +# Example test using the isolated_fork_test fixture +# """ +# engine = settings.engine +# +# with engine.connect() as conn: +# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() +# +# def child_work(): +# from airflow.settings import configure_orm +# configure_orm() +# +# # Use fixture helper +# isolated_fork_test(child_work) +# +# # Verify parent connection still works +# result = conn.execute("SELECT 1").scalar() +# assert result == 1 + + +# Markers for different test categories + +pytestmark = [ + pytest.mark.db_test, # Airflow의 DB 테스트 마커 +] From 24525c894f44e00fdcd8c5028892a29cab78e547 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Mon, 20 Oct 2025 21:04:08 +0900 Subject: [PATCH 3/4] add test --- airflow-core/src/airflow/settings.py | 93 +++-- .../tests/unit/core/test_db_pool_safety.py | 388 ------------------ .../tests/unit/core/test_register_at_fork.py | 228 ++++++++++ 3 files changed, 276 insertions(+), 433 deletions(-) delete mode 100644 airflow-core/tests/unit/core/test_db_pool_safety.py create mode 100644 airflow-core/tests/unit/core/test_register_at_fork.py diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index 605e2186f57be..e9b2729967923 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -370,70 +370,73 @@ def _configure_async_session() -> None: ) -def _configure_session(disable_connection_pool: bool, pool_class): - """(Re)create engine, NonScopedSession, Session using SQLAlchemy.""" - from airflow._shared.secrets_masker import mask_secret +def configure_orm(disable_connection_pool=False, pool_class=None): + """Configure ORM using SQLAlchemy.""" + if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): + from airflow.exceptions import AirflowConfigException - global NonScopedSession, Session, engine + raise AirflowConfigException( + f"Cannot use relative path: `{SQL_ALCHEMY_CONN}` to connect to sqlite. " + "Please use absolute path such as `sqlite:////tmp/airflow.db`." + ) - log.debug("Setting up DB connection pool (PID %s)", os.getpid()) + def _configure_session(): + """(Re)create engine, NonScopedSession, Session using SQLAlchemy.""" + global NonScopedSession + global Session + global engine - if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": - # Skip DB initialization in unit tests, if DB tests are skipped - Session = SkipDBTestsSession - engine = None - return + from airflow._shared.secrets_masker import mask_secret - engine_args = prepare_engine_args(disable_connection_pool, pool_class) + log.debug("Setting up DB connection pool (PID %s)", os.getpid()) - connect_args = _get_connect_args("sync") - if SQL_ALCHEMY_CONN.startswith("sqlite"): - connect_args["check_same_thread"] = False + if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": + # Skip DB initialization in unit tests, if DB tests are skipped + Session = SkipDBTestsSession + engine = None + return - engine = create_engine( - SQL_ALCHEMY_CONN, - connect_args=connect_args, - **engine_args, - future=True, - ) + engine_args = prepare_engine_args(disable_connection_pool, pool_class) - mask_secret(engine.url.password) - setup_event_handlers(engine) + connect_args = _get_connect_args("sync") + if SQL_ALCHEMY_CONN.startswith("sqlite"): + connect_args["check_same_thread"] = False - if conf.has_option("database", "sql_alchemy_session_maker"): - _session_maker = conf.getimport("database", "sql_alchemy_session_maker") - else: - _session_maker = functools.partial( - sessionmaker, - autocommit=False, - autoflush=False, - expire_on_commit=False, + engine = create_engine( + SQL_ALCHEMY_CONN, + connect_args=connect_args, + **engine_args, + future=True, ) - NonScopedSession = _session_maker(engine) - Session = scoped_session(NonScopedSession) - + mask_secret(engine.url.password) + setup_event_handlers(engine) -def configure_orm(disable_connection_pool=False, pool_class=None): - """Configure ORM using SQLAlchemy.""" - print(SQL_ALCHEMY_CONN) - if _is_sqlite_db_path_relative(SQL_ALCHEMY_CONN): - from airflow.exceptions import AirflowConfigException + if conf.has_option("database", "sql_alchemy_session_maker"): + _session_maker = conf.getimport("database", "sql_alchemy_session_maker") + else: + _session_maker = functools.partial( + sessionmaker, + autocommit=False, + autoflush=False, + expire_on_commit=False, + ) - raise AirflowConfigException( - f"Cannot use relative path: `{SQL_ALCHEMY_CONN}` to connect to sqlite. " - "Please use absolute path such as `sqlite:////tmp/airflow.db`." - ) + NonScopedSession = _session_maker(engine) + Session = scoped_session(NonScopedSession) - _configure_session(disable_connection_pool, pool_class) + _configure_session() _configure_async_session() if register_at_fork := getattr(os, "register_at_fork", None): # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork def clean_in_fork(): _globals = globals() - if _globals.get("engine"): - _configure_session(disable_connection_pool, pool_class) + if engine := _globals.get("engine"): + if engine.dialect.name == "mysql": + _configure_session() + else: + engine.dispose(close=False) if _globals.get("async_engine"): _configure_async_session() diff --git a/airflow-core/tests/unit/core/test_db_pool_safety.py b/airflow-core/tests/unit/core/test_db_pool_safety.py deleted file mode 100644 index 6c423705ce8ef..0000000000000 --- a/airflow-core/tests/unit/core/test_db_pool_safety.py +++ /dev/null @@ -1,388 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -import weakref -import gc as gc_module - -import pytest -from sqlalchemy.exc import OperationalError - -from airflow import settings -from airflow.utils.session import create_session - - -@pytest.mark.backend("mysql") -class TestLocalTaskJobForkSafety: - """ - Test fork safety for LocalTaskJobRunner with MySQL backend - """ - def test_old_dispose_causes_parent_connection_loss(self): - """ - BEFORE FIX: Demonstrates the problem - Using dispose(close=False) in child causes parent connection to die - """ - # WeakRef로 Pool GC 추적 - gc_callback_called = [] - - # Airflow의 실제 engine 사용 - engine = settings.engine - pool = engine.pool - weak_pool = weakref.ref(pool, lambda ref: gc_callback_called.append(True)) - - if register_at_fork := getattr(os, "register_at_fork", None): - # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork - def clean_in_fork(): - print("engine disposed") - engine.dispose(close=False) - - # Won't work on Windows - register_at_fork(after_in_child=clean_in_fork) - - with engine.connect() as conn: - thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() - - # Fork - pid = os.fork() - if pid == 0: # Child - try: - gc_module.collect() - finally: - os._exit(0) - - # Parent - os.waitpid(pid, 0) - - # Verify GC happened - assert len(gc_callback_called) > 0, "Pool was garbage collected in child" - - # Verify connection is dead - with pytest.raises( - OperationalError, - match="MySQL server has gone away|2006|2013" - ): - conn.execute("SELECT 1") - -# def test_new_engine_creation_preserves_parent_connection(self): -# """ -# AFTER FIX: Demonstrates the solution -# Creating new engine in child preserves parent connection -# """ -# gc_callback_called = [] -# -# engine = settings.engine -# pool = engine.pool -# weak_pool = weakref.ref(pool, lambda ref: gc_callback_called.append(True)) -# -# with engine.connect() as conn: -# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() -# -# pid = os.fork() -# if pid == 0: # Child -# try: -# # NEW approach - create new engine -# from airflow.settings import configure_orm -# configure_orm() -# # Parent engine/pool은 건드리지 않음! -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # Verify NO GC happened to parent pool -# assert len(gc_callback_called) == 0, "Pool was NOT garbage collected" -# -# # Verify connection is alive -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 -# -# # Verify same MySQL thread_id -# current_id = conn.execute("SELECT CONNECTION_ID()").scalar() -# assert current_id == thread_id -# -# @pytest.mark.parametrize("approach,expect_failure", [ -# ("old_dispose", True), -# ("new_engine", False), -# ]) -# def test_fork_approaches_comparison(self, approach, expect_failure): -# """ -# Parameterized test comparing old vs new approach -# Shows both behaviors in one test -# """ -# engine = settings.engine -# -# with engine.connect() as conn: -# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() -# -# pid = os.fork() -# if pid == 0: # Child -# try: -# if approach == "old_dispose": -# # OLD: causes parent connection to die -# engine.dispose(close=False) -# gc_module.collect() -# else: # new_engine -# # NEW: preserves parent connection -# from airflow.settings import configure_orm -# configure_orm() -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # Verify expected behavior -# if expect_failure: -# with pytest.raises(OperationalError): -# conn.execute("SELECT 1") -# else: -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 -# current_id = conn.execute("SELECT CONNECTION_ID()").scalar() -# assert current_id == thread_id -# -# -# @pytest.mark.backend("mysql") -# class TestMySQLProcessListVerification: -# """ -# Verify connection persistence using MySQL's SHOW PROCESSLIST -# """ -# -# def test_connection_survives_in_processlist_with_new_approach(self): -# """ -# Verify parent connection persists in MySQL SHOW PROCESSLIST -# after child fork with new approach -# """ -# engine = settings.engine -# -# with create_session() as session: -# # Get connection ID -# thread_id = session.execute("SELECT CONNECTION_ID()").scalar() -# -# # Verify in processlist before fork -# result = session.execute( -# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", -# {"tid": thread_id} -# ).scalar() -# assert result == 1, "Connection not found in PROCESSLIST before fork" -# -# # Fork with NEW approach -# pid = os.fork() -# if pid == 0: -# try: -# from airflow.settings import configure_orm -# configure_orm() -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # Verify connection still in processlist after fork -# result = session.execute( -# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", -# {"tid": thread_id} -# ).scalar() -# assert result == 1, "Connection disappeared from SHOW PROCESSLIST" -# -# def test_connection_disappears_from_processlist_with_old_approach(self): -# """ -# Demonstrate that old approach causes connection to disappear -# from SHOW PROCESSLIST -# """ -# engine = settings.engine -# -# with create_session() as session: -# thread_id = session.execute("SELECT CONNECTION_ID()").scalar() -# -# # Verify in processlist before fork -# result = session.execute( -# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", -# {"tid": thread_id} -# ).scalar() -# assert result == 1 -# -# # Fork with OLD approach -# pid = os.fork() -# if pid == 0: -# try: -# engine.dispose(close=False) -# gc_module.collect() -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # Connection should be gone (old bug) -# result = session.execute( -# "SELECT COUNT(*) FROM information_schema.PROCESSLIST WHERE ID = :tid", -# {"tid": thread_id} -# ).scalar() -# assert result == 0, "Connection should have disappeared (old bug)" -# -# -# @pytest.mark.backend("mysql") -# class TestFileDescriptorState: -# """ -# Test file descriptor state after fork -# """ -# -# def test_fd_remains_valid_with_new_approach(self): -# """ -# Verify parent's file descriptor remains valid -# """ -# import fcntl -# -# engine = settings.engine -# -# with engine.connect() as conn: -# # Get the underlying socket file descriptor -# fd = conn.connection.connection.fileno() -# -# # Verify fd is valid -# flags = fcntl.fcntl(fd, fcntl.F_GETFD) -# assert flags >= 0, "FD should be valid before fork" -# -# # Fork with NEW approach -# pid = os.fork() -# if pid == 0: -# try: -# from airflow.settings import configure_orm -# configure_orm() -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # Verify fd is still valid in parent -# flags = fcntl.fcntl(fd, fcntl.F_GETFD) -# assert flags >= 0, "FD should still be valid after fork" -# -# # Verify can still use connection -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 -# -# -# @pytest.mark.backend("postgres") -# class TestPostgreSQLBaseline: -# """ -# PostgreSQL doesn't have this issue (baseline comparison) -# """ -# -# def test_postgres_safe_with_old_approach(self): -# """ -# Demonstrate that PostgreSQL is safe even with old approach -# """ -# engine = settings.engine -# -# with engine.connect() as conn: -# # Query works before fork -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 -# -# pid = os.fork() -# if pid == 0: -# try: -# # Even with old approach, PostgreSQL is safe -# engine.dispose(close=False) -# gc_module.collect() -# finally: -# os._exit(0) -# -# os.waitpid(pid, 0) -# -# # PostgreSQL connection should still work -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 -# -# -# # Fixtures -# -# @pytest.fixture -# def clean_engine(): -# """ -# Ensure clean engine state for each test -# """ -# from airflow.settings import configure_orm -# -# # Setup -# configure_orm() -# engine = settings.engine -# -# yield engine -# -# # Teardown -# engine.dispose() -# -# -# @pytest.fixture -# def isolated_fork_test(): -# """ -# Fixture for isolated fork testing -# Returns helper functions for fork testing -# """ -# children = [] -# -# def fork_and_wait(child_func): -# """Fork, execute child_func, and wait""" -# pid = os.fork() -# if pid == 0: # Child -# try: -# child_func() -# finally: -# os._exit(0) -# else: # Parent -# children.append(pid) -# os.waitpid(pid, 0) -# -# yield fork_and_wait -# -# # Cleanup: wait for any remaining children -# for pid in children: -# try: -# os.waitpid(pid, os.WNOHANG) -# except: -# pass -# -# -# # Example usage with fixtures -# -# @pytest.mark.backend("mysql") -# def test_with_fixture(isolated_fork_test): -# """ -# Example test using the isolated_fork_test fixture -# """ -# engine = settings.engine -# -# with engine.connect() as conn: -# thread_id = conn.execute("SELECT CONNECTION_ID()").scalar() -# -# def child_work(): -# from airflow.settings import configure_orm -# configure_orm() -# -# # Use fixture helper -# isolated_fork_test(child_work) -# -# # Verify parent connection still works -# result = conn.execute("SELECT 1").scalar() -# assert result == 1 - - -# Markers for different test categories - -pytestmark = [ - pytest.mark.db_test, # Airflow의 DB 테스트 마커 -] diff --git a/airflow-core/tests/unit/core/test_register_at_fork.py b/airflow-core/tests/unit/core/test_register_at_fork.py new file mode 100644 index 0000000000000..2387f64f2ab63 --- /dev/null +++ b/airflow-core/tests/unit/core/test_register_at_fork.py @@ -0,0 +1,228 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import gc as gc_module +import os +import weakref +from contextlib import contextmanager + +import pytest +from sqlalchemy import create_engine, event, text +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import create_async_engine + +from airflow.settings import SQL_ALCHEMY_CONN, SQL_ALCHEMY_CONN_ASYNC + +pytestmark = [ + pytest.mark.db_test, +] + + +# Common helpers +def create_test_engine(): + """Create a test engine with standard configuration.""" + return create_engine( + SQL_ALCHEMY_CONN, + pool_size=5, + max_overflow=10, + pool_recycle=1800, + future=True, + ) + + +def create_test_async_engine(): + """Create a test async engine with standard configuration.""" + return create_async_engine( + SQL_ALCHEMY_CONN_ASYNC, + pool_size=5, + max_overflow=10, + pool_recycle=1800, + future=True, + ) + + +def get_connection_id_query(dialect_name): + return "SELECT CONNECTION_ID()" if dialect_name == "mysql" else "SELECT pg_backend_pid()" + + +def register_connection_finalizers(engine): + """Register weakref finalizers to track connection cleanup.""" + + @event.listens_for(engine, "connect") + def set_mysql_timezone(dbapi_connection, connection_record): + weakref.finalize(dbapi_connection, lambda: print(f"finalize dbapi_connection in {os.getpid()}")) + weakref.finalize(connection_record, lambda: print(f"finalize connection_record in {os.getpid()}")) + + +@contextmanager +def fork_process(): + """ + Context manager for forking a process which rigger garbage collection manually + """ + + pid = os.fork() + if pid == 0: + # Child process + try: + gc_module.collect() + finally: + os._exit(0) + + # Parent process + try: + yield pid + finally: + os.waitpid(pid, 0) + + +class TestLocalTaskJobForkSafety: + """ + Test fork safety for LocalTaskJobRunner with MySQL backend. + + These tests verify that database connections are properly handled + when forking processes, ensuring that parent process connections + remain valid after child process cleanup. + """ + + @pytest.mark.backend("mysql") + def test_dispose_breaks_parent_connection(self): + """ + Test that dispose(close=False) in child process breaks parent connection. + + This test demonstrates the bug: when a child process calls + engine.dispose(close=False), it invalidates the parent's connection + pool, causing the parent to lose its database (only MYSQL) connection. + + Expected result: Parent connection fails with OperationalError + """ + engine1 = create_test_engine() + register_connection_finalizers(engine1) + + if register_at_fork := getattr(os, "register_at_fork", None): + register_at_fork(after_in_child=lambda: engine1.dispose(close=False)) + + # Establish connection before fork + with engine1.connect() as conn: + conn.execute(text("SELECT CONNECTION_ID()")).scalar() + + with fork_process(): + pass + + # Verify parent connection is broken + with engine1.connect() as conn: + with pytest.raises(OperationalError, match="Lost connection to server during query"): + conn.execute(text("SELECT 1")) + + @pytest.mark.backend("mysql", "postgres") + @pytest.mark.asyncio + async def test_async_dispose_breaks_parent_connection(self): + """ + Test that sync_engine.dispose(close=False) breaks async parent connection. + + Similar to the sync version, this demonstrates that calling dispose() + on the underlying sync_engine in a child process breaks the parent's + async database connection. It affects both MySQL and postgres. + + Expected result: Parent connection hangs and times out + """ + async_engine1 = create_test_async_engine() + + if register_at_fork := getattr(os, "register_at_fork", None): + register_at_fork(after_in_child=lambda: async_engine1.sync_engine.dispose(close=False)) + + query = get_connection_id_query(async_engine1.sync_engine.dialect.name) + + async with async_engine1.connect() as conn: + conn_id = await conn.scalar(text(query)) + print(f"Connection ID: {conn_id}") + + with fork_process(): + pass + + async with async_engine1.connect() as conn: + with pytest.raises(asyncio.exceptions.TimeoutError): + await asyncio.wait_for(conn.execute(text(query)), timeout=5) + + @pytest.mark.backend("mysql", "postgres") + def test_parent_process_retains_same_connection_after_child_fork(self): + """ + Test the parent process maintains its original MySQL / postgres connection after forking a child process. + + This test verifies that: + 1. The parent process keeps the same connection ID before and after fork + 2. The engine object identity remains unchanged in the parent process + 3. Forking a child process (which triggers garbage collection) does not affect the parent's DB state + + This ensures that the fork cleanup mechanism (register_at_fork) only affects child processes + and does not inadvertently modify the parent process's database connections. + """ + from sqlalchemy import text + + from airflow.settings import engine + + query = get_connection_id_query(engine.dialect.name) + + with engine.connect() as conn: + before_cid = conn.execute(text(query)).scalar() + before_engine_id = id(engine) + + with fork_process(): + pass + + with engine.connect() as conn: + after_cid = conn.execute(text(query)).scalar() + after_engine_id = id(engine) + + assert before_cid == after_cid + assert before_engine_id == after_engine_id + + @pytest.mark.backend("mysql", "postgres") + @pytest.mark.asyncio + async def test_parent_process_retains_same_async_connection_after_child_fork(self): + """ + Test the parent process maintains its original MySQL / POSTGRES connection after forking a child process. + + This test verifies that: + 1. The parent process keeps the same connection ID before and after fork + 2. The engine object identity remains unchanged in the parent process + 3. Forking a child process (which triggers garbage collection) does not affect the parent's DB state + + This ensures that the fork cleanup mechanism (register_at_fork) only affects child processes + and does not inadvertently modify the parent process's database connections. + """ + from sqlalchemy import text + + from airflow.settings import async_engine + + query = get_connection_id_query(async_engine.sync_engine.dialect.name) + + async with async_engine.connect() as conn: + before_cid = await conn.scalar(text(query)) + before_engine_id = id(async_engine) + + with fork_process(): + pass + + async with async_engine.connect() as conn: + after_cid = await conn.scalar(text(query)) + after_engine_id = id(async_engine) + + assert before_cid == after_cid + assert before_engine_id == after_engine_id From 46c6a29276535925997cc5ae6d1da9d853085dff Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Tue, 21 Oct 2025 09:07:06 +0900 Subject: [PATCH 4/4] add test --- airflow-core/src/airflow/settings.py | 2 +- .../tests/unit/core/test_register_at_fork.py | 27 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index e9b2729967923..3b4185ae0369b 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -433,7 +433,7 @@ def _configure_session(): def clean_in_fork(): _globals = globals() if engine := _globals.get("engine"): - if engine.dialect.name == "mysql": + if "mysql" in engine.dialect.name: _configure_session() else: engine.dispose(close=False) diff --git a/airflow-core/tests/unit/core/test_register_at_fork.py b/airflow-core/tests/unit/core/test_register_at_fork.py index 2387f64f2ab63..bc40e378b60a0 100644 --- a/airflow-core/tests/unit/core/test_register_at_fork.py +++ b/airflow-core/tests/unit/core/test_register_at_fork.py @@ -101,7 +101,7 @@ class TestLocalTaskJobForkSafety: remain valid after child process cleanup. """ - @pytest.mark.backend("mysql") + @pytest.mark.backend("mysql", "postgres") def test_dispose_breaks_parent_connection(self): """ Test that dispose(close=False) in child process breaks parent connection. @@ -110,25 +110,38 @@ def test_dispose_breaks_parent_connection(self): engine.dispose(close=False), it invalidates the parent's connection pool, causing the parent to lose its database (only MYSQL) connection. - Expected result: Parent connection fails with OperationalError + Expected result: + - Parent connection fails with OperationalError in MYSQL backend + - Don't modify the parent process's database connections in Postgres backend """ engine1 = create_test_engine() register_connection_finalizers(engine1) + query = get_connection_id_query(engine1.dialect.name) + if register_at_fork := getattr(os, "register_at_fork", None): register_at_fork(after_in_child=lambda: engine1.dispose(close=False)) # Establish connection before fork with engine1.connect() as conn: - conn.execute(text("SELECT CONNECTION_ID()")).scalar() + before_cid = conn.execute(text(query)).scalar() + before_engine_id = id(engine1) with fork_process(): pass - # Verify parent connection is broken - with engine1.connect() as conn: - with pytest.raises(OperationalError, match="Lost connection to server during query"): - conn.execute(text("SELECT 1")) + if engine1.dialect.name == "mysql": + # Verify parent connection is broken + with engine1.connect() as conn: + with pytest.raises(OperationalError, match="Lost connection to server during query"): + conn.execute(text("SELECT 1")) + else: + with engine1.connect() as conn: + after_cid = conn.execute(text(query)).scalar() + after_engine_id = id(engine1) + + assert before_cid == after_cid + assert before_engine_id == after_engine_id @pytest.mark.backend("mysql", "postgres") @pytest.mark.asyncio