diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py index e6622f842e116..16d2143523078 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py @@ -18,8 +18,9 @@ import json from datetime import datetime +from typing import Literal -from pydantic import JsonValue, field_validator +from pydantic import AwareDatetime, JsonValue, field_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -43,7 +44,35 @@ class TaskStateCollectionResponse(BaseModel): class TaskStateBody(StrictBaseModel): - """Request body for setting a task state value.""" + """ + Request body for setting a task state value. + + ``expires_at`` controls expiry: + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + - ``null``: never expire. + - aware datetime: expire at that time. + """ + + value: JsonValue + expires_at: AwareDatetime | None | Literal["default"] = "default" + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + try: + serialized = json.dumps(v, allow_nan=False) + except ValueError: + raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable") + if len(serialized) > _MAX_SERIALIZED_BYTES: + raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes") + return v + + +class TaskStatePatchBody(StrictBaseModel): + """Request body for patching only the value of an existing task state key.""" value: JsonValue diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 86017f12a60f8..8c1097adbb1d8 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -6144,6 +6144,84 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - Task State + summary: Patch Task State + description: Update the value of an existing task state key. + operationId: patch_task_state + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: key + in: path + required: true + schema: + type: string + title: Key + - name: map_index + in: query + required: false + schema: + type: integer + minimum: -1 + default: -1 + title: Map Index + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/TaskStatePatchBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' delete: tags: - Task State @@ -15814,12 +15892,31 @@ components: properties: value: $ref: '#/components/schemas/JsonValue' + expires_at: + anyOf: + - type: string + format: date-time + - type: string + const: default + - type: 'null' + title: Expires At + default: default additionalProperties: false type: object required: - value title: TaskStateBody - description: Request body for setting a task state value. + description: 'Request body for setting a task state value. + + + ``expires_at`` controls expiry: + + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + + - ``null``: never expire. + + - aware datetime: expire at that time.' TaskStateCollectionResponse: properties: task_states: @@ -15836,6 +15933,17 @@ components: - total_entries title: TaskStateCollectionResponse description: All task state entries for a task instance. + TaskStatePatchBody: + properties: + value: + $ref: '#/components/schemas/JsonValue' + additionalProperties: false + type: object + required: + - value + title: TaskStatePatchBody + description: Request body for patching only the value of an existing task state + key. TaskStateResponse: properties: key: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py index 31cc7272ddca6..3ca667336f486 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py @@ -17,7 +17,8 @@ from __future__ import annotations import json -from typing import Annotated +from datetime import datetime, timedelta, timezone +from typing import Annotated, Literal from fastapi import Depends, HTTPException, Query, status from sqlalchemy import select @@ -30,10 +31,12 @@ from airflow.api_fastapi.core_api.datamodels.task_state import ( TaskStateBody, TaskStateCollectionResponse, + TaskStatePatchBody, TaskStateResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import requires_access_dag +from airflow.configuration import conf from airflow.models.task_state import TaskStateModel from airflow.models.taskinstance import TaskInstance as TI from airflow.state import get_state_backend @@ -48,6 +51,36 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, map_index: int) -> Ta return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) +def _resolve_expires_at(expires_at: datetime | None | Literal["default"]) -> datetime | None: + """ + Resolve the expires_at value from the request body. + + - ``"default"``: apply configured default_retention_days + - ``None``: never expire + - datetime: use as-is + """ + if expires_at == "default": + days = conf.getint("state_store", "default_retention_days") + return datetime.now(tz=timezone.utc) + timedelta(days=days) + return expires_at + + +def _require_ti(dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: SessionDep) -> None: + ti_exists = session.scalar( + select(TI.task_id).where( + TI.dag_id == dag_id, + TI.run_id == dag_run_id, + TI.task_id == task_id, + TI.map_index == map_index, + ) + ) + if ti_exists is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}", + ) + + @task_state_router.get( "", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), @@ -150,24 +183,53 @@ def set_task_state( map_index: Annotated[int, Query(ge=-1)] = -1, ) -> None: """Set a task state value. Creates or overwrites the key.""" - ti_exists = session.scalar( - select(TI.task_id).where( - TI.dag_id == dag_id, - TI.run_id == dag_run_id, - TI.task_id == task_id, - TI.map_index == map_index, + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + expires_at = _resolve_expires_at(body.expires_at) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) + try: + get_state_backend().set(scope, key, json.dumps(body.value), expires_at=expires_at, session=session) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@task_state_router.patch( + "/{key:path}", + status_code=status.HTTP_200_OK, + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def patch_task_state( + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + body: TaskStatePatchBody, + session: SessionDep, + map_index: Annotated[int, Query(ge=-1)] = -1, +) -> None: + """Update the value of an existing task state key.""" + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + + existing = session.execute( + select(TaskStateModel.expires_at).where( + TaskStateModel.dag_id == dag_id, + TaskStateModel.run_id == dag_run_id, + TaskStateModel.task_id == task_id, + TaskStateModel.map_index == map_index, + TaskStateModel.key == key, ) - ) - if ti_exists is None: + ).one_or_none() + + if existing is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}", + detail=f"Task state key {key!r} not found", ) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) - try: - get_state_backend().set(scope, key, json.dumps(body.value), session=session) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + get_state_backend().set( + scope, key, json.dumps(body.value), expires_at=existing.expires_at, session=session + ) @task_state_router.delete( diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index f8e3dbe9af638..632d311157631 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -1067,6 +1067,7 @@ export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited>; export type PoolServicePatchPoolMutationResult = Awaited>; export type PoolServiceBulkPoolsMutationResult = Awaited>; +export type TaskStateServicePatchTaskStateMutationResult = Awaited>; export type XcomServiceUpdateXcomEntryMutationResult = Awaited>; export type VariableServicePatchVariableMutationResult = Awaited>; export type VariableServiceBulkVariablesMutationResult = Awaited>; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 8c0976ec328e9..01e3d78e69cd7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -2,7 +2,7 @@ import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query"; import { AssetService, AssetStateService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TaskStateService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen"; -import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; +import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TaskStatePatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; import * as Common from "./common"; /** * Get Assets @@ -2801,6 +2801,34 @@ export const usePoolServiceBulkPools = ({ mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise, ...options }); /** +* Patch Task State +* Update the value of an existing task state key. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.key +* @param data.requestBody +* @param data.mapIndex +* @returns unknown Successful Response +* @throws ApiError +*/ +export const useTaskStateServicePatchTaskState = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) => TaskStateService.patchTaskState({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) as unknown as Promise, ...options }); +/** * Update Xcom Entry * Update an existing XCom entry. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index b00c1833c164b..742a18b09efc5 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -6992,13 +6992,36 @@ export const $TaskStateBody = { properties: { value: { '$ref': '#/components/schemas/JsonValue' + }, + expires_at: { + anyOf: [ + { + type: 'string', + format: 'date-time' + }, + { + type: 'string', + const: 'default' + }, + { + type: 'null' + } + ], + title: 'Expires At', + default: 'default' } }, additionalProperties: false, type: 'object', required: ['value'], title: 'TaskStateBody', - description: 'Request body for setting a task state value.' + description: `Request body for setting a task state value. + +\`\`expires_at\`\` controls expiry: + +- \`\`"default"\`\`: apply the configured \`\`[state_store] default_retention_days\`\`. +- \`\`null\`\`: never expire. +- aware datetime: expire at that time.` } as const; export const $TaskStateCollectionResponse = { @@ -7021,6 +7044,19 @@ export const $TaskStateCollectionResponse = { description: 'All task state entries for a task instance.' } as const; +export const $TaskStatePatchBody = { + properties: { + value: { + '$ref': '#/components/schemas/JsonValue' + } + }, + additionalProperties: false, + type: 'object', + required: ['value'], + title: 'TaskStatePatchBody', + description: 'Request body for patching only the value of an existing task state key.' +} as const; + export const $TaskStateResponse = { properties: { key: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index e65f313e1d93f..3e65983530399 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, PatchTaskStateData, PatchTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -3813,6 +3813,43 @@ export class TaskStateService { }); } + /** + * Patch Task State + * Update the value of an existing task state key. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.key + * @param data.requestBody + * @param data.mapIndex + * @returns unknown Successful Response + * @throws ApiError + */ + public static patchTaskState(data: PatchTaskStateData): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/states/{key}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + key: data.key + }, + query: { + map_index: data.mapIndex + }, + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + /** * Delete Task State * Delete a single task state key. No-op if the key does not exist. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index e674f27e8a30a..16929844a942e 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1714,9 +1714,16 @@ export type TaskResponse = { /** * Request body for setting a task state value. + * + * ``expires_at`` controls expiry: + * + * - ``"default"``: apply the configured ``[state_store] default_retention_days``. + * - ``null``: never expire. + * - aware datetime: expire at that time. */ export type TaskStateBody = { value: JsonValue; + expires_at?: string | "default" | null; }; /** @@ -1727,6 +1734,13 @@ export type TaskStateCollectionResponse = { total_entries: number; }; +/** + * Request body for patching only the value of an existing task state key. + */ +export type TaskStatePatchBody = { + value: JsonValue; +}; + /** * A single task state key/value pair with metadata. */ @@ -3968,6 +3982,17 @@ export type SetTaskStateData = { export type SetTaskStateResponse = void; +export type PatchTaskStateData = { + dagId: string; + dagRunId: string; + key: string; + mapIndex?: number; + requestBody: TaskStatePatchBody; + taskId: string; +}; + +export type PatchTaskStateResponse = unknown; + export type DeleteTaskStateData = { dagId: string; dagRunId: string; @@ -7262,6 +7287,31 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchTaskStateData; + res: { + /** + * Successful Response + */ + 200: unknown; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; delete: { req: DeleteTaskStateData; res: { diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py index e96aca22cddd9..8481a6fc43b31 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py @@ -23,12 +23,13 @@ from sqlalchemy import select from airflow._shared.timezones import timezone -from airflow.api_fastapi.core_api.datamodels.task_state import TaskStateBody +from airflow.api_fastapi.core_api.datamodels.task_state import TaskStateBody, TaskStatePatchBody from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.types import DagRunType +from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, clear_db_runs pytestmark = pytest.mark.db_test @@ -258,10 +259,104 @@ def test_key_with_slash_is_supported(self, test_client): assert response.status_code == 204 assert test_client.get(f"{BASE_URL}/workflow/step_1").json()["key"] == "workflow/step_1" + def test_new_key_default_retention_applies_config(self, test_client, time_machine): + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) + with conf_vars({("state_store", "default_retention_days"): "7"}): + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": "default"}) + + resp = test_client.get(f"{BASE_URL}/job_id").json() + assert resp["expires_at"] == "2026-01-08T00:00:00Z" + + def test_new_key_never_expiry(self, test_client): + """PUT with expires_at=null stores a key that never expires.""" + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": None}) + assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] is None + + def test_new_key_explicit_expiry(self, test_client, time_machine): + """PUT with an explicit datetime uses that as expires_at.""" + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) + target = "2026-01-31T00:00:00Z" + test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": target}) + assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] == target + + def test_put_overwrites_expiry_on_existing_key(self, test_client, time_machine): + """PUT on an existing key replaces expires_at with whatever the body specifies.""" + time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False) + test_client.put(f"{BASE_URL}/job_id", json={"value": "v1", "expires_at": "2026-01-31T00:00:00Z"}) + + # second request but with null expires_at + test_client.put(f"{BASE_URL}/job_id", json={"value": "v2", "expires_at": None}) + + resp = test_client.get(f"{BASE_URL}/job_id").json() + assert resp["value"] == "v2" + assert resp["expires_at"] is None + def test_unauthorized_returns_401(self, unauthenticated_test_client): assert unauthenticated_test_client.put(f"{BASE_URL}/job_id", json={"value": "v"}).status_code == 401 +class TestPatchTaskState(TestTaskStateEndpoint): + def test_patch_updates_value(self, test_client): + _create_task_state(self._session, "job_id", "v1", self.dag_run) + self._session.commit() + + assert test_client.patch(f"{BASE_URL}/job_id", json={"value": "v2"}).status_code == 200 + row = self._session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == DAG_ID, + TaskStateModel.run_id == RUN_ID, + TaskStateModel.task_id == TASK_ID, + TaskStateModel.key == "job_id", + ) + ) + assert row.value == '"v2"' + + def test_patch_missing_key_returns_404(self, test_client): + assert test_client.patch(f"{BASE_URL}/nonexistent", json={"value": "v"}).status_code == 404 + + def test_patch_empty_body_returns_422(self, test_client): + _create_task_state(self._session, "job_id", "v", self.dag_run) + self._session.commit() + assert test_client.patch(f"{BASE_URL}/job_id", json={}).status_code == 422 + + def test_patch_null_value_returns_422(self, test_client): + _create_task_state(self._session, "job_id", "v", self.dag_run) + self._session.commit() + assert test_client.patch(f"{BASE_URL}/job_id", json={"value": None}).status_code == 422 + + @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": float("nan")}, [float("inf")]]) + def test_patch_non_finite_float_rejected_by_validator(self, bad_value): + with pytest.raises(ValidationError, match="non-finite"): + TaskStatePatchBody(value=bad_value) + + @pytest.mark.parametrize( + ("value", "expected_db"), + [ + (42, "42"), + ("hello", '"hello"'), + ({"k": 1}, '{"k": 1}'), + ([1, 2], "[1, 2]"), + ], + ) + def test_patch_stores_json_encoded_value(self, test_client, value, expected_db): + _create_task_state(self._session, "job_id", "initial", self.dag_run) + self._session.commit() + test_client.patch(f"{BASE_URL}/job_id", json={"value": value}) + row = self._session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == DAG_ID, + TaskStateModel.run_id == RUN_ID, + TaskStateModel.task_id == TASK_ID, + TaskStateModel.key == "job_id", + ) + ) + self._session.refresh(row) + assert row.value == expected_db + + def test_unauthorized_returns_401(self, unauthenticated_test_client): + assert unauthenticated_test_client.patch(f"{BASE_URL}/job_id", json={"value": "v"}).status_code == 401 + + class TestDeleteTaskState(TestTaskStateEndpoint): def test_deletes_key(self, test_client): _create_task_state(self._session, "job_id", "spark_001", self.dag_run) diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 63444c465d94d..4ad4fd39cb907 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -956,6 +956,24 @@ class TaskOutletAssetReference(BaseModel): class TaskStateBody(BaseModel): """ Request body for setting a task state value. + + ``expires_at`` controls expiry: + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + - ``null``: never expire. + - aware datetime: expire at that time. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: JsonValue + expires_at: Annotated[datetime | str | None, Field(title="Expires At")] = "default" + + +class TaskStatePatchBody(BaseModel): + """ + Request body for patching only the value of an existing task state key. """ model_config = ConfigDict(