From c117b61d7f99501916fae2bf23572834e08c3ae5 Mon Sep 17 00:00:00 2001 From: hseo36 Date: Sat, 28 Mar 2026 20:10:12 -0400 Subject: [PATCH 1/3] Fix broker_use_ssl not applied for amqps:// broker URLs --- .../celery/executors/celery_executor_utils.py | 6 ++ .../celery/executors/default_celery.py | 2 +- .../celery/executors/test_celery_executor.py | 98 +++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 6ac9ce1902974..722a0bfa33e88 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -138,6 +138,12 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: config = get_default_celery_config(team_conf) + # Apply celery_config_options override if explicitly configured + if conf.has_option("celery", "celery_config_options"): + user_config = conf.getimport("celery", "celery_config_options") + if isinstance(user_config, dict): + config.update(user_config) + celery_app = Celery(celery_app_name, config_source=config) # Register tasks with this app diff --git a/providers/celery/src/airflow/providers/celery/executors/default_celery.py b/providers/celery/src/airflow/providers/celery/executors/default_celery.py index 52ef77a15ac61..f0ef8185d1ecd 100644 --- a/providers/celery/src/airflow/providers/celery/executors/default_celery.py +++ b/providers/celery/src/airflow/providers/celery/executors/default_celery.py @@ -141,7 +141,7 @@ def get_default_celery_config(team_conf) -> dict[str, Any]: try: if celery_ssl_active: - if broker_url and "amqp://" in broker_url: + if broker_url and re.search(r"amqps?://", broker_url): broker_use_ssl = { "keyfile": team_conf.get("celery", "SSL_KEY"), "certfile": team_conf.get("celery", "SSL_CERT"), diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index ff2c146f82874..6083bc5c49904 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -814,3 +814,101 @@ def test_execute_workload_ignores_already_running_task(): """ with pytest.raises(Ignore): execute_workload_unwrapped(workload_json) + + +class TestAmqpsSslConfig: + """Tests for amqps:// broker URL SSL configuration (Fix for substring match bug).""" + + @conf_vars( + { + ("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//", + ("celery", "SSL_ACTIVE"): "True", + ("celery", "SSL_KEY"): "/path/to/key.pem", + ("celery", "SSL_CERT"): "/path/to/cert.pem", + ("celery", "SSL_CACERT"): "/path/to/ca.pem", + } + ) + def test_amqps_broker_url_builds_ssl_config(self): + """Test that amqps:// broker URLs correctly build broker_use_ssl with AMQP param names.""" + import importlib + import ssl + + importlib.reload(default_celery) + + config = default_celery.DEFAULT_CELERY_CONFIG + assert "broker_use_ssl" in config, "broker_use_ssl should be set for amqps:// URLs" + broker_ssl = config["broker_use_ssl"] + assert broker_ssl["keyfile"] == "/path/to/key.pem" + assert broker_ssl["certfile"] == "/path/to/cert.pem" + assert broker_ssl["ca_certs"] == "/path/to/ca.pem" + assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED + # Must NOT have ssl_ prefixed keys (those are for Redis) + assert "ssl_keyfile" not in broker_ssl + assert "ssl_certfile" not in broker_ssl + + @conf_vars( + { + ("celery", "BROKER_URL"): "amqp://guest:guest@rabbitmq:5672//", + ("celery", "SSL_ACTIVE"): "True", + ("celery", "SSL_KEY"): "/path/to/key.pem", + ("celery", "SSL_CERT"): "/path/to/cert.pem", + ("celery", "SSL_CACERT"): "/path/to/ca.pem", + } + ) + def test_amqp_broker_url_still_builds_ssl_config(self): + """Test that amqp:// (non-TLS) broker URLs still build SSL config correctly (no regression).""" + import importlib + import ssl + + importlib.reload(default_celery) + + config = default_celery.DEFAULT_CELERY_CONFIG + assert "broker_use_ssl" in config + broker_ssl = config["broker_use_ssl"] + assert broker_ssl["keyfile"] == "/path/to/key.pem" + assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED + + @conf_vars( + { + ("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//", + ("celery", "SSL_ACTIVE"): "False", + } + ) + def test_amqps_broker_url_no_ssl_when_inactive(self): + """Test that amqps:// broker URLs don't get SSL config when SSL_ACTIVE is False.""" + import importlib + + importlib.reload(default_celery) + + config = default_celery.DEFAULT_CELERY_CONFIG + assert "broker_use_ssl" not in config + + +class TestCeleryConfigOptionsOverride: + """Tests for celery_config_options being applied in create_celery_app().""" + + def test_celery_config_options_applied_in_create_celery_app(self): + """Test that celery_config_options overrides are merged into create_celery_app() config.""" + custom_config = {"worker_concurrency": 42, "broker_url": "redis://custom:6379/0"} + + original_has_option = conf.has_option + + def mock_has_option(section, key, **kwargs): + if section == "celery" and key == "celery_config_options": + return True + return original_has_option(section, key, **kwargs) + + with ( + mock.patch.object(conf, "has_option", side_effect=mock_has_option), + mock.patch.object(conf, "getimport", return_value=custom_config), + ): + celery_app = celery_executor_utils.create_celery_app(conf) + # The custom config should override defaults + assert celery_app.conf.worker_concurrency == 42 + assert celery_app.conf.broker_url == "redis://custom:6379/0" + + def test_create_celery_app_works_without_celery_config_options(self): + """Test that create_celery_app() works when celery_config_options is not set.""" + # Should not raise — uses defaults from get_default_celery_config() + celery_app = celery_executor_utils.create_celery_app(conf) + assert celery_app is not None From b89f67111255dbf04a44951334423e8264f0db0c Mon Sep 17 00:00:00 2001 From: hseo36 Date: Thu, 2 Apr 2026 12:24:20 -0400 Subject: [PATCH 2/3] addressed copilot comments Co-Authored-By: Claude Opus 4.6 (1M context) --- .../celery/executors/celery_executor_utils.py | 17 ++-- .../celery/executors/test_celery_executor.py | 83 +++++++++++++++---- 2 files changed, 78 insertions(+), 22 deletions(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 722a0bfa33e88..ebf9859f91d04 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -31,6 +31,7 @@ import sys import traceback from collections.abc import Collection, Mapping, MutableMapping, Sequence +from importlib import import_module from concurrent.futures import ProcessPoolExecutor from functools import cache from typing import TYPE_CHECKING, Any @@ -124,7 +125,10 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: :param team_conf: ExecutorConf instance with team-specific configuration, or global conf :return: Celery app instance """ - from airflow.providers.celery.executors.default_celery import get_default_celery_config + from airflow.providers.celery.executors.default_celery import ( + DEFAULT_CELERY_CONFIG, + get_default_celery_config, + ) celery_app_name = team_conf.get("celery", "CELERY_APP_NAME") @@ -138,10 +142,13 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: config = get_default_celery_config(team_conf) - # Apply celery_config_options override if explicitly configured - if conf.has_option("celery", "celery_config_options"): - user_config = conf.getimport("celery", "celery_config_options") - if isinstance(user_config, dict): + # Apply user-provided celery_config_options on top of team config. + # Skip if it resolves to DEFAULT_CELERY_CONFIG (built from global conf, not team-aware). + configured_path = team_conf.get("celery", "celery_config_options", fallback=None) + if configured_path: + module_path, _, attr_name = configured_path.rpartition(".") + user_config = getattr(import_module(module_path), attr_name) + if user_config is not DEFAULT_CELERY_CONFIG and isinstance(user_config, dict): config.update(user_config) celery_app = Celery(celery_app_name, config_source=config) diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 6083bc5c49904..bc668bef50861 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -54,6 +54,8 @@ if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.base_executor import ExecutorConf if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator, timezone else: @@ -884,31 +886,78 @@ def test_amqps_broker_url_no_ssl_when_inactive(self): assert "broker_use_ssl" not in config -class TestCeleryConfigOptionsOverride: - """Tests for celery_config_options being applied in create_celery_app().""" +class TestCreateCeleryAppTeamIsolation: + """Tests for create_celery_app() multi-team config isolation.""" - def test_celery_config_options_applied_in_create_celery_app(self): - """Test that celery_config_options overrides are merged into create_celery_app() config.""" + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") + def test_custom_celery_config_options_applied(self): + """User-provided celery_config_options (non-default) should be merged into team config.""" custom_config = {"worker_concurrency": 42, "broker_url": "redis://custom:6379/0"} + custom_path = "my_custom_module.CELERY_CONFIG" - original_has_option = conf.has_option + team_conf = ExecutorConf(team_name="team_alpha") + original_get = team_conf.get - def mock_has_option(section, key, **kwargs): + def mock_get(section, key, **kwargs): if section == "celery" and key == "celery_config_options": - return True - return original_has_option(section, key, **kwargs) + return custom_path + return original_get(section, key, **kwargs) + + mock_module = mock.MagicMock() + mock_module.CELERY_CONFIG = custom_config with ( - mock.patch.object(conf, "has_option", side_effect=mock_has_option), - mock.patch.object(conf, "getimport", return_value=custom_config), + mock.patch.object(team_conf, "get", side_effect=mock_get), + mock.patch.object(celery_executor_utils, "import_module", return_value=mock_module), ): - celery_app = celery_executor_utils.create_celery_app(conf) - # The custom config should override defaults + celery_app = celery_executor_utils.create_celery_app(team_conf) assert celery_app.conf.worker_concurrency == 42 assert celery_app.conf.broker_url == "redis://custom:6379/0" - def test_create_celery_app_works_without_celery_config_options(self): - """Test that create_celery_app() works when celery_config_options is not set.""" - # Should not raise — uses defaults from get_default_celery_config() - celery_app = celery_executor_utils.create_celery_app(conf) - assert celery_app is not None + def test_default_celery_config_options_skipped_via_identity_check(self): + """When celery_config_options resolves to DEFAULT_CELERY_CONFIG (same object), + it must be skipped — re-applying it would overwrite team-specific config + since DEFAULT_CELERY_CONFIG is built from global conf.""" + original_get = conf.get + # Path just needs a dot for rpartition and attr name matching DEFAULT_CELERY_CONFIG. + # import_module is mocked to return default_celery module regardless of path. + celery_config_path = "any.module.DEFAULT_CELERY_CONFIG" + + def mock_get(section, key, **kwargs): + if section == "celery" and key == "celery_config_options": + return celery_config_path + return original_get(section, key, **kwargs) + + with ( + mock.patch.object(conf, "get", side_effect=mock_get), + mock.patch.object(celery_executor_utils, "import_module") as mock_import, + ): + mock_import.return_value = default_celery + celery_app = celery_executor_utils.create_celery_app(conf) + # import_module called (path is non-None), but override skipped (same object) + mock_import.assert_called_once() + default_config = default_celery.get_default_celery_config(conf) + assert celery_app.conf.broker_url == default_config["broker_url"] + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") + def test_team_specific_broker_not_overwritten(self): + """Team-specific BROKER_URL set via ExecutorConf must survive create_celery_app().""" + team_conf = ExecutorConf(team_name="team_alpha") + + original_get = team_conf.get + + def mock_team_get(section, key, **kwargs): + if section == "celery" and key == "BROKER_URL": + return "amqps://team-alpha-rabbit:5671//" + return original_get(section, key, **kwargs) + + with mock.patch.object(team_conf, "get", side_effect=mock_team_get): + celery_app = celery_executor_utils.create_celery_app(team_conf) + assert celery_app.conf.broker_url == "amqps://team-alpha-rabbit:5671//" + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") + def test_team_app_name_includes_team_name(self): + """Each team gets a unique Celery app name for broker isolation.""" + team_conf = ExecutorConf(team_name="team_beta") + celery_app = celery_executor_utils.create_celery_app(team_conf) + assert "team_beta" in celery_app.main From 1d117bb3dd7e891a8327455a0897f04156d29744 Mon Sep 17 00:00:00 2001 From: hseo36 Date: Fri, 3 Apr 2026 07:41:49 -0400 Subject: [PATCH 3/3] fix import sort order --- .../airflow/providers/celery/executors/celery_executor_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index ebf9859f91d04..699052b470ed7 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -31,9 +31,9 @@ import sys import traceback from collections.abc import Collection, Mapping, MutableMapping, Sequence -from importlib import import_module from concurrent.futures import ProcessPoolExecutor from functools import cache +from importlib import import_module from typing import TYPE_CHECKING, Any from celery import Celery, states as celery_states