From 3df38ff0b456f06335c36c9dfc6fb8a22d7f4c44 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 22 May 2026 15:06:41 +0530 Subject: [PATCH 1/7] AIP-103: Add patch task state core API and support for expires_at in set API --- .../core_api/datamodels/task_state.py | 20 +++- .../openapi/v2-rest-api-generated.yaml | 109 +++++++++++++++++- .../core_api/routes/public/task_state.py | 86 +++++++++++--- .../airflow/ui/openapi-gen/queries/common.ts | 1 + .../airflow/ui/openapi-gen/queries/queries.ts | 30 ++++- .../ui/openapi-gen/requests/schemas.gen.ts | 40 ++++++- .../ui/openapi-gen/requests/services.gen.ts | 39 ++++++- .../ui/openapi-gen/requests/types.gen.ts | 50 ++++++++ .../core_api/routes/public/test_task_state.py | 53 +++++++++ .../airflowctl/api/datamodels/generated.py | 18 +++ 10 files changed, 427 insertions(+), 19 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index 856de74a0877b..dc05631cf3c57 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -17,8 +17,9 @@ from __future__ import annotations from datetime import datetime +from typing import Literal -from pydantic import Field +from pydantic import AwareDatetime, Field from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -40,6 +41,21 @@ class TaskStateCollectionResponse(BaseModel): class TaskStateBody(StrictBaseModel): - """Request body for setting a task state value.""" + """ + Request body for setting a task state value. + + ``expires_at`` controls expiry: + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + - ``null``: never expire. + - aware datetime: expire at that time. + """ + + value: str = Field(max_length=65535) + expires_at: AwareDatetime | None | Literal["default"] = "default" + + +class TaskStatePatchBody(StrictBaseModel): + """Request body for patching only the value of an existing task state key.""" value: str = Field(max_length=65535) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 23f2622b13b3f..a2a8afa7b25c0 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -6150,6 +6150,81 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - Task State + summary: Patch Task State + description: Update the value of an existing task state key. + operationId: patch_task_state + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: key + in: path + required: true + schema: + type: string + title: Key + - name: map_index + in: query + required: false + schema: + type: integer + minimum: -1 + default: -1 + title: Map Index + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/TaskStatePatchBody' + responses: + '204': + description: Successful Response + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' delete: tags: - Task State @@ -15813,12 +15888,31 @@ components: type: string maxLength: 65535 title: Value + expires_at: + anyOf: + - type: string + format: date-time + - type: string + const: default + - type: 'null' + title: Expires At + default: default additionalProperties: false type: object required: - value title: TaskStateBody - description: Request body for setting a task state value. + description: 'Request body for setting a task state value. + + + ``expires_at`` controls expiry: + + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + + - ``null``: never expire. + + - aware datetime: expire at that time.' TaskStateCollectionResponse: properties: task_states: @@ -15835,6 +15929,19 @@ components: - total_entries title: TaskStateCollectionResponse description: All task state entries for a task instance. + TaskStatePatchBody: + properties: + value: + type: string + maxLength: 65535 + title: Value + additionalProperties: false + type: object + required: + - value + title: TaskStatePatchBody + description: Request body for patching only the value of an existing task state + key. TaskStateResponse: properties: key: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py index 138380232a8aa..4580c904c2643 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from datetime import datetime, timedelta, timezone from typing import Annotated from fastapi import Depends, HTTPException, Query, status @@ -29,10 +30,12 @@ from airflow.api_fastapi.core_api.datamodels.task_state import ( TaskStateBody, TaskStateCollectionResponse, + TaskStatePatchBody, TaskStateResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import requires_access_dag +from airflow.configuration import conf from airflow.models.task_state import TaskStateModel from airflow.models.taskinstance import TaskInstance as TI from airflow.state import get_state_backend @@ -47,6 +50,36 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, map_index: int) -> Ta return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) +def _resolve_expires_at(expires_at: datetime | None | str) -> datetime | None: + """ + Resolve the expires_at value from the request body. + + - ``"default"``: apply configured default_retention_days + - ``None``: never expire + - datetime: use as-is + """ + if expires_at == "default": + days = conf.getint("state_store", "default_retention_days") + return datetime.now(tz=timezone.utc) + timedelta(days=days) + return None + + +def _require_ti(dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: SessionDep) -> None: + ti_exists = session.scalar( + select(TI.task_id).where( + TI.dag_id == dag_id, + TI.run_id == dag_run_id, + TI.task_id == task_id, + TI.map_index == map_index, + ) + ) + if ti_exists is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}", + ) + + @task_state_router.get( "", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), @@ -147,24 +180,51 @@ def set_task_state( map_index: Annotated[int, Query(ge=-1)] = -1, ) -> None: """Set a task state value. Creates or overwrites the key.""" - ti_exists = session.scalar( - select(TI.task_id).where( - TI.dag_id == dag_id, - TI.run_id == dag_run_id, - TI.task_id == task_id, - TI.map_index == map_index, + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + expires_at = _resolve_expires_at(body.expires_at) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) + try: + get_state_backend().set(scope, key, body.value, expires_at=expires_at, session=session) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@task_state_router.patch( + "/{key:path}", + status_code=status.HTTP_204_NO_CONTENT, + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def patch_task_state( + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + body: TaskStatePatchBody, + session: SessionDep, + map_index: Annotated[int, Query(ge=-1)] = -1, +) -> None: + """Update the value of an existing task state key.""" + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + + existing = session.execute( + select(TaskStateModel.expires_at).where( + TaskStateModel.dag_id == dag_id, + TaskStateModel.run_id == dag_run_id, + TaskStateModel.task_id == task_id, + TaskStateModel.map_index == map_index, + TaskStateModel.key == key, ) - ) - if ti_exists is None: + ).one_or_none() + + if existing is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}", + detail=f"Task state key {key!r} not found", ) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) - try: - get_state_backend().set(scope, key, body.value, session=session) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + get_state_backend().set(scope, key, body.value, expires_at=existing.expires_at, session=session) @task_state_router.delete( diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index f8e3dbe9af638..632d311157631 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -1067,6 +1067,7 @@ export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited>; export type PoolServicePatchPoolMutationResult = Awaited>; export type PoolServiceBulkPoolsMutationResult = Awaited>; +export type TaskStateServicePatchTaskStateMutationResult = Awaited>; export type XcomServiceUpdateXcomEntryMutationResult = Awaited>; export type VariableServicePatchVariableMutationResult = Awaited>; export type VariableServiceBulkVariablesMutationResult = Awaited>; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 8c0976ec328e9..3885f773de7cc 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -2,7 +2,7 @@ import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query"; import { AssetService, AssetStateService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TaskStateService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen"; -import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; +import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TaskStatePatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; import * as Common from "./common"; /** * Get Assets @@ -2801,6 +2801,34 @@ export const usePoolServiceBulkPools = ({ mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise, ...options }); /** +* Patch Task State +* Update the value of an existing task state key. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.key +* @param data.requestBody +* @param data.mapIndex +* @returns void Successful Response +* @throws ApiError +*/ +export const useTaskStateServicePatchTaskState = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) => TaskStateService.patchTaskState({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) as unknown as Promise, ...options }); +/** * Update Xcom Entry * Update an existing XCom entry. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 4a4b95f183023..22c2333a49a64 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -6985,13 +6985,36 @@ export const $TaskStateBody = { type: 'string', maxLength: 65535, title: 'Value' + }, + expires_at: { + anyOf: [ + { + type: 'string', + format: 'date-time' + }, + { + type: 'string', + const: 'default' + }, + { + type: 'null' + } + ], + title: 'Expires At', + default: 'default' } }, additionalProperties: false, type: 'object', required: ['value'], title: 'TaskStateBody', - description: 'Request body for setting a task state value.' + description: `Request body for setting a task state value. + +\`\`expires_at\`\` controls expiry: + +- \`\`"default"\`\`: apply the configured \`\`[state_store] default_retention_days\`\`. +- \`\`null\`\`: never expire. +- aware datetime: expire at that time.` } as const; export const $TaskStateCollectionResponse = { @@ -7014,6 +7037,21 @@ export const $TaskStateCollectionResponse = { description: 'All task state entries for a task instance.' } as const; +export const $TaskStatePatchBody = { + properties: { + value: { + type: 'string', + maxLength: 65535, + title: 'Value' + } + }, + additionalProperties: false, + type: 'object', + required: ['value'], + title: 'TaskStatePatchBody', + description: 'Request body for patching only the value of an existing task state key.' +} as const; + export const $TaskStateResponse = { properties: { key: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 7546818497787..ed1f662df0395 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, PatchTaskStateData, PatchTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -3814,6 +3814,43 @@ export class TaskStateService { }); } + /** + * Patch Task State + * Update the value of an existing task state key. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.key + * @param data.requestBody + * @param data.mapIndex + * @returns void Successful Response + * @throws ApiError + */ + public static patchTaskState(data: PatchTaskStateData): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/states/{key}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + key: data.key + }, + query: { + map_index: data.mapIndex + }, + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + /** * Delete Task State * Delete a single task state key. No-op if the key does not exist. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 70c9fa131c191..901300f2e4563 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1713,9 +1713,16 @@ export type TaskResponse = { /** * Request body for setting a task state value. + * + * ``expires_at`` controls expiry: + * + * - ``"default"``: apply the configured ``[state_store] default_retention_days``. + * - ``null``: never expire. + * - aware datetime: expire at that time. */ export type TaskStateBody = { value: string; + expires_at?: string | "default" | null; }; /** @@ -1726,6 +1733,13 @@ export type TaskStateCollectionResponse = { total_entries: number; }; +/** + * Request body for patching only the value of an existing task state key. + */ +export type TaskStatePatchBody = { + value: string; +}; + /** * A single task state key/value pair with metadata. */ @@ -3966,6 +3980,17 @@ export type SetTaskStateData = { export type SetTaskStateResponse = void; +export type PatchTaskStateData = { + dagId: string; + dagRunId: string; + key: string; + mapIndex?: number; + requestBody: TaskStatePatchBody; + taskId: string; +}; + +export type PatchTaskStateResponse = void; + export type DeleteTaskStateData = { dagId: string; dagRunId: string; @@ -7264,6 +7289,31 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchTaskStateData; + res: { + /** + * Successful Response + */ + 204: void; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; delete: { req: DeleteTaskStateData; res: { diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index c53212fa5e444..bdc854fe1b279 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -25,6 +25,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.types import DagRunType +from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs pytestmark = pytest.mark.db_test @@ -211,10 +212,62 @@ def test_key_with_slash_is_supported(self, test_client): assert response.status_code == 204 assert test_client.get(f"{BASE_URL}/workflow/step_1").json()["key"] == "workflow/step_1" + def test_new_key_default_retention_applies_config(self, test_client, time_machine): + time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) + with conf_vars({("state_store", "default_retention_days"): "7"}): + test_client.put(f"{BASE_URL}/job_id", json={"value": "v"}) + + resp = test_client.get(f"{BASE_URL}/job_id").json() + assert resp["expires_at"] == "2026-01-08T00:00:00+00:00" + + def test_new_key_never_expiry(self, test_client): + """PUT with expires_at=null stores a key that never expires.""" + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": None}) + assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] is None + + def test_new_key_explicit_expiry(self, test_client, time_machine): + """PUT with an explicit datetime uses that as expires_at.""" + time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) + target = "2026-01-31T00:00:00+00:00" + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": target}) + assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] == target + + def test_put_overwrites_expiry_on_existing_key(self, test_client, time_machine): + """PUT on an existing key replaces expires_at with whatever the body specifies.""" + time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) + test_client.put(f"{BASE_URL}/job_id", json={"value": "v1", "expires_at": "2026-01-31T00:00:00+00:00"}) + + # second request but with null expires_at + test_client.put(f"{BASE_URL}/job_id", json={"value": "v2", "expires_at": None}) + + resp = test_client.get(f"{BASE_URL}/job_id").json() + assert resp["value"] == "v2" + assert resp["expires_at"] is None + def test_unauthorized_returns_401(self, unauthenticated_test_client): assert unauthenticated_test_client.put(f"{BASE_URL}/job_id", json={"value": "v"}).status_code == 401 +class TestPatchTaskState(TestTaskStateEndpoint): + def test_patch_updates_value(self, test_client): + _create_task_state(self._session, "job_id", "v1", self.dag_run) + self._session.commit() + + assert test_client.patch(f"{BASE_URL}/job_id", json={"value": "v2"}).status_code == 204 + assert test_client.get(f"{BASE_URL}/job_id").json()["value"] == "v2" + + def test_patch_missing_key_returns_404(self, test_client): + assert test_client.patch(f"{BASE_URL}/nonexistent", json={"value": "v"}).status_code == 404 + + def test_patch_empty_body_returns_422(self, test_client): + _create_task_state(self._session, "job_id", "v", self.dag_run) + self._session.commit() + assert test_client.patch(f"{BASE_URL}/job_id", json={}).status_code == 422 + + def test_unauthorized_returns_401(self, unauthenticated_test_client): + assert unauthenticated_test_client.patch(f"{BASE_URL}/job_id", json={"value": "v"}).status_code == 401 + + class TestDeleteTaskState(TestTaskStateEndpoint): def test_deletes_key(self, test_client): _create_task_state(self._session, "job_id", "spark_001", self.dag_run) diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index f05fa65cf56f8..1db7481201009 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -976,6 +976,24 @@ class TaskOutletAssetReference(BaseModel): class TaskStateBody(BaseModel): """ Request body for setting a task state value. + + ``expires_at`` controls expiry: + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + - ``null``: never expire. + - aware datetime: expire at that time. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: Annotated[str, Field(max_length=65535, title="Value")] + expires_at: Annotated[datetime | str | None, Field(title="Expires At")] = "default" + + +class TaskStatePatchBody(BaseModel): + """ + Request body for patching only the value of an existing task state key. """ model_config = ConfigDict( From 41adb36ab81224c34a4a36be7c837976ec617b95 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 22 May 2026 15:25:12 +0530 Subject: [PATCH 2/7] AIP-103: Add patch task state core API and support for expires_at in set API --- .../core_api/routes/public/test_task_state.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index bdc854fe1b279..015d442c02136 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -213,12 +213,12 @@ def test_key_with_slash_is_supported(self, test_client): assert test_client.get(f"{BASE_URL}/workflow/step_1").json()["key"] == "workflow/step_1" def test_new_key_default_retention_applies_config(self, test_client, time_machine): - time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) with conf_vars({("state_store", "default_retention_days"): "7"}): - test_client.put(f"{BASE_URL}/job_id", json={"value": "v"}) + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": "default"}) resp = test_client.get(f"{BASE_URL}/job_id").json() - assert resp["expires_at"] == "2026-01-08T00:00:00+00:00" + assert resp["expires_at"] == "2026-01-08T00:00:00Z" def test_new_key_never_expiry(self, test_client): """PUT with expires_at=null stores a key that never expires.""" @@ -227,15 +227,15 @@ def test_new_key_never_expiry(self, test_client): def test_new_key_explicit_expiry(self, test_client, time_machine): """PUT with an explicit datetime uses that as expires_at.""" - time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) - target = "2026-01-31T00:00:00+00:00" + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) + target = "2026-01-31T00:00:00Z" test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": target}) assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] == target def test_put_overwrites_expiry_on_existing_key(self, test_client, time_machine): """PUT on an existing key replaces expires_at with whatever the body specifies.""" - time_machine.travel("2026-01-01T00:00:00+00:00", tick=False) - test_client.put(f"{BASE_URL}/job_id", json={"value": "v1", "expires_at": "2026-01-31T00:00:00+00:00"}) + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) + test_client.put(f"{BASE_URL}/job_id", json={"value": "v1", "expires_at": "2026-01-31T00:00:00Z"}) # second request but with null expires_at test_client.put(f"{BASE_URL}/job_id", json={"value": "v2", "expires_at": None}) From 4aadf9711e6cc598642593b72ab9364c0fc3bfc7 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 25 May 2026 11:55:19 +0530 Subject: [PATCH 3/7] handling comments from jason --- .../core_api/openapi/v2-rest-api-generated.yaml | 5 ++++- .../api_fastapi/core_api/routes/public/task_state.py | 8 ++++---- .../src/airflow/ui/openapi-gen/queries/queries.ts | 2 +- .../src/airflow/ui/openapi-gen/requests/services.gen.ts | 2 +- .../src/airflow/ui/openapi-gen/requests/types.gen.ts | 4 ++-- .../api_fastapi/core_api/routes/public/test_task_state.py | 2 +- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index a2a8afa7b25c0..ff0e82c7658c7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -6199,8 +6199,11 @@ paths: schema: $ref: '#/components/schemas/TaskStatePatchBody' responses: - '204': + '200': description: Successful Response + content: + application/json: + schema: {} '401': content: application/json: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py index 4580c904c2643..a0d700977c5fa 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone -from typing import Annotated +from typing import Annotated, Literal from fastapi import Depends, HTTPException, Query, status from sqlalchemy import select @@ -50,7 +50,7 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, map_index: int) -> Ta return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) -def _resolve_expires_at(expires_at: datetime | None | str) -> datetime | None: +def _resolve_expires_at(expires_at: datetime | None | Literal["default"]) -> datetime | None: """ Resolve the expires_at value from the request body. @@ -61,7 +61,7 @@ def _resolve_expires_at(expires_at: datetime | None | str) -> datetime | None: if expires_at == "default": days = conf.getint("state_store", "default_retention_days") return datetime.now(tz=timezone.utc) + timedelta(days=days) - return None + return expires_at def _require_ti(dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: SessionDep) -> None: @@ -191,7 +191,7 @@ def set_task_state( @task_state_router.patch( "/{key:path}", - status_code=status.HTTP_204_NO_CONTENT, + status_code=status.HTTP_200_OK, responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], ) diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 3885f773de7cc..01e3d78e69cd7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -2810,7 +2810,7 @@ export const usePoolServiceBulkPools = (options?: Omit { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 901300f2e4563..bfb0d8b7405c1 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -3989,7 +3989,7 @@ export type PatchTaskStateData = { taskId: string; }; -export type PatchTaskStateResponse = void; +export type PatchTaskStateResponse = unknown; export type DeleteTaskStateData = { dagId: string; @@ -7295,7 +7295,7 @@ export type $OpenApiTs = { /** * Successful Response */ - 204: void; + 200: unknown; /** * Unauthorized */ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index 015d442c02136..ec901581513be 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -253,7 +253,7 @@ def test_patch_updates_value(self, test_client): _create_task_state(self._session, "job_id", "v1", self.dag_run) self._session.commit() - assert test_client.patch(f"{BASE_URL}/job_id", json={"value": "v2"}).status_code == 204 + assert test_client.patch(f"{BASE_URL}/job_id", json={"value": "v2"}).status_code == 200 assert test_client.get(f"{BASE_URL}/job_id").json()["value"] == "v2" def test_patch_missing_key_returns_404(self, test_client): From 73c8b583977bc97ffb55719f7dc515a18f28e92f Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 26 May 2026 17:07:12 +0530 Subject: [PATCH 4/7] meeting contract of execution API and core API --- .../core_api/datamodels/task_state.py | 19 ++++++++- .../openapi/v2-rest-api-generated.yaml | 4 +- .../core_api/routes/public/task_state.py | 5 ++- .../ui/openapi-gen/requests/schemas.gen.ts | 4 +- .../ui/openapi-gen/requests/types.gen.ts | 2 +- .../core_api/routes/public/test_task_state.py | 42 +++++++++++++++++++ 6 files changed, 66 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index dc05631cf3c57..54690dcdb6b88 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -16,13 +16,17 @@ # under the License. from __future__ import annotations +import json +import math from datetime import datetime from typing import Literal -from pydantic import AwareDatetime, Field +from pydantic import AwareDatetime, Field, JsonValue, field_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +_MAX_SERIALIZED_BYTES = 65535 + class TaskStateResponse(BaseModel): """A single task state key/value pair with metadata.""" @@ -58,4 +62,15 @@ class TaskStateBody(StrictBaseModel): class TaskStatePatchBody(StrictBaseModel): """Request body for patching only the value of an existing task state key.""" - value: str = Field(max_length=65535) + value: JsonValue + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + if isinstance(v, float) and not math.isfinite(v): + raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") + if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") + return v diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index ff0e82c7658c7..f2112508a2d8b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -15935,9 +15935,7 @@ components: TaskStatePatchBody: properties: value: - type: string - maxLength: 65535 - title: Value + $ref: '#/components/schemas/JsonValue' additionalProperties: false type: object required: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py index a0d700977c5fa..838e7579cbcd9 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from datetime import datetime, timedelta, timezone from typing import Annotated, Literal @@ -224,7 +225,9 @@ def patch_task_state( ) scope = _get_scope(dag_id, dag_run_id, task_id, map_index) - get_state_backend().set(scope, key, body.value, expires_at=existing.expires_at, session=session) + get_state_backend().set( + scope, key, json.dumps(body.value), expires_at=existing.expires_at, session=session + ) @task_state_router.delete( diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 22c2333a49a64..bd93c074cae57 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -7040,9 +7040,7 @@ export const $TaskStateCollectionResponse = { export const $TaskStatePatchBody = { properties: { value: { - type: 'string', - maxLength: 65535, - title: 'Value' + '$ref': '#/components/schemas/JsonValue' } }, additionalProperties: false, diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index bfb0d8b7405c1..b797908a0e9b1 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1737,7 +1737,7 @@ export type TaskStateCollectionResponse = { * Request body for patching only the value of an existing task state key. */ export type TaskStatePatchBody = { - value: string; + value: JsonValue; }; /** diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index ec901581513be..9155855547dfc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import json + import pytest from sqlalchemy import select @@ -264,6 +266,46 @@ def test_patch_empty_body_returns_422(self, test_client): self._session.commit() assert test_client.patch(f"{BASE_URL}/job_id", json={}).status_code == 422 + def test_patch_null_value_returns_422(self, test_client): + _create_task_state(self._session, "job_id", "v", self.dag_run) + self._session.commit() + assert test_client.patch(f"{BASE_URL}/job_id", json={"value": None}).status_code == 422 + + @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) + def test_patch_non_finite_float_returns_422(self, test_client, bad_float): + _create_task_state(self._session, "job_id", "v", self.dag_run) + self._session.commit() + with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): + test_client.patch( + f"{BASE_URL}/job_id", + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + + @pytest.mark.parametrize( + ("value", "expected_db"), + [ + (42, "42"), + ("hello", '"hello"'), + ({"k": 1}, '{"k": 1}'), + ([1, 2], "[1, 2]"), + ], + ) + def test_patch_stores_json_encoded_value(self, test_client, value, expected_db): + _create_task_state(self._session, "job_id", "initial", self.dag_run) + self._session.commit() + test_client.patch(f"{BASE_URL}/job_id", json={"value": value}) + row = self._session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == DAG_ID, + TaskStateModel.run_id == RUN_ID, + TaskStateModel.task_id == TASK_ID, + TaskStateModel.key == "job_id", + ) + ) + self._session.refresh(row) + assert row.value == expected_db + def test_unauthorized_returns_401(self, unauthenticated_test_client): assert unauthenticated_test_client.patch(f"{BASE_URL}/job_id", json={"value": "v"}).status_code == 401 From 0393f6e121298aaf5292eb2d8cb835b99db43d59 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 27 May 2026 13:00:25 +0530 Subject: [PATCH 5/7] nan validation better for PATCH too --- .../api_fastapi/core_api/datamodels/task_state.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index 54690dcdb6b88..7e2944cc12330 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import math from datetime import datetime from typing import Literal @@ -69,8 +68,10 @@ class TaskStatePatchBody(StrictBaseModel): def value_is_json_representable(cls, v: JsonValue) -> JsonValue: if v is None: raise ValueError("value cannot be null") - if isinstance(v, float) and not math.isfinite(v): - raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") - if len(json.dumps(v)) > _MAX_SERIALIZED_BYTES: + try: + serialized = json.dumps(v, allow_nan=False) + except ValueError: + raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable") + if len(serialized) > _MAX_SERIALIZED_BYTES: raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") return v From 4abb3bffecdd349fd1d9d1e528de7378b94e566f Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 29 May 2026 10:27:34 +0530 Subject: [PATCH 6/7] fixing failing tests --- .../core_api/routes/public/test_task_state.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index 9155855547dfc..a32f423e64295 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -256,7 +256,15 @@ def test_patch_updates_value(self, test_client): self._session.commit() assert test_client.patch(f"{BASE_URL}/job_id", json={"value": "v2"}).status_code == 200 - assert test_client.get(f"{BASE_URL}/job_id").json()["value"] == "v2" + row = self._session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == DAG_ID, + TaskStateModel.run_id == RUN_ID, + TaskStateModel.task_id == TASK_ID, + TaskStateModel.key == "job_id", + ) + ) + assert row.value == '"v2"' def test_patch_missing_key_returns_404(self, test_client): assert test_client.patch(f"{BASE_URL}/nonexistent", json={"value": "v"}).status_code == 404 From 6c127a6ac85ec9a3dcd41032dc8a5e6b82701a4c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 29 May 2026 10:35:21 +0530 Subject: [PATCH 7/7] testing patch validator instead --- .../core_api/routes/public/test_task_state.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index a32f423e64295..a3a5ff028aa68 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -16,12 +16,12 @@ # under the License. from __future__ import annotations -import json - import pytest +from pydantic import ValidationError from sqlalchemy import select from airflow._shared.timezones import timezone +from airflow.api_fastapi.core_api.datamodels.task_state import TaskStatePatchBody from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -279,16 +279,10 @@ def test_patch_null_value_returns_422(self, test_client): self._session.commit() assert test_client.patch(f"{BASE_URL}/job_id", json={"value": None}).status_code == 422 - @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) - def test_patch_non_finite_float_returns_422(self, test_client, bad_float): - _create_task_state(self._session, "job_id", "v", self.dag_run) - self._session.commit() - with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): - test_client.patch( - f"{BASE_URL}/job_id", - content=json.dumps({"value": bad_float}, allow_nan=True).encode(), - headers={"Content-Type": "application/json"}, - ) + @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": float("nan")}, [float("inf")]]) + def test_patch_non_finite_float_rejected_by_validator(self, bad_value): + with pytest.raises(ValidationError, match="non-finite"): + TaskStatePatchBody(value=bad_value) @pytest.mark.parametrize( ("value", "expected_db"),