Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from kubernetes.client import ApiClient, models as k8s
from kubernetes.client import ApiClient, Configuration, models as k8s

from airflow.providers.common.compat.sdk import AirflowException

Expand All @@ -36,7 +36,7 @@ def _convert_from_dict(obj, new_class):
if isinstance(obj, new_class):
return obj
if isinstance(obj, dict):
api_client = ApiClient()
api_client = ApiClient(configuration=Configuration())
return api_client._ApiClient__deserialize_model(obj, new_class)
raise AirflowException(f"Expected dict or {new_class}, got {type(obj)}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal

from kubernetes.client import BatchV1Api, models as k8s
from kubernetes.client import BatchV1Api, Configuration, models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException

Expand Down Expand Up @@ -378,7 +378,7 @@ def deserialize_job_template_file(path: str) -> k8s.V1Job:
job = None
log.warning("Template file %s does not exist", path)

api_client = ApiClient()
api_client = ApiClient(configuration=Configuration())
return api_client._ApiClient__deserialize_model(job, k8s.V1Job)

def on_kill(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import TYPE_CHECKING

from dateutil import parser
from kubernetes.client import models as k8s
from kubernetes.client import Configuration, models as k8s
from kubernetes.client.api_client import ApiClient

from airflow.exceptions import (
Expand Down Expand Up @@ -568,10 +568,15 @@ def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod:
``_ApiClient__deserialize_model`` from the kubernetes client.
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.

A fresh ``Configuration`` is passed so that neither the pod nor any nested model captures the
process-global in-cluster ``Configuration``. In-cluster, that global carries a
``refresh_api_key_hook`` local closure which ``pickle`` cannot serialize, and which would
otherwise break pickling a ``pod_override`` onto the KubernetesExecutor multiprocessing queue.

:param pod_dict: Serialized dict of k8s.V1Pod object
:return: De-serialized k8s.V1Pod
"""
api_client = ApiClient()
api_client = ApiClient(configuration=Configuration())
return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.
from __future__ import annotations

import pickle
from unittest.mock import Mock, patch

import pytest
from kubernetes.client import models as k8s
from kubernetes.client import Configuration, models as k8s

from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import (
_convert_from_dict,
Expand Down Expand Up @@ -102,6 +103,29 @@ def test_convert_from_dict_with_invalid_type():
assert str(exc_info.value) == "Expected dict or <class 'unittest.mock.Mock'>, got <class 'str'>"


def test_convert_from_dict_is_picklable_in_cluster(monkeypatch):
"""A model deserialized from a dict must not capture the unpicklable in-cluster Configuration.

In-cluster, the kubernetes client installs a process-global default ``Configuration`` whose
``refresh_api_key_hook`` is an unpicklable local closure. ``_convert_from_dict`` must deserialize
through a fresh ``Configuration`` so the model (and every nested object) stays picklable.
"""

def _refresh_api_key(config):
return None

dirty = Configuration()
dirty.refresh_api_key_hook = _refresh_api_key
monkeypatch.setattr(Configuration, "_default", dirty, raising=False)

result = _convert_from_dict({"name": "vol", "emptyDir": {}}, k8s.V1Volume)

assert isinstance(result, k8s.V1Volume)
pickle.dumps(result)
assert result.local_vars_configuration.refresh_api_key_hook is None
assert result.empty_dir.local_vars_configuration.refresh_api_key_hook is None


# testcase of convert_volume() function
@patch("airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters._convert_kube_model_object")
def test_convert_volume_normal_value(mock_convert_kube_model_object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import pickle
import random
import re
import string
Expand All @@ -24,7 +25,7 @@

import pendulum
import pytest
from kubernetes.client import ApiClient, models as k8s
from kubernetes.client import ApiClient, Configuration, models as k8s
from kubernetes.client.rest import ApiException

from airflow.exceptions import AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -200,6 +201,42 @@ def test_backoff_limit_correctly_set(self, clean_dags_dagruns_and_dagbundles):
job = k.build_job_request_obj(create_context(k))
assert job.spec.backoff_limit == 6

def test_deserialize_job_template_file_is_picklable_in_cluster(self, tmp_path, monkeypatch):
"""A job deserialized from a template file must not capture the in-cluster Configuration.

In-cluster, the kubernetes client installs a process-global default ``Configuration`` whose
``refresh_api_key_hook`` is an unpicklable local closure. ``deserialize_job_template_file`` must
deserialize through a fresh ``Configuration`` so the job (and every nested model) stays picklable.
"""

def _refresh_api_key(config):
return None

dirty = Configuration()
dirty.refresh_api_key_hook = _refresh_api_key
monkeypatch.setattr(Configuration, "_default", dirty, raising=False)

template = tmp_path / "job.yaml"
template.write_text(
"apiVersion: batch/v1\n"
"kind: Job\n"
"metadata:\n"
" name: test-job\n"
"spec:\n"
" template:\n"
" spec:\n"
" containers:\n"
" - name: base\n"
" image: airflow:3\n"
)

job = KubernetesJobOperator.deserialize_job_template_file(template.as_posix())

assert isinstance(job, k8s.V1Job)
pickle.dumps(job)
assert job.local_vars_configuration.refresh_api_key_hook is None
assert job.spec.template.spec.containers[0].local_vars_configuration.refresh_api_key_hook is None

def test_completion_mode_correctly_set(self, clean_dags_dagruns_and_dagbundles):
k = KubernetesJobOperator(
task_id="task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# under the License.
from __future__ import annotations

import pickle
import re
from unittest import mock
from unittest.mock import MagicMock

import pendulum
import pytest
from dateutil import parser
from kubernetes.client import ApiClient, models as k8s
from kubernetes.client import ApiClient, Configuration, models as k8s

from airflow import __version__
from airflow.exceptions import AirflowConfigException
Expand Down Expand Up @@ -698,6 +699,33 @@ def test_deserialize_non_existent_model_file(self, caplog, tmp_path):
assert len(caplog.records) == 1
assert "non_existent.yaml does not exist" in caplog.text

def test_deserialize_model_dict_is_picklable_in_cluster(self, monkeypatch):
"""A deserialized pod must not capture the unpicklable in-cluster Configuration.

In-cluster, the kubernetes client installs a process-global default ``Configuration`` whose
``refresh_api_key_hook`` is an unpicklable local closure. ``deserialize_model_dict`` must
round-trip through a fresh ``Configuration`` so the pod (and every nested model) stays
picklable onto the KubernetesExecutor multiprocessing queue.
"""

def _refresh_api_key(config):
return None

dirty = Configuration()
dirty.refresh_api_key_hook = _refresh_api_key
monkeypatch.setattr(Configuration, "_default", dirty, raising=False)

pod_dict = {
"metadata": {"name": "test-pod"},
"spec": {"containers": [{"name": "base", "image": "airflow:3"}]},
}
pod = PodGenerator.deserialize_model_dict(pod_dict)

assert isinstance(pod, k8s.V1Pod)
pickle.dumps(pod)
assert pod.local_vars_configuration.refresh_api_key_hook is None
assert pod.spec.containers[0].local_vars_configuration.refresh_api_key_hook is None

@pytest.mark.parametrize(
"input",
(
Expand Down
Loading