diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 4c4a0bed39b93..fcac1925d90fc 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -23,13 +23,24 @@ from typing import TYPE_CHECKING, Any, Callable from uuid import uuid4 -from flask import Response - +from flask import Response, request +from itsdangerous import BadSignature +from jwt import ( + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidSignatureError, +) + +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.configuration import conf from airflow.jobs.job import Job, most_recent_job from airflow.models.taskinstance import _record_task_map_for_downstreams from airflow.models.xcom_arg import _get_task_map_length from airflow.sensors.base import _orig_start_date from airflow.serialization.serialized_objects import BaseSerialization +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.session import create_session if TYPE_CHECKING: @@ -142,6 +153,38 @@ def log_and_build_error_response(message, status): def internal_airflow_api(body: dict[str, Any]) -> APIResponse: """Handle Internal API /internal_api/v1/rpcapi endpoint.""" + auth = request.headers.get("Authorization", "") + signer = JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), + audience="api", + ) + try: + payload = signer.verify_token(auth) + signed_method = payload.get("method") + if not signed_method or signed_method != body.get("method"): + raise BadSignature("Invalid method in token authorization.") + except BadSignature: + raise PermissionDenied("Bad Signature. Please use only the tokens provided by the API.") + except InvalidAudienceError: + raise PermissionDenied("Invalid audience for the request", exc_info=True) + except InvalidSignatureError: + raise PermissionDenied("The signature of the request was wrong", exc_info=True) + except ImmatureSignatureError: + raise PermissionDenied("The signature of the request was sent from the future", exc_info=True) + except ExpiredSignatureError: + raise PermissionDenied( + "The signature of the request has expired. Make sure that all components " + "in your system have synchronized clocks.", + ) + except InvalidIssuedAtError: + raise PermissionDenied( + "The request was issues in the future. Make sure that all components " + "in your system have synchronized clocks.", + ) + except Exception: + raise PermissionDenied("Unable to authenticate API via token.") + log.debug("Got request") json_rpc = body.get("jsonrpc") if json_rpc != "2.0": diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 2da451c15537e..07bd0ec5fedd9 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -32,6 +32,7 @@ from airflow.exceptions import AirflowConfigException, AirflowException from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import ParamSpec +from airflow.utils.jwt_signer import JWTSigner PS = ParamSpec("PS") RT = TypeVar("RT") @@ -117,9 +118,6 @@ def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: See [AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API) for more information . """ - headers = { - "Content-Type": "application/json", - } from requests.exceptions import ConnectionError @tenacity.retry( @@ -129,6 +127,15 @@ def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: before_sleep=tenacity.before_log(logger, logging.WARNING), ) def make_jsonrpc_request(method_name: str, params_json: str) -> bytes: + signer = JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), + audience="api", + ) + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": method_name}), + } data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint() response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers) diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml index 3edacfbc23471..15995a954a683 100644 --- a/airflow/api_internal/openapi/internal_api_v1.yaml +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -36,7 +36,6 @@ servers: paths: "/rpcapi": post: - operationId: rpcapi deprecated: false x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint operationId: internal_airflow_api diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 1cd39166ba084..7087c42011116 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -513,6 +513,19 @@ core: type: string default: ~ example: 'http://localhost:8080' + internal_api_secret_key: + description: | + Secret key used to authenticate internal API clients to core. It should be as random as possible. + However, when running more than 1 instances of webserver / internal API services, make sure all + of them use the same ``secret_key`` otherwise calls will fail on authentication. + The authentication token generated using the secret key has a short expiry time though - make + sure that time on ALL the machines that you run airflow components on is synchronized + (for example using ntpd) otherwise you might get "forbidden" errors when the logs are accessed. + version_added: 2.10.0 + type: string + sensitive: true + example: ~ + default: "{SECRET_KEY}" test_connection: description: | The ability to allow testing connections across Airflow UI, API and CLI. diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index 64ea733d39c58..a453c8f3674fe 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -22,6 +22,8 @@ import pytest +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.configuration import conf from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.taskinstance import TaskInstance @@ -29,6 +31,7 @@ from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import _ENABLE_AIP_44 +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.state import State from airflow.www import app from tests.test_utils.config import conf_vars @@ -82,6 +85,14 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: } yield mock_initialize_method_map + @pytest.fixture + def signer(self) -> JWTSigner: + return JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), + audience="api", + ) + @pytest.mark.parametrize( "input_params, method_result, result_cmp_func, method_params", [ @@ -108,9 +119,12 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: ), ], ) - def test_method(self, input_params, method_result, result_cmp_func, method_params): + def test_method(self, input_params, method_result, result_cmp_func, method_params, signer: JWTSigner): mock_test_method.return_value = method_result - + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } input_data = { "jsonrpc": "2.0", "method": TEST_METHOD_NAME, @@ -118,7 +132,7 @@ def test_method(self, input_params, method_result, result_cmp_func, method_param } response = self.client.post( "/internal_api/v1/rpcapi", - headers={"Content-Type": "application/json"}, + headers=headers, data=json.dumps(input_data), ) assert response.status_code == 200 @@ -131,33 +145,67 @@ def test_method(self, input_params, method_result, result_cmp_func, method_param mock_test_method.assert_called_once_with(**method_params, session=mock.ANY) - def test_method_with_exception(self): + def test_method_with_exception(self, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } mock_test_method.side_effect = ValueError("Error!!!") data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}} - response = self.client.post( - "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) + response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) assert response.status_code == 500 assert response.data, b"Error executing method: test_method." mock_test_method.assert_called_once() - def test_unknown_method(self): - data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist", "params": {}} + def test_unknown_method(self, signer: JWTSigner): + UNKNOWN_METHOD = "i-bet-it-does-not-exist" + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}), + } + data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}} - response = self.client.post( - "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) + response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) assert response.status_code == 400 assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") mock_test_method.assert_not_called() - def test_invalid_jsonrpc(self): + def test_invalid_jsonrpc(self, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - response = self.client.post( - "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) + response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) assert response.status_code == 400 assert response.data.startswith(b"Expected jsonrpc 2.0 request.") mock_test_method.assert_not_called() + + def test_missing_token(self): + mock_test_method.return_value = None + + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": {}, + } + with pytest.raises(PermissionDenied, match="Unable to authenticate API via token."): + self.client.post( + "/internal_api/v1/rpcapi", + headers={"Content-Type": "application/json"}, + data=json.dumps(input_data), + ) + + def test_invalid_token(self, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises( + PermissionDenied, match="Bad Signature. Please use only the tokens provided by the API." + ): + self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index 9ac061181fef7..896e88d77e824 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -138,11 +138,12 @@ def test_remote_call(self, mock_requests): "params": BaseSerialization.serialize({}), } ) - mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal_api/v1/rpcapi", - data=expected_data, - headers={"Content-Type": "application/json"}, - ) + mock_requests.post.assert_called_once() + call_kwargs: dict = mock_requests.post.call_args.kwargs + assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" + assert call_kwargs["data"] == expected_data + assert call_kwargs["headers"]["Content-Type"] == "application/json" + assert "Authorization" in call_kwargs["headers"] @conf_vars( { @@ -192,11 +193,12 @@ def test_remote_call_with_params(self, mock_requests): ), } ) - mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal_api/v1/rpcapi", - data=expected_data, - headers={"Content-Type": "application/json"}, - ) + mock_requests.post.assert_called_once() + call_kwargs: dict = mock_requests.post.call_args.kwargs + assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" + assert call_kwargs["data"] == expected_data + assert call_kwargs["headers"]["Content-Type"] == "application/json" + assert "Authorization" in call_kwargs["headers"] @conf_vars( { @@ -228,11 +230,12 @@ def test_remote_classmethod_call_with_params(self, mock_requests): ), } ) - mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal_api/v1/rpcapi", - data=expected_data, - headers={"Content-Type": "application/json"}, - ) + mock_requests.post.assert_called_once() + call_kwargs: dict = mock_requests.post.call_args.kwargs + assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" + assert call_kwargs["data"] == expected_data + assert call_kwargs["headers"]["Content-Type"] == "application/json" + assert "Authorization" in call_kwargs["headers"] @conf_vars( { @@ -261,8 +264,9 @@ def test_remote_call_with_serialized_model(self, mock_requests): "params": BaseSerialization.serialize({"ti": ti}, use_pydantic_models=True), } ) - mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal_api/v1/rpcapi", - data=expected_data, - headers={"Content-Type": "application/json"}, - ) + mock_requests.post.assert_called_once() + call_kwargs: dict = mock_requests.post.call_args.kwargs + assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" + assert call_kwargs["data"] == expected_data + assert call_kwargs["headers"]["Content-Type"] == "application/json" + assert "Authorization" in call_kwargs["headers"] diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index ded83ca23336f..e6ea725db1112 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -1623,6 +1623,7 @@ def test_sensitive_values(): sensitive_values = { ("database", "sql_alchemy_conn"), ("core", "fernet_key"), + ("core", "internal_api_secret_key"), ("smtp", "smtp_password"), ("webserver", "secret_key"), ("secrets", "backend_kwargs"),