Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@
"apache.spark": {
"deps": [
"apache-airflow-providers-common-compat>=1.12.0",
"aiohttp>=3.14.0",
"apache-airflow>=2.11.0",
"grpcio-status>=1.67.0",
"pyspark-client>=4.0.0",
Expand Down
5 changes: 5 additions & 0 deletions providers/apache/spark/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ connection-types:
- string
- 'null'

triggers:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should come after hooks.

- integration-name: Apache Spark
python-modules:
- airflow.providers.apache.spark.triggers.spark_submit

task-decorators:
- class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
name: pyspark
1 change: 1 addition & 0 deletions providers/apache/spark/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.12.0",
"aiohttp>=3.14.0",
"pyspark-client>=4.0.0",
"grpcio-status>=1.67.0",
"requests>=2.32.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tenacity import retry, stop_after_attempt, wait_fixed

from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
Expand Down Expand Up @@ -134,6 +135,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
omitted, Kerberos-enabled Spark connections with both ``keytab`` and
``principal`` configured use ``requests-kerberos`` automatically.
Defaults to ``None`` (no auth for non-Kerberos connections).
:param deferrable: Run operator in deferrable mode.
"""

# Generic key used across all Spark deployment modes (standalone driver ID,
Expand Down Expand Up @@ -197,6 +199,7 @@ def __init__(
track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
Expand Down Expand Up @@ -244,6 +247,7 @@ def __init__(

self.reconnect_on_retry = reconnect_on_retry
self._track_driver_via_k8s_api = track_driver_via_k8s_api
self.deferrable = deferrable
self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self._openlineage_inject_transport_info = openlineage_inject_transport_info

Expand All @@ -261,6 +265,18 @@ def execute(self, context: Context) -> None:
hook = self._hook
if self._track_driver_via_k8s_api:
hook._validate_track_driver_via_k8s_api_config()
if self.deferrable:
driver_id = self.submit_job(context)
master_urls = self._build_master_rest_urls()
self.defer(
trigger=SparkDriverTrigger(
driver_id=driver_id,
master_urls=master_urls,
poll_interval=self._status_poll_interval,
),
method_name="execute_complete",
)
return # unreachable after defer(); keeps type checkers happy
if hook._should_track_driver_status:
if self.reconnect_on_retry:
return self.execute_resumable(context)
Expand Down Expand Up @@ -290,6 +306,41 @@ def execute(self, context: Context) -> None:
return self.get_job_result(driver_id, context)
hook.submit(self.application)

def execute_complete(self, context: Context, event: dict) -> None:
"""Handle the result emitted by SparkDriverTrigger."""
status = event.get("status")
driver_id = event.get("driver_id")
if status is None:
raise RuntimeError(f"Malformed trigger event: {event}")
if driver_id is None:
raise RuntimeError(f"Malformed trigger event: {event}")
if status != "success":
driver_state = event.get("driver_state", "")
message = event.get("message", "")
raise RuntimeError(
f"Spark driver {driver_id} did not finish successfully "
f"(state={driver_state}): {message}"
)
self.log.info("Spark driver %s finished successfully", driver_id)

def _build_master_rest_urls(self) -> list[str]:
"""
Build Spark master REST API base URLs for SparkDriverTrigger.

Supports HA (comma-separated master URL) and respects rest_scheme /
rest_port connection extras (same logic as get_job_status).
"""
if self._hook is None:
self._hook = self._get_hook()
scheme = self._hook._connection.get("rest_scheme", "http")
Comment thread
bujjibabukatta marked this conversation as resolved.
rest_port = self._hook._connection.get("rest_port", 6066)
master_hosts = self._hook._connection["master"].replace("spark://", "").split(",")
urls = []
for host in master_hosts:
hostname = host.strip().split(":")[0]
urls.append(f"{scheme}://{hostname}:{rest_port}")
return urls

def submit_job(self, context: Context) -> str:
if self._hook is None:
self._hook = self._get_hook()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio

from collections.abc import AsyncIterator
from typing import Any

import aiohttp

from airflow.triggers.base import BaseTrigger, TriggerEvent

_SPARK_TERMINAL_STATES = {"FINISHED", "FAILED", "KILLED", "ERROR"}


class SparkDriverTrigger(BaseTrigger):
"""
Async trigger that polls the Spark standalone REST API until the driver
reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True.

:param driver_id: Spark driver submission ID returned by spark-submit --rest.
:param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"].
:param poll_interval: Seconds between REST API polls. Defaults to 10.
"""

def __init__(
self,
driver_id: str,
master_urls: list[str],
poll_interval: int = 10,
) -> None:
super().__init__()
self.driver_id = driver_id
self.master_urls = master_urls
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger",
{
"driver_id": self.driver_id,
"master_urls": self.master_urls,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Poll Spark REST API until driver reaches a terminal state."""
while True:
status = await self._poll_driver_status()
if status is None:
yield TriggerEvent(
{
"status": "error",
"driver_id": self.driver_id,
"message": "All Spark masters unreachable",
}
)
return
self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status)
normalized_status = status.upper()
if normalized_status in _SPARK_TERMINAL_STATES:
success = normalized_status == "FINISHED"
yield TriggerEvent(
{
"status": "success" if success else "error",
"driver_id": self.driver_id,
"driver_state": normalized_status,
"message": f"Driver {self.driver_id} reached state {normalized_status}",
}
)
return
await asyncio.sleep(self.poll_interval)

async def _poll_driver_status(self) -> str | None:
"""Try each master URL; return driverState str or None if all fail."""
for url in self.master_urls:
status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}"
try:
async with aiohttp.ClientSession() as session:
async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
resp.raise_for_status()
data = await resp.json()
if not data.get("success"):
self.log.warning(
"Spark REST API returned success=false for %s: %s",
self.driver_id,
data.get("message", "unknown"),
)
continue
return data["driverState"]
except aiohttp.ClientError as exc:
self.log.warning("Could not reach Spark master at %s: %s", url, exc)
continue
return None
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,137 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self):

