From 0fb27af54cd14d018f63a5eec772d8f3aecdba74 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Tue, 21 Apr 2026 20:38:15 +0200 Subject: [PATCH 01/17] fix frame mismatch in triggerer --- .../src/airflow/jobs/triggerer_job_runner.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 45a6c7ddc0d4c..222827085a33d 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -877,7 +877,7 @@ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: from asgiref.sync import async_to_sync with self._thread_lock: - return async_to_sync(self.asend)(msg) + return async_to_sync(self._asend)(msg) async def _aread_frame(self): try: @@ -899,14 +899,20 @@ async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}") return self._from_frame(frame) - async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + async def _asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) bytes = frame.as_bytes() + self._async_writer.write(bytes) + return await self._aget_response(frame.id) + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: async with self._async_lock: - self._async_writer.write(bytes) - - return await self._aget_response(frame.id) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._thread_lock.acquire) + try: + return await self._asend(msg) + finally: + self._thread_lock.release() class TriggerRunner: From 65804f39a484d4f531eb1074a41007642d3590d2 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 14:19:51 +0200 Subject: [PATCH 02/17] added context managers --- .../src/airflow/sdk/execution_time/comms.py | 65 ++++++++++++------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..a9d5d2800725b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -58,6 +58,7 @@ from socket import socket from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload from uuid import UUID +from contextlib import asynccontextmanager, contextmanager import attrs import msgspec @@ -193,6 +194,28 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): # Async lock for async operations _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + @contextmanager + def _lock_sync(self): + """Acquire the thread lock for synchronous operations.""" + with self._thread_lock: + yield + + @asynccontextmanager + async def _lock_async(self): + """ + Acquire both the async lock and the thread lock for asynchronous operations. + + The thread lock is acquired via a thread executor so the event loop + is not blocked while waiting for other holders. + """ + async with self._async_lock: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._thread_lock.acquire) + try: + yield + finally: + self._thread_lock.release() + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: """Send a request to the parent and block until the response is received.""" frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) @@ -200,7 +223,7 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: # We must make sure sockets aren't intermixed between sync and async calls, # thus we need a dual locking mechanism to ensure that. - with self._thread_lock: + with self._lock_sync(): self.socket.sendall(frame_bytes) if isinstance(msg, ResendLoggingFD): if recv_fds is None: @@ -227,30 +250,26 @@ async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) frame_bytes = frame.as_bytes() - async with self._async_lock: + async with self._lock_async(): # Acquire the threading lock without blocking the event loop loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self._thread_lock.acquire) - try: - # Async write to socket - await loop.sock_sendall(self.socket, frame_bytes) - - if isinstance(msg, ResendLoggingFD): - if recv_fds is None: - return None - # Blocking read in a thread - frame, fds = await asyncio.to_thread(self._read_frame, maxfds=1) - resp = self._from_frame(frame) - if TYPE_CHECKING: - assert isinstance(resp, SentFDs) - resp.fds = fds - return resp # type: ignore[return-value] - - # Normal blocking read in a thread - frame = await asyncio.to_thread(self._read_frame) - return self._from_frame(frame) - finally: - self._thread_lock.release() + # Async write to socket + await loop.sock_sendall(self.socket, frame_bytes) + + if isinstance(msg, ResendLoggingFD): + if recv_fds is None: + return None + # Blocking read in a thread + frame, fds = await asyncio.to_thread(self._read_frame, maxfds=1) + resp = self._from_frame(frame) + if TYPE_CHECKING: + assert isinstance(resp, SentFDs) + resp.fds = fds + return resp # type: ignore[return-value] + + # Normal blocking read in a thread + frame = await asyncio.to_thread(self._read_frame) + return self._from_frame(frame) @overload def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... From c673e12d30c62e6f17b0d16f52a341b174622a42 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 14:31:56 +0200 Subject: [PATCH 03/17] tests --- .../task_sdk/execution_time/test_comms.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 5c6d88439250c..fd40720a1cadc 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -17,6 +17,7 @@ from __future__ import annotations +import asyncio import threading import uuid from socket import socketpair @@ -29,6 +30,7 @@ from airflow.sdk.execution_time.comms import ( BundleInfo, CommsDecoder, + GetVariable, MaskSecret, StartupDetails, VariableResult, @@ -193,3 +195,87 @@ def send_and_store(idx): assert errors[idx] is None, f"Thread {idx} error: {errors[idx]}" assert results[idx].key == f"key{idx}", f"Out-of-order or missing response for thread {idx}" assert results[idx].value == f"value{idx}", f"Incorrect value for thread {idx}" + + @pytest.mark.asyncio + async def test_asend_basic(self): + """Verify a single async request‑response cycle via asend.""" + r, w = socketpair() + r.setblocking(False) # <-- required for asyncio + w.setblocking(False) + decoder = CommsDecoder(socket=r, log=structlog.get_logger()) + + async def server(): + loop = asyncio.get_running_loop() + # read length + len_bytes = await loop.sock_recv(w, 4) + length = int.from_bytes(len_bytes, "big") + data = bytearray() + while len(data) < length: + chunk = await loop.sock_recv(w, length - len(data)) + if not chunk: + break + data.extend(chunk) + req = decoder.resp_decoder.decode(data) + # build a VariableResult matching the request + key = req.body["key"] + resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} + resp_frame = _ResponseFrame(req.id, resp, None) + resp_bytes = msgspec.msgpack.encode(resp_frame) + await loop.sock_sendall(w, len(resp_bytes).to_bytes(4, "big") + resp_bytes) + w.close() # signal EOF for clean shutdown + + server_task = asyncio.create_task(server()) + + msg = GetVariable(key="basic_key") + result = await decoder.asend(msg) + + assert isinstance(result, VariableResult) + assert result.key == "basic_key" + assert result.value == "value_basic_key" + + await server_task + r.close() # clean up the read end + + @pytest.mark.asyncio + async def test_asend_concurrent_safety(self): + """Multiple concurrent asend calls must not interleave and must each receive the correct response.""" + r, w = socketpair() + r.setblocking(False) # <-- required for asyncio + w.setblocking(False) + decoder = CommsDecoder(socket=r, log=structlog.get_logger()) + num_requests = 5 + + async def server(): + loop = asyncio.get_running_loop() + for _ in range(num_requests): + # read one frame + len_bytes = await loop.sock_recv(w, 4) + length = int.from_bytes(len_bytes, "big") + data = bytearray() + while len(data) < length: + chunk = await loop.sock_recv(w, length - len(data)) + if not chunk: + break + data.extend(chunk) + req = decoder.resp_decoder.decode(data) + key = req.body["key"] + resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} + resp_frame = _ResponseFrame(req.id, resp, None) + resp_bytes = msgspec.msgpack.encode(resp_frame) + await loop.sock_sendall(w, len(resp_bytes).to_bytes(4, "big") + resp_bytes) + w.close() + + server_task = asyncio.create_task(server()) + + async def make_request(idx: int) -> VariableResult: + msg = GetVariable(key=f"key{idx}") + return await decoder.asend(msg) + + # Start all requests concurrently + results = await asyncio.gather(*[make_request(i) for i in range(num_requests)]) + await server_task + r.close() + + for idx, result in enumerate(results): + assert result.key == f"key{idx}" + assert result.value == f"value_key{idx}" From 48aa6181b6836ac80b25dfa2883074355c4f002c Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 15:00:30 +0200 Subject: [PATCH 04/17] use context managers from parent class --- .../src/airflow/jobs/triggerer_job_runner.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 222827085a33d..bf1c612ac7a82 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -870,15 +870,19 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): def _read_frame(self): from asgiref.sync import async_to_sync - with self._thread_lock: + with self._lock_sync(): return async_to_sync(self._aread_frame)() def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: from asgiref.sync import async_to_sync - with self._thread_lock: + with self._lock_sync(): return async_to_sync(self._asend)(msg) + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + async with self._lock_async(): + return await self._asend(msg) + async def _aread_frame(self): try: len_bytes = await self._async_reader.readexactly(4) @@ -905,16 +909,6 @@ async def _asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: self._async_writer.write(bytes) return await self._aget_response(frame.id) - async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: - async with self._async_lock: - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self._thread_lock.acquire) - try: - return await self._asend(msg) - finally: - self._thread_lock.release() - - class TriggerRunner: """ Runtime environment for all triggers. From 20adea341c9a7bef3e56e76bb33c1198a25d1366 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 15:44:54 +0200 Subject: [PATCH 05/17] cleanup --- .../task_sdk/execution_time/test_comms.py | 56 ++++++------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index fd40720a1cadc..a11746332a1ea 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -196,35 +196,35 @@ def send_and_store(idx): assert results[idx].key == f"key{idx}", f"Out-of-order or missing response for thread {idx}" assert results[idx].value == f"value{idx}", f"Incorrect value for thread {idx}" - @pytest.mark.asyncio - async def test_asend_basic(self): - """Verify a single async request‑response cycle via asend.""" - r, w = socketpair() - r.setblocking(False) # <-- required for asyncio - w.setblocking(False) - decoder = CommsDecoder(socket=r, log=structlog.get_logger()) - - async def server(): - loop = asyncio.get_running_loop() - # read length - len_bytes = await loop.sock_recv(w, 4) + async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder) -> None: + """Handle ``num_requests`` GetVariable frames, responding with VariableResult.""" + loop = asyncio.get_running_loop() + for _ in range(num_requests): + len_bytes = await loop.sock_recv(w_sock, 4) length = int.from_bytes(len_bytes, "big") data = bytearray() while len(data) < length: - chunk = await loop.sock_recv(w, length - len(data)) + chunk = await loop.sock_recv(w_sock, length - len(data)) if not chunk: break data.extend(chunk) req = decoder.resp_decoder.decode(data) - # build a VariableResult matching the request key = req.body["key"] resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} resp_frame = _ResponseFrame(req.id, resp, None) resp_bytes = msgspec.msgpack.encode(resp_frame) - await loop.sock_sendall(w, len(resp_bytes).to_bytes(4, "big") + resp_bytes) - w.close() # signal EOF for clean shutdown + await loop.sock_sendall(w_sock, len(resp_bytes).to_bytes(4, "big") + resp_bytes) + w_sock.close() + + @pytest.mark.asyncio + async def test_asend_basic(self): + """Verify a single async request‑response cycle via asend.""" + r, w = socketpair() + r.setblocking(False) # <-- required for asyncio + w.setblocking(False) + decoder = CommsDecoder(socket=r, log=structlog.get_logger()) - server_task = asyncio.create_task(server()) + server_task = asyncio.create_task(TestCommsDecoder._variable_server_side(w, 1, decoder)) msg = GetVariable(key="basic_key") result = await decoder.asend(msg) @@ -245,27 +245,7 @@ async def test_asend_concurrent_safety(self): decoder = CommsDecoder(socket=r, log=structlog.get_logger()) num_requests = 5 - async def server(): - loop = asyncio.get_running_loop() - for _ in range(num_requests): - # read one frame - len_bytes = await loop.sock_recv(w, 4) - length = int.from_bytes(len_bytes, "big") - data = bytearray() - while len(data) < length: - chunk = await loop.sock_recv(w, length - len(data)) - if not chunk: - break - data.extend(chunk) - req = decoder.resp_decoder.decode(data) - key = req.body["key"] - resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} - resp_frame = _ResponseFrame(req.id, resp, None) - resp_bytes = msgspec.msgpack.encode(resp_frame) - await loop.sock_sendall(w, len(resp_bytes).to_bytes(4, "big") + resp_bytes) - w.close() - - server_task = asyncio.create_task(server()) + server_task = asyncio.create_task(TestCommsDecoder._variable_server_side(w, num_requests, decoder)) async def make_request(idx: int) -> VariableResult: msg = GetVariable(key=f"key{idx}") From 52a3ec425c9b1f30a4a0e34d49d24fada30695fb Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 16:58:42 +0200 Subject: [PATCH 06/17] tests --- .../tests/unit/jobs/test_triggerer_job.py | 135 +++++++++++++++++- 1 file changed, 133 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 08c960fd33fe6..abbc928635095 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -20,13 +20,16 @@ import asyncio import datetime import itertools +import msgspec import os import selectors +import structlog import time +import threading import typing import uuid from collections.abc import AsyncIterator -from socket import socket +from socket import socket, socketpair from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, patch @@ -67,7 +70,7 @@ from airflow.providers.standard.triggers.file import FileDeleteTrigger from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.sdk import DAG, BaseHook, BaseOperator -from airflow.sdk.execution_time.comms import ToSupervisor, ToTask +from airflow.sdk.execution_time.comms import ToSupervisor, ToTask, _RequestFrame, _ResponseFrame from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.triggers.testing import FailureTrigger, SuccessTrigger @@ -1538,3 +1541,131 @@ def test_make_trigger_span_sets_only_trigger_name_without_ti(self): assert attrs["airflow.trigger.name"] == "OnlyTrigger" assert "airflow.dag_id" not in attrs assert "airflow.task_id" not in attrs + +class TestTriggerCommsDecoder: + """Tests for the low‑level TriggerCommsDecoder socket communication.""" + + @pytest.mark.asyncio + async def test_recv_trigger_message(self): + r, w = socketpair() + r.setblocking(False) + w.setblocking(False) + + # Create asyncio reader/writer for the decoder side (using socket r) + reader, writer = await asyncio.open_connection(sock=r) + + # Pass the socket explicitly to avoid ENOTSOCK (fd 0 is not a socket) + decoder = TriggerCommsDecoder( + socket=r, + async_reader=reader, + async_writer=writer, + log=None, + ) + + # Prepare a TriggerStateSync response frame, send it on the other socket (w) + sync = messages.TriggerStateSync(to_create=[], to_cancel=set()) + frame = _RequestFrame(id=0, body=sync.model_dump()) + data = msgspec.msgpack.encode(frame) + loop = asyncio.get_running_loop() + await loop.sock_sendall(w, len(data).to_bytes(4, "big") + data) + + # Use async internal methods to avoid AsyncToSync in an async context + resp_frame = await decoder._aread_frame() + msg = decoder._from_frame(resp_frame) + assert isinstance(msg, messages.TriggerStateSync) + assert msg.to_create == [] + assert msg.to_cancel == set() + + writer.close() + await writer.wait_closed() + + async def _run_trigger_comms_server( + self, w_sock: socket, num_requests: int, decoder: TriggerCommsDecoder + ) -> None: + loop = asyncio.get_running_loop() + for _ in range(num_requests): + len_bytes = await loop.sock_recv(w_sock, 4) + length = int.from_bytes(len_bytes, "big") + data = bytearray() + while len(data) < length: + chunk = await loop.sock_recv(w_sock, length - len(data)) + if not chunk: + break + data.extend(chunk) + + req = decoder.resp_decoder.decode(data) # This is a _RequestFrame + # Deserialize the TriggerStateChanges request body + changes = messages.TriggerStateChanges(**req.body) + + # Prepare the TriggerStateSync response + resp = messages.TriggerStateSync(to_create=[], to_cancel=set()) + resp_frame = _RequestFrame(id=req.id, body=resp.model_dump()) + resp_bytes = msgspec.msgpack.encode(resp_frame) + await loop.sock_sendall(w_sock, len(resp_bytes).to_bytes(4, "big") + resp_bytes) + w_sock.close() + + @pytest.mark.asyncio + async def test_asend_basic(self): + r, w = socketpair() + r.setblocking(False) + w.setblocking(False) + + reader, writer = await asyncio.open_connection(sock=r) + decoder = TriggerCommsDecoder( + socket=r, + async_reader=reader, + async_writer=writer, + log=structlog.get_logger(), + ) + + server_task = asyncio.create_task( + self._run_trigger_comms_server(w, 1, decoder) + ) + + msg = messages.TriggerStateChanges(events=None, finished=None, failures=None) + result = await decoder.asend(msg) + + assert isinstance(result, messages.TriggerStateSync) + assert result.to_create == [] + assert result.to_cancel == set() + + await server_task + writer.close() + await writer.wait_closed() + r.close() + + @pytest.mark.asyncio + async def test_asend_concurrent_safety(self): + r, w = socketpair() + r.setblocking(False) + w.setblocking(False) + + reader, writer = await asyncio.open_connection(sock=r) + decoder = TriggerCommsDecoder( + socket=r, + async_reader=reader, + async_writer=writer, + log=structlog.get_logger(), + ) + + num_requests = 5 + server_task = asyncio.create_task( + self._run_trigger_comms_server(w, num_requests, decoder) + ) + + async def make_request(idx: int): + msg = messages.TriggerStateChanges( + events=[(idx, TriggerEvent(payload={"index": idx}))], + finished=None, + failures=None, + ) + return await decoder.asend(msg) + + results = await asyncio.gather(*(make_request(i) for i in range(num_requests))) + await server_task + writer.close() + await writer.wait_closed() + r.close() + + for idx, result in enumerate(results): + assert isinstance(result, messages.TriggerStateSync) From d1ec416840566b53344e1b48ca9d062921351663 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 17:19:57 +0200 Subject: [PATCH 07/17] cleanup --- .../src/airflow/jobs/triggerer_job_runner.py | 3 ++- .../tests/unit/jobs/test_triggerer_job.py | 18 +++++++----------- .../src/airflow/sdk/execution_time/comms.py | 2 +- .../task_sdk/execution_time/test_comms.py | 4 ++-- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index bf1c612ac7a82..0f189e357b306 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -870,7 +870,7 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): def _read_frame(self): from asgiref.sync import async_to_sync - with self._lock_sync(): + with self._lock_sync(): return async_to_sync(self._aread_frame)() def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: @@ -909,6 +909,7 @@ async def _asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: self._async_writer.write(bytes) return await self._aget_response(frame.id) + class TriggerRunner: """ Runtime environment for all triggers. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index abbc928635095..a0374682a2ebc 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -20,12 +20,9 @@ import asyncio import datetime import itertools -import msgspec import os import selectors -import structlog import time -import threading import typing import uuid from collections.abc import AsyncIterator @@ -34,8 +31,10 @@ from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, patch +import msgspec import pendulum import pytest +import structlog from asgiref.sync import sync_to_async from opentelemetry import trace as otel_trace from opentelemetry.sdk.trace import TracerProvider @@ -70,7 +69,7 @@ from airflow.providers.standard.triggers.file import FileDeleteTrigger from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.sdk import DAG, BaseHook, BaseOperator -from airflow.sdk.execution_time.comms import ToSupervisor, ToTask, _RequestFrame, _ResponseFrame +from airflow.sdk.execution_time.comms import ToSupervisor, ToTask, _RequestFrame from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.triggers.testing import FailureTrigger, SuccessTrigger @@ -1542,6 +1541,7 @@ def test_make_trigger_span_sets_only_trigger_name_without_ti(self): assert "airflow.dag_id" not in attrs assert "airflow.task_id" not in attrs + class TestTriggerCommsDecoder: """Tests for the low‑level TriggerCommsDecoder socket communication.""" @@ -1593,7 +1593,7 @@ async def _run_trigger_comms_server( break data.extend(chunk) - req = decoder.resp_decoder.decode(data) # This is a _RequestFrame + req = decoder.resp_decoder.decode(data) # This is a _RequestFrame # Deserialize the TriggerStateChanges request body changes = messages.TriggerStateChanges(**req.body) @@ -1618,9 +1618,7 @@ async def test_asend_basic(self): log=structlog.get_logger(), ) - server_task = asyncio.create_task( - self._run_trigger_comms_server(w, 1, decoder) - ) + server_task = asyncio.create_task(self._run_trigger_comms_server(w, 1, decoder)) msg = messages.TriggerStateChanges(events=None, finished=None, failures=None) result = await decoder.asend(msg) @@ -1649,9 +1647,7 @@ async def test_asend_concurrent_safety(self): ) num_requests = 5 - server_task = asyncio.create_task( - self._run_trigger_comms_server(w, num_requests, decoder) - ) + server_task = asyncio.create_task(self._run_trigger_comms_server(w, num_requests, decoder)) async def make_request(idx: int): msg = messages.TriggerStateChanges( diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index a9d5d2800725b..284b014d792ec 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -52,13 +52,13 @@ import itertools import threading from collections.abc import Iterator +from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import cached_property from pathlib import Path from socket import socket from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload from uuid import UUID -from contextlib import asynccontextmanager, contextmanager import attrs import msgspec diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index a11746332a1ea..6d95e51d0f144 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -220,7 +220,7 @@ async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder async def test_asend_basic(self): """Verify a single async request‑response cycle via asend.""" r, w = socketpair() - r.setblocking(False) # <-- required for asyncio + r.setblocking(False) # <-- required for asyncio w.setblocking(False) decoder = CommsDecoder(socket=r, log=structlog.get_logger()) @@ -240,7 +240,7 @@ async def test_asend_basic(self): async def test_asend_concurrent_safety(self): """Multiple concurrent asend calls must not interleave and must each receive the correct response.""" r, w = socketpair() - r.setblocking(False) # <-- required for asyncio + r.setblocking(False) # <-- required for asyncio w.setblocking(False) decoder = CommsDecoder(socket=r, log=structlog.get_logger()) num_requests = 5 From eb4f42fd1be7c686fea114d1af3dfa5ea5ee118b Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 17:57:22 +0200 Subject: [PATCH 08/17] protect frame counter --- task-sdk/src/airflow/sdk/execution_time/comms.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 284b014d792ec..7603ba6b464c0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -218,12 +218,13 @@ async def _lock_async(self): def send(self, msg: SendMsgType) -> ReceiveMsgType | None: """Send a request to the parent and block until the response is received.""" - frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) - frame_bytes = frame.as_bytes() # We must make sure sockets aren't intermixed between sync and async calls, # thus we need a dual locking mechanism to ensure that. with self._lock_sync(): + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + frame_bytes = frame.as_bytes() + self.socket.sendall(frame_bytes) if isinstance(msg, ResendLoggingFD): if recv_fds is None: @@ -247,13 +248,13 @@ async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: Uses async lock for coroutine safety and thread lock for socket safety. """ - frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) - frame_bytes = frame.as_bytes() async with self._lock_async(): - # Acquire the threading lock without blocking the event loop - loop = asyncio.get_running_loop() + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + frame_bytes = frame.as_bytes() + # Async write to socket + loop = asyncio.get_running_loop() await loop.sock_sendall(self.socket, frame_bytes) if isinstance(msg, ResendLoggingFD): From dfc38e0ba6737d4b1dfa7e400a4a0ffed4c2c91f Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 17:58:02 +0200 Subject: [PATCH 09/17] cleanup --- task-sdk/src/airflow/sdk/execution_time/comms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 7603ba6b464c0..f41b838dc71dd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -218,7 +218,6 @@ async def _lock_async(self): def send(self, msg: SendMsgType) -> ReceiveMsgType | None: """Send a request to the parent and block until the response is received.""" - # We must make sure sockets aren't intermixed between sync and async calls, # thus we need a dual locking mechanism to ensure that. with self._lock_sync(): @@ -248,7 +247,6 @@ async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: Uses async lock for coroutine safety and thread lock for socket safety. """ - async with self._lock_async(): frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) frame_bytes = frame.as_bytes() From 16cd660b491596e5f240939f1fe9eeeaa7d539ec Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:16:18 +0200 Subject: [PATCH 10/17] cleanup --- airflow-core/tests/unit/jobs/test_triggerer_job.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index a0374682a2ebc..a4eeb33c8f484 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1547,6 +1547,7 @@ class TestTriggerCommsDecoder: @pytest.mark.asyncio async def test_recv_trigger_message(self): + """Verify low-level receiving of a TriggerStateSync message through the TriggerCommsDecoder.""" r, w = socketpair() r.setblocking(False) w.setblocking(False) @@ -1582,6 +1583,7 @@ async def test_recv_trigger_message(self): async def _run_trigger_comms_server( self, w_sock: socket, num_requests: int, decoder: TriggerCommsDecoder ) -> None: + """Mock server that receives ``TriggerStateChanges`` requests and replies with ``TriggerStateSync`` responses.""" loop = asyncio.get_running_loop() for _ in range(num_requests): len_bytes = await loop.sock_recv(w_sock, 4) @@ -1606,6 +1608,7 @@ async def _run_trigger_comms_server( @pytest.mark.asyncio async def test_asend_basic(self): + """Test basic async send of a ``TriggerStateChanges`` message and receiving of a ``TriggerStateSync`` response.""" r, w = socketpair() r.setblocking(False) w.setblocking(False) @@ -1634,6 +1637,7 @@ async def test_asend_basic(self): @pytest.mark.asyncio async def test_asend_concurrent_safety(self): + """Ensure that multiple concurrent ``asend()`` calls to the ``TriggerCommsDecoder`` are serialised correctly and each receives its proper response.""" r, w = socketpair() r.setblocking(False) w.setblocking(False) From fff65ffdb094e79daea65e95eb1d9cd1db0965ee Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:20:08 +0200 Subject: [PATCH 11/17] cleanup --- airflow-core/tests/unit/jobs/test_triggerer_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index a4eeb33c8f484..b6b5ada2b1c2c 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1597,7 +1597,7 @@ async def _run_trigger_comms_server( req = decoder.resp_decoder.decode(data) # This is a _RequestFrame # Deserialize the TriggerStateChanges request body - changes = messages.TriggerStateChanges(**req.body) + messages.TriggerStateChanges(**req.body) # Prepare the TriggerStateSync response resp = messages.TriggerStateSync(to_create=[], to_cancel=set()) @@ -1667,5 +1667,5 @@ async def make_request(idx: int): await writer.wait_closed() r.close() - for idx, result in enumerate(results): + for _idx, result in enumerate(results): assert isinstance(result, messages.TriggerStateSync) From 2c38e3b0a6c5cded93aa10034c1a803c0ad09d83 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:24:50 +0200 Subject: [PATCH 12/17] cleanup --- task-sdk/tests/task_sdk/execution_time/test_comms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 6d95e51d0f144..f2b79d74dc5aa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -196,6 +196,7 @@ def send_and_store(idx): assert results[idx].key == f"key{idx}", f"Out-of-order or missing response for thread {idx}" assert results[idx].value == f"value{idx}", f"Incorrect value for thread {idx}" + @staticmethod async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder) -> None: """Handle ``num_requests`` GetVariable frames, responding with VariableResult.""" loop = asyncio.get_running_loop() From 4aeb33b3bc51cea18ebfd0b01df59c42b81b4e0e Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:25:57 +0200 Subject: [PATCH 13/17] cleanup --- task-sdk/tests/task_sdk/execution_time/test_comms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index f2b79d74dc5aa..d9e0ba483f86e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -210,6 +210,7 @@ async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder break data.extend(chunk) req = decoder.resp_decoder.decode(data) + assert req.body is not None key = req.body["key"] resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} resp_frame = _ResponseFrame(req.id, resp, None) From d4afcc9d7753d5dacb8f16662dd117c4400d45a9 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:49:48 +0200 Subject: [PATCH 14/17] cleanup --- task-sdk/tests/task_sdk/execution_time/test_comms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index d9e0ba483f86e..360ea1ef58ff9 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -242,7 +242,7 @@ async def test_asend_basic(self): async def test_asend_concurrent_safety(self): """Multiple concurrent asend calls must not interleave and must each receive the correct response.""" r, w = socketpair() - r.setblocking(False) # <-- required for asyncio + r.setblocking(False) w.setblocking(False) decoder = CommsDecoder(socket=r, log=structlog.get_logger()) num_requests = 5 From 2fedd4b89061010d0cbca59a5046f69ba23931b4 Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 18:50:40 +0200 Subject: [PATCH 15/17] cleanup --- task-sdk/tests/task_sdk/execution_time/test_comms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 360ea1ef58ff9..d146b6b954d3e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -222,7 +222,7 @@ async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder async def test_asend_basic(self): """Verify a single async request‑response cycle via asend.""" r, w = socketpair() - r.setblocking(False) # <-- required for asyncio + r.setblocking(False) w.setblocking(False) decoder = CommsDecoder(socket=r, log=structlog.get_logger()) From ef871ae40ec0284cc151e842cf36430ad4b36d2e Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 19:38:15 +0200 Subject: [PATCH 16/17] cleanup --- airflow-core/tests/unit/jobs/test_triggerer_job.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index b6b5ada2b1c2c..f99b4bb4f760f 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1597,6 +1597,7 @@ async def _run_trigger_comms_server( req = decoder.resp_decoder.decode(data) # This is a _RequestFrame # Deserialize the TriggerStateChanges request body + assert req.body is not None, "Expected a non-None body in the test frame" messages.TriggerStateChanges(**req.body) # Prepare the TriggerStateSync response From 46999be41c0c2f833d1c838f6b1edb07a53aa96f Mon Sep 17 00:00:00 2001 From: Renat Sagutdinov Date: Sun, 26 Apr 2026 19:42:16 +0200 Subject: [PATCH 17/17] cleanup --- task-sdk/tests/task_sdk/execution_time/test_comms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index d146b6b954d3e..94b924913a301 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -210,7 +210,7 @@ async def _variable_server_side(w_sock, num_requests: int, decoder: CommsDecoder break data.extend(chunk) req = decoder.resp_decoder.decode(data) - assert req.body is not None + assert req.body is not None, "Expected a non-None body in the test frame" key = req.body["key"] resp = {"type": "VariableResult", "key": key, "value": f"value_{key}"} resp_frame = _ResponseFrame(req.id, resp, None)