From 82cb7d7c935fe11e78ad12d114954d543a3f660c Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Fri, 15 Dec 2023 15:12:45 +0000 Subject: [PATCH 1/8] Add create_job hook for Kubernetes --- .../cncf/kubernetes/hooks/kubernetes.py | 28 +- .../cncf/kubernetes/operators/job.py | 363 ++++++++++++++++++ .../cncf/kubernetes/utils/k8s_yaml_manager.py | 236 ++++++++++++ 3 files changed, 626 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/cncf/kubernetes/operators/job.py create mode 100644 airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index e6e1054f10b30..cbe01dace683b 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -37,7 +37,7 @@ from airflow.utils import yaml if TYPE_CHECKING: - from kubernetes.client.models import V1Deployment, V1Pod + from kubernetes.client.models import V1Deployment, V1Job, V1Pod LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." @@ -290,6 +290,10 @@ def apps_v1_client(self) -> client.AppsV1Api: def custom_object_client(self) -> client.CustomObjectsApi: return client.CustomObjectsApi(api_client=self.api_client) + @cached_property + def batch_v1_client(self) -> client.BatchV1Api: + return client.BatchV1Api(api_client=self.api_client) + def create_custom_object( self, group: str, version: str, plural: str, body: str | dict, namespace: str | None = None ): @@ -472,6 +476,28 @@ def get_deployment_status( except Exception as exc: raise exc + def create_job( + self, + job: V1Job, + **kwargs, + ) -> V1Job: + """Run Job""" + sanitized_job = self.batch_v1_client.api_client.sanitize_for_serialization(job) + json_job = json.dumps(sanitized_job, indent=2) + + self.log.debug("Job Creation Request: \n%s", json_job) + try: + resp = self.batch_v1_client.create_namespaced_job( + body=sanitized_job, namespace=job.metadata.namespace, **kwargs + ) + self.log.debug("Job Creation Response: %s", resp) + except Exception as e: + self.log.exception( + "Exception when attempting to create Namespaced Job: %s", str(json_job).replace("\n", " ") + ) + raise e + return resp + def _get_bool(val) -> bool | None: """Convert val to bool if can be done with certainty; if we cannot infer intention we return None.""" diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py new file mode 100644 index 0000000000000..0710d9f665007 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -0,0 +1,363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Executes a Kubernetes Job.""" +from __future__ import annotations + +import logging +import os +import re +from functools import cached_property +from typing import TYPE_CHECKING + +from kubernetes.client import models as k8s +from kubernetes.client.api_client import ApiClient + +from airflow.models import BaseOperator +from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import ( + convert_affinity, + convert_env_vars, + convert_image_pull_secrets, + convert_port, + convert_toleration, + convert_volume, + convert_volume_mount, +) +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.providers.cncf.kubernetes.utils.k8s_yaml_manager import ( + reconcile_jobs, +) +from airflow.utils import yaml +from airflow.utils.helpers import validate_key + +if TYPE_CHECKING: + from airflow.utils.context import Context + +log = logging.getLogger(__name__) + + +class KubernetesJobOperator(BaseOperator): + """ + Executes a Kubernetes Job + + + """ + + def __init__( + self, + *, + kubernetes_conn_id: str | None = KubernetesHook.default_conn_name, + job_namespace: str | None = None, + pod_namespace: str | None = None, + image: str | None = None, + job_name: str | None = None, + pod_name: str | None = None, + cmds: list[str] | None = None, + arguments: list[str] | None = None, + ports: list[k8s.V1ContainerPort] | None = None, + volume_mounts: list[k8s.V1VolumeMount] | None = None, + volumes: list[k8s.V1Volume] | None = None, + env_vars: list[k8s.V1EnvVar] | dict[str, str] | None = None, + env_from: list[k8s.V1EnvFromSource] | None = None, + in_cluster: bool | None = None, + cluster_context: str | None = None, + job_labels: dict | None = None, + pod_labels: dict | None = None, + image_pull_policy: str | None = None, + job_annotations: dict | None = None, + pod_annotations: dict | None = None, + container_resources: k8s.V1ResourceRequirements | None = None, + affinity: k8s.V1Affinity | None = None, + config_file: str | None = None, + node_selector: dict | None = None, + image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None, + service_account_name: str | None = None, + hostnetwork: bool = False, + host_aliases: list[k8s.V1HostAlias] | None = None, + tolerations: list[k8s.V1Toleration] | None = None, + security_context: k8s.V1PodSecurityContext | dict | None = None, + container_security_context: k8s.V1SecurityContext | dict | None = None, + dnspolicy: str | None = None, + dns_config: k8s.V1PodDNSConfig | None = None, + hostname: str | None = None, + subdomain: str | None = None, + schedulername: str | None = None, + init_containers: list[k8s.V1Container] | None = None, + priority_class_name: str | None = None, + base_container_name: str | None = None, + termination_message_policy: str = "File", + job_active_deadline_seconds: int | None = None, + pod_active_deadline_seconds: int | None = None, + job_template_file: str | None = None, + full_job_spec: k8s.V1Job | None = None, + backoff_limit: int | None = None, + completion_mode: str | None = None, + completions: int | None = None, + manual_selector: bool | None = None, + parallelism: int | None = None, + selector: k8s.V1LabelSelector | None = None, + suspend: bool | None = None, + ttl_seconds_after_finished: int | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.kubernetes_conn_id = kubernetes_conn_id + self.in_cluster = in_cluster + self.cluster_context = cluster_context + self.config_file = config_file + self.job_template_file = job_template_file + self.full_job_spec = full_job_spec + self.job_request_obj: k8s.V1Job | None = None + self.job: k8s.V1Job | None = None + self.image = image + self.job_namespace = job_namespace + self.pod_namespace = pod_namespace + self.cmds = cmds or [] + self.arguments = arguments or [] + self.job_labels = job_labels or {} + self.pod_labels = pod_labels or {} + self.env_vars = convert_env_vars(env_vars) if env_vars else [] + self.env_from = env_from or [] + self.ports = [convert_port(p) for p in ports] if ports else [] + self.volume_mounts = [convert_volume_mount(v) for v in volume_mounts] if volume_mounts else [] + self.volumes = [convert_volume(volume) for volume in volumes] if volumes else [] + self.image_pull_policy = image_pull_policy + self.node_selector = node_selector or {} + self.job_annotations = job_annotations or {} + self.pod_annotations = pod_annotations or {} + self.affinity = convert_affinity(affinity) if affinity else {} + self.container_resources = container_resources + self.config_file = config_file + self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else [] + self.service_account_name = service_account_name + self.hostnetwork = hostnetwork + self.host_aliases = host_aliases + self.tolerations = ( + [convert_toleration(toleration) for toleration in tolerations] if tolerations else [] + ) + self.security_context = security_context or {} + self.container_security_context = container_security_context + self.dnspolicy = dnspolicy + self.dns_config = dns_config + self.hostname = hostname + self.subdomain = subdomain + self.schedulername = schedulername + self.init_containers = init_containers or [] + self.priority_class_name = priority_class_name + self.job_name = self._set_name(job_name) + self.pod_name = self._set_name(pod_name) + self.pod_request_obj: k8s.V1Pod | None = None + self.pod: k8s.V1Pod | None = None + self.base_container_name = base_container_name or "base" + self.remote_pod: k8s.V1Pod | None = None + self.termination_message_policy = termination_message_policy + self.job_active_deadline_seconds = job_active_deadline_seconds + self.pod_active_deadline_seconds = pod_active_deadline_seconds + self.backoff_limit = backoff_limit + self.completion_mode = completion_mode + self.completions = completions + self.manual_selector = manual_selector + self.parallelism = parallelism + self.selector = selector + self.suspend = suspend + self.ttl_seconds_after_finished = ttl_seconds_after_finished + + @cached_property + def _incluster_namespace(self): + from pathlib import Path + + path = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace") + return path.exists() and path.read_text() or None + + @cached_property + def hook(self) -> KubernetesHook: + hook = KubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, + config_file=self.config_file, + cluster_context=self.cluster_context, + ) + return hook + + def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job: + self.log.debug("Starting job:\n%s", yaml.safe_dump(job_request_obj.to_dict())) + self.hook.create_job(job=job_request_obj) + + return job_request_obj + + def execute(self, context: Context): + self.job_request_obj = self.build_job_request_obj(context) + self.job = self.create_job( # must set `self.job` for `on_kill` + job_request_obj=self.job_request_obj + ) + + @staticmethod + def deserialize_job_template_file(path: str) -> k8s.V1Job: + """ + Generate a Job from a file. + + Unfortunately we need access to the private method + ``_ApiClient__deserialize_model`` from the kubernetes client. + This issue is tracked here: https://github.com/kubernetes-client/python/issues/977. + + :param path: Path to the file + :return: a kubernetes.client.models.V1Job + """ + if os.path.exists(path): + with open(path) as stream: + job = yaml.safe_load(stream) + else: + job = None + log.warning("Template file %s does not exist", path) + + api_client = ApiClient() + return api_client._ApiClient__deserialize_model(job, k8s.V1Job) + + @staticmethod + def _set_name(name: str | None) -> str | None: + if name is not None: + validate_key(name, max_length=220) + return re.sub(r"[^a-z0-9-]+", "-", name.lower()) + return None + + def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: + """ + Return V1Job object based on job template file, full job spec, and other operator parameters. + + The V1Job attributes are derived (in order of precedence) from operator params, full job spec, job + template file. + """ + self.log.debug("Creating job for KubernetesJobOperator task %s", self.task_id) + if self.job_template_file: + self.log.debug("Job template file found, will parse for base job") + job_template = self.deserialize_job_template_file(self.job_template_file) + if self.full_job_spec: + # looks one more time in the future + job_template = reconcile_jobs(job_template, self.full_job_spec) + elif self.full_job_spec: + job_template = self.full_job_spec + else: + job_template = k8s.V1Job(metadata=k8s.V1ObjectMeta()) + + pod_template = k8s.V1PodTemplateSpec( + metadata=k8s.V1ObjectMeta( + namespace=self.pod_namespace, + labels=self.pod_labels, + name=self.pod_name, + annotations=self.pod_annotations, + ), + spec=k8s.V1PodSpec( + node_selector=self.node_selector, + affinity=self.affinity, + tolerations=self.tolerations, + init_containers=self.init_containers, + host_aliases=self.host_aliases, + containers=[ + k8s.V1Container( + image=self.image, + name=self.base_container_name, + command=self.cmds, + ports=self.ports, + image_pull_policy=self.image_pull_policy, + resources=self.container_resources, + volume_mounts=self.volume_mounts, + args=self.arguments, + env=self.env_vars, + env_from=self.env_from, + security_context=self.container_security_context, + termination_message_policy=self.termination_message_policy, + ) + ], + image_pull_secrets=self.image_pull_secrets, + service_account_name=self.service_account_name, + host_network=self.hostnetwork, + hostname=self.hostname, + subdomain=self.subdomain, + security_context=self.security_context, + dns_policy=self.dnspolicy, + dns_config=self.dns_config, + scheduler_name=self.schedulername, + restart_policy="Never", + priority_class_name=self.priority_class_name, + volumes=self.volumes, + active_deadline_seconds=self.pod_active_deadline_seconds, + ), + ) + + job = k8s.V1Job( + api_version="v1", + kind="Job", + metadata=k8s.V1ObjectMeta( + namespace=self.job_namespace, + labels=self.job_labels, + name=self.job_name, + annotations=self.job_annotations, + ), + spec=k8s.V1JobSpec( + active_deadline_seconds=self.job_active_deadline_seconds, + backoff_limit=self.backoff_limit, + completion_mode=self.completion_mode, + completions=self.completions, + manual_selector=self.manual_selector, + parallelism=self.parallelism, + selector=self.selector, + suspend=self.suspend, + template=pod_template, + ttl_seconds_after_finished=self.ttl_seconds_after_finished, + ), + ) + + job = reconcile_jobs(job_template, job) + + # if not pod.metadata.name: + # pod.metadata.name = _create_pod_id( + # task_id=self.task_id, unique=self.random_name_suffix, max_length=80 + # ) + # elif self.random_name_suffix: + # # user has supplied pod name, we're just adding suffix + # pod.metadata.name = _add_pod_suffix(pod_name=pod.metadata.name) + + if not job.metadata.namespace: + hook_namespace = self.hook.get_namespace() + job_namespace = self.job_namespace or hook_namespace or self._incluster_namespace or "default" + job.metadata.namespace = job_namespace + + # for secret in self.secrets: + # self.log.debug("Adding secret to task %s", self.task_id) + # pod = secret.attach_to_pod(pod) + # if self.do_xcom_push: + # self.log.debug("Adding xcom sidecar to task %s", self.task_id) + # pod = xcom_sidecar.add_xcom_sidecar( + # pod, + # sidecar_container_image=self.hook.get_xcom_sidecar_container_image(), + # sidecar_container_resources=self.hook.get_xcom_sidecar_container_resources(), + # ) + + # labels = self._get_ti_pod_labels(context) + # self.log.info("Building pod %s with labels: %s", pod.metadata.name, labels) + + # # Merge Pod Identifying labels with labels passed to operator + # pod.metadata.labels.update(labels) + # # Add Airflow Version to the label + # # And a label to identify that pod is launched by KubernetesPodOperator + # pod.metadata.labels.update( + # { + # "airflow_version": airflow_version.replace("+", "-"), + # "airflow_kpo_in_cluster": str(self.hook.is_in_cluster), + # } + # ) + # pod_mutation_hook(pod) + return job diff --git a/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py b/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py new file mode 100644 index 0000000000000..9d243f846c04d --- /dev/null +++ b/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +""" +from __future__ import annotations + +import copy + +from kubernetes.client import models as k8s + + +def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod: + """ + Merge Kubernetes Pod objects. + + :param base_pod: has the base attributes which are overwritten if they exist + in the client pod and remain if they do not exist in the client_pod + :param client_pod: the pod that the client wants to create. + :return: the merged pods + + This can't be done recursively as certain fields are overwritten and some are concatenated. + """ + if client_pod is None: + return base_pod + + client_pod_cp = copy.deepcopy(client_pod) + client_pod_cp.spec = reconcile_pod_specs(base_pod.spec, client_pod_cp.spec) + client_pod_cp.metadata = reconcile_metadata(base_pod.metadata, client_pod_cp.metadata) + client_pod_cp = merge_objects(base_pod, client_pod_cp) + + return client_pod_cp + + +def reconcile_pod_specs( + base_spec: k8s.V1PodSpec | None, client_spec: k8s.V1PodSpec | None +) -> k8s.V1PodSpec | None: + """ + Merge Kubernetes PodSpec objects. + + :param base_spec: has the base attributes which are overwritten if they exist + in the client_spec and remain if they do not exist in the client_spec + :param client_spec: the spec that the client wants to create. + :return: the merged specs + """ + if base_spec and not client_spec: + return base_spec + if not base_spec and client_spec: + return client_spec + elif client_spec and base_spec: + client_spec.containers = reconcile_containers(base_spec.containers, client_spec.containers) + merged_spec = extend_object_field(base_spec, client_spec, "init_containers") + merged_spec = extend_object_field(base_spec, merged_spec, "volumes") + return merge_objects(base_spec, merged_spec) + + return None + + +def reconcile_jobs(base_job: k8s.V1Job, client_job: k8s.V1Job | None) -> k8s.V1Job: + """ + Merge Kubernetes Job objects. + + :param base_job: has the base attributes which are overwritten if they exist + in the client job and remain if they do not exist in the client_job + :param client_job: the job that the client wants to create. + :return: the merged jobs + + This can't be done recursively as certain fields are overwritten and some are concatenated. + """ + if client_job is None: + return base_job + + client_job_cp = copy.deepcopy(client_job) + client_job_cp.spec = reconcile_job_specs(base_job.spec, client_job_cp.spec) + client_job_cp.metadata = reconcile_metadata(base_job.metadata, client_job_cp.metadata) + client_job_cp = merge_objects(base_job, client_job_cp) + + return client_job_cp + + +def reconcile_job_specs( + base_spec: k8s.V1JobSpec | None, client_spec: k8s.V1JobSpec | None +) -> k8s.V1JobSpec | None: + """ + Merge Kubernetes JobSpec objects. + + :param base_spec: has the base attributes which are overwritten if they exist + in the client_spec and remain if they do not exist in the client_spec + :param client_spec: the spec that the client wants to create. + :return: the merged specs + """ + if base_spec and not client_spec: + return base_spec + if not base_spec and client_spec: + return client_spec + elif client_spec and base_spec: + client_spec.containers = reconcile_containers(base_spec.containers, client_spec.containers) + merged_spec = extend_object_field(base_spec, client_spec, "init_containers") + merged_spec = extend_object_field(base_spec, merged_spec, "volumes") + return merge_objects(base_spec, merged_spec) + + return None + + +def reconcile_metadata(base_meta, client_meta): + """ + Merge Kubernetes Metadata objects. + + :param base_meta: has the base attributes which are overwritten if they exist + in the client_meta and remain if they do not exist in the client_meta + :param client_meta: the spec that the client wants to create. + :return: the merged specs + """ + if base_meta and not client_meta: + return base_meta + if not base_meta and client_meta: + return client_meta + elif client_meta and base_meta: + client_meta.labels = merge_objects(base_meta.labels, client_meta.labels) + client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations) + extend_object_field(base_meta, client_meta, "managed_fields") + extend_object_field(base_meta, client_meta, "finalizers") + extend_object_field(base_meta, client_meta, "owner_references") + return merge_objects(base_meta, client_meta) + + return None + + +def reconcile_containers( + base_containers: list[k8s.V1Container], client_containers: list[k8s.V1Container] +) -> list[k8s.V1Container]: + """ + Merge Kubernetes Container objects. + + :param base_containers: has the base attributes which are overwritten if they exist + in the client_containers and remain if they do not exist in the client_containers + :param client_containers: the containers that the client wants to create. + :return: the merged containers + + The runs recursively over the list of containers. + """ + if not base_containers: + return client_containers + if not client_containers: + return base_containers + + client_container = client_containers[0] + base_container = base_containers[0] + client_container = extend_object_field(base_container, client_container, "volume_mounts") + client_container = extend_object_field(base_container, client_container, "env") + client_container = extend_object_field(base_container, client_container, "env_from") + client_container = extend_object_field(base_container, client_container, "ports") + client_container = extend_object_field(base_container, client_container, "volume_devices") + client_container = merge_objects(base_container, client_container) + + return [ + client_container, + *reconcile_containers(base_containers[1:], client_containers[1:]), + ] + + +def merge_objects(base_obj, client_obj): + """ + Merge objects. + + :param base_obj: has the base attributes which are overwritten if they exist + in the client_obj and remain if they do not exist in the client_obj + :param client_obj: the object that the client wants to create. + :return: the merged objects + """ + if not base_obj: + return client_obj + if not client_obj: + return base_obj + + client_obj_cp = copy.deepcopy(client_obj) + + if isinstance(base_obj, dict) and isinstance(client_obj_cp, dict): + base_obj_cp = copy.deepcopy(base_obj) + base_obj_cp.update(client_obj_cp) + return base_obj_cp + + for base_key in base_obj.to_dict(): + base_val = getattr(base_obj, base_key, None) + if not getattr(client_obj, base_key, None) and base_val: + if not isinstance(client_obj_cp, dict): + setattr(client_obj_cp, base_key, base_val) + else: + client_obj_cp[base_key] = base_val + return client_obj_cp + + +def extend_object_field(base_obj, client_obj, field_name): + """ + Add field values to existing objects. + + :param base_obj: an object which has a property `field_name` that is a list + :param client_obj: an object which has a property `field_name` that is a list. + A copy of this object is returned with `field_name` modified + :param field_name: the name of the list field + :return: the client_obj with the property `field_name` being the two properties appended + """ + client_obj_cp = copy.deepcopy(client_obj) + base_obj_field = getattr(base_obj, field_name, None) + client_obj_field = getattr(client_obj, field_name, None) + + if (not isinstance(base_obj_field, list) and base_obj_field is not None) or ( + not isinstance(client_obj_field, list) and client_obj_field is not None + ): + raise ValueError( + f"The chosen field must be a list. Got {type(base_obj_field)} base_object_field " + f"and {type(client_obj_field)} client_object_field." + ) + + if not base_obj_field: + return client_obj_cp + if not client_obj_field: + setattr(client_obj_cp, field_name, base_obj_field) + return client_obj_cp + + appended_fields = base_obj_field + client_obj_field + setattr(client_obj_cp, field_name, appended_fields) + return client_obj_cp From f0430eb38db9a9d19eccb363a6d49da062182dfe Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Wed, 17 Jan 2024 16:05:41 +0000 Subject: [PATCH 2/8] Create GKEStartJobOperator and KubernetesJobOperator --- .../kubernetes/kubernetes_helper_functions.py | 58 +++ .../cncf/kubernetes/operators/job.py | 271 +++-------- .../cncf/kubernetes/utils/k8s_yaml_manager.py | 16 +- .../google/cloud/hooks/kubernetes_engine.py | 56 +++ .../google/cloud/links/kubernetes_engine.py | 29 ++ .../cloud/operators/kubernetes_engine.py | 122 ++++- airflow/providers/google/provider.yaml | 1 + .../operators.rst | 33 ++ .../operators/cloud/kubernetes_engine.rst | 19 + .../cncf/kubernetes/operators/test_job.py | 440 ++++++++++++++++++ .../cloud/hooks/test_kubernetes_engine.py | 47 ++ .../cloud/operators/test_kubernetes_engine.py | 117 +++++ .../cncf/kubernetes/example_kubernetes_job.py | 57 +++ .../example_kubernetes_engine_job.py | 87 ++++ 14 files changed, 1148 insertions(+), 205 deletions(-) create mode 100644 tests/providers/cncf/kubernetes/operators/test_job.py create mode 100644 tests/system/providers/cncf/kubernetes/example_kubernetes_job.py create mode 100644 tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py diff --git a/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index c9d94a2a9a5cf..62dc351bbd1a5 100644 --- a/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -19,6 +19,7 @@ import logging import secrets import string +import warnings from typing import TYPE_CHECKING import pendulum @@ -26,6 +27,7 @@ from airflow.compat.functools import cache from airflow.configuration import conf +from airflow.exceptions import AirflowProviderDeprecationWarning if TYPE_CHECKING: from airflow.models.taskinstancekey import TaskInstanceKey @@ -45,6 +47,18 @@ def rand_str(num): return "".join(secrets.choice(alphanum_lower) for _ in range(num)) +def add_unique_suffix(*, name: str, rand_len: int = 8, max_len: int = POD_NAME_MAX_LENGTH) -> str: + """Add random string to pod or job name while staying under max length. + + :param name: name of the pod or job + :param rand_len: length of the random string to append + :param max_len: maximum length of the pod name + :meta private: + """ + suffix = "-" + rand_str(rand_len) + return name[: max_len - len(suffix)].strip("-.") + suffix + + def add_pod_suffix(*, pod_name: str, rand_len: int = 8, max_len: int = POD_NAME_MAX_LENGTH) -> str: """Add random string to pod name while staying under max length. @@ -53,10 +67,48 @@ def add_pod_suffix(*, pod_name: str, rand_len: int = 8, max_len: int = POD_NAME_ :param max_len: maximum length of the pod name :meta private: """ + warnings.warn( + "This function is deprecated. Please use `add_unique_suffix`.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + suffix = "-" + rand_str(rand_len) return pod_name[: max_len - len(suffix)].strip("-.") + suffix +def create_unique_id( + dag_id: str | None = None, + task_id: str | None = None, + *, + max_length: int = POD_NAME_MAX_LENGTH, + unique: bool = True, +) -> str: + """ + Generate unique pod or job ID given a dag_id and / or task_id. + + :param dag_id: DAG ID + :param task_id: Task ID + :param max_length: max number of characters + :param unique: whether a random string suffix should be added + :return: A valid identifier for a kubernetes pod name + """ + if not (dag_id or task_id): + raise ValueError("Must supply either dag_id or task_id.") + name = "" + if dag_id: + name += dag_id + if task_id: + if name: + name += "-" + name += task_id + base_name = slugify(name, lowercase=True)[:max_length].strip(".-") + if unique: + return add_pod_suffix(pod_name=base_name, rand_len=8, max_len=max_length) + else: + return base_name + + def create_pod_id( dag_id: str | None = None, task_id: str | None = None, @@ -73,6 +125,12 @@ def create_pod_id( :param unique: whether a random string suffix should be added :return: A valid identifier for a kubernetes pod name """ + warnings.warn( + "This function is deprecated. Please use `create_unique_id`.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + if not (dag_id or task_id): raise ValueError("Must supply either dag_id or task_id.") name = "" diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py index 0710d9f665007..ffe9776c4f9d2 100644 --- a/airflow/providers/cncf/kubernetes/operators/job.py +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -19,29 +19,22 @@ import logging import os -import re from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence -from kubernetes.client import models as k8s +from kubernetes.client import BatchV1Api, models as k8s from kubernetes.client.api_client import ApiClient -from airflow.models import BaseOperator -from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import ( - convert_affinity, - convert_env_vars, - convert_image_pull_secrets, - convert_port, - convert_toleration, - convert_volume, - convert_volume_mount, -) from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( + add_unique_suffix, + create_unique_id, +) +from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.k8s_yaml_manager import ( reconcile_jobs, ) from airflow.utils import yaml -from airflow.utils.helpers import validate_key if TYPE_CHECKING: from airflow.utils.context import Context @@ -49,58 +42,37 @@ log = logging.getLogger(__name__) -class KubernetesJobOperator(BaseOperator): +class KubernetesJobOperator(KubernetesPodOperator): """ - Executes a Kubernetes Job - - + Executes a Kubernetes Job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:KubernetesJobOperator` + + .. note:: + If you use `Google Kubernetes Engine `__ + and Airflow is not running in the same cluster, consider using + :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartJobOperator`, which + simplifies the authorization process. + + :param job_template_file: path to job template file (templated) + :param full_job_spec: The complete JodSpec + :param backoff_limit: Specifies the number of retries before marking this job failed. Defaults to 6 + :param completion_mode: CompletionMode specifies how Pod completions are tracked. It can be `NonIndexed` (default) or `Indexed`. + :param completions: Specifies the desired number of successfully finished pods the job should be run with. + :param manual_selector: manualSelector controls generation of pod labels and pod selectors. + :param parallelism: Specifies the maximum desired number of pods the job should run at any given time. + :param selector: The selector of this V1JobSpec. + :param suspend: Suspend specifies whether the Job controller should create Pods or not. + :param ttl_seconds_after_finished: ttlSecondsAfterFinished limits the lifetime of a Job that has finished execution (either Complete or Failed). """ + template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields)) + def __init__( self, *, - kubernetes_conn_id: str | None = KubernetesHook.default_conn_name, - job_namespace: str | None = None, - pod_namespace: str | None = None, - image: str | None = None, - job_name: str | None = None, - pod_name: str | None = None, - cmds: list[str] | None = None, - arguments: list[str] | None = None, - ports: list[k8s.V1ContainerPort] | None = None, - volume_mounts: list[k8s.V1VolumeMount] | None = None, - volumes: list[k8s.V1Volume] | None = None, - env_vars: list[k8s.V1EnvVar] | dict[str, str] | None = None, - env_from: list[k8s.V1EnvFromSource] | None = None, - in_cluster: bool | None = None, - cluster_context: str | None = None, - job_labels: dict | None = None, - pod_labels: dict | None = None, - image_pull_policy: str | None = None, - job_annotations: dict | None = None, - pod_annotations: dict | None = None, - container_resources: k8s.V1ResourceRequirements | None = None, - affinity: k8s.V1Affinity | None = None, - config_file: str | None = None, - node_selector: dict | None = None, - image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None, - service_account_name: str | None = None, - hostnetwork: bool = False, - host_aliases: list[k8s.V1HostAlias] | None = None, - tolerations: list[k8s.V1Toleration] | None = None, - security_context: k8s.V1PodSecurityContext | dict | None = None, - container_security_context: k8s.V1SecurityContext | dict | None = None, - dnspolicy: str | None = None, - dns_config: k8s.V1PodDNSConfig | None = None, - hostname: str | None = None, - subdomain: str | None = None, - schedulername: str | None = None, - init_containers: list[k8s.V1Container] | None = None, - priority_class_name: str | None = None, - base_container_name: str | None = None, - termination_message_policy: str = "File", - job_active_deadline_seconds: int | None = None, - pod_active_deadline_seconds: int | None = None, job_template_file: str | None = None, full_job_spec: k8s.V1Job | None = None, backoff_limit: int | None = None, @@ -114,58 +86,10 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self.kubernetes_conn_id = kubernetes_conn_id - self.in_cluster = in_cluster - self.cluster_context = cluster_context - self.config_file = config_file self.job_template_file = job_template_file self.full_job_spec = full_job_spec self.job_request_obj: k8s.V1Job | None = None self.job: k8s.V1Job | None = None - self.image = image - self.job_namespace = job_namespace - self.pod_namespace = pod_namespace - self.cmds = cmds or [] - self.arguments = arguments or [] - self.job_labels = job_labels or {} - self.pod_labels = pod_labels or {} - self.env_vars = convert_env_vars(env_vars) if env_vars else [] - self.env_from = env_from or [] - self.ports = [convert_port(p) for p in ports] if ports else [] - self.volume_mounts = [convert_volume_mount(v) for v in volume_mounts] if volume_mounts else [] - self.volumes = [convert_volume(volume) for volume in volumes] if volumes else [] - self.image_pull_policy = image_pull_policy - self.node_selector = node_selector or {} - self.job_annotations = job_annotations or {} - self.pod_annotations = pod_annotations or {} - self.affinity = convert_affinity(affinity) if affinity else {} - self.container_resources = container_resources - self.config_file = config_file - self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else [] - self.service_account_name = service_account_name - self.hostnetwork = hostnetwork - self.host_aliases = host_aliases - self.tolerations = ( - [convert_toleration(toleration) for toleration in tolerations] if tolerations else [] - ) - self.security_context = security_context or {} - self.container_security_context = container_security_context - self.dnspolicy = dnspolicy - self.dns_config = dns_config - self.hostname = hostname - self.subdomain = subdomain - self.schedulername = schedulername - self.init_containers = init_containers or [] - self.priority_class_name = priority_class_name - self.job_name = self._set_name(job_name) - self.pod_name = self._set_name(pod_name) - self.pod_request_obj: k8s.V1Pod | None = None - self.pod: k8s.V1Pod | None = None - self.base_container_name = base_container_name or "base" - self.remote_pod: k8s.V1Pod | None = None - self.termination_message_policy = termination_message_policy - self.job_active_deadline_seconds = job_active_deadline_seconds - self.pod_active_deadline_seconds = pod_active_deadline_seconds self.backoff_limit = backoff_limit self.completion_mode = completion_mode self.completions = completions @@ -192,6 +116,10 @@ def hook(self) -> KubernetesHook: ) return hook + @cached_property + def client(self) -> BatchV1Api: + return self.hook.batch_v1_client + def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job: self.log.debug("Starting job:\n%s", yaml.safe_dump(job_request_obj.to_dict())) self.hook.create_job(job=job_request_obj) @@ -204,6 +132,10 @@ def execute(self, context: Context): job_request_obj=self.job_request_obj ) + ti = context["ti"] + ti.xcom_push(key="job_name", value=self.job.metadata.name) + ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) + @staticmethod def deserialize_job_template_file(path: str) -> k8s.V1Job: """ @@ -226,12 +158,16 @@ def deserialize_job_template_file(path: str) -> k8s.V1Job: api_client = ApiClient() return api_client._ApiClient__deserialize_model(job, k8s.V1Job) - @staticmethod - def _set_name(name: str | None) -> str | None: - if name is not None: - validate_key(name, max_length=220) - return re.sub(r"[^a-z0-9-]+", "-", name.lower()) - return None + def on_kill(self) -> None: + if self.job: + job = self.job + kwargs = { + "name": job.metadata.name, + "namespace": job.metadata.namespace, + } + if self.termination_grace_period is not None: + kwargs.update(grace_period_seconds=self.termination_grace_period) + self.client.delete_namespaced_job(**kwargs) def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: """ @@ -245,69 +181,29 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: self.log.debug("Job template file found, will parse for base job") job_template = self.deserialize_job_template_file(self.job_template_file) if self.full_job_spec: - # looks one more time in the future job_template = reconcile_jobs(job_template, self.full_job_spec) elif self.full_job_spec: job_template = self.full_job_spec else: job_template = k8s.V1Job(metadata=k8s.V1ObjectMeta()) - pod_template = k8s.V1PodTemplateSpec( - metadata=k8s.V1ObjectMeta( - namespace=self.pod_namespace, - labels=self.pod_labels, - name=self.pod_name, - annotations=self.pod_annotations, - ), - spec=k8s.V1PodSpec( - node_selector=self.node_selector, - affinity=self.affinity, - tolerations=self.tolerations, - init_containers=self.init_containers, - host_aliases=self.host_aliases, - containers=[ - k8s.V1Container( - image=self.image, - name=self.base_container_name, - command=self.cmds, - ports=self.ports, - image_pull_policy=self.image_pull_policy, - resources=self.container_resources, - volume_mounts=self.volume_mounts, - args=self.arguments, - env=self.env_vars, - env_from=self.env_from, - security_context=self.container_security_context, - termination_message_policy=self.termination_message_policy, - ) - ], - image_pull_secrets=self.image_pull_secrets, - service_account_name=self.service_account_name, - host_network=self.hostnetwork, - hostname=self.hostname, - subdomain=self.subdomain, - security_context=self.security_context, - dns_policy=self.dnspolicy, - dns_config=self.dns_config, - scheduler_name=self.schedulername, - restart_policy="Never", - priority_class_name=self.priority_class_name, - volumes=self.volumes, - active_deadline_seconds=self.pod_active_deadline_seconds, - ), + pod_template = super().build_pod_request_obj(context) + pod_template_spec = k8s.V1PodTemplateSpec( + metadata=pod_template.metadata, + spec=pod_template.spec, ) job = k8s.V1Job( - api_version="v1", + api_version="batch/v1", kind="Job", metadata=k8s.V1ObjectMeta( - namespace=self.job_namespace, - labels=self.job_labels, - name=self.job_name, - annotations=self.job_annotations, + namespace=self.namespace, + labels=self.labels, + name=self.name, + annotations=self.annotations, ), spec=k8s.V1JobSpec( - active_deadline_seconds=self.job_active_deadline_seconds, + active_deadline_seconds=self.active_deadline_seconds, backoff_limit=self.backoff_limit, completion_mode=self.completion_mode, completions=self.completions, @@ -315,49 +211,28 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: parallelism=self.parallelism, selector=self.selector, suspend=self.suspend, - template=pod_template, + template=pod_template_spec, ttl_seconds_after_finished=self.ttl_seconds_after_finished, ), ) job = reconcile_jobs(job_template, job) - # if not pod.metadata.name: - # pod.metadata.name = _create_pod_id( - # task_id=self.task_id, unique=self.random_name_suffix, max_length=80 - # ) - # elif self.random_name_suffix: - # # user has supplied pod name, we're just adding suffix - # pod.metadata.name = _add_pod_suffix(pod_name=pod.metadata.name) + if not job.metadata.name: + job.metadata.name = create_unique_id( + task_id=self.task_id, unique=self.random_name_suffix, max_length=80 + ) + elif self.random_name_suffix: + # user has supplied job name, we're just adding suffix + job.metadata.name = add_unique_suffix(name=job.metadata.name) + + job.metadata.name = f"job-{job.metadata.name}" if not job.metadata.namespace: hook_namespace = self.hook.get_namespace() - job_namespace = self.job_namespace or hook_namespace or self._incluster_namespace or "default" + job_namespace = self.namespace or hook_namespace or self._incluster_namespace or "default" job.metadata.namespace = job_namespace - # for secret in self.secrets: - # self.log.debug("Adding secret to task %s", self.task_id) - # pod = secret.attach_to_pod(pod) - # if self.do_xcom_push: - # self.log.debug("Adding xcom sidecar to task %s", self.task_id) - # pod = xcom_sidecar.add_xcom_sidecar( - # pod, - # sidecar_container_image=self.hook.get_xcom_sidecar_container_image(), - # sidecar_container_resources=self.hook.get_xcom_sidecar_container_resources(), - # ) - - # labels = self._get_ti_pod_labels(context) - # self.log.info("Building pod %s with labels: %s", pod.metadata.name, labels) - - # # Merge Pod Identifying labels with labels passed to operator - # pod.metadata.labels.update(labels) - # # Add Airflow Version to the label - # # And a label to identify that pod is launched by KubernetesPodOperator - # pod.metadata.labels.update( - # { - # "airflow_version": airflow_version.replace("+", "-"), - # "airflow_kpo_in_cluster": str(self.hook.is_in_cluster), - # } - # ) - # pod_mutation_hook(pod) + self.log.info("Building job %s ", job.metadata.name) + return job diff --git a/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py b/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py index 9d243f846c04d..63bd9d68fc341 100644 --- a/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py @@ -15,12 +15,17 @@ # specific language governing permissions and limitations # under the License. """ +K8s YAML Manager. + +This module provides a functions for working with K8s yaml. """ from __future__ import annotations import copy +from typing import TYPE_CHECKING -from kubernetes.client import models as k8s +if TYPE_CHECKING: + from kubernetes.client import models as k8s def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod: @@ -107,10 +112,11 @@ def reconcile_job_specs( if not base_spec and client_spec: return client_spec elif client_spec and base_spec: - client_spec.containers = reconcile_containers(base_spec.containers, client_spec.containers) - merged_spec = extend_object_field(base_spec, client_spec, "init_containers") - merged_spec = extend_object_field(base_spec, merged_spec, "volumes") - return merge_objects(base_spec, merged_spec) + client_spec.template.spec = reconcile_pod_specs(base_spec.template.spec, client_spec.template.spec) + client_spec.template.metadata = reconcile_metadata( + base_spec.template.metadata, client_spec.template.metadata + ) + return merge_objects(base_spec, client_spec) return None diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 0878fa6b6f729..4fa97372f6a3f 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -588,6 +588,62 @@ def get_pod(self, name: str, namespace: str) -> V1Pod: ) +class GKEJobHook(GoogleBaseHook, KubernetesHook): + """Google Kubernetes Engine Job APIs.""" + + def __init__( + self, + cluster_url: str, + ssl_ca_cert: str, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._cluster_url = cluster_url + self._ssl_ca_cert = ssl_ca_cert + + @cached_property + def api_client(self) -> client.ApiClient: + return self.get_conn() + + @cached_property + def core_v1_client(self) -> client.CoreV1Api: + return client.CoreV1Api(self.api_client) + + @cached_property + def batch_v1_client(self) -> client.BatchV1Api: + return client.BatchV1Api(self.api_client) + + def get_conn(self) -> client.ApiClient: + configuration = self._get_config() + configuration.refresh_api_key_hook = self._refresh_api_key_hook + return client.ApiClient(configuration) + + def _refresh_api_key_hook(self, configuration: client.configuration.Configuration): + configuration.api_key = {"authorization": self._get_token(self.get_credentials())} + + def _get_config(self) -> client.configuration.Configuration: + configuration = client.Configuration( + host=self._cluster_url, + api_key_prefix={"authorization": "Bearer"}, + api_key={"authorization": self._get_token(self.get_credentials())}, + ) + configuration.ssl_ca_cert = FileOrData( + { + "certificate-authority-data": self._ssl_ca_cert, + }, + file_key_name="certificate-authority", + ).as_file() + return configuration + + @staticmethod + def _get_token(creds: google.auth.credentials.Credentials) -> str: + if creds.token is None or creds.expired: + auth_req = google_requests.Request() + creds.refresh(auth_req) + return creds.token + + class GKEPodAsyncHook(GoogleBaseAsyncHook): """Google Kubernetes Engine pods APIs asynchronously. diff --git a/airflow/providers/google/cloud/links/kubernetes_engine.py b/airflow/providers/google/cloud/links/kubernetes_engine.py index 0703e2eb2cb34..ba59d02b55222 100644 --- a/airflow/providers/google/cloud/links/kubernetes_engine.py +++ b/airflow/providers/google/cloud/links/kubernetes_engine.py @@ -34,6 +34,10 @@ KUBERNETES_BASE_LINK + "/pod/{location}/{cluster_name}/{namespace}/{pod_name}/details?project={project_id}" ) +KUBERNETES_JOB_LINK = ( + KUBERNETES_BASE_LINK + + "/job/{location}/{cluster_name}/{namespace}/{job_name}/details?project={project_id}" +) class KubernetesEngineClusterLink(BaseGoogleLink): @@ -82,3 +86,28 @@ def persist( "project_id": task_instance.project_id, }, ) + + +class KubernetesEngineJobLink(BaseGoogleLink): + """Helper class for constructing Kubernetes Engine Job Link.""" + + name = "Kubernetes Job" + key = "kubernetes_job_conf" + format_str = KUBERNETES_JOB_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=KubernetesEngineJobLink.key, + value={ + "location": task_instance.location, + "cluster_name": task_instance.cluster_name, + "namespace": task_instance.job.metadata.namespace, + "job_name": task_instance.job.metadata.name, + "project_id": task_instance.project_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index e063bc7bdc773..8142a4ed61691 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -32,11 +32,18 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction -from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEDeploymentHook, GKEHook, GKEPodHook +from airflow.providers.google.cloud.hooks.kubernetes_engine import ( + GKEDeploymentHook, + GKEHook, + GKEJobHook, + GKEPodHook, +) from airflow.providers.google.cloud.links.kubernetes_engine import ( KubernetesEngineClusterLink, + KubernetesEngineJobLink, KubernetesEnginePodLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -44,7 +51,7 @@ from airflow.utils.timezone import utcnow if TYPE_CHECKING: - from kubernetes.client.models import V1Pod + from kubernetes.client.models import V1Job, V1Pod from airflow.utils.context import Context @@ -780,3 +787,114 @@ def execute_complete(self, context: Context, event: dict, **kwargs): self._ssl_ca_cert = kwargs["ssl_ca_cert"] return super().execute_complete(context, event, **kwargs) + + +class GKEStartJobOperator(KubernetesJobOperator): + """ + Executes a Kubernetes job in the specified Google Kubernetes Engine cluster. + + This Operator assumes that the system has gcloud installed and has configured a + connection id with a service account. + + The **minimum** required to define a cluster to create are the variables + ``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``, + ``namespace``, and ``image`` + + .. seealso:: + For more detail about Kubernetes Engine authentication have a look at the reference: + https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#internal_ip + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKEStartJobOperator` + + :param location: The name of the Google Kubernetes Engine zone or region in which the + cluster resides, e.g. 'us-central1-a' + :param cluster_name: The name of the Google Kubernetes Engine cluster + :param use_internal_ip: Use the internal IP address as the endpoint. + :param project_id: The Google Developers Console project id + :param gcp_conn_id: The Google cloud connection id to use. This allows for + users to specify a service account. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param location: The location param is region name. + """ + + template_fields: Sequence[str] = tuple( + {"project_id", "location", "cluster_name"} | set(KubernetesJobOperator.template_fields) + ) + operator_extra_links = (KubernetesEngineJobLink(),) + + def __init__( + self, + *, + location: str, + cluster_name: str, + use_internal_ip: bool = False, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.cluster_name = cluster_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.use_internal_ip = use_internal_ip + + self.job: V1Job | None = None + self._ssl_ca_cert: str | None = None + self._cluster_url: str | None = None + + if self.gcp_conn_id is None: + raise AirflowException( + "The gcp_conn_id parameter has become required. If you want to use Application Default " + "Credentials (ADC) strategy for authorization, create an empty connection " + "called `google_cloud_default`.", + ) + # There is no need to manage the kube_config file, as it will be generated automatically. + # All Kubernetes parameters (except config_file) are also valid for the GKEStartJobOperator. + if self.config_file: + raise AirflowException("config_file is not an allowed parameter for the GKEStartJobOperator.") + + @cached_property + def cluster_hook(self) -> GKEHook: + return GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + @cached_property + def hook(self) -> GKEJobHook: + if self._cluster_url is None or self._ssl_ca_cert is None: + raise AttributeError( + "Cluster url and ssl_ca_cert should be defined before using self.hook method. " + "Try to use self.get_kube_creds method", + ) + + hook = GKEJobHook( + gcp_conn_id=self.gcp_conn_id, + cluster_url=self._cluster_url, + ssl_ca_cert=self._ssl_ca_cert, + ) + return hook + + def execute(self, context: Context): + """Executes process of creating Job.""" + self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( + cluster_name=self.cluster_name, + project_id=self.project_id, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ).fetch_cluster_info() + + return super().execute(context) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 96dd7a6f3aa3b..c8e1f81f6a97c 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1203,6 +1203,7 @@ extra-links: - airflow.providers.google.cloud.links.stackdriver.StackdriverPoliciesLink - airflow.providers.google.cloud.links.kubernetes_engine.KubernetesEngineClusterLink - airflow.providers.google.cloud.links.kubernetes_engine.KubernetesEnginePodLink + - airflow.providers.google.cloud.links.kubernetes_engine.KubernetesEngineJobLink - airflow.providers.google.cloud.links.pubsub.PubSubSubscriptionLink - airflow.providers.google.cloud.links.pubsub.PubSubTopicLink - airflow.providers.google.cloud.links.cloud_memorystore.MemcachedInstanceDetailsLink diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index b2e5f15f393cd..020149c403e87 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -586,3 +586,36 @@ For further information, look at: * `Kubernetes Documentation `__ * `Spark-on-k8s-operator Documentation - User guide `__ * `Spark-on-k8s-operator Documentation - API `__ + + +.. _howto/operator:kubernetesjoboperator: + +KubernetesJobOperator +===================== + +The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` allows +you to create and run Jobs on a Kubernetes cluster. + +.. note:: + If you use a managed Kubernetes consider using a specialize KPO operator as it simplifies the Kubernetes authorization process : + + - :ref:`GKEStartJobOperator ` operator for `Google Kubernetes Engine `__. + +.. note:: + The :doc:`Kubernetes executor ` is **not** required to use this operator. + +How does this operator work? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` uses the +Kubernetes API to launch a job in a Kubernetes cluster. The operator uses the Kube Python Client to generate a Kubernetes API +request that dynamically launches this Job. +Users can specify a kubeconfig file using the ``config_file`` parameter, otherwise the operator will default +to ``~/.kube/config``. It also allows users to supply a template YAML file using the ``job_template_file`` parameter. + +.. exampleinclude:: /../../tests/system/providers/cncf/kubernetes/example_kubernetes_job.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_k8s_job] + :end-before: [END howto_operator_k8s_job] + +More information about the Jobs here: `Kubernetes Job Documentation `__ diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index 7663ce48118ff..1c50f7a037cd7 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -192,6 +192,25 @@ lot less resources wasted on idle Operators or Sensors: :start-after: [START howto_operator_gke_start_pod_xcom_async] :end-before: [END howto_operator_gke_start_pod_xcom_async] +Run a Job on a GKE cluster +"""""""""""""""""""""""""" + +There are two operators available in order to run a job on a GKE cluster: + +* :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` +* :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartJobOperator` + +``GKEStartJobOperator`` extends ``KubernetesJobOperator`` to provide authorization using Google Cloud credentials. +There is no need to manage the ``kube_config`` file, as it will be generated automatically. +All Kubernetes parameters (except ``config_file``) are also valid for the ``GKEStartJobOperator``. +For more information on ``KubernetesJobOperator``, please look at: :ref:`howto/operator:KubernetesJobOperator` guide. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gke_start_job] + :end-before: [END howto_operator_gke_start_job] + Reference ^^^^^^^^^ diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py new file mode 100644 index 0000000000000..e3e1e335ce1f8 --- /dev/null +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re + +import pendulum +import pytest +from kubernetes.client import ApiClient, models as k8s + +from airflow.models import DAG, DagModel, DagRun, TaskInstance +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator +from airflow.utils import timezone +from airflow.utils.session import create_session +from airflow.utils.types import DagRunType + +DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) + + +def create_context(task, persist_to_db=False, map_index=None): + if task.has_dag(): + dag = task.dag + else: + dag = DAG(dag_id="dag", start_date=pendulum.now()) + dag.add_task(task) + dag_run = DagRun( + run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), + run_type=DagRunType.MANUAL, + dag_id=dag.dag_id, + ) + task_instance = TaskInstance(task=task, run_id=dag_run.run_id) + task_instance.dag_run = dag_run + if map_index is not None: + task_instance.map_index = map_index + if persist_to_db: + with create_session() as session: + session.add(DagModel(dag_id=dag.dag_id)) + session.add(dag_run) + session.add(task_instance) + session.commit() + return { + "dag": dag, + "ts": DEFAULT_DATE.isoformat(), + "task": task, + "ti": task_instance, + "task_instance": task_instance, + "run_id": "test", + } + + +@pytest.mark.execution_timeout(300) +class TestKubernetesJobOperator: + def test_templates(self, create_task_instance_of_operator): + dag_id = "TestKubernetesJobOperator" + ti = create_task_instance_of_operator( + KubernetesJobOperator, + dag_id=dag_id, + task_id="task-id", + namespace="{{ dag.dag_id }}", + container_resources=k8s.V1ResourceRequirements( + requests={"memory": "{{ dag.dag_id }}", "cpu": "{{ dag.dag_id }}"}, + limits={"memory": "{{ dag.dag_id }}", "cpu": "{{ dag.dag_id }}"}, + ), + volume_mounts=[ + k8s.V1VolumeMount( + name="{{ dag.dag_id }}", + mount_path="mount_path", + sub_path="{{ dag.dag_id }}", + ) + ], + job_template_file="{{ dag.dag_id }}", + config_file="{{ dag.dag_id }}", + labels="{{ dag.dag_id }}", + env_vars=["{{ dag.dag_id }}"], + arguments="{{ dag.dag_id }}", + cmds="{{ dag.dag_id }}", + image="{{ dag.dag_id }}", + annotations={"dag-id": "{{ dag.dag_id }}"}, + ) + + rendered = ti.render_templates() + + assert dag_id == rendered.container_resources.limits["memory"] + assert dag_id == rendered.container_resources.limits["cpu"] + assert dag_id == rendered.container_resources.requests["memory"] + assert dag_id == rendered.container_resources.requests["cpu"] + assert dag_id == rendered.volume_mounts[0].name + assert dag_id == rendered.volume_mounts[0].sub_path + assert dag_id == ti.task.image + assert dag_id == ti.task.cmds + assert dag_id == ti.task.namespace + assert dag_id == ti.task.config_file + assert dag_id == ti.task.labels + assert dag_id == ti.task.job_template_file + assert dag_id == ti.task.arguments + assert dag_id == ti.task.env_vars[0] + assert dag_id == rendered.annotations["dag-id"] + + def sanitize_for_serialization(self, obj): + return ApiClient().sanitize_for_serialization(obj) + + def test_backoff_limit_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + backoff_limit=6, + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.backoff_limit == 6 + + def test_completion_mode_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + completion_mode="NonIndexed", + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.completion_mode == "NonIndexed" + + def test_completions_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + completions=1, + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.completions == 1 + + def test_manual_selector_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + manual_selector=False, + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.manual_selector is False + + def test_parallelism_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + parallelism=2, + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.parallelism == 2 + + def test_selector(self): + selector = k8s.V1LabelSelector( + match_expressions=[], + match_labels={"foo": "bar", "hello": "airflow"}, + ) + + k = KubernetesJobOperator( + task_id="task", + selector=selector, + ) + + job = k.build_job_request_obj(create_context(k)) + assert isinstance(job.spec.selector, k8s.V1LabelSelector) + assert job.spec.selector == selector + + def test_suspend_correctly_set(self): + k = KubernetesJobOperator( + task_id="task", + suspend=True, + ) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.suspend is True + + def test_ttl_seconds_after_finished_correctly_set(self): + k = KubernetesJobOperator(task_id="task", ttl_seconds_after_finished=5) + job = k.build_job_request_obj(create_context(k)) + assert job.spec.ttl_seconds_after_finished == 5 + + @pytest.mark.parametrize("randomize", [True, False]) + def test_provided_job_name(self, randomize): + name_base = "test" + k = KubernetesJobOperator( + name=name_base, + random_name_suffix=randomize, + task_id="task", + ) + context = create_context(k) + job = k.build_job_request_obj(context) + + if randomize: + assert job.metadata.name.startswith(f"job-{name_base}") + assert job.metadata.name != f"job-{name_base}" + else: + assert job.metadata.name == f"job-{name_base}" + + @pytest.fixture + def job_spec(self): + return k8s.V1Job( + metadata=k8s.V1ObjectMeta(name="hello", labels={"foo": "bar"}, namespace="jobspecnamespace"), + spec=k8s.V1JobSpec( + template=k8s.V1PodTemplateSpec( + metadata=k8s.V1ObjectMeta( + name="world", labels={"foo": "bar"}, namespace="podspecnamespace" + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="base", + image="ubuntu:16.04", + command=["something"], + ) + ] + ), + ) + ), + ) + + @pytest.mark.parametrize(("randomize_name",), ([True], [False])) + def test_full_job_spec(self, randomize_name, job_spec): + job_spec_name_base = job_spec.metadata.name + + k = KubernetesJobOperator( + task_id="task", + random_name_suffix=randomize_name, + full_job_spec=job_spec, + ) + context = create_context(k) + job = k.build_job_request_obj(context) + + if randomize_name: + assert job.metadata.name.startswith(f"job-{job_spec_name_base}") + assert job.metadata.name != f"job-{job_spec_name_base}" + else: + assert job.metadata.name == f"job-{job_spec_name_base}" + assert job.metadata.namespace == job_spec.metadata.namespace + assert job.spec.template.spec.containers[0].image == job_spec.spec.template.spec.containers[0].image + assert ( + job.spec.template.spec.containers[0].command == job_spec.spec.template.spec.containers[0].command + ) + assert job.metadata.labels == {"foo": "bar"} + + @pytest.mark.parametrize(("randomize_name",), ([True], [False])) + def test_full_job_spec_kwargs(self, randomize_name, job_spec): + # kwargs take precedence, however + image = "some.custom.image:andtag" + name_base = "world" + k = KubernetesJobOperator( + task_id="task", + random_name_suffix=randomize_name, + full_job_spec=job_spec, + name=name_base, + image=image, + labels={"hello": "world"}, + ) + job = k.build_job_request_obj(create_context(k)) + + # make sure the kwargs takes precedence (and that name is randomized when expected) + if randomize_name: + assert job.metadata.name.startswith(f"job-{name_base}") + assert job.metadata.name != f"job-{name_base}" + else: + assert job.metadata.name == f"job-{name_base}" + assert job.spec.template.spec.containers[0].image == image + assert job.metadata.labels == { + "foo": "bar", + "hello": "world", + } + + @pytest.fixture + def job_template_file(self, tmp_path): + job_template_yaml = """ + apiVersion: batch/v1 + kind: Job + metadata: + name: hello + namespace: templatenamespace + labels: + foo: bar + spec: + ttlSecondsAfterFinished: 60 + parallelism: 3 + completions: 3 + suspend: true + template: + spec: + serviceAccountName: foo + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: kubernetes.io/role + operator: In + values: + - foo + - bar + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + preference: + matchExpressions: + - key: kubernetes.io/role + operator: In + values: + - foo + - bar + containers: + - name: base + image: ubuntu:16.04 + imagePullPolicy: Always + command: + - something + """ + + tpl_file = tmp_path / "template.yaml" + tpl_file.write_text(job_template_yaml) + + yield tpl_file + + @pytest.mark.parametrize(("randomize_name",), ([True], [False])) + def test_job_template_file(self, randomize_name, job_template_file): + k = KubernetesJobOperator( + task_id="task", + random_name_suffix=randomize_name, + job_template_file=job_template_file, + ) + job = k.build_job_request_obj(create_context(k)) + + if randomize_name: + assert job.metadata.name.startswith("job-hello") + assert job.metadata.name != "job-hello" + else: + assert job.metadata.name == "job-hello" + assert job.metadata.labels == {"foo": "bar"} + assert job.metadata.namespace == "templatenamespace" + assert job.spec.template.spec.containers[0].image == "ubuntu:16.04" + assert job.spec.template.spec.containers[0].image_pull_policy == "Always" + assert job.spec.template.spec.containers[0].command == ["something"] + assert job.spec.template.spec.service_account_name == "foo" + affinity = { + "node_affinity": { + "preferred_during_scheduling_ignored_during_execution": [ + { + "preference": { + "match_expressions": [ + {"key": "kubernetes.io/role", "operator": "In", "values": ["foo", "bar"]} + ], + "match_fields": None, + }, + "weight": 1, + } + ], + "required_during_scheduling_ignored_during_execution": { + "node_selector_terms": [ + { + "match_expressions": [ + {"key": "kubernetes.io/role", "operator": "In", "values": ["foo", "bar"]} + ], + "match_fields": None, + } + ] + }, + }, + "pod_affinity": None, + "pod_anti_affinity": None, + } + + assert job.spec.template.spec.affinity.to_dict() == affinity + + @pytest.mark.parametrize(("randomize_name",), ([True], [False])) + def test_job_template_file_kwargs_override(self, randomize_name, job_template_file): + # kwargs take precedence, however + image = "some.custom.image:andtag" + name_base = "world" + k = KubernetesJobOperator( + task_id="task", + job_template_file=job_template_file, + name=name_base, + random_name_suffix=randomize_name, + image=image, + labels={"hello": "world"}, + ) + job = k.build_job_request_obj(create_context(k)) + + # make sure the kwargs takes precedence (and that name is randomized when expected) + if randomize_name: + assert job.metadata.name.startswith(f"job-{name_base}") + assert job.metadata.name != f"job-{name_base}" + else: + assert job.metadata.name == f"job-{name_base}" + assert job.spec.template.spec.containers[0].image == image + assert job.metadata.labels == { + "foo": "bar", + "hello": "world", + } + + def test_task_id_as_name(self): + k = KubernetesJobOperator( + task_id=".hi.-_09HI", + random_name_suffix=False, + ) + job = k.build_job_request_obj({}) + assert job.metadata.name == "job-hi-09hi" + + def test_task_id_as_name_with_suffix(self): + k = KubernetesJobOperator( + task_id=".hi.-_09HI", + random_name_suffix=True, + ) + job = k.build_job_request_obj({}) + expected = "job-hi-09hi" + assert job.metadata.name[: len(expected)] == expected + assert re.match(rf"{expected}-[a-z0-9]{{8}}", job.metadata.name) is not None + + def test_task_id_as_name_with_suffix_very_long(self): + k = KubernetesJobOperator( + task_id="a" * 250, + random_name_suffix=True, + ) + job = k.build_job_request_obj({}) + assert ( + re.match( + r"job-a{71}-[a-z0-9]{8}", + job.metadata.name, + ) + is not None + ) + + def test_task_id_as_name_dag_id_is_ignored(self): + dag = DAG(dag_id="this_is_a_dag_name", start_date=pendulum.now()) + k = KubernetesJobOperator( + task_id="a_very_reasonable_task_name", + dag=dag, + ) + job = k.build_job_request_obj({}) + assert re.match(r"job-a-very-reasonable-task-name-[a-z0-9-]+", job.metadata.name) is not None diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index 06cbea84bc5fe..e7ca2a9744111 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -31,6 +31,7 @@ GKEAsyncHook, GKEDeploymentHook, GKEHook, + GKEJobHook, GKEPodAsyncHook, GKEPodHook, ) @@ -696,3 +697,49 @@ def test_disable_tcp_keepalive( api_conn = gke_hook.get_conn() assert mock_enable.called is expected assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + + +class TestGKEJobHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.gke_hook = GKEJobHook(gcp_conn_id="test", ssl_ca_cert=None, cluster_url=None) + self.gke_hook._client = mock.Mock() + + def refresh_token(request): + self.credentials.token = "New" + + self.credentials = mock.MagicMock() + self.credentials.token = "Old" + self.credentials.expired = False + self.credentials.refresh = refresh_token + + @mock.patch(GKE_STRING.format("google_requests.Request")) + def test_get_connection_update_hook_with_invalid_token(self, mock_request): + self.gke_hook._get_config = self._get_config + self.gke_hook.get_credentials = self._get_credentials + self.gke_hook.get_credentials().expired = True + the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn() + + the_client.configuration.refresh_api_key_hook(the_client.configuration) + + assert self.gke_hook.get_credentials().token == "New" + + @mock.patch(GKE_STRING.format("google_requests.Request")) + def test_get_connection_update_hook_with_valid_token(self, mock_request): + self.gke_hook._get_config = self._get_config + self.gke_hook.get_credentials = self._get_credentials + self.gke_hook.get_credentials().expired = False + the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn() + + the_client.configuration.refresh_api_key_hook(the_client.configuration) + + assert self.gke_hook.get_credentials().token == "Old" + + def _get_config(self): + return kubernetes.client.configuration.Configuration() + + def _get_credentials(self): + return self.credentials + diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 30a5247a68a67..db1dfe1d5dd36 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -28,12 +28,14 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.operators.kubernetes_engine import ( GKECreateClusterOperator, GKEDeleteClusterOperator, GKEStartKueueInsideClusterOperator, + GKEStartJobOperator, GKEStartPodOperator, ) from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger @@ -68,11 +70,13 @@ GKE_POD_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEPodHook" GKE_DEPLOYMENT_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEDeploymentHook" KUB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute" +KUB_JOB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator.execute" TEMP_FILE = "tempfile.NamedTemporaryFile" GKE_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator" GKE_CREATE_CLUSTER_PATH = ( "airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator" ) +GKE_JOB_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartJobOperator" GKE_CLUSTER_AUTH_DETAILS_PATH = ( "airflow.providers.google.cloud.operators.kubernetes_engine.GKEClusterAuthDetails" ) @@ -639,3 +643,116 @@ def test_async_create_pod_should_execute_successfully( self.gke_op.execute(context=mock.MagicMock()) fetch_cluster_info_mock.assert_called_once() assert isinstance(exc.value.trigger, GKEStartPodTrigger) + + +class TestGKEStartJobOperator: + def setup_method(self): + self.gke_op = GKEStartJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + ) + self.gke_op.job = mock.MagicMock( + name=TASK_NAME, + namespace=NAMESPACE, + ) + + def test_template_fields(self): + assert set(KubernetesJobOperator.template_fields).issubset(GKEStartJobOperator.template_fields) + + @mock.patch.dict(os.environ, {}) + @mock.patch(KUB_JOB_OPERATOR_EXEC) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + def test_execute(self, fetch_cluster_info_mock, file_mock, exec_mock): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + def test_config_file_throws_error(self): + with pytest.raises(AirflowException): + GKEStartJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + config_file="/path/to/alternative/kubeconfig", + ) + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(KUB_JOB_OPERATOR_EXEC) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + def test_execute_with_impersonation_service_account( + self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = "test_account@example.com" + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(KUB_JOB_OPERATOR_EXEC) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + def test_execute_with_impersonation_service_chain_one_element( + self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = ["test_account@example.com"] + self.gke_op.execute(context=mock.MagicMock()) + + fetch_cluster_info_mock.assert_called_once() + + @pytest.mark.db_test + def test_default_gcp_conn_id(self): + gke_op = GKEStartJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + ) + gke_op._cluster_url = CLUSTER_URL + gke_op._ssl_ca_cert = SSL_CA_CERT + hook = gke_op.hook + + assert hook.gcp_conn_id == "google_cloud_default" + + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), + ) + def test_gcp_conn_id(self, get_con_mock): + gke_op = GKEStartJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + gcp_conn_id="test_conn", + ) + gke_op._cluster_url = CLUSTER_URL + gke_op._ssl_ca_cert = SSL_CA_CERT + hook = gke_op.hook + + assert hook.gcp_conn_id == "test_conn" diff --git a/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py new file mode 100644 index 0000000000000..8801f5ddaf6a6 --- /dev/null +++ b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using the KubernetesJobOperator. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_kubernetes_job_operator" + +with DAG( + dag_id=DAG_ID, + schedule=None, + start_date=datetime(2021, 1, 1), + tags=["example", "kubernetes"], +) as dag: + # [START howto_operator_k8s_job] + k8s_job = KubernetesJobOperator( + task_id="job-task", + namespace="default", + image="perl:5.34.0", + cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"], + name="test-pi", + ) + # [END howto_operator_k8s_job] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py new file mode 100644 index 0000000000000..e1acd576e7fed --- /dev/null +++ b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Kubernetes Engine. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.kubernetes_engine import ( + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartJobOperator, +) + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "kubernetes_engine_job" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_LOCATION = "europe-north1-a" +CLUSTER_NAME = f"cluster-name-test-build-{ENV_ID}" +CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example"], +) as dag: + create_cluster = GKECreateClusterOperator( + task_id="create_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + body=CLUSTER, + ) + + # [START howto_operator_gke_start_job] + job_task = GKEStartJobOperator( + task_id="job_task", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + namespace="default", + image="perl:5.34.0", + cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"], + name="test-pi", + ) + # [END howto_operator_gke_start_job] + + delete_cluster = GKEDeleteClusterOperator( + task_id="delete_cluster", + name=CLUSTER_NAME, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + + create_cluster >> job_task >> delete_cluster + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From 5ee99b843561b63a05c7512f0438f1dce5f2fa28 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Thu, 18 Jan 2024 16:29:43 +0000 Subject: [PATCH 3/8] Remove k8s_yaml_manager.py file --- .../cncf/kubernetes/hooks/kubernetes.py | 6 +- .../cncf/kubernetes/operators/job.py | 58 ++++- .../cncf/kubernetes/utils/k8s_yaml_manager.py | 242 ------------------ .../operators/cloud/kubernetes_engine.rst | 2 + docs/spelling_wordlist.txt | 3 + 5 files changed, 63 insertions(+), 248 deletions(-) delete mode 100644 airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index cbe01dace683b..64053b92a15c2 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -481,7 +481,11 @@ def create_job( job: V1Job, **kwargs, ) -> V1Job: - """Run Job""" + """ + Run Job. + + :param job: A kubernetes Job object + """ sanitized_job = self.batch_v1_client.api_client.sanitize_for_serialization(job) json_job = json.dumps(sanitized_job, indent=2) diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py index ffe9776c4f9d2..4f155205e4f74 100644 --- a/airflow/providers/cncf/kubernetes/operators/job.py +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -17,6 +17,7 @@ """Executes a Kubernetes Job.""" from __future__ import annotations +import copy import logging import os from functools import cached_property @@ -31,9 +32,7 @@ create_unique_id, ) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator -from airflow.providers.cncf.kubernetes.utils.k8s_yaml_manager import ( - reconcile_jobs, -) +from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects from airflow.utils import yaml if TYPE_CHECKING: @@ -181,7 +180,7 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: self.log.debug("Job template file found, will parse for base job") job_template = self.deserialize_job_template_file(self.job_template_file) if self.full_job_spec: - job_template = reconcile_jobs(job_template, self.full_job_spec) + job_template = self.reconcile_jobs(job_template, self.full_job_spec) elif self.full_job_spec: job_template = self.full_job_spec else: @@ -216,7 +215,7 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: ), ) - job = reconcile_jobs(job_template, job) + job = self.reconcile_jobs(job_template, job) if not job.metadata.name: job.metadata.name = create_unique_id( @@ -236,3 +235,52 @@ def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: self.log.info("Building job %s ", job.metadata.name) return job + + @staticmethod + def reconcile_jobs(base_job: k8s.V1Job, client_job: k8s.V1Job | None) -> k8s.V1Job: + """ + Merge Kubernetes Job objects. + + :param base_job: has the base attributes which are overwritten if they exist + in the client job and remain if they do not exist in the client_job + :param client_job: the job that the client wants to create. + :return: the merged jobs + + This can't be done recursively as certain fields are overwritten and some are concatenated. + """ + if client_job is None: + return base_job + + client_job_cp = copy.deepcopy(client_job) + client_job_cp.spec = KubernetesJobOperator.reconcile_job_specs(base_job.spec, client_job_cp.spec) + client_job_cp.metadata = PodGenerator.reconcile_metadata(base_job.metadata, client_job_cp.metadata) + client_job_cp = merge_objects(base_job, client_job_cp) + + return client_job_cp + + @staticmethod + def reconcile_job_specs( + base_spec: k8s.V1JobSpec | None, client_spec: k8s.V1JobSpec | None + ) -> k8s.V1JobSpec | None: + """ + Merge Kubernetes JobSpec objects. + + :param base_spec: has the base attributes which are overwritten if they exist + in the client_spec and remain if they do not exist in the client_spec + :param client_spec: the spec that the client wants to create. + :return: the merged specs + """ + if base_spec and not client_spec: + return base_spec + if not base_spec and client_spec: + return client_spec + elif client_spec and base_spec: + client_spec.template.spec = PodGenerator.reconcile_specs( + base_spec.template.spec, client_spec.template.spec + ) + client_spec.template.metadata = PodGenerator.reconcile_metadata( + base_spec.template.metadata, client_spec.template.metadata + ) + return merge_objects(base_spec, client_spec) + + return None diff --git a/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py b/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py deleted file mode 100644 index 63bd9d68fc341..0000000000000 --- a/airflow/providers/cncf/kubernetes/utils/k8s_yaml_manager.py +++ /dev/null @@ -1,242 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -K8s YAML Manager. - -This module provides a functions for working with K8s yaml. -""" -from __future__ import annotations - -import copy -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from kubernetes.client import models as k8s - - -def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod: - """ - Merge Kubernetes Pod objects. - - :param base_pod: has the base attributes which are overwritten if they exist - in the client pod and remain if they do not exist in the client_pod - :param client_pod: the pod that the client wants to create. - :return: the merged pods - - This can't be done recursively as certain fields are overwritten and some are concatenated. - """ - if client_pod is None: - return base_pod - - client_pod_cp = copy.deepcopy(client_pod) - client_pod_cp.spec = reconcile_pod_specs(base_pod.spec, client_pod_cp.spec) - client_pod_cp.metadata = reconcile_metadata(base_pod.metadata, client_pod_cp.metadata) - client_pod_cp = merge_objects(base_pod, client_pod_cp) - - return client_pod_cp - - -def reconcile_pod_specs( - base_spec: k8s.V1PodSpec | None, client_spec: k8s.V1PodSpec | None -) -> k8s.V1PodSpec | None: - """ - Merge Kubernetes PodSpec objects. - - :param base_spec: has the base attributes which are overwritten if they exist - in the client_spec and remain if they do not exist in the client_spec - :param client_spec: the spec that the client wants to create. - :return: the merged specs - """ - if base_spec and not client_spec: - return base_spec - if not base_spec and client_spec: - return client_spec - elif client_spec and base_spec: - client_spec.containers = reconcile_containers(base_spec.containers, client_spec.containers) - merged_spec = extend_object_field(base_spec, client_spec, "init_containers") - merged_spec = extend_object_field(base_spec, merged_spec, "volumes") - return merge_objects(base_spec, merged_spec) - - return None - - -def reconcile_jobs(base_job: k8s.V1Job, client_job: k8s.V1Job | None) -> k8s.V1Job: - """ - Merge Kubernetes Job objects. - - :param base_job: has the base attributes which are overwritten if they exist - in the client job and remain if they do not exist in the client_job - :param client_job: the job that the client wants to create. - :return: the merged jobs - - This can't be done recursively as certain fields are overwritten and some are concatenated. - """ - if client_job is None: - return base_job - - client_job_cp = copy.deepcopy(client_job) - client_job_cp.spec = reconcile_job_specs(base_job.spec, client_job_cp.spec) - client_job_cp.metadata = reconcile_metadata(base_job.metadata, client_job_cp.metadata) - client_job_cp = merge_objects(base_job, client_job_cp) - - return client_job_cp - - -def reconcile_job_specs( - base_spec: k8s.V1JobSpec | None, client_spec: k8s.V1JobSpec | None -) -> k8s.V1JobSpec | None: - """ - Merge Kubernetes JobSpec objects. - - :param base_spec: has the base attributes which are overwritten if they exist - in the client_spec and remain if they do not exist in the client_spec - :param client_spec: the spec that the client wants to create. - :return: the merged specs - """ - if base_spec and not client_spec: - return base_spec - if not base_spec and client_spec: - return client_spec - elif client_spec and base_spec: - client_spec.template.spec = reconcile_pod_specs(base_spec.template.spec, client_spec.template.spec) - client_spec.template.metadata = reconcile_metadata( - base_spec.template.metadata, client_spec.template.metadata - ) - return merge_objects(base_spec, client_spec) - - return None - - -def reconcile_metadata(base_meta, client_meta): - """ - Merge Kubernetes Metadata objects. - - :param base_meta: has the base attributes which are overwritten if they exist - in the client_meta and remain if they do not exist in the client_meta - :param client_meta: the spec that the client wants to create. - :return: the merged specs - """ - if base_meta and not client_meta: - return base_meta - if not base_meta and client_meta: - return client_meta - elif client_meta and base_meta: - client_meta.labels = merge_objects(base_meta.labels, client_meta.labels) - client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations) - extend_object_field(base_meta, client_meta, "managed_fields") - extend_object_field(base_meta, client_meta, "finalizers") - extend_object_field(base_meta, client_meta, "owner_references") - return merge_objects(base_meta, client_meta) - - return None - - -def reconcile_containers( - base_containers: list[k8s.V1Container], client_containers: list[k8s.V1Container] -) -> list[k8s.V1Container]: - """ - Merge Kubernetes Container objects. - - :param base_containers: has the base attributes which are overwritten if they exist - in the client_containers and remain if they do not exist in the client_containers - :param client_containers: the containers that the client wants to create. - :return: the merged containers - - The runs recursively over the list of containers. - """ - if not base_containers: - return client_containers - if not client_containers: - return base_containers - - client_container = client_containers[0] - base_container = base_containers[0] - client_container = extend_object_field(base_container, client_container, "volume_mounts") - client_container = extend_object_field(base_container, client_container, "env") - client_container = extend_object_field(base_container, client_container, "env_from") - client_container = extend_object_field(base_container, client_container, "ports") - client_container = extend_object_field(base_container, client_container, "volume_devices") - client_container = merge_objects(base_container, client_container) - - return [ - client_container, - *reconcile_containers(base_containers[1:], client_containers[1:]), - ] - - -def merge_objects(base_obj, client_obj): - """ - Merge objects. - - :param base_obj: has the base attributes which are overwritten if they exist - in the client_obj and remain if they do not exist in the client_obj - :param client_obj: the object that the client wants to create. - :return: the merged objects - """ - if not base_obj: - return client_obj - if not client_obj: - return base_obj - - client_obj_cp = copy.deepcopy(client_obj) - - if isinstance(base_obj, dict) and isinstance(client_obj_cp, dict): - base_obj_cp = copy.deepcopy(base_obj) - base_obj_cp.update(client_obj_cp) - return base_obj_cp - - for base_key in base_obj.to_dict(): - base_val = getattr(base_obj, base_key, None) - if not getattr(client_obj, base_key, None) and base_val: - if not isinstance(client_obj_cp, dict): - setattr(client_obj_cp, base_key, base_val) - else: - client_obj_cp[base_key] = base_val - return client_obj_cp - - -def extend_object_field(base_obj, client_obj, field_name): - """ - Add field values to existing objects. - - :param base_obj: an object which has a property `field_name` that is a list - :param client_obj: an object which has a property `field_name` that is a list. - A copy of this object is returned with `field_name` modified - :param field_name: the name of the list field - :return: the client_obj with the property `field_name` being the two properties appended - """ - client_obj_cp = copy.deepcopy(client_obj) - base_obj_field = getattr(base_obj, field_name, None) - client_obj_field = getattr(client_obj, field_name, None) - - if (not isinstance(base_obj_field, list) and base_obj_field is not None) or ( - not isinstance(client_obj_field, list) and client_obj_field is not None - ): - raise ValueError( - f"The chosen field must be a list. Got {type(base_obj_field)} base_object_field " - f"and {type(client_obj_field)} client_object_field." - ) - - if not base_obj_field: - return client_obj_cp - if not client_obj_field: - setattr(client_obj_cp, field_name, base_obj_field) - return client_obj_cp - - appended_fields = base_obj_field + client_obj_field - setattr(client_obj_cp, field_name, appended_fields) - return client_obj_cp diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index 1c50f7a037cd7..e1089646d329c 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -192,6 +192,8 @@ lot less resources wasted on idle Operators or Sensors: :start-after: [START howto_operator_gke_start_pod_xcom_async] :end-before: [END howto_operator_gke_start_pod_xcom_async] +.. _howto/operator:GKEStartJobOperator: + Run a Job on a GKE cluster """""""""""""""""""""""""" diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f04918bf3945a..a46c344851a47 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -860,6 +860,7 @@ jobflow jobId jobName JobRunning +JobSpec JobStatus jobtracker JobTrigger @@ -970,6 +971,7 @@ mailto makedirs makedsn Makefile +manualSelector mapred Mapreduce mapreduce @@ -1648,6 +1650,7 @@ truthy tsql tsv ttl +ttlSecondsAfterFinished TTY Tunables tunables From 4631cd57521006d6f7c751024dd2e75836809c5a Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Fri, 19 Jan 2024 11:38:38 +0000 Subject: [PATCH 4/8] Fix static checks --- airflow/providers/cncf/kubernetes/provider.yaml | 1 + .../operators/cloud/kubernetes_engine.rst | 1 + tests/providers/cncf/kubernetes/operators/test_job.py | 11 +++++++++++ .../google/cloud/operators/test_kubernetes_engine.py | 9 ++++++--- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/provider.yaml b/airflow/providers/cncf/kubernetes/provider.yaml index 80b9e512024d9..cf480963f178e 100644 --- a/airflow/providers/cncf/kubernetes/provider.yaml +++ b/airflow/providers/cncf/kubernetes/provider.yaml @@ -118,6 +118,7 @@ operators: - airflow.providers.cncf.kubernetes.operators.pod - airflow.providers.cncf.kubernetes.operators.spark_kubernetes - airflow.providers.cncf.kubernetes.operators.resource + - airflow.providers.cncf.kubernetes.operators.job sensors: - integration-name: Kubernetes diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index e1089646d329c..b2893f12909c6 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -192,6 +192,7 @@ lot less resources wasted on idle Operators or Sensors: :start-after: [START howto_operator_gke_start_pod_xcom_async] :end-before: [END howto_operator_gke_start_pod_xcom_async] + .. _howto/operator:GKEStartJobOperator: Run a Job on a GKE cluster diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py index e3e1e335ce1f8..2d0d4708eb40e 100644 --- a/tests/providers/cncf/kubernetes/operators/test_job.py +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -17,6 +17,7 @@ from __future__ import annotations import re +from unittest.mock import patch import pendulum import pytest @@ -29,6 +30,7 @@ from airflow.utils.types import DagRunType DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) +HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.job.KubernetesHook" def create_context(task, persist_to_db=False, map_index=None): @@ -64,6 +66,15 @@ def create_context(task, persist_to_db=False, map_index=None): @pytest.mark.execution_timeout(300) class TestKubernetesJobOperator: + @pytest.fixture(autouse=True) + def setup_tests(self): + self._default_client_patch = patch(f"{HOOK_CLASS}._get_default_client") + self._default_client_mock = self._default_client_patch.start() + + yield + + patch.stopall() + def test_templates(self, create_task_instance_of_operator): dag_id = "TestKubernetesJobOperator" ti = create_task_instance_of_operator( diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index db1dfe1d5dd36..8267629a2316e 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -668,7 +668,8 @@ def test_template_fields(self): @mock.patch(KUB_JOB_OPERATOR_EXEC) @mock.patch(TEMP_FILE) @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - def test_execute(self, fetch_cluster_info_mock, file_mock, exec_mock): + @mock.patch(GKE_HOOK_PATH) + def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) self.gke_op.execute(context=mock.MagicMock()) fetch_cluster_info_mock.assert_called_once() @@ -694,8 +695,9 @@ def test_config_file_throws_error(self): @mock.patch(KUB_JOB_OPERATOR_EXEC) @mock.patch(TEMP_FILE) @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) def test_execute_with_impersonation_service_account( - self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock + self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock ): fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) self.gke_op.impersonation_chain = "test_account@example.com" @@ -710,8 +712,9 @@ def test_execute_with_impersonation_service_account( @mock.patch(KUB_JOB_OPERATOR_EXEC) @mock.patch(TEMP_FILE) @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) def test_execute_with_impersonation_service_chain_one_element( - self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock + self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock ): fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) self.gke_op.impersonation_chain = ["test_account@example.com"] From d2f7446b5a8552fd73fe98fe534b54f6830e3064 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Fri, 26 Jan 2024 14:46:57 +0000 Subject: [PATCH 5/8] Fix unit tests --- tests/providers/cncf/kubernetes/operators/test_job.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py index 2d0d4708eb40e..5f4efe3c9657a 100644 --- a/tests/providers/cncf/kubernetes/operators/test_job.py +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -75,6 +75,7 @@ def setup_tests(self): patch.stopall() + @pytest.mark.db_test def test_templates(self, create_task_instance_of_operator): dag_id = "TestKubernetesJobOperator" ti = create_task_instance_of_operator( From a5f7b1a3b2f08503fbda33b3aa3418cba36b6dff Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Mon, 29 Jan 2024 12:55:41 +0000 Subject: [PATCH 6/8] Update docs --- docs/apache-airflow-providers-cncf-kubernetes/operators.rst | 4 ++-- .../operators/cloud/kubernetes_engine.rst | 1 - tests/providers/google/cloud/hooks/test_kubernetes_engine.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index 020149c403e87..af0cc41087597 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -597,9 +597,9 @@ The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperat you to create and run Jobs on a Kubernetes cluster. .. note:: - If you use a managed Kubernetes consider using a specialize KPO operator as it simplifies the Kubernetes authorization process : + If you use a managed Kubernetes consider using a specialize KJO operator as it simplifies the Kubernetes authorization process : - - :ref:`GKEStartJobOperator ` operator for `Google Kubernetes Engine `__. + - ``GKEStartJobOperator`` operator for `Google Kubernetes Engine `__. .. note:: The :doc:`Kubernetes executor ` is **not** required to use this operator. diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index b2893f12909c6..3b6e3cb316271 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -206,7 +206,6 @@ There are two operators available in order to run a job on a GKE cluster: ``GKEStartJobOperator`` extends ``KubernetesJobOperator`` to provide authorization using Google Cloud credentials. There is no need to manage the ``kube_config`` file, as it will be generated automatically. All Kubernetes parameters (except ``config_file``) are also valid for the ``GKEStartJobOperator``. -For more information on ``KubernetesJobOperator``, please look at: :ref:`howto/operator:KubernetesJobOperator` guide. .. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py :language: python diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index e7ca2a9744111..f00b7d4efd2b6 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -742,4 +742,3 @@ def _get_config(self): def _get_credentials(self): return self.credentials - From 961380c88e84fd06930d9d8e840b441865adc15c Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Tue, 6 Feb 2024 14:31:18 +0000 Subject: [PATCH 7/8] Update documentation for KubernetesJobOperator --- .../operators.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index af0cc41087597..fcfec977e7b45 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -618,4 +618,13 @@ to ``~/.kube/config``. It also allows users to supply a template YAML file using :start-after: [START howto_operator_k8s_job] :end-before: [END howto_operator_k8s_job] +Difference between ``KubernetesPodOperator`` and ``KubernetesJobOperator`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` is operator for creating Job. +A Job creates one or more Pods and will continue to retry execution of the Pods until a specified number of them successfully terminate. +As Pods successfully complete, the Job tracks the successful completions. When a specified number of successful completions is reached, the Job is complete. +Users can limit how many times a Job retries execution using configuration parameters like ``activeDeadlineSeconds`` and ``backoffLimit``. +Instead of ``template`` parameter for Pod creating this operator uses :class:`~airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator`. +It means that user can use all parameters from :class:`~airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator` in :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator`. + More information about the Jobs here: `Kubernetes Job Documentation `__ From 8166ea9318a7a4e30ce7f99342d003540eefdd20 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Thu, 15 Feb 2024 15:27:19 +0000 Subject: [PATCH 8/8] Fix static checks --- airflow/providers/google/cloud/operators/kubernetes_engine.py | 2 +- .../providers/google/cloud/operators/test_kubernetes_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 8142a4ed61691..204107c0235f2 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -889,7 +889,7 @@ def hook(self) -> GKEJobHook: return hook def execute(self, context: Context): - """Executes process of creating Job.""" + """Execute process of creating Job.""" self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( cluster_name=self.cluster_name, project_id=self.project_id, diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 8267629a2316e..f610803a3463f 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -34,8 +34,8 @@ from airflow.providers.google.cloud.operators.kubernetes_engine import ( GKECreateClusterOperator, GKEDeleteClusterOperator, - GKEStartKueueInsideClusterOperator, GKEStartJobOperator, + GKEStartKueueInsideClusterOperator, GKEStartPodOperator, ) from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger