diff --git a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py index 924cbde90d598..a00bbd9354d1b 100644 --- a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py +++ b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py @@ -96,7 +96,15 @@ def date_param(): "dags update example_bash_operator --no-is-paused", # Dag Run commands "dagrun list --dag-id example_bash_operator --state success --limit=1", - # XCom commands - need a Dag run with completed tasks + # Tasks commands + 'tasks list example_bash_operator "manual__{date_param}"', + 'tasks get example_bash_operator "manual__{date_param}" runme_0', + "tasks clear example_bash_operator --dry-run", + # runme_0 completes as "success" once the triggered run finishes, so updating it + # to "success" is rejected with 409 "already in success state". Use "failed" to + # exercise a real state transition (valid states: success, failed, skipped). + 'tasks update example_bash_operator "manual__{date_param}" runme_0 --new-state=failed', + # XCom commands - need a DAG run with completed tasks 'xcom add example_bash_operator "manual__{date_param}" runme_0 {xcom_key} \'{{"test": "value"}}\'', 'xcom get example_bash_operator "manual__{date_param}" runme_0 {xcom_key}', 'xcom list example_bash_operator "manual__{date_param}" runme_0', diff --git a/airflow-ctl/docs/images/command_hashes.txt b/airflow-ctl/docs/images/command_hashes.txt index 36eec17c2eb9f..4458f44858e55 100644 --- a/airflow-ctl/docs/images/command_hashes.txt +++ b/airflow-ctl/docs/images/command_hashes.txt @@ -1,4 +1,4 @@ -main:27a22c00dcf32e7a1a4f06672dc8e3c8 +main:164bc97843d5be583c0b48f7a34dc8c8 assets:6419e20452692f577c4c6f570b74be0c auth:d79e9c7d00c432bdbcbc2a86e2e32053 backfill:74c8737b0a62a86ed3605fa9e6165874 @@ -12,4 +12,5 @@ providers:34502fe09dc0b8b0a13e7e46efdffda6 variables:f8fc76d3d398b2780f4e97f7cd816646 version:31f4efdf8de0dbaaa4fac71ff7efecc3 plugins:4864fd8f356704bd2b3cd1aec3567e35 +tasks:7ab24cac521242b6b6012e2bcd317831 auth login:9fe2bb1dd5c602beea2eefb33a2b20a8 diff --git a/airflow-ctl/docs/images/output_main.svg b/airflow-ctl/docs/images/output_main.svg index f586877bce8eb..f087581896312 100644 --- a/airflow-ctl/docs/images/output_main.svg +++ b/airflow-ctl/docs/images/output_main.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + - + - + - - Usage:airflowctl [-hGROUP_OR_COMMAND... - -Positional Arguments: -GROUP_OR_COMMAND - -    Groups -assetsPerform Assets operations -authManage authentication for CLI. Either pass token from -environment variable/parameter or pass username and -password. -backfillPerform Backfill operations -configPerform Config operations -connectionsPerform Connections operations -dagrunPerform DagRun operations -dagsPerform Dags operations -jobsPerform Jobs operations -pluginsPerform Plugins operations -poolsPerform Pools operations -providersPerform Providers operations -variablesPerform Variables operations -xcomPerform XCom operations - -    Commands: -versionShow version information - -Options: --h--helpshow this help message and exit + + Usage:airflowctl [-hGROUP_OR_COMMAND... + +Positional Arguments: +GROUP_OR_COMMAND + +    Groups +assetsPerform Assets operations +authManage authentication for CLI. Either pass token from +environment variable/parameter or pass username and +password. +backfillPerform Backfill operations +configPerform Config operations +connectionsPerform Connections operations +dagrunPerform DagRun operations +dagsPerform Dags operations +jobsPerform Jobs operations +pluginsPerform Plugins operations +poolsPerform Pools operations +providersPerform Providers operations +tasksPerform Tasks operations +variablesPerform Variables operations +xcomPerform XCom operations + +    Commands: +versionShow version information + +Options: +-h--helpshow this help message and exit diff --git a/airflow-ctl/docs/images/output_tasks.svg b/airflow-ctl/docs/images/output_tasks.svg new file mode 100644 index 0000000000000..b5760191492c2 --- /dev/null +++ b/airflow-ctl/docs/images/output_tasks.svg @@ -0,0 +1,117 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Usage:airflowctl tasks [-hCOMMAND... + +Perform Tasks operations + +Positional Arguments: +COMMAND +clearClear task instance state, optionally filtering by task IDs  +or run scope +getRetrieve a task instance by Dag ID, run ID, and task ID +listList all task instances for a given Dag run +updateUpdate task instance state or note; use --map-index to target +a single mapped instance + +Options: +-h--helpshow this help message and exit + + + + diff --git a/airflow-ctl/src/airflowctl/api/client.py b/airflow-ctl/src/airflowctl/api/client.py index de2cec4ea5f33..278de48b4bbcc 100644 --- a/airflow-ctl/src/airflowctl/api/client.py +++ b/airflow-ctl/src/airflowctl/api/client.py @@ -57,6 +57,7 @@ PoolsOperations, ProvidersOperations, ServerResponseError, + TasksOperations, VariablesOperations, VersionOperations, XComOperations, @@ -474,6 +475,12 @@ def plugins(self): """Operations related to plugins.""" return PluginsOperations(self) + @lru_cache() # type: ignore[prop-decorator] + @property + def tasks(self): + """Operations related to tasks.""" + return TasksOperations(self) + # API Client Decorator for CLI Actions @contextlib.contextmanager diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index ad0de6d008b19..ffdb76cd6da89 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -39,6 +39,7 @@ BulkBodyPoolBody, BulkBodyVariableBody, BulkResponse, + ClearTaskInstancesBody, Config, ConnectionBody, ConnectionCollectionResponse, @@ -59,6 +60,7 @@ ImportErrorCollectionResponse, ImportErrorResponse, JobCollectionResponse, + PatchTaskInstanceBody, PluginCollectionResponse, PluginImportErrorCollectionResponse, PoolBody, @@ -68,6 +70,8 @@ ProviderCollectionResponse, QueuedEventCollectionResponse, QueuedEventResponse, + TaskInstanceCollectionResponse, + TaskInstanceResponse, TriggerDAGRunPostBody, VariableBody, VariableCollectionResponse, @@ -733,10 +737,16 @@ def delete(self, pool: str) -> str | ServerResponseError: def update(self, pool_body: PoolPatchBody) -> PoolResponse | ServerResponseError: """Update a pool.""" - try: - self.response = self.client.patch( - f"pools/{pool_body.pool}", json=pool_body.model_dump(mode="json") - ) + # Workaround: the server's PATCH handler validates the partial body + # against ``BasePool`` (see airflow-core/.../services/public/pools.py) + # which requires ``include_deferred``. Omitting it fails with + # "Field required", sending ``null`` fails with "bool_type". Always + # send ``include_deferred`` (defaulting to False when unset) so PATCH + # requests are accepted until the server switches to a partial validator. + body = pool_body.model_dump(mode="json", exclude_none=True) + body.setdefault("include_deferred", False) + try: + self.response = self.client.patch(f"pools/{pool_body.pool}", json=body) return PoolResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -950,3 +960,45 @@ def list_import_errors(self) -> PluginImportErrorCollectionResponse | ServerResp return PluginImportErrorCollectionResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e + + +class TasksOperations(BaseOperations): + """Tasks operations.""" + + def get(self, dag_id: str, dag_run_id: str, task_id: str) -> TaskInstanceResponse: + """Get a task instance.""" + self.response = self.client.get(f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}") + return TaskInstanceResponse.model_validate_json(self.response.content) + + def list(self, dag_id: str, dag_run_id: str) -> TaskInstanceCollectionResponse | ServerResponseError: + """List task instances.""" + return super().execute_list( + path=f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances", + data_model=TaskInstanceCollectionResponse, + ) + + def clear( + self, dag_id: str, body: ClearTaskInstancesBody + ) -> TaskInstanceCollectionResponse | ServerResponseError: + """Clear task instances.""" + self.response = self.client.post( + f"dags/{dag_id}/clearTaskInstances", + json=body.model_dump(mode="json", exclude_unset=True), + ) + return TaskInstanceCollectionResponse.model_validate_json(self.response.content) + + def update( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + body: PatchTaskInstanceBody, + map_index: int | None = None, + ) -> TaskInstanceCollectionResponse: + """Update task instance state. When map_index is given, only that mapped instance is affected.""" + if map_index is not None: + path = f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}" + else: + path = f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}" + self.response = self.client.patch(path, json=body.model_dump(mode="json", exclude_unset=True)) + return TaskInstanceCollectionResponse.model_validate_json(self.response.content) diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index d0901e795ef48..6b63708e3c26e 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -26,6 +26,8 @@ import inspect import os import sys +import types as builtin_types +import typing from argparse import Namespace from collections.abc import Callable, Iterable from enum import Enum @@ -52,6 +54,17 @@ BUILD_DOCS = "BUILDING_AIRFLOW_DOCS" in os.environ +def _is_list_annotation(annotation: Any) -> bool: + """Check whether a Pydantic field annotation is a list type (including Optional[list[...]]).""" + origin = typing.get_origin(annotation) + if origin is list: + return True + # Handle both typing.Union (Optional[list[...]]) and PEP-604 X | Y (types.UnionType) + if origin is typing.Union or isinstance(annotation, builtin_types.UnionType): + return any(_is_list_annotation(arg) for arg in typing.get_args(annotation) if arg is not type(None)) + return False + + def lazy_load_command(import_path: str) -> Callable: """Create a lazy loader for command.""" _, _, name = import_path.rpartition(".") @@ -399,7 +412,17 @@ def __init__(self, file_path: str | Path | None = None): # Exclude parameters that are not needed for CLI from datamodels self.excluded_parameters = ["schema_"] # This list is used to determine if the command/operation needs to output data - self.output_command_list = ["list", "get", "create", "delete", "update", "trigger", "add", "edit"] + self.output_command_list = [ + "list", + "get", + "create", + "delete", + "update", + "trigger", + "add", + "edit", + "clear", + ] self.exclude_operation_names = ["LoginOperations", "VersionOperations", "BaseOperations"] self.exclude_method_names = [ "error", @@ -411,6 +434,8 @@ def __init__(self, file_path: str | Path | None = None): ] self.excluded_output_keys = [ "total_entries", + "next_cursor", + "previous_cursor", ] def _inspect_operations(self) -> None: @@ -587,7 +612,7 @@ def _create_arg_for_non_primitive_type( arg_type=self._python_type_from_string(field_type.annotation), arg_action=argparse.BooleanOptionalAction if field_type.annotation is bool else None, # type: ignore arg_help=f"{field} for {parameter_key} operation", - arg_default=False if field_type.annotation is bool else None, + arg_default=None, ) ) else: @@ -602,7 +627,7 @@ def _create_arg_for_non_primitive_type( arg_type=self._python_type_from_string(annotation), arg_action=argparse.BooleanOptionalAction if annotation is bool else None, # type: ignore arg_help=f"{field} for {parameter_key} operation", - arg_default=False if annotation is bool else None, + arg_default=None, ) ) return commands @@ -717,10 +742,19 @@ def _get_func(args: Namespace, api_operation: dict, api_client: Client = NEW_API datamodel_param_name = parameter_key if expanded_parameter in self.excluded_parameters: continue - if expanded_parameter in args_dict.keys(): + if ( + expanded_parameter in args_dict.keys() + and args_dict[expanded_parameter] is not None + ): + val = args_dict[expanded_parameter] + if isinstance(val, str) and expanded_parameter in datamodel.model_fields: + if _is_list_annotation( + datamodel.model_fields[expanded_parameter].annotation + ): + val = [v.strip() for v in val.split(",") if v.strip()] method_params[parameter_key][ self._sanitize_method_param_key(expanded_parameter) - ] = args_dict[expanded_parameter] + ] = val if datamodel: if datamodel_param_name: diff --git a/airflow-ctl/src/airflowctl/ctl/help_texts.yaml b/airflow-ctl/src/airflowctl/ctl/help_texts.yaml index d0ebdcfeaebdd..f1575bab1703b 100644 --- a/airflow-ctl/src/airflowctl/ctl/help_texts.yaml +++ b/airflow-ctl/src/airflowctl/ctl/help_texts.yaml @@ -102,3 +102,9 @@ xcom: plugins: list: "List all installed Airflow plugins" list-import-errors: "List all plugin import errors" + +tasks: + get: "Retrieve a task instance by Dag ID, run ID, and task ID" + list: "List all task instances for a given Dag run" + clear: "Clear task instance state, optionally filtering by task IDs or run scope" + update: "Update task instance state or note; use --map-index to target a single mapped instance" diff --git a/airflow-ctl/tests/airflow_ctl/api/test_operations.py b/airflow-ctl/tests/airflow_ctl/api/test_operations.py index 53dfb3217175b..93f9423ca4fe6 100644 --- a/airflow-ctl/tests/airflow_ctl/api/test_operations.py +++ b/airflow-ctl/tests/airflow_ctl/api/test_operations.py @@ -48,6 +48,7 @@ BulkCreateActionPoolBody, BulkCreateActionVariableBody, BulkResponse, + ClearTaskInstancesBody, Config, ConfigOption, ConfigSection, @@ -79,18 +80,23 @@ ImportErrorResponse, JobCollectionResponse, JobResponse, + PatchTaskInstanceBody, PluginCollectionResponse, PluginImportErrorCollectionResponse, PluginImportErrorResponse, PluginResponse, PoolBody, PoolCollectionResponse, + PoolPatchBody, PoolResponse, ProviderCollectionResponse, ProviderResponse, QueuedEventCollectionResponse, QueuedEventResponse, ReprocessBehavior, + TaskInstanceCollectionResponse, + TaskInstanceResponse, + TaskInstanceState, TriggerDAGRunPostBody, VariableBody, VariableCollectionResponse, @@ -1458,6 +1464,40 @@ def handle_request(request: httpx.Request) -> httpx.Response: response = client.pools.bulk(pools=self.pools_bulk_body) assert response == self.pool_bulk_response + def test_update_defaults_unset_include_deferred_to_false(self): + """Unset include_deferred must default to False to satisfy the server's PATCH validator.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == f"/api/v2/pools/{self.pool_name}" + assert json.loads(request.content.decode()) == { + "pool": self.pool_name, + "slots": 10, + "include_deferred": False, + } + return httpx.Response(200, json=json.loads(self.pool_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.pools.update(pool_body=PoolPatchBody(pool=self.pool_name, slots=10)) + assert response == self.pool_response + + def test_update_preserves_explicit_include_deferred(self): + """Explicit include_deferred value from the user must be preserved, not overwritten.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == f"/api/v2/pools/{self.pool_name}" + assert json.loads(request.content.decode()) == { + "pool": self.pool_name, + "slots": 10, + "include_deferred": True, + } + return httpx.Response(200, json=json.loads(self.pool_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.pools.update( + pool_body=PoolPatchBody(pool=self.pool_name, slots=10, include_deferred=True) + ) + assert response == self.pool_response + def test_delete(self): def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == f"/api/v2/pools/{self.pool_name}" @@ -1978,10 +2018,10 @@ class TestPluginsOperations: ) def test_list(self): - """Test listing plugins""" + """Test listing plugins.""" def handle_request(request: httpx.Request) -> httpx.Response: - assert request.url.path == ("/api/v2/plugins") + assert request.url.path == "/api/v2/plugins" return httpx.Response(200, json=json.loads(self.plugin_collection_response.model_dump_json())) client = make_api_client(transport=httpx.MockTransport(handle_request)) @@ -1989,7 +2029,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert response == self.plugin_collection_response def test_list_import_errors(self): - """Test listing plugin import errors""" + """Test listing plugin import errors.""" def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == "/api/v2/plugins/importErrors" @@ -2000,3 +2040,193 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_api_client(transport=httpx.MockTransport(handle_request)) response = client.plugins.list_import_errors() assert response == self.plugin_import_error_collection_response + + +class TestTasksOperations: + """Test suite for Tasks operations.""" + + dag_id: str = "test_dag" + dag_run_id: str = "manual__2025-01-24T00:00:00+00:00" + task_id: str = "test_task" + + task_instance_response = TaskInstanceResponse( + id=uuid.uuid4(), + task_id=task_id, + dag_id=dag_id, + dag_run_id=dag_run_id, + map_index=-1, + logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0), + run_after=datetime.datetime(2025, 1, 24, 0, 0, 0), + start_date=datetime.datetime(2025, 1, 24, 0, 0, 1), + end_date=datetime.datetime(2025, 1, 24, 0, 0, 10), + duration=9.0, + state=TaskInstanceState.SUCCESS, + try_number=1, + max_tries=0, + task_display_name=task_id, + dag_display_name=dag_id, + hostname="hostname", + unixname="airflow", + pool="default_pool", + pool_slots=1, + queue="default", + priority_weight=1, + operator="EmptyOperator", + executor_config="{}", + note=None, + ) + + task_instance_collection_response = TaskInstanceCollectionResponse( + task_instances=[task_instance_response], + total_entries=1, + ) + + def test_get(self): + """Test fetching a single task instance.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}" + ) + return httpx.Response(200, json=json.loads(self.task_instance_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.tasks.get( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + ) + assert response == self.task_instance_response + + def test_list(self): + """Test listing task instances for a DAG run.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == (f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances") + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.tasks.list( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + ) + assert response == self.task_instance_collection_response + + def test_clear(self): + """Test clearing task instances with default options.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == f"/api/v2/dags/{self.dag_id}/clearTaskInstances" + request_body = json.loads(request.content) + assert request_body["dry_run"] is True + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + body = ClearTaskInstancesBody(dry_run=True) + response = client.tasks.clear( + dag_id=self.dag_id, + body=body, + ) + assert response == self.task_instance_collection_response + + def test_clear_with_options(self): + """Test clearing task instances with specific options.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == f"/api/v2/dags/{self.dag_id}/clearTaskInstances" + request_body = json.loads(request.content) + assert request_body["dry_run"] is False + assert request_body["only_failed"] is True + assert request_body["task_ids"] == [self.task_id] + assert request_body["dag_run_id"] == self.dag_run_id + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + body = ClearTaskInstancesBody( + dry_run=False, + only_failed=True, + task_ids=[self.task_id], + dag_run_id=self.dag_run_id, + ) + response = client.tasks.clear( + dag_id=self.dag_id, + body=body, + ) + assert response == self.task_instance_collection_response + + def test_update(self): + """Test updating a task instance state — API always returns a collection.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}" + ) + request_body = json.loads(request.content) + assert request_body["new_state"] == TaskInstanceState.FAILED.value + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + body = PatchTaskInstanceBody(new_state=TaskInstanceState.FAILED) + response = client.tasks.update( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + body=body, + ) + assert response == self.task_instance_collection_response + + def test_update_with_note(self): + """Test updating a task instance with a note only (no new_state).""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}" + ) + request_body = json.loads(request.content) + assert request_body["note"] == "Manually marked as success" + assert "new_state" not in request_body + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + body = PatchTaskInstanceBody(note="Manually marked as success") + response = client.tasks.update( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + body=body, + ) + assert response == self.task_instance_collection_response + + def test_update_with_map_index(self): + """Test that map_index routes to the indexed endpoint, scoping the update to one instance.""" + map_index = 0 + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}" + f"/taskInstances/{self.task_id}/{map_index}" + ) + return httpx.Response( + 200, json=json.loads(self.task_instance_collection_response.model_dump_json()) + ) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + body = PatchTaskInstanceBody(new_state=TaskInstanceState.SUCCESS) + response = client.tasks.update( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + body=body, + map_index=map_index, + ) + assert response == self.task_instance_collection_response diff --git a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py index 945b5abb83378..e64a560a66032 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py @@ -94,7 +94,7 @@ def test_args_create(): { "help": "run_backwards for backfill operation", "action": BooleanOptionalAction, - "default": False, + "default": None, "type": bool, "dest": None, }, @@ -321,6 +321,41 @@ def list(self, is_alive: bool | None = None) -> JobCollectionResponse | ServerRe assert is_alive_arg.kwargs["default"] is None assert is_alive_arg.kwargs["type"] is bool + def test_command_factory_body_bool_field_defaults_to_none(self, tmp_path): + """Bool fields expanded from a Pydantic body must default to None, not False. + + Otherwise the dispatcher's ``is not None`` filter passes an unset flag + through as ``False``, silently overriding API-side defaults that are + ``True`` (e.g. ``ClearTaskInstancesBody.only_failed``). + """ + temp_file = self._save_temp_operations_py( + tmp_path=tmp_path, + file_content=""" + class TasksOperations(BaseOperations): + def clear(self, dag_id: str, body: ClearTaskInstancesBody): + self.response = self.client.post( + f"dags/{dag_id}/clearTaskInstances", + json=body.model_dump(mode="json"), + ) + return self.response + """, + ) + + command_factory = CommandFactory(file_path=str(temp_file)) + clear_args: list = [] + for generated_group_command in command_factory.group_commands: + if generated_group_command.name != "tasks": + continue + for sub_command in generated_group_command.subcommands: + if sub_command.name == "clear": + clear_args = list(sub_command.args) + break + + for flag in ("--dry-run", "--only-failed", "--reset-dag-runs", "--run-on-latest-version"): + arg = next(a for a in clear_args if a.flags == (flag,)) + assert arg.kwargs["action"] == BooleanOptionalAction, flag + assert arg.kwargs["default"] is None, flag + def test_command_factory_required_primitive_param_is_positional(self, tmp_path): """Required primitive parameters (no default, not Optional) become positional arguments. diff --git a/scripts/in_container/run_capture_airflowctl_help.py b/scripts/in_container/run_capture_airflowctl_help.py index 9529dbe04390c..00001dc97048d 100644 --- a/scripts/in_container/run_capture_airflowctl_help.py +++ b/scripts/in_container/run_capture_airflowctl_help.py @@ -48,6 +48,7 @@ "variables", "version", "plugins", + "tasks", ] SUBCOMMANDS = [