Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, *args, **kwargs):
self.task_publish_max_retries = self.conf.getint(
"kubernetes_executor", "task_publish_max_retries", fallback=0
)
self.completed: set[KubernetesResults] = set()
self.completed: dict[tuple[str, str], KubernetesResults] = {}
self.create_pods_after: datetime | None = None

def _list_pods(self, query_kwargs):
Expand Down Expand Up @@ -300,8 +300,18 @@ def sync(self) -> None:
finally:
self.result_queue.task_done()

for result in self.completed:
if self.completed:
still_pending: dict[tuple[str, str], KubernetesResults] = {}
for pod_key, result in self.completed.items():
try:
self._change_state(result)
except Exception:
self.log.exception(
"Exception when attempting to change state of adopted completed pod %s, will retry.",
result,
)
still_pending[pod_key] = result
self.completed = still_pending

from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ResourceVersion

Expand Down Expand Up @@ -813,15 +823,15 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
continue

ti_id = annotations_to_key(pod.metadata.annotations)
self.completed.add(
KubernetesResults(
key=ti_id,
state="completed",
pod_name=pod.metadata.name,
namespace=pod.metadata.namespace,
resource_version=pod.metadata.resource_version,
failure_details=None,
)
pod_name = pod.metadata.name
namespace = pod.metadata.namespace
self.completed[(namespace, pod_name)] = KubernetesResults(
key=ti_id,
state="completed",
pod_name=pod_name,
namespace=namespace,
resource_version=pod.metadata.resource_version,
failure_details=None,
)

def _flush_task_queue(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ def get_annotations(pod_name):
],
any_order=True,
)
assert {k8s_res.key for k8s_res in executor.completed} == expected_running_ti_keys
assert {k8s_res.key for k8s_res in executor.completed.values()} == expected_running_ti_keys

@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor.DynamicClient")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
Expand Down Expand Up @@ -1536,6 +1536,88 @@ def test_alive_other_scheduler_job_ids_does_not_detach_caller_session(self, sess
"_alive_other_scheduler_job_ids closed/detached the caller's scoped session"
)

@pytest.mark.db_test
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
@mock.patch(
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod"
)
def test_sync_processes_completed_pods_once(
self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher
):
"""Adopted completed pods must not be re-deleted for every result-queue item."""
executor = self.kubernetes_executor
executor.start()
try:
completed_key = TaskInstanceKey(dag_id="dag", task_id="completed", run_id="run_id", try_number=1)
queue_key = TaskInstanceKey(dag_id="dag", task_id="queued", run_id="run_id", try_number=1)
executor.completed = {
("default", "completed-pod"): KubernetesResults(
completed_key,
"completed",
"completed-pod",
"default",
"1",
None,
)
}
executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod", "default", "2", None))
executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod-2", "default", "3", None))

executor.sync()

assert mock_delete_pod.call_count == 3
Comment thread
ihorlukianov marked this conversation as resolved.
assert executor.completed == {}
finally:
executor.end()

@pytest.mark.db_test
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
@mock.patch(
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler"
)
def test_sync_processes_completed_pods_once_without_deletion(
self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher
):
"""Adopted completed pods must not be re-patched for every result-queue item."""
mock_delete_pod = mock_kubescheduler.return_value.delete_pod
mock_patch_pod = mock_kubescheduler.return_value.patch_pod_executor_done
executor = self.kubernetes_executor
executor.kube_config.delete_worker_pods = False
executor.start()
try:
completed_key = TaskInstanceKey(dag_id="dag", task_id="completed", run_id="run_id", try_number=1)
queue_key = TaskInstanceKey(dag_id="dag", task_id="queued", run_id="run_id", try_number=1)
executor.completed = {
("default", "completed-pod"): KubernetesResults(
completed_key,
"completed",
"completed-pod",
"default",
"1",
None,
)
}
executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod", "default", "2", None))
executor.result_queue.put(KubernetesResults(queue_key, None, "queue-pod-2", "default", "3", None))

executor.sync()

mock_delete_pod.assert_not_called()
assert mock_patch_pod.call_count == 3
mock_patch_pod.assert_has_calls(
[
mock.call(pod_name="completed-pod", namespace="default"),
mock.call(pod_name="queue-pod", namespace="default"),
mock.call(pod_name="queue-pod-2", namespace="default"),
],
any_order=True,
)
assert executor.completed == {}
finally:
executor.end()

@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
def test_not_adopt_unassigned_task(self, mock_kube_client):
"""
Expand Down
Loading