Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from collections.abc import Collection, Mapping, MutableMapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import cache
from importlib import import_module
from typing import TYPE_CHECKING, Any

from celery import Celery, states as celery_states
Expand Down Expand Up @@ -124,7 +125,10 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery:
:param team_conf: ExecutorConf instance with team-specific configuration, or global conf
:return: Celery app instance
"""
from airflow.providers.celery.executors.default_celery import get_default_celery_config
from airflow.providers.celery.executors.default_celery import (
DEFAULT_CELERY_CONFIG,
get_default_celery_config,
)

celery_app_name = team_conf.get("celery", "CELERY_APP_NAME")

Expand All @@ -138,6 +142,15 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery:

config = get_default_celery_config(team_conf)

# Apply user-provided celery_config_options on top of team config.
# Skip if it resolves to DEFAULT_CELERY_CONFIG (built from global conf, not team-aware).
configured_path = team_conf.get("celery", "celery_config_options", fallback=None)
if configured_path:
module_path, _, attr_name = configured_path.rpartition(".")
user_config = getattr(import_module(module_path), attr_name)
if user_config is not DEFAULT_CELERY_CONFIG and isinstance(user_config, dict):
config.update(user_config)

celery_app = Celery(celery_app_name, config_source=config)

# Register tasks with this app
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_default_celery_config(team_conf) -> dict[str, Any]:

try:
if celery_ssl_active:
if broker_url and "amqp://" in broker_url:
if broker_url and re.search(r"amqps?://", broker_url):
broker_use_ssl = {
"keyfile": team_conf.get("celery", "SSL_KEY"),
"certfile": team_conf.get("celery", "SSL_CERT"),
Expand Down
147 changes: 147 additions & 0 deletions providers/celery/tests/unit/celery/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

if AIRFLOW_V_3_0_PLUS:
from airflow.models.dag_version import DagVersion
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.base_executor import ExecutorConf
if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseOperator, timezone
else:
Expand Down Expand Up @@ -814,3 +816,148 @@ def test_execute_workload_ignores_already_running_task():
"""
with pytest.raises(Ignore):
execute_workload_unwrapped(workload_json)


class TestAmqpsSslConfig:
"""Tests for amqps:// broker URL SSL configuration (Fix for substring match bug)."""

