diff --git a/airflow-core/newsfragments/68175.significant.rst b/airflow-core/newsfragments/68175.significant.rst new file mode 100644 index 0000000000000..4c9f7b4dffc18 --- /dev/null +++ b/airflow-core/newsfragments/68175.significant.rst @@ -0,0 +1,24 @@ +Airflow CLI commands are moving to talk to the API server + +The CLI is being migrated to reach Airflow through the API server (via the ``airflowctl`` +client) instead of the metadata database directly. Migrated so far: ``dags trigger``, +``dags delete``, ``pools`` (list/get/set/delete/import/export), and ``assets materialize``; +this fragment is updated as more commands migrate rather than adding new ones. + +These commands now require a reachable API server and mint a short-lived token in memory +(set ``AIRFLOW_CLI_TOKEN`` for auth managers that cannot mint locally, or for remote servers). +``airflow.api.client`` is removed — use ``airflow.cli.api_client.get_cli_api_client``. + +Each migrated command emits a ``RemovedInAirflow4Warning`` and will be removed in a future +Airflow release; use the equivalent ``airflowctl`` command instead. + +* Types of change + + * [ ] Dag changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes diff --git a/airflow-core/src/airflow/api/client/__init__.py b/airflow-core/src/airflow/api/client/__init__.py deleted file mode 100644 index f0d236b9019ab..0000000000000 --- a/airflow-core/src/airflow/api/client/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""API Client that allows interacting with Airflow API.""" - -from __future__ import annotations - -from airflow.api.client.local_client import Client - - -def get_current_api_client() -> Client: - return Client() diff --git a/airflow-core/src/airflow/api/client/local_client.py b/airflow-core/src/airflow/api/client/local_client.py deleted file mode 100644 index 057d6d99c7cfd..0000000000000 --- a/airflow-core/src/airflow/api/client/local_client.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Local client API.""" - -from __future__ import annotations - -import httpx - -from airflow.api.common import delete_dag, trigger_dag -from airflow.exceptions import AirflowBadRequest, PoolNotFound -from airflow.models.pool import Pool -from airflow.utils.types import DagRunTriggeredByType - - -class Client: - """Local API client implementation.""" - - def __init__(self, auth=None, session: httpx.Client | None = None): - self._session: httpx.Client = session or httpx.Client() - if auth: - self._session.auth = auth - - def trigger_dag( - self, - dag_id, - run_id=None, - conf=None, - logical_date=None, - triggering_user_name=None, - replace_microseconds=True, - ) -> dict | None: - dag_run = trigger_dag.trigger_dag( - dag_id=dag_id, - triggered_by=DagRunTriggeredByType.CLI, - triggering_user_name=triggering_user_name, - run_id=run_id, - conf=conf, - logical_date=logical_date, - replace_microseconds=replace_microseconds, - ) - if dag_run: - return { - "conf": dag_run.conf, - "dag_id": dag_run.dag_id, - "dag_run_id": dag_run.run_id, - "data_interval_start": dag_run.data_interval_start, - "data_interval_end": dag_run.data_interval_end, - "end_date": dag_run.end_date, - "last_scheduling_decision": dag_run.last_scheduling_decision, - "logical_date": dag_run.logical_date, - "run_type": dag_run.run_type, - "start_date": dag_run.start_date, - "state": dag_run.state, - "triggering_user_name": dag_run.triggering_user_name, - } - return dag_run - - def delete_dag(self, dag_id): - count = delete_dag.delete_dag(dag_id) - return f"Removed {count} record(s)" - - def get_pool(self, name): - pool = Pool.get_pool(pool_name=name) - if not pool: - raise PoolNotFound(f"Pool {name} not found") - return pool.pool, pool.slots, pool.description, pool.include_deferred, pool.team_name - - def get_pools(self): - return [(p.pool, p.slots, p.description, p.include_deferred, p.team_name) for p in Pool.get_pools()] - - def create_pool(self, name, slots, description, include_deferred, team_name=None): - if not (name and name.strip()): - raise AirflowBadRequest("Pool name shouldn't be empty") - pool_name_length = Pool.pool.property.columns[0].type.length - if len(name) > pool_name_length: - raise AirflowBadRequest(f"Pool name cannot be more than {pool_name_length} characters") - try: - slots = int(slots) - except ValueError: - raise AirflowBadRequest(f"Invalid value for `slots`: {slots}") - pool = Pool.create_or_update_pool( - name=name, - slots=slots, - description=description, - include_deferred=include_deferred, - team_name=team_name, - ) - return pool.pool, pool.slots, pool.description, pool.team_name - - def delete_pool(self, name): - pool = Pool.delete_pool(name=name) - return pool.pool, pool.slots, pool.description diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index 4ba08e5b447e5..7f62d8af27c33 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -180,6 +180,22 @@ def generate_jwt( self.serialize_user(user) ) + def get_cli_user(self) -> T: + """ + Return the user the local CLI acts as when calling the API server. + + The Airflow CLI mints a short-lived JWT for this user (via :meth:`generate_jwt`) + so it can talk to the API server without persisting any credentials. A generic + auth manager cannot know which user is authorized for local CLI access, so the + default raises. Auth managers that support local CLI usage should override this + to return an administrative user. Otherwise, operators must provide a token via + the ``AIRFLOW_CLI_TOKEN`` environment variable. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support minting a local CLI token. " + "Set the AIRFLOW_CLI_TOKEN environment variable with a valid API token instead." + ) + @abstractmethod def get_url_login(self, **kwargs) -> str: """Return the login page url.""" diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py index 0559a388156d5..0deaadf40346b 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py @@ -238,6 +238,9 @@ def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser: def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]: return {"sub": user.username, "role": user.role, "teams": user.teams} + def get_cli_user(self) -> SimpleAuthManagerUser: + return SimpleAuthManagerUser(username="cli", role=SimpleAuthManagerRole.ADMIN.name) + def is_authorized_configuration( self, *, diff --git a/airflow-core/src/airflow/cli/api_client.py b/airflow-core/src/airflow/cli/api_client.py new file mode 100644 index 0000000000000..d1ff5e2ddd93c --- /dev/null +++ b/airflow-core/src/airflow/cli/api_client.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provide the :mod:`airflowctl` HTTP API client to the local Airflow CLI. + +The local CLI talks to the API server through the same typed client that ``airflowctl`` +uses, but without the keyring-backed credential store. For each invocation it mints a +short-lived JWT **in memory** (via the active auth manager) and builds a client with it; +nothing is persisted. Set the ``AIRFLOW_CLI_TOKEN`` environment variable to supply a token +explicitly (required for auth managers whose tokens cannot be minted locally, such as +Keycloak, or when targeting a remote API server). +""" + +from __future__ import annotations + +import atexit +import os +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, TypeVar + +import httpx + +# Re-exported so command modules import the client surface from a single place. +from airflowctl.api.client import NEW_API_CLIENT, Client, ClientKind + +from airflow.configuration import conf +from airflow.typing_compat import ParamSpec + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager + +__all__ = [ + "NEW_API_CLIENT", + "Client", + "ClientKind", + "get_cli_api_client", + "provide_api_client", +] + +PS = ParamSpec("PS") +RT = TypeVar("RT") + +# Validity of the in-memory CLI token. It only needs to outlive a single CLI command +# (including the client's request retries) and is never persisted or logged. +_CLI_TOKEN_VALID_FOR_SECONDS = 300 + +_api_client: Client | None = None + + +def _resolve_base_url() -> str: + """Resolve the API server base URL from configuration.""" + base_url = conf.get("api", "base_url", fallback=None) + if base_url: + return base_url + host = conf.get("api", "host", fallback="localhost") or "localhost" + port = conf.get("api", "port", fallback="8080") or "8080" + return f"http://{host}:{port}" + + +def _mint_cli_token() -> str: + """ + Return a token for the CLI to authenticate against the API server. + + Prefers an explicit ``AIRFLOW_CLI_TOKEN`` (the universal override), otherwise mints a + short-lived JWT through the active auth manager. The token lives only in this process. + """ + if token := os.environ.get("AIRFLOW_CLI_TOKEN"): + return token + + from airflow.api_fastapi.app import get_auth_manager, init_auth_manager + + # The CLI runs outside the API server, so the auth manager singleton is usually not + # initialized yet; initialize it on demand. ``init_auth_manager`` reuses the cached + # instance when one already exists, so this is safe to call here. + try: + auth_manager: BaseAuthManager = get_auth_manager() + except RuntimeError: + auth_manager = init_auth_manager() + return auth_manager.generate_jwt( + auth_manager.get_cli_user(), + expiration_time_in_seconds=_CLI_TOKEN_VALID_FOR_SECONDS, + ) + + +def get_cli_api_client() -> Client: + """Return the process-wide singleton airflowctl client for the local CLI.""" + global _api_client + if _api_client is None: + _api_client = Client( + base_url=_resolve_base_url(), + token=_mint_cli_token(), + kind=ClientKind.CLI, + limits=httpx.Limits(max_keepalive_connections=1, max_connections=1), + ) + atexit.register(_api_client.close) + return _api_client + + +def provide_api_client(func: Callable[PS, RT]) -> Callable[PS, RT]: + """ + Provide the CLI API client to the decorated command function. + + Injects ``api_client=get_cli_api_client()`` when the caller does not pass one. Tests + (or callers that already hold a client) pass ``api_client=`` explicitly to bypass it. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> RT: + if "api_client" not in kwargs: + kwargs["api_client"] = get_cli_api_client() + return func(*args, **kwargs) + + return wrapper diff --git a/airflow-core/src/airflow/cli/commands/asset_command.py b/airflow-core/src/airflow/cli/commands/asset_command.py index 15430b76488d6..29c1025958ab6 100644 --- a/airflow-core/src/airflow/cli/commands/asset_command.py +++ b/airflow-core/src/airflow/cli/commands/asset_command.py @@ -17,21 +17,17 @@ from __future__ import annotations -import logging import typing from sqlalchemy import select -from airflow.api.common.trigger_dag import trigger_dag from airflow.api_fastapi.core_api.datamodels.assets import AssetAliasResponse, AssetResponse -from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse +from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client from airflow.cli.simple_table import AirflowConsole -from airflow.exceptions import AirflowConfigException -from airflow.models.asset import AssetAliasModel, AssetModel, TaskOutletAssetReference +from airflow.cli.utils import deprecated_for_airflowctl +from airflow.models.asset import AssetAliasModel, AssetModel from airflow.utils import cli as cli_utils -from airflow.utils.platform import getuser from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import DagRunTriggeredByType, DagRunType if typing.TYPE_CHECKING: from typing import Any @@ -40,8 +36,6 @@ from airflow.api_fastapi.core_api.base import BaseModel -log = logging.getLogger(__name__) - def _list_asset_aliases(args, *, session: Session) -> tuple[Any, type[BaseModel]]: aliases = session.scalars(select(AssetAliasModel).order_by(AssetAliasModel.name)) @@ -49,7 +43,13 @@ def _list_asset_aliases(args, *, session: Session) -> tuple[Any, type[BaseModel] def _list_assets(args, *, session: Session) -> tuple[Any, type[BaseModel]]: - assets = session.scalars(select(AssetModel).order_by(AssetModel.name)) + assets = session.scalars(select(AssetModel).order_by(AssetModel.name)).all() + for asset in assets: + for watcher in asset.watchers: + # ``AssetWatcherModel`` has no ``created_date`` column; like the public API + # serializer, derive it from the watcher's trigger so ``AssetResponse`` validation + # succeeds. Set on the instance so ``model_validate`` reads it via ``from_attributes``. + watcher.created_date = watcher.trigger.created_date return assets, AssetResponse @@ -124,48 +124,38 @@ def asset_details(args, *, session: Session = NEW_SESSION) -> None: @cli_utils.action_cli -@provide_session -def asset_materialize(args, *, session: Session = NEW_SESSION) -> None: +@deprecated_for_airflowctl("airflowctl assets materialize") +@provide_api_client +def asset_materialize(args, api_client: Client = NEW_API_CLIENT) -> None: """ Materialize the specified asset. This is done by finding the DAG with the asset defined as outlet, and create - a run for that DAG. + a run for that DAG. Resolving the DAG and creating the run is handled by the API + server; the asset is identified here by its name and/or URI. """ if not args.name and not args.uri: raise SystemExit("Either --name or --uri is required") - stmt = select(TaskOutletAssetReference.dag_id).join(TaskOutletAssetReference.asset) select_message_parts = [] if args.name: - stmt = stmt.where(AssetModel.name == args.name) select_message_parts.append(f"name {args.name}") if args.uri: - stmt = stmt.where(AssetModel.uri == args.uri) select_message_parts.append(f"URI {args.uri}") - dag_id_it = iter(session.scalars(stmt.group_by(TaskOutletAssetReference.dag_id).limit(2))) select_message = " and ".join(select_message_parts) - if (dag_id := next(dag_id_it, None)) is None: + matches = [ + asset + for asset in api_client.assets.list().assets + if (not args.name or asset.name == args.name) and (not args.uri or asset.uri == args.uri) + ] + if not matches: raise SystemExit(f"Asset with {select_message} does not exist.") - if next(dag_id_it, None) is not None: - raise SystemExit(f"More than one DAG materializes asset with {select_message}.") - - try: - user = getuser() - except AirflowConfigException as e: - log.warning("Failed to get user name from os: %s, not setting the triggering user", e) - user = None - dagrun = trigger_dag( - dag_id=dag_id, - triggered_by=DagRunTriggeredByType.CLI, - run_type=DagRunType.ASSET_MATERIALIZATION, - triggering_user_name=user, - session=session, - ) - if dagrun is not None: - data = [DAGRunResponse.model_validate(dagrun).model_dump(mode="json")] - else: - data = [] + if len(matches) > 1: + raise SystemExit(f"More than one asset exists with {select_message}.") - AirflowConsole().print_as(data=data, output=args.output) + dag_run = api_client.assets.materialize(asset_id=str(matches[0].id)) + AirflowConsole().print_as( + data=[dag_run.model_dump(mode="json")], + output=args.output, + ) diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index d9d8388055af9..1d5e423757680 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -33,14 +33,15 @@ from sqlalchemy import func, select from airflow._shared.timezones import timezone -from airflow.api.client import get_current_api_client +from airflow.api_fastapi.core_api.datamodels.dag_run import TriggerDAGRunPostBody from airflow.api_fastapi.core_api.datamodels.dags import DAGResponse +from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client from airflow.cli.simple_table import AirflowConsole -from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string +from airflow.cli.utils import deprecated_for_airflowctl, fetch_dag_run_from_run_id_or_logical_date_string from airflow.dag_processing.bundles.base import unpack_bundle_version from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.dagbag import BundleDagBag, DagBag, sync_bag_to_db -from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.exceptions import AirflowException from airflow.jobs.job import Job from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.errors import ParseImportError @@ -55,7 +56,6 @@ ) from airflow.utils.dot_renderer import render_dag, render_dag_dependencies from airflow.utils.helpers import ask_yesno -from airflow.utils.platform import getuser from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState @@ -77,47 +77,41 @@ @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl dags trigger") @providers_configuration_loaded -def dag_trigger(args) -> None: +@provide_api_client +def dag_trigger(args, api_client: Client = NEW_API_CLIENT) -> None: """Create a dag run for the specified dag.""" - api_client = get_current_api_client() - try: - user = getuser() - except AirflowConfigException as e: - log.warning("Failed to get user name from os: %s, not setting the triggering user", e) - user = None - try: - message = api_client.trigger_dag( - dag_id=args.dag_id, - run_id=args.run_id, - conf=args.conf, - logical_date=args.logical_date, - triggering_user_name=user, - replace_microseconds=args.replace_microseconds, - ) - AirflowConsole().print_as( - data=[message] if message is not None else [], - output=args.output, - ) - except OSError as err: - raise AirflowException(err) + run_conf = json.loads(args.conf) if args.conf is not None else None + if run_conf is not None and not isinstance(run_conf, dict): + raise ValueError("DagRun conf must be a JSON object or null") + # The core_api request models are the source of truth; they are wire-compatible with + # the airflowctl client's generated models (the API server uses populate_by_name). + trigger_body = TriggerDAGRunPostBody( + dag_run_id=args.run_id, + conf=run_conf, + logical_date=args.logical_date, + ) + dag_run = api_client.dags.trigger(dag_id=args.dag_id, trigger_dag_run=trigger_body) # type: ignore[arg-type] + AirflowConsole().print_as( + data=[dag_run.model_dump(mode="json")], + output=args.output, + ) @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl dags delete") @providers_configuration_loaded -def dag_delete(args) -> None: +@provide_api_client +def dag_delete(args, api_client: Client = NEW_API_CLIENT) -> None: """Delete all DB records related to the specified dag.""" - api_client = get_current_api_client() if ( args.yes or input("This will drop all existing records related to the specified DAG. Proceed? (y/n)").upper() == "Y" ): - try: - message = api_client.delete_dag(dag_id=args.dag_id) - print(message) - except OSError as err: - raise AirflowException(err) + api_client.dags.delete(dag_id=args.dag_id) + print(f"Removed DAG {args.dag_id}") else: print("Cancelled") diff --git a/airflow-core/src/airflow/cli/commands/pool_command.py b/airflow-core/src/airflow/cli/commands/pool_command.py index c2e624d23a6aa..a51351f73bdc8 100644 --- a/airflow-core/src/airflow/cli/commands/pool_command.py +++ b/airflow-core/src/airflow/cli/commands/pool_command.py @@ -23,9 +23,12 @@ import os from json import JSONDecodeError -from airflow.api.client import get_current_api_client +from airflowctl.api.operations import ServerResponseError + +from airflow.api_fastapi.core_api.datamodels.pools import PoolBody +from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client from airflow.cli.simple_table import AirflowConsole -from airflow.exceptions import PoolNotFound +from airflow.cli.utils import deprecated_for_airflowctl from airflow.utils import cli as cli_utils from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -36,66 +39,78 @@ def _show_pools(pools, output): data=pools, output=output, mapper=lambda x: { - "pool": x[0], - "slots": x[1], - "description": x[2], - "include_deferred": x[3], - "team_name": x[4], + "pool": x.name, + "slots": x.slots, + "description": x.description, + "include_deferred": x.include_deferred, + "team_name": x.team_name, }, ) +@deprecated_for_airflowctl("airflowctl pools list") @suppress_logs_and_warning @providers_configuration_loaded -def pool_list(args): +@provide_api_client +def pool_list(args, api_client: Client = NEW_API_CLIENT): """Display info of all the pools.""" - api_client = get_current_api_client() - pools = api_client.get_pools() + pools = api_client.pools.list().pools _show_pools(pools=pools, output=args.output) +@deprecated_for_airflowctl("airflowctl pools get") @suppress_logs_and_warning @providers_configuration_loaded -def pool_get(args): +@provide_api_client +def pool_get(args, api_client: Client = NEW_API_CLIENT): """Display pool info by a given name.""" - api_client = get_current_api_client() try: - pools = [api_client.get_pool(name=args.pool)] + pools = [api_client.pools.get(pool_name=args.pool)] _show_pools(pools=pools, output=args.output) - except PoolNotFound: - raise SystemExit(f"Pool {args.pool} does not exist") + except ServerResponseError as e: + if e.response.status_code == 404: + raise SystemExit(f"Pool {args.pool} does not exist") + raise @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl pools create") @suppress_logs_and_warning @providers_configuration_loaded -def pool_set(args): +@provide_api_client +def pool_set(args, api_client: Client = NEW_API_CLIENT): """Create new pool with a given name and slots.""" - api_client = get_current_api_client() - api_client.create_pool( + # core_api PoolBody is the source of truth and is wire-compatible with the airflowctl + # client's generated model (the API server uses populate_by_name). + pool_body = PoolBody( name=args.pool, slots=args.slots, description=args.description, include_deferred=args.include_deferred, team_name=args.team_name, ) + api_client.pools.create(pool=pool_body) # type: ignore[arg-type] print(f"Pool {args.pool} created") @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl pools delete") @suppress_logs_and_warning @providers_configuration_loaded -def pool_delete(args): +@provide_api_client +def pool_delete(args, api_client: Client = NEW_API_CLIENT): """Delete pool by a given name.""" - api_client = get_current_api_client() try: - api_client.delete_pool(name=args.pool) + api_client.pools.delete(pool=args.pool) print(f"Pool {args.pool} deleted") - except PoolNotFound: - raise SystemExit(f"Pool {args.pool} does not exist") + except ServerResponseError as e: + if e.response.status_code == 404: + raise SystemExit(f"Pool {args.pool} does not exist") + raise @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl pools import") @suppress_logs_and_warning @providers_configuration_loaded def pool_import(args): @@ -108,6 +123,7 @@ def pool_import(args): print(f"Uploaded {len(pools)} pool(s)") +@deprecated_for_airflowctl("airflowctl pools export") @providers_configuration_loaded def pool_export(args): """Export all the pools to the file.""" @@ -115,10 +131,9 @@ def pool_export(args): print(f"Exported {len(pools)} pools to {args.file}") -def pool_import_helper(filepath): +@provide_api_client +def pool_import_helper(filepath, api_client: Client = NEW_API_CLIENT): """Help import pools from the json file.""" - api_client = get_current_api_client() - with open(filepath) as poolfile: data = poolfile.read() try: @@ -129,34 +144,33 @@ def pool_import_helper(filepath): failed = [] for k, v in pools_json.items(): if isinstance(v, dict) and "slots" in v and "description" in v: - pools.append( - api_client.create_pool( - name=k, - slots=v["slots"], - description=v["description"], - include_deferred=v.get("include_deferred", False), - team_name=v.get("team_name"), - ) + pool_body = PoolBody( + name=k, + slots=v["slots"], + description=v["description"], + include_deferred=v.get("include_deferred", False), + team_name=v.get("team_name"), ) + pools.append(api_client.pools.create(pool=pool_body)) # type: ignore[arg-type] else: failed.append(k) return pools, failed -def pool_export_helper(filepath): +@provide_api_client +def pool_export_helper(filepath, api_client: Client = NEW_API_CLIENT): """Help export all the pools to the json file.""" - api_client = get_current_api_client() pool_dict = {} - pools = api_client.get_pools() + pools = api_client.pools.list().pools for pool in pools: entry = { - "slots": pool[1], - "description": pool[2], - "include_deferred": pool[3], + "slots": pool.slots, + "description": pool.description, + "include_deferred": pool.include_deferred, } - if pool[4] is not None: - entry["team_name"] = pool[4] - pool_dict[pool[0]] = entry + if pool.team_name is not None: + entry["team_name"] = pool.team_name + pool_dict[pool.name] = entry with open(filepath, "w") as poolfile: poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) return pools diff --git a/airflow-core/src/airflow/cli/utils.py b/airflow-core/src/airflow/cli/utils.py index 870f045071b7d..a19d761c30485 100644 --- a/airflow-core/src/airflow/cli/utils.py +++ b/airflow-core/src/airflow/cli/utils.py @@ -17,8 +17,13 @@ from __future__ import annotations +import functools import sys -from typing import TYPE_CHECKING +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar + +from airflow.exceptions import RemovedInAirflow4Warning # Placeholder for masking sensitive values in CLI output SENSITIVE_PLACEHOLDER = "***" @@ -32,6 +37,35 @@ from airflow.models.dagrun import DagRun +F = TypeVar("F", bound=Callable[..., object]) + + +def deprecated_for_airflowctl(replacement: str) -> Callable[[F], F]: + """ + Mark an ``airflow`` CLI command as deprecated in favour of an ``airflowctl`` equivalent. + + These commands now reach Airflow through the API server via the ``airflowctl`` client. They + are kept for backwards compatibility but will be removed in a future Airflow release; users + should switch to ``airflowctl`` directly. + + :param replacement: The equivalent ``airflowctl`` command, e.g. ``airflowctl dags trigger``. + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + def wrapper(*args, **kwargs): + warnings.warn( + f"This `airflow` CLI command is deprecated and will be removed in a future " + f"Airflow release. Use `{replacement}` instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + return func(*args, **kwargs) + + return wrapper # type: ignore[return-value] + + return decorator + class CliConflictError(Exception): """Error for when CLI commands are defined twice by different sources.""" diff --git a/airflow-core/tests/unit/cli/commands/test_asset_command.py b/airflow-core/tests/unit/cli/commands/test_asset_command.py index 6efd8293534f3..d30e36eb04d97 100644 --- a/airflow-core/tests/unit/cli/commands/test_asset_command.py +++ b/airflow-core/tests/unit/cli/commands/test_asset_command.py @@ -21,7 +21,7 @@ import json import os import typing -from unittest import mock +from types import SimpleNamespace import pytest @@ -37,7 +37,9 @@ pytestmark = [pytest.mark.db_test] -@pytest.fixture(scope="module", autouse=True) +# Not autouse: only the DB-backed tests below request it, so the mocked (non-DB) +# ``assets materialize`` tests stay free of any database access. +@pytest.fixture(scope="module") def prepare_examples(): with conf_vars({("core", "load_examples"): "True"}): parse_and_sync_to_db(os.devnull) @@ -46,17 +48,12 @@ def prepare_examples(): clear_db_dags() -@pytest.fixture(autouse=True) -def clear_runs(): - clear_db_runs() - - @pytest.fixture(scope="module") def parser() -> ArgumentParser: return cli_parser.get_parser() -def test_cli_assets_list(parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--output=json"]) with stdout_capture as capture: asset_command.asset_list(args) @@ -67,7 +64,7 @@ def test_cli_assets_list(parser: ArgumentParser, stdout_capture) -> None: assert any(asset["uri"] == "s3://dag1/output_1.txt" for asset in asset_list), asset_list -def test_cli_assets_alias_list(parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_alias_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--alias", "--output=json"]) with stdout_capture as capture: asset_command.asset_list(args) @@ -78,7 +75,7 @@ def test_cli_assets_alias_list(parser: ArgumentParser, stdout_capture) -> None: assert any(alias["name"] == "example-alias" for alias in alias_list), alias_list -def test_cli_assets_details(parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--name=asset1_producer", "--output=json"]) with stdout_capture as capture: asset_command.asset_details(args) @@ -107,7 +104,7 @@ def test_cli_assets_details(parser: ArgumentParser, stdout_capture) -> None: } -def test_cli_assets_alias_details(parser: ArgumentParser, stdout_capture) -> None: +def test_cli_assets_alias_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--alias", "--name=example-alias", "--output=json"]) with stdout_capture as capture: asset_command.asset_details(args) @@ -124,85 +121,46 @@ def test_cli_assets_alias_details(parser: ArgumentParser, stdout_capture) -> Non } -@mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr") -def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser, stdout_capture) -> None: - mock_hasattr.return_value = False - args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) - with stdout_capture as capture: - asset_command.asset_materialize(args) - - output = capture.getvalue() - - # Check if output is empty first - assert output, "No output captured from asset_materialize command" - - run_list = json.loads(output) - assert len(run_list) == 1 - - # No good way to statically compare these. - undeterministic: dict = { - "dag_run_id": None, - "dag_versions": [], - "data_interval_end": None, - "data_interval_start": None, - "logical_date": None, - "queued_at": None, - "run_after": "2025-02-12T19:27:59.066046Z", - } - - assert run_list[0] | undeterministic == undeterministic | { - "conf": {}, - "bundle_version": None, - "dag_display_name": "asset1_producer", - "dag_id": "asset1_producer", - "end_date": None, - "duration": None, - "last_scheduling_decision": None, - "note": None, - "partition_key": None, - "run_type": "asset_materialization", - "start_date": None, - "state": "queued", - "triggered_by": "cli", - "triggering_user_name": "root", - "run_after": "2025-02-12T19:27:59.066046Z", - } - - -def test_cli_assets_materialize_with_view_url_template(parser: ArgumentParser, stdout_capture) -> None: - args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) - with stdout_capture as capture: - asset_command.asset_materialize(args) - - output = capture.getvalue() - run_list = json.loads(output) - assert len(run_list) == 1 - - # No good way to statically compare these. - undeterministic: dict = { - "dag_run_id": None, - "dag_versions": [], - "data_interval_end": None, - "data_interval_start": None, - "logical_date": None, - "queued_at": None, - "run_after": "2025-02-12T19:27:59.066046Z", - } - - assert run_list[0] | undeterministic == undeterministic | { - "conf": {}, - "bundle_version": None, - "dag_display_name": "asset1_producer", - "dag_id": "asset1_producer", - "end_date": None, - "duration": None, - "last_scheduling_decision": None, - "note": None, - "partition_key": None, - "run_type": "asset_materialization", - "start_date": None, - "state": "queued", - "triggered_by": "cli", - "triggering_user_name": "root", - "run_after": "2025-02-12T19:27:59.066046Z", - } +@pytest.mark.non_db_test_override +class TestCliAssetsMaterialize: + """`assets materialize` goes through the airflowctl client; mocked here (no DB/server).""" + + def test_materialize(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None: + mock_cli_api_client.assets.list.return_value.assets = [ + SimpleNamespace(id=7, name="asset1_producer", uri="s3://bucket/asset1_producer"), + SimpleNamespace(id=8, name="other", uri="s3://bucket/other"), + ] + mock_cli_api_client.assets.materialize.return_value.model_dump.return_value = { + "dag_id": "asset1_producer", + "run_type": "asset_materialization", + "state": "queued", + } + args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) + with stdout_capture as capture: + asset_command.asset_materialize(args) + + run_list = json.loads(capture.getvalue()) + assert len(run_list) == 1 + assert run_list[0]["dag_id"] == "asset1_producer" + # The asset is resolved to its id and materialization is delegated to the API server. + mock_cli_api_client.assets.materialize.assert_called_once_with(asset_id="7") + + def test_materialize_requires_name_or_uri(self, parser: ArgumentParser, mock_cli_api_client) -> None: + with pytest.raises(SystemExit, match="Either --name or --uri is required"): + asset_command.asset_materialize(parser.parse_args(["assets", "materialize"])) + mock_cli_api_client.assets.materialize.assert_not_called() + + def test_materialize_missing(self, parser: ArgumentParser, mock_cli_api_client) -> None: + mock_cli_api_client.assets.list.return_value.assets = [] + with pytest.raises(SystemExit, match="Asset with name nope does not exist"): + asset_command.asset_materialize(parser.parse_args(["assets", "materialize", "--name=nope"])) + mock_cli_api_client.assets.materialize.assert_not_called() + + def test_materialize_ambiguous(self, parser: ArgumentParser, mock_cli_api_client) -> None: + mock_cli_api_client.assets.list.return_value.assets = [ + SimpleNamespace(id=1, name="dup", uri="s3://a"), + SimpleNamespace(id=2, name="dup", uri="s3://b"), + ] + with pytest.raises(SystemExit, match="More than one asset exists with name dup"): + asset_command.asset_materialize(parser.parse_args(["assets", "materialize", "--name=dup"])) + mock_cli_api_client.assets.materialize.assert_not_called() diff --git a/airflow-core/tests/unit/cli/commands/test_command_deprecations.py b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py new file mode 100644 index 0000000000000..b4eb6840c9069 --- /dev/null +++ b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Single source of truth for the ``airflow`` CLI commands deprecated in favour of ``airflowctl``. + +Every command decorated with ``deprecated_for_airflowctl`` must have one entry below. When a new +command is migrated and deprecated, add a row to ``DEPRECATED_CLI_COMMANDS`` -- the test then +verifies it emits ``RemovedInAirflow4Warning`` pointing at the right ``airflowctl`` command. +""" + +from __future__ import annotations + +import contextlib +import re + +import pytest + +from airflow.cli.commands import asset_command, dag_command, pool_command +from airflow.exceptions import RemovedInAirflow4Warning + +# (command callable, argv to parse, expected airflowctl replacement named in the warning) +DEPRECATED_CLI_COMMANDS = [ + (dag_command.dag_trigger, ["dags", "trigger", "example_dag", "--run-id=x"], "airflowctl dags trigger"), + (dag_command.dag_delete, ["dags", "delete", "example_dag", "--yes"], "airflowctl dags delete"), + (pool_command.pool_list, ["pools", "list"], "airflowctl pools list"), + (pool_command.pool_get, ["pools", "get", "foo"], "airflowctl pools get"), + (pool_command.pool_set, ["pools", "set", "foo", "1", "desc"], "airflowctl pools create"), + (pool_command.pool_delete, ["pools", "delete", "foo"], "airflowctl pools delete"), + (pool_command.pool_import, ["pools", "import", "/nonexistent.json"], "airflowctl pools import"), + ( + pool_command.pool_export, + ["pools", "export", "/tmp/airflow_pools_export.json"], + "airflowctl pools export", + ), + ( + asset_command.asset_materialize, + ["assets", "materialize", "--name=foo"], + "airflowctl assets materialize", + ), +] + + +@pytest.mark.parametrize( + ("command", "argv", "replacement"), + DEPRECATED_CLI_COMMANDS, + ids=[argv[0] + "-" + argv[1] for _, argv, _ in DEPRECATED_CLI_COMMANDS], +) +def test_deprecated_cli_command_points_to_airflowctl(command, argv, replacement, parser, mock_cli_api_client): + """Each migrated command warns it will become an alias for its ``airflowctl`` counterpart. + + We only assert the deprecation warning fires (and names the right replacement); the command + body itself is exercised by the per-command test modules, so any error it raises against the + bare mocked client is irrelevant here and suppressed. + """ + with pytest.warns(RemovedInAirflow4Warning, match=re.escape(replacement)): + with contextlib.suppress(Exception, SystemExit): + command(parser.parse_args(argv)) diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 6f76034a4218d..2dafed8f15ee9 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -25,13 +25,14 @@ from unittest import mock from unittest.mock import MagicMock +import httpx import msgspec import pendulum import pytest import time_machine -from sqlalchemy import func, select +from airflowctl.api.operations import ServerResponseError +from sqlalchemy import select -from airflow import settings from airflow._shared.timezones import timezone from airflow.cli import cli_parser from airflow.cli.commands import dag_command @@ -485,21 +486,19 @@ def test_cli_list_import_errors(self, get_test_dag, configure_testing_dag_bundle assert str(path_to_parse) in log_output assert "[0 100 * * *] is not acceptable, out of range" in log_output - def test_cli_list_dag_runs(self): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - ] - ) - ) + def test_cli_list_dag_runs(self, dag_maker): + # Seed a run directly in the DB; ``dags trigger`` now goes through the API server + # (airflowctl client) and cannot be used as an in-process fixture here. + with dag_maker("test_list_dag_runs", start_date=DEFAULT_DATE, serialized=True): + EmptyOperator(task_id="t1") + dag_maker.create_dagrun(state=DagRunState.SUCCESS, logical_date=DEFAULT_DATE) + dag_maker.sync_dagbag_to_db() + args = self.parser.parse_args( [ "dags", "list-runs", - "example_bash_operator", + "test_list_dag_runs", "--no-backfill", "--start-date", DEFAULT_DATE.isoformat(), @@ -592,206 +591,6 @@ def test_pausing_already_paused_dag_do_not_error(self, stdout_capture): out = temp_stdout.splitlines()[-1] assert out == "No unpaused DAGs were found" - def test_trigger_dag(self): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=test_trigger_dag", - '--conf={"foo": "bar"}', - ], - ), - ) - with create_session() as session: - dagrun = session.scalars(select(DagRun).where(DagRun.run_id == "test_trigger_dag")).one() - - assert dagrun, "DagRun not created" - assert dagrun.run_type == DagRunType.MANUAL - assert dagrun.conf == {"foo": "bar"} - - # logical_date is None as it's not provided - assert dagrun.logical_date is None - - # data_interval is None as logical_date is None - assert dagrun.data_interval_start is None - assert dagrun.data_interval_end is None - - def test_trigger_dag_empty_object_conf(self): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=test_trigger_dag_empty_object_conf", - "--conf={}", - ], - ), - ) - with create_session() as session: - dagrun = session.scalars( - select(DagRun).where(DagRun.run_id == "test_trigger_dag_empty_object_conf") - ).one() - - assert dagrun.conf == {} - - def test_trigger_dag_json_null_conf(self): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=test_trigger_dag_json_null_conf", - "--conf=null", - ], - ), - ) - with create_session() as session: - dagrun = session.scalars( - select(DagRun).where(DagRun.run_id == "test_trigger_dag_json_null_conf") - ).one() - - assert dagrun.conf == {} - - def test_trigger_dag_with_microseconds(self): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id=test_trigger_dag_with_micro", - "--logical-date=2021-06-04T09:00:00.000001+08:00", - "--no-replace-microseconds", - ], - ) - ) - - with create_session() as session: - dagrun = session.scalars( - select(DagRun).where(DagRun.run_id == "test_trigger_dag_with_micro") - ).one() - - assert dagrun, "DagRun not created" - assert dagrun.run_type == DagRunType.MANUAL - assert dagrun.logical_date.isoformat(timespec="microseconds") == "2021-06-04T01:00:00.000001+00:00" - - @pytest.mark.parametrize("conf", ["NOT JSON", ""]) - def test_trigger_dag_invalid_conf(self, conf): - with pytest.raises(ValueError, match=r"Expecting value: line \d+ column \d+ \(char \d+\)"): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id", - "trigger_dag_xxx", - "--conf", - conf, - ] - ), - ) - - @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"]) - def test_trigger_dag_rejects_non_object_conf(self, conf): - with pytest.raises(ValueError, match="DagRun conf must be a JSON object or null"): - dag_command.dag_trigger( - self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id", - "trigger_dag_xxx", - "--conf", - conf, - ] - ), - ) - - def test_trigger_dag_output_as_json(self, stdout_capture): - args = self.parser.parse_args( - [ - "dags", - "trigger", - "example_bash_operator", - "--run-id", - "trigger_dag_xxx", - "--conf", - '{"conf1": "val1", "conf2": "val2"}', - "--output=json", - ] - ) - with stdout_capture as temp_stdout: - dag_command.dag_trigger(args) - # get the last line from the logs ignoring all logging lines - out = temp_stdout.getvalue().strip().splitlines()[-1] - parsed_out = json.loads(out) - - assert len(parsed_out) == 1 - assert parsed_out[0]["dag_id"] == "example_bash_operator" - assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx" - assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"} - - def test_delete_dag(self): - DM = DagModel - key = "my_dag_id" - session = settings.Session() - session.add(DM(dag_id=key, bundle_name="dags-folder")) - session.commit() - dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) - assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 - with pytest.raises(AirflowException): - dag_command.dag_delete( - self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]), - ) - - def test_dag_delete_when_backfill_and_dagrun_exist(self): - # Test to check that the DAG should be deleted even if - # there are backfill records associated with it. - from airflow.models.backfill import Backfill - - DM = DagModel - key = "my_dag_id" - session = settings.Session() - session.add(DM(dag_id=key, bundle_name="dags-folder")) - _backfill = Backfill(dag_id=key, from_date=DEFAULT_DATE, to_date=DEFAULT_DATE + timedelta(days=1)) - session.add(_backfill) - # To create the backfill_id in DagRun - session.flush() - session.add( - DagRun( - dag_id=key, - run_id="backfill__" + key, - state=DagRunState.SUCCESS, - run_type="backfill", - backfill_id=_backfill.id, - ) - ) - session.commit() - dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) - assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 - with pytest.raises(AirflowException): - dag_command.dag_delete( - self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]), - ) - - def test_delete_dag_existing_file(self, tmp_path): - # Test to check that the DAG should be deleted even if - # the file containing it is not deleted - path = tmp_path / "testfile" - DM = DagModel - key = "my_dag_id" - session = settings.Session() - session.add(DM(dag_id=key, bundle_name="dags-folder", fileloc=os.fspath(path))) - session.commit() - dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"])) - assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0 - def test_cli_list_jobs(self): args = self.parser.parse_args(["dags", "list-jobs"]) dag_command.dag_list_jobs(args) @@ -1991,3 +1790,142 @@ def test_is_backfillable(self, schedule, allowed_run_types, expected): ) dag_details = dag_command._get_dagbag_dag_details(dag) assert dag_details["is_backfillable"] is expected + + +def _server_error(status_code: int) -> ServerResponseError: + request = httpx.Request("DELETE", "http://testserver/api/v2/dags/foo") + response = httpx.Response(status_code, request=request, json={"detail": "boom"}) + return ServerResponseError(message="boom", request=request, response=response) + + +@pytest.mark.non_db_test_override +class TestCliDagsApiClientCommands: + """Dag CLI commands that talk to the API server through the airflowctl client. + + These are unit tests: the airflowctl client is mocked so no API server (or + database) is required. + """ + + @classmethod + def setup_class(cls): + cls.parser = cli_parser.get_parser() + + @pytest.fixture(autouse=True) + def _default_trigger_response(self, mock_cli_api_client): + """Give the mocked ``dags.trigger`` a dict response so ``print_as`` can render it.""" + mock_cli_api_client.dags.trigger.return_value.model_dump.return_value = { + "dag_id": "example_bash_operator", + "dag_run_id": "test_run", + } + + def test_trigger_dag(self, mock_cli_api_client): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=test_trigger_dag", + '--conf={"foo": "bar"}', + ] + ), + ) + mock_cli_api_client.dags.trigger.assert_called_once() + call = mock_cli_api_client.dags.trigger.call_args + assert call.kwargs["dag_id"] == "example_bash_operator" + body = call.kwargs["trigger_dag_run"] + assert body.dag_run_id == "test_trigger_dag" + assert body.conf == {"foo": "bar"} + # logical_date is None as it's not provided + assert body.logical_date is None + + def test_trigger_dag_empty_object_conf(self, mock_cli_api_client): + dag_command.dag_trigger( + self.parser.parse_args( + ["dags", "trigger", "example_bash_operator", "--run-id=empty_conf", "--conf={}"] + ), + ) + body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] + assert body.conf == {} + + def test_trigger_dag_json_null_conf(self, mock_cli_api_client): + dag_command.dag_trigger( + self.parser.parse_args( + ["dags", "trigger", "example_bash_operator", "--run-id=null_conf", "--conf=null"] + ), + ) + # ``null`` conf resolves to None on the client; the API server coerces it to {}. + body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] + assert body.conf is None + + def test_trigger_dag_with_microseconds(self, mock_cli_api_client): + dag_command.dag_trigger( + self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id=micro", + "--logical-date=2021-06-04T09:00:00.000001+08:00", + ] + ) + ) + body = mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"] + assert body.logical_date.isoformat(timespec="microseconds") == "2021-06-04T09:00:00.000001+08:00" + + @pytest.mark.parametrize("conf", ["NOT JSON", ""]) + def test_trigger_dag_invalid_conf(self, mock_cli_api_client, conf): + with pytest.raises(ValueError, match=r"Expecting value: line \d+ column \d+ \(char \d+\)"): + dag_command.dag_trigger( + self.parser.parse_args( + ["dags", "trigger", "example_bash_operator", "--run-id", "xxx", "--conf", conf] + ), + ) + mock_cli_api_client.dags.trigger.assert_not_called() + + @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"]) + def test_trigger_dag_rejects_non_object_conf(self, mock_cli_api_client, conf): + with pytest.raises(ValueError, match="DagRun conf must be a JSON object or null"): + dag_command.dag_trigger( + self.parser.parse_args( + ["dags", "trigger", "example_bash_operator", "--run-id", "xxx", "--conf", conf] + ), + ) + mock_cli_api_client.dags.trigger.assert_not_called() + + def test_trigger_dag_output_as_json(self, mock_cli_api_client, stdout_capture): + mock_cli_api_client.dags.trigger.return_value.model_dump.return_value = { + "dag_id": "example_bash_operator", + "dag_run_id": "trigger_dag_xxx", + "conf": {"conf1": "val1", "conf2": "val2"}, + } + args = self.parser.parse_args( + [ + "dags", + "trigger", + "example_bash_operator", + "--run-id", + "trigger_dag_xxx", + "--conf", + '{"conf1": "val1", "conf2": "val2"}', + "--output=json", + ] + ) + with stdout_capture as temp_stdout: + dag_command.dag_trigger(args) + out = temp_stdout.getvalue().strip().splitlines()[-1] + parsed_out = json.loads(out) + + assert len(parsed_out) == 1 + assert parsed_out[0]["dag_id"] == "example_bash_operator" + assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx" + assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"} + + def test_delete_dag(self, mock_cli_api_client): + dag_command.dag_delete(self.parser.parse_args(["dags", "delete", "my_dag_id", "--yes"])) + mock_cli_api_client.dags.delete.assert_called_once_with(dag_id="my_dag_id") + + def test_delete_dag_missing(self, mock_cli_api_client): + mock_cli_api_client.dags.delete.side_effect = _server_error(404) + with pytest.raises(ServerResponseError): + dag_command.dag_delete(self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"])) diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py b/airflow-core/tests/unit/cli/commands/test_pool_command.py index ad6951567fbf6..28e2d761812e3 100644 --- a/airflow-core/tests/unit/cli/commands/test_pool_command.py +++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py @@ -18,281 +18,235 @@ from __future__ import annotations import json +from types import SimpleNamespace +import httpx import pytest -from sqlalchemy import delete, func, select +from airflowctl.api.operations import ServerResponseError -from airflow import models, settings from airflow.cli import cli_parser from airflow.cli.commands import pool_command -from airflow.models import Pool -from airflow.settings import Session -from airflow.utils.db import add_default_pool_if_not_exists -pytestmark = pytest.mark.db_test +from tests_common.test_utils.config import conf_vars + + +def _pool(name, slots, description="", include_deferred=False, team_name=None): + """Build a stand-in for the airflowctl ``PoolResponse`` returned by the client.""" + return SimpleNamespace( + name=name, + slots=slots, + description=description, + include_deferred=include_deferred, + team_name=team_name, + ) + + +def _server_error(status_code: int) -> ServerResponseError: + request = httpx.Request("GET", "http://testserver/api/v2/pools/foo") + response = httpx.Response(status_code, request=request, json={"detail": "boom"}) + return ServerResponseError(message="boom", request=request, response=response) class TestCliPools: @classmethod def setup_class(cls): - cls.dagbag = models.DagBag() cls.parser = cli_parser.get_parser() - settings.configure_orm() - cls.session = Session - cls._cleanup() - - def tearDown(self): - self._cleanup() - - @staticmethod - def _cleanup(session=None): - if session is None: - session = Session() - session.execute(delete(Pool).where(Pool.pool != Pool.DEFAULT_POOL_NAME)) - session.commit() - add_default_pool_if_not_exists() - session.close() - - def test_pool_list(self, stdout_capture): - pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + + def test_pool_list(self, mock_cli_api_client, stdout_capture): + mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1, "test")] with stdout_capture as stdout: pool_command.pool_list(self.parser.parse_args(["pools", "list"])) assert "foo" in stdout.getvalue() + mock_cli_api_client.pools.list.assert_called_once() - def test_pool_list_with_args(self): + def test_pool_list_with_args(self, mock_cli_api_client): + mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1, "test")] pool_command.pool_list(self.parser.parse_args(["pools", "list", "--output", "json"])) - def test_pool_create(self): + def test_pool_create(self, mock_cli_api_client): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.scalar(select(func.count()).select_from(Pool)) == 2 - def test_pool_update_deferred(self): - pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False + mock_cli_api_client.pools.create.assert_called_once() + body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] + # core_api PoolBody exposes the name via the ``pool`` attribute (alias ``name``). + assert body.pool == "foo" + assert body.slots == 1 + assert body.description == "test" + assert body.include_deferred is False + def test_pool_create_include_deferred(self, mock_cli_api_client): pool_command.pool_set( self.parser.parse_args(["pools", "set", "foo", "1", "test", "--include-deferred"]) ) - assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is True - pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False + body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] + assert body.include_deferred is True - def test_pool_get(self): - pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) + def test_pool_get(self, mock_cli_api_client, stdout_capture): + mock_cli_api_client.pools.get.return_value = _pool("foo", 1, "test") + with stdout_capture as stdout: + pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) - def test_pool_delete(self): - pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) + assert "foo" in stdout.getvalue() + mock_cli_api_client.pools.get.assert_called_once_with(pool_name="foo") + + def test_pool_get_missing(self, mock_cli_api_client): + mock_cli_api_client.pools.get.side_effect = _server_error(404) + with pytest.raises(SystemExit, match="Pool foo does not exist"): + pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) + + def test_pool_get_other_error_reraised(self, mock_cli_api_client): + mock_cli_api_client.pools.get.side_effect = _server_error(500) + with pytest.raises(ServerResponseError): + pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"])) + + def test_pool_delete(self, mock_cli_api_client): pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"])) - assert self.session.scalar(select(func.count()).select_from(Pool)) == 1 + mock_cli_api_client.pools.delete.assert_called_once_with(pool="foo") + + def test_pool_delete_missing(self, mock_cli_api_client): + mock_cli_api_client.pools.delete.side_effect = _server_error(404) + with pytest.raises(SystemExit, match="Pool foo does not exist"): + pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"])) - def test_pool_import_nonexistent(self): + def test_pool_import_nonexistent(self, mock_cli_api_client): with pytest.raises(SystemExit): pool_command.pool_import(self.parser.parse_args(["pools", "import", "nonexistent.json"])) - def test_pool_import_invalid_json(self, tmp_path): + def test_pool_import_invalid_json(self, mock_cli_api_client, tmp_path): invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json" - with open(invalid_pool_import_file_path, mode="w") as file: - file.write("not valid json") + invalid_pool_import_file_path.write_text("not valid json") with pytest.raises(SystemExit): pool_command.pool_import( self.parser.parse_args(["pools", "import", str(invalid_pool_import_file_path)]) ) - def test_pool_import_invalid_pools(self, tmp_path): + def test_pool_import_invalid_pools(self, mock_cli_api_client, tmp_path): invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json" + # Missing ``slots`` makes the entry invalid. pool_config_input = {"foo": {"description": "foo_test", "include_deferred": False}} - with open(invalid_pool_import_file_path, mode="w") as file: - json.dump(pool_config_input, file) + invalid_pool_import_file_path.write_text(json.dumps(pool_config_input)) with pytest.raises(SystemExit): pool_command.pool_import( self.parser.parse_args(["pools", "import", str(invalid_pool_import_file_path)]) ) - def test_pool_import_backwards_compatibility(self, tmp_path): + def test_pool_import(self, mock_cli_api_client, tmp_path): pool_import_file_path = tmp_path / "pools_import.json" pool_config_input = { - # JSON before version 2.7.0 does not contain `include_deferred` - "foo": {"description": "foo_test", "slots": 1}, + "foo": {"description": "foo_test", "slots": 1, "include_deferred": True}, + # JSON before version 2.7.0 does not contain ``include_deferred``. + "bar": {"description": "bar_test", "slots": 2}, } - with open(pool_import_file_path, mode="w") as file: - json.dump(pool_config_input, file) + pool_import_file_path.write_text(json.dumps(pool_config_input)) pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) - assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False - - def test_pool_import_export(self, tmp_path): - pool_import_file_path = tmp_path / "pools_import.json" - pool_export_file_path = tmp_path / "pools_export.json" - pool_config_input = { - "foo": {"description": "foo_test", "slots": 1, "include_deferred": True}, - "default_pool": { - "description": "Default pool", - "slots": 128, - "include_deferred": False, - }, - "baz": {"description": "baz_test", "slots": 2, "include_deferred": False}, + assert mock_cli_api_client.pools.create.call_count == 2 + bodies = { + call.kwargs["pool"].pool: call.kwargs["pool"] + for call in mock_cli_api_client.pools.create.call_args_list } - with open(pool_import_file_path, mode="w") as file: - json.dump(pool_config_input, file) + assert bodies["foo"].include_deferred is True + # Missing ``include_deferred`` defaults to False (backwards compatibility). + assert bodies["bar"].include_deferred is False - # Import json - pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) + def test_pool_export(self, mock_cli_api_client, tmp_path): + pool_export_file_path = tmp_path / "pools_export.json" + mock_cli_api_client.pools.list.return_value.pools = [ + _pool("foo", 1, "foo_test", include_deferred=True), + _pool("baz", 2, "baz_test", include_deferred=False), + ] - # Export json pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) - with open(pool_export_file_path) as file: - pool_config_output = json.load(file) - assert pool_config_input == pool_config_output, "Input and output pool files are not same" - - def test_pool_set_with_team_name(self): - """Test that pool_set with --team-name assigns the pool to the team when multi_team is enabled.""" - from airflow.models.team import Team + exported = json.loads(pool_export_file_path.read_text()) + assert exported == { + "foo": {"slots": 1, "description": "foo_test", "include_deferred": True}, + "baz": {"slots": 2, "description": "baz_test", "include_deferred": False}, + } - from tests_common.test_utils.config import conf_vars + def test_pool_set_with_team_name(self, mock_cli_api_client): + """``--team-name`` is forwarded to the airflowctl client when multi_team is enabled.""" + with conf_vars({("core", "multi_team"): "True"}): + pool_command.pool_set( + self.parser.parse_args( + ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] + ) + ) - # Create the team first - team = Team(name="test_team") - self.session.add(team) - self.session.commit() + body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] + assert body.team_name == "test_team" - try: - with conf_vars({("core", "multi_team"): "True"}): + def test_pool_set_team_name_rejected_when_multi_team_disabled(self, mock_cli_api_client): + """``PoolBody`` rejects a team_name (client-side) when multi_team is disabled.""" + with conf_vars({("core", "multi_team"): "False"}): + with pytest.raises(ValueError, match="team_name cannot be set when multi_team mode is disabled"): pool_command.pool_set( self.parser.parse_args( ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] ) ) + mock_cli_api_client.pools.create.assert_not_called() - pool = self.session.scalar(select(Pool).where(Pool.pool == "team_pool")) - assert pool is not None - assert pool.team_name == "test_team" - assert pool.slots == 5 - finally: - self.session.execute(delete(Pool).where(Pool.pool == "team_pool")) - self.session.execute(delete(Team).where(Team.name == "test_team")) - self.session.commit() - - def test_pool_set_team_name_rejected_when_multi_team_disabled(self): - """Test that pool_set with --team-name raises when multi_team is disabled.""" - from airflow.models.team import Team - - from tests_common.test_utils.config import conf_vars - - team = Team(name="test_team") - self.session.add(team) - self.session.commit() - - try: - with conf_vars({("core", "multi_team"): "False"}): - with pytest.raises( - ValueError, match="team_name cannot be set when multi_team mode is disabled" - ): - pool_command.pool_set( - self.parser.parse_args( - ["pools", "set", "team_pool", "5", "team pool", "--team-name", "test_team"] - ) - ) - finally: - self.session.execute(delete(Pool).where(Pool.pool == "team_pool")) - self.session.execute(delete(Team).where(Team.name == "test_team")) - self.session.commit() - - def test_pool_set_without_team_name(self): - """Test that pool_set without --team-name leaves team_name as None.""" + def test_pool_set_without_team_name(self, mock_cli_api_client): + """Without ``--team-name`` the forwarded body has ``team_name`` as None.""" pool_command.pool_set(self.parser.parse_args(["pools", "set", "no_team_pool", "3", "no team"])) - pool = self.session.scalar(select(Pool).where(Pool.pool == "no_team_pool")) - assert pool is not None - assert pool.team_name is None - - def test_pool_import_export_with_team_name(self, tmp_path): - """Test that import/export round-trips the team_name field.""" - from airflow.models.team import Team - - from tests_common.test_utils.config import conf_vars - - team = Team(name="import_team") - self.session.add(team) - self.session.commit() + body = mock_cli_api_client.pools.create.call_args.kwargs["pool"] + assert body.team_name is None + def test_pool_import_forwards_team_name(self, mock_cli_api_client, tmp_path): + """Import forwards each pool's ``team_name`` (or None) to the airflowctl client.""" pool_import_file_path = tmp_path / "pools_import_team.json" - pool_export_file_path = tmp_path / "pools_export_team.json" - pool_config_input = { - "team_pool_a": { - "slots": 10, - "description": "team pool", - "include_deferred": False, - "team_name": "import_team", - }, - "global_pool": { - "slots": 5, - "description": "global pool", - "include_deferred": False, - }, - } - - with open(pool_import_file_path, mode="w") as file: - json.dump(pool_config_input, file) - - try: - with conf_vars({("core", "multi_team"): "True"}): - pool_command.pool_import( - self.parser.parse_args(["pools", "import", str(pool_import_file_path)]) - ) - - # Verify team assignment - pool = self.session.scalar(select(Pool).where(Pool.pool == "team_pool_a")) - assert pool is not None - assert pool.team_name == "import_team" - - global_pool = self.session.scalar(select(Pool).where(Pool.pool == "global_pool")) - assert global_pool is not None - assert global_pool.team_name is None + pool_import_file_path.write_text( + json.dumps( + { + "team_pool_a": { + "slots": 10, + "description": "team pool", + "include_deferred": False, + "team_name": "import_team", + }, + "global_pool": {"slots": 5, "description": "global pool", "include_deferred": False}, + } + ) + ) - # Export and verify - pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) + with conf_vars({("core", "multi_team"): "True"}): + pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) - with open(pool_export_file_path) as file: - pool_config_output = json.load(file) + bodies = { + call.kwargs["pool"].pool: call.kwargs["pool"] + for call in mock_cli_api_client.pools.create.call_args_list + } + assert bodies["team_pool_a"].team_name == "import_team" + assert bodies["global_pool"].team_name is None - assert pool_config_output["team_pool_a"]["team_name"] == "import_team" - assert "team_name" not in pool_config_output["global_pool"] - finally: - self.session.execute(delete(Pool).where(Pool.pool.in_(["team_pool_a", "global_pool"]))) - self.session.execute(delete(Team).where(Team.name == "import_team")) - self.session.commit() + def test_pool_export_includes_team_name(self, mock_cli_api_client, tmp_path): + """Export writes ``team_name`` only for pools that have one.""" + pool_export_file_path = tmp_path / "pools_export_team.json" + mock_cli_api_client.pools.list.return_value.pools = [ + _pool("team_pool_a", 10, "team pool", team_name="import_team"), + _pool("global_pool", 5, "global pool"), + ] - def test_pool_list_shows_team_name(self, stdout_capture): - """Test that pool list output includes the team_name column.""" - from airflow.models.team import Team + pool_command.pool_export(self.parser.parse_args(["pools", "export", str(pool_export_file_path)])) - from tests_common.test_utils.config import conf_vars + exported = json.loads(pool_export_file_path.read_text()) + assert exported["team_pool_a"]["team_name"] == "import_team" + assert "team_name" not in exported["global_pool"] - team = Team(name="list_team") - self.session.add(team) - self.session.commit() + def test_pool_list_shows_team_name(self, mock_cli_api_client, stdout_capture): + """Pool list output includes the team_name column.""" + mock_cli_api_client.pools.list.return_value.pools = [ + _pool("list_pool", 5, "desc", team_name="list_team") + ] - try: - with conf_vars({("core", "multi_team"): "True"}): - pool_command.pool_set( - self.parser.parse_args( - ["pools", "set", "list_pool", "5", "desc", "--team-name", "list_team"] - ) - ) - - with stdout_capture as stdout: - pool_command.pool_list(self.parser.parse_args(["pools", "list"])) + with stdout_capture as stdout: + pool_command.pool_list(self.parser.parse_args(["pools", "list"])) - output = stdout.getvalue() - assert "list_team" in output - finally: - self.session.execute(delete(Pool).where(Pool.pool == "list_pool")) - self.session.execute(delete(Team).where(Team.name == "list_team")) - self.session.commit() + assert "list_team" in stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/conftest.py b/airflow-core/tests/unit/cli/conftest.py index 7676a103b5363..d9d2ae341eb51 100644 --- a/airflow-core/tests/unit/cli/conftest.py +++ b/airflow-core/tests/unit/cli/conftest.py @@ -18,6 +18,7 @@ from __future__ import annotations import sys +from unittest import mock import pytest @@ -68,6 +69,25 @@ def parser(): # log messages +@pytest.fixture +def mock_cli_api_client(): + """Mock the CLI airflowctl client and neutralize ``action_cli``'s DB touch points. + + CLI commands that go through the airflowctl client only need the mocked client; the + ``@action_cli`` audit logging and log-template sync would otherwise open a database + session. Patching them lets these command tests run without a database or API server. + """ + client = mock.MagicMock() + with ( + mock.patch("airflow.cli.api_client.get_cli_api_client", return_value=client), + mock.patch("airflow.utils.cli_action_loggers.on_pre_execution"), + mock.patch("airflow.utils.cli_action_loggers.on_post_execution"), + mock.patch("airflow.utils.db.synchronize_log_template"), + mock.patch("airflow.utils.db.check_and_run_migrations"), + ): + yield client + + @pytest.fixture def stdout_capture(request): """Fixture that captures stdout only.""" diff --git a/airflow-core/tests/unit/cli/test_api_client.py b/airflow-core/tests/unit/cli/test_api_client.py new file mode 100644 index 0000000000000..3ef813cebda43 --- /dev/null +++ b/airflow-core/tests/unit/cli/test_api_client.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager +from airflow.api_fastapi.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.cli import api_client as cli_api_client + +from tests_common.test_utils.config import conf_vars + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Reset the process-wide client singleton around each test.""" + cli_api_client._api_client = None + yield + cli_api_client._api_client = None + + +class TestResolveBaseUrl: + @conf_vars({("api", "base_url"): "https://airflow.example.com:9999"}) + def test_explicit_base_url(self): + assert cli_api_client._resolve_base_url() == "https://airflow.example.com:9999" + + @conf_vars({("api", "base_url"): "", ("api", "host"): "myhost", ("api", "port"): "1234"}) + def test_host_port_fallback(self): + assert cli_api_client._resolve_base_url() == "http://myhost:1234" + + +class TestMintCliToken: + def test_uses_env_token(self, monkeypatch): + monkeypatch.setenv("AIRFLOW_CLI_TOKEN", "tok-123") + with mock.patch("airflow.api_fastapi.app.get_auth_manager") as get_auth_manager: + assert cli_api_client._mint_cli_token() == "tok-123" + # The auth manager is never consulted when a token is supplied explicitly. + get_auth_manager.assert_not_called() + + def test_mints_via_auth_manager(self, monkeypatch): + monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False) + auth_manager = mock.MagicMock() + auth_manager.get_cli_user.return_value = "cli-user" + auth_manager.generate_jwt.return_value = "signed-jwt" + with mock.patch("airflow.api_fastapi.app.get_auth_manager", return_value=auth_manager): + assert cli_api_client._mint_cli_token() == "signed-jwt" + + auth_manager.generate_jwt.assert_called_once() + assert auth_manager.generate_jwt.call_args.args[0] == "cli-user" + # Token must be short-lived. + assert auth_manager.generate_jwt.call_args.kwargs["expiration_time_in_seconds"] > 0 + + def test_initializes_auth_manager_when_not_initialized(self, monkeypatch): + # In the CLI the auth manager singleton is usually not initialized yet, so + # ``get_auth_manager`` raises and we must initialize it on demand. + monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False) + auth_manager = mock.MagicMock() + auth_manager.get_cli_user.return_value = "cli-user" + auth_manager.generate_jwt.return_value = "signed-jwt" + with ( + mock.patch( + "airflow.api_fastapi.app.get_auth_manager", + side_effect=RuntimeError("Auth Manager has not been initialized yet."), + ), + mock.patch( + "airflow.api_fastapi.app.init_auth_manager", return_value=auth_manager + ) as init_auth_manager, + ): + assert cli_api_client._mint_cli_token() == "signed-jwt" + + init_auth_manager.assert_called_once() + auth_manager.generate_jwt.assert_called_once() + + +class TestGetCliApiClient: + def test_builds_singleton(self): + with ( + mock.patch.object(cli_api_client, "_resolve_base_url", return_value="http://h:8080"), + mock.patch.object(cli_api_client, "_mint_cli_token", return_value="tok"), + mock.patch.object(cli_api_client, "Client") as client_cls, + ): + first = cli_api_client.get_cli_api_client() + second = cli_api_client.get_cli_api_client() + + assert first is second + client_cls.assert_called_once() + kwargs = client_cls.call_args.kwargs + assert kwargs["base_url"] == "http://h:8080" + assert kwargs["token"] == "tok" + assert kwargs["kind"] == cli_api_client.ClientKind.CLI + + +class TestProvideApiClient: + def test_injects_when_missing(self): + with mock.patch.object(cli_api_client, "get_cli_api_client", return_value="CLIENT"): + + @cli_api_client.provide_api_client + def command(args, api_client=None): + return api_client + + assert command("args") == "CLIENT" + + def test_uses_explicit_client(self): + with mock.patch.object(cli_api_client, "get_cli_api_client") as get_client: + + @cli_api_client.provide_api_client + def command(args, api_client=None): + return api_client + + assert command("args", api_client="EXPLICIT") == "EXPLICIT" + get_client.assert_not_called() + + +class TestGetCliUser: + def test_base_default_raises(self): + # The generic auth manager cannot know which user is authorized for the CLI. + with pytest.raises(NotImplementedError, match="AIRFLOW_CLI_TOKEN"): + BaseAuthManager.get_cli_user(mock.Mock()) + + def test_simple_auth_manager_returns_admin(self): + user = SimpleAuthManager.get_cli_user(mock.Mock()) + assert user.get_id() == "cli" + assert user.role == "ADMIN" diff --git a/airflow-core/tests/unit/cli/test_utils.py b/airflow-core/tests/unit/cli/test_utils.py new file mode 100644 index 0000000000000..4fb137ad48201 --- /dev/null +++ b/airflow-core/tests/unit/cli/test_utils.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.cli.utils import deprecated_for_airflowctl +from airflow.exceptions import RemovedInAirflow4Warning + + +class TestDeprecatedForAirflowctl: + def test_emits_warning_naming_replacement(self): + @deprecated_for_airflowctl("airflowctl dags trigger") + def command(args): + return "result" + + with pytest.warns(RemovedInAirflow4Warning, match="airflowctl dags trigger"): + result = command(args=None) + + # The wrapped command still runs and returns its value. + assert result == "result" + + def test_passes_through_args_and_preserves_metadata(self): + @deprecated_for_airflowctl("airflowctl pools create") + def command(a, b, *, c): + """Original docstring.""" + return (a, b, c) + + with pytest.warns(RemovedInAirflow4Warning): + assert command(1, 2, c=3) == (1, 2, 3) + + assert command.__name__ == "command" + assert command.__doc__ == "Original docstring." diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 20b3be74846c0..e02a1a4cf81ad 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -300,6 +300,30 @@ def _fetch_user() -> User: def serialize_user(self, user: User) -> dict[str, Any]: return {"sub": str(user.id)} + def get_cli_user(self) -> User: + """ + Return an existing ``Admin`` user for the local CLI to mint a token for. + + The Airflow CLI mints a short-lived, in-memory JWT for this user so it can talk to + the API server. FAB tokens reference a real database user, so we reuse an existing + ``Admin`` user rather than fabricating one. If none exists, the operator must + create one or provide a token via the ``AIRFLOW_CLI_TOKEN`` environment variable. + """ + from airflow.utils.session import create_session + + with create_session() as session: + user = session.scalars(select(User).join(User.roles).where(Role.name == "Admin").limit(1)).first() + if user is None: + raise AirflowConfigException( + "No user with the 'Admin' role exists in the FAB database. Create one " + "(e.g. `airflow fab create-user --role Admin ...`) or set the " + "AIRFLOW_CLI_TOKEN environment variable with a valid API token." + ) + # Detach so attributes stay accessible after the session closes (and is not + # expired on commit) while the CLI serializes the user to mint the token. + session.expunge(user) + return user + def is_logged_in(self) -> bool: """Return whether the user is logged in.""" user = self.get_user() diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index ab3cb8ee26429..ba647b6dd8612 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -1298,3 +1298,29 @@ def test_session_remove_generic_error_does_not_propagate(self, mock_create_app): mock_session.remove.assert_called() mock_log.warning.assert_called() assert response is not None + + +class TestFabGetCliUser: + """``get_cli_user`` reuses an existing ``Admin`` user for the local CLI token.""" + + @mock.patch("airflow.utils.session.create_session") + def test_returns_admin_user(self, mock_create_session, auth_manager): + admin_user = MagicMock() + session = MagicMock() + session.scalars.return_value.first.return_value = admin_user + mock_create_session.return_value.__enter__.return_value = session + + result = auth_manager.get_cli_user() + + assert result is admin_user + # The user is detached so its attributes survive the session closing. + session.expunge.assert_called_once_with(admin_user) + + @mock.patch("airflow.utils.session.create_session") + def test_raises_when_no_admin_user(self, mock_create_session, auth_manager): + session = MagicMock() + session.scalars.return_value.first.return_value = None + mock_create_session.return_value.__enter__.return_value = session + + with pytest.raises(AirflowConfigException, match="Admin"): + auth_manager.get_cli_user() diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index a8cd683ef46ba..bc9ebc305c306 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -34,7 +34,7 @@ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning try: from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod @@ -141,6 +141,34 @@ def serialize_user(self, user: KeycloakAuthManagerUser) -> dict[str, Any]: "refresh_token": user.refresh_token, } + def get_cli_user(self) -> KeycloakAuthManagerUser: + """ + Return a service-account user for the local CLI to mint a token for. + + Keycloak tokens are issued by the external Keycloak server, so they cannot be + forged locally. The Keycloak client is already configured for Airflow to talk to + Keycloak, so we reuse it to obtain a service-account token through the + ``client_credentials`` flow. The service account's effective permissions are + governed by the Keycloak deployment. If the client credentials are not usable, the + operator must provide a token via the ``AIRFLOW_CLI_TOKEN`` environment variable. + """ + try: + tokens = self.get_keycloak_client().token(grant_type="client_credentials") + except Exception as e: + raise AirflowConfigException( + "Could not obtain a Keycloak service-account token for the CLI via the " + "client_credentials flow. Set the AIRFLOW_CLI_TOKEN environment variable " + f"with a valid API token instead. Original error: {e}" + ) from e + return KeycloakAuthManagerUser( + user_id="airflow-cli", + name="airflow-cli", + access_token=tokens["access_token"], + # No refresh token is issued for the client_credentials flow (RFC 6749 §4.4.3), + # which marks this as a service account in refresh_user/refresh_tokens. + refresh_token=tokens.get("refresh_token"), + ) + def get_url_login(self, **kwargs) -> str: base_url = conf.get("api", "base_url", fallback="/") return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login") diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py index 9c8ed9dd5b610..e4b6a5c294de8 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py @@ -45,7 +45,7 @@ else: TeamDetails = None # type: ignore[assignment,misc] from airflow.api_fastapi.common.types import MenuItem -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning try: from airflow.providers.common.compat.sdk import AirflowException @@ -121,6 +121,25 @@ def _clear_filter_cache(): class TestKeycloakAuthManager: + @patch.object(KeycloakAuthManager, "get_keycloak_client") + def test_get_cli_user(self, mock_get_keycloak_client, auth_manager): + # client_credentials (service account) flow returns an access token and no refresh token. + mock_get_keycloak_client.return_value.token.return_value = {"access_token": "svc-token"} + + user = auth_manager.get_cli_user() + + assert user.get_id() == "airflow-cli" + assert user.access_token == "svc-token" + assert user.refresh_token is None + mock_get_keycloak_client.return_value.token.assert_called_once_with(grant_type="client_credentials") + + @patch.object(KeycloakAuthManager, "get_keycloak_client") + def test_get_cli_user_raises_when_credentials_unusable(self, mock_get_keycloak_client, auth_manager): + mock_get_keycloak_client.return_value.token.side_effect = Exception("boom") + + with pytest.raises(AirflowConfigException, match="AIRFLOW_CLI_TOKEN"): + auth_manager.get_cli_user() + def test_deserialize_user(self, auth_manager): result = auth_manager.deserialize_user( { diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index 93aa6557c8332..04c6e9534f07f 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -1,7 +1,7 @@ airflow-core/src/airflow/api/common/delete_dag.py::1 airflow-core/src/airflow/api_fastapi/core_api/app.py::1 airflow-core/src/airflow/cli/cli_parser.py::1 -airflow-core/src/airflow/cli/commands/dag_command.py::3 +airflow-core/src/airflow/cli/commands/dag_command.py::1 airflow-core/src/airflow/cli/commands/db_command.py::1 airflow-core/src/airflow/config_templates/airflow_local_settings.py::1 airflow-core/src/airflow/dag_processing/dagbag.py::1