From 93ada665d1b9af7956d5f2b3313071c8848d4c0b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 10 May 2025 05:11:39 +0530 Subject: [PATCH] Fix bug with in-process request handling for `dag.test` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes a subtle bug where `SUPERVISOR_COMMS` was incorrectly set during supervisor-side request handling in `InProcessSupervisorComms.send_request()`. When calling the supervisor’s `_handle_request()` (e.g. in response to a connection or variable fetch), some internal logic (like `Variable.get()` or `Connection.get_connection_from_secrets()`) would incorrectly detect that it was running in a Task SDK execution context — because `SUPERVISOR_COMMS` was still set. This led to recursive calls into the SDK flow (e.g., calling `SUPERVISOR_COMMS.lock`) **while already holding the lock**, resulting in `AttributeError: 'NoneType' object has no attribute 'lock'` or deadlocks. - The fix ensures `SUPERVISOR_COMMS` is temporarily unset while handling the request. - This prevents Task SDK context detection logic from activating during supervisor API handling. - The `set_supervisor_comms(None)` context manager is now explicitly used within `send_request()` to guard the call to `_handle_request()`. - Unit tests for `set_supervisor_comms()` covering all override/restore edge cases - A roundtrip test that verifies `send_request()` triggers `_handle_request()`, which in turn uses `send_msg()` to queue a response retrievable via `get_message()` This fixes real bugs encountered when using `dag.test()` in system tests that rely on connections or variables: - Tasks attempting to fetch a connection during execution caused the supervisor to recurse into its own comms path - This led to incorrect error handling (`500 Internal Server Error`) and test failures - Based on debugging a failure in the `example_athena` system test under `dag.test()` - Prevents regressions for other DAGs/tasks that rely on connection or variable fetches inside Task SDK runtime --- .../airflow/sdk/execution_time/supervisor.py | 17 +++- .../execution_time/test_supervisor.py | 92 ++++++++++++++++++- 2 files changed, 103 insertions(+), 6 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 0e7b7f54cd1f5..90ab13305bc72 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1320,15 +1320,22 @@ def set_supervisor_comms(temp_comms): """ from airflow.sdk.execution_time import task_runner - old = getattr(task_runner, "SUPERVISOR_COMMS", None) - task_runner.SUPERVISOR_COMMS = temp_comms + sentinel = object() + old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + + if temp_comms is not None: + task_runner.SUPERVISOR_COMMS = temp_comms + elif old is not sentinel: + delattr(task_runner, "SUPERVISOR_COMMS") + try: yield finally: - if old is not None: - task_runner.SUPERVISOR_COMMS = old + if old is sentinel: + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") else: - delattr(task_runner, "SUPERVISOR_COMMS") + task_runner.SUPERVISOR_COMMS = old def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 5690f9b418ec6..1e6aec6bed06d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -51,6 +51,7 @@ TaskInstanceState, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -92,7 +93,15 @@ XComResult, ) from airflow.sdk.execution_time.secrets_masker import SecretsMasker -from airflow.sdk.execution_time.supervisor import BUFFER_SIZE, ActivitySubprocess, mkpipe, supervise +from airflow.sdk.execution_time.supervisor import ( + BUFFER_SIZE, + ActivitySubprocess, + InProcessSupervisorComms, + InProcessTestSupervisor, + mkpipe, + set_supervisor_comms, + supervise, +) from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz @@ -1600,3 +1609,84 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): "message": str(error), "detail": error.response.json(), } + + +class TestSetSupervisorComms: + class DummyComms: + pass + + @pytest.fixture(autouse=True) + def cleanup_supervisor_comms(self): + # Ensure clean state before/after test + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") + yield + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") + + def test_set_supervisor_comms_overrides_and_restores(self): + task_runner.SUPERVISOR_COMMS = self.DummyComms() + original = task_runner.SUPERVISOR_COMMS + replacement = self.DummyComms() + + with set_supervisor_comms(replacement): + assert task_runner.SUPERVISOR_COMMS is replacement + assert task_runner.SUPERVISOR_COMMS is original + + def test_set_supervisor_comms_sets_temporarily_when_not_set(self): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + replacement = self.DummyComms() + + with set_supervisor_comms(replacement): + assert task_runner.SUPERVISOR_COMMS is replacement + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + # This will delete an attribute that isn't set, and restore it likewise + with set_supervisor_comms(None): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + +class TestInProcessTestSupervisor: + def test_inprocess_supervisor_comms_roundtrip(self): + """ + Test that InProcessSupervisorComms correctly sends a message to the supervisor, + and that the supervisor's response is received via the message queue. + + This verifies the end-to-end communication flow: + - send_request() dispatches a message to the supervisor + - the supervisor handles the request and appends a response via send_msg() + - get_message() returns the enqueued response + + This test mocks the supervisor's `_handle_request()` method to simulate + a simple echo-style response, avoiding full task execution. + """ + + class MinimalSupervisor(InProcessTestSupervisor): + def _handle_request(self, msg, log): + resp = VariableResult(key=msg.key, value="value") + self.send_msg(resp) + + supervisor = MinimalSupervisor( + id="test", + pid=123, + requests_fd=-1, + process=MagicMock(), + process_log=MagicMock(), + client=MagicMock(), + ) + comms = InProcessSupervisorComms(supervisor=supervisor) + supervisor.comms = comms + + test_msg = GetVariable(key="test_key") + + comms.send_request(log=MagicMock(), msg=test_msg) + + # Ensure we got back what we expect + response = comms.get_message() + assert isinstance(response, VariableResult) + assert response.value == "value"