diff --git a/airflow-core/newsfragments/68175.significant.rst b/airflow-core/newsfragments/68175.significant.rst deleted file mode 100644 index 4c9f7b4dffc18..0000000000000 --- a/airflow-core/newsfragments/68175.significant.rst +++ /dev/null @@ -1,24 +0,0 @@ -Airflow CLI commands are moving to talk to the API server - -The CLI is being migrated to reach Airflow through the API server (via the ``airflowctl`` -client) instead of the metadata database directly. Migrated so far: ``dags trigger``, -``dags delete``, ``pools`` (list/get/set/delete/import/export), and ``assets materialize``; -this fragment is updated as more commands migrate rather than adding new ones. - -These commands now require a reachable API server and mint a short-lived token in memory -(set ``AIRFLOW_CLI_TOKEN`` for auth managers that cannot mint locally, or for remote servers). -``airflow.api.client`` is removed — use ``airflow.cli.api_client.get_cli_api_client``. - -Each migrated command emits a ``RemovedInAirflow4Warning`` and will be removed in a future -Airflow release; use the equivalent ``airflowctl`` command instead. - -* Types of change - - * [ ] Dag changes - * [ ] Config changes - * [ ] API changes - * [ ] CLI changes - * [x] Behaviour changes - * [ ] Plugin changes - * [ ] Dependency changes - * [x] Code interface changes diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 31a88734c88fd..f126c18bd3a5d 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -154,7 +154,6 @@ dependencies = [ "universal-pathlib>=0.3.8", "uuid6>=2024.7.10", "apache-airflow-task-sdk<1.4.0,>=1.3.0", - "apache-airflow-ctl<0.1.6,>=0.1.5", # pre-installed providers "apache-airflow-providers-common-compat>=1.7.4", "apache-airflow-providers-common-io>=1.6.3", @@ -328,7 +327,6 @@ required-version = ">=0.11.8" [tool.uv.sources] apache-airflow-core = {workspace = true} -apache-airflow-ctl = {workspace = true} apache-airflow-devel-common = { workspace = true } [tool.airflow] diff --git a/airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py b/airflow-core/src/airflow/api/client/__init__.py similarity index 62% rename from airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py rename to airflow-core/src/airflow/api/client/__init__.py index f7957f1a576eb..f0d236b9019ab 100644 --- a/airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py +++ b/airflow-core/src/airflow/api/client/__init__.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,25 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""API Client that allows interacting with Airflow API.""" from __future__ import annotations -import subprocess -import sys +from airflow.api.client.local_client import Client -def test_airflowctl_is_importable(): - # checks if airflowctl imports correctly - result = subprocess.run( - [ - sys.executable, - "-c", - "import airflowctl; print('airflowctl imported successfully')", - ], - capture_output=True, - text=True, - check=False, - ) - assert result.returncode == 0, ( - f"airflowctl import failed!\nstdout: {result.stdout}\nstderr: {result.stderr}" - ) +def get_current_api_client() -> Client: + return Client() diff --git a/airflow-core/src/airflow/api/client/local_client.py b/airflow-core/src/airflow/api/client/local_client.py new file mode 100644 index 0000000000000..057d6d99c7cfd --- /dev/null +++ b/airflow-core/src/airflow/api/client/local_client.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Local client API.""" + +from __future__ import annotations + +import httpx + +from airflow.api.common import delete_dag, trigger_dag +from airflow.exceptions import AirflowBadRequest, PoolNotFound +from airflow.models.pool import Pool +from airflow.utils.types import DagRunTriggeredByType + + +class Client: + """Local API client implementation.""" + + def __init__(self, auth=None, session: httpx.Client | None = None): + self._session: httpx.Client = session or httpx.Client() + if auth: + self._session.auth = auth + + def trigger_dag( + self, + dag_id, + run_id=None, + conf=None, + logical_date=None, + triggering_user_name=None, + replace_microseconds=True, + ) -> dict | None: + dag_run = trigger_dag.trigger_dag( + dag_id=dag_id, + triggered_by=DagRunTriggeredByType.CLI, + triggering_user_name=triggering_user_name, + run_id=run_id, + conf=conf, + logical_date=logical_date, + replace_microseconds=replace_microseconds, + ) + if dag_run: + return { + "conf": dag_run.conf, + "dag_id": dag_run.dag_id, + "dag_run_id": dag_run.run_id, + "data_interval_start": dag_run.data_interval_start, + "data_interval_end": dag_run.data_interval_end, + "end_date": dag_run.end_date, + "last_scheduling_decision": dag_run.last_scheduling_decision, + "logical_date": dag_run.logical_date, + "run_type": dag_run.run_type, + "start_date": dag_run.start_date, + "state": dag_run.state, + "triggering_user_name": dag_run.triggering_user_name, + } + return dag_run + + def delete_dag(self, dag_id): + count = delete_dag.delete_dag(dag_id) + return f"Removed {count} record(s)" + + def get_pool(self, name): + pool = Pool.get_pool(pool_name=name) + if not pool: + raise PoolNotFound(f"Pool {name} not found") + return pool.pool, pool.slots, pool.description, pool.include_deferred, pool.team_name + + def get_pools(self): + return [(p.pool, p.slots, p.description, p.include_deferred, p.team_name) for p in Pool.get_pools()] + + def create_pool(self, name, slots, description, include_deferred, team_name=None): + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + pool_name_length = Pool.pool.property.columns[0].type.length + if len(name) > pool_name_length: + raise AirflowBadRequest(f"Pool name cannot be more than {pool_name_length} characters") + try: + slots = int(slots) + except ValueError: + raise AirflowBadRequest(f"Invalid value for `slots`: {slots}") + pool = Pool.create_or_update_pool( + name=name, + slots=slots, + description=description, + include_deferred=include_deferred, + team_name=team_name, + ) + return pool.pool, pool.slots, pool.description, pool.team_name + + def delete_pool(self, name): + pool = Pool.delete_pool(name=name) + return pool.pool, pool.slots, pool.description diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index 7f62d8af27c33..4ba08e5b447e5 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -180,22 +180,6 @@ def generate_jwt( self.serialize_user(user) ) - def get_cli_user(self) -> T: - """ - Return the user the local CLI acts as when calling the API server. - - The Airflow CLI mints a short-lived JWT for this user (via :meth:`generate_jwt`) - so it can talk to the API server without persisting any credentials. A generic - auth manager cannot know which user is authorized for local CLI access, so the - default raises. Auth managers that support local CLI usage should override this - to return an administrative user. Otherwise, operators must provide a token via - the ``AIRFLOW_CLI_TOKEN`` environment variable. - """ - raise NotImplementedError( - f"{type(self).__name__} does not support minting a local CLI token. " - "Set the AIRFLOW_CLI_TOKEN environment variable with a valid API token instead." - ) - @abstractmethod def get_url_login(self, **kwargs) -> str: """Return the login page url.""" diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py index 0deaadf40346b..0559a388156d5 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py @@ -238,9 +238,6 @@ def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser: def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]: return {"sub": user.username, "role": user.role, "teams": user.teams} - def get_cli_user(self) -> SimpleAuthManagerUser: - return SimpleAuthManagerUser(username="cli", role=SimpleAuthManagerRole.ADMIN.name) - def is_authorized_configuration( self, *, diff --git a/airflow-core/src/airflow/cli/api_client.py b/airflow-core/src/airflow/cli/api_client.py deleted file mode 100644 index d1ff5e2ddd93c..0000000000000 --- a/airflow-core/src/airflow/cli/api_client.py +++ /dev/null @@ -1,129 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Provide the :mod:`airflowctl` HTTP API client to the local Airflow CLI. - -The local CLI talks to the API server through the same typed client that ``airflowctl`` -uses, but without the keyring-backed credential store. For each invocation it mints a -short-lived JWT **in memory** (via the active auth manager) and builds a client with it; -nothing is persisted. Set the ``AIRFLOW_CLI_TOKEN`` environment variable to supply a token -explicitly (required for auth managers whose tokens cannot be minted locally, such as -Keycloak, or when targeting a remote API server). -""" - -from __future__ import annotations - -import atexit -import os -from collections.abc import Callable -from functools import wraps -from typing import TYPE_CHECKING, TypeVar - -import httpx - -# Re-exported so command modules import the client surface from a single place. -from airflowctl.api.client import NEW_API_CLIENT, Client, ClientKind - -from airflow.configuration import conf -from airflow.typing_compat import ParamSpec - -if TYPE_CHECKING: - from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager - -__all__ = [ - "NEW_API_CLIENT", - "Client", - "ClientKind", - "get_cli_api_client", - "provide_api_client", -] - -PS = ParamSpec("PS") -RT = TypeVar("RT") - -# Validity of the in-memory CLI token. It only needs to outlive a single CLI command -# (including the client's request retries) and is never persisted or logged. -_CLI_TOKEN_VALID_FOR_SECONDS = 300 - -_api_client: Client | None = None - - -def _resolve_base_url() -> str: - """Resolve the API server base URL from configuration.""" - base_url = conf.get("api", "base_url", fallback=None) - if base_url: - return base_url - host = conf.get("api", "host", fallback="localhost") or "localhost" - port = conf.get("api", "port", fallback="8080") or "8080" - return f"http://{host}:{port}" - - -def _mint_cli_token() -> str: - """ - Return a token for the CLI to authenticate against the API server. - - Prefers an explicit ``AIRFLOW_CLI_TOKEN`` (the universal override), otherwise mints a - short-lived JWT through the active auth manager. The token lives only in this process. - """ - if token := os.environ.get("AIRFLOW_CLI_TOKEN"): - return token - - from airflow.api_fastapi.app import get_auth_manager, init_auth_manager - - # The CLI runs outside the API server, so the auth manager singleton is usually not - # initialized yet; initialize it on demand. ``init_auth_manager`` reuses the cached - # instance when one already exists, so this is safe to call here. - try: - auth_manager: BaseAuthManager = get_auth_manager() - except RuntimeError: - auth_manager = init_auth_manager() - return auth_manager.generate_jwt( - auth_manager.get_cli_user(), - expiration_time_in_seconds=_CLI_TOKEN_VALID_FOR_SECONDS, - ) - - -def get_cli_api_client() -> Client: - """Return the process-wide singleton airflowctl client for the local CLI.""" - global _api_client - if _api_client is None: - _api_client = Client( - base_url=_resolve_base_url(), - token=_mint_cli_token(), - kind=ClientKind.CLI, - limits=httpx.Limits(max_keepalive_connections=1, max_connections=1), - ) - atexit.register(_api_client.close) - return _api_client - - -def provide_api_client(func: Callable[PS, RT]) -> Callable[PS, RT]: - """ - Provide the CLI API client to the decorated command function. - - Injects ``api_client=get_cli_api_client()`` when the caller does not pass one. Tests - (or callers that already hold a client) pass ``api_client=`` explicitly to bypass it. - """ - - @wraps(func) - def wrapper(*args, **kwargs) -> RT: - if "api_client" not in kwargs: - kwargs["api_client"] = get_cli_api_client() - return func(*args, **kwargs) - - return wrapper diff --git a/airflow-core/src/airflow/cli/commands/asset_command.py b/airflow-core/src/airflow/cli/commands/asset_command.py index 29c1025958ab6..3897e42b4fd85 100644 --- a/airflow-core/src/airflow/cli/commands/asset_command.py +++ b/airflow-core/src/airflow/cli/commands/asset_command.py @@ -17,17 +17,22 @@ from __future__ import annotations +import logging import typing from sqlalchemy import select +from airflow.api.common.trigger_dag import trigger_dag from airflow.api_fastapi.core_api.datamodels.assets import AssetAliasResponse, AssetResponse -from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client +from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse from airflow.cli.simple_table import AirflowConsole from airflow.cli.utils import deprecated_for_airflowctl -from airflow.models.asset import AssetAliasModel, AssetModel +from airflow.exceptions import AirflowConfigException +from airflow.models.asset import AssetAliasModel, AssetModel, TaskOutletAssetReference from airflow.utils import cli as cli_utils +from airflow.utils.platform import getuser from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.types import DagRunTriggeredByType, DagRunType if typing.TYPE_CHECKING: from typing import Any @@ -36,6 +41,8 @@ from airflow.api_fastapi.core_api.base import BaseModel +log = logging.getLogger(__name__) + def _list_asset_aliases(args, *, session: Session) -> tuple[Any, type[BaseModel]]: aliases = session.scalars(select(AssetAliasModel).order_by(AssetAliasModel.name)) @@ -43,13 +50,7 @@ def _list_asset_aliases(args, *, session: Session) -> tuple[Any, type[BaseModel] def _list_assets(args, *, session: Session) -> tuple[Any, type[BaseModel]]: - assets = session.scalars(select(AssetModel).order_by(AssetModel.name)).all() - for asset in assets: - for watcher in asset.watchers: - # ``AssetWatcherModel`` has no ``created_date`` column; like the public API - # serializer, derive it from the watcher's trigger so ``AssetResponse`` validation - # succeeds. Set on the instance so ``model_validate`` reads it via ``from_attributes``. - watcher.created_date = watcher.trigger.created_date + assets = session.scalars(select(AssetModel).order_by(AssetModel.name)) return assets, AssetResponse @@ -123,39 +124,50 @@ def asset_details(args, *, session: Session = NEW_SESSION) -> None: AirflowConsole().print_as(data=data, output=args.output) -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl assets materialize") -@provide_api_client -def asset_materialize(args, api_client: Client = NEW_API_CLIENT) -> None: +@cli_utils.action_cli +@provide_session +def asset_materialize(args, *, session: Session = NEW_SESSION) -> None: """ Materialize the specified asset. This is done by finding the DAG with the asset defined as outlet, and create - a run for that DAG. Resolving the DAG and creating the run is handled by the API - server; the asset is identified here by its name and/or URI. + a run for that DAG. """ if not args.name and not args.uri: raise SystemExit("Either --name or --uri is required") + stmt = select(TaskOutletAssetReference.dag_id).join(TaskOutletAssetReference.asset) select_message_parts = [] if args.name: + stmt = stmt.where(AssetModel.name == args.name) select_message_parts.append(f"name {args.name}") if args.uri: + stmt = stmt.where(AssetModel.uri == args.uri) select_message_parts.append(f"URI {args.uri}") + dag_id_it = iter(session.scalars(stmt.group_by(TaskOutletAssetReference.dag_id).limit(2))) select_message = " and ".join(select_message_parts) - matches = [ - asset - for asset in api_client.assets.list().assets - if (not args.name or asset.name == args.name) and (not args.uri or asset.uri == args.uri) - ] - if not matches: + if (dag_id := next(dag_id_it, None)) is None: raise SystemExit(f"Asset with {select_message} does not exist.") - if len(matches) > 1: - raise SystemExit(f"More than one asset exists with {select_message}.") - - dag_run = api_client.assets.materialize(asset_id=str(matches[0].id)) - AirflowConsole().print_as( - data=[dag_run.model_dump(mode="json")], - output=args.output, + if next(dag_id_it, None) is not None: + raise SystemExit(f"More than one DAG materializes asset with {select_message}.") + + try: + user = getuser() + except AirflowConfigException as e: + log.warning("Failed to get user name from os: %s, not setting the triggering user", e) + user = None + dagrun = trigger_dag( + dag_id=dag_id, + triggered_by=DagRunTriggeredByType.CLI, + run_type=DagRunType.ASSET_MATERIALIZATION, + triggering_user_name=user, + session=session, ) + if dagrun is not None: + data = [DAGRunResponse.model_validate(dagrun).model_dump(mode="json")] + else: + data = [] + + AirflowConsole().print_as(data=data, output=args.output) diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 39bdc7a04708b..756048af14d5e 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -33,15 +33,14 @@ from sqlalchemy import func, select from airflow._shared.timezones import timezone -from airflow.api_fastapi.core_api.datamodels.dag_run import TriggerDAGRunPostBody +from airflow.api.client import get_current_api_client from airflow.api_fastapi.core_api.datamodels.dags import DAGResponse -from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client from airflow.cli.simple_table import AirflowConsole from airflow.cli.utils import deprecated_for_airflowctl, fetch_dag_run_from_run_id_or_logical_date_string from airflow.dag_processing.bundles.base import unpack_bundle_version from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.dagbag import BundleDagBag, DagBag, sync_bag_to_db -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowConfigException, AirflowException from airflow.jobs.job import Job from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.errors import ParseImportError @@ -57,6 +56,7 @@ ) from airflow.utils.dot_renderer import render_dag, render_dag_dependencies from airflow.utils.helpers import ask_yesno, chunks +from airflow.utils.platform import getuser from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -80,42 +80,50 @@ _RUN_CHUNK_SIZE = 500 -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl dags trigger") +@cli_utils.action_cli @providers_configuration_loaded -@provide_api_client -def dag_trigger(args, api_client: Client = NEW_API_CLIENT) -> None: +def dag_trigger(args) -> None: """Create a dag run for the specified dag.""" - run_conf = json.loads(args.conf) if args.conf is not None else None - if run_conf is not None and not isinstance(run_conf, dict): - raise ValueError("DagRun conf must be a JSON object or null") - # The core_api request models are the source of truth; they are wire-compatible with - # the airflowctl client's generated models (the API server uses populate_by_name). - trigger_body = TriggerDAGRunPostBody( - dag_run_id=args.run_id, - conf=run_conf, - logical_date=args.logical_date, - ) - dag_run = api_client.dags.trigger(dag_id=args.dag_id, trigger_dag_run=trigger_body) # type: ignore[arg-type] - AirflowConsole().print_as( - data=[dag_run.model_dump(mode="json")], - output=args.output, - ) + api_client = get_current_api_client() + try: + user = getuser() + except AirflowConfigException as e: + log.warning("Failed to get user name from os: %s, not setting the triggering user", e) + user = None + try: + message = api_client.trigger_dag( + dag_id=args.dag_id, + run_id=args.run_id, + conf=args.conf, + logical_date=args.logical_date, + triggering_user_name=user, + replace_microseconds=args.replace_microseconds, + ) + AirflowConsole().print_as( + data=[message] if message is not None else [], + output=args.output, + ) + except OSError as err: + raise AirflowException(err) -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl dags delete") +@cli_utils.action_cli @providers_configuration_loaded -@provide_api_client -def dag_delete(args, api_client: Client = NEW_API_CLIENT) -> None: +def dag_delete(args) -> None: """Delete all DB records related to the specified dag.""" + api_client = get_current_api_client() if ( args.yes or input("This will drop all existing records related to the specified DAG. Proceed? (y/n)").upper() == "Y" ): - api_client.dags.delete(dag_id=args.dag_id) - print(f"Removed DAG {args.dag_id}") + try: + message = api_client.delete_dag(dag_id=args.dag_id) + print(message) + except OSError as err: + raise AirflowException(err) else: print("Cancelled") diff --git a/airflow-core/src/airflow/cli/commands/pool_command.py b/airflow-core/src/airflow/cli/commands/pool_command.py index a51351f73bdc8..0d1f087e37702 100644 --- a/airflow-core/src/airflow/cli/commands/pool_command.py +++ b/airflow-core/src/airflow/cli/commands/pool_command.py @@ -23,12 +23,10 @@ import os from json import JSONDecodeError -from airflowctl.api.operations import ServerResponseError - -from airflow.api_fastapi.core_api.datamodels.pools import PoolBody -from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client +from airflow.api.client import get_current_api_client from airflow.cli.simple_table import AirflowConsole from airflow.cli.utils import deprecated_for_airflowctl +from airflow.exceptions import PoolNotFound from airflow.utils import cli as cli_utils from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -39,11 +37,11 @@ def _show_pools(pools, output): data=pools, output=output, mapper=lambda x: { - "pool": x.name, - "slots": x.slots, - "description": x.description, - "include_deferred": x.include_deferred, - "team_name": x.team_name, + "pool": x[0], + "slots": x[1], + "description": x[2], + "include_deferred": x[3], + "team_name": x[4], }, ) @@ -51,66 +49,59 @@ def _show_pools(pools, output): @deprecated_for_airflowctl("airflowctl pools list") @suppress_logs_and_warning @providers_configuration_loaded -@provide_api_client -def pool_list(args, api_client: Client = NEW_API_CLIENT): +def pool_list(args): """Display info of all the pools.""" - pools = api_client.pools.list().pools + api_client = get_current_api_client() + pools = api_client.get_pools() _show_pools(pools=pools, output=args.output) @deprecated_for_airflowctl("airflowctl pools get") @suppress_logs_and_warning @providers_configuration_loaded -@provide_api_client -def pool_get(args, api_client: Client = NEW_API_CLIENT): +def pool_get(args): """Display pool info by a given name.""" + api_client = get_current_api_client() try: - pools = [api_client.pools.get(pool_name=args.pool)] + pools = [api_client.get_pool(name=args.pool)] _show_pools(pools=pools, output=args.output) - except ServerResponseError as e: - if e.response.status_code == 404: - raise SystemExit(f"Pool {args.pool} does not exist") - raise + except PoolNotFound: + raise SystemExit(f"Pool {args.pool} does not exist") -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl pools create") +@cli_utils.action_cli @suppress_logs_and_warning @providers_configuration_loaded -@provide_api_client -def pool_set(args, api_client: Client = NEW_API_CLIENT): +def pool_set(args): """Create new pool with a given name and slots.""" - # core_api PoolBody is the source of truth and is wire-compatible with the airflowctl - # client's generated model (the API server uses populate_by_name). - pool_body = PoolBody( + api_client = get_current_api_client() + api_client.create_pool( name=args.pool, slots=args.slots, description=args.description, include_deferred=args.include_deferred, team_name=args.team_name, ) - api_client.pools.create(pool=pool_body) # type: ignore[arg-type] print(f"Pool {args.pool} created") -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl pools delete") +@cli_utils.action_cli @suppress_logs_and_warning @providers_configuration_loaded -@provide_api_client -def pool_delete(args, api_client: Client = NEW_API_CLIENT): +def pool_delete(args): """Delete pool by a given name.""" + api_client = get_current_api_client() try: - api_client.pools.delete(pool=args.pool) + api_client.delete_pool(name=args.pool) print(f"Pool {args.pool} deleted") - except ServerResponseError as e: - if e.response.status_code == 404: - raise SystemExit(f"Pool {args.pool} does not exist") - raise + except PoolNotFound: + raise SystemExit(f"Pool {args.pool} does not exist") -@cli_utils.action_cli @deprecated_for_airflowctl("airflowctl pools import") +@cli_utils.action_cli @suppress_logs_and_warning @providers_configuration_loaded def pool_import(args): @@ -131,9 +122,10 @@ def pool_export(args): print(f"Exported {len(pools)} pools to {args.file}") -@provide_api_client -def pool_import_helper(filepath, api_client: Client = NEW_API_CLIENT): +def pool_import_helper(filepath): """Help import pools from the json file.""" + api_client = get_current_api_client() + with open(filepath) as poolfile: data = poolfile.read() try: @@ -144,33 +136,34 @@ def pool_import_helper(filepath, api_client: Client = NEW_API_CLIENT): failed = [] for k, v in pools_json.items(): if isinstance(v, dict) and "slots" in v and "description" in v: - pool_body = PoolBody( - name=k, - slots=v["slots"], - description=v["description"], - include_deferred=v.get("include_deferred", False), - team_name=v.get("team_name"), + pools.append( + api_client.create_pool( + name=k, + slots=v["slots"], + description=v["description"], + include_deferred=v.get("include_deferred", False), + team_name=v.get("team_name"), + ) ) - pools.append(api_client.pools.create(pool=pool_body)) # type: ignore[arg-type] else: failed.append(k) return pools, failed -@provide_api_client -def pool_export_helper(filepath, api_client: Client = NEW_API_CLIENT): +def pool_export_helper(filepath): """Help export all the pools to the json file.""" + api_client = get_current_api_client() pool_dict = {} - pools = api_client.pools.list().pools + pools = api_client.get_pools() for pool in pools: entry = { - "slots": pool.slots, - "description": pool.description, - "include_deferred": pool.include_deferred, + "slots": pool[1], + "description": pool[2], + "include_deferred": pool[3], } - if pool.team_name is not None: - entry["team_name"] = pool.team_name - pool_dict[pool.name] = entry + if pool[4] is not None: + entry["team_name"] = pool[4] + pool_dict[pool[0]] = entry with open(filepath, "w") as poolfile: poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) return pools diff --git a/airflow-core/src/airflow/cli/utils.py b/airflow-core/src/airflow/cli/utils.py index a19d761c30485..35a62c9df9135 100644 --- a/airflow-core/src/airflow/cli/utils.py +++ b/airflow-core/src/airflow/cli/utils.py @@ -17,14 +17,10 @@ from __future__ import annotations -import functools import sys -import warnings from collections.abc import Callable from typing import TYPE_CHECKING, TypeVar -from airflow.exceptions import RemovedInAirflow4Warning - # Placeholder for masking sensitive values in CLI output SENSITIVE_PLACEHOLDER = "***" @@ -44,25 +40,22 @@ def deprecated_for_airflowctl(replacement: str) -> Callable[[F], F]: """ Mark an ``airflow`` CLI command as deprecated in favour of an ``airflowctl`` equivalent. - These commands now reach Airflow through the API server via the ``airflowctl`` client. They - are kept for backwards compatibility but will be removed in a future Airflow release; users - should switch to ``airflowctl`` directly. + The command keeps its existing implementation and stays in the ``airflow`` CLI as a supported + entry point, so it emits **no user-facing deprecation warning** at runtime. The intent is to + point future development at ``airflowctl``: the equivalent ``airflowctl`` command is recorded + for maintainers only, on the ``_migrated_to_airflowctl`` attribute (the migration registry test + in ``test_command_deprecations.py`` reads it). The decorator at the command's definition site is + the developer-facing trace -- it is source-only and never rendered to users. + + See ``contributing-docs/27_cli_implementation_guide.rst`` for the CLI / ``airflowctl`` + development guidance. :param replacement: The equivalent ``airflowctl`` command, e.g. ``airflowctl dags trigger``. """ def decorator(func: F) -> F: - @functools.wraps(func) - def wrapper(*args, **kwargs): - warnings.warn( - f"This `airflow` CLI command is deprecated and will be removed in a future " - f"Airflow release. Use `{replacement}` instead.", - RemovedInAirflow4Warning, - stacklevel=2, - ) - return func(*args, **kwargs) - - return wrapper # type: ignore[return-value] + func._migrated_to_airflowctl = replacement # type: ignore[attr-defined] + return func return decorator diff --git a/airflow-core/tests/unit/cli/commands/test_asset_command.py b/airflow-core/tests/unit/cli/commands/test_asset_command.py index d30e36eb04d97..a7329a96251d2 100644 --- a/airflow-core/tests/unit/cli/commands/test_asset_command.py +++ b/airflow-core/tests/unit/cli/commands/test_asset_command.py @@ -21,7 +21,7 @@ import json import os import typing -from types import SimpleNamespace +from unittest import mock import pytest @@ -37,9 +37,7 @@ pytestmark = [pytest.mark.db_test] -# Not autouse: only the DB-backed tests below request it, so the mocked (non-DB) -# ``assets materialize`` tests stay free of any database access. -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def prepare_examples(): with conf_vars({("core", "load_examples"): "True"}): parse_and_sync_to_db(os.devnull) @@ -48,12 +46,17 @@ def prepare_examples(): clear_db_dags() +@pytest.fixture(autouse=True) +def clear_runs(): + clear_db_runs() + + @pytest.fixture(scope="module") def parser() -> ArgumentParser: return cli_parser.get_parser() -def test_cli_assets_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_list(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--output=json"]) with stdout_capture as capture: asset_command.asset_list(args) @@ -64,7 +67,7 @@ def test_cli_assets_list(prepare_examples, parser: ArgumentParser, stdout_captur assert any(asset["uri"] == "s3://dag1/output_1.txt" for asset in asset_list), asset_list -def test_cli_assets_alias_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_alias_list(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--alias", "--output=json"]) with stdout_capture as capture: asset_command.asset_list(args) @@ -75,7 +78,7 @@ def test_cli_assets_alias_list(prepare_examples, parser: ArgumentParser, stdout_ assert any(alias["name"] == "example-alias" for alias in alias_list), alias_list -def test_cli_assets_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_details(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--name=asset1_producer", "--output=json"]) with stdout_capture as capture: asset_command.asset_details(args) @@ -104,7 +107,7 @@ def test_cli_assets_details(prepare_examples, parser: ArgumentParser, stdout_cap } -def test_cli_assets_alias_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_alias_details(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--alias", "--name=example-alias", "--output=json"]) with stdout_capture as capture: asset_command.asset_details(args) @@ -121,46 +124,87 @@ def test_cli_assets_alias_details(prepare_examples, parser: ArgumentParser, stdo } -@pytest.mark.non_db_test_override -class TestCliAssetsMaterialize: - """`assets materialize` goes through the airflowctl client; mocked here (no DB/server).""" - - def test_materialize(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None: - mock_cli_api_client.assets.list.return_value.assets = [ - SimpleNamespace(id=7, name="asset1_producer", uri="s3://bucket/asset1_producer"), - SimpleNamespace(id=8, name="other", uri="s3://bucket/other"), - ] - mock_cli_api_client.assets.materialize.return_value.model_dump.return_value = { - "dag_id": "asset1_producer", - "run_type": "asset_materialization", - "state": "queued", - } - args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) - with stdout_capture as capture: - asset_command.asset_materialize(args) - - run_list = json.loads(capture.getvalue()) - assert len(run_list) == 1 - assert run_list[0]["dag_id"] == "asset1_producer" - # The asset is resolved to its id and materialization is delegated to the API server. - mock_cli_api_client.assets.materialize.assert_called_once_with(asset_id="7") - - def test_materialize_requires_name_or_uri(self, parser: ArgumentParser, mock_cli_api_client) -> None: - with pytest.raises(SystemExit, match="Either --name or --uri is required"): - asset_command.asset_materialize(parser.parse_args(["assets", "materialize"])) - mock_cli_api_client.assets.materialize.assert_not_called() - - def test_materialize_missing(self, parser: ArgumentParser, mock_cli_api_client) -> None: - mock_cli_api_client.assets.list.return_value.assets = [] - with pytest.raises(SystemExit, match="Asset with name nope does not exist"): - asset_command.asset_materialize(parser.parse_args(["assets", "materialize", "--name=nope"])) - mock_cli_api_client.assets.materialize.assert_not_called() - - def test_materialize_ambiguous(self, parser: ArgumentParser, mock_cli_api_client) -> None: - mock_cli_api_client.assets.list.return_value.assets = [ - SimpleNamespace(id=1, name="dup", uri="s3://a"), - SimpleNamespace(id=2, name="dup", uri="s3://b"), - ] - with pytest.raises(SystemExit, match="More than one asset exists with name dup"): - asset_command.asset_materialize(parser.parse_args(["assets", "materialize", "--name=dup"])) - mock_cli_api_client.assets.materialize.assert_not_called() +@mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr") +def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser, stdout_capture) -> None: + mock_hasattr.return_value = False + args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) + with stdout_capture as capture: + asset_command.asset_materialize(args) + + output = capture.getvalue() + + # Check if output is empty first + assert output, "No output captured from asset_materialize command" + + run_list = json.loads(output) + assert len(run_list) == 1 + + # No good way to statically compare these. + undeterministic: dict = { + "dag_run_id": None, + "dag_versions": [], + "data_interval_end": None, + "data_interval_start": None, + "logical_date": None, + "queued_at": None, + "run_after": "2025-02-12T19:27:59.066046Z", + } + + assert run_list[0] | undeterministic == undeterministic | { + "conf": {}, + "bundle_version": None, + "dag_display_name": "asset1_producer", + "dag_id": "asset1_producer", + "end_date": None, + "duration": None, + "last_scheduling_decision": None, + "note": None, + "partition_date": None, + "partition_key": None, + "run_type": "asset_materialization", + "start_date": None, + "state": "queued", + "triggered_by": "cli", + "triggering_user_name": "root", + "run_after": "2025-02-12T19:27:59.066046Z", + } + + +def test_cli_assets_materialize_with_view_url_template(parser: ArgumentParser, stdout_capture) -> None: + args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) + with stdout_capture as capture: + asset_command.asset_materialize(args) + + output = capture.getvalue() + run_list = json.loads(output) + assert len(run_list) == 1 + + # No good way to statically compare these. + undeterministic: dict = { + "dag_run_id": None, + "dag_versions": [], + "data_interval_end": None, + "data_interval_start": None, + "logical_date": None, + "queued_at": None, + "run_after": "2025-02-12T19:27:59.066046Z", + } + + assert run_list[0] | undeterministic == undeterministic | { + "conf": {}, + "bundle_version": None, + "dag_display_name": "asset1_producer", + "dag_id": "asset1_producer", + "end_date": None, + "duration": None, + "last_scheduling_decision": None, + "note": None, + "partition_date": None, + "partition_key": None, + "run_type": "asset_materialization", + "start_date": None, + "state": "queued", + "triggered_by": "cli", + "triggering_user_name": "root", + "run_after": "2025-02-12T19:27:59.066046Z", + } diff --git a/airflow-core/tests/unit/cli/commands/test_command_deprecations.py b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py index b4eb6840c9069..9300219fe5a1e 100644 --- a/airflow-core/tests/unit/cli/commands/test_command_deprecations.py +++ b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py @@ -18,55 +18,45 @@ """ Single source of truth for the ``airflow`` CLI commands deprecated in favour of ``airflowctl``. -Every command decorated with ``deprecated_for_airflowctl`` must have one entry below. When a new -command is migrated and deprecated, add a row to ``DEPRECATED_CLI_COMMANDS`` -- the test then -verifies it emits ``RemovedInAirflow4Warning`` pointing at the right ``airflowctl`` command. +Every command decorated with ``deprecated_for_airflowctl`` must have one entry below. When a +command is deprecated, add a row to ``MIGRATED_CLI_COMMANDS`` -- the test then verifies the decorator +recorded the right ``airflowctl`` replacement for maintainers. The commands stay in the ``airflow`` +CLI as supported entry points, so they emit no user-facing deprecation warning; they are simply no +longer developed here -- new work belongs in ``airflowctl``. See +``contributing-docs/27_cli_implementation_guide.rst`` for the CLI / ``airflowctl`` direction. """ from __future__ import annotations -import contextlib -import re - import pytest from airflow.cli.commands import asset_command, dag_command, pool_command -from airflow.exceptions import RemovedInAirflow4Warning -# (command callable, argv to parse, expected airflowctl replacement named in the warning) -DEPRECATED_CLI_COMMANDS = [ - (dag_command.dag_trigger, ["dags", "trigger", "example_dag", "--run-id=x"], "airflowctl dags trigger"), - (dag_command.dag_delete, ["dags", "delete", "example_dag", "--yes"], "airflowctl dags delete"), - (pool_command.pool_list, ["pools", "list"], "airflowctl pools list"), - (pool_command.pool_get, ["pools", "get", "foo"], "airflowctl pools get"), - (pool_command.pool_set, ["pools", "set", "foo", "1", "desc"], "airflowctl pools create"), - (pool_command.pool_delete, ["pools", "delete", "foo"], "airflowctl pools delete"), - (pool_command.pool_import, ["pools", "import", "/nonexistent.json"], "airflowctl pools import"), - ( - pool_command.pool_export, - ["pools", "export", "/tmp/airflow_pools_export.json"], - "airflowctl pools export", - ), - ( - asset_command.asset_materialize, - ["assets", "materialize", "--name=foo"], - "airflowctl assets materialize", - ), +# (command callable, expected airflowctl replacement recorded by the decorator) +MIGRATED_CLI_COMMANDS = [ + (dag_command.dag_trigger, "airflowctl dags trigger"), + (dag_command.dag_delete, "airflowctl dags delete"), + (pool_command.pool_list, "airflowctl pools list"), + (pool_command.pool_get, "airflowctl pools get"), + (pool_command.pool_set, "airflowctl pools create"), + (pool_command.pool_delete, "airflowctl pools delete"), + (pool_command.pool_import, "airflowctl pools import"), + (pool_command.pool_export, "airflowctl pools export"), + (asset_command.asset_materialize, "airflowctl assets materialize"), ] @pytest.mark.parametrize( - ("command", "argv", "replacement"), - DEPRECATED_CLI_COMMANDS, - ids=[argv[0] + "-" + argv[1] for _, argv, _ in DEPRECATED_CLI_COMMANDS], + ("command", "replacement"), + MIGRATED_CLI_COMMANDS, + ids=[replacement for _, replacement in MIGRATED_CLI_COMMANDS], ) -def test_deprecated_cli_command_points_to_airflowctl(command, argv, replacement, parser, mock_cli_api_client): - """Each migrated command warns it will become an alias for its ``airflowctl`` counterpart. +def test_migrated_cli_command_records_airflowctl_replacement(command, replacement): + """Each migrated command records its ``airflowctl`` counterpart for maintainers. - We only assert the deprecation warning fires (and names the right replacement); the command - body itself is exercised by the per-command test modules, so any error it raises against the - bare mocked client is irrelevant here and suppressed. + The marker is the maintainer-facing trace of the migration; users see no runtime deprecation + warning. The command body itself is exercised by the per-command test modules. + ``functools.wraps`` on the outer ``action_cli`` decorator propagates the attribute up to the + command object imported here. """ - with pytest.warns(RemovedInAirflow4Warning, match=re.escape(replacement)): - with contextlib.suppress(Exception, SystemExit): - command(parser.parse_args(argv)) + assert getattr(command, "_migrated_to_airflowctl", None) == replacement diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 8ba0a64156a22..059ce1d2f2608 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -25,14 +25,13 @@ from unittest import mock from unittest.mock import MagicMock -import httpx import msgspec import pendulum import pytest import time_machine -from airflowctl.api.operations import ServerResponseError -from sqlalchemy import select +from sqlalchemy import func, select +from airflow import settings from airflow._shared.timezones import timezone from airflow.cli import cli_parser from airflow.cli.commands import dag_command @@ -486,19 +485,21 @@ def test_cli_list_import_errors(self, get_test_dag, configure_testing_dag_bundle assert str(path_to_parse) in log_output assert "[0 100 * * *] is not acceptable, out of range" in log_output - def test_cli_list_dag_runs(self, dag_maker): - # Seed a run directly in the DB; ``dags trigger`` now goes through the API server - # (airflowctl client) and cannot be used as an in-process fixture here. - with dag_maker("test_list_dag_runs", start_date=DEFAULT_DATE, serialized=True): - EmptyOperator(task_id="t1") - dag_maker.create_dagrun(state=DagRunState.SUCCESS, logical_date=DEFAULT_DATE) - dag_maker.sync_dagbag_to_db() - + def test_cli_list_dag_runs(self): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + ] + ) + ) args = self.parser.parse_args( [ "dags", "list-runs", - "test_list_dag_runs", + "example_bash_operator", "--no-backfill", "--start-date", DEFAULT_DATE.isoformat(), @@ -591,6 +592,206 @@ def test_pausing_already_paused_dag_do_not_error(self, stdout_capture): out = temp_stdout.splitlines()[-1] assert out == "No unpaused DAGs were found" + def test_trigger_dag(self): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=test_trigger_dag", + '--conf={"foo": "bar"}', + ], + ), + ) + with create_session() as session: + dagrun = session.scalars(select(DagRun).where(DagRun.run_id == "test_trigger_dag")).one() + + assert dagrun, "DagRun not created" + assert dagrun.run_type == DagRunType.MANUAL + assert dagrun.conf == {"foo": "bar"} + + # logical_date is None as it's not provided + assert dagrun.logical_date is None + + # data_interval is None as logical_date is None + assert dagrun.data_interval_start is None + assert dagrun.data_interval_end is None + + def test_trigger_dag_empty_object_conf(self): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=test_trigger_dag_empty_object_conf", + "--conf={}", + ], + ), + ) + with create_session() as session: + dagrun = session.scalars( + select(DagRun).where(DagRun.run_id == "test_trigger_dag_empty_object_conf") + ).one() + + assert dagrun.conf == {} + + def test_trigger_dag_json_null_conf(self): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=test_trigger_dag_json_null_conf", + "--conf=null", + ], + ), + ) + with create_session() as session: + dagrun = session.scalars( + select(DagRun).where(DagRun.run_id == "test_trigger_dag_json_null_conf") + ).one() + + assert dagrun.conf == {} + + def test_trigger_dag_with_microseconds(self): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=test_trigger_dag_with_micro", + "--logical-date=2021-06-04T09:00:00.000001+08:00", + "--no-replace-microseconds", + ], + ) + ) + + with create_session() as session: + dagrun = session.scalars( + select(DagRun).where(DagRun.run_id == "test_trigger_dag_with_micro") + ).one() + + assert dagrun, "DagRun not created" + assert dagrun.run_type == DagRunType.MANUAL + assert dagrun.logical_date.isoformat(timespec="microseconds") == "2021-06-04T01:00:00.000001+00:00" + + @pytest.mark.parametrize("conf", ["NOT JSON", ""]) + def test_trigger_dag_invalid_conf(self, conf): + with pytest.raises(ValueError, match=r"Expecting value: line \d+ column \d+ \(char \d+\)"): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id", + "trigger_dag_xxx", + "--conf", + conf, + ] + ), + ) + + @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"]) + def test_trigger_dag_rejects_non_object_conf(self, conf): + with pytest.raises(ValueError, match="DagRun conf must be a JSON object or null"): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id", + "trigger_dag_xxx", + "--conf", + conf, + ] + ), + ) + + def test_trigger_dag_output_as_json(self, stdout_capture): + args = self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id", + "trigger_dag_xxx", + "--conf", + '{"conf1": "val1", "conf2": "val2"}', + "--output=json", + ] + ) + with stdout_capture as temp_stdout: + dag_command.dag_trigger(args) + # get the last line from the logs ignoring all logging lines + out = temp_stdout.getvalue().strip().splitlines()[-1] + parsed_out = json.loads(out) + + assert len(parsed_out) == 1 + assert parsed_out[0]["dag_id"] == "example_bash_operator" + assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx" + assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"} + + def test_delete_dag(self): + DM = DagModel + key = "my_dag_id" + session = settings.Session() + session.add(DM(dag_id=key, bundle_name="dags-folder")) + session.commit() + dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) + assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 + with pytest.raises(AirflowException): + dag_command.dag_delete( + self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]), + ) + + def test_dag_delete_when_backfill_and_dagrun_exist(self): + # Test to check that the DAG should be deleted even if + # there are backfill records associated with it. + from airflow.models.backfill import Backfill + + DM = DagModel + key = "my_dag_id" + session = settings.Session() + session.add(DM(dag_id=key, bundle_name="dags-folder")) + _backfill = Backfill(dag_id=key, from_date=DEFAULT_DATE, to_date=DEFAULT_DATE + timedelta(days=1)) + session.add(_backfill) + # To create the backfill_id in DagRun + session.flush() + session.add( + DagRun( + dag_id=key, + run_id="backfill__" + key, + state=DagRunState.SUCCESS, + run_type="backfill", + backfill_id=_backfill.id, + ) + ) + session.commit() + dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) + assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 + with pytest.raises(AirflowException): + dag_command.dag_delete( + self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]), + ) + + def test_delete_dag_existing_file(self, tmp_path): + # Test to check that the DAG should be deleted even if + # the file containing it is not deleted + path = tmp_path / "testfile" + DM = DagModel + key = "my_dag_id" + session = settings.Session() + session.add(DM(dag_id=key, bundle_name="dags-folder", fileloc=os.fspath(path))) + session.commit() + dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) + assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 + def test_cli_list_jobs(self): args = self.parser.parse_args(["dags", "list-jobs"]) dag_command.dag_list_jobs(args) @@ -1908,142 +2109,3 @@ def test_is_backfillable(self, schedule, allowed_run_types, expected): ) dag_details = dag_command._get_dagbag_dag_details(dag) assert dag_details["is_backfillable"] is expected - - -def _server_error(status_code: int) -> ServerResponseError: - request = httpx.Request("DELETE", "http://testserver/api/v2/dags/foo") - response = httpx.Response(status_code, request=request, json={"detail": "boom"}) - return ServerResponseError(message="boom", request=request, response=response) - - -@pytest.mark.non_db_test_override -class TestCliDagsApiClientCommands: - """Dag CLI commands that talk to the API server through the airflowctl client. - - These are unit tests: the airflowctl client is mocked so no API server (or - database) is required. - """ - - @classmethod - def setup_class(cls): - cls.parser = cli_parser.get_parser() - - @pytest.fixture(autouse=True) - def _default_trigger_response(self, mock_cli_api_client): - """Give the mocked ``dags.trigger`` a dict response so ``print_as`` can render it.""" - mock_cli_api_client.dags.trigger.return_value.model_dump.return_value = { - "dag_id": "example_bash_operator", - "dag_run_id": "test_run", - } - - def test_trigger_dag(self, mock_cli_api_client): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=test_trigger_dag", - '--conf={"foo": "bar"}', - ] - ), - ) - mock_cli_api_client.dags.trigger.assert_called_once() - call = mock_cli_api_client.dags.trigger.call_args - assert call.kwargs["dag_id"] == "example_bash_operator" - body = call.kwargs["trigger_dag_run"] - assert body.dag_run_id == "test_trigger_dag" - assert body.conf == {"foo": "bar"} - # logical_date is None as it's not provided - assert body.logical_date is None - - def test_trigger_dag_empty_object_conf(self, mock_cli_api_client): - dag_command.dag_trigger( - self.parser.parse_args( - ["dags", "trigger", "example_bash_operator", "--run-id=empty_conf", "--conf={}"] - ), - ) - body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] - assert body.conf == {} - - def test_trigger_dag_json_null_conf(self, mock_cli_api_client): - dag_command.dag_trigger( - self.parser.parse_args( - ["dags", "trigger", "example_bash_operator", "--run-id=null_conf", "--conf=null"] - ), - ) - # ``null`` conf resolves to None on the client; the API server coerces it to {}. - body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] - assert body.conf is None - - def test_trigger_dag_with_microseconds(self, mock_cli_api_client): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=micro", - "--logical-date=2021-06-04T09:00:00.000001+08:00", - ] - ) - ) - body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] - assert body.logical_date.isoformat(timespec="microseconds") == "2021-06-04T09:00:00.000001+08:00" - - @pytest.mark.parametrize("conf", ["NOT JSON", ""]) - def test_trigger_dag_invalid_conf(self, mock_cli_api_client, conf): - with pytest.raises(ValueError, match=r"Expecting value: line \d+ column \d+ \(char \d+\)"): - dag_command.dag_trigger( - self.parser.parse_args( - ["dags", "trigger", "example_bash_operator", "--run-id", "xxx", "--conf", conf] - ), - ) - mock_cli_api_client.dags.trigger.assert_not_called() - - @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"]) - def test_trigger_dag_rejects_non_object_conf(self, mock_cli_api_client, conf): - with pytest.raises(ValueError, match="DagRun conf must be a JSON object or null"): - dag_command.dag_trigger( - self.parser.parse_args( - ["dags", "trigger", "example_bash_operator", "--run-id", "xxx", "--conf", conf] - ), - ) - mock_cli_api_client.dags.trigger.assert_not_called() - - def test_trigger_dag_output_as_json(self, mock_cli_api_client, stdout_capture): - mock_cli_api_client.dags.trigger.return_value.model_dump.return_value = { - "dag_id": "example_bash_operator", - "dag_run_id": "trigger_dag_xxx", - "conf": {"conf1": "val1", "conf2": "val2"}, - } - args = self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id", - "trigger_dag_xxx", - "--conf", - '{"conf1": "val1", "conf2": "val2"}', - "--output=json", - ] - ) - with stdout_capture as temp_stdout: - dag_command.dag_trigger(args) - out = temp_stdout.getvalue().strip().splitlines()[-1] - parsed_out = json.loads(out) - - assert len(parsed_out) == 1 - assert parsed_out[0]["dag_id"] == "example_bash_operator" - assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx" - assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"} - - def test_delete_dag(self, mock_cli_api_client): - dag_command.dag_delete(self.parser.parse_args(["dags", "delete", "my_dag_id", "--yes"])) - mock_cli_api_client.dags.delete.assert_called_once_with(dag_id="my_dag_id") - - def test_delete_dag_missing(self, mock_cli_api_client): - mock_cli_api_client.dags.delete.side_effect = _server_error(404) - with pytest.raises(ServerResponseError): - dag_command.dag_delete(self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"])) diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py b/airflow-core/tests/unit/cli/commands/test_pool_command.py index 28e2d761812e3..ad6951567fbf6 100644 --- a/airflow-core/tests/unit/cli/commands/test_pool_command.py +++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py @@ -18,235 +18,281 @@ from __future__ import annotations import json -from types import SimpleNamespace -import httpx import pytest -from airflowctl.api.operations import ServerResponseError +from sqlalchemy import delete, func, select +from airflow import models, settings from airflow.cli import cli_parser from airflow.cli.commands import pool_command +from airflow.models import Pool +from airflow.settings import Session +from airflow.utils.db import add_default_pool_if_not_exists -from tests_common.test_utils.config import conf_vars - - -def _pool(name, slots, description="", include_deferred=False, team_name=None): - """Build a stand-in for the airflowctl ``PoolResponse`` returned by the client.""" - return SimpleNamespace( - name=name, - slots=slots, - description=description, - include_deferred=include_deferred, - team_name=team_name, - ) - - -def _server_error(status_code: int) -> ServerResponseError: - request = httpx.Request("GET", "http://testserver/api/v2/pools/foo") - response = httpx.Response(status_code, request=request, json={"detail": "boom"}) - return ServerResponseError(message="boom", request=request, response=response) +pytestmark = pytest.mark.db_test class TestCliPools: @classmethod def setup_class(cls): + cls.dagbag = models.DagBag() cls.parser = cli_parser.get_parser() - - def test_pool_list(self, mock_cli_api_client, stdout_capture): - mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1, "test")] + settings.configure_orm() + cls.session = Session + cls._cleanup() + + def tearDown(self): + self._cleanup() + + @staticmethod + def _cleanup(session=None): + if session is None: + session = Session() + session.execute(delete(Pool).where(Pool.pool != Pool.DEFAULT_POOL_NAME)) + session.commit() + add_default_pool_if_not_exists() + session.close() + + def test_pool_list(self, stdout_capture): + pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) with stdout_capture as stdout: pool_command.pool_list(self.parser.parse_args(["pools", "list"])) assert "foo" in stdout.getvalue() - mock_cli_api_client.pools.list.assert_called_once() - def test_pool_list_with_args(self, mock_cli_api_client): - mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1, "test")] + def test_pool_list_with_args(self): pool_command.pool_list(self.parser.parse_args(["pools", "list", "--output", "json"])) - def test_pool_create(self, mock_cli_api_client): + def test_pool_create(self): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + assert self.session.scalar(select(func.count()).select_from(Pool)) == 2 - mock_cli_api_client.pools.create.assert_called_once() - body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] - # core_api PoolBody exposes the name via the ``pool`` attribute (alias ``name``). - assert body.pool == "foo" - assert body.slots == 1 - assert body.description == "test" - assert body.include_deferred is False + def test_pool_update_deferred(self): + pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False - def test_pool_create_include_deferred(self, mock_cli_api_client): pool_command.pool_set( self.parser.parse_args(["pools", "set", "foo", "1", "test", "--include-deferred"]) ) + assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is True - body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] - assert body.include_deferred is True - - def test_pool_get(self, mock_cli_api_client, stdout_capture): - mock_cli_api_client.pools.get.return_value = _pool("foo", 1, "test") - with stdout_capture as stdout: - pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) - - assert "foo" in stdout.getvalue() - mock_cli_api_client.pools.get.assert_called_once_with(pool_name="foo") - - def test_pool_get_missing(self, mock_cli_api_client): - mock_cli_api_client.pools.get.side_effect = _server_error(404) - with pytest.raises(SystemExit, match="Pool foo does not exist"): - pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) + pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False - def test_pool_get_other_error_reraised(self, mock_cli_api_client): - mock_cli_api_client.pools.get.side_effect = _server_error(500) - with pytest.raises(ServerResponseError): - pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) + def test_pool_get(self): + pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) - def test_pool_delete(self, mock_cli_api_client): + def test_pool_delete(self): + pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"])) - mock_cli_api_client.pools.delete.assert_called_once_with(pool="foo") - - def test_pool_delete_missing(self, mock_cli_api_client): - mock_cli_api_client.pools.delete.side_effect = _server_error(404) - with pytest.raises(SystemExit, match="Pool foo does not exist"): - pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"])) + assert self.session.scalar(select(func.count()).select_from(Pool)) == 1 - def test_pool_import_nonexistent(self, mock_cli_api_client): + def test_pool_import_nonexistent(self): with pytest.raises(SystemExit): pool_command.pool_import(self.parser.parse_args(["pools", "import", "nonexistent.json"])) - def test_pool_import_invalid_json(self, mock_cli_api_client, tmp_path): + def test_pool_import_invalid_json(self, tmp_path): invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json" - invalid_pool_import_file_path.write_text("not valid json") + with open(invalid_pool_import_file_path, mode="w") as file: + file.write("not valid json") with pytest.raises(SystemExit): pool_command.pool_import( self.parser.parse_args(["pools", "import", str(invalid_pool_import_file_path)]) ) - def test_pool_import_invalid_pools(self, mock_cli_api_client, tmp_path): + def test_pool_import_invalid_pools(self, tmp_path): invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json" - # Missing ``slots`` makes the entry invalid. pool_config_input = {"foo": {"description": "foo_test", "include_deferred": False}} - invalid_pool_import_file_path.write_text(json.dumps(pool_config_input)) + with open(invalid_pool_import_file_path, mode="w") as file: + json.dump(pool_config_input, file) with pytest.raises(SystemExit): pool_command.pool_import( self.parser.parse_args(["pools", "import", str(invalid_pool_import_file_path)]) ) - def test_pool_import(self, mock_cli_api_client, tmp_path): + def test_pool_import_backwards_compatibility(self, tmp_path): pool_import_file_path = tmp_path / "pools_import.json" pool_config_input = { - "foo": {"description": "foo_test", "slots": 1, "include_deferred": True}, - # JSON before version 2.7.0 does not contain ``include_deferred``. - "bar": {"description": "bar_test", "slots": 2}, + # JSON before version 2.7.0 does not contain `include_deferred` + "foo": {"description": "foo_test", "slots": 1}, } - pool_import_file_path.write_text(json.dumps(pool_config_input)) + with open(pool_import_file_path, mode="w") as file: + json.dump(pool_config_input, file) pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) - assert mock_cli_api_client.pools.create.call_count == 2 - bodies = { - call.kwargs["pool"].pool: call.kwargs["pool"] - for call in mock_cli_api_client.pools.create.call_args_list - } - assert bodies["foo"].include_deferred is True - # Missing ``include_deferred`` defaults to False (backwards compatibility). - assert bodies["bar"].include_deferred is False + assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False - def test_pool_export(self, mock_cli_api_client, tmp_path): + def test_pool_import_export(self, tmp_path): + pool_import_file_path = tmp_path / "pools_import.json" pool_export_file_path = tmp_path / "pools_export.json" - mock_cli_api_client.pools.list.return_value.pools = [ - _pool("foo", 1, "foo_test", include_deferred=True), - _pool("baz", 2, "baz_test", include_deferred=False), - ] + pool_config_input = { + "foo": {"description": "foo_test", "slots": 1, "include_deferred": True}, + "default_pool": { + "description": "Default pool", + "slots": 128, + "include_deferred": False, + }, + "baz": {"description": "baz_test", "slots": 2, "include_deferred": False}, + } + with open(pool_import_file_path, mode="w") as file: + json.dump(pool_config_input, file) + + # Import json + pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) + # Export json pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) - exported = json.loads(pool_export_file_path.read_text()) - assert exported == { - "foo": {"slots": 1, "description": "foo_test", "include_deferred": True}, - "baz": {"slots": 2, "description": "baz_test", "include_deferred": False}, - } + with open(pool_export_file_path) as file: + pool_config_output = json.load(file) + assert pool_config_input == pool_config_output, "Input and output pool files are not same" - def test_pool_set_with_team_name(self, mock_cli_api_client): - """``--team-name`` is forwarded to the airflowctl client when multi_team is enabled.""" - with conf_vars({("core", "multi_team"): "True"}): - pool_command.pool_set( - self.parser.parse_args( - ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] - ) - ) + def test_pool_set_with_team_name(self): + """Test that pool_set with --team-name assigns the pool to the team when multi_team is enabled.""" + from airflow.models.team import Team - body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] - assert body.team_name == "test_team" + from tests_common.test_utils.config import conf_vars - def test_pool_set_team_name_rejected_when_multi_team_disabled(self, mock_cli_api_client): - """``PoolBody`` rejects a team_name (client-side) when multi_team is disabled.""" - with conf_vars({("core", "multi_team"): "False"}): - with pytest.raises(ValueError, match="team_name cannot be set when multi_team mode is disabled"): + # Create the team first + team = Team(name="test_team") + self.session.add(team) + self.session.commit() + + try: + with conf_vars({("core", "multi_team"): "True"}): pool_command.pool_set( self.parser.parse_args( ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] ) ) - mock_cli_api_client.pools.create.assert_not_called() - def test_pool_set_without_team_name(self, mock_cli_api_client): - """Without ``--team-name`` the forwarded body has ``team_name`` as None.""" + pool = self.session.scalar(select(Pool).where(Pool.pool == "team_pool")) + assert pool is not None + assert pool.team_name == "test_team" + assert pool.slots == 5 + finally: + self.session.execute(delete(Pool).where(Pool.pool == "team_pool")) + self.session.execute(delete(Team).where(Team.name == "test_team")) + self.session.commit() + + def test_pool_set_team_name_rejected_when_multi_team_disabled(self): + """Test that pool_set with --team-name raises when multi_team is disabled.""" + from airflow.models.team import Team + + from tests_common.test_utils.config import conf_vars + + team = Team(name="test_team") + self.session.add(team) + self.session.commit() + + try: + with conf_vars({("core", "multi_team"): "False"}): + with pytest.raises( + ValueError, match="team_name cannot be set when multi_team mode is disabled" + ): + pool_command.pool_set( + self.parser.parse_args( + ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] + ) + ) + finally: + self.session.execute(delete(Pool).where(Pool.pool == "team_pool")) + self.session.execute(delete(Team).where(Team.name == "test_team")) + self.session.commit() + + def test_pool_set_without_team_name(self): + """Test that pool_set without --team-name leaves team_name as None.""" pool_command.pool_set(self.parser.parse_args(["pools", "set", "no_team_pool", "3", "no team"])) - body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] - assert body.team_name is None + pool = self.session.scalar(select(Pool).where(Pool.pool == "no_team_pool")) + assert pool is not None + assert pool.team_name is None - def test_pool_import_forwards_team_name(self, mock_cli_api_client, tmp_path): - """Import forwards each pool's ``team_name`` (or None) to the airflowctl client.""" - pool_import_file_path = tmp_path / "pools_import_team.json" - pool_import_file_path.write_text( - json.dumps( - { - "team_pool_a": { - "slots": 10, - "description": "team pool", - "include_deferred": False, - "team_name": "import_team", - }, - "global_pool": {"slots": 5, "description": "global pool", "include_deferred": False}, - } - ) - ) + def test_pool_import_export_with_team_name(self, tmp_path): + """Test that import/export round-trips the team_name field.""" + from airflow.models.team import Team - with conf_vars({("core", "multi_team"): "True"}): - pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) + from tests_common.test_utils.config import conf_vars - bodies = { - call.kwargs["pool"].pool: call.kwargs["pool"] - for call in mock_cli_api_client.pools.create.call_args_list - } - assert bodies["team_pool_a"].team_name == "import_team" - assert bodies["global_pool"].team_name is None + team = Team(name="import_team") + self.session.add(team) + self.session.commit() - def test_pool_export_includes_team_name(self, mock_cli_api_client, tmp_path): - """Export writes ``team_name`` only for pools that have one.""" + pool_import_file_path = tmp_path / "pools_import_team.json" pool_export_file_path = tmp_path / "pools_export_team.json" - mock_cli_api_client.pools.list.return_value.pools = [ - _pool("team_pool_a", 10, "team pool", team_name="import_team"), - _pool("global_pool", 5, "global pool"), - ] + pool_config_input = { + "team_pool_a": { + "slots": 10, + "description": "team pool", + "include_deferred": False, + "team_name": "import_team", + }, + "global_pool": { + "slots": 5, + "description": "global pool", + "include_deferred": False, + }, + } - pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) + with open(pool_import_file_path, mode="w") as file: + json.dump(pool_config_input, file) - exported = json.loads(pool_export_file_path.read_text()) - assert exported["team_pool_a"]["team_name"] == "import_team" - assert "team_name" not in exported["global_pool"] + try: + with conf_vars({("core", "multi_team"): "True"}): + pool_command.pool_import( + self.parser.parse_args(["pools", "import", str(pool_import_file_path)]) + ) - def test_pool_list_shows_team_name(self, mock_cli_api_client, stdout_capture): - """Pool list output includes the team_name column.""" - mock_cli_api_client.pools.list.return_value.pools = [ - _pool("list_pool", 5, "desc", team_name="list_team") - ] + # Verify team assignment + pool = self.session.scalar(select(Pool).where(Pool.pool == "team_pool_a")) + assert pool is not None + assert pool.team_name == "import_team" - with stdout_capture as stdout: - pool_command.pool_list(self.parser.parse_args(["pools", "list"])) + global_pool = self.session.scalar(select(Pool).where(Pool.pool == "global_pool")) + assert global_pool is not None + assert global_pool.team_name is None + + # Export and verify + pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) + + with open(pool_export_file_path) as file: + pool_config_output = json.load(file) + + assert pool_config_output["team_pool_a"]["team_name"] == "import_team" + assert "team_name" not in pool_config_output["global_pool"] + finally: + self.session.execute(delete(Pool).where(Pool.pool.in_(["team_pool_a", "global_pool"]))) + self.session.execute(delete(Team).where(Team.name == "import_team")) + self.session.commit() + + def test_pool_list_shows_team_name(self, stdout_capture): + """Test that pool list output includes the team_name column.""" + from airflow.models.team import Team + + from tests_common.test_utils.config import conf_vars + + team = Team(name="list_team") + self.session.add(team) + self.session.commit() + + try: + with conf_vars({("core", "multi_team"): "True"}): + pool_command.pool_set( + self.parser.parse_args( + ["pools", "set", "list_pool", "5", "desc", "--team-name", "list_team"] + ) + ) + + with stdout_capture as stdout: + pool_command.pool_list(self.parser.parse_args(["pools", "list"])) - assert "list_team" in stdout.getvalue() + output = stdout.getvalue() + assert "list_team" in output + finally: + self.session.execute(delete(Pool).where(Pool.pool == "list_pool")) + self.session.execute(delete(Team).where(Team.name == "list_team")) + self.session.commit() diff --git a/airflow-core/tests/unit/cli/conftest.py b/airflow-core/tests/unit/cli/conftest.py index d9d2ae341eb51..7676a103b5363 100644 --- a/airflow-core/tests/unit/cli/conftest.py +++ b/airflow-core/tests/unit/cli/conftest.py @@ -18,7 +18,6 @@ from __future__ import annotations import sys -from unittest import mock import pytest @@ -69,25 +68,6 @@ def parser(): # log messages -@pytest.fixture -def mock_cli_api_client(): - """Mock the CLI airflowctl client and neutralize ``action_cli``'s DB touch points. - - CLI commands that go through the airflowctl client only need the mocked client; the - ``@action_cli`` audit logging and log-template sync would otherwise open a database - session. Patching them lets these command tests run without a database or API server. - """ - client = mock.MagicMock() - with ( - mock.patch("airflow.cli.api_client.get_cli_api_client", return_value=client), - mock.patch("airflow.utils.cli_action_loggers.on_pre_execution"), - mock.patch("airflow.utils.cli_action_loggers.on_post_execution"), - mock.patch("airflow.utils.db.synchronize_log_template"), - mock.patch("airflow.utils.db.check_and_run_migrations"), - ): - yield client - - @pytest.fixture def stdout_capture(request): """Fixture that captures stdout only.""" diff --git a/airflow-core/tests/unit/cli/test_api_client.py b/airflow-core/tests/unit/cli/test_api_client.py deleted file mode 100644 index 3ef813cebda43..0000000000000 --- a/airflow-core/tests/unit/cli/test_api_client.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -import pytest - -from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager -from airflow.api_fastapi.auth.managers.simple.simple_auth_manager import SimpleAuthManager -from airflow.cli import api_client as cli_api_client - -from tests_common.test_utils.config import conf_vars - - -@pytest.fixture(autouse=True) -def _reset_singleton(): - """Reset the process-wide client singleton around each test.""" - cli_api_client._api_client = None - yield - cli_api_client._api_client = None - - -class TestResolveBaseUrl: - @conf_vars({("api", "base_url"): "https://airflow.example.com:9999"}) - def test_explicit_base_url(self): - assert cli_api_client._resolve_base_url() == "https://airflow.example.com:9999" - - @conf_vars({("api", "base_url"): "", ("api", "host"): "myhost", ("api", "port"): "1234"}) - def test_host_port_fallback(self): - assert cli_api_client._resolve_base_url() == "http://myhost:1234" - - -class TestMintCliToken: - def test_uses_env_token(self, monkeypatch): - monkeypatch.setenv("AIRFLOW_CLI_TOKEN", "tok-123") - with mock.patch("airflow.api_fastapi.app.get_auth_manager") as get_auth_manager: - assert cli_api_client._mint_cli_token() == "tok-123" - # The auth manager is never consulted when a token is supplied explicitly. - get_auth_manager.assert_not_called() - - def test_mints_via_auth_manager(self, monkeypatch): - monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False) - auth_manager = mock.MagicMock() - auth_manager.get_cli_user.return_value = "cli-user" - auth_manager.generate_jwt.return_value = "signed-jwt" - with mock.patch("airflow.api_fastapi.app.get_auth_manager", return_value=auth_manager): - assert cli_api_client._mint_cli_token() == "signed-jwt" - - auth_manager.generate_jwt.assert_called_once() - assert auth_manager.generate_jwt.call_args.args[0] == "cli-user" - # Token must be short-lived. - assert auth_manager.generate_jwt.call_args.kwargs["expiration_time_in_seconds"] > 0 - - def test_initializes_auth_manager_when_not_initialized(self, monkeypatch): - # In the CLI the auth manager singleton is usually not initialized yet, so - # ``get_auth_manager`` raises and we must initialize it on demand. - monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False) - auth_manager = mock.MagicMock() - auth_manager.get_cli_user.return_value = "cli-user" - auth_manager.generate_jwt.return_value = "signed-jwt" - with ( - mock.patch( - "airflow.api_fastapi.app.get_auth_manager", - side_effect=RuntimeError("Auth Manager has not been initialized yet."), - ), - mock.patch( - "airflow.api_fastapi.app.init_auth_manager", return_value=auth_manager - ) as init_auth_manager, - ): - assert cli_api_client._mint_cli_token() == "signed-jwt" - - init_auth_manager.assert_called_once() - auth_manager.generate_jwt.assert_called_once() - - -class TestGetCliApiClient: - def test_builds_singleton(self): - with ( - mock.patch.object(cli_api_client, "_resolve_base_url", return_value="http://h:8080"), - mock.patch.object(cli_api_client, "_mint_cli_token", return_value="tok"), - mock.patch.object(cli_api_client, "Client") as client_cls, - ): - first = cli_api_client.get_cli_api_client() - second = cli_api_client.get_cli_api_client() - - assert first is second - client_cls.assert_called_once() - kwargs = client_cls.call_args.kwargs - assert kwargs["base_url"] == "http://h:8080" - assert kwargs["token"] == "tok" - assert kwargs["kind"] == cli_api_client.ClientKind.CLI - - -class TestProvideApiClient: - def test_injects_when_missing(self): - with mock.patch.object(cli_api_client, "get_cli_api_client", return_value="CLIENT"): - - @cli_api_client.provide_api_client - def command(args, api_client=None): - return api_client - - assert command("args") == "CLIENT" - - def test_uses_explicit_client(self): - with mock.patch.object(cli_api_client, "get_cli_api_client") as get_client: - - @cli_api_client.provide_api_client - def command(args, api_client=None): - return api_client - - assert command("args", api_client="EXPLICIT") == "EXPLICIT" - get_client.assert_not_called() - - -class TestGetCliUser: - def test_base_default_raises(self): - # The generic auth manager cannot know which user is authorized for the CLI. - with pytest.raises(NotImplementedError, match="AIRFLOW_CLI_TOKEN"): - BaseAuthManager.get_cli_user(mock.Mock()) - - def test_simple_auth_manager_returns_admin(self): - user = SimpleAuthManager.get_cli_user(mock.Mock()) - assert user.get_id() == "cli" - assert user.role == "ADMIN" diff --git a/airflow-core/tests/unit/cli/test_utils.py b/airflow-core/tests/unit/cli/test_utils.py index 4fb137ad48201..f98a9ef6b193d 100644 --- a/airflow-core/tests/unit/cli/test_utils.py +++ b/airflow-core/tests/unit/cli/test_utils.py @@ -17,32 +17,34 @@ # under the License. from __future__ import annotations -import pytest +import warnings from airflow.cli.utils import deprecated_for_airflowctl -from airflow.exceptions import RemovedInAirflow4Warning class TestDeprecatedForAirflowctl: - def test_emits_warning_naming_replacement(self): + def test_records_replacement_without_emitting_a_user_warning(self): @deprecated_for_airflowctl("airflowctl dags trigger") def command(args): return "result" - with pytest.warns(RemovedInAirflow4Warning, match="airflowctl dags trigger"): + # Calling the command emits nothing to users (any warning would become an error here). + with warnings.catch_warnings(): + warnings.simplefilter("error") result = command(args=None) - # The wrapped command still runs and returns its value. assert result == "result" + # The replacement is recorded for maintainers, not shown to users. + assert command._migrated_to_airflowctl == "airflowctl dags trigger" - def test_passes_through_args_and_preserves_metadata(self): + def test_passes_through_args_and_leaves_function_untouched(self): @deprecated_for_airflowctl("airflowctl pools create") def command(a, b, *, c): """Original docstring.""" return (a, b, c) - with pytest.warns(RemovedInAirflow4Warning): - assert command(1, 2, c=3) == (1, 2, 3) - + assert command(1, 2, c=3) == (1, 2, 3) + # The decorator returns the original function untouched apart from the metadata it records. assert command.__name__ == "command" assert command.__doc__ == "Original docstring." + assert command._migrated_to_airflowctl == "airflowctl pools create" diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 8ea13d44a7ee2..b6a79d032ae3c 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -301,30 +301,6 @@ def _fetch_user() -> User: def serialize_user(self, user: User) -> dict[str, Any]: return {"sub": str(user.id)} - def get_cli_user(self) -> User: - """ - Return an existing ``Admin`` user for the local CLI to mint a token for. - - The Airflow CLI mints a short-lived, in-memory JWT for this user so it can talk to - the API server. FAB tokens reference a real database user, so we reuse an existing - ``Admin`` user rather than fabricating one. If none exists, the operator must - create one or provide a token via the ``AIRFLOW_CLI_TOKEN`` environment variable. - """ - from airflow.utils.session import create_session - - with create_session() as session: - user = session.scalars(select(User).join(User.roles).where(Role.name == "Admin").limit(1)).first() - if user is None: - raise AirflowConfigException( - "No user with the 'Admin' role exists in the FAB database. Create one " - "(e.g. `airflow fab create-user --role Admin ...`) or set the " - "AIRFLOW_CLI_TOKEN environment variable with a valid API token." - ) - # Detach so attributes stay accessible after the session closes (and is not - # expired on commit) while the CLI serializes the user to mint the token. - session.expunge(user) - return user - def is_logged_in(self) -> bool: """Return whether the user is logged in.""" user = self.get_user() diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index 5b6d27cecafc7..62f9856701290 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -1332,29 +1332,3 @@ def test_session_remove_generic_error_does_not_propagate(self, mock_create_app): mock_session.remove.assert_called() mock_log.warning.assert_called() assert response is not None - - -class TestFabGetCliUser: - """``get_cli_user`` reuses an existing ``Admin`` user for the local CLI token.""" - - @mock.patch("airflow.utils.session.create_session") - def test_returns_admin_user(self, mock_create_session, auth_manager): - admin_user = MagicMock() - session = MagicMock() - session.scalars.return_value.first.return_value = admin_user - mock_create_session.return_value.__enter__.return_value = session - - result = auth_manager.get_cli_user() - - assert result is admin_user - # The user is detached so its attributes survive the session closing. - session.expunge.assert_called_once_with(admin_user) - - @mock.patch("airflow.utils.session.create_session") - def test_raises_when_no_admin_user(self, mock_create_session, auth_manager): - session = MagicMock() - session.scalars.return_value.first.return_value = None - mock_create_session.return_value.__enter__.return_value = session - - with pytest.raises(AirflowConfigException, match="Admin"): - auth_manager.get_cli_user() diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index bc9ebc305c306..a8cd683ef46ba 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -34,7 +34,7 @@ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager -from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowProviderDeprecationWarning try: from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod @@ -141,34 +141,6 @@ def serialize_user(self, user: KeycloakAuthManagerUser) -> dict[str, Any]: "refresh_token": user.refresh_token, } - def get_cli_user(self) -> KeycloakAuthManagerUser: - """ - Return a service-account user for the local CLI to mint a token for. - - Keycloak tokens are issued by the external Keycloak server, so they cannot be - forged locally. The Keycloak client is already configured for Airflow to talk to - Keycloak, so we reuse it to obtain a service-account token through the - ``client_credentials`` flow. The service account's effective permissions are - governed by the Keycloak deployment. If the client credentials are not usable, the - operator must provide a token via the ``AIRFLOW_CLI_TOKEN`` environment variable. - """ - try: - tokens = self.get_keycloak_client().token(grant_type="client_credentials") - except Exception as e: - raise AirflowConfigException( - "Could not obtain a Keycloak service-account token for the CLI via the " - "client_credentials flow. Set the AIRFLOW_CLI_TOKEN environment variable " - f"with a valid API token instead. Original error: {e}" - ) from e - return KeycloakAuthManagerUser( - user_id="airflow-cli", - name="airflow-cli", - access_token=tokens["access_token"], - # No refresh token is issued for the client_credentials flow (RFC 6749 §4.4.3), - # which marks this as a service account in refresh_user/refresh_tokens. - refresh_token=tokens.get("refresh_token"), - ) - def get_url_login(self, **kwargs) -> str: base_url = conf.get("api", "base_url", fallback="/") return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login") diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py index e4b6a5c294de8..9c8ed9dd5b610 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py @@ -45,7 +45,7 @@ else: TeamDetails = None # type: ignore[assignment,misc] from airflow.api_fastapi.common.types import MenuItem -from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowProviderDeprecationWarning try: from airflow.providers.common.compat.sdk import AirflowException @@ -121,25 +121,6 @@ def _clear_filter_cache(): class TestKeycloakAuthManager: - @patch.object(KeycloakAuthManager, "get_keycloak_client") - def test_get_cli_user(self, mock_get_keycloak_client, auth_manager): - # client_credentials (service account) flow returns an access token and no refresh token. - mock_get_keycloak_client.return_value.token.return_value = {"access_token": "svc-token"} - - user = auth_manager.get_cli_user() - - assert user.get_id() == "airflow-cli" - assert user.access_token == "svc-token" - assert user.refresh_token is None - mock_get_keycloak_client.return_value.token.assert_called_once_with(grant_type="client_credentials") - - @patch.object(KeycloakAuthManager, "get_keycloak_client") - def test_get_cli_user_raises_when_credentials_unusable(self, mock_get_keycloak_client, auth_manager): - mock_get_keycloak_client.return_value.token.side_effect = Exception("boom") - - with pytest.raises(AirflowConfigException, match="AIRFLOW_CLI_TOKEN"): - auth_manager.get_cli_user() - def test_deserialize_user(self, auth_manager): result = auth_manager.deserialize_user( { diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index 262f5d6ce547f..fec615e061f2b 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -1,7 +1,7 @@ airflow-core/src/airflow/api/common/delete_dag.py::1 airflow-core/src/airflow/api_fastapi/core_api/app.py::1 airflow-core/src/airflow/cli/cli_parser.py::1 -airflow-core/src/airflow/cli/commands/dag_command.py::1 +airflow-core/src/airflow/cli/commands/dag_command.py::3 airflow-core/src/airflow/cli/commands/db_command.py::1 airflow-core/src/airflow/config_templates/airflow_local_settings.py::1 airflow-core/src/airflow/dag_processing/dagbag.py::1 diff --git a/uv.lock b/uv.lock index 865064ec1fac4..709e8ee883e4d 100644 --- a/uv.lock +++ b/uv.lock @@ -1917,7 +1917,6 @@ dependencies = [ { name = "a2wsgi" }, { name = "aiosqlite" }, { name = "alembic" }, - { name = "apache-airflow-ctl" }, { name = "apache-airflow-providers-common-compat" }, { name = "apache-airflow-providers-common-io" }, { name = "apache-airflow-providers-common-sql" }, @@ -2046,7 +2045,6 @@ requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0,<0.22.0" }, { name = "alembic", specifier = ">=1.13.1,<2.0" }, { name = "apache-airflow-core", extras = ["graphviz", "gunicorn", "kerberos", "otel", "statsd"], marker = "extra == 'all'", editable = "airflow-core" }, - { name = "apache-airflow-ctl", editable = "airflow-ctl" }, { name = "apache-airflow-providers-common-compat", editable = "providers/common/compat" }, { name = "apache-airflow-providers-common-io", editable = "providers/common/io" }, { name = "apache-airflow-providers-common-sql", editable = "providers/common/sql" },