diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 5e1b38305b0d5..a01f05ae82de6 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from datetime import datetime, timezone from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast @@ -29,7 +30,7 @@ from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import CustomObjectLauncher from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator -from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager +from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager, PodPhase from airflow.providers.common.compat.sdk import AirflowException from airflow.utils.helpers import prune_dict @@ -235,6 +236,14 @@ def template_body(self): return self.manage_template_specs() def find_spark_job(self, context, exclude_checked: bool = True): + """ + Find an existing Spark driver pod for this task instance. + + The pod is identified using Airflow task context labels. If multiple + driver pods match the same labels (which can occur if cleanup did not + run after an abrupt failure), a single pod is selected deterministically + for reattachment, preferring a Running driver pod when present. + """ label_selector = ( self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + ",spark-role=driver" @@ -242,8 +251,25 @@ def find_spark_job(self, context, exclude_checked: bool = True): pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items pod = None - if len(pod_list) > 1: # and self.reattach_on_restart: - raise AirflowException(f"More than one pod running with labels: {label_selector}") + if len(pod_list) > 1: + # When multiple pods match the same labels, select one deterministically, + # preferring a Running pod, then creation time, with name as a tie-breaker. + pod = max( + pod_list, + key=lambda p: ( + p.status.phase == PodPhase.RUNNING, + p.metadata.creation_timestamp or datetime.min.replace(tzinfo=timezone.utc), + p.metadata.name or "", + ), + ) + self.log.warning( + "Found %d Spark driver pods matching labels %s; " + "selecting pod %s for reattachment based on status and creation time.", + len(pod_list), + label_selector, + pod.metadata.name, + ) + if len(pod_list) == 1: pod = pod_list[0] self.log.info( diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index bf573ba51b37b..1e7f5c1da23d4 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -37,6 +37,7 @@ from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger +from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase from airflow.providers.common.compat.sdk import TaskDeferred from airflow.utils import timezone from airflow.utils.types import DagRunType @@ -944,6 +945,170 @@ def test_reattach_on_restart_with_task_context_labels( mock_create_namespaced_crd.assert_not_called() + def test_find_spark_job_picks_running_pod( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + """ + Verifies that find_spark_job picks a Running Spark driver pod over a non-Running pod. + """ + + task_name = "test_find_spark_job_prefers_running_pod" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + # Running pod should be selected. + running_pod = mock.MagicMock() + running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + running_pod.metadata.name = "spark-driver-running" + running_pod.metadata.labels = {"try_number": "1"} + running_pod.status.phase = "Running" + + # Pending pod should not be selected. + pending_pod = mock.MagicMock() + pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + pending_pod.metadata.name = "spark-driver-pending" + pending_pod.metadata.labels = {"try_number": "1"} + pending_pod.status.phase = "Pending" + + mock_get_kube_client.list_namespaced_pod.return_value.items = [ + running_pod, + pending_pod, + ] + + returned_pod = op.find_spark_job(context) + + assert returned_pod is running_pod + + def test_find_spark_job_picks_latest_pod( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + """ + Verifies that find_spark_job selects the most recently created Spark driver pod + when multiple candidate driver pods are present and status does not disambiguate. + """ + + task_name = "test_find_spark_job_picks_latest_pod" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + # Older pod that should be ignored. + old_mock_pod = mock.MagicMock() + old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + old_mock_pod.metadata.name = "spark-driver-old" + old_mock_pod.status.phase = PodPhase.RUNNING + + # Newer pod that should be picked up. + new_mock_pod = mock.MagicMock() + new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc) + new_mock_pod.metadata.name = "spark-driver-new" + new_mock_pod.status.phase = PodPhase.RUNNING + + # Same try_number to simulate abrupt failure scenarios (e.g. scheduler crash) + # where cleanup did not occur and multiple pods share identical labels. + old_mock_pod.metadata.labels = {"try_number": "1"} + new_mock_pod.metadata.labels = {"try_number": "1"} + + mock_get_kube_client.list_namespaced_pod.return_value.items = [old_mock_pod, new_mock_pod] + + returned_pod = op.find_spark_job(context) + + assert returned_pod is new_mock_pod + + def test_find_spark_job_tiebreaks_by_name( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + """ + Verifies that find_spark_job uses pod name as a deterministic tie-breaker + when multiple running Spark driver pods share the same creation_timestamp. + """ + + task_name = "test_find_spark_job_tiebreaks_by_name" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + # Use identical creation timestamps to force name-based tie-breaking. + ts = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + + # Pod with lexicographically smaller name should not be selected. + invalid_mock_pod = mock.MagicMock() + invalid_mock_pod.metadata.creation_timestamp = ts + invalid_mock_pod.metadata.name = "spark-driver-abc" + invalid_mock_pod.metadata.labels = {"try_number": "1"} + invalid_mock_pod.status.phase = PodPhase.RUNNING + + # Pod with lexicographically greater name should be selected. + valid_mock_pod = mock.MagicMock() + valid_mock_pod.metadata.creation_timestamp = ts + valid_mock_pod.metadata.name = "spark-driver-xyz" + valid_mock_pod.metadata.labels = {"try_number": "1"} + valid_mock_pod.status.phase = PodPhase.RUNNING + + mock_get_kube_client.list_namespaced_pod.return_value.items = [invalid_mock_pod, valid_mock_pod] + + returned_pod = op.find_spark_job(context) + + assert returned_pod is valid_mock_pod + @pytest.mark.asyncio def test_execute_deferrable( self,