From 40fe70d1207ba7d8e6e2ab9eb3242f361b8bae07 Mon Sep 17 00:00:00 2001 From: bujjibabukatta Date: Tue, 9 Jun 2026 13:25:54 +0530 Subject: [PATCH 1/4] Add deferrable=True support to SparkSubmitOperator (issue #67168) --- providers/apache/spark/provider.yaml | 5 + providers/apache/spark/pyproject.toml | 1 + .../apache/spark/operators/spark_submit.py | 51 ++++ .../apache/spark/triggers/__init__.py | 16 ++ .../apache/spark/triggers/spark_submit.py | 110 +++++++++ .../spark/operators/test_spark_submit.py | 153 ++++++++++++ .../unit/apache/spark/triggers/__init__.py | 16 ++ .../spark/triggers/test_spark_submit.py | 227 ++++++++++++++++++ 8 files changed, 579 insertions(+) create mode 100644 providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py create mode 100644 providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py create mode 100644 providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py create mode 100644 providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py diff --git a/providers/apache/spark/provider.yaml b/providers/apache/spark/provider.yaml index 2a57351b875c3..73b4da667e809 100644 --- a/providers/apache/spark/provider.yaml +++ b/providers/apache/spark/provider.yaml @@ -241,3 +241,8 @@ connection-types: task-decorators: - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task name: pyspark + +triggers: + - integration-name: Apache Spark + python-modules: + - airflow.providers.apache.spark.triggers.spark_submit diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml index 411ed168d5e16..a916eda2faffd 100644 --- a/providers/apache/spark/pyproject.toml +++ b/providers/apache/spark/pyproject.toml @@ -61,6 +61,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", + "aiohttp>=3.9.0", "pyspark-client>=4.0.0", "grpcio-status>=1.67.0", "requests>=2.32.0", diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 3ac4870f313fd..0ea0e09105980 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -134,6 +134,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). + :param deferrable: If ``True``, submits the job then defers to + ``SparkDriverTrigger``; the worker slot is freed while the trigger + polls the Spark REST API. On crash the trigger is re-created from + its serialised state (no reconnect needed). On user-clear, execute() + runs again and a fresh job is submitted. + If ``False`` (default), the sync ``ResumableJobMixin`` path is used. """ # Generic key used across all Spark deployment modes (standalone driver ID, @@ -197,6 +203,7 @@ def __init__( track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), @@ -244,6 +251,7 @@ def __init__( self.reconnect_on_retry = reconnect_on_retry self._track_driver_via_k8s_api = track_driver_via_k8s_api + self.deferrable = deferrable self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -261,6 +269,19 @@ def execute(self, context: Context) -> None: hook = self._hook if self._track_driver_via_k8s_api: hook._validate_track_driver_via_k8s_api_config() + if self.deferrable: + driver_id = self.submit_job(context) + master_urls = self._build_master_rest_urls() + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + self.defer( + trigger=SparkDriverTrigger( + driver_id=driver_id, + master_urls=master_urls, + poll_interval=self._status_poll_interval, + ), + method_name="execute_complete", + ) + return # unreachable after defer(); keeps type checkers happy if hook._should_track_driver_status: if self.reconnect_on_retry: return self.execute_resumable(context) @@ -275,6 +296,36 @@ def execute(self, context: Context) -> None: hook._poll_k8s_driver_via_api() return hook.submit(self.application) + + def execute_complete(self, context: Context, event: dict) -> None: + """ + Handle the result emitted by SparkDriverTrigger. + Called by Airflow when the trigger fires after deferrable=True execution. + Raises AirflowException if the driver did not finish successfully. + """ + from airflow.providers.common.compat.sdk import AirflowException + driver_state = event.get("driver_state", "UNKNOWN") + driver_id = event.get("driver_id", "unknown") + message = event.get("message", "") + if event.get("status") != "success": + raise AirflowException( + f"Spark driver {driver_id} did not finish successfully " + f"(state={driver_state}): {message}" + ) + self.log.info("Spark driver %s finished successfully (state=%s)", driver_id, driver_state) + def _build_master_rest_urls(self) -> list[str]: + """ + Build Spark master REST API base URLs for SparkDriverTrigger. + Supports HA (comma-separated master URL) and respects rest_scheme / + rest_port connection extras (same logic as get_job_status). + """ + if self._hook is None: + self._hook = self._get_hook() + scheme = self._hook._connection.get("rest_scheme", "http") + rest_port = self._hook._connection.get("rest_port", 6066) + master_hosts = self._hook._connection["master"].replace("spark://", "").split(",") + return [f"{scheme}://{m.strip().split(':')[0]}:{rest_port}" for m in master_hosts] + def submit_job(self, context: Context) -> str: if self._hook is None: diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py new file mode 100644 index 0000000000000..d216be4ddc949 --- /dev/null +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py new file mode 100644 index 0000000000000..569e949e7b764 --- /dev/null +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) +_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) + + +class SparkDriverTrigger(BaseTrigger): + """ + Async trigger that polls the Spark standalone REST API until the driver + reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + :param driver_id: Spark driver submission ID returned by spark-submit --rest. + :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. + :param poll_interval: Seconds between REST API polls. Defaults to 10. + """ + + def __init__( + self, + driver_id: str, + master_urls: list[str], + poll_interval: int = 10, + ) -> None: + super().__init__() + self.driver_id = driver_id + self.master_urls = master_urls + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", + { + "driver_id": self.driver_id, + "master_urls": self.master_urls, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Spark REST API until driver reaches a terminal state.""" + while True: + status = await self._poll_driver_status() + if status is None: + yield TriggerEvent( + { + "status": "error", + "driver_id": self.driver_id, + "message": "All Spark masters unreachable", + } + ) + return + self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) + upper = status.upper() + if upper in _SPARK_TERMINAL_STATES: + success = upper == "FINISHED" + yield TriggerEvent( + { + "status": "success" if success else "error", + "driver_id": self.driver_id, + "driver_state": upper, + "message": f"Driver {self.driver_id} reached state {upper}", + } + ) + return + await asyncio.sleep(self.poll_interval) + + async def _poll_driver_status(self) -> str | None: + """Try each master URL; return driverState str or None if all fail.""" + for url in self.master_urls: + status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + data = await resp.json() + if not data.get("success"): + self.log.warning( + "Spark REST API returned success=false for %s: %s", + self.driver_id, + data.get("message", "unknown"), + ) + return "UNKNOWN" + return data["driverState"] + except Exception as exc: + self.log.warning("Could not reach Spark master at %s: %s", url, exc) + continue + return None \ No newline at end of file diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 95ad9f5142a57..5fd533168e3cc 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -773,3 +773,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") + assert op.deferrable is False + + def test_deferrable_stored_on_operator(self): + """deferrable=True must be stored as self.deferrable.""" + op = self._make_operator() + assert op.deferrable is True + + def test_execute_calls_defer_when_deferrable_true(self): + """execute() must call self.defer() when deferrable=True.""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_called_once() + call_kwargs = mock_defer.call_args.kwargs + assert call_kwargs["method_name"] == "execute_complete" + + def test_execute_passes_correct_args_to_trigger(self): + """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + + op = self._make_operator(status_poll_interval=15) + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + trigger = mock_defer.call_args.kwargs["trigger"] + assert isinstance(trigger, SparkDriverTrigger) + assert trigger.driver_id == "driver-xyz" + assert trigger.master_urls == ["http://m1:6066"] + assert trigger.poll_interval == 15 + + def test_execute_does_not_call_hook_submit_directly(self): + """execute() in deferrable mode must use submit_job(), not hook.submit().""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer"): + op.execute(context={}) + + hook.submit.assert_not_called() + + def test_execute_complete_succeeds_on_success_event(self): + """execute_complete() must not raise when status=success.""" + op = self._make_operator() + event = { + "status": "success", + "driver_id": "driver-001", + "driver_state": "FINISHED", + "message": "Driver reached FINISHED", + } + op.execute_complete(context={}, event=event) # must not raise + + def test_execute_complete_raises_on_error_event(self): + """execute_complete() must raise AirflowException when status=error.""" + from airflow.providers.common.compat.sdk import AirflowException + + op = self._make_operator() + event = { + "status": "error", + "driver_id": "driver-001", + "driver_state": "FAILED", + "message": "Driver reached FAILED", + } + with pytest.raises(AirflowException, match="driver-001"): + op.execute_complete(context={}, event=event) + + def test_build_master_rest_urls_single_master(self): + """_build_master_rest_urls must return correct URL for a single master.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://myhost:7077", + "rest_scheme": "http", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["http://myhost:6066"] + + def test_build_master_rest_urls_ha_multiple_masters(self): + """_build_master_rest_urls must return a URL per master in HA mode.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://m1:7077,m2:7077", + "rest_scheme": "https", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["https://m1:6066", "https://m2:6066"] + + def test_deferrable_false_uses_sync_path(self): + """deferrable=False must fall through to the existing sync path (no defer call).""" + op = SparkSubmitOperator( + task_id="sync", dag=self.dag, application="app.jar", deferrable=False + ) + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + op._hook = hook + + with mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_not_called() + hook.submit.assert_called_once_with("app.jar") diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py new file mode 100644 index 0000000000000..d216be4ddc949 --- /dev/null +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py new file mode 100644 index 0000000000000..58d8a25da7b8e --- /dev/null +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py @@ -0,0 +1,227 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger +from airflow.triggers.base import TriggerEvent + + +class TestSparkDriverTrigger: + """Tests for SparkDriverTrigger.""" + + def _make_trigger(self, driver_id="driver-001", master_urls=None, poll_interval=1): + return SparkDriverTrigger( + driver_id=driver_id, + master_urls=master_urls or ["http://spark-master:6066"], + poll_interval=poll_interval, + ) + + # ── serialize ────────────────────────────────────────────────────── + + def test_serialize_roundtrip(self): + """serialize() must return the correct classpath and all constructor args.""" + trigger = self._make_trigger( + driver_id="driver-abc", + master_urls=["http://m1:6066", "http://m2:6066"], + poll_interval=5, + ) + classpath, kwargs = trigger.serialize() + + assert classpath == "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger" + assert kwargs == { + "driver_id": "driver-abc", + "master_urls": ["http://m1:6066", "http://m2:6066"], + "poll_interval": 5, + } + + def test_serialize_can_reconstruct_trigger(self): + """A trigger reconstructed from serialize() is identical to the original.""" + original = self._make_trigger(driver_id="driver-xyz", poll_interval=30) + _, kwargs = original.serialize() + reconstructed = SparkDriverTrigger(**kwargs) + + assert reconstructed.driver_id == original.driver_id + assert reconstructed.master_urls == original.master_urls + assert reconstructed.poll_interval == original.poll_interval + + # ── run — terminal states ────────────────────────────────────────── + + @pytest.mark.asyncio + @pytest.mark.parametrize("state", ["FINISHED", "FAILED", "KILLED", "ERROR"]) + async def test_run_yields_event_on_terminal_state(self, state): + """run() must yield a TriggerEvent when driver reaches a terminal state.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=state) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["driver_id"] == "driver-001" + assert events[0].payload["driver_state"] == state.upper() + + @pytest.mark.asyncio + async def test_run_success_on_finished(self): + """FINISHED state must produce status=success.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value="FINISHED") + + events = [] + async for event in trigger.run(): + events.append(event) + + assert events[0].payload["status"] == "success" + + @pytest.mark.asyncio + @pytest.mark.parametrize("state", ["FAILED", "KILLED", "ERROR"]) + async def test_run_error_on_non_finished_terminal(self, state): + """Non-FINISHED terminal states must produce status=error.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=state) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert events[0].payload["status"] == "error" + + # ── run — active states loop ─────────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_polls_until_terminal(self): + """run() must keep polling while driver is in an active state.""" + trigger = self._make_trigger(poll_interval=0) + trigger._poll_driver_status = mock.AsyncMock( + side_effect=["RUNNING", "RUNNING", "FINISHED"] + ) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert trigger._poll_driver_status.call_count == 3 + assert events[0].payload["status"] == "success" + + # ── run — all masters unreachable ────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_yields_error_when_all_masters_unreachable(self): + """run() must yield an error event when _poll_driver_status returns None.""" + trigger = self._make_trigger() + trigger._poll_driver_status = mock.AsyncMock(return_value=None) + + events = [] + async for event in trigger.run(): + events.append(event) + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert "unreachable" in events[0].payload["message"].lower() + + # ── _poll_driver_status ──────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_poll_returns_driver_state(self): + """_poll_driver_status must return driverState from a successful REST response.""" + trigger = self._make_trigger() + mock_resp = mock.AsyncMock() + mock_resp.raise_for_status = mock.MagicMock() + mock_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = mock.AsyncMock(return_value=False) + + mock_session = mock.AsyncMock() + mock_session.get.return_value = mock_resp + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result == "RUNNING" + + @pytest.mark.asyncio + async def test_poll_returns_unknown_on_success_false(self): + """When success=false, _poll_driver_status must return UNKNOWN (HA failover).""" + trigger = self._make_trigger() + mock_resp = mock.AsyncMock() + mock_resp.raise_for_status = mock.MagicMock() + mock_resp.json = mock.AsyncMock(return_value={"success": False, "message": "not found"}) + mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = mock.AsyncMock(return_value=False) + + mock_session = mock.AsyncMock() + mock_session.get.return_value = mock_resp + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result == "UNKNOWN" + + @pytest.mark.asyncio + async def test_poll_tries_next_master_on_exception(self): + """_poll_driver_status must try the next master URL when one raises.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + + call_count = 0 + + async def fake_get(url, **kwargs): + nonlocal call_count + call_count += 1 + if "m1" in url: + raise ConnectionError("m1 down") + mock_resp = mock.AsyncMock() + mock_resp.raise_for_status = mock.MagicMock() + mock_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = mock.AsyncMock(return_value=False) + return mock_resp + + mock_session = mock.AsyncMock() + mock_session.get.side_effect = fake_get + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result == "RUNNING" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_poll_returns_none_when_all_masters_fail(self): + """_poll_driver_status must return None when all master URLs raise.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + + mock_session = mock.AsyncMock() + mock_session.get.side_effect = ConnectionError("all down") + mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) + mock_session.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await trigger._poll_driver_status() + + assert result is None \ No newline at end of file From f4def2c1994e7281879afb9b5651c21b9668dc27 Mon Sep 17 00:00:00 2001 From: bujjibabukatta Date: Wed, 10 Jun 2026 23:48:59 +0530 Subject: [PATCH 2/4] fix(spark): address review feedback on deferrable SparkSubmitOperator --- .../apache/spark/operators/spark_submit.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 892ed560f13ec..4aec7c9a3446e 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -29,6 +29,7 @@ inject_transport_information_into_spark_properties, ) from airflow.providers.common.compat.sdk import BaseOperator, conf +from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger try: from airflow.sdk import ResumableJobMixin @@ -134,12 +135,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). - :param deferrable: If ``True``, submits the job then defers to - ``SparkDriverTrigger``; the worker slot is freed while the trigger - polls the Spark REST API. On crash the trigger is re-created from - its serialised state (no reconnect needed). On user-clear, execute() - runs again and a fresh job is submitted. - If ``False`` (default), the sync ``ResumableJobMixin`` path is used. + :param deferrable: Run operator in deferrable mode. """ # Generic key used across all Spark deployment modes (standalone driver ID, @@ -272,7 +268,6 @@ def execute(self, context: Context) -> None: if self.deferrable: driver_id = self.submit_job(context) master_urls = self._build_master_rest_urls() - from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger self.defer( trigger=SparkDriverTrigger( driver_id=driver_id, @@ -312,21 +307,22 @@ def execute(self, context: Context) -> None: hook.submit(self.application) def execute_complete(self, context: Context, event: dict) -> None: - """ - Handle the result emitted by SparkDriverTrigger. - Called by Airflow when the trigger fires after deferrable=True execution. - Raises AirflowException if the driver did not finish successfully. - """ - from airflow.providers.common.compat.sdk import AirflowException - driver_state = event.get("driver_state", "UNKNOWN") - driver_id = event.get("driver_id", "unknown") - message = event.get("message", "") - if event.get("status") != "success": - raise AirflowException( + """Handle the result emitted by SparkDriverTrigger.""" + status = event.get("status") + driver_id = event.get("driver_id") + if status is None: + raise RuntimeError(f"Malformed trigger event: {event}") + if driver_id is None: + raise RuntimeError(f"Malformed trigger event: {event}") + if status != "success": + driver_state = event.get("driver_state", "") + message = event.get("message", "") + raise RuntimeError( f"Spark driver {driver_id} did not finish successfully " f"(state={driver_state}): {message}" ) - self.log.info("Spark driver %s finished successfully (state=%s)", driver_id, driver_state) + self.log.info("Spark driver %s finished successfully", driver_id) + def _build_master_rest_urls(self) -> list[str]: """ Build Spark master REST API base URLs for SparkDriverTrigger. @@ -338,7 +334,11 @@ def _build_master_rest_urls(self) -> list[str]: scheme = self._hook._connection.get("rest_scheme", "http") rest_port = self._hook._connection.get("rest_port", 6066) master_hosts = self._hook._connection["master"].replace("spark://", "").split(",") - return [f"{scheme}://{m.strip().split(':')[0]}:{rest_port}" for m in master_hosts] + urls = [] + for host in master_hosts: + hostname = host.strip().split(":")[0] + urls.append(f"{scheme}://{hostname}:{rest_port}") + return urls def submit_job(self, context: Context) -> str: From f3d63352e8f2d5d9a532d68ca219e2da05a3b4b6 Mon Sep 17 00:00:00 2001 From: bujjibabukatta Date: Thu, 11 Jun 2026 01:26:50 +0530 Subject: [PATCH 3/4] fix(spark): address review feedback on deferrable SparkSubmitOperator --- .../airflow/providers/apache/spark/operators/spark_submit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 4aec7c9a3446e..8ea84e3e0f527 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -24,15 +24,15 @@ from tenacity import retry, stop_after_attempt, wait_fixed from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, ) from airflow.providers.common.compat.sdk import BaseOperator, conf -from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger try: - from airflow.sdk import ResumableJobMixin + from airflow.sdk.bases.resumablemixin import ResumableJobMixin except ImportError: # Airflow 2 compat. # ResumableJobMixin does not exist in Airflow 2, so we need to add a stub to make it @@ -340,7 +340,6 @@ def _build_master_rest_urls(self) -> list[str]: urls.append(f"{scheme}://{hostname}:{rest_port}") return urls - def submit_job(self, context: Context) -> str: if self._hook is None: self._hook = self._get_hook() From febcaad71d81f4e7e1b83617fb9e98b9d9bab6cc Mon Sep 17 00:00:00 2001 From: bujjibabukatta Date: Fri, 12 Jun 2026 10:49:49 +0530 Subject: [PATCH 4/4] fix(spark): resolve all static check failures and address reviewer feedback --- generated/provider_dependencies.json | 1 + providers/apache/spark/provider.yaml | 8 ++--- providers/apache/spark/pyproject.toml | 2 +- .../apache/spark/operators/spark_submit.py | 7 +++-- .../apache/spark/triggers/__init__.py | 2 +- .../apache/spark/triggers/spark_submit.py | 26 ++++++++-------- .../spark/operators/test_spark_submit.py | 25 ++-------------- .../unit/apache/spark/triggers/__init__.py | 2 +- .../spark/triggers/test_spark_submit.py | 30 ++++++++++--------- 9 files changed, 44 insertions(+), 59 deletions(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 296bba3c24ab3..d497be48dc251 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -370,6 +370,7 @@ "apache.spark": { "deps": [ "apache-airflow-providers-common-compat>=1.12.0", + "aiohttp>=3.14.0", "apache-airflow>=2.11.0", "grpcio-status>=1.67.0", "pyspark-client>=4.0.0", diff --git a/providers/apache/spark/provider.yaml b/providers/apache/spark/provider.yaml index 73b4da667e809..4dffe05af42fc 100644 --- a/providers/apache/spark/provider.yaml +++ b/providers/apache/spark/provider.yaml @@ -238,11 +238,11 @@ connection-types: - string - 'null' -task-decorators: - - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task - name: pyspark - triggers: - integration-name: Apache Spark python-modules: - airflow.providers.apache.spark.triggers.spark_submit + +task-decorators: + - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task + name: pyspark diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml index a916eda2faffd..800a272a5d0fd 100644 --- a/providers/apache/spark/pyproject.toml +++ b/providers/apache/spark/pyproject.toml @@ -61,7 +61,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "aiohttp>=3.9.0", + "aiohttp>=3.14.0", "pyspark-client>=4.0.0", "grpcio-status>=1.67.0", "requests>=2.32.0", diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 8ea84e3e0f527..1cc04e00f4855 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -32,7 +32,7 @@ from airflow.providers.common.compat.sdk import BaseOperator, conf try: - from airflow.sdk.bases.resumablemixin import ResumableJobMixin + from airflow.sdk import ResumableJobMixin except ImportError: # Airflow 2 compat. # ResumableJobMixin does not exist in Airflow 2, so we need to add a stub to make it @@ -305,7 +305,7 @@ def execute(self, context: Context) -> None: self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) hook.submit(self.application) - + def execute_complete(self, context: Context, event: dict) -> None: """Handle the result emitted by SparkDriverTrigger.""" status = event.get("status") @@ -322,10 +322,11 @@ def execute_complete(self, context: Context, event: dict) -> None: f"(state={driver_state}): {message}" ) self.log.info("Spark driver %s finished successfully", driver_id) - + def _build_master_rest_urls(self) -> list[str]: """ Build Spark master REST API base URLs for SparkDriverTrigger. + Supports HA (comma-separated master URL) and respects rest_scheme / rest_port connection extras (same logic as get_job_status). """ diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py index d216be4ddc949..13a83393a9124 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/__init__.py @@ -13,4 +13,4 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. \ No newline at end of file +# under the License. diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py index 569e949e7b764..26263b544a45b 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py @@ -25,19 +25,19 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent -_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) -_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) +_SPARK_TERMINAL_STATES = {"FINISHED", "FAILED", "KILLED", "ERROR"} class SparkDriverTrigger(BaseTrigger): """ Async trigger that polls the Spark standalone REST API until the driver reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + :param driver_id: Spark driver submission ID returned by spark-submit --rest. :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. :param poll_interval: Seconds between REST API polls. Defaults to 10. """ - + def __init__( self, driver_id: str, @@ -58,7 +58,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "poll_interval": self.poll_interval, }, ) - + async def run(self) -> AsyncIterator[TriggerEvent]: """Poll Spark REST API until driver reaches a terminal state.""" while True: @@ -73,20 +73,20 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) return self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) - upper = status.upper() - if upper in _SPARK_TERMINAL_STATES: - success = upper == "FINISHED" + normalized_status = status.upper() + if normalized_status in _SPARK_TERMINAL_STATES: + success = normalized_status == "FINISHED" yield TriggerEvent( { "status": "success" if success else "error", "driver_id": self.driver_id, - "driver_state": upper, - "message": f"Driver {self.driver_id} reached state {upper}", + "driver_state": normalized_status, + "message": f"Driver {self.driver_id} reached state {normalized_status}", } ) return await asyncio.sleep(self.poll_interval) - + async def _poll_driver_status(self) -> str | None: """Try each master URL; return driverState str or None if all fail.""" for url in self.master_urls: @@ -102,9 +102,9 @@ async def _poll_driver_status(self) -> str | None: self.driver_id, data.get("message", "unknown"), ) - return "UNKNOWN" + continue return data["driverState"] - except Exception as exc: + except aiohttp.ClientError as exc: self.log.warning("Could not reach Spark master at %s: %s", url, exc) continue - return None \ No newline at end of file + return None diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 1d2ef06624f8f..4d3bae4e6c8e7 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -920,7 +920,7 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() - + class TestSparkSubmitOperatorDeferrable: """Tests for SparkSubmitOperator deferrable=True mode.""" @@ -948,9 +948,7 @@ def _make_hook(self): def test_deferrable_defaults_to_false(self): """deferrable must default to False — existing behaviour unchanged.""" - op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") - assert op.deferrable is False - + def test_deferrable_stored_on_operator(self): """deferrable=True must be stored as self.deferrable.""" op = self._make_operator() @@ -1016,7 +1014,6 @@ def test_execute_complete_succeeds_on_success_event(self): def test_execute_complete_raises_on_error_event(self): """execute_complete() must raise AirflowException when status=error.""" - from airflow.providers.common.compat.sdk import AirflowException op = self._make_operator() event = { @@ -1025,7 +1022,7 @@ def test_execute_complete_raises_on_error_event(self): "driver_state": "FAILED", "message": "Driver reached FAILED", } - with pytest.raises(AirflowException, match="driver-001"): + with pytest.raises(RuntimeError, match="driver-001"): op.execute_complete(context={}, event=event) def test_build_master_rest_urls_single_master(self): @@ -1057,19 +1054,3 @@ def test_build_master_rest_urls_ha_multiple_masters(self): urls = op._build_master_rest_urls() assert urls == ["https://m1:6066", "https://m2:6066"] - - def test_deferrable_false_uses_sync_path(self): - """deferrable=False must fall through to the existing sync path (no defer call).""" - op = SparkSubmitOperator( - task_id="sync", dag=self.dag, application="app.jar", deferrable=False - ) - hook = MagicMock() - hook._should_track_driver_status = False - hook._should_track_driver_via_k8s_api.return_value = False - op._hook = hook - - with mock.patch.object(op, "defer") as mock_defer: - op.execute(context={}) - - mock_defer.assert_not_called() - hook.submit.assert_called_once_with("app.jar") diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py index d216be4ddc949..13a83393a9124 100644 --- a/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/__init__.py @@ -13,4 +13,4 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. \ No newline at end of file +# under the License. diff --git a/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py index 58d8a25da7b8e..d22dd0c150bc7 100644 --- a/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/triggers/test_spark_submit.py @@ -162,24 +162,26 @@ async def test_poll_returns_driver_state(self): assert result == "RUNNING" @pytest.mark.asyncio - async def test_poll_returns_unknown_on_success_false(self): - """When success=false, _poll_driver_status must return UNKNOWN (HA failover).""" - trigger = self._make_trigger() - mock_resp = mock.AsyncMock() - mock_resp.raise_for_status = mock.MagicMock() - mock_resp.json = mock.AsyncMock(return_value={"success": False, "message": "not found"}) - mock_resp.__aenter__ = mock.AsyncMock(return_value=mock_resp) - mock_resp.__aexit__ = mock.AsyncMock(return_value=False) - + async def test_poll_skips_master_on_success_false(self): + """When success=false from one master, must try the next master URL.""" + trigger = self._make_trigger(master_urls=["http://m1:6066", "http://m2:6066"]) + bad_resp = mock.AsyncMock() + bad_resp.raise_for_status = mock.MagicMock() + bad_resp.json = mock.AsyncMock(return_value={"success": False, "message": "not found"}) + bad_resp.__aenter__ = mock.AsyncMock(return_value=bad_resp) + bad_resp.__aexit__ = mock.AsyncMock(return_value=False) + good_resp = mock.AsyncMock() + good_resp.raise_for_status = mock.MagicMock() + good_resp.json = mock.AsyncMock(return_value={"success": True, "driverState": "RUNNING"}) + good_resp.__aenter__ = mock.AsyncMock(return_value=good_resp) + good_resp.__aexit__ = mock.AsyncMock(return_value=False) mock_session = mock.AsyncMock() - mock_session.get.return_value = mock_resp + mock_session.get.side_effect = [bad_resp, good_resp] mock_session.__aenter__ = mock.AsyncMock(return_value=mock_session) mock_session.__aexit__ = mock.AsyncMock(return_value=False) - with mock.patch("aiohttp.ClientSession", return_value=mock_session): result = await trigger._poll_driver_status() - - assert result == "UNKNOWN" + assert result == "RUNNING" @pytest.mark.asyncio async def test_poll_tries_next_master_on_exception(self): @@ -224,4 +226,4 @@ async def test_poll_returns_none_when_all_masters_fail(self): with mock.patch("aiohttp.ClientSession", return_value=mock_session): result = await trigger._poll_driver_status() - assert result is None \ No newline at end of file + assert result is None