From facee8aa2c1b89e3b192a00ab3dd0db9d9ba58ac Mon Sep 17 00:00:00 2001 From: pierrejeambrun Date: Fri, 23 Jan 2026 14:11:32 +0100 Subject: [PATCH] Revert "Prevent Triggerer from crashing when a trigger event isn't serializable (#60152)" This reverts commit b7d1c41e61a41fa6919223f31a73bb817c64317f. --- .../src/airflow/jobs/triggerer_job_runner.py | 91 ++++++------------- .../tests/unit/jobs/test_triggerer_job.py | 28 +----- 2 files changed, 31 insertions(+), 88 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index cc78e21245aef..c229dddc7fd20 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -70,13 +70,12 @@ UpdateHITLDetail, VariableResult, XComResult, - _new_encoder, _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader -from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent from airflow.stats import Stats from airflow.traces.tracer import DebugTrace, Trace, add_debug_span +from airflow.triggers import base as events from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string @@ -204,7 +203,7 @@ class TriggerStateChanges(BaseModel): type: Literal["TriggerStateChanges"] = "TriggerStateChanges" events: Annotated[ - list[tuple[int, DiscrimatedTriggerEvent]] | None, + list[tuple[int, events.DiscrimatedTriggerEvent]] | None, # We have to specify a default here, as otherwise Pydantic struggles to deal with the discriminated # union :shrug: Field(default=None), @@ -356,7 +355,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False) # Outbound queue of events - events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, init=False) + events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False) # Outbound queue of failed triggers failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False) @@ -759,6 +758,8 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): factory=lambda: TypeAdapter(ToTriggerRunner), repr=False ) + _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + def _read_frame(self): from asgiref.sync import async_to_sync @@ -793,7 +794,7 @@ async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) bytes = frame.as_bytes() - async with self._async_lock: + async with self._lock: self._async_writer.write(bytes) return await self._aget_response(frame.id) @@ -820,7 +821,7 @@ class TriggerRunner: to_cancel: deque[int] # Outbound queue of events - events: deque[tuple[int, TriggerEvent]] + events: deque[tuple[int, events.TriggerEvent]] # Outbound queue of failed triggers failed_triggers: deque[tuple[int, BaseException | None]] @@ -970,7 +971,7 @@ async def create_triggers(self): "task": asyncio.create_task( self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name ), - "is_watcher": isinstance(trigger_instance, BaseEventTrigger), + "is_watcher": isinstance(trigger_instance, events.BaseEventTrigger), "name": trigger_name, "events": 0, } @@ -1016,7 +1017,7 @@ async def cleanup_finished_triggers(self) -> list[int]: saved_exc = e else: # See if they foolishly returned a TriggerEvent - if isinstance(result, TriggerEvent): + if isinstance(result, events.TriggerEvent): self.log.error( "Trigger returned a TriggerEvent rather than yielding it", trigger=details["name"], @@ -1036,78 +1037,46 @@ async def cleanup_finished_triggers(self) -> list[int]: await asyncio.sleep(0) return finished_ids - def process_trigger_events(self, finished_ids: list[int]) -> messages.TriggerStateChanges: + async def sync_state_to_supervisor(self, finished_ids: list[int]): # Copy out of our deques in threadsafe manner to sync state with parent - events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] - failures_to_send: list[tuple[int, list[str] | None]] = [] - + events_to_send = [] while self.events: - trigger_id, trigger_event = self.events.popleft() - events_to_send.append((trigger_id, trigger_event)) + data = self.events.popleft() + events_to_send.append(data) + failures_to_send = [] while self.failed_triggers: - trigger_id, exc = self.failed_triggers.popleft() + id, exc = self.failed_triggers.popleft() tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None - failures_to_send.append((trigger_id, tb)) + failures_to_send.append((id, tb)) - return messages.TriggerStateChanges( - events=events_to_send if events_to_send else None, - finished=finished_ids if finished_ids else None, - failures=failures_to_send if failures_to_send else None, + msg = messages.TriggerStateChanges( + events=events_to_send, finished=finished_ids, failures=failures_to_send ) - def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateChanges: - req_encoder = _new_encoder() - events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] + if not events_to_send: + msg.events = None - if msg.events: - for trigger_id, trigger_event in msg.events: - try: - req_encoder.encode(trigger_event) - except Exception as e: - logger.error( - "Trigger %s returned non-serializable result %r. Cancelling trigger.", - trigger_id, - trigger_event, - ) - self.failed_triggers.append((trigger_id, e)) - else: - events_to_send.append((trigger_id, trigger_event)) - - return messages.TriggerStateChanges( - events=events_to_send if events_to_send else None, - finished=msg.finished, - failures=msg.failures, - ) + if not failures_to_send: + msg.failures = None - async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None: - msg = self.process_trigger_events(finished_ids=finished_ids) + if not finished_ids: + msg.finished = None # Tell the monitor that we've finished triggers so it can update things try: - resp = await self.asend(msg) - except NotImplementedError: - # A non-serializable trigger event was detected, remove it and fail associated trigger - resp = await self.asend(self.sanitize_trigger_events(msg)) - - if resp: - self.to_create.extend(resp.to_create) - self.to_cancel.extend(resp.to_cancel) - - async def asend(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateSync | None: - try: - response = await self.comms_decoder.asend(msg) - - if not isinstance(response, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") - - return response + resp = await self.comms_decoder.asend(msg) except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") return raise + if not isinstance(resp, messages.TriggerStateSync): + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") + self.to_create.extend(resp.to_create) + self.to_cancel.extend(resp.to_cancel) + async def block_watchdog(self): """ Watchdog loop that detects blocking (badly-written) triggers. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 116b3c115df3a..3181afa540c67 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -371,7 +371,7 @@ async def test_invalid_trigger(self, supervisor_builder): trigger_runner = TriggerRunner() trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( - to_create=[], to_cancel=set() + to_create=[], to_cancel=[] ) trigger_runner.to_create.append(workload) @@ -438,32 +438,6 @@ async def test_trigger_kwargs_serialization_cleanup(self, session): trigger_instance.cancel() await runner.cleanup_finished_triggers() - @pytest.mark.asyncio - @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) - async def test_sync_state_to_supervisor(self, supervisor_builder): - trigger_runner = TriggerRunner() - trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) - trigger_runner.events.append((1, TriggerEvent(payload={"status": "SUCCESS"}))) - trigger_runner.events.append((2, TriggerEvent(payload={"status": "FAILED"}))) - trigger_runner.events.append((3, TriggerEvent(payload={"status": "SUCCESS", "data": object()}))) - - async def asend_side_effect(msg): - if msg.events and len(msg.events) == 3: - raise NotImplementedError("Simulate non-serializable event") - return messages.TriggerStateSync(to_create=[], to_cancel=set()) - - trigger_runner.comms_decoder.asend.side_effect = asend_side_effect - - await trigger_runner.sync_state_to_supervisor(finished_ids=[]) - - assert trigger_runner.comms_decoder.asend.call_count == 2 - - first_call = trigger_runner.comms_decoder.asend.call_args_list[0].args[0] - second_call = trigger_runner.comms_decoder.asend.call_args_list[1].args[0] - - assert len(first_call.events) == 3 - assert len(second_call.events) == 2 - @pytest.mark.asyncio async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle):