diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 0dcc01537cb78..03362c3c1687d 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -114,7 +114,7 @@ def __init__(self, *args, **kwargs): self.task_publish_max_retries = self.conf.getint( "kubernetes_executor", "task_publish_max_retries", fallback=0 ) - self.completed: set[KubernetesResults] = set() + self.completed: dict[tuple[str, str], KubernetesResults] = {} self.create_pods_after: datetime | None = None def _list_pods(self, query_kwargs): @@ -300,8 +300,18 @@ def sync(self) -> None: finally: self.result_queue.task_done() - for result in self.completed: + if self.completed: + still_pending: dict[tuple[str, str], KubernetesResults] = {} + for pod_key, result in self.completed.items(): + try: self._change_state(result) + except Exception: + self.log.exception( + "Exception when attempting to change state of adopted completed pod %s, will retry.", + result, + ) + still_pending[pod_key] = result + self.completed = still_pending from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ResourceVersion @@ -813,15 +823,15 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: continue ti_id = annotations_to_key(pod.metadata.annotations) - self.completed.add( - KubernetesResults( - key=ti_id, - state="completed", - pod_name=pod.metadata.name, - namespace=pod.metadata.namespace, - resource_version=pod.metadata.resource_version, - failure_details=None, - ) + pod_name = pod.metadata.name + namespace = pod.metadata.namespace + self.completed[(namespace, pod_name)] = KubernetesResults( + key=ti_id, + state="completed", + pod_name=pod_name, + namespace=namespace, + resource_version=pod.metadata.resource_version, + failure_details=None, ) def _flush_task_queue(self) -> None: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py index bc1c2a97f55c7..c82b97af85978 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -1397,7 +1397,7 @@ def get_annotations(pod_name): ], any_order=True, ) - assert {k8s_res.key for k8s_res in executor.completed} == expected_running_ti_keys + assert {k8s_res.key for k8s_res in executor.completed.values()} == expected_running_ti_keys @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor.DynamicClient") @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") @@ -1536,6 +1536,88 @@ def test_alive_other_scheduler_job_ids_does_not_detach_caller_session(self, sess "_alive_other_scheduler_job_ids closed/detached the caller's scoped session" ) + @pytest.mark.db_test + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + @mock.patch( + "airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod" + ) + def test_sync_processes_completed_pods_once( + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher + ): + """Adopted completed pods must not be re-deleted for every result-queue item.""" + executor = self.kubernetes_executor + executor.start() + try: + completed_key = TaskInstanceKey(dag_id="dag", task_id="completed", run_id="run_id", try_number=1) + queue_key = TaskInstanceKey(dag_id="dag", task_id="queued", run_id="run_id", try_number=1) + executor.completed = { + ("default", "completed-pod"): KubernetesResults( + completed_key, + "completed", + "completed-pod", + "default", + "1", + None, + ) + } + executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod", "default", "2", None)) + executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod-2", "default", "3", None)) + + executor.sync() + + assert mock_delete_pod.call_count == 3 + assert executor.completed == {} + finally: + executor.end() + + @pytest.mark.db_test + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + @mock.patch( + "airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler" + ) + def test_sync_processes_completed_pods_once_without_deletion( + self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher + ): + """Adopted completed pods must not be re-patched for every result-queue item.""" + mock_delete_pod = mock_kubescheduler.return_value.delete_pod + mock_patch_pod = mock_kubescheduler.return_value.patch_pod_executor_done + executor = self.kubernetes_executor + executor.kube_config.delete_worker_pods = False + executor.start() + try: + completed_key = TaskInstanceKey(dag_id="dag", task_id="completed", run_id="run_id", try_number=1) + queue_key = TaskInstanceKey(dag_id="dag", task_id="queued", run_id="run_id", try_number=1) + executor.completed = { + ("default", "completed-pod"): KubernetesResults( + completed_key, + "completed", + "completed-pod", + "default", + "1", + None, + ) + } + executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod", "default", "2", None)) + executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod-2", "default", "3", None)) + + executor.sync() + + mock_delete_pod.assert_not_called() + assert mock_patch_pod.call_count == 3 + mock_patch_pod.assert_has_calls( + [ + mock.call(pod_name="completed-pod", namespace="default"), + mock.call(pod_name="queue-pod", namespace="default"), + mock.call(pod_name="queue-pod-2", namespace="default"), + ], + any_order=True, + ) + assert executor.completed == {} + finally: + executor.end() + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") def test_not_adopt_unassigned_task(self, mock_kube_client): """