diff --git a/airflow/cli/commands/fastapi_api_command.py b/airflow/cli/commands/fastapi_api_command.py index d50d454347a73..c11da959ce910 100644 --- a/airflow/cli/commands/fastapi_api_command.py +++ b/airflow/cli/commands/fastapi_api_command.py @@ -36,6 +36,7 @@ from airflow import settings from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.cli.commands.webserver_command import GunicornMonitor +from airflow.exceptions import AirflowConfigException from airflow.utils import cli as cli_utils from airflow.utils.cli import setup_locations from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -124,6 +125,10 @@ def fastapi_api(args): "python:airflow.api_fastapi.gunicorn_config", ] + ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args) + if ssl_cert and ssl_key: + run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key] + if args.access_logformat and args.access_logformat.strip(): run_args += ["--access-logformat", str(args.access_logformat)] @@ -199,3 +204,23 @@ def start_and_monitor_gunicorn(args): should_setup_logging=True, pid_file=monitor_pid_file, ) + + +def _get_ssl_cert_and_key_filepaths(cli_arguments) -> tuple[str | None, str | None]: + error_template_1 = "Need both, have provided {} but not {}" + error_template_2 = "SSL related file does not exist {}" + + ssl_cert, ssl_key = cli_arguments.ssl_cert, cli_arguments.ssl_key + if ssl_cert and ssl_key: + if not os.path.isfile(ssl_cert): + raise AirflowConfigException(error_template_2.format(ssl_cert)) + if not os.path.isfile(ssl_key): + raise AirflowConfigException(error_template_2.format(ssl_key)) + + return (ssl_cert, ssl_key) + elif ssl_cert: + raise AirflowConfigException(error_template_1.format("SSL certificate", "SSL key")) + elif ssl_key: + raise AirflowConfigException(error_template_1.format("SSL key", "SSL certificate")) + + return (None, None) diff --git a/tests/cli/commands/test_fastapi_api_command.py b/tests/cli/commands/test_fastapi_api_command.py index 529c67f5ed821..1c1af3342faf6 100644 --- a/tests/cli/commands/test_fastapi_api_command.py +++ b/tests/cli/commands/test_fastapi_api_command.py @@ -27,13 +27,14 @@ from rich.console import Console from airflow.cli.commands import fastapi_api_command +from airflow.exceptions import AirflowConfigException from tests.cli.commands._common_cli_classes import _CommonCLIGunicornTestClass console = Console(width=400, color_system="standard") @pytest.mark.db_test -class TestCliInternalAPI(_CommonCLIGunicornTestClass): +class TestCliFastAPI(_CommonCLIGunicornTestClass): main_process_regexp = r"airflow fastapi-api" @pytest.mark.execution_timeout(210) @@ -46,7 +47,7 @@ def test_cli_fastapi_api_background(self, tmp_path): stderr = parent_path / "airflow-fastapi-api.err" logfile = parent_path / "airflow-fastapi-api.log" try: - # Run internal-api as daemon in background. Note that the wait method is not called. + # Run fastapi-api as daemon in background. Note that the wait method is not called. console.print("[magenta]Starting airflow fastapi-api --daemon") env = os.environ.copy() proc = subprocess.Popen( @@ -123,7 +124,9 @@ def test_cli_fastapi_api_debug(self, app): close_fds=True, ) - def test_cli_fastapi_api_args(self): + def test_cli_fastapi_api_args(self, ssl_cert_and_key): + cert_path, key_path = ssl_cert_and_key + with mock.patch("subprocess.Popen") as Popen, mock.patch.object( fastapi_api_command, "GunicornMonitor" ): @@ -134,6 +137,10 @@ def test_cli_fastapi_api_args(self): "custom_log_format", "--pid", "/tmp/x.pid", + "--ssl-cert", + str(cert_path), + "--ssl-key", + str(key_path), ] ) fastapi_api_command.fastapi_api(args) @@ -161,6 +168,10 @@ def test_cli_fastapi_api_args(self): "-", "--config", "python:airflow.api_fastapi.gunicorn_config", + "--certfile", + str(cert_path), + "--keyfile", + str(key_path), "--access-logformat", "custom_log_format", "airflow.api_fastapi.app:cached_app()", @@ -168,3 +179,31 @@ def test_cli_fastapi_api_args(self): ], close_fds=True, ) + + @pytest.mark.parametrize( + "ssl_arguments, error_pattern", + [ + (["--ssl-cert", "_.crt", "--ssl-key", "_.key"], "does not exist _.crt"), + (["--ssl-cert", "_.crt"], "Need both.*certificate.*key"), + (["--ssl-key", "_.key"], "Need both.*key.*certificate"), + ], + ) + def test_get_ssl_cert_and_key_filepaths_with_incorrect_usage(self, ssl_arguments, error_pattern): + args = self.parser.parse_args(["fastapi-api"] + ssl_arguments) + with pytest.raises(AirflowConfigException, match=error_pattern): + fastapi_api_command._get_ssl_cert_and_key_filepaths(args) + + def test_get_ssl_cert_and_key_filepaths_with_correct_usage(self, ssl_cert_and_key): + cert_path, key_path = ssl_cert_and_key + + args = self.parser.parse_args( + ["fastapi-api"] + ["--ssl-cert", str(cert_path), "--ssl-key", str(key_path)] + ) + assert fastapi_api_command._get_ssl_cert_and_key_filepaths(args) == (str(cert_path), str(key_path)) + + @pytest.fixture + def ssl_cert_and_key(self, tmp_path): + cert_path, key_path = tmp_path / "_.crt", tmp_path / "_.key" + cert_path.touch() + key_path.touch() + return cert_path, key_path