diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py
index b1f9182e4c9c3..8bba6ebd89173 100644
--- a/airflow/cli/commands/remote_commands/task_command.py
+++ b/airflow/cli/commands/remote_commands/task_command.py
@@ -45,8 +45,8 @@
from airflow.models import TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
-from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
+from airflow.sdk.definitions.param import ParamsDict
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
diff --git a/airflow/example_dags/example_params_trigger_ui.py b/airflow/example_dags/example_params_trigger_ui.py
index e47ceae556501..ece4056764567 100644
--- a/airflow/example_dags/example_params_trigger_ui.py
+++ b/airflow/example_dags/example_params_trigger_ui.py
@@ -27,7 +27,7 @@
from airflow.decorators import task
from airflow.models.dag import DAG
-from airflow.models.param import Param, ParamsDict
+from airflow.sdk import Param, ParamsDict
from airflow.utils.trigger_rule import TriggerRule
# [START params_trigger]
diff --git a/airflow/example_dags/example_params_ui_tutorial.py b/airflow/example_dags/example_params_ui_tutorial.py
index b64e777bed144..0bf9994c95c70 100644
--- a/airflow/example_dags/example_params_ui_tutorial.py
+++ b/airflow/example_dags/example_params_ui_tutorial.py
@@ -29,7 +29,7 @@
from airflow.decorators import task
from airflow.models.dag import DAG
-from airflow.models.param import Param, ParamsDict
+from airflow.sdk import Param, ParamsDict
with (
DAG(
diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py
index ae0fa3040e181..6bd3883b139af 100644
--- a/airflow/models/__init__.py
+++ b/airflow/models/__init__.py
@@ -99,7 +99,7 @@ def __getattr__(name):
"Log": "airflow.models.log",
"MappedOperator": "airflow.models.mappedoperator",
"Operator": "airflow.models.operator",
- "Param": "airflow.models.param",
+ "Param": "airflow.sdk.definitions.param",
"Pool": "airflow.models.pool",
"RenderedTaskInstanceFields": "airflow.models.renderedtifields",
"SkipMixin": "airflow.models.skipmixin",
@@ -128,7 +128,6 @@ def __getattr__(name):
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
- from airflow.models.param import Param
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.skipmixin import SkipMixin
@@ -138,3 +137,4 @@ def __getattr__(name):
from airflow.models.trigger import Trigger
from airflow.models.variable import Variable
from airflow.models.xcom import XCom
+ from airflow.sdk.definitions.param import Param
diff --git a/airflow/models/param.py b/airflow/models/param.py
index cd3ccec26a48a..01886f6e585ab 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,340 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-import contextlib
-import copy
-import json
-import logging
-from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
-from typing import TYPE_CHECKING, Any, ClassVar
-
-from airflow.exceptions import AirflowException, ParamValidationError
-from airflow.sdk.definitions._internal.mixins import ResolveMixin
-from airflow.utils.types import NOTSET, ArgNotSet
-
-if TYPE_CHECKING:
- from airflow.sdk.definitions.context import Context
- from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.types import Operator
-
-logger = logging.getLogger(__name__)
-
-
-class Param:
- """
- Class to hold the default value of a Param and rule set to do the validations.
-
- Without the rule set it always validates and returns the default value.
-
- :param default: The value this Param object holds
- :param description: Optional help text for the Param
- :param schema: The validation schema of the Param, if not given then all kwargs except
- default & description will form the schema
- """
-
- __version__: ClassVar[int] = 1
-
- CLASS_IDENTIFIER = "__class"
-
- def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
- if default is not NOTSET:
- self._check_json(default)
- self.value = default
- self.description = description
- self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
-
- def __copy__(self) -> Param:
- return Param(self.value, self.description, schema=self.schema)
-
- @staticmethod
- def _check_json(value):
- try:
- json.dumps(value)
- except Exception:
- raise ParamValidationError(
- "All provided parameters must be json-serializable. "
- f"The value '{value}' is not serializable."
- )
-
- def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
- """
- Run the validations and returns the Param's final value.
-
- May raise ValueError on failed validations, or TypeError
- if no value is passed and no value already exists.
- We first check that value is json-serializable; if not, warn.
- In future release we will require the value to be json-serializable.
-
- :param value: The value to be updated for the Param
- :param suppress_exception: To raise an exception or not when the validations fails.
- If true and validations fails, the return value would be None.
- """
- import jsonschema
- from jsonschema import FormatChecker
- from jsonschema.exceptions import ValidationError
-
- if value is not NOTSET:
- self._check_json(value)
- final_val = self.value if value is NOTSET else value
- if isinstance(final_val, ArgNotSet):
- if suppress_exception:
- return None
- raise ParamValidationError("No value passed and Param has no default value")
- try:
- jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
- except ValidationError as err:
- if suppress_exception:
- return None
- raise ParamValidationError(err) from None
- self.value = final_val
- return final_val
-
- def dump(self) -> dict:
- """Dump the Param as a dictionary."""
- out_dict: dict[str, str | None] = {
- self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
- }
- out_dict.update(self.__dict__)
- # Ensure that not set is translated to None
- if self.value is NOTSET:
- out_dict["value"] = None
- return out_dict
-
- @property
- def has_value(self) -> bool:
- return self.value is not NOTSET and self.value is not None
-
- def serialize(self) -> dict:
- return {"value": self.value, "description": self.description, "schema": self.schema}
-
- @staticmethod
- def deserialize(data: dict[str, Any], version: int) -> Param:
- if version > Param.__version__:
- raise TypeError("serialized version > class version")
-
- return Param(default=data["value"], description=data["description"], schema=data["schema"])
-
-
-class ParamsDict(MutableMapping[str, Any]):
- """
- Class to hold all params for dags or tasks.
-
- All the keys are strictly string and values are converted into Param's object
- if they are not already. This class is to replace param's dictionary implicitly
- and ideally not needed to be used directly.
-
-
- :param dict_obj: A dict or dict like object to init ParamsDict
- :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
- """
-
- __version__: ClassVar[int] = 1
- __slots__ = ["__dict", "suppress_exception"]
-
- def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False):
- params_dict: dict[str, Param] = {}
- dict_obj = dict_obj or {}
- for k, v in dict_obj.items():
- if not isinstance(v, Param):
- params_dict[k] = Param(v)
- else:
- params_dict[k] = v
- self.__dict = params_dict
- self.suppress_exception = suppress_exception
-
- def __bool__(self) -> bool:
- return bool(self.__dict)
-
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, ParamsDict):
- return self.dump() == other.dump()
- if isinstance(other, dict):
- return self.dump() == other
- return NotImplemented
-
- def __copy__(self) -> ParamsDict:
- return ParamsDict(self.__dict, self.suppress_exception)
-
- def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
- return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
-
- def __contains__(self, o: object) -> bool:
- return o in self.__dict
- def __len__(self) -> int:
- return len(self.__dict)
+"""Re exporting the new param module from Task SDK for backward compatibility."""
- def __delitem__(self, v: str) -> None:
- del self.__dict[v]
-
- def __iter__(self):
- return iter(self.__dict)
-
- def __repr__(self):
- return repr(self.dump())
-
- def __setitem__(self, key: str, value: Any) -> None:
- """
- Override for dictionary's ``setitem`` method to ensure all values are of Param's type only.
-
- :param key: A key which needs to be inserted or updated in the dict
- :param value: A value which needs to be set against the key. It could be of any
- type but will be converted and stored as a Param object eventually.
- """
- if isinstance(value, Param):
- param = value
- elif key in self.__dict:
- param = self.__dict[key]
- try:
- param.resolve(value=value, suppress_exception=self.suppress_exception)
- except ParamValidationError as ve:
- raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
- else:
- # if the key isn't there already and if the value isn't of Param type create a new Param object
- param = Param(value)
-
- self.__dict[key] = param
-
- def __getitem__(self, key: str) -> Any:
- """
- Override for dictionary's ``getitem`` method to call the resolve method after fetching the key.
-
- :param key: The key to fetch
- """
- param = self.__dict[key]
- return param.resolve(suppress_exception=self.suppress_exception)
-
- def get_param(self, key: str) -> Param:
- """Get the internal :class:`.Param` object for this key."""
- return self.__dict[key]
-
- def items(self):
- return ItemsView(self.__dict)
-
- def values(self):
- return ValuesView(self.__dict)
-
- def update(self, *args, **kwargs) -> None:
- if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
- return super().update(args[0].__dict)
- super().update(*args, **kwargs)
-
- def dump(self) -> dict[str, Any]:
- """Dump the ParamsDict object as a dictionary, while suppressing exceptions."""
- return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
-
- def validate(self) -> dict[str, Any]:
- """Validate & returns all the Params object stored in the dictionary."""
- resolved_dict = {}
- try:
- for k, v in self.items():
- resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
- except ParamValidationError as ve:
- raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
-
- return resolved_dict
-
- def serialize(self) -> dict[str, Any]:
- return self.dump()
-
- @staticmethod
- def deserialize(data: dict, version: int) -> ParamsDict:
- if version > ParamsDict.__version__:
- raise TypeError("serialized version > class version")
-
- return ParamsDict(data)
-
-
-class DagParam(ResolveMixin):
- """
- DAG run parameter reference.
-
- This binds a simple Param object to a name within a DAG instance, so that it
- can be resolved during the runtime via the ``{{ context }}`` dictionary. The
- ideal use case of this class is to implicitly convert args passed to a
- method decorated by ``@dag``.
-
- It can be used to parameterize a DAG. You can overwrite its value by setting
- it on conf when you trigger your DagRun.
-
- This can also be used in templates by accessing ``{{ context.params }}``.
-
- **Example**:
-
- with DAG(...) as dag:
- EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
-
- :param current_dag: Dag being used for parameter.
- :param name: key value which is used to set the parameter
- :param default: Default value used if no parameter was set.
- """
-
- def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
- if default is not NOTSET:
- current_dag.params[name] = default
- self._name = name
- self._default = default
- self.current_dag = current_dag
-
- def iter_references(self) -> Iterable[tuple[Operator, str]]:
- return ()
-
- def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
- """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
- with contextlib.suppress(KeyError):
- if context["dag_run"].conf:
- return context["dag_run"].conf[self._name]
- if self._default is not NOTSET:
- return self._default
- with contextlib.suppress(KeyError):
- return context["params"][self._name]
- raise AirflowException(f"No value could be resolved for parameter {self._name}")
-
- def serialize(self) -> dict:
- """Serialize the DagParam object into a dictionary."""
- return {
- "dag_id": self.current_dag.dag_id,
- "name": self._name,
- "default": self._default,
- }
-
- @classmethod
- def deserialize(cls, data: dict, dags: dict) -> DagParam:
- """
- Deserializes the dictionary back into a DagParam object.
-
- :param data: The serialized representation of the DagParam.
- :param dags: A dictionary of available DAGs to look up the DAG.
- """
- dag_id = data["dag_id"]
- # Retrieve the current DAG from the provided DAGs dictionary
- current_dag = dags.get(dag_id)
- if not current_dag:
- raise ValueError(f"DAG with id {dag_id} not found.")
-
- return cls(current_dag=current_dag, name=data["name"], default=data["default"])
-
-
-def process_params(
- dag: DAG,
- task: Operator,
- dagrun_conf: dict[str, Any] | None,
- *,
- suppress_exception: bool,
-) -> dict[str, Any]:
- """Merge, validate params, and convert them into a simple dict."""
- from airflow.configuration import conf
+from __future__ import annotations
- dagrun_conf = dagrun_conf or {}
+from airflow.sdk.definitions.param import Param, ParamsDict
- params = ParamsDict(suppress_exception=suppress_exception)
- with contextlib.suppress(AttributeError):
- params.update(dag.params)
- if task.params:
- params.update(task.params)
- if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf:
- logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf)
- params.update(dagrun_conf)
- return params.validate()
+__all__ = ["Param", "ParamsDict"]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f1aa3a8236e9c..74130fa2151b3 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -98,7 +98,6 @@
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
-from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
@@ -108,6 +107,7 @@
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
+from airflow.sdk.definitions.param import process_params
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index b7e08a45aed74..9c5f43c0c0b5e 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -47,7 +47,6 @@
create_expand_input,
get_map_type_key,
)
-from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers_manager import ProvidersManager
@@ -64,6 +63,7 @@
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
@@ -985,7 +985,7 @@ def _serialize_params_dict(cls, params: ParamsDict | dict) -> list[tuple[str, di
class_identity = f"{v.__module__}.{v.__class__.__name__}"
except AttributeError:
class_identity = ""
- if class_identity == "airflow.models.param.Param":
+ if class_identity == "airflow.sdk.definitions.param.Param":
serialized_params.append((k, cls._serialize_param(v)))
else:
# Auto-box other values into Params object like it is done by DAG parsing as well
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 6ed1399fe63f6..0415542c6ca8c 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -303,7 +303,7 @@ def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
:meta private:
"""
- from airflow.models.param import process_params
+ from airflow.sdk.definitions.param import process_params
context["task"] = context["ti"].task = task
context["params"] = process_params(
diff --git a/docs/apache-airflow/core-concepts/params.rst b/docs/apache-airflow/core-concepts/params.rst
index b54026ccb22fc..8c1c98cd1724c 100644
--- a/docs/apache-airflow/core-concepts/params.rst
+++ b/docs/apache-airflow/core-concepts/params.rst
@@ -22,8 +22,8 @@ Params
Params enable you to provide runtime configuration to tasks. You can configure default Params in your DAG
code and supply additional Params, or overwrite Param values, at runtime when you trigger a DAG.
-:class:`~airflow.models.param.Param` values are validated with JSON Schema. For scheduled DAG runs,
-default :class:`~airflow.models.param.Param` values are used.
+:class:`~airflow.sdk.definitions.param.Param` values are validated with JSON Schema. For scheduled DAG runs,
+default :class:`~airflow.sdk.definitions.param.Param` values are used.
Also defined Params are used to render a nice UI when triggering manually.
When you trigger a DAG manually, you can modify its Params before the dagrun starts.
@@ -33,14 +33,14 @@ DAG-level Params
----------------
To add Params to a :class:`~airflow.models.dag.DAG`, initialize it with the ``params`` kwarg.
-Use a dictionary that maps Param names to either a :class:`~airflow.models.param.Param` or an object indicating the parameter's default value.
+Use a dictionary that maps Param names to either a :class:`~airflow.sdk.definitions.param.Param` or an object indicating the parameter's default value.
.. code-block::
:emphasize-lines: 7-10
from airflow import DAG
from airflow.decorators import task
- from airflow.models.param import Param
+ from airflow.sdk import Param
with DAG(
"the_dag",
@@ -127,7 +127,7 @@ You can change this by setting ``render_template_as_native_obj=True`` while init
):
-This way, the :class:`~airflow.models.param.Param`'s type is respected when it's provided to your task:
+This way, the :class:`~airflow.sdk.definitions.param.Param`'s type is respected when it's provided to your task:
.. code-block::
@@ -160,7 +160,7 @@ Another way to access your param is via a task's ``context`` kwarg.
JSON Schema Validation
----------------------
-:class:`~airflow.models.param.Param` makes use of `JSON Schema `_, so you can use the full JSON Schema specifications mentioned at https://json-schema.org/draft/2020-12/json-schema-validation.html to define ``Param`` objects.
+:class:`~airflow.sdk.definitions.param.Param` makes use of `JSON Schema `_, so you can use the full JSON Schema specifications mentioned at https://json-schema.org/draft/2020-12/json-schema-validation.html to define ``Param`` objects.
.. code-block::
@@ -195,8 +195,8 @@ JSON Schema Validation
at time of trigger.
.. note::
- As of now, for security reasons, one can not use :class:`~airflow.models.param.Param` objects derived out of custom classes. We are
- planning to have a registration system for custom :class:`~airflow.models.param.Param` classes, just like we've for Operator ExtraLinks.
+ As of now, for security reasons, one can not use :class:`~airflow.sdk.definitions.param.Param` objects derived out of custom classes. We are
+ planning to have a registration system for custom :class:`~airflow.sdk.definitions.param.Param` classes, just like we've for Operator ExtraLinks.
Use Params to Provide a Trigger UI Form
---------------------------------------
@@ -207,21 +207,21 @@ Use Params to Provide a Trigger UI Form
This form is provided when a user clicks on the "Trigger DAG" button.
The Trigger UI Form is rendered based on the pre-defined DAG Params. If the DAG has no params defined, the trigger form is skipped.
-The form elements can be defined with the :class:`~airflow.models.param.Param` class and attributes define how a form field is displayed.
+The form elements can be defined with the :class:`~airflow.sdk.definitions.param.Param` class and attributes define how a form field is displayed.
The following features are supported in the Trigger UI Form:
-- Direct scalar values (boolean, int, string, lists, dicts) from top-level DAG params are auto-boxed into :class:`~airflow.models.param.Param` objects.
+- Direct scalar values (boolean, int, string, lists, dicts) from top-level DAG params are auto-boxed into :class:`~airflow.sdk.definitions.param.Param` objects.
From the native Python data type the ``type`` attribute is auto detected. So these simple types render to a corresponding field type.
The name of the parameter is used as label and no further validation is made, all values are treated as optional.
-- If you use the :class:`~airflow.models.param.Param` class as definition of the parameter value, the following attributes can be added:
+- If you use the :class:`~airflow.sdk.definitions.param.Param` class as definition of the parameter value, the following attributes can be added:
- - The :class:`~airflow.models.param.Param` attribute ``title`` is used to render the form field label of the entry box.
+ - The :class:`~airflow.sdk.definitions.param.Param` attribute ``title`` is used to render the form field label of the entry box.
If no ``title`` is defined the parameter name/key is used instead.
- - The :class:`~airflow.models.param.Param` attribute ``description`` is rendered below an entry field as help text in gray color.
+ - The :class:`~airflow.sdk.definitions.param.Param` attribute ``description`` is rendered below an entry field as help text in gray color.
If you want to provide special formatting or links you need to use the Param attribute
``description_md``. See tutorial DAG :ref:`Params UI example DAG ` for an example.
- - The :class:`~airflow.models.param.Param` attribute ``type`` influences how a field is rendered. The following types are supported:
+ - The :class:`~airflow.sdk.definitions.param.Param` attribute ``type`` influences how a field is rendered. The following types are supported:
.. list-table::
:header-rows: 1
diff --git a/docs/apache-airflow/public-airflow-interface.rst b/docs/apache-airflow/public-airflow-interface.rst
index 2853c6fbe2e1b..d1ac63eb5b3a9 100644
--- a/docs/apache-airflow/public-airflow-interface.rst
+++ b/docs/apache-airflow/public-airflow-interface.rst
@@ -62,7 +62,7 @@ DAGs
The DAG is Airflow's core entity that represents a recurring workflow. You can create a DAG by
instantiating the :class:`~airflow.models.dag.DAG` class in your DAG file. You can also instantiate
them via :class:`~airflow.models.dagbag.DagBag` class that reads DAGs from a file or a folder. DAGs
-can also have parameters specified via :class:`~airflow.models.param.Param` class.
+can also have parameters specified via :class:`~airflow.sdk.definitions.param.Param` class.
Airflow has a set of example DAGs that you can use to learn how to write DAGs
diff --git a/providers/edge/src/airflow/providers/edge/example_dags/integration_test.py b/providers/edge/src/airflow/providers/edge/example_dags/integration_test.py
index 418164832576d..777a85ef2dd99 100644
--- a/providers/edge/src/airflow/providers/edge/example_dags/integration_test.py
+++ b/providers/edge/src/airflow/providers/edge/example_dags/integration_test.py
@@ -30,10 +30,10 @@
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models.dag import DAG
-from airflow.models.param import Param
from airflow.models.variable import Variable
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.compat.standard.operators import PythonOperator
+from airflow.sdk import Param
from airflow.utils.trigger_rule import TriggerRule
try:
diff --git a/providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py b/providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py
index da50fff8b96ee..a3b229a28f722 100644
--- a/providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py
+++ b/providers/edge/src/airflow/providers/edge/example_dags/win_notepad.py
@@ -34,7 +34,7 @@
from airflow.models import BaseOperator
from airflow.models.dag import DAG
-from airflow.models.param import Param
+from airflow.sdk import Param
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
index 3a730009d50c3..630092180b590 100644
--- a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
+++ b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
@@ -37,9 +37,9 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.models.dag import DAG
-from airflow.models.param import Param
from airflow.models.variable import Variable
from airflow.operators.empty import EmptyOperator
+from airflow.sdk import Param
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import ArgNotSet
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
index e745d3d655bdc..d0aecf20f92a6 100644
--- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
+++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
@@ -22,8 +22,8 @@
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
-from airflow.models.param import Param
from airflow.providers.fab.www.security import permissions
+from airflow.sdk.definitions.param import Param
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState
diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py
index b8d6b6609dba7..d8968fd416f11 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -48,6 +48,8 @@
__lazy_imports: dict[str, str] = {
"BaseOperator": ".definitions.baseoperator",
"Connection": ".definitions.connection",
+ "Param": ".definitions.param",
+ "ParamsDict": ".definitions.param",
"DAG": ".definitions.dag",
"EdgeModifier": ".definitions.edges",
"Label": ".definitions.edges",
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
index 1f1d90883240b..579cc94b3ce34 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -31,9 +31,9 @@
from sqlalchemy.orm import Session
from airflow.io.path import ObjectStoragePath
- from airflow.models.param import ParamsDict
from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg
+ from airflow.sdk.definitions.param import ParamsDict
from airflow.serialization.dag_dependency import DagDependency
from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index e7ecec69411ba..14d67656008e5 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -33,7 +33,6 @@
import attrs
-from airflow.models.param import ParamsDict
from airflow.sdk.definitions._internal.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
DEFAULT_OWNER,
@@ -54,6 +53,7 @@
from airflow.sdk.definitions._internal.node import validate_key
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args
from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
+from airflow.sdk.definitions.param import ParamsDict
from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py
index cd5217c8111d0..7882222d5d72d 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -51,12 +51,12 @@
ParamValidationError,
TaskNotFound,
)
-from airflow.models.param import DagParam, ParamsDict
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import AssetAll, BaseAsset
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.context import Context
+from airflow.sdk.definitions.param import DagParam, ParamsDict
from airflow.timetables.base import Timetable
from airflow.timetables.simple import (
AssetTriggeredTimetable,
diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
index 0fc0a7fa1896a..136400534243f 100644
--- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -72,10 +72,10 @@
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
- from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
+ from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.types import Operator
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.context import Context
diff --git a/task_sdk/src/airflow/sdk/definitions/param.py b/task_sdk/src/airflow/sdk/definitions/param.py
new file mode 100644
index 0000000000000..cd3ccec26a48a
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/definitions/param.py
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import contextlib
+import copy
+import json
+import logging
+from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.exceptions import AirflowException, ParamValidationError
+from airflow.sdk.definitions._internal.mixins import ResolveMixin
+from airflow.utils.types import NOTSET, ArgNotSet
+
+if TYPE_CHECKING:
+ from airflow.sdk.definitions.context import Context
+ from airflow.sdk.definitions.dag import DAG
+ from airflow.sdk.types import Operator
+
+logger = logging.getLogger(__name__)
+
+
+class Param:
+ """
+ Class to hold the default value of a Param and rule set to do the validations.
+
+ Without the rule set it always validates and returns the default value.
+
+ :param default: The value this Param object holds
+ :param description: Optional help text for the Param
+ :param schema: The validation schema of the Param, if not given then all kwargs except
+ default & description will form the schema
+ """
+
+ __version__: ClassVar[int] = 1
+
+ CLASS_IDENTIFIER = "__class"
+
+ def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
+ if default is not NOTSET:
+ self._check_json(default)
+ self.value = default
+ self.description = description
+ self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
+
+ def __copy__(self) -> Param:
+ return Param(self.value, self.description, schema=self.schema)
+
+ @staticmethod
+ def _check_json(value):
+ try:
+ json.dumps(value)
+ except Exception:
+ raise ParamValidationError(
+ "All provided parameters must be json-serializable. "
+ f"The value '{value}' is not serializable."
+ )
+
+ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
+ """
+ Run the validations and returns the Param's final value.
+
+ May raise ValueError on failed validations, or TypeError
+ if no value is passed and no value already exists.
+ We first check that value is json-serializable; if not, warn.
+ In future release we will require the value to be json-serializable.
+
+ :param value: The value to be updated for the Param
+ :param suppress_exception: To raise an exception or not when the validations fails.
+ If true and validations fails, the return value would be None.
+ """
+ import jsonschema
+ from jsonschema import FormatChecker
+ from jsonschema.exceptions import ValidationError
+
+ if value is not NOTSET:
+ self._check_json(value)
+ final_val = self.value if value is NOTSET else value
+ if isinstance(final_val, ArgNotSet):
+ if suppress_exception:
+ return None
+ raise ParamValidationError("No value passed and Param has no default value")
+ try:
+ jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
+ except ValidationError as err:
+ if suppress_exception:
+ return None
+ raise ParamValidationError(err) from None
+ self.value = final_val
+ return final_val
+
+ def dump(self) -> dict:
+ """Dump the Param as a dictionary."""
+ out_dict: dict[str, str | None] = {
+ self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
+ }
+ out_dict.update(self.__dict__)
+ # Ensure that not set is translated to None
+ if self.value is NOTSET:
+ out_dict["value"] = None
+ return out_dict
+
+ @property
+ def has_value(self) -> bool:
+ return self.value is not NOTSET and self.value is not None
+
+ def serialize(self) -> dict:
+ return {"value": self.value, "description": self.description, "schema": self.schema}
+
+ @staticmethod
+ def deserialize(data: dict[str, Any], version: int) -> Param:
+ if version > Param.__version__:
+ raise TypeError("serialized version > class version")
+
+ return Param(default=data["value"], description=data["description"], schema=data["schema"])
+
+
+class ParamsDict(MutableMapping[str, Any]):
+ """
+ Class to hold all params for dags or tasks.
+
+ All the keys are strictly string and values are converted into Param's object
+ if they are not already. This class is to replace param's dictionary implicitly
+ and ideally not needed to be used directly.
+
+
+ :param dict_obj: A dict or dict like object to init ParamsDict
+ :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
+ """
+
+ __version__: ClassVar[int] = 1
+ __slots__ = ["__dict", "suppress_exception"]
+
+ def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False):
+ params_dict: dict[str, Param] = {}
+ dict_obj = dict_obj or {}
+ for k, v in dict_obj.items():
+ if not isinstance(v, Param):
+ params_dict[k] = Param(v)
+ else:
+ params_dict[k] = v
+ self.__dict = params_dict
+ self.suppress_exception = suppress_exception
+
+ def __bool__(self) -> bool:
+ return bool(self.__dict)
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, ParamsDict):
+ return self.dump() == other.dump()
+ if isinstance(other, dict):
+ return self.dump() == other
+ return NotImplemented
+
+ def __copy__(self) -> ParamsDict:
+ return ParamsDict(self.__dict, self.suppress_exception)
+
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
+ return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
+
+ def __contains__(self, o: object) -> bool:
+ return o in self.__dict
+
+ def __len__(self) -> int:
+ return len(self.__dict)
+
+ def __delitem__(self, v: str) -> None:
+ del self.__dict[v]
+
+ def __iter__(self):
+ return iter(self.__dict)
+
+ def __repr__(self):
+ return repr(self.dump())
+
+ def __setitem__(self, key: str, value: Any) -> None:
+ """
+ Override for dictionary's ``setitem`` method to ensure all values are of Param's type only.
+
+ :param key: A key which needs to be inserted or updated in the dict
+ :param value: A value which needs to be set against the key. It could be of any
+ type but will be converted and stored as a Param object eventually.
+ """
+ if isinstance(value, Param):
+ param = value
+ elif key in self.__dict:
+ param = self.__dict[key]
+ try:
+ param.resolve(value=value, suppress_exception=self.suppress_exception)
+ except ParamValidationError as ve:
+ raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
+ else:
+ # if the key isn't there already and if the value isn't of Param type create a new Param object
+ param = Param(value)
+
+ self.__dict[key] = param
+
+ def __getitem__(self, key: str) -> Any:
+ """
+ Override for dictionary's ``getitem`` method to call the resolve method after fetching the key.
+
+ :param key: The key to fetch
+ """
+ param = self.__dict[key]
+ return param.resolve(suppress_exception=self.suppress_exception)
+
+ def get_param(self, key: str) -> Param:
+ """Get the internal :class:`.Param` object for this key."""
+ return self.__dict[key]
+
+ def items(self):
+ return ItemsView(self.__dict)
+
+ def values(self):
+ return ValuesView(self.__dict)
+
+ def update(self, *args, **kwargs) -> None:
+ if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
+ return super().update(args[0].__dict)
+ super().update(*args, **kwargs)
+
+ def dump(self) -> dict[str, Any]:
+ """Dump the ParamsDict object as a dictionary, while suppressing exceptions."""
+ return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
+
+ def validate(self) -> dict[str, Any]:
+ """Validate & returns all the Params object stored in the dictionary."""
+ resolved_dict = {}
+ try:
+ for k, v in self.items():
+ resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
+ except ParamValidationError as ve:
+ raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
+
+ return resolved_dict
+
+ def serialize(self) -> dict[str, Any]:
+ return self.dump()
+
+ @staticmethod
+ def deserialize(data: dict, version: int) -> ParamsDict:
+ if version > ParamsDict.__version__:
+ raise TypeError("serialized version > class version")
+
+ return ParamsDict(data)
+
+
+class DagParam(ResolveMixin):
+ """
+ DAG run parameter reference.
+
+ This binds a simple Param object to a name within a DAG instance, so that it
+ can be resolved during the runtime via the ``{{ context }}`` dictionary. The
+ ideal use case of this class is to implicitly convert args passed to a
+ method decorated by ``@dag``.
+
+ It can be used to parameterize a DAG. You can overwrite its value by setting
+ it on conf when you trigger your DagRun.
+
+ This can also be used in templates by accessing ``{{ context.params }}``.
+
+ **Example**:
+
+ with DAG(...) as dag:
+ EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
+
+ :param current_dag: Dag being used for parameter.
+ :param name: key value which is used to set the parameter
+ :param default: Default value used if no parameter was set.
+ """
+
+ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
+ if default is not NOTSET:
+ current_dag.params[name] = default
+ self._name = name
+ self._default = default
+ self.current_dag = current_dag
+
+ def iter_references(self) -> Iterable[tuple[Operator, str]]:
+ return ()
+
+ def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
+ """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
+ with contextlib.suppress(KeyError):
+ if context["dag_run"].conf:
+ return context["dag_run"].conf[self._name]
+ if self._default is not NOTSET:
+ return self._default
+ with contextlib.suppress(KeyError):
+ return context["params"][self._name]
+ raise AirflowException(f"No value could be resolved for parameter {self._name}")
+
+ def serialize(self) -> dict:
+ """Serialize the DagParam object into a dictionary."""
+ return {
+ "dag_id": self.current_dag.dag_id,
+ "name": self._name,
+ "default": self._default,
+ }
+
+ @classmethod
+ def deserialize(cls, data: dict, dags: dict) -> DagParam:
+ """
+ Deserializes the dictionary back into a DagParam object.
+
+ :param data: The serialized representation of the DagParam.
+ :param dags: A dictionary of available DAGs to look up the DAG.
+ """
+ dag_id = data["dag_id"]
+ # Retrieve the current DAG from the provided DAGs dictionary
+ current_dag = dags.get(dag_id)
+ if not current_dag:
+ raise ValueError(f"DAG with id {dag_id} not found.")
+
+ return cls(current_dag=current_dag, name=data["name"], default=data["default"])
+
+
+def process_params(
+ dag: DAG,
+ task: Operator,
+ dagrun_conf: dict[str, Any] | None,
+ *,
+ suppress_exception: bool,
+) -> dict[str, Any]:
+ """Merge, validate params, and convert them into a simple dict."""
+ from airflow.configuration import conf
+
+ dagrun_conf = dagrun_conf or {}
+
+ params = ParamsDict(suppress_exception=suppress_exception)
+ with contextlib.suppress(AttributeError):
+ params.update(dag.params)
+ if task.params:
+ params.update(task.params)
+ if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf:
+ logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf)
+ params.update(dagrun_conf)
+ return params.validate()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c2d2c51b630be..33ad6c9229270 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -37,6 +37,7 @@
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
from airflow.sdk.definitions.baseoperator import BaseOperator
+from airflow.sdk.definitions.param import process_params
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
@@ -84,6 +85,16 @@ def get_template_context(self) -> Context:
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?
+ dag_run_conf = None
+ if (
+ self._ti_context_from_server
+ and self._ti_context_from_server.dag_run
+ and self._ti_context_from_server.dag_run.conf
+ ):
+ dag_run_conf = self._ti_context_from_server.dag_run.conf
+
+ validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
+
# TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime()
context: Context = {
# From the Task Execution interface
@@ -100,7 +111,7 @@ def get_template_context(self) -> Context:
"outlet_events": OutletEventAccessors(),
# "inlet_events": InletEventsAccessors(task.inlets, session=session),
"macros": MacrosAccessor(),
- # "params": validated_params,
+ "params": validated_params,
# TODO: Make this go through Public API longer term.
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index e24f6e397d3e5..cc4bc4f96148a 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -184,6 +184,7 @@ def _make_context(
data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
start_date: str | datetime = "2024-12-01T01:00:00Z",
run_type: str = "manual",
+ conf=None,
) -> TIRunContext:
return TIRunContext(
dag_run=DagRun(
@@ -194,6 +195,7 @@ def _make_context(
data_interval_end=data_interval_end, # type: ignore
start_date=start_date, # type: ignore
run_type=run_type, # type: ignore
+ conf=conf,
),
max_tries=0,
)
diff --git a/task_sdk/tests/definitions/test_dag.py b/task_sdk/tests/definitions/test_dag.py
index f0e634f19b667..e6baeabe98dee 100644
--- a/task_sdk/tests/definitions/test_dag.py
+++ b/task_sdk/tests/definitions/test_dag.py
@@ -23,9 +23,9 @@
import pytest
from airflow.exceptions import DuplicateTaskIdFound
-from airflow.models.param import Param, ParamsDict
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG, dag as dag_decorator
+from airflow.sdk.definitions.param import Param, ParamsDict
DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
diff --git a/task_sdk/tests/definitions/test_mappedoperator.py b/task_sdk/tests/definitions/test_mappedoperator.py
index aba7523b5ad39..eeb79f31b4d47 100644
--- a/task_sdk/tests/definitions/test_mappedoperator.py
+++ b/task_sdk/tests/definitions/test_mappedoperator.py
@@ -22,10 +22,10 @@
import pendulum
import pytest
-from airflow.models.param import ParamsDict
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.utils.trigger_rule import TriggerRule
diff --git a/task_sdk/tests/definitions/test_param.py b/task_sdk/tests/definitions/test_param.py
new file mode 100644
index 0000000000000..93e863222ef87
--- /dev/null
+++ b/task_sdk/tests/definitions/test_param.py
@@ -0,0 +1,308 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from contextlib import nullcontext
+
+import pytest
+
+from airflow.exceptions import ParamValidationError
+from airflow.sdk.definitions.param import Param, ParamsDict
+from airflow.serialization.serialized_objects import BaseSerialization
+
+
+class TestParam:
+ def test_param_without_schema(self):
+ p = Param("test")
+ assert p.resolve() == "test"
+
+ p.value = 10
+ assert p.resolve() == 10
+
+ def test_null_param(self):
+ p = Param()
+ with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
+ p.resolve()
+ assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
+
+ p = Param(None)
+ assert p.resolve() is None
+ assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
+
+ p = Param(None, type="null")
+ assert p.resolve() is None
+ assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
+ with pytest.raises(ParamValidationError):
+ p.resolve("test")
+
+ def test_string_param(self):
+ p = Param("test", type="string")
+ assert p.resolve() == "test"
+
+ p = Param("test")
+ assert p.resolve() == "test"
+
+ p = Param("10.0.0.0", type="string", format="ipv4")
+ assert p.resolve() == "10.0.0.0"
+
+ p = Param(type="string")
+ with pytest.raises(ParamValidationError):
+ p.resolve(None)
+ with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
+ p.resolve()
+
+ @pytest.mark.parametrize(
+ "dt",
+ [
+ pytest.param("2022-01-02T03:04:05.678901Z", id="microseconds-zed-timezone"),
+ pytest.param("2022-01-02T03:04:05.678Z", id="milliseconds-zed-timezone"),
+ pytest.param("2022-01-02T03:04:05+00:00", id="seconds-00-00-timezone"),
+ pytest.param("2022-01-02T03:04:05+04:00", id="seconds-custom-timezone"),
+ ],
+ )
+ def test_string_rfc3339_datetime_format(self, dt):
+ """Test valid rfc3339 datetime."""
+ assert Param(dt, type="string", format="date-time").resolve() == dt
+
+ @pytest.mark.parametrize(
+ "dt",
+ [
+ pytest.param("2022-01-02", id="date"),
+ pytest.param("03:04:05", id="time"),
+ pytest.param("Thu, 04 Mar 2021 05:06:07 GMT", id="rfc2822-datetime"),
+ ],
+ )
+ def test_string_datetime_invalid_format(self, dt):
+ """Test invalid iso8601 and rfc3339 datetime format."""
+ with pytest.raises(ParamValidationError, match="is not a 'date-time'"):
+ Param(dt, type="string", format="date-time").resolve()
+
+ def test_string_time_format(self):
+ """Test string time format."""
+ assert Param("03:04:05", type="string", format="time").resolve() == "03:04:05"
+
+ error_pattern = "is not a 'time'"
+ with pytest.raises(ParamValidationError, match=error_pattern):
+ Param("03:04:05.06", type="string", format="time").resolve()
+
+ with pytest.raises(ParamValidationError, match=error_pattern):
+ Param("03:04", type="string", format="time").resolve()
+
+ with pytest.raises(ParamValidationError, match=error_pattern):
+ Param("24:00:00", type="string", format="time").resolve()
+
+ @pytest.mark.parametrize(
+ "date_string",
+ [
+ "2021-01-01",
+ ],
+ )
+ def test_string_date_format(self, date_string):
+ """Test string date format."""
+ assert Param(date_string, type="string", format="date").resolve() == date_string
+
+ # Note that 20120503 behaved differently in 3.11.3 Official python image. It was validated as a date
+ # there but it started to fail again in 3.11.4 released on 2023-07-05.
+ @pytest.mark.parametrize(
+ "date_string",
+ [
+ "01/01/2021",
+ "21 May 1975",
+ "20120503",
+ ],
+ )
+ def test_string_date_format_error(self, date_string):
+ """Test string date format failures."""
+ with pytest.raises(ParamValidationError, match="is not a 'date'"):
+ Param(date_string, type="string", format="date").resolve()
+
+ def test_int_param(self):
+ p = Param(5)
+ assert p.resolve() == 5
+
+ p = Param(type="integer", minimum=0, maximum=10)
+ assert p.resolve(value=5) == 5
+
+ with pytest.raises(ParamValidationError):
+ p.resolve(value=20)
+
+ def test_number_param(self):
+ p = Param(42, type="number")
+ assert p.resolve() == 42
+
+ p = Param(1.2, type="number")
+ assert p.resolve() == 1.2
+
+ p = Param("42", type="number")
+ with pytest.raises(ParamValidationError):
+ p.resolve()
+
+ def test_list_param(self):
+ p = Param([1, 2], type="array")
+ assert p.resolve() == [1, 2]
+
+ def test_dict_param(self):
+ p = Param({"a": 1, "b": 2}, type="object")
+ assert p.resolve() == {"a": 1, "b": 2}
+
+ def test_composite_param(self):
+ p = Param(type=["string", "number"])
+ assert p.resolve(value="abc") == "abc"
+ assert p.resolve(value=5.0) == 5.0
+
+ def test_param_with_description(self):
+ p = Param(10, description="Sample description")
+ assert p.description == "Sample description"
+
+ def test_suppress_exception(self):
+ p = Param("abc", type="string", minLength=2, maxLength=4)
+ assert p.resolve() == "abc"
+
+ p.value = "long_string"
+ assert p.resolve(suppress_exception=True) is None
+
+ def test_explicit_schema(self):
+ p = Param("abc", schema={type: "string"})
+ assert p.resolve() == "abc"
+
+ def test_custom_param(self):
+ class S3Param(Param):
+ def __init__(self, path: str):
+ schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
+ super().__init__(default=path, schema=schema)
+
+ p = S3Param("s3://my_bucket/my_path")
+ assert p.resolve() == "s3://my_bucket/my_path"
+
+ p = S3Param("file://not_valid/s3_path")
+ with pytest.raises(ParamValidationError):
+ p.resolve()
+
+ def test_value_saved(self):
+ p = Param("hello", type="string")
+ assert p.resolve("world") == "world"
+ assert p.resolve() == "world"
+
+ def test_dump(self):
+ p = Param("hello", description="world", type="string", minLength=2)
+ dump = p.dump()
+ assert dump["__class"] == "airflow.sdk.definitions.param.Param"
+ assert dump["value"] == "hello"
+ assert dump["description"] == "world"
+ assert dump["schema"] == {"type": "string", "minLength": 2}
+
+ @pytest.mark.parametrize(
+ "param",
+ [
+ Param("my value", description="hello", schema={"type": "string"}),
+ Param("my value", description="hello"),
+ Param(None, description=None),
+ Param([True], type="array", items={"type": "boolean"}),
+ Param(),
+ ],
+ )
+ def test_param_serialization(self, param: Param):
+ """
+ Test to make sure that native Param objects can be correctly serialized
+ """
+
+ serializer = BaseSerialization()
+ serialized_param = serializer.serialize(param)
+ restored_param: Param = serializer.deserialize(serialized_param)
+
+ assert restored_param.value == param.value
+ assert isinstance(restored_param, Param)
+ assert restored_param.description == param.description
+ assert restored_param.schema == param.schema
+
+ @pytest.mark.parametrize(
+ "default, should_raise",
+ [
+ pytest.param({0, 1, 2}, True, id="default-non-JSON-serializable"),
+ pytest.param(None, False, id="default-None"), # Param init should not warn
+ pytest.param({"b": 1}, False, id="default-JSON-serializable"), # Param init should not warn
+ ],
+ )
+ def test_param_json_validation(self, default, should_raise):
+ exception_msg = "All provided parameters must be json-serializable"
+ cm = pytest.raises(ParamValidationError, match=exception_msg) if should_raise else nullcontext()
+ with cm:
+ p = Param(default=default)
+ if not should_raise:
+ p.resolve() # when resolved with NOTSET, should not warn.
+ p.resolve(value={"a": 1}) # when resolved with JSON-serializable, should not warn.
+ with pytest.raises(ParamValidationError, match=exception_msg):
+ p.resolve(value={1, 2, 3}) # when resolved with not JSON-serializable, should warn.
+
+
+class TestParamsDict:
+ def test_params_dict(self):
+ # Init with a simple dictionary
+ pd = ParamsDict(dict_obj={"key": "value"})
+ assert isinstance(pd.get_param("key"), Param)
+ assert pd["key"] == "value"
+ assert pd.suppress_exception is False
+
+ # Init with a dict which contains Param objects
+ pd2 = ParamsDict({"key": Param("value", type="string")}, suppress_exception=True)
+ assert isinstance(pd2.get_param("key"), Param)
+ assert pd2["key"] == "value"
+ assert pd2.suppress_exception is True
+
+ # Init with another object of another ParamsDict
+ pd3 = ParamsDict(pd2)
+ assert isinstance(pd3.get_param("key"), Param)
+ assert pd3["key"] == "value"
+ assert pd3.suppress_exception is False # as it's not a deepcopy of pd2
+
+ # Dump the ParamsDict
+ assert pd.dump() == {"key": "value"}
+ assert pd2.dump() == {"key": "value"}
+ assert pd3.dump() == {"key": "value"}
+
+ # Validate the ParamsDict
+ plain_dict = pd.validate()
+ assert isinstance(plain_dict, dict)
+ pd2.validate()
+ pd3.validate()
+
+ # Update the ParamsDict
+ with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
+ pd3["key"] = 1
+
+ # Should not raise an error as suppress_exception is True
+ pd2["key"] = 1
+ pd2.validate()
+
+ def test_update(self):
+ pd = ParamsDict({"key": Param("value", type="string")})
+
+ pd.update({"key": "a"})
+ internal_value = pd.get_param("key")
+ assert isinstance(internal_value, Param)
+ with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
+ pd.update({"key": 1})
+
+ def test_repr(self):
+ pd = ParamsDict({"key": Param("value", type="string")})
+ assert repr(pd) == "{'key': 'value'}"
diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py
index 832f2b60ca351..ac0c21246c1ce 100644
--- a/task_sdk/tests/execution_time/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -71,6 +71,8 @@ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTas
from airflow.utils import timezone
dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+ if what.ti_context.dag_run.conf:
+ dag.params = what.ti_context.dag_run.conf # type: ignore[assignment]
task.dag = dag
t = dag.task_dict[task.task_id]
ti = RuntimeTaskInstance.model_construct(
@@ -120,6 +122,7 @@ def _create_task_instance(
start_date: str | datetime = "2024-12-01T01:00:00Z",
run_type: str = "manual",
try_number: int = 1,
+ conf=None,
ti_id=None,
) -> RuntimeTaskInstance:
if not ti_id:
@@ -133,6 +136,7 @@ def _create_task_instance(
data_interval_end=data_interval_end,
start_date=start_date,
run_type=run_type,
+ conf=conf,
)
startup_details = StartupDetails(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py
index 250b7e765a0cf..574d3a8617311 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -672,6 +672,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_
# Verify the context keys and values
assert context == {
+ "params": {},
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
@@ -712,6 +713,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s
context = runtime_ti.get_template_context()
assert context == {
+ "params": {},
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
@@ -944,6 +946,36 @@ def execute(self, context):
),
)
+ def test_get_param_from_context(
+ self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
+ ):
+ """Test that a params can be retrieved from context."""
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ value = context["params"]
+ print("The dag params are", value)
+
+ task = CustomOperator(task_id="print-params")
+ runtime_ti = create_runtime_ti(
+ dag_id="basic_param_dag",
+ task=task,
+ conf={
+ "x": 3,
+ "text": "Hello World!",
+ "flag": False,
+ "a_simple_list": ["one", "two", "three", "actually one value is made per line"],
+ },
+ )
+ run(runtime_ti, log=mock.MagicMock())
+
+ assert runtime_ti.task.dag.params == {
+ "x": 3,
+ "text": "Hello World!",
+ "flag": False,
+ "a_simple_list": ["one", "two", "three", "actually one value is made per line"],
+ }
+
class TestXComAfterTaskExecution:
@pytest.mark.parametrize(
@@ -1053,3 +1085,85 @@ def execute(self, context):
assert str(exc_info.value) == (
f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead"
)
+
+
+class TestDagParamRuntime:
+ def test_dag_param_resolves_from_task(self, create_runtime_ti, mock_supervisor_comms, time_machine):
+ """Test dagparam resolves on operator execution"""
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
+
+ dag = DAG(dag_id="dag_with_dag_params", start_date=timezone.datetime(2024, 12, 3))
+ dag.param("value", default="NOTSET")
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ assert dag.params["value"] == "NOTSET"
+
+ task = CustomOperator(task_id="task_with_dag_params")
+ runtime_ti = create_runtime_ti(task=task, dag_id="dag_with_dag_params")
+
+ run(runtime_ti, log=mock.MagicMock())
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=SucceedTask(
+ state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
+ ),
+ log=mock.ANY,
+ )
+
+ def test_dag_param_dag_overwrite(self, create_runtime_ti, mock_supervisor_comms, time_machine):
+ """Test dag param is overwritten from dagrun config"""
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
+
+ dag = DAG(dag_id="dag_with_dag_params_overwrite", start_date=timezone.datetime(2024, 12, 3))
+ dag.param("value", default="NOTSET")
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ # important to use self.dag here
+ assert self.dag.params["value"] == "new_value"
+
+ # asserting on the default value when not set in dag run
+ assert dag.params["value"] == "NOTSET"
+ task = CustomOperator(task_id="task_with_dag_params_overwrite")
+
+ # we reparse the dag here, and if conf passed, added as params
+ runtime_ti = create_runtime_ti(
+ task=task, dag_id="dag_with_dag_params_overwrite", conf={"value": "new_value"}
+ )
+ run(runtime_ti, log=mock.MagicMock())
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=SucceedTask(
+ state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
+ ),
+ log=mock.ANY,
+ )
+
+ def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine):
+ """ "Test dag param is retrieved from default config"""
+ instant = timezone.datetime(2024, 12, 3, 10, 0)
+ time_machine.move_to(instant, tick=False)
+
+ dag = DAG(
+ dag_id="dag_with_dag_params_default",
+ start_date=timezone.datetime(2024, 12, 3),
+ params={"value": "test"},
+ )
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ assert dag.params["value"] == "test"
+
+ assert dag.params["value"] == "test"
+ task = CustomOperator(task_id="task_with_dag_params_default")
+ runtime_ti = create_runtime_ti(task=task, dag_id="dag_with_dag_params_default")
+
+ run(runtime_ti, log=mock.MagicMock())
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=SucceedTask(
+ state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
+ ),
+ log=mock.ANY,
+ )
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 93bdb600cb80a..79190a3664709 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -320,7 +320,7 @@ def test_should_respond_200(self, url_safe_serializer):
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"description": None,
"schema": {},
"value": 1,
@@ -380,7 +380,7 @@ def test_should_respond_200_with_asset_expression(self, url_safe_serializer):
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"description": None,
"schema": {},
"value": 1,
@@ -533,7 +533,7 @@ def test_should_respond_200_serialized(self, url_safe_serializer, testing_dag_bu
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"description": None,
"schema": {},
"value": 1,
@@ -591,7 +591,7 @@ def test_should_respond_200_serialized(self, url_safe_serializer, testing_dag_bu
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"description": None,
"schema": {},
"value": 1,
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 9558dd4fd256a..dc3073a475c3d 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -30,9 +30,9 @@
from airflow.models.asset import AssetEvent, AssetModel
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
-from airflow.models.param import Param
from airflow.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
+from airflow.sdk.definitions.param import Param
from airflow.utils import timezone
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py
index 826b912ddc989..874c5d0508547 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -124,7 +124,7 @@ def test_should_respond_200(self):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -207,7 +207,7 @@ def test_unscheduled_task(self):
"owner": "airflow",
"params": {
"is_unscheduled": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": True,
"description": None,
"schema": {},
@@ -271,7 +271,7 @@ def test_should_respond_200_serialized(self, testing_dag_bundle):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -348,7 +348,7 @@ def test_should_respond_200(self):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -508,7 +508,7 @@ def test_get_unscheduled_tasks(self):
"owner": "airflow",
"params": {
"is_unscheduled": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": True,
"description": None,
"schema": {},
diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py
index d6438045249aa..800b512f993bc 100644
--- a/tests/api_connexion/schemas/test_dag_schema.py
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -167,7 +167,7 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer):
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": 1,
"description": None,
"schema": {},
@@ -229,7 +229,7 @@ def test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serialize
"owners": [],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": 1,
"description": None,
"schema": {},
diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py
index 5748529b864af..eee51c3aac73a 100644
--- a/tests/api_connexion/schemas/test_task_schema.py
+++ b/tests/api_connexion/schemas/test_task_schema.py
@@ -86,7 +86,7 @@ def test_serialize(self):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index b316b0119dd19..3d70f4dbf29f8 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -27,9 +27,9 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, DagRun
from airflow.models.asset import AssetEvent, AssetModel
-from airflow.models.param import Param
from airflow.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
+from airflow.sdk.definitions.param import Param
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, State
diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py
index 748baae71a413..8ef4a82613775 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dags.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dags.py
@@ -377,7 +377,7 @@ def test_dag_details(
"owners": ["airflow"],
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"description": None,
"schema": {},
"value": 1,
diff --git a/tests/api_fastapi/core_api/routes/public/test_tasks.py b/tests/api_fastapi/core_api/routes/public/test_tasks.py
index 2c00a9e96a7b5..b2e7671365690 100644
--- a/tests/api_fastapi/core_api/routes/public/test_tasks.py
+++ b/tests/api_fastapi/core_api/routes/public/test_tasks.py
@@ -103,7 +103,7 @@ def test_should_respond_200(self, test_client):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -185,7 +185,7 @@ def test_unscheduled_task(self, test_client):
"owner": "airflow",
"params": {
"is_unscheduled": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": True,
"description": None,
"schema": {},
@@ -248,7 +248,7 @@ def test_should_respond_200_serialized(self, test_client, testing_dag_bundle):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -313,7 +313,7 @@ def test_should_respond_200(self, test_client):
"owner": "airflow",
"params": {
"foo": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": "bar",
"description": None,
"schema": {},
@@ -469,7 +469,7 @@ def test_get_unscheduled_tasks(self, test_client):
"owner": "airflow",
"params": {
"is_unscheduled": {
- "__class": "airflow.models.param.Param",
+ "__class": "airflow.sdk.definitions.param.Param",
"value": True,
"description": None,
"schema": {},
diff --git a/tests/dags/test_invalid_param.py b/tests/dags/test_invalid_param.py
index fb0d3c854d12d..547fc7c11253d 100644
--- a/tests/dags/test_invalid_param.py
+++ b/tests/dags/test_invalid_param.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow.models.dag import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_invalid_param",
diff --git a/tests/dags/test_invalid_param2.py b/tests/dags/test_invalid_param2.py
index 69ffda442301d..5678f46090c89 100644
--- a/tests/dags/test_invalid_param2.py
+++ b/tests/dags/test_invalid_param2.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_invalid_param2",
diff --git a/tests/dags/test_invalid_param3.py b/tests/dags/test_invalid_param3.py
index a8017a3402b66..ea3bfa202a319 100644
--- a/tests/dags/test_invalid_param3.py
+++ b/tests/dags/test_invalid_param3.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_invalid_param3",
diff --git a/tests/dags/test_invalid_param4.py b/tests/dags/test_invalid_param4.py
index bbfc7e970c51c..0156072ba11cf 100644
--- a/tests/dags/test_invalid_param4.py
+++ b/tests/dags/test_invalid_param4.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_invalid_param4",
diff --git a/tests/dags/test_valid_param.py b/tests/dags/test_valid_param.py
index afa0f98ce21d5..ddb858a9acc1e 100644
--- a/tests/dags/test_valid_param.py
+++ b/tests/dags/test_valid_param.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_valid_param",
diff --git a/tests/dags/test_valid_param2.py b/tests/dags/test_valid_param2.py
index d59d6278c3a71..ee6920bd92ee7 100644
--- a/tests/dags/test_valid_param2.py
+++ b/tests/dags/test_valid_param2.py
@@ -19,8 +19,8 @@
from datetime import datetime
from airflow import DAG
-from airflow.models.param import Param
from airflow.providers.standard.operators.python import PythonOperator
+from airflow.sdk.definitions.param import Param
with DAG(
"test_valid_param2",
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index bafddd532738f..2112f7adae6d0 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -63,7 +63,6 @@
)
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
-from airflow.models.param import DagParam, Param
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.empty import EmptyOperator
@@ -73,6 +72,7 @@
from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
+from airflow.sdk.definitions.param import DagParam, Param
from airflow.security import permissions
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.simple import (
@@ -150,7 +150,7 @@ def _create_dagrun(
triggered_by_kwargs: dict = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
run_id = dag.timetable.generate_run_id(
run_type=run_type,
- logical_date=logical_date,
+ logical_date=logical_date, # type: ignore
data_interval=data_interval,
)
return dag.create_dagrun(
diff --git a/tests/models/test_param.py b/tests/models/test_param.py
index 152419db2fa4d..77cf96eda2226 100644
--- a/tests/models/test_param.py
+++ b/tests/models/test_param.py
@@ -22,278 +22,13 @@
from airflow.decorators import task
from airflow.exceptions import ParamValidationError
-from airflow.models.param import Param, ParamsDict
-from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.sdk.definitions.param import Param
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
-class TestParam:
- def test_param_without_schema(self):
- p = Param("test")
- assert p.resolve() == "test"
-
- p.value = 10
- assert p.resolve() == 10
-
- def test_null_param(self):
- p = Param()
- with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
- p.resolve()
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
-
- p = Param(None)
- assert p.resolve() is None
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
-
- p = Param(None, type="null")
- assert p.resolve() is None
- assert p.resolve(None) is None
- assert p.dump()["value"] is None
- assert not p.has_value
- with pytest.raises(ParamValidationError):
- p.resolve("test")
-
- def test_string_param(self):
- p = Param("test", type="string")
- assert p.resolve() == "test"
-
- p = Param("test")
- assert p.resolve() == "test"
-
- p = Param("10.0.0.0", type="string", format="ipv4")
- assert p.resolve() == "10.0.0.0"
-
- p = Param(type="string")
- with pytest.raises(ParamValidationError):
- p.resolve(None)
- with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
- p.resolve()
-
- @pytest.mark.parametrize(
- "dt",
- [
- pytest.param("2022-01-02T03:04:05.678901Z", id="microseconds-zed-timezone"),
- pytest.param("2022-01-02T03:04:05.678Z", id="milliseconds-zed-timezone"),
- pytest.param("2022-01-02T03:04:05+00:00", id="seconds-00-00-timezone"),
- pytest.param("2022-01-02T03:04:05+04:00", id="seconds-custom-timezone"),
- ],
- )
- def test_string_rfc3339_datetime_format(self, dt):
- """Test valid rfc3339 datetime."""
- assert Param(dt, type="string", format="date-time").resolve() == dt
-
- @pytest.mark.parametrize(
- "dt",
- [
- pytest.param("2022-01-02", id="date"),
- pytest.param("03:04:05", id="time"),
- pytest.param("Thu, 04 Mar 2021 05:06:07 GMT", id="rfc2822-datetime"),
- ],
- )
- def test_string_datetime_invalid_format(self, dt):
- """Test invalid iso8601 and rfc3339 datetime format."""
- with pytest.raises(ParamValidationError, match="is not a 'date-time'"):
- Param(dt, type="string", format="date-time").resolve()
-
- def test_string_time_format(self):
- """Test string time format."""
- assert Param("03:04:05", type="string", format="time").resolve() == "03:04:05"
-
- error_pattern = "is not a 'time'"
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("03:04:05.06", type="string", format="time").resolve()
-
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("03:04", type="string", format="time").resolve()
-
- with pytest.raises(ParamValidationError, match=error_pattern):
- Param("24:00:00", type="string", format="time").resolve()
-
- @pytest.mark.parametrize(
- "date_string",
- [
- "2021-01-01",
- ],
- )
- def test_string_date_format(self, date_string):
- """Test string date format."""
- assert Param(date_string, type="string", format="date").resolve() == date_string
-
- # Note that 20120503 behaved differently in 3.11.3 Official python image. It was validated as a date
- # there but it started to fail again in 3.11.4 released on 2023-07-05.
- @pytest.mark.parametrize(
- "date_string",
- [
- "01/01/2021",
- "21 May 1975",
- "20120503",
- ],
- )
- def test_string_date_format_error(self, date_string):
- """Test string date format failures."""
- with pytest.raises(ParamValidationError, match="is not a 'date'"):
- Param(date_string, type="string", format="date").resolve()
-
- def test_int_param(self):
- p = Param(5)
- assert p.resolve() == 5
-
- p = Param(type="integer", minimum=0, maximum=10)
- assert p.resolve(value=5) == 5
-
- with pytest.raises(ParamValidationError):
- p.resolve(value=20)
-
- def test_number_param(self):
- p = Param(42, type="number")
- assert p.resolve() == 42
-
- p = Param(1.2, type="number")
- assert p.resolve() == 1.2
-
- p = Param("42", type="number")
- with pytest.raises(ParamValidationError):
- p.resolve()
-
- def test_list_param(self):
- p = Param([1, 2], type="array")
- assert p.resolve() == [1, 2]
-
- def test_dict_param(self):
- p = Param({"a": 1, "b": 2}, type="object")
- assert p.resolve() == {"a": 1, "b": 2}
-
- def test_composite_param(self):
- p = Param(type=["string", "number"])
- assert p.resolve(value="abc") == "abc"
- assert p.resolve(value=5.0) == 5.0
-
- def test_param_with_description(self):
- p = Param(10, description="Sample description")
- assert p.description == "Sample description"
-
- def test_suppress_exception(self):
- p = Param("abc", type="string", minLength=2, maxLength=4)
- assert p.resolve() == "abc"
-
- p.value = "long_string"
- assert p.resolve(suppress_exception=True) is None
-
- def test_explicit_schema(self):
- p = Param("abc", schema={type: "string"})
- assert p.resolve() == "abc"
-
- def test_custom_param(self):
- class S3Param(Param):
- def __init__(self, path: str):
- schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
- super().__init__(default=path, schema=schema)
-
- p = S3Param("s3://my_bucket/my_path")
- assert p.resolve() == "s3://my_bucket/my_path"
-
- p = S3Param("file://not_valid/s3_path")
- with pytest.raises(ParamValidationError):
- p.resolve()
-
- def test_value_saved(self):
- p = Param("hello", type="string")
- assert p.resolve("world") == "world"
- assert p.resolve() == "world"
-
- def test_dump(self):
- p = Param("hello", description="world", type="string", minLength=2)
- dump = p.dump()
- assert dump["__class"] == "airflow.models.param.Param"
- assert dump["value"] == "hello"
- assert dump["description"] == "world"
- assert dump["schema"] == {"type": "string", "minLength": 2}
-
- @pytest.mark.parametrize(
- "param",
- [
- Param("my value", description="hello", schema={"type": "string"}),
- Param("my value", description="hello"),
- Param(None, description=None),
- Param([True], type="array", items={"type": "boolean"}),
- Param(),
- ],
- )
- def test_param_serialization(self, param: Param):
- """
- Test to make sure that native Param objects can be correctly serialized
- """
-
- serializer = BaseSerialization()
- serialized_param = serializer.serialize(param)
- restored_param: Param = serializer.deserialize(serialized_param)
-
- assert restored_param.value == param.value
- assert isinstance(restored_param, Param)
- assert restored_param.description == param.description
- assert restored_param.schema == param.schema
-
-
-class TestParamsDict:
- def test_params_dict(self):
- # Init with a simple dictionary
- pd = ParamsDict(dict_obj={"key": "value"})
- assert isinstance(pd.get_param("key"), Param)
- assert pd["key"] == "value"
- assert pd.suppress_exception is False
-
- # Init with a dict which contains Param objects
- pd2 = ParamsDict({"key": Param("value", type="string")}, suppress_exception=True)
- assert isinstance(pd2.get_param("key"), Param)
- assert pd2["key"] == "value"
- assert pd2.suppress_exception is True
-
- # Init with another object of another ParamsDict
- pd3 = ParamsDict(pd2)
- assert isinstance(pd3.get_param("key"), Param)
- assert pd3["key"] == "value"
- assert pd3.suppress_exception is False # as it's not a deepcopy of pd2
-
- # Dump the ParamsDict
- assert pd.dump() == {"key": "value"}
- assert pd2.dump() == {"key": "value"}
- assert pd3.dump() == {"key": "value"}
-
- # Validate the ParamsDict
- plain_dict = pd.validate()
- assert isinstance(plain_dict, dict)
- pd2.validate()
- pd3.validate()
-
- # Update the ParamsDict
- with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
- pd3["key"] = 1
-
- # Should not raise an error as suppress_exception is True
- pd2["key"] = 1
- pd2.validate()
-
- def test_update(self):
- pd = ParamsDict({"key": Param("value", type="string")})
-
- pd.update({"key": "a"})
- internal_value = pd.get_param("key")
- assert isinstance(internal_value, Param)
- with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
- pd.update({"key": 1})
-
- def test_repr(self):
- pd = ParamsDict({"key": Param("value", type="string")})
- assert repr(pd) == "{'key': 'value'}"
-
-
class TestDagParamRuntime:
VALUE = 42
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index e01daac205ad4..faa157ac4d1e8 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -60,7 +60,6 @@
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated
-from airflow.models.param import process_params
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.serialized_dag import SerializedDagModel
@@ -81,6 +80,7 @@
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.python import PythonSensor
from airflow.sdk.definitions.asset import Asset, AssetAlias
+from airflow.sdk.definitions.param import process_params
from airflow.sensors.base import BaseSensorOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.stats import Stats
diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py
index 5936a95b23d6d..f3afdbbf769cc 100644
--- a/tests/serialization/serializers/test_serializers.py
+++ b/tests/serialization/serializers/test_serializers.py
@@ -31,7 +31,7 @@
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone
-from airflow.models.param import Param, ParamsDict
+from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.serialization.serde import DATA, deserialize, serialize
PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 84a63674e5119..773be6951fb53 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -63,13 +63,13 @@
from airflow.models.dagbag import DagBag
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.mappedoperator import MappedOperator
-from airflow.models.param import Param, ParamsDict
from airflow.models.xcom import XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.sdk.definitions.asset import Asset
+from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.security import permissions
from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py
index 06bb477becdf4..ca6cb78a62794 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -37,12 +37,12 @@
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
-from airflow.models.param import Param
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey
+from airflow.sdk.definitions.param import Param
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py
index 17d0b687b8572..c4136520d7f9b 100644
--- a/tests/www/views/test_views_trigger_dag.py
+++ b/tests/www/views/test_views_trigger_dag.py
@@ -25,8 +25,8 @@
import pytest
from airflow.models import DagBag, DagRun
-from airflow.models.param import Param
from airflow.operators.empty import EmptyOperator
+from airflow.sdk.definitions.param import Param
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.json import WebEncoder