@conf_vars(
{
("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//",
("celery", "SSL_ACTIVE"): "True",
("celery", "SSL_KEY"): "/path/to/key.pem",
("celery", "SSL_CERT"): "/path/to/cert.pem",
("celery", "SSL_CACERT"): "/path/to/ca.pem",
}
)
def test_amqps_broker_url_builds_ssl_config(self):
"""Test that amqps:// broker URLs correctly build broker_use_ssl with AMQP param names."""
import importlib
import ssl

importlib.reload(default_celery)

config = default_celery.DEFAULT_CELERY_CONFIG
assert "broker_use_ssl" in config, "broker_use_ssl should be set for amqps:// URLs"
broker_ssl = config["broker_use_ssl"]
assert broker_ssl["keyfile"] == "/path/to/key.pem"
assert broker_ssl["certfile"] == "/path/to/cert.pem"
assert broker_ssl["ca_certs"] == "/path/to/ca.pem"
assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED
# Must NOT have ssl_ prefixed keys (those are for Redis)
assert "ssl_keyfile" not in broker_ssl
assert "ssl_certfile" not in broker_ssl

@conf_vars(
{
("celery", "BROKER_URL"): "amqp://guest:guest@rabbitmq:5672//",
("celery", "SSL_ACTIVE"): "True",
("celery", "SSL_KEY"): "/path/to/key.pem",
("celery", "SSL_CERT"): "/path/to/cert.pem",
("celery", "SSL_CACERT"): "/path/to/ca.pem",
}
)
def test_amqp_broker_url_still_builds_ssl_config(self):
"""Test that amqp:// (non-TLS) broker URLs still build SSL config correctly (no regression)."""
import importlib
import ssl

importlib.reload(default_celery)

config = default_celery.DEFAULT_CELERY_CONFIG
assert "broker_use_ssl" in config
broker_ssl = config["broker_use_ssl"]
assert broker_ssl["keyfile"] == "/path/to/key.pem"
assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED

@conf_vars(
{
("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//",
("celery", "SSL_ACTIVE"): "False",
}
)
def test_amqps_broker_url_no_ssl_when_inactive(self):
"""Test that amqps:// broker URLs don't get SSL config when SSL_ACTIVE is False."""
import importlib

importlib.reload(default_celery)

config = default_celery.DEFAULT_CELERY_CONFIG
assert "broker_use_ssl" not in config


class TestCreateCeleryAppTeamIsolation:
"""Tests for create_celery_app() multi-team config isolation."""

@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
def test_custom_celery_config_options_applied(self):
"""User-provided celery_config_options (non-default) should be merged into team config."""
custom_config = {"worker_concurrency": 42, "broker_url": "redis://custom:6379/0"}
custom_path = "my_custom_module.CELERY_CONFIG"

team_conf = ExecutorConf(team_name="team_alpha")
original_get = team_conf.get

def mock_get(section, key, **kwargs):
if section == "celery" and key == "celery_config_options":
return custom_path
return original_get(section, key, **kwargs)

mock_module = mock.MagicMock()
mock_module.CELERY_CONFIG = custom_config

with (
mock.patch.object(team_conf, "get", side_effect=mock_get),
mock.patch.object(celery_executor_utils, "import_module", return_value=mock_module),
):
celery_app = celery_executor_utils.create_celery_app(team_conf)
assert celery_app.conf.worker_concurrency == 42
assert celery_app.conf.broker_url == "redis://custom:6379/0"

Comment thread
dandanseo123 marked this conversation as resolved.
def test_default_celery_config_options_skipped_via_identity_check(self):
"""When celery_config_options resolves to DEFAULT_CELERY_CONFIG (same object),
it must be skipped — re-applying it would overwrite team-specific config
since DEFAULT_CELERY_CONFIG is built from global conf."""
original_get = conf.get
# Path just needs a dot for rpartition and attr name matching DEFAULT_CELERY_CONFIG.
# import_module is mocked to return default_celery module regardless of path.
celery_config_path = "any.module.DEFAULT_CELERY_CONFIG"

def mock_get(section, key, **kwargs):
if section == "celery" and key == "celery_config_options":
return celery_config_path
return original_get(section, key, **kwargs)

with (
mock.patch.object(conf, "get", side_effect=mock_get),
mock.patch.object(celery_executor_utils, "import_module") as mock_import,
):
mock_import.return_value = default_celery
celery_app = celery_executor_utils.create_celery_app(conf)
# import_module called (path is non-None), but override skipped (same object)
mock_import.assert_called_once()
default_config = default_celery.get_default_celery_config(conf)
assert celery_app.conf.broker_url == default_config["broker_url"]

@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
def test_team_specific_broker_not_overwritten(self):
"""Team-specific BROKER_URL set via ExecutorConf must survive create_celery_app()."""
team_conf = ExecutorConf(team_name="team_alpha")

original_get = team_conf.get

def mock_team_get(section, key, **kwargs):
if section == "celery" and key == "BROKER_URL":
return "amqps://team-alpha-rabbit:5671//"
return original_get(section, key, **kwargs)

with mock.patch.object(team_conf, "get", side_effect=mock_team_get):
celery_app = celery_executor_utils.create_celery_app(team_conf)
assert celery_app.conf.broker_url == "amqps://team-alpha-rabbit:5671//"

@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
def test_team_app_name_includes_team_name(self):
"""Each team gets a unique Celery app name for broker isolation."""
team_conf = ExecutorConf(team_name="team_beta")
celery_app = celery_executor_utils.create_celery_app(team_conf)
assert "team_beta" in celery_app.main
Loading