From 05483b6be22375daf198fce89fba5ba60883ef90 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 27 Oct 2023 01:45:02 +0200 Subject: [PATCH 01/11] Add a public interface for custom weight_rule implementation --- airflow/models/abstractoperator.py | 18 +----- airflow/models/baseoperator.py | 9 +-- airflow/models/mappedoperator.py | 2 + airflow/task/priority_strategy.py | 90 ++++++++++++++++++++++++++++++ tests/models/test_dag.py | 17 ++++++ 5 files changed, 114 insertions(+), 22 deletions(-) create mode 100644 airflow/task/priority_strategy.py diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index aa79555557c2b..d58da84494be9 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -53,6 +53,7 @@ from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance + from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.utils.task_group import TaskGroup DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") @@ -99,6 +100,7 @@ class AbstractOperator(Templater, DAGNode): weight_rule: str priority_weight: int + _weight_strategy: PriorityWeightStrategy # Defines the operator level extra links. operator_extra_links: Collection[BaseOperatorLink] @@ -397,21 +399,7 @@ def priority_weight_total(self) -> int: - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - WeightRule.UPSTREAM - adds priority weight of all upstream tasks """ - if self.weight_rule == WeightRule.ABSOLUTE: - return self.priority_weight - elif self.weight_rule == WeightRule.DOWNSTREAM: - upstream = False - elif self.weight_rule == WeightRule.UPSTREAM: - upstream = True - else: - upstream = False - dag = self.get_dag() - if dag is None: - return self.priority_weight - return self.priority_weight + sum( - dag.task_dict[task_id].priority_weight - for task_id in self.get_flat_relative_ids(upstream=upstream) - ) + return self._weight_strategy.get_weight(self) @cached_property def operator_extra_link_dict(self) -> dict[str, Any]: diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0b05e6f6d8a44..0fbc195b5cabb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -78,6 +78,7 @@ from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep @@ -92,7 +93,6 @@ from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET -from airflow.utils.weight_rule import WeightRule from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -899,13 +899,8 @@ def __init__( f"received '{type(priority_weight)}'." ) self.priority_weight = priority_weight - if not WeightRule.is_valid(weight_rule): - raise AirflowException( - f"The weight_rule must be one of " - f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; " - f"received '{weight_rule}'." - ) self.weight_rule = weight_rule + self._weight_strategy = get_priority_weight_strategy(weight_rule) self.resources = coerce_resources(resources) if task_concurrency and not max_active_tis_per_dag: # TODO: Remove in Airflow 3.0 diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index bba45968f01f3..a7787d95c3ce8 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -48,6 +48,7 @@ ) from airflow.models.pool import Pool from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped @@ -328,6 +329,7 @@ def __attrs_post_init__(self): f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " f"{self.task_id!r}." ) + self._weight_strategy = get_priority_weight_strategy(str(self.weight_rule)) @classmethod @cache diff --git a/airflow/task/priority_strategy.py b/airflow/task/priority_strategy.py new file mode 100644 index 0000000000000..e2668e0ce14b1 --- /dev/null +++ b/airflow/task/priority_strategy.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Priority weight strategies for task scheduling.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.utils.module_loading import import_string + +if TYPE_CHECKING: + from airflow.models.abstractoperator import AbstractOperator + + +class PriorityWeightStrategy(ABC): + """Priority weight strategy interface.""" + + @abstractmethod + def get_weight(self, task: AbstractOperator): + """Get the priority weight of a task.""" + ... + + +class AbsolutePriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the task's priority weight directly.""" + + def get_weight(self, task: AbstractOperator): + return task.priority_weight + + +class DownstreamPriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the sum of the priority weights of all downstream tasks.""" + + def get_weight(self, task: AbstractOperator): + dag = task.get_dag() + if dag is None: + return task.priority_weight + return task.priority_weight + sum( + dag.task_dict[task_id].priority_weight for task_id in task.get_flat_relative_ids(upstream=False) + ) + + +class UpstreamPriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" + + def get_weight(self, task: AbstractOperator): + dag = task.get_dag() + if dag is None: + return task.priority_weight + return task.priority_weight + sum( + dag.task_dict[task_id].priority_weight for task_id in task.get_flat_relative_ids(upstream=True) + ) + + +_airflow_priority_weight_strategies = { + "absolute": AbsolutePriorityWeightStrategy(), + "downstream": DownstreamPriorityWeightStrategy(), + "upstream": UpstreamPriorityWeightStrategy(), +} + + +def get_priority_weight_strategy(strategy_name: str) -> PriorityWeightStrategy: + """Get a priority weight strategy by name or class path.""" + if strategy_name not in _airflow_priority_weight_strategies: + try: + priority_strategy_class = import_string(strategy_name) + if not issubclass(priority_strategy_class, PriorityWeightStrategy): + raise AirflowException( + f"Priority strategy {priority_strategy_class} is not a subclass of PriorityWeightStrategy" + ) + _airflow_priority_weight_strategies[strategy_name] = priority_strategy_class() + except ImportError: + raise AirflowException(f"Unknown priority strategy {strategy_name}") + return _airflow_priority_weight_strategies[strategy_name] diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index d063077fa25dc..ada4f6836c01c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -29,6 +29,7 @@ from datetime import timedelta from io import StringIO from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -69,6 +70,7 @@ from airflow.operators.python import PythonOperator from airflow.operators.subdag import SubDagOperator from airflow.security import permissions +from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -93,6 +95,9 @@ from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.timetables import cron_timetable, delta_timetable +if TYPE_CHECKING: + from airflow.models.abstractoperator import AbstractOperator + TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) repo_root = Path(__file__).parents[2] @@ -114,6 +119,11 @@ def clear_datasets(): clear_db_datasets() +class TestPriorityWeightStrategy(PriorityWeightStrategy): + def get_weight(self, task: AbstractOperator): + return 99 + + class TestDag: def setup_method(self) -> None: clear_db_runs() @@ -428,6 +438,13 @@ def test_dag_task_invalid_weight_rule(self): with pytest.raises(AirflowException): EmptyOperator(task_id="should_fail", weight_rule="no rule") + def test_dag_task_custom_weight_strategy(self): + with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}): + task = EmptyOperator( + task_id="empty_task", weight_rule="tests.models.test_dag.TestPriorityWeightStrategy" + ) + assert task.priority_weight_total == 99 + def test_get_num_task_instances(self): test_dag_id = "test_get_num_task_instances_dag" test_task_id = "task_1" From 44f513fbacb2f911e7f2d7c5613cf86866ca7f5f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sun, 29 Oct 2023 00:20:22 +0200 Subject: [PATCH 02/11] Remove _weight_strategy attribute --- airflow/models/abstractoperator.py | 5 ++--- airflow/models/baseoperator.py | 3 ++- airflow/models/mappedoperator.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index d58da84494be9..70a9ca73a8b66 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -29,6 +29,7 @@ from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskmixin import DAGNode, DependencyMixin +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query @@ -53,7 +54,6 @@ from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance - from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.utils.task_group import TaskGroup DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") @@ -100,7 +100,6 @@ class AbstractOperator(Templater, DAGNode): weight_rule: str priority_weight: int - _weight_strategy: PriorityWeightStrategy # Defines the operator level extra links. operator_extra_links: Collection[BaseOperatorLink] @@ -399,7 +398,7 @@ def priority_weight_total(self) -> int: - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - WeightRule.UPSTREAM - adds priority weight of all upstream tasks """ - return self._weight_strategy.get_weight(self) + return get_priority_weight_strategy(self.weight_rule).get_weight(self) @cached_property def operator_extra_link_dict(self) -> dict[str, Any]: diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0fbc195b5cabb..7bde01feb63d5 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -900,7 +900,8 @@ def __init__( ) self.priority_weight = priority_weight self.weight_rule = weight_rule - self._weight_strategy = get_priority_weight_strategy(weight_rule) + # validate the priority weight strategy + get_priority_weight_strategy(weight_rule) self.resources = coerce_resources(resources) if task_concurrency and not max_active_tis_per_dag: # TODO: Remove in Airflow 3.0 diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index a7787d95c3ce8..316f88d11d7b2 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -329,7 +329,8 @@ def __attrs_post_init__(self): f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " f"{self.task_id!r}." ) - self._weight_strategy = get_priority_weight_strategy(str(self.weight_rule)) + # validate the priority weight strategy + get_priority_weight_strategy(str(self.weight_rule)) @classmethod @cache From 701c9a2202eac1c25f3e03d85fb423e087866340 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sun, 29 Oct 2023 23:10:29 +0100 Subject: [PATCH 03/11] Move priority weight calculation to TI to support advanced strategies --- airflow/config_templates/config.yml | 7 + airflow/executors/base_executor.py | 2 +- airflow/executors/debug_executor.py | 2 +- ...0_add_priority_weight_strategy_to_task_.py | 48 + airflow/models/abstractoperator.py | 28 +- airflow/models/baseoperator.py | 18 +- airflow/models/mappedoperator.py | 5 + airflow/models/taskinstance.py | 26 +- .../apache/spark/hooks/spark_jdbc_script.py | 7 +- .../serialization/pydantic/taskinstance.py | 1 + airflow/task/priority_strategy.py | 29 +- airflow/utils/db.py | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 2503 +++++++++-------- docs/apache-airflow/migrations-ref.rst | 4 +- tests/models/test_dag.py | 13 +- 16 files changed, 1414 insertions(+), 1283 deletions(-) create mode 100644 airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index ffb9dce073ad5..398afa593f437 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -309,6 +309,13 @@ core: type: string example: ~ default: "downstream" + default_task_priority_weight_strategy: + description: | + The strategy used for the effective total priority weight of the task + version_added: 2.8.0 + type: string + example: ~ + default: ~ default_task_execution_timeout: description: | The default task execution_timeout value for the operators. Expected an integer value to diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 2791c938a4f87..babfe8e9038c0 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -184,7 +184,7 @@ def queue_task_instance( self.queue_command( task_instance, command_list_to_run, - priority=task_instance.task.priority_weight_total, + priority=task_instance.priority_weight, queue=task_instance.task.queue, ) diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index be2b657b7556e..b601c2b7c926f 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -109,7 +109,7 @@ def queue_task_instance( self.queue_command( task_instance, [str(task_instance)], # Just for better logging, it's not used anywhere - priority=task_instance.task.priority_weight_total, + priority=task_instance.priority_weight, queue=task_instance.task.queue, ) # Save params for TaskInstance._run_raw_task diff --git a/airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py b/airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py new file mode 100644 index 0000000000000..8b3d30ba7613a --- /dev/null +++ b/airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""add priority_weight_strategy to task_instance + +Revision ID: 624ecf3b6a5e +Revises: bd5dfbe21f88 +Create Date: 2023-10-29 02:01:34.774596 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "624ecf3b6a5e" +down_revision = "bd5dfbe21f88" +branch_labels = None +depends_on = None +airflow_version = "2.8.0" + + +def upgrade(): + """Apply add priority_weight_strategy to task_instance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.add_column(sa.Column("priority_weight_strategy", sa.String(length=1000))) + + +def downgrade(): + """Unapply add priority_weight_strategy to task_instance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.drop_column("priority_weight_strategy") diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 70a9ca73a8b66..cc077b8b863b6 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,6 +19,7 @@ import datetime import inspect +import warnings from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence @@ -29,7 +30,6 @@ from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskmixin import DAGNode, DependencyMixin -from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query @@ -73,6 +73,9 @@ DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) ) +DEFAULT_PRIORITY_WEIGHT_STRATEGY: str | None = conf.get( + "core", "default_task_priority_weight_strategy", fallback=None +) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( "core", "default_task_execution_timeout" @@ -99,6 +102,7 @@ class AbstractOperator(Templater, DAGNode): operator_class: type[BaseOperator] | dict[str, Any] weight_rule: str + priority_weight_strategy: str priority_weight: int # Defines the operator level extra links. @@ -398,7 +402,27 @@ def priority_weight_total(self) -> int: - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - WeightRule.UPSTREAM - adds priority weight of all upstream tasks """ - return get_priority_weight_strategy(self.weight_rule).get_weight(self) + warnings.warn( + "Accessing `priority_weight_total` from AbstractOperator instance is deprecated." + " Please use `priority_weight` from task instance instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.weight_rule == WeightRule.ABSOLUTE: + return self.priority_weight + elif self.weight_rule == WeightRule.DOWNSTREAM: + upstream = False + elif self.weight_rule == WeightRule.UPSTREAM: + upstream = True + else: + upstream = False + dag = self.get_dag() + if dag is None: + return self.priority_weight + return self.priority_weight + sum( + dag.task_dict[task_id].priority_weight + for task_id in self.get_flat_relative_ids(upstream=upstream) + ) @cached_property def operator_extra_link_dict(self) -> dict[str, Any]: diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7bde01feb63d5..ce860bb53a54f 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -63,6 +63,7 @@ DEFAULT_OWNER, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, + DEFAULT_PRIORITY_WEIGHT_STRATEGY, DEFAULT_QUEUE, DEFAULT_RETRIES, DEFAULT_RETRY_DELAY, @@ -211,6 +212,7 @@ def partial(**kwargs): "retry_exponential_backoff": False, "priority_weight": DEFAULT_PRIORITY_WEIGHT, "weight_rule": DEFAULT_WEIGHT_RULE, + "weight_strategy": DEFAULT_PRIORITY_WEIGHT_STRATEGY, "inlets": [], "outlets": [], } @@ -244,6 +246,7 @@ def partial( retry_exponential_backoff: bool | ArgNotSet = NOTSET, priority_weight: int | ArgNotSet = NOTSET, weight_rule: str | ArgNotSet = NOTSET, + priority_weight_strategy: str | ArgNotSet = NOTSET, sla: timedelta | None | ArgNotSet = NOTSET, max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, @@ -306,6 +309,7 @@ def partial( "retry_exponential_backoff": retry_exponential_backoff, "priority_weight": priority_weight, "weight_rule": weight_rule, + "priority_weight_strategy": priority_weight_strategy, "sla": sla, "max_active_tis_per_dag": max_active_tis_per_dag, "max_active_tis_per_dagrun": max_active_tis_per_dagrun, @@ -568,6 +572,7 @@ class derived from this one results in the creation of a task object, significantly speeding up the task creation process as for very large DAGs. Options can be set as string or using the constants defined in the static class ``airflow.utils.WeightRule`` + :param priority_weight_strategy: TODO: add description :param queue: which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues. @@ -754,6 +759,7 @@ def __init__( default_args: dict | None = None, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, weight_rule: str = DEFAULT_WEIGHT_RULE, + priority_weight_strategy: str | None = DEFAULT_PRIORITY_WEIGHT_STRATEGY, queue: str = DEFAULT_QUEUE, pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, @@ -900,8 +906,18 @@ def __init__( ) self.priority_weight = priority_weight self.weight_rule = weight_rule + self.priority_weight_strategy: str + if not priority_weight_strategy: + warnings.warn( + "weight_rule is deprecated. Please use `priority_weight_strategy` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.priority_weight_strategy = weight_rule + else: + self.priority_weight_strategy = priority_weight_strategy # validate the priority weight strategy - get_priority_weight_strategy(weight_rule) + get_priority_weight_strategy(self.priority_weight_strategy) self.resources = coerce_resources(resources) if task_concurrency and not max_active_tis_per_dag: # TODO: Remove in Airflow 3.0 diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 316f88d11d7b2..21797f6b476fe 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -32,6 +32,7 @@ DEFAULT_OWNER, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, + DEFAULT_PRIORITY_WEIGHT_STRATEGY, DEFAULT_QUEUE, DEFAULT_RETRIES, DEFAULT_RETRY_DELAY, @@ -476,6 +477,10 @@ def priority_weight(self) -> int: # type: ignore[override] def weight_rule(self) -> int: # type: ignore[override] return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + @property + def priority_weight_strategy(self) -> str: # type: ignore[override] + return self.partial_kwargs.get("priority_weight_strategy", DEFAULT_PRIORITY_WEIGHT_STRATEGY) + @property def sla(self) -> datetime.timedelta | None: return self.partial_kwargs.get("sla") diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 816ca3bf6a3ec..6b096b7d363b4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -98,6 +98,7 @@ from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.stats import Stats +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.templates import SandboxedEnvironment from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS @@ -130,7 +131,6 @@ _CURRENT_CONTEXT: list[Context] = [] log = logging.getLogger(__name__) - if TYPE_CHECKING: from datetime import datetime from pathlib import PurePath @@ -158,7 +158,6 @@ else: from sqlalchemy.ext.hybrid import hybrid_property - PAST_DEPENDS_MET = "past_depends_met" @@ -486,6 +485,7 @@ def _refresh_from_db( task_instance.pool_slots = ti.pool_slots or 1 task_instance.queue = ti.queue task_instance.priority_weight = ti.priority_weight + task_instance.priority_weight_strategy = ti.priority_weight_strategy task_instance.operator = ti.operator task_instance.custom_operator_name = ti.custom_operator_name task_instance.queued_dttm = ti.queued_dttm @@ -873,7 +873,13 @@ def _refresh_from_task( task_instance.queue = task.queue task_instance.pool = pool_override or task.pool task_instance.pool_slots = task.pool_slots - task_instance.priority_weight = task.priority_weight_total + with contextlib.suppress(Exception): + # This method is called from the different places, and sometimes the TI is not fully initialized + task_instance.priority_weight = get_priority_weight_strategy( + task.priority_weight_strategy or str(task.weight_rule) + ).get_weight( + task_instance # type: ignore + ) task_instance.run_as_user = task.run_as_user # Do not set max_tries to task.retries here because max_tries is a cumulative # value that needs to be stored in the db. @@ -1208,6 +1214,7 @@ class TaskInstance(Base, LoggingMixin): pool_slots = Column(Integer, default=1, nullable=False) queue = Column(String(256)) priority_weight = Column(Integer) + priority_weight_strategy = Column(String(1000)) operator = Column(String(1000)) custom_operator_name = Column(String(1000)) queued_dttm = Column(UtcDateTime) @@ -1376,6 +1383,9 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any :meta private: """ + priority_weight = get_priority_weight_strategy( + task.priority_weight_strategy or str(task.weight_rule) + ).get_weight(TaskInstance(task=task, run_id=run_id, map_index=map_index)) return { "dag_id": task.dag_id, "task_id": task.task_id, @@ -1386,7 +1396,8 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any "queue": task.queue, "pool": task.pool, "pool_slots": task.pool_slots, - "priority_weight": task.priority_weight_total, + "priority_weight": priority_weight, + "priority_weight_strategy": task.priority_weight_strategy or task.weight_rule, "run_as_user": task.run_as_user, "max_tries": task.retries, "executor_config": task.executor_config, @@ -1444,6 +1455,10 @@ def operator_name(self) -> str | None: """@property: use a more friendly display name for the operator, if set.""" return self.custom_operator_name or self.operator + # @property + # def priority_weight_total(self) -> int: + # return get_priority_weight_strategy(self.priority_weight_strategy).get_weight(self) + def command_as_list( self, mark_success=False, @@ -3351,6 +3366,7 @@ def __init__( key: TaskInstanceKey, run_as_user: str | None = None, priority_weight: int | None = None, + priority_weight_strategy: str | None = None, ): self.dag_id = dag_id self.task_id = task_id @@ -3364,6 +3380,7 @@ def __init__( self.run_as_user = run_as_user self.pool = pool self.priority_weight = priority_weight + self.priority_weight_strategy = priority_weight_strategy self.queue = queue self.key = key @@ -3404,6 +3421,7 @@ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: key=ti.key, run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, + priority_weight_strategy=ti.priority_weight_strategy, ) @classmethod diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py index d431782929e2e..63aa7d0799a89 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py @@ -18,9 +18,10 @@ from __future__ import annotations import argparse -from typing import Any +from typing import TYPE_CHECKING, Any -from pyspark.sql import SparkSession +if TYPE_CHECKING: + from pyspark.sql import SparkSession SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc" SPARK_READ_FROM_JDBC: str = "jdbc_to_spark" @@ -146,6 +147,8 @@ def _parse_arguments(args: list[str] | None = None) -> Any: def _create_spark_session(arguments: Any) -> SparkSession: + from pyspark.sql import SparkSession + return SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate() diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 7e86c7f51f0c8..ea74220edceed 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -84,6 +84,7 @@ class TaskInstancePydantic(BaseModelPydantic): pool_slots: int queue: str priority_weight: Optional[int] + priority_weight_strategy: Optional[str] operator: str custom_operator_name: Optional[str] queued_dttm: Optional[str] diff --git a/airflow/task/priority_strategy.py b/airflow/task/priority_strategy.py index e2668e0ce14b1..6e061ad7069f6 100644 --- a/airflow/task/priority_strategy.py +++ b/airflow/task/priority_strategy.py @@ -25,14 +25,14 @@ from airflow.utils.module_loading import import_string if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator + from airflow.models.taskinstance import TaskInstance class PriorityWeightStrategy(ABC): """Priority weight strategy interface.""" @abstractmethod - def get_weight(self, task: AbstractOperator): + def get_weight(self, ti: TaskInstance): """Get the priority weight of a task.""" ... @@ -40,31 +40,32 @@ def get_weight(self, task: AbstractOperator): class AbsolutePriorityWeightStrategy(PriorityWeightStrategy): """Priority weight strategy that uses the task's priority weight directly.""" - def get_weight(self, task: AbstractOperator): - return task.priority_weight + def get_weight(self, ti: TaskInstance): + return ti.task.priority_weight class DownstreamPriorityWeightStrategy(PriorityWeightStrategy): """Priority weight strategy that uses the sum of the priority weights of all downstream tasks.""" - def get_weight(self, task: AbstractOperator): - dag = task.get_dag() + def get_weight(self, ti: TaskInstance): + dag = ti.task.get_dag() if dag is None: - return task.priority_weight - return task.priority_weight + sum( - dag.task_dict[task_id].priority_weight for task_id in task.get_flat_relative_ids(upstream=False) + return ti.task.priority_weight + return ti.task.priority_weight + sum( + dag.task_dict[task_id].priority_weight + for task_id in ti.task.get_flat_relative_ids(upstream=False) ) class UpstreamPriorityWeightStrategy(PriorityWeightStrategy): """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" - def get_weight(self, task: AbstractOperator): - dag = task.get_dag() + def get_weight(self, ti: TaskInstance): + dag = ti.task.get_dag() if dag is None: - return task.priority_weight - return task.priority_weight + sum( - dag.task_dict[task_id].priority_weight for task_id in task.get_flat_relative_ids(upstream=True) + return ti.task.priority_weight + return ti.task.priority_weight + sum( + dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True) ) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index b87edce827c1b..a67d640347f6d 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -88,7 +88,7 @@ "2.6.0": "98ae134e6fff", "2.6.2": "c804e5c76e3e", "2.7.0": "405de8318b3a", - "2.8.0": "bd5dfbe21f88", + "2.8.0": "624ecf3b6a5e", } diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 301abf8a84311..b2d9dbf5a2e14 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -8229a936107bee851d6a39c791b842b11f295ffa308b18106e45298a50871493 \ No newline at end of file +4739d87664d779f93e39b09ca6e5e662d72f1fa88857d8b6e44d2f2557656753 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index a0cfc1866cbec..402bf123e5272 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -4,1632 +4,1635 @@ - - + + %3 - + ab_permission - -ab_permission - -id - [INTEGER] - NOT NULL - -name - [VARCHAR(100)] - NOT NULL + +ab_permission + +id + [INTEGER] + NOT NULL + +name + [VARCHAR(100)] + NOT NULL ab_permission_view - -ab_permission_view - -id - [INTEGER] - NOT NULL - -permission_id - [INTEGER] - -view_menu_id - [INTEGER] + +ab_permission_view + +id + [INTEGER] + NOT NULL + +permission_id + [INTEGER] + +view_menu_id + [INTEGER] ab_permission--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_permission_view_role - -ab_permission_view_role - -id - [INTEGER] - NOT NULL - -permission_view_id - [INTEGER] - -role_id - [INTEGER] + +ab_permission_view_role + +id + [INTEGER] + NOT NULL + +permission_view_id + [INTEGER] + +role_id + [INTEGER] ab_permission_view--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} ab_view_menu - -ab_view_menu - -id - [INTEGER] - NOT NULL - -name - [VARCHAR(250)] - NOT NULL + +ab_view_menu + +id + [INTEGER] + NOT NULL + +name + [VARCHAR(250)] + NOT NULL ab_view_menu--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_role - -ab_role - -id - [INTEGER] - NOT NULL - -name - [VARCHAR(64)] - NOT NULL + +ab_role + +id + [INTEGER] + NOT NULL + +name + [VARCHAR(64)] + NOT NULL ab_role--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} ab_user_role - -ab_user_role - -id - [INTEGER] - NOT NULL - -role_id - [INTEGER] - -user_id - [INTEGER] + +ab_user_role + +id + [INTEGER] + NOT NULL + +role_id + [INTEGER] + +user_id + [INTEGER] ab_role--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_register_user - -ab_register_user - -id - [INTEGER] - NOT NULL - -email - [VARCHAR(512)] - NOT NULL - -first_name - [VARCHAR(256)] - NOT NULL - -last_name - [VARCHAR(256)] - NOT NULL - -password - [VARCHAR(256)] - -registration_date - [DATETIME] - -registration_hash - [VARCHAR(256)] - -username - [VARCHAR(512)] - NOT NULL + +ab_register_user + +id + [INTEGER] + NOT NULL + +email + [VARCHAR(512)] + NOT NULL + +first_name + [VARCHAR(256)] + NOT NULL + +last_name + [VARCHAR(256)] + NOT NULL + +password + [VARCHAR(256)] + +registration_date + [DATETIME] + +registration_hash + [VARCHAR(256)] + +username + [VARCHAR(512)] + NOT NULL ab_user - -ab_user - -id - [INTEGER] - NOT NULL - -active - [BOOLEAN] - -changed_by_fk - [INTEGER] - -changed_on - [DATETIME] - -created_by_fk - [INTEGER] - -created_on - [DATETIME] - -email - [VARCHAR(512)] - NOT NULL - -fail_login_count - [INTEGER] - -first_name - [VARCHAR(256)] - NOT NULL - -last_login - [DATETIME] - -last_name - [VARCHAR(256)] - NOT NULL - -login_count - [INTEGER] - -password - [VARCHAR(256)] - -username - [VARCHAR(512)] - NOT NULL + +ab_user + +id + [INTEGER] + NOT NULL + +active + [BOOLEAN] + +changed_by_fk + [INTEGER] + +changed_on + [DATETIME] + +created_by_fk + [INTEGER] + +created_on + [DATETIME] + +email + [VARCHAR(512)] + NOT NULL + +fail_login_count + [INTEGER] + +first_name + [VARCHAR(256)] + NOT NULL + +last_login + [DATETIME] + +last_name + [VARCHAR(256)] + NOT NULL + +login_count + [INTEGER] + +password + [VARCHAR(256)] + +username + [VARCHAR(512)] + NOT NULL ab_user--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} dag_run_note - -dag_run_note - -dag_run_id - [INTEGER] - NOT NULL - -content - [VARCHAR(1000)] - -created_at - [TIMESTAMP] - NOT NULL - -updated_at - [TIMESTAMP] - NOT NULL - -user_id - [INTEGER] + +dag_run_note + +dag_run_id + [INTEGER] + NOT NULL + +content + [VARCHAR(1000)] + +created_at + [TIMESTAMP] + NOT NULL + +updated_at + [TIMESTAMP] + NOT NULL + +user_id + [INTEGER] ab_user--dag_run_note - -0..N -{0,1} + +0..N +{0,1} task_instance_note - -task_instance_note - -dag_id - [VARCHAR(250)] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -content - [VARCHAR(1000)] - -created_at - [TIMESTAMP] - NOT NULL - -updated_at - [TIMESTAMP] - NOT NULL - -user_id - [INTEGER] + +task_instance_note + +dag_id + [VARCHAR(250)] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +content + [VARCHAR(1000)] + +created_at + [TIMESTAMP] + NOT NULL + +updated_at + [TIMESTAMP] + NOT NULL + +user_id + [INTEGER] ab_user--task_instance_note - -0..N -{0,1} + +0..N +{0,1} alembic_version - -alembic_version - -version_num - [VARCHAR(32)] - NOT NULL + +alembic_version + +version_num + [VARCHAR(32)] + NOT NULL callback_request - -callback_request - -id - [INTEGER] - NOT NULL - -callback_data - [JSON] - NOT NULL - -callback_type - [VARCHAR(20)] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL - -priority_weight - [INTEGER] - NOT NULL - -processor_subdir - [VARCHAR(2000)] + +callback_request + +id + [INTEGER] + NOT NULL + +callback_data + [JSON] + NOT NULL + +callback_type + [VARCHAR(20)] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL + +priority_weight + [INTEGER] + NOT NULL + +processor_subdir + [VARCHAR(2000)] connection - -connection - -id - [INTEGER] - NOT NULL - -conn_id - [VARCHAR(250)] - NOT NULL - -conn_type - [VARCHAR(500)] - NOT NULL - -description - [VARCHAR(5000)] - -extra - [TEXT] - -host - [VARCHAR(500)] - -is_encrypted - [BOOLEAN] - -is_extra_encrypted - [BOOLEAN] - -login - [TEXT] - -password - [TEXT] - -port - [INTEGER] - -schema - [VARCHAR(500)] + +connection + +id + [INTEGER] + NOT NULL + +conn_id + [VARCHAR(250)] + NOT NULL + +conn_type + [VARCHAR(500)] + NOT NULL + +description + [VARCHAR(5000)] + +extra + [TEXT] + +host + [VARCHAR(500)] + +is_encrypted + [BOOLEAN] + +is_extra_encrypted + [BOOLEAN] + +login + [TEXT] + +password + [TEXT] + +port + [INTEGER] + +schema + [VARCHAR(500)] dag - -dag - -dag_id - [VARCHAR(250)] - NOT NULL - -default_view - [VARCHAR(25)] - -description - [TEXT] - -fileloc - [VARCHAR(2000)] - -has_import_errors - [BOOLEAN] - -has_task_concurrency_limits - [BOOLEAN] - NOT NULL - -is_active - [BOOLEAN] - -is_paused - [BOOLEAN] - -is_subdag - [BOOLEAN] - -last_expired - [TIMESTAMP] - -last_parsed_time - [TIMESTAMP] - -last_pickled - [TIMESTAMP] - -max_active_runs - [INTEGER] - -max_active_tasks - [INTEGER] - NOT NULL - -next_dagrun - [TIMESTAMP] - -next_dagrun_create_after - [TIMESTAMP] - -next_dagrun_data_interval_end - [TIMESTAMP] - -next_dagrun_data_interval_start - [TIMESTAMP] - -owners - [VARCHAR(2000)] - -pickle_id - [INTEGER] - -processor_subdir - [VARCHAR(2000)] - -root_dag_id - [VARCHAR(250)] - -schedule_interval - [TEXT] - -scheduler_lock - [BOOLEAN] - -timetable_description - [VARCHAR(1000)] + +dag + +dag_id + [VARCHAR(250)] + NOT NULL + +default_view + [VARCHAR(25)] + +description + [TEXT] + +fileloc + [VARCHAR(2000)] + +has_import_errors + [BOOLEAN] + +has_task_concurrency_limits + [BOOLEAN] + NOT NULL + +is_active + [BOOLEAN] + +is_paused + [BOOLEAN] + +is_subdag + [BOOLEAN] + +last_expired + [TIMESTAMP] + +last_parsed_time + [TIMESTAMP] + +last_pickled + [TIMESTAMP] + +max_active_runs + [INTEGER] + +max_active_tasks + [INTEGER] + NOT NULL + +next_dagrun + [TIMESTAMP] + +next_dagrun_create_after + [TIMESTAMP] + +next_dagrun_data_interval_end + [TIMESTAMP] + +next_dagrun_data_interval_start + [TIMESTAMP] + +owners + [VARCHAR(2000)] + +pickle_id + [INTEGER] + +processor_subdir + [VARCHAR(2000)] + +root_dag_id + [VARCHAR(250)] + +schedule_interval + [TEXT] + +scheduler_lock + [BOOLEAN] + +timetable_description + [VARCHAR(1000)] dag_owner_attributes - -dag_owner_attributes - -dag_id - [VARCHAR(250)] - NOT NULL - -owner - [VARCHAR(500)] - NOT NULL - -link - [VARCHAR(500)] - NOT NULL + +dag_owner_attributes + +dag_id + [VARCHAR(250)] + NOT NULL + +owner + [VARCHAR(500)] + NOT NULL + +link + [VARCHAR(500)] + NOT NULL dag--dag_owner_attributes - -1 -1 + +1 +1 dag_schedule_dataset_reference - -dag_schedule_dataset_reference - -dag_id - [VARCHAR(250)] - NOT NULL - -dataset_id - [INTEGER] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL - -updated_at - [TIMESTAMP] - NOT NULL + +dag_schedule_dataset_reference + +dag_id + [VARCHAR(250)] + NOT NULL + +dataset_id + [INTEGER] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL + +updated_at + [TIMESTAMP] + NOT NULL dag--dag_schedule_dataset_reference - -1 -1 + +1 +1 dag_tag - -dag_tag - -dag_id - [VARCHAR(250)] - NOT NULL - -name - [VARCHAR(100)] - NOT NULL + +dag_tag + +dag_id + [VARCHAR(250)] + NOT NULL + +name + [VARCHAR(100)] + NOT NULL dag--dag_tag - -1 -1 + +1 +1 dag_warning - -dag_warning - -dag_id - [VARCHAR(250)] - NOT NULL - -warning_type - [VARCHAR(50)] - NOT NULL - -message - [TEXT] - NOT NULL - -timestamp - [TIMESTAMP] - NOT NULL + +dag_warning + +dag_id + [VARCHAR(250)] + NOT NULL + +warning_type + [VARCHAR(50)] + NOT NULL + +message + [TEXT] + NOT NULL + +timestamp + [TIMESTAMP] + NOT NULL dag--dag_warning - -1 -1 + +1 +1 dataset_dag_run_queue - -dataset_dag_run_queue - -dataset_id - [INTEGER] - NOT NULL - -target_dag_id - [VARCHAR(250)] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL + +dataset_dag_run_queue + +dataset_id + [INTEGER] + NOT NULL + +target_dag_id + [VARCHAR(250)] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL dag--dataset_dag_run_queue - -1 -1 + +1 +1 task_outlet_dataset_reference - -task_outlet_dataset_reference - -dag_id - [VARCHAR(250)] - NOT NULL - -dataset_id - [INTEGER] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL - -updated_at - [TIMESTAMP] - NOT NULL + +task_outlet_dataset_reference + +dag_id + [VARCHAR(250)] + NOT NULL + +dataset_id + [INTEGER] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL + +updated_at + [TIMESTAMP] + NOT NULL dag--task_outlet_dataset_reference - -1 -1 + +1 +1 dag_code - -dag_code - -fileloc_hash - [BIGINT] - NOT NULL - -fileloc - [VARCHAR(2000)] - NOT NULL - -last_updated - [TIMESTAMP] - NOT NULL - -source_code - [TEXT] - NOT NULL + +dag_code + +fileloc_hash + [BIGINT] + NOT NULL + +fileloc + [VARCHAR(2000)] + NOT NULL + +last_updated + [TIMESTAMP] + NOT NULL + +source_code + [TEXT] + NOT NULL dag_pickle - -dag_pickle - -id - [INTEGER] - NOT NULL - -created_dttm - [TIMESTAMP] - -pickle - [BLOB] - -pickle_hash - [BIGINT] + +dag_pickle + +id + [INTEGER] + NOT NULL + +created_dttm + [TIMESTAMP] + +pickle + [BLOB] + +pickle_hash + [BIGINT] dag_run - -dag_run - -id - [INTEGER] - NOT NULL - -clear_number - [INTEGER] - NOT NULL - -conf - [BLOB] - -creating_job_id - [INTEGER] - -dag_hash - [VARCHAR(32)] - -dag_id - [VARCHAR(250)] - NOT NULL - -data_interval_end - [TIMESTAMP] - -data_interval_start - [TIMESTAMP] - -end_date - [TIMESTAMP] - -execution_date - [TIMESTAMP] - NOT NULL - -external_trigger - [BOOLEAN] - -last_scheduling_decision - [TIMESTAMP] - -log_template_id - [INTEGER] - -queued_at - [TIMESTAMP] - -run_id - [VARCHAR(250)] - NOT NULL - -run_type - [VARCHAR(50)] - NOT NULL - -start_date - [TIMESTAMP] - -state - [VARCHAR(50)] - -updated_at - [TIMESTAMP] + +dag_run + +id + [INTEGER] + NOT NULL + +clear_number + [INTEGER] + NOT NULL + +conf + [BLOB] + +creating_job_id + [INTEGER] + +dag_hash + [VARCHAR(32)] + +dag_id + [VARCHAR(250)] + NOT NULL + +data_interval_end + [TIMESTAMP] + +data_interval_start + [TIMESTAMP] + +end_date + [TIMESTAMP] + +execution_date + [TIMESTAMP] + NOT NULL + +external_trigger + [BOOLEAN] + +last_scheduling_decision + [TIMESTAMP] + +log_template_id + [INTEGER] + +queued_at + [TIMESTAMP] + +run_id + [VARCHAR(250)] + NOT NULL + +run_type + [VARCHAR(50)] + NOT NULL + +start_date + [TIMESTAMP] + +state + [VARCHAR(50)] + +updated_at + [TIMESTAMP] dag_run--dag_run_note - -1 -1 + +1 +1 dagrun_dataset_event - -dagrun_dataset_event - -dag_run_id - [INTEGER] - NOT NULL - -event_id - [INTEGER] - NOT NULL + +dagrun_dataset_event + +dag_run_id + [INTEGER] + NOT NULL + +event_id + [INTEGER] + NOT NULL dag_run--dagrun_dataset_event - -1 -1 + +1 +1 task_instance - -task_instance - -dag_id - [VARCHAR(250)] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -custom_operator_name - [VARCHAR(1000)] - -duration - [FLOAT] - -end_date - [TIMESTAMP] - -executor_config - [BLOB] - -external_executor_id - [VARCHAR(250)] - -hostname - [VARCHAR(1000)] - -job_id - [INTEGER] - -max_tries - [INTEGER] - -next_kwargs - [JSON] - -next_method - [VARCHAR(1000)] - -operator - [VARCHAR(1000)] - -pid - [INTEGER] - -pool - [VARCHAR(256)] - NOT NULL - -pool_slots - [INTEGER] - NOT NULL - -priority_weight - [INTEGER] - -queue - [VARCHAR(256)] - -queued_by_job_id - [INTEGER] - -queued_dttm - [TIMESTAMP] - -start_date - [TIMESTAMP] - -state - [VARCHAR(20)] - -trigger_id - [INTEGER] - -trigger_timeout - [DATETIME] - -try_number - [INTEGER] - -unixname - [VARCHAR(1000)] - -updated_at - [TIMESTAMP] + +task_instance + +dag_id + [VARCHAR(250)] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +custom_operator_name + [VARCHAR(1000)] + +duration + [FLOAT] + +end_date + [TIMESTAMP] + +executor_config + [BLOB] + +external_executor_id + [VARCHAR(250)] + +hostname + [VARCHAR(1000)] + +job_id + [INTEGER] + +max_tries + [INTEGER] + +next_kwargs + [JSON] + +next_method + [VARCHAR(1000)] + +operator + [VARCHAR(1000)] + +pid + [INTEGER] + +pool + [VARCHAR(256)] + NOT NULL + +pool_slots + [INTEGER] + NOT NULL + +priority_weight + [INTEGER] + +priority_weight_strategy + [VARCHAR(1000)] + +queue + [VARCHAR(256)] + +queued_by_job_id + [INTEGER] + +queued_dttm + [TIMESTAMP] + +start_date + [TIMESTAMP] + +state + [VARCHAR(20)] + +trigger_id + [INTEGER] + +trigger_timeout + [DATETIME] + +try_number + [INTEGER] + +unixname + [VARCHAR(1000)] + +updated_at + [TIMESTAMP] dag_run--task_instance - -1 -1 + +1 +1 dag_run--task_instance - -1 -1 + +1 +1 task_reschedule - -task_reschedule - -id - [INTEGER] - NOT NULL - -dag_id - [VARCHAR(250)] - NOT NULL - -duration - [INTEGER] - NOT NULL - -end_date - [TIMESTAMP] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -reschedule_date - [TIMESTAMP] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -start_date - [TIMESTAMP] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -try_number - [INTEGER] - NOT NULL + +task_reschedule + +id + [INTEGER] + NOT NULL + +dag_id + [VARCHAR(250)] + NOT NULL + +duration + [INTEGER] + NOT NULL + +end_date + [TIMESTAMP] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +reschedule_date + [TIMESTAMP] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +start_date + [TIMESTAMP] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +try_number + [INTEGER] + NOT NULL dag_run--task_reschedule - -0..N -1 + +0..N +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_instance_note - -1 -1 + +1 +1 task_instance--task_instance_note - -1 -1 + +1 +1 task_instance--task_instance_note - -1 -1 + +1 +1 task_instance--task_instance_note - -1 -1 + +1 +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 rendered_task_instance_fields - -rendered_task_instance_fields - -dag_id - [VARCHAR(250)] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -k8s_pod_yaml - [JSON] - -rendered_fields - [JSON] - NOT NULL + +rendered_task_instance_fields + +dag_id + [VARCHAR(250)] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +k8s_pod_yaml + [JSON] + +rendered_fields + [JSON] + NOT NULL task_instance--rendered_task_instance_fields - -1 -1 + +1 +1 task_instance--rendered_task_instance_fields - -1 -1 + +1 +1 task_instance--rendered_task_instance_fields - -1 -1 + +1 +1 task_instance--rendered_task_instance_fields - -1 -1 + +1 +1 task_fail - -task_fail - -id - [INTEGER] - NOT NULL - -dag_id - [VARCHAR(250)] - NOT NULL - -duration - [INTEGER] - -end_date - [TIMESTAMP] - -map_index - [INTEGER] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -start_date - [TIMESTAMP] - -task_id - [VARCHAR(250)] - NOT NULL + +task_fail + +id + [INTEGER] + NOT NULL + +dag_id + [VARCHAR(250)] + NOT NULL + +duration + [INTEGER] + +end_date + [TIMESTAMP] + +map_index + [INTEGER] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +start_date + [TIMESTAMP] + +task_id + [VARCHAR(250)] + NOT NULL task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_map - -task_map - -dag_id - [VARCHAR(250)] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -keys - [JSON] - -length - [INTEGER] - NOT NULL + +task_map + +dag_id + [VARCHAR(250)] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +keys + [JSON] + +length + [INTEGER] + NOT NULL task_instance--task_map - -1 -1 + +1 +1 task_instance--task_map - -1 -1 + +1 +1 task_instance--task_map - -1 -1 + +1 +1 task_instance--task_map - -1 -1 + +1 +1 xcom - -xcom - -dag_run_id - [INTEGER] - NOT NULL - -key - [VARCHAR(512)] - NOT NULL - -map_index - [INTEGER] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -dag_id - [VARCHAR(250)] - NOT NULL - -run_id - [VARCHAR(250)] - NOT NULL - -timestamp - [TIMESTAMP] - NOT NULL - -value - [BLOB] + +xcom + +dag_run_id + [INTEGER] + NOT NULL + +key + [VARCHAR(512)] + NOT NULL + +map_index + [INTEGER] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +dag_id + [VARCHAR(250)] + NOT NULL + +run_id + [VARCHAR(250)] + NOT NULL + +timestamp + [TIMESTAMP] + NOT NULL + +value + [BLOB] task_instance--xcom - -1 -1 + +0..N +1 task_instance--xcom - -0..N -1 + +1 +1 task_instance--xcom - -1 -1 + +0..N +1 task_instance--xcom - -0..N -1 + +1 +1 log_template - -log_template - -id - [INTEGER] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL - -elasticsearch_id - [TEXT] - NOT NULL - -filename - [TEXT] - NOT NULL + +log_template + +id + [INTEGER] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL + +elasticsearch_id + [TEXT] + NOT NULL + +filename + [TEXT] + NOT NULL log_template--dag_run - -0..N -{0,1} + +0..N +{0,1} dataset - -dataset - -id - [INTEGER] - NOT NULL - -created_at - [TIMESTAMP] - NOT NULL - -extra - [JSON] - NOT NULL - -is_orphaned - [BOOLEAN] - NOT NULL - -updated_at - [TIMESTAMP] - NOT NULL - -uri - [VARCHAR(3000)] - NOT NULL + +dataset + +id + [INTEGER] + NOT NULL + +created_at + [TIMESTAMP] + NOT NULL + +extra + [JSON] + NOT NULL + +is_orphaned + [BOOLEAN] + NOT NULL + +updated_at + [TIMESTAMP] + NOT NULL + +uri + [VARCHAR(3000)] + NOT NULL dataset--dag_schedule_dataset_reference - -1 -1 + +1 +1 dataset--dataset_dag_run_queue - -1 -1 + +1 +1 dataset--task_outlet_dataset_reference - -1 -1 + +1 +1 dataset_event - -dataset_event - -id - [INTEGER] - NOT NULL - -dataset_id - [INTEGER] - NOT NULL - -extra - [JSON] - NOT NULL - -source_dag_id - [VARCHAR(250)] - -source_map_index - [INTEGER] - -source_run_id - [VARCHAR(250)] - -source_task_id - [VARCHAR(250)] - -timestamp - [TIMESTAMP] - NOT NULL + +dataset_event + +id + [INTEGER] + NOT NULL + +dataset_id + [INTEGER] + NOT NULL + +extra + [JSON] + NOT NULL + +source_dag_id + [VARCHAR(250)] + +source_map_index + [INTEGER] + +source_run_id + [VARCHAR(250)] + +source_task_id + [VARCHAR(250)] + +timestamp + [TIMESTAMP] + NOT NULL dataset_event--dagrun_dataset_event - -1 -1 + +1 +1 import_error - -import_error - -id - [INTEGER] - NOT NULL - -filename - [VARCHAR(1024)] - -stacktrace - [TEXT] - -timestamp - [TIMESTAMP] + +import_error + +id + [INTEGER] + NOT NULL + +filename + [VARCHAR(1024)] + +stacktrace + [TEXT] + +timestamp + [TIMESTAMP] job - -job - -id - [INTEGER] - NOT NULL - -dag_id - [VARCHAR(250)] - -end_date - [TIMESTAMP] - -executor_class - [VARCHAR(500)] - -hostname - [VARCHAR(500)] - -job_type - [VARCHAR(30)] - -latest_heartbeat - [TIMESTAMP] - -start_date - [TIMESTAMP] - -state - [VARCHAR(20)] - -unixname - [VARCHAR(1000)] + +job + +id + [INTEGER] + NOT NULL + +dag_id + [VARCHAR(250)] + +end_date + [TIMESTAMP] + +executor_class + [VARCHAR(500)] + +hostname + [VARCHAR(500)] + +job_type + [VARCHAR(30)] + +latest_heartbeat + [TIMESTAMP] + +start_date + [TIMESTAMP] + +state + [VARCHAR(20)] + +unixname + [VARCHAR(1000)] log - -log - -id - [INTEGER] - NOT NULL - -dag_id - [VARCHAR(250)] - -dttm - [TIMESTAMP] - -event - [VARCHAR(30)] - -execution_date - [TIMESTAMP] - -extra - [TEXT] - -map_index - [INTEGER] - -owner - [VARCHAR(500)] - -owner_display_name - [VARCHAR(500)] - -task_id - [VARCHAR(250)] + +log + +id + [INTEGER] + NOT NULL + +dag_id + [VARCHAR(250)] + +dttm + [TIMESTAMP] + +event + [VARCHAR(30)] + +execution_date + [TIMESTAMP] + +extra + [TEXT] + +map_index + [INTEGER] + +owner + [VARCHAR(500)] + +owner_display_name + [VARCHAR(500)] + +task_id + [VARCHAR(250)] trigger - -trigger - -id - [INTEGER] - NOT NULL - -classpath - [VARCHAR(1000)] - NOT NULL - -created_date - [TIMESTAMP] - NOT NULL - -kwargs - [JSON] - NOT NULL - -triggerer_id - [INTEGER] + +trigger + +id + [INTEGER] + NOT NULL + +classpath + [VARCHAR(1000)] + NOT NULL + +created_date + [TIMESTAMP] + NOT NULL + +kwargs + [JSON] + NOT NULL + +triggerer_id + [INTEGER] trigger--task_instance - -0..N -{0,1} + +0..N +{0,1} serialized_dag - -serialized_dag - -dag_id - [VARCHAR(250)] - NOT NULL - -dag_hash - [VARCHAR(32)] - NOT NULL - -data - [JSON] - -data_compressed - [BLOB] - -fileloc - [VARCHAR(2000)] - NOT NULL - -fileloc_hash - [BIGINT] - NOT NULL - -last_updated - [TIMESTAMP] - NOT NULL - -processor_subdir - [VARCHAR(2000)] + +serialized_dag + +dag_id + [VARCHAR(250)] + NOT NULL + +dag_hash + [VARCHAR(32)] + NOT NULL + +data + [JSON] + +data_compressed + [BLOB] + +fileloc + [VARCHAR(2000)] + NOT NULL + +fileloc_hash + [BIGINT] + NOT NULL + +last_updated + [TIMESTAMP] + NOT NULL + +processor_subdir + [VARCHAR(2000)] session - -session - -id - [INTEGER] - NOT NULL - -data - [BLOB] - -expiry - [DATETIME] - -session_id - [VARCHAR(255)] + +session + +id + [INTEGER] + NOT NULL + +data + [BLOB] + +expiry + [DATETIME] + +session_id + [VARCHAR(255)] sla_miss - -sla_miss - -dag_id - [VARCHAR(250)] - NOT NULL - -execution_date - [TIMESTAMP] - NOT NULL - -task_id - [VARCHAR(250)] - NOT NULL - -description - [TEXT] - -email_sent - [BOOLEAN] - -notification_sent - [BOOLEAN] - -timestamp - [TIMESTAMP] + +sla_miss + +dag_id + [VARCHAR(250)] + NOT NULL + +execution_date + [TIMESTAMP] + NOT NULL + +task_id + [VARCHAR(250)] + NOT NULL + +description + [TEXT] + +email_sent + [BOOLEAN] + +notification_sent + [BOOLEAN] + +timestamp + [TIMESTAMP] slot_pool - -slot_pool - -id - [INTEGER] - NOT NULL - -description - [TEXT] - -include_deferred - [BOOLEAN] - NOT NULL - -pool - [VARCHAR(256)] - -slots - [INTEGER] + +slot_pool + +id + [INTEGER] + NOT NULL + +description + [TEXT] + +include_deferred + [BOOLEAN] + NOT NULL + +pool + [VARCHAR(256)] + +slots + [INTEGER] variable - -variable - -id - [INTEGER] - NOT NULL - -description - [TEXT] - -is_encrypted - [BOOLEAN] - -key - [VARCHAR(250)] - -val - [TEXT] + +variable + +id + [INTEGER] + NOT NULL + +description + [TEXT] + +is_encrypted + [BOOLEAN] + +key + [VARCHAR(250)] + +val + [TEXT] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 72f467f6fbc57..af23ce50ca0f6 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=================================+===================+===================+==============================================================+ -| ``bd5dfbe21f88`` (head) | ``f7bf2a57d0a6`` | ``2.8.0`` | Make connection login/password TEXT | +| ``624ecf3b6a5e`` (head) | ``bd5dfbe21f88`` | ``2.8.0`` | add priority_weight_strategy to task_instance | ++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ +| ``bd5dfbe21f88`` | ``f7bf2a57d0a6`` | ``2.8.0`` | Make connection login/password TEXT | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | ``f7bf2a57d0a6`` | ``375a816bbbf4`` | ``2.8.0`` | Add owner_display_name to (Audit) Log table | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index ada4f6836c01c..995bc939f4d48 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -96,7 +96,7 @@ from tests.test_utils.timetables import cron_timetable, delta_timetable if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator + from airflow.models.taskinstance import TaskInstance TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) @@ -120,7 +120,7 @@ def clear_datasets(): class TestPriorityWeightStrategy(PriorityWeightStrategy): - def get_weight(self, task: AbstractOperator): + def get_weight(self, ti: TaskInstance): return 99 @@ -439,11 +439,14 @@ def test_dag_task_invalid_weight_rule(self): EmptyOperator(task_id="should_fail", weight_rule="no rule") def test_dag_task_custom_weight_strategy(self): - with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}): + with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag: task = EmptyOperator( - task_id="empty_task", weight_rule="tests.models.test_dag.TestPriorityWeightStrategy" + task_id="empty_task", + priority_weight_strategy="tests.models.test_dag.TestPriorityWeightStrategy", ) - assert task.priority_weight_total == 99 + dr = dag.create_dagrun(state=None, run_id="test", execution_date=DEFAULT_DATE) + ti = dr.get_task_instance(task.task_id) + assert ti.priority_weight == 99 def test_get_num_task_instances(self): test_dag_id = "test_get_num_task_instances_dag" From 01025831cbfde44a0c805a4ab1b35c991cb0a1b4 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sun, 29 Oct 2023 23:49:24 +0100 Subject: [PATCH 04/11] Fix loading the var from mapped operators and simplify loading it from task --- airflow/models/mappedoperator.py | 7 +++++-- airflow/models/taskinstance.py | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 21797f6b476fe..669e5bab4309d 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -474,12 +474,15 @@ def priority_weight(self) -> int: # type: ignore[override] return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) @property - def weight_rule(self) -> int: # type: ignore[override] + def weight_rule(self) -> str: # type: ignore[override] return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) @property def priority_weight_strategy(self) -> str: # type: ignore[override] - return self.partial_kwargs.get("priority_weight_strategy", DEFAULT_PRIORITY_WEIGHT_STRATEGY) + return ( + self.partial_kwargs.get("priority_weight_strategy", DEFAULT_PRIORITY_WEIGHT_STRATEGY) + or self.weight_rule + ) @property def sla(self) -> datetime.timedelta | None: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 6b096b7d363b4..9fe535cca4857 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -876,7 +876,7 @@ def _refresh_from_task( with contextlib.suppress(Exception): # This method is called from the different places, and sometimes the TI is not fully initialized task_instance.priority_weight = get_priority_weight_strategy( - task.priority_weight_strategy or str(task.weight_rule) + task.priority_weight_strategy ).get_weight( task_instance # type: ignore ) @@ -1383,9 +1383,9 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any :meta private: """ - priority_weight = get_priority_weight_strategy( - task.priority_weight_strategy or str(task.weight_rule) - ).get_weight(TaskInstance(task=task, run_id=run_id, map_index=map_index)) + priority_weight = get_priority_weight_strategy(task.priority_weight_strategy).get_weight( + TaskInstance(task=task, run_id=run_id, map_index=map_index) + ) return { "dag_id": task.dag_id, "task_id": task.task_id, @@ -1397,7 +1397,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any "pool": task.pool, "pool_slots": task.pool_slots, "priority_weight": priority_weight, - "priority_weight_strategy": task.priority_weight_strategy or task.weight_rule, + "priority_weight_strategy": task.priority_weight_strategy, "run_as_user": task.run_as_user, "max_tries": task.retries, "executor_config": task.executor_config, From 813391958556011f58e5728ac521f408b5ff67e1 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 1 Nov 2023 22:31:39 +0100 Subject: [PATCH 05/11] Update default value and deprecated the other one --- airflow/config_templates/config.yml | 8 ++++++-- airflow/models/abstractoperator.py | 13 ++++++++----- airflow/models/baseoperator.py | 24 +++++++++++++----------- airflow/models/mappedoperator.py | 11 ++++++----- airflow/models/taskinstance.py | 4 ---- airflow/utils/weight_rule.py | 6 +++++- tests/www/views/test_views_tasks.py | 7 +++++++ 7 files changed, 45 insertions(+), 28 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 398afa593f437..6613edfb8fe96 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -306,16 +306,20 @@ core: description: | The weighting method used for the effective total priority weight of the task version_added: 2.2.0 + version_deprecated: 2.8.0 + deprecation_reason: | + This option is deprecated and will be removed in Airflow 3.0. + Please use ``default_task_priority_weight_strategy`` instead. type: string example: ~ - default: "downstream" + default: ~ default_task_priority_weight_strategy: description: | The strategy used for the effective total priority weight of the task version_added: 2.8.0 type: string example: ~ - default: ~ + default: "downstream" default_task_execution_timeout: description: | The default task execution_timeout value for the operators. Expected an integer value to diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index cc077b8b863b6..14de7a4717299 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -70,11 +70,14 @@ ) MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( - conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) +DEFAULT_WEIGHT_RULE: WeightRule | None = ( + WeightRule(conf.get("core", "default_task_weight_rule", fallback=None)) + if conf.get("core", "default_task_weight_rule", fallback=None) + else None ) -DEFAULT_PRIORITY_WEIGHT_STRATEGY: str | None = conf.get( - "core", "default_task_priority_weight_strategy", fallback=None + +DEFAULT_PRIORITY_WEIGHT_STRATEGY: str = conf.get( + "core", "default_task_priority_weight_strategy", fallback=WeightRule.DOWNSTREAM ) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( @@ -101,7 +104,7 @@ class AbstractOperator(Templater, DAGNode): operator_class: type[BaseOperator] | dict[str, Any] - weight_rule: str + weight_rule: str | None priority_weight_strategy: str priority_weight: int diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ce860bb53a54f..789a3acd0b26a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -212,7 +212,7 @@ def partial(**kwargs): "retry_exponential_backoff": False, "priority_weight": DEFAULT_PRIORITY_WEIGHT, "weight_rule": DEFAULT_WEIGHT_RULE, - "weight_strategy": DEFAULT_PRIORITY_WEIGHT_STRATEGY, + "priority_weight_strategy": DEFAULT_PRIORITY_WEIGHT_STRATEGY, "inlets": [], "outlets": [], } @@ -550,9 +550,9 @@ class derived from this one results in the creation of a task object, This allows the executor to trigger higher priority tasks before others when things get backed up. Set priority_weight as a higher number for more important tasks. - :param weight_rule: weighting method used for the effective total - priority weight of the task. Options are: - ``{ downstream | upstream | absolute }`` default is ``downstream`` + :param weight_rule: Deprecated field, please use ``priority_weight_strategy`` instead. + weighting method used for the effective total priority weight of the task. Options are: + ``{ downstream | upstream | absolute }`` default is ``None`` When set to ``downstream`` the effective weight of the task is the aggregate sum of all downstream descendants. As a result, upstream tasks will have higher weight and will be scheduled more aggressively @@ -572,7 +572,11 @@ class derived from this one results in the creation of a task object, significantly speeding up the task creation process as for very large DAGs. Options can be set as string or using the constants defined in the static class ``airflow.utils.WeightRule`` - :param priority_weight_strategy: TODO: add description + :param priority_weight_strategy: weighting method used for the effective total priority weight + of the task. You can provide one of the following options: + ``{ downstream | upstream | absolute }`` or the path to a custom + strategy class that extends ``airflow.task.priority_strategy.PriorityWeightStrategy``. + Default is ``downstream``. :param queue: which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues. @@ -758,8 +762,8 @@ def __init__( params: collections.abc.MutableMapping | None = None, default_args: dict | None = None, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - weight_rule: str = DEFAULT_WEIGHT_RULE, - priority_weight_strategy: str | None = DEFAULT_PRIORITY_WEIGHT_STRATEGY, + weight_rule: str | None = DEFAULT_WEIGHT_RULE, + priority_weight_strategy: str = DEFAULT_PRIORITY_WEIGHT_STRATEGY, queue: str = DEFAULT_QUEUE, pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, @@ -906,16 +910,14 @@ def __init__( ) self.priority_weight = priority_weight self.weight_rule = weight_rule - self.priority_weight_strategy: str - if not priority_weight_strategy: + self.priority_weight_strategy = priority_weight_strategy + if weight_rule: warnings.warn( "weight_rule is deprecated. Please use `priority_weight_strategy` instead.", DeprecationWarning, stacklevel=2, ) self.priority_weight_strategy = weight_rule - else: - self.priority_weight_strategy = priority_weight_strategy # validate the priority weight strategy get_priority_weight_strategy(self.priority_weight_strategy) self.resources = coerce_resources(resources) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 669e5bab4309d..2883ea71cd2d9 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -331,7 +331,7 @@ def __attrs_post_init__(self): f"{self.task_id!r}." ) # validate the priority weight strategy - get_priority_weight_strategy(str(self.weight_rule)) + get_priority_weight_strategy(self.priority_weight_strategy) @classmethod @cache @@ -474,14 +474,15 @@ def priority_weight(self) -> int: # type: ignore[override] return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) @property - def weight_rule(self) -> str: # type: ignore[override] - return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + def weight_rule(self) -> str | None: # type: ignore[override] + return self.partial_kwargs.get("weight_rule") or DEFAULT_WEIGHT_RULE @property def priority_weight_strategy(self) -> str: # type: ignore[override] return ( - self.partial_kwargs.get("priority_weight_strategy", DEFAULT_PRIORITY_WEIGHT_STRATEGY) - or self.weight_rule + self.weight_rule # for backward compatibility + or self.partial_kwargs.get("priority_weight_strategy") + or DEFAULT_PRIORITY_WEIGHT_STRATEGY ) @property diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9fe535cca4857..c128b241af4af 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1455,10 +1455,6 @@ def operator_name(self) -> str | None: """@property: use a more friendly display name for the operator, if set.""" return self.custom_operator_name or self.operator - # @property - # def priority_weight_total(self) -> int: - # return get_priority_weight_strategy(self.priority_weight_strategy).get_weight(self) - def command_as_list( self, mark_success=False, diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py index f65f2fa77e1af..dd6c554c673d7 100644 --- a/airflow/utils/weight_rule.py +++ b/airflow/utils/weight_rule.py @@ -23,7 +23,11 @@ class WeightRule(str, Enum): - """Weight rules.""" + """ + Weight rules. + + This class is deprecated and will be removed in Airflow 3 + """ DOWNSTREAM = "downstream" UPSTREAM = "upstream" diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index fbe99afeb3eea..06de0241e6614 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1132,6 +1132,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1164,6 +1165,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1196,6 +1198,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 1, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1228,6 +1231,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1260,6 +1264,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1292,6 +1297,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1324,6 +1330,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, From f4fd2821bc9c11c7fc3f76c1569ec190238ca2ae Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 2 Nov 2023 00:22:56 +0100 Subject: [PATCH 06/11] Update task endpoint API spec --- airflow/api_connexion/openapi/v1.yaml | 7 +++++++ airflow/api_connexion/schemas/task_schema.py | 1 + airflow/www/static/js/types/api-generated.ts | 10 +++++++-- .../endpoints/test_task_endpoint.py | 21 ++++++++++++------- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index e4ae9c776f1c3..7e91520333aee 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -3744,6 +3744,8 @@ components: readOnly: true weight_rule: $ref: '#/components/schemas/WeightRule' + priority_weight_strategy: + $ref: '#/components/schemas/PriorityWeightStrategy' ui_color: $ref: '#/components/schemas/Color' ui_fgcolor: @@ -4784,11 +4786,16 @@ components: WeightRule: description: Weight rule. type: string + nullable: true enum: - downstream - upstream - absolute + PriorityWeightStrategy: + description: Priority weight strategy. + type: string + HealthStatus: description: Health status type: string diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index ac1b465bb25b0..cd8ccdfd3b966 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -57,6 +57,7 @@ class TaskSchema(Schema): retry_exponential_backoff = fields.Boolean(dump_only=True) priority_weight = fields.Number(dump_only=True) weight_rule = WeightRuleField(dump_only=True) + priority_weight_strategy = fields.String(dump_only=True) ui_color = ColorField(dump_only=True) ui_fgcolor = ColorField(dump_only=True) template_fields = fields.List(fields.String(), dump_only=True) diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 607a173a6ba5b..dd7ba3071e694 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -1561,6 +1561,7 @@ export interface components { retry_exponential_backoff?: boolean; priority_weight?: number; weight_rule?: components["schemas"]["WeightRule"]; + priority_weight_strategy?: components["schemas"]["PriorityWeightStrategy"]; ui_color?: components["schemas"]["Color"]; ui_fgcolor?: components["schemas"]["Color"]; template_fields?: string[]; @@ -2234,9 +2235,11 @@ export interface components { | "always"; /** * @description Weight rule. - * @enum {string} + * @enum {string|null} */ - WeightRule: "downstream" | "upstream" | "absolute"; + WeightRule: ("downstream" | "upstream" | "absolute") | null; + /** @description Priority weight strategy. */ + PriorityWeightStrategy: string; /** * @description Health status * @enum {string|null} @@ -4952,6 +4955,9 @@ export type TriggerRule = CamelCasedPropertiesDeep< export type WeightRule = CamelCasedPropertiesDeep< components["schemas"]["WeightRule"] >; +export type PriorityWeightStrategy = CamelCasedPropertiesDeep< + components["schemas"]["PriorityWeightStrategy"] +>; export type HealthStatus = CamelCasedPropertiesDeep< components["schemas"]["HealthStatus"] >; diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index b8ef8dc0cf650..d2b717bfc093c 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -123,6 +123,7 @@ def test_should_respond_200(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -134,7 +135,7 @@ def test_should_respond_200(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } response = self.client.get( @@ -158,6 +159,7 @@ def test_mapped_task(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "microseconds": 0, "seconds": 300}, @@ -169,7 +171,7 @@ def test_mapped_task(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, } response = self.client.get( f"/api/v1/dags/{self.mapped_dag_id}/tasks/{self.mapped_task_id}", @@ -209,6 +211,7 @@ def test_should_respond_200_serialized(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -220,7 +223,7 @@ def test_should_respond_200_serialized(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } response = self.client.get( @@ -284,6 +287,7 @@ def test_should_respond_200(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -295,7 +299,7 @@ def test_should_respond_200(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, { @@ -314,6 +318,7 @@ def test_should_respond_200(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -325,7 +330,7 @@ def test_should_respond_200(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, ], @@ -354,6 +359,7 @@ def test_get_tasks_mapped(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "microseconds": 0, "seconds": 300}, @@ -365,7 +371,7 @@ def test_get_tasks_mapped(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, }, { "class_ref": { @@ -383,6 +389,7 @@ def test_get_tasks_mapped(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -394,7 +401,7 @@ def test_get_tasks_mapped(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, ], From a12d361611e4001a0ea320fab3c7947ca37b846e Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sun, 26 Nov 2023 16:29:22 +0200 Subject: [PATCH 07/11] fix tests --- tests/api_connexion/schemas/test_task_schema.py | 3 ++- tests/models/test_taskinstance.py | 1 + tests/serialization/test_dag_serialization.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py index 54403ebbf0bf2..742567867e289 100644 --- a/tests/api_connexion/schemas/test_task_schema.py +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -93,6 +93,7 @@ def test_serialize(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -104,7 +105,7 @@ def test_serialize(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } ], diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 27ce80df1ab76..a1c4281285c46 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3093,6 +3093,7 @@ def test_refresh_from_db(self, create_task_instance): "pool_slots": 25, "queue": "some_queue_id", "priority_weight": 123, + "priority_weight_strategy": "downstream", "operator": "some_custom_operator", "custom_operator_name": "some_custom_operator", "queued_dttm": run_date + datetime.timedelta(hours=1), diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 30407eb945933..3c0ce045eec94 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1243,6 +1243,7 @@ def test_no_new_fields_added_to_base_operator(self): "pool": "default_pool", "pool_slots": 1, "priority_weight": 1, + "priority_weight_strategy": "downstream", "queue": "default", "resources": None, "retries": 0, @@ -1254,7 +1255,7 @@ def test_no_new_fields_added_to_base_operator(self): "trigger_rule": "all_success", "wait_for_downstream": False, "wait_for_past_depends_before_skipping": False, - "weight_rule": "downstream", + "weight_rule": None, }, """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! From 5e77bcd7347cac78ebe04b0872367dd01b57a887 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sun, 26 Nov 2023 17:25:26 +0200 Subject: [PATCH 08/11] Update docs and add dag example --- .../example_priority_weight_strategy.py | 69 +++++++++++++++++++ .../priority-weight.rst | 12 ++-- 2 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 airflow/example_dags/example_priority_weight_strategy.py diff --git a/airflow/example_dags/example_priority_weight_strategy.py b/airflow/example_dags/example_priority_weight_strategy.py new file mode 100644 index 0000000000000..5575d74a371f9 --- /dev/null +++ b/airflow/example_dags/example_priority_weight_strategy.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of a custom PriorityWeightStrategy class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pendulum + +from airflow.models.dag import DAG +from airflow.operators.python import PythonOperator +from airflow.task.priority_strategy import PriorityWeightStrategy + +if TYPE_CHECKING: + from airflow.models import TaskInstance + + +def success_on_third_attempt(ti: TaskInstance, **context): + if ti.try_number < 3: + raise Exception("Not yet") + + +class DecreasingPriorityStrategy(PriorityWeightStrategy): + """A priority weight strategy that decreases the priority weight with each attempt.""" + + def get_weight(self, ti: TaskInstance): + return max(3 - ti._try_number + 1, 1) + + +with DAG( + dag_id="example_priority_weight_strategy", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + schedule="@daily", + tags=["example"], + default_args={ + "retries": 3, + "retry_delay": pendulum.duration(seconds=10), + }, +) as dag: + fixed_weight_task = PythonOperator( + task_id="fixed_weight_task", + python_callable=success_on_third_attempt, + priority_weight_strategy="downstream", + ) + + decreasing_weight_task = PythonOperator( + task_id="decreasing_weight_task", + python_callable=success_on_third_attempt, + priority_weight_strategy=( + "airflow.example_dags.example_priority_weight_strategy.DecreasingPriorityStrategy" + ), + ) diff --git a/docs/apache-airflow/administration-and-deployment/priority-weight.rst b/docs/apache-airflow/administration-and-deployment/priority-weight.rst index 87a9288ddcbbe..3e064123af2ba 100644 --- a/docs/apache-airflow/administration-and-deployment/priority-weight.rst +++ b/docs/apache-airflow/administration-and-deployment/priority-weight.rst @@ -22,12 +22,9 @@ Priority Weights ``priority_weight`` defines priorities in the executor queue. The default ``priority_weight`` is ``1``, and can be bumped to any integer. Moreover, each task has a true ``priority_weight`` that is calculated based on its -``weight_rule`` which defines weighting method used for the effective total priority weight of the task. +``priority_weight_strategy`` which defines weighting method used for the effective total priority weight of the task. -By default, Airflow's weighting method is ``downstream``. You can find other weighting methods in -:class:`airflow.utils.WeightRule`. - -There are three weighting methods. +Airflow has three weighting strategies: - downstream @@ -57,5 +54,10 @@ There are three weighting methods. significantly speeding up the task creation process as for very large DAGs +You can also implement your own weighting strategy by extending the class +:class:`~airflow.task.priority_strategy.PriorityWeightStrategy` and overriding the method +:meth:`~airflow.task.priority_strategy.PriorityWeightStrategy.get_weight`, the providing the path of your class +to the ``priority_weight_strategy`` parameter. + The ``priority_weight`` parameter can be used in conjunction with :ref:`concepts:pool`. From 1149fb1457b801dab546a6b70b83bc7d9f1379bc Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sun, 26 Nov 2023 18:48:14 +0200 Subject: [PATCH 09/11] Fix serialization test --- tests/api_connexion/schemas/test_task_schema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py index 742567867e289..f76fa439e83f7 100644 --- a/tests/api_connexion/schemas/test_task_schema.py +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -46,6 +46,7 @@ def test_serialize(self): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -57,7 +58,7 @@ def test_serialize(self): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } assert expected == result From 7476c3504414e36c1f9db4e3c913c32cf9c5f700 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sun, 26 Nov 2023 19:27:17 +0200 Subject: [PATCH 10/11] revert change in spark provider --- airflow/providers/apache/spark/hooks/spark_jdbc_script.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py index 63aa7d0799a89..d431782929e2e 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py @@ -18,10 +18,9 @@ from __future__ import annotations import argparse -from typing import TYPE_CHECKING, Any +from typing import Any -if TYPE_CHECKING: - from pyspark.sql import SparkSession +from pyspark.sql import SparkSession SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc" SPARK_READ_FROM_JDBC: str = "jdbc_to_spark" @@ -147,8 +146,6 @@ def _parse_arguments(args: list[str] | None = None) -> Any: def _create_spark_session(arguments: Any) -> SparkSession: - from pyspark.sql import SparkSession - return SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate() From a0f28fe38f7116000cf014d198e92f7eab91e909 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sun, 26 Nov 2023 20:07:09 +0200 Subject: [PATCH 11/11] Update unit tests --- tests/models/test_baseoperator.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index fb46fd39c79d5..28b76f3684e1e 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -784,12 +784,20 @@ def test_replace_dummy_trigger_rule(self, rule): def test_weight_rule_default(self): op = BaseOperator(task_id="test_task") - assert WeightRule.DOWNSTREAM == op.weight_rule + assert op.weight_rule is None - def test_weight_rule_override(self): + def test_priority_weight_strategy_default(self): + op = BaseOperator(task_id="test_task") + assert op.priority_weight_strategy == "downstream" + + def test_deprecated_weight_rule_override(self): op = BaseOperator(task_id="test_task", weight_rule="upstream") assert WeightRule.UPSTREAM == op.weight_rule + def test_priority_weight_strategy_override(self): + op = BaseOperator(task_id="test_task", priority_weight_strategy="upstream") + assert op.priority_weight_strategy == "upstream" + # ensure the default logging config is used for this test, no matter what ran before @pytest.mark.usefixtures("reset_logging_config") def test_logging_propogated_by_default(self, caplog):