diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 296bba3c24ab3..d497be48dc251 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -370,6 +370,7 @@ "apache.spark": { "deps": [ "apache-airflow-providers-common-compat>=1.12.0", + "aiohttp>=3.14.0", "apache-airflow>=2.11.0", "grpcio-status>=1.67.0", "pyspark-client>=4.0.0", diff --git a/providers/apache/spark/provider.yaml b/providers/apache/spark/provider.yaml index 2a57351b875c3..4dffe05af42fc 100644 --- a/providers/apache/spark/provider.yaml +++ b/providers/apache/spark/provider.yaml @@ -238,6 +238,11 @@ connection-types: - string - 'null' +triggers: + - integration-name: Apache Spark + python-modules: + - airflow.providers.apache.spark.triggers.spark_submit + task-decorators: - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task name: pyspark diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml index 411ed168d5e16..800a272a5d0fd 100644 --- a/providers/apache/spark/pyproject.toml +++ b/providers/apache/spark/pyproject.toml @@ -61,6 +61,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", + "aiohttp>=3.14.0", "pyspark-client>=4.0.0", "grpcio-status>=1.67.0", "requests>=2.32.0", diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ac9b550409ff0..1cc04e00f4855 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -24,6 +24,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, @@ -134,6 +135,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). + :param deferrable: Run operator in deferrable mode. """ # Generic key used across all Spark deployment modes (standalone driver ID, @@ -197,6 +199,7 @@ def __init__( track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), @@ -244,6 +247,7 @@ def __init__( self.reconnect_on_retry = reconnect_on_retry self._track_driver_via_k8s_api = track_driver_via_k8s_api + self.deferrable = deferrable self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -261,6 +265,18 @@ def execute(self, context: Context) -> None: hook = self._hook if self._track_driver_via_k8s_api: hook._validate_track_driver_via_k8s_api_config() + if self.deferrable: + driver_id = self.submit_job(context) + master_urls = self._build_master_rest_urls() + self.defer( + trigger=SparkDriverTrigger( + driver_id=driver_id, + master_urls=master_urls, + poll_interval=self._status_poll_interval, + ), + method_name="execute_complete", + ) + return # unreachable after defer(); keeps type checkers happy if hook._should_track_driver_status: if self.reconnect_on_retry: return self.execute_resumable(context) @@ -290,6 +306,41 @@ def execute(self, context: Context) -> None: return self.get_job_result(driver_id, context) hook.submit(self.application) + def execute_complete(self, context: Context, event: dict) -> None: + """Handle the result emitted by SparkDriverTrigger.""" + status = event.get("status") + driver_id = event.get("driver_id") + if status is None: + raise RuntimeError(f"Malformed trigger event: {event}") + if driver_id is None: + raise RuntimeError(f"Malformed trigger event: {event}") + if status != "success": + driver_state = event.get("driver_state", "") + message = event.get("message", "") + raise RuntimeError( + f"Spark driver {driver_id} did not finish successfully " + f"(state={driver_state}): {message}" + ) + self.log.info("Spark driver %s finished successfully", driver_id) + + def _build_master_rest_urls(self) -> list[str]: + """ + Build Spark master REST API base URLs for SparkDriverTrigger. + + Supports HA (comma-separated master URL) and respects rest_scheme / + rest_port connection extras (same logic as get_job_status). + """ + if self._hook is None: + self._hook = self._get_hook() + scheme = self._hook._connection.get("rest_scheme", "http") + rest_port = self._hook._connection.get("rest_port", 6066) + master_hosts = self._hook._connection["master"].replace("spark://", "").split(",") + urls = [] + for host in master_hosts: + hostname = host.strip().split(":")[0] + urls.append(f"{scheme}://{hostname}:{rest_port}") + return urls + def submit_job(self, context: Context) -> str: if self._hook is None: self._hook = self._get_hook() diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py new file mode 100644 index 0000000000000..26263b544a45b --- /dev/null +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = {"FINISHED", "FAILED", "KILLED", "ERROR"} + + +class SparkDriverTrigger(BaseTrigger): + """ + Async trigger that polls the Spark standalone REST API until the driver + reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + + :param driver_id: Spark driver submission ID returned by spark-submit --rest. + :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. + :param poll_interval: Seconds between REST API polls. Defaults to 10. + """ + + def __init__( + self, + driver_id: str, + master_urls: list[str], + poll_interval: int = 10, + ) -> None: + super().__init__() + self.driver_id = driver_id + self.master_urls = master_urls + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", + { + "driver_id": self.driver_id, + "master_urls": self.master_urls, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Spark REST API until driver reaches a terminal state.""" + while True: + status = await self._poll_driver_status() + if status is None: + yield TriggerEvent( + { + "status": "error", + "driver_id": self.driver_id, + "message": "All Spark masters unreachable", + } + ) + return + self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) + normalized_status = status.upper() + if normalized_status in _SPARK_TERMINAL_STATES: + success = normalized_status == "FINISHED" + yield TriggerEvent( + { + "status": "success" if success else "error", + "driver_id": self.driver_id, + "driver_state": normalized_status, + "message": f"Driver {self.driver_id} reached state {normalized_status}", + } + ) + return + await asyncio.sleep(self.poll_interval) + + async def _poll_driver_status(self) -> str | None: + """Try each master URL; return driverState str or None if all fail.""" + for url in self.master_urls: + status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + data = await resp.json() + if not data.get("success"): + self.log.warning( + "Spark REST API returned success=false for %s: %s", + self.driver_id, + data.get("message", "unknown"), + ) + continue + return data["driverState"] + except aiohttp.ClientError as exc: + self.log.warning("Could not reach Spark master at %s: %s", url, exc) + continue + return None diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 56b2ad3a409e5..4d3bae4e6c8e7 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -920,3 +920,137 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + + def test_deferrable_stored_on_operator(self): + """deferrable=True must be stored as self.deferrable.""" + op = self._make_operator() + assert op.deferrable is True + + def test_execute_calls_defer_when_deferrable_true(self): + """execute() must call self.defer() when deferrable=True.""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_called_once() + call_kwargs = mock_defer.call_args.kwargs + assert call_kwargs["method_name"] == "execute_complete" + + def test_execute_passes_correct_args_to_trigger(self): + """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + + op = self._make_operator(status_poll_interval=15) + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + trigger = mock_defer.call_args.kwargs["trigger"] + assert isinstance(trigger, SparkDriverTrigger) + assert trigger.driver_id == "driver-xyz" + assert trigger.master_urls == ["http://m1:6066"] + assert trigger.poll_interval == 15 + + def test_execute_does_not_call_hook_submit_directly(self): + """execute() in deferrable mode must use submit_job(), not hook.submit().""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer"): + op.execute(context={}) + + hook.submit.assert_not_called() + + def test_execute_complete_succeeds_on_success_event(self): + """execute_complete() must not raise when status=success.""" + op = self._make_operator() + event = { + "status": "success", + "driver_id": "driver-001", + "driver_state": "FINISHED", + "message": "Driver reached FINISHED", + } + op.execute_complete(context={}, event=event) # must not raise + + def test_execute_complete_raises_on_error_event(self): + """execute_complete() must raise AirflowException when status=error.""" + + op = self._make_operator() + event = { + "status": "error", + "driver_id": "driver-001", + "driver_state": "FAILED", + "message": "Driver reached FAILED", + } + with pytest.raises(RuntimeError, match="driver-001"): + op.execute_complete(context={}, event=event) + + def test_build_master_rest_urls_single_master(self): + """_build_master_rest_urls must return correct URL for a single master.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://myhost:7077", + "rest_scheme": "http", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["http://myhost:6066"] + + def test_build_master_rest_urls_ha_multiple_masters(self): + """_build_master_rest_urls must return a URL per master in HA mode.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://m1:7077,m2:7077", + "rest_scheme": "https", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["https://m1:6066", "https://m2:6066"] diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py new file mode 100644 index 0000000000000..d22dd0c150bc7 --- /dev/null +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py @@ -0,0 +1,229 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger +from airflow.triggers.base import TriggerEvent + + +class TestSparkDriverTrigger: + """Tests for SparkDriverTrigger.""" + + def _make_trigger(self, driver_id="driver-001", master_urls=None, poll_interval=1): + return SparkDriverTrigger( + driver_id=driver_id, + master_urls=master_urls or ["http://spark-master:6066"], + poll_interval=poll_interval, + ) + + # ── serialize ────────────────────────────────────────────────────── + + def test_serialize_roundtrip(self): + """serialize() must return the correct classpath and all constructor args.""" + trigger = self._make_trigger( + driver_id="driver-abc", + master_urls=["http://m1:6066", "http://m2:6066"], + poll_interval=5, + ) + classpath, kwargs = trigger.serialize() + + assert classpath == "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger" + assert kwargs == { + "driver_id": "driver-abc", + "master_urls": ["http://m1:6066", "http://m2:6066"], + "poll_interval": 5, + } + + def test_serialize_can_reconstruct_trigger(self): + """A trigger reconstructed from serialize() is identical to the original.""" + original = self._make_trigger(driver_id="driver-xyz", poll_interval=30) + _, kwargs = original.serialize() + reconstructed = SparkDriverTrigger(**kwargs) + + assert reconstructed.driver_id == original.driver_id + assert reconstructed.master_urls == original.master_urls + assert reconstructed.poll_interval == original.poll_interval + + # ── run — terminal states ────────────────────────────────────────── + + @pytest.mark.asyncio + @pytest.mark.parametrize("state", ["FINISHED", "FAILED", "KILLED", "ERROR"]) + async def test_run_yields_event_on_terminal_state(self, state): + """run() must yield a TriggerEvent when driver reaches a terminal state.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=state) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["driver_id"] == "driver-001" + assert events[0].payload["driver_state"] == state.upper() + + @pytest.mark.asyncio + async def test_run_success_on_finished(self): + """FINISHED state must produce status=success.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value="FINISHED") + + events = [] + async for event in trigger.run(): + events.append(event) + + assert events[0].payload["status"] == "success" + + @pytest.mark.asyncio + @pytest.mark.parametrize("state", ["FAILED", "KILLED", "ERROR"]) + async def test_run_error_on_non_finished_terminal(self, state): + """Non-FINISHED terminal states must produce status=error.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=state) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert events[0].payload["status"] == "error" + + # ── run — active states loop ─────────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_polls_until_terminal(self): + """run() must keep polling while driver is in an active state.""" + trigger = self._make_trigger(poll_interval=0) + trigger._poll_driver_status = mock.AsyncMock( + side_effect=["RUNNING", "RUNNING", "FINISHED"] + ) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert trigger._poll_driver_status.call_count == 3 + assert events[0].payload["status"] == "success" + + # ── run — all masters unreachable ────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_yields_error_when_all_masters_unreachable(self): + """run() must yield an error event when _poll_driver_status returns None.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=None) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert "unreachable" in events[0].payload["message"].lower() + + # ── _poll_driver_status ──────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_poll_returns_driver_state(self): + """_poll_driver_status must return driverState from a successful REST response.""" + trigger = self._make_trigger() + mock_resp = mock.AsyncMock() + mock_resp.raise_for_status = mock.MagicMock() + mock_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = mock.AsyncMock(return_value=False) + + mock_session = mock.AsyncMock() + mock_session.get.return_value = mock_resp + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result == "RUNNING" + + @pytest.mark.asyncio + async def test_poll_skips_master_on_success_false(self): + """When success=false from one master, must try the next master URL.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + bad_resp = mock.AsyncMock() + bad_resp.raise_for_status = mock.MagicMock() + bad_resp.json = mock.AsyncMock(return_value={"success": False, "message": "not found"}) + bad_resp.__aenter__ = mock.AsyncMock(return_value=bad_resp) + bad_resp.__aexit__ = mock.AsyncMock(return_value=False) + good_resp = mock.AsyncMock() + good_resp.raise_for_status = mock.MagicMock() + good_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + good_resp.__aenter__ = mock.AsyncMock(return_value=good_resp) + good_resp.__aexit__ = mock.AsyncMock(return_value=False) + mock_session = mock.AsyncMock() + mock_session.get.side_effect = [bad_resp, good_resp] + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + assert result == "RUNNING" + + @pytest.mark.asyncio + async def test_poll_tries_next_master_on_exception(self): + """_poll_driver_status must try the next master URL when one raises.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + + call_count = 0 + + async def fake_get(url, **kwargs): + nonlocal call_count + call_count += 1 + if "m1" in url: + raise ConnectionError("m1 down") + mock_resp = mock.AsyncMock() + mock_resp.raise_for_status = mock.MagicMock() + mock_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = mock.AsyncMock(return_value=False) + return mock_resp + + mock_session = mock.AsyncMock() + mock_session.get.side_effect = fake_get + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result == "RUNNING" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_poll_returns_none_when_all_masters_fail(self): + """_poll_driver_status must return None when all master URLs raise.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + + mock_session = mock.AsyncMock() + mock_session.get.side_effect = ConnectionError("all down") + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result is None