hook.submit.assert_called_once_with("test.jar")
hook._poll_k8s_driver_via_api.assert_not_called()


class TestSparkSubmitOperatorDeferrable:
"""Tests for SparkSubmitOperator deferrable=True mode."""

def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args)

def _make_operator(self, **kwargs):
return SparkSubmitOperator(
task_id="test_deferrable",
dag=self.dag,
application="test.jar",
deferrable=True,
**kwargs,
)

def _make_hook(self):
hook = MagicMock()
hook._should_track_driver_status = False
hook._should_track_driver_via_k8s_api.return_value = False
hook._connection = {"master": "spark://myhost:7077"}
hook.submit.return_value = "driver-001"
return hook

def test_deferrable_defaults_to_false(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is not needed as you are testing native python behaviour.

"""deferrable must default to False — existing behaviour unchanged."""

def test_deferrable_stored_on_operator(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Not needed.

"""deferrable=True must be stored as self.deferrable."""
op = self._make_operator()
assert op.deferrable is True

def test_execute_calls_defer_when_deferrable_true(self):
"""execute() must call self.defer() when deferrable=True."""
op = self._make_operator()
hook = self._make_hook()
op._hook = hook

with mock.patch.object(op, "submit_job", return_value="driver-001"), \
mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \
mock.patch.object(op, "defer") as mock_defer:
op.execute(context={})

mock_defer.assert_called_once()
call_kwargs = mock_defer.call_args.kwargs
assert call_kwargs["method_name"] == "execute_complete"

def test_execute_passes_correct_args_to_trigger(self):
"""execute() must pass driver_id and master_urls to SparkDriverTrigger."""
from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger

op = self._make_operator(status_poll_interval=15)
hook = self._make_hook()
op._hook = hook

with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \
mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \
mock.patch.object(op, "defer") as mock_defer:
op.execute(context={})

trigger = mock_defer.call_args.kwargs["trigger"]
assert isinstance(trigger, SparkDriverTrigger)
assert trigger.driver_id == "driver-xyz"
assert trigger.master_urls == ["http://m1:6066"]
assert trigger.poll_interval == 15

def test_execute_does_not_call_hook_submit_directly(self):
"""execute() in deferrable mode must use submit_job(), not hook.submit()."""
op = self._make_operator()
hook = self._make_hook()
op._hook = hook

with mock.patch.object(op, "submit_job", return_value="driver-001"), \
mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \
mock.patch.object(op, "defer"):
op.execute(context={})

hook.submit.assert_not_called()

def test_execute_complete_succeeds_on_success_event(self):
"""execute_complete() must not raise when status=success."""
op = self._make_operator()
event = {
"status": "success",
"driver_id": "driver-001",
"driver_state": "FINISHED",
"message": "Driver reached FINISHED",
}
op.execute_complete(context={}, event=event) # must not raise

def test_execute_complete_raises_on_error_event(self):
"""execute_complete() must raise AirflowException when status=error."""

op = self._make_operator()
event = {
"status": "error",
"driver_id": "driver-001",
"driver_state": "FAILED",
"message": "Driver reached FAILED",
}
with pytest.raises(RuntimeError, match="driver-001"):
op.execute_complete(context={}, event=event)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another test for malformed event payloads i.e. missing driver_id and status?

def test_build_master_rest_urls_single_master(self):
"""_build_master_rest_urls must return correct URL for a single master."""
op = self._make_operator()
hook = self._make_hook()
hook._connection = {
"master": "spark://myhost:7077",
"rest_scheme": "http",
"rest_port": 6066,
}
op._hook = hook

urls = op._build_master_rest_urls()

assert urls == ["http://myhost:6066"]

def test_build_master_rest_urls_ha_multiple_masters(self):
"""_build_master_rest_urls must return a URL per master in HA mode."""
op = self._make_operator()
hook = self._make_hook()
hook._connection = {
"master": "spark://m1:7077,m2:7077",
"rest_scheme": "https",
"rest_port": 6066,
}
op._hook = hook

urls = op._build_master_rest_urls()

assert urls == ["https://m1:6066", "https://m2:6066"]
Loading
Loading