-
Notifications
You must be signed in to change notification settings - Fork 17.3k
Add deferrable=True support to SparkSubmitOperator (issue #67168) #68277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
40fe70d
37051f7
958428e
f4def2c
f3d6335
90d4104
70b9478
9f85e43
febcaad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
|
|
||
| from collections.abc import AsyncIterator | ||
| from typing import Any | ||
|
|
||
| import aiohttp | ||
|
|
||
| from airflow.triggers.base import BaseTrigger, TriggerEvent | ||
|
|
||
| _SPARK_TERMINAL_STATES = {"FINISHED", "FAILED", "KILLED", "ERROR"} | ||
|
|
||
|
|
||
| class SparkDriverTrigger(BaseTrigger): | ||
| """ | ||
| Async trigger that polls the Spark standalone REST API until the driver | ||
| reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. | ||
|
|
||
| :param driver_id: Spark driver submission ID returned by spark-submit --rest. | ||
| :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. | ||
| :param poll_interval: Seconds between REST API polls. Defaults to 10. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| driver_id: str, | ||
| master_urls: list[str], | ||
| poll_interval: int = 10, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.driver_id = driver_id | ||
| self.master_urls = master_urls | ||
| self.poll_interval = poll_interval | ||
|
|
||
| def serialize(self) -> tuple[str, dict[str, Any]]: | ||
| return ( | ||
| "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", | ||
| { | ||
| "driver_id": self.driver_id, | ||
| "master_urls": self.master_urls, | ||
| "poll_interval": self.poll_interval, | ||
| }, | ||
| ) | ||
|
|
||
| async def run(self) -> AsyncIterator[TriggerEvent]: | ||
| """Poll Spark REST API until driver reaches a terminal state.""" | ||
| while True: | ||
| status = await self._poll_driver_status() | ||
| if status is None: | ||
| yield TriggerEvent( | ||
| { | ||
| "status": "error", | ||
| "driver_id": self.driver_id, | ||
| "message": "All Spark masters unreachable", | ||
| } | ||
| ) | ||
| return | ||
| self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) | ||
| normalized_status = status.upper() | ||
| if normalized_status in _SPARK_TERMINAL_STATES: | ||
| success = normalized_status == "FINISHED" | ||
| yield TriggerEvent( | ||
| { | ||
| "status": "success" if success else "error", | ||
| "driver_id": self.driver_id, | ||
| "driver_state": normalized_status, | ||
| "message": f"Driver {self.driver_id} reached state {normalized_status}", | ||
| } | ||
| ) | ||
| return | ||
| await asyncio.sleep(self.poll_interval) | ||
|
|
||
| async def _poll_driver_status(self) -> str | None: | ||
| """Try each master URL; return driverState str or None if all fail.""" | ||
| for url in self.master_urls: | ||
| status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}" | ||
| try: | ||
| async with aiohttp.ClientSession() as session: | ||
| async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: | ||
| resp.raise_for_status() | ||
| data = await resp.json() | ||
| if not data.get("success"): | ||
| self.log.warning( | ||
| "Spark REST API returned success=false for %s: %s", | ||
| self.driver_id, | ||
| data.get("message", "unknown"), | ||
| ) | ||
| continue | ||
| return data["driverState"] | ||
| except aiohttp.ClientError as exc: | ||
| self.log.warning("Could not reach Spark master at %s: %s", url, exc) | ||
| continue | ||
| return None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -920,3 +920,137 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): | |
|
|
||
| hook.submit.assert_called_once_with("test.jar") | ||
| hook._poll_k8s_driver_via_api.assert_not_called() | ||
|
|
||
|
|
||
| class TestSparkSubmitOperatorDeferrable: | ||
| """Tests for SparkSubmitOperator deferrable=True mode.""" | ||
|
|
||
| def setup_method(self): | ||
| args = {"owner": "airflow", "start_date": DEFAULT_DATE} | ||
| self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) | ||
|
|
||
| def _make_operator(self, **kwargs): | ||
| return SparkSubmitOperator( | ||
| task_id="test_deferrable", | ||
| dag=self.dag, | ||
| application="test.jar", | ||
| deferrable=True, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def _make_hook(self): | ||
| hook = MagicMock() | ||
| hook._should_track_driver_status = False | ||
| hook._should_track_driver_via_k8s_api.return_value = False | ||
| hook._connection = {"master": "spark://myhost:7077"} | ||
| hook.submit.return_value = "driver-001" | ||
| return hook | ||
|
|
||
| def test_deferrable_defaults_to_false(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is not needed as you are testing native python behaviour. |
||
| """deferrable must default to False — existing behaviour unchanged.""" | ||
|
|
||
| def test_deferrable_stored_on_operator(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. Not needed. |
||
| """deferrable=True must be stored as self.deferrable.""" | ||
| op = self._make_operator() | ||
| assert op.deferrable is True | ||
|
|
||
| def test_execute_calls_defer_when_deferrable_true(self): | ||
| """execute() must call self.defer() when deferrable=True.""" | ||
| op = self._make_operator() | ||
| hook = self._make_hook() | ||
| op._hook = hook | ||
|
|
||
| with mock.patch.object(op, "submit_job", return_value="driver-001"), \ | ||
| mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ | ||
| mock.patch.object(op, "defer") as mock_defer: | ||
| op.execute(context={}) | ||
|
|
||
| mock_defer.assert_called_once() | ||
| call_kwargs = mock_defer.call_args.kwargs | ||
| assert call_kwargs["method_name"] == "execute_complete" | ||
|
|
||
| def test_execute_passes_correct_args_to_trigger(self): | ||
| """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" | ||
| from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger | ||
|
|
||
| op = self._make_operator(status_poll_interval=15) | ||
| hook = self._make_hook() | ||
| op._hook = hook | ||
|
|
||
| with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ | ||
| mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ | ||
| mock.patch.object(op, "defer") as mock_defer: | ||
| op.execute(context={}) | ||
|
|
||
| trigger = mock_defer.call_args.kwargs["trigger"] | ||
| assert isinstance(trigger, SparkDriverTrigger) | ||
| assert trigger.driver_id == "driver-xyz" | ||
| assert trigger.master_urls == ["http://m1:6066"] | ||
| assert trigger.poll_interval == 15 | ||
|
|
||
| def test_execute_does_not_call_hook_submit_directly(self): | ||
| """execute() in deferrable mode must use submit_job(), not hook.submit().""" | ||
| op = self._make_operator() | ||
| hook = self._make_hook() | ||
| op._hook = hook | ||
|
|
||
| with mock.patch.object(op, "submit_job", return_value="driver-001"), \ | ||
| mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ | ||
| mock.patch.object(op, "defer"): | ||
| op.execute(context={}) | ||
|
|
||
| hook.submit.assert_not_called() | ||
|
|
||
| def test_execute_complete_succeeds_on_success_event(self): | ||
| """execute_complete() must not raise when status=success.""" | ||
| op = self._make_operator() | ||
| event = { | ||
| "status": "success", | ||
| "driver_id": "driver-001", | ||
| "driver_state": "FINISHED", | ||
| "message": "Driver reached FINISHED", | ||
| } | ||
| op.execute_complete(context={}, event=event) # must not raise | ||
|
|
||
| def test_execute_complete_raises_on_error_event(self): | ||
| """execute_complete() must raise AirflowException when status=error.""" | ||
|
|
||
| op = self._make_operator() | ||
| event = { | ||
| "status": "error", | ||
| "driver_id": "driver-001", | ||
| "driver_state": "FAILED", | ||
| "message": "Driver reached FAILED", | ||
| } | ||
| with pytest.raises(RuntimeError, match="driver-001"): | ||
| op.execute_complete(context={}, event=event) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add another test for malformed event payloads i.e. missing |
||
| def test_build_master_rest_urls_single_master(self): | ||
| """_build_master_rest_urls must return correct URL for a single master.""" | ||
| op = self._make_operator() | ||
| hook = self._make_hook() | ||
| hook._connection = { | ||
| "master": "spark://myhost:7077", | ||
| "rest_scheme": "http", | ||
| "rest_port": 6066, | ||
| } | ||
| op._hook = hook | ||
|
|
||
| urls = op._build_master_rest_urls() | ||
|
|
||
| assert urls == ["http://myhost:6066"] | ||
|
|
||
| def test_build_master_rest_urls_ha_multiple_masters(self): | ||
| """_build_master_rest_urls must return a URL per master in HA mode.""" | ||
| op = self._make_operator() | ||
| hook = self._make_hook() | ||
| hook._connection = { | ||
| "master": "spark://m1:7077,m2:7077", | ||
| "rest_scheme": "https", | ||
| "rest_port": 6066, | ||
| } | ||
| op._hook = hook | ||
|
|
||
| urls = op._build_master_rest_urls() | ||
|
|
||
| assert urls == ["https://m1:6066", "https://m2:6066"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should come after hooks.