Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, _send_error_email_notification
from airflow.sdk.log import mask_secret
from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG
from airflow.utils.dag_version_inflation_checker import check_dag_file_stability
from airflow.utils.file import iter_airflow_imports
Expand Down Expand Up @@ -608,6 +609,10 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
if conn.password:
Comment thread
leeyspaul marked this conversation as resolved.
mask_secret(conn.password)
if conn.extra:
mask_secret(conn.extra)
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result
dump_opts = {"exclude_unset": True, "by_alias": True}
Expand All @@ -616,6 +621,8 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
if isinstance(var, VariableResponse):
if var.value:
Comment thread
leeyspaul marked this conversation as resolved.
mask_secret(var.value, var.key)
var_result = VariableResult.from_variable_response(var)
resp = var_result
dump_opts = {"exclude_unset": True}
Expand Down Expand Up @@ -666,8 +673,6 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
resp = XComSequenceSliceResult.from_response(xcoms)
elif isinstance(msg, MaskSecret):
# Use sdk masker in dag processor and triggerer because those use the task sdk machinery
from airflow.sdk.log import mask_secret

mask_secret(msg.value, msg.name)
elif isinstance(msg, GetTICount):
resp = self.client.task_instances.get_count(
Expand Down
75 changes: 74 additions & 1 deletion airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
from airflow.models import DagRun
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import DagRunState
from airflow.sdk.api.datamodels._generated import ConnectionResponse, DagRunState, VariableResponse
from airflow.sdk.execution_time import comms
from airflow.sdk.execution_time.comms import (
GetConnection,
GetTaskStates,
GetTICount,
GetVariable,
GetXCom,
GetXComSequenceSlice,
TaskStatesResult,
Expand Down Expand Up @@ -2056,3 +2058,74 @@ def test_create_log_forwarder_rewrites_task_prefix_to_dag_processor(self, proc):
with patch.object(WatchedSubprocess, "_create_log_forwarder") as mock_base:
proc._create_log_forwarder((), "task.stdout")
mock_base.assert_called_once_with((), "dag_processor.stdout", logging.INFO)

def test_handle_request_get_connection_masks_password_and_extra(self, proc):
proc.client.connections.get.return_value = ConnectionResponse(
conn_id="test_conn",
conn_type="mysql",
password="super-secret-password",
extra='{"api_key":"super-secret-extra"}',
)

with (
patch("airflow.dag_processing.processor.mask_secret") as mock_mask_secret,
patch.object(DagFileProcessorProcess, "send_msg", autospec=True) as mock_send_msg,
):
proc._handle_request(
GetConnection(conn_id="test_conn"),
structlog.get_logger(),
req_id=123,
)

proc.client.connections.get.assert_called_once_with("test_conn")
mock_mask_secret.assert_any_call("super-secret-password")
mock_mask_secret.assert_any_call('{"api_key":"super-secret-extra"}')
assert mock_mask_secret.call_count == 2

mock_send_msg.assert_called_once()
_, args, kwargs = mock_send_msg.mock_calls[0]
assert args[0] is proc
msg = args[1]
assert kwargs["request_id"] == 123
assert kwargs["error"] is None
assert kwargs["exclude_unset"] is True
assert kwargs["by_alias"] is True
assert msg.model_dump(by_alias=True, exclude_unset=True) == {
"conn_id": "test_conn",
"conn_type": "mysql",
"password": "super-secret-password",
"extra": '{"api_key":"super-secret-extra"}',
"type": "ConnectionResult",
}

def test_handle_request_get_variable_masks_value_with_key(self, proc):
proc.client.variables.get.return_value = VariableResponse(
key="test_key",
value="super-secret-value",
)

with (
patch("airflow.dag_processing.processor.mask_secret") as mock_mask_secret,
patch.object(DagFileProcessorProcess, "send_msg", autospec=True) as mock_send_msg,
):
proc._handle_request(
GetVariable(key="test_key"),
structlog.get_logger(),
req_id=456,
)

proc.client.variables.get.assert_called_once_with("test_key")
mock_mask_secret.assert_called_once_with("super-secret-value", "test_key")

mock_send_msg.assert_called_once()
_, args, kwargs = mock_send_msg.mock_calls[0]
assert args[0] is proc
msg = args[1]
assert kwargs["request_id"] == 456
assert kwargs["error"] is None
assert kwargs["exclude_unset"] is True
assert msg.model_dump(exclude_unset=True) == {
"key": "test_key",
"value": "super-secret-value",
"type": "VariableResult",
}
Loading