Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import json
from datetime import datetime
from typing import Literal

from pydantic import JsonValue, field_validator
from pydantic import AwareDatetime, JsonValue, field_validator

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel

Expand All @@ -43,7 +44,35 @@ class TaskStateCollectionResponse(BaseModel):


class TaskStateBody(StrictBaseModel):
"""Request body for setting a task state value."""
"""
Request body for setting a task state value.

``expires_at`` controls expiry:

- ``"default"``: apply the configured ``[state_store] default_retention_days``.
- ``null``: never expire.
- aware datetime: expire at that time.
"""

value: JsonValue
expires_at: AwareDatetime | None | Literal["default"] = "default"

@field_validator("value")
@classmethod
def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
if v is None:
raise ValueError("value cannot be null")
try:
serialized = json.dumps(v, allow_nan=False)
except ValueError:
raise ValueError("value contains non-finite numbers; NaN and Inf are not JSON representable")
if len(serialized) > _MAX_SERIALIZED_BYTES:
raise ValueError(f"value exceeds maximum serialized size of {_MAX_SERIALIZED_BYTES} bytes")
return v


class TaskStatePatchBody(StrictBaseModel):
"""Request body for patching only the value of an existing task state key."""

value: JsonValue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6144,6 +6144,84 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Task State
summary: Patch Task State
description: Update the value of an existing task state key.
operationId: patch_task_state
security:
- OAuth2PasswordBearer: []
- HTTPBearer: []
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: task_id
in: path
required: true
schema:
type: string
title: Task Id
- name: key
in: path
required: true
schema:
type: string
title: Key
- name: map_index
in: query
required: false
schema:
type: integer
minimum: -1
default: -1
title: Map Index
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TaskStatePatchBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema: {}
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
delete:
tags:
- Task State
Expand Down Expand Up @@ -15814,12 +15892,31 @@ components:
properties:
value:
$ref: '#/components/schemas/JsonValue'
expires_at:
anyOf:
- type: string
format: date-time
- type: string
const: default
- type: 'null'
title: Expires At
default: default
additionalProperties: false
type: object
required:
- value
title: TaskStateBody
description: Request body for setting a task state value.
description: 'Request body for setting a task state value.


``expires_at`` controls expiry:


- ``"default"``: apply the configured ``[state_store] default_retention_days``.

- ``null``: never expire.

- aware datetime: expire at that time.'
TaskStateCollectionResponse:
properties:
task_states:
Expand All @@ -15836,6 +15933,17 @@ components:
- total_entries
title: TaskStateCollectionResponse
description: All task state entries for a task instance.
TaskStatePatchBody:
properties:
value:
$ref: '#/components/schemas/JsonValue'
additionalProperties: false
type: object
required:
- value
title: TaskStatePatchBody
description: Request body for patching only the value of an existing task state
key.
TaskStateResponse:
properties:
key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from __future__ import annotations

import json
from typing import Annotated
from datetime import datetime, timedelta, timezone
Comment thread
jason810496 marked this conversation as resolved.
from typing import Annotated, Literal

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand All @@ -30,10 +31,12 @@
from airflow.api_fastapi.core_api.datamodels.task_state import (
TaskStateBody,
TaskStateCollectionResponse,
TaskStatePatchBody,
TaskStateResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.configuration import conf
from airflow.models.task_state import TaskStateModel
from airflow.models.taskinstance import TaskInstance as TI
from airflow.state import get_state_backend
Expand All @@ -48,6 +51,36 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, map_index: int) -> Ta
return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index)


def _resolve_expires_at(expires_at: datetime | None | Literal["default"]) -> datetime | None:
"""
Resolve the expires_at value from the request body.

- ``"default"``: apply configured default_retention_days
- ``None``: never expire
- datetime: use as-is
"""
if expires_at == "default":
days = conf.getint("state_store", "default_retention_days")
return datetime.now(tz=timezone.utc) + timedelta(days=days)
return expires_at


def _require_ti(dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: SessionDep) -> None:
ti_exists = session.scalar(
select(TI.task_id).where(
TI.dag_id == dag_id,
TI.run_id == dag_run_id,
TI.task_id == task_id,
TI.map_index == map_index,
)
)
if ti_exists is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}",
)


@task_state_router.get(
"",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
Expand Down Expand Up @@ -150,24 +183,53 @@ def set_task_state(
map_index: Annotated[int, Query(ge=-1)] = -1,
) -> None:
"""Set a task state value. Creates or overwrites the key."""
ti_exists = session.scalar(
select(TI.task_id).where(
TI.dag_id == dag_id,
TI.run_id == dag_run_id,
TI.task_id == task_id,
TI.map_index == map_index,
_require_ti(dag_id, dag_run_id, task_id, map_index, session)
expires_at = _resolve_expires_at(body.expires_at)
scope = _get_scope(dag_id, dag_run_id, task_id, map_index)
try:
get_state_backend().set(scope, key, json.dumps(body.value), expires_at=expires_at, session=session)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e


@task_state_router.patch(
"/{key:path}",
status_code=status.HTTP_200_OK,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))],
)
def patch_task_state(
dag_id: str,
dag_run_id: str,
task_id: str,
key: str,
body: TaskStatePatchBody,
session: SessionDep,
map_index: Annotated[int, Query(ge=-1)] = -1,
) -> None:
"""Update the value of an existing task state key."""
_require_ti(dag_id, dag_run_id, task_id, map_index, session)

existing = session.execute(
select(TaskStateModel.expires_at).where(
TaskStateModel.dag_id == dag_id,
TaskStateModel.run_id == dag_run_id,
TaskStateModel.task_id == task_id,
TaskStateModel.map_index == map_index,
TaskStateModel.key == key,
)
)
if ti_exists is None:
).one_or_none()

if existing is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}",
detail=f"Task state key {key!r} not found",
)

scope = _get_scope(dag_id, dag_run_id, task_id, map_index)
try:
get_state_backend().set(scope, key, json.dumps(body.value), session=session)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
get_state_backend().set(
scope, key, json.dumps(body.value), expires_at=existing.expires_at, session=session
)


@task_state_router.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<R
export type TaskInstanceServiceUpdateHitlDetailMutationResult = Awaited<ReturnType<typeof TaskInstanceService.updateHitlDetail>>;
export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof PoolService.patchPool>>;
export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof PoolService.bulkPools>>;
export type TaskStateServicePatchTaskStateMutationResult = Awaited<ReturnType<typeof TaskStateService.patchTaskState>>;
export type XcomServiceUpdateXcomEntryMutationResult = Awaited<ReturnType<typeof XcomService.updateXcomEntry>>;
export type VariableServicePatchVariableMutationResult = Awaited<ReturnType<typeof VariableService.patchVariable>>;
export type VariableServiceBulkVariablesMutationResult = Awaited<ReturnType<typeof VariableService.bulkVariables>>;
Expand Down
30 changes: 29 additions & 1 deletion airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query";
import { AssetService, AssetStateService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TaskStateService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen";
import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen";
import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TaskStatePatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen";
import * as Common from "./common";
/**
* Get Assets
Expand Down Expand Up @@ -2801,6 +2801,34 @@ export const usePoolServiceBulkPools = <TData = Common.PoolServiceBulkPoolsMutat
requestBody: BulkBody_PoolBody_;
}, TContext>({ mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise<TData>, ...options });
/**
* Patch Task State
* Update the value of an existing task state key.
* @param data The data for the request.
* @param data.dagId
* @param data.dagRunId
* @param data.taskId
* @param data.key
* @param data.requestBody
* @param data.mapIndex
* @returns unknown Successful Response
* @throws ApiError
*/
export const useTaskStateServicePatchTaskState = <TData = Common.TaskStateServicePatchTaskStateMutationResult, TError = unknown, TContext = unknown>(options?: Omit<UseMutationOptions<TData, TError, {
dagId: string;
dagRunId: string;
key: string;
mapIndex?: number;
requestBody: TaskStatePatchBody;
taskId: string;
}, TContext>, "mutationFn">) => useMutation<TData, TError, {
dagId: string;
dagRunId: string;
key: string;
mapIndex?: number;
requestBody: TaskStatePatchBody;
taskId: string;
}, TContext>({ mutationFn: ({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) => TaskStateService.patchTaskState({ dagId, dagRunId, key, mapIndex, requestBody, taskId }) as unknown as Promise<TData>, ...options });
/**
* Update Xcom Entry
* Update an existing XCom entry.
* @param data The data for the request.
Expand Down
Loading
Loading