Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ Override four serialization hooks from :class:`~airflow.sdk.state.BaseStoreBacke

* ``serialize_task_store_to_ref``: called by ``TaskStoreAccessor.set()`` before the value is sent to the Execution API; return a compact reference string (e.g. an S3 key) to be stored in the database instead of the raw value.
* ``deserialize_task_store_from_ref``: called by ``TaskStoreAccessor.get()`` after retrieving the reference from the backend; return the actual value.
* ``serialize_asset_store_to_ref``: same as the task variant but for asset store; receives the asset name or URI as ``asset_ref``.
* ``serialize_asset_store_to_ref``: same as the task variant but for asset store; receives the asset scope as ``scope`` (an :class:`~airflow.sdk.state.AssetScope` with ``name`` and/or ``uri``).
* ``deserialize_asset_store_from_ref``: called by ``AssetStoreAccessor.get()`` to resolve the stored reference back to the actual value.

.. important::

**References must be deterministic.** Given the same inputs (``ti_id`` + ``key`` for task store; ``asset_ref`` + ``key`` for asset store), the serialization method must always return the same reference string. Do not embed timestamps, random UUIDs, or any other non-deterministic component in the reference path.
**References must be deterministic.** Given the same inputs (``scope`` + ``key``), the serialization method must always return the same reference string. Do not embed timestamps, random UUIDs, or any other non-deterministic component in the reference path.

When a key is deleted or cleared, Airflow clears the database reference *first*, then calls the backend's ``delete()`` or ``clear()`` method. If backend cleanup fails after the DB row is gone, the external object is orphaned. Because the reference is deterministic, a subsequent ``set()`` for the same key will overwrite the orphaned object, making the failure recoverable. A non-deterministic reference would leave the external object permanently orphaned with no way to locate it.

Expand All @@ -178,26 +178,27 @@ Example skeleton:

class S3StateBackend(BaseStoreBackend):

def _task_ref(self, ti_id: str, key: str) -> str:
return f"airflow/task-store/{ti_id}/{key}"
def _task_ref(self, scope: TaskScope, key: str) -> str:
return f"airflow/task-store/{scope.dag_id}/{scope.run_id}/{scope.task_id}/{scope.map_index}/{key}"

def _asset_ref(self, asset_ref: str, key: str) -> str:
def _asset_ref(self, scope: AssetScope, key: str) -> str:
import hashlib

safe = hashlib.sha256(asset_ref.encode()).hexdigest()[:16]
asset_identifier = scope.name or scope.uri or ""
safe = hashlib.sha256(asset_identifier.encode()).hexdigest()[:16]
return f"airflow/asset-store/{safe}/{key}"

def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, ti_id: str) -> str:
s3_key = self._task_ref(ti_id, key)
def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, scope: TaskScope) -> str:
s3_key = self._task_ref(scope, key)
s3_client.put_object(Bucket=BUCKET, Key=s3_key, Body=json.dumps(value).encode())
return s3_key

def deserialize_task_store_from_ref(self, stored: str) -> JsonValue:
s3_object = s3_client.get_object(Bucket=BUCKET, Key=stored)
return json.loads(s3_object["Body"].read().decode())

def serialize_asset_store_to_ref(self, *, value: JsonValue, key: str, asset_ref: str) -> str:
s3_key = self._asset_ref(asset_ref, key)
def serialize_asset_store_to_ref(self, *, value: JsonValue, key: str, scope: AssetScope) -> str:
s3_key = self._asset_ref(scope, key)
s3_client.put_object(Bucket=BUCKET, Key=s3_key, Body=json.dumps(value).encode())
return s3_key

Expand Down
11 changes: 4 additions & 7 deletions shared/state/src/airflow_shared/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def cleanup(self) -> None:
``[state_store] default_retention_days``) and deciding what to delete.
"""

def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, ti_id: str) -> str:
def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, scope: TaskScope) -> str:
"""
Serialize a task store value before it is sent to the execution API for db persistence.

Expand All @@ -260,7 +260,7 @@ def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, ti_id: str)
that wrapper before passing ``stored`` to ``deserialize_task_store_from_ref()``. Do not
wrap the reference yourself.

The returned reference must be deterministic — given the same ``ti_id`` and ``key`` it
The returned reference must be deterministic — given the same ``scope`` and ``key`` it
must always return the same string. Do not use timestamps or random UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external
object will be orphaned. By default, it JSON dumps the value and returns a JSON string.
Expand All @@ -277,7 +277,7 @@ def deserialize_task_store_from_ref(self, stored: str) -> JsonValue:
"""
return json.loads(stored)

def serialize_asset_store_to_ref(self, *, value: JsonValue, key: str, asset_ref: str) -> str:
def serialize_asset_store_to_ref(self, *, value: JsonValue, key: str, scope: AssetScope) -> str:
"""
Serialize an asset store value before it is sent to the Execution API for db persistence.

Expand All @@ -290,10 +290,7 @@ def serialize_asset_store_to_ref(self, *, value: JsonValue, key: str, asset_ref:
that wrapper before passing ``stored`` to ``deserialize_asset_store_from_ref()``. Do not
wrap the reference yourself.

``asset_ref`` is either the asset name or URI, depending on how the accessor was
constructed. It may be a URI string if the task inlet was declared as ``AssetUriRef``.

The returned reference must be deterministic — given the same ``asset_ref`` and ``key`` it
The returned reference must be deterministic — given the same ``scope`` and ``key`` it
must always return the same string. Do not use timestamps or random UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external
object will be orphaned. By default, it JSON dumps the value and returns a JSON string.
Expand Down
40 changes: 22 additions & 18 deletions shared/state/tests/state/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

from airflow_shared.state import AssetScope, BaseStoreBackend, StoreScope
from airflow_shared.state import AssetScope, BaseStoreBackend, StoreScope, TaskScope


class TestAssetScope:
Expand Down Expand Up @@ -88,20 +88,22 @@ def test_abstract_methods_cover_full_interface(self):

def test_task_store_serialize_deserialize_round_trip(self, backend):
original = "app_1234"
serialized = backend.serialize_task_store_to_ref(value=original, key="job_id", ti_id="abc-123")
scope = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=-1)
serialized = backend.serialize_task_store_to_ref(value=original, key="job_id", scope=scope)
deserialized = backend.deserialize_task_store_from_ref(serialized)
assert deserialized == original

def test_task_store_serialize_deserialize_typed_values(self, backend):
"""Default backend passes typed values through unchanged (custom backends handle storage)."""
scope = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=-1)
assert (
backend.deserialize_task_store_from_ref(
backend.serialize_task_store_to_ref(value=42, key="count", ti_id="abc-123")
backend.serialize_task_store_to_ref(value=42, key="count", scope=scope)
)
== 42
)
assert backend.deserialize_task_store_from_ref(
backend.serialize_task_store_to_ref(value={"status": "ok"}, key="result", ti_id="abc-123")
backend.serialize_task_store_to_ref(value={"status": "ok"}, key="result", scope=scope)
) == {"status": "ok"}

def test_custom_backend_overrides_task_store_ser_deser(self):
Expand All @@ -115,38 +117,39 @@ async def aset(self, scope, key, value): ...
async def adelete(self, scope, key): ...
async def aclear(self, scope, *, all_map_indices=False): ...

def serialize_task_store_to_ref(self, *, value, key, ti_id):
return f"s3://bucket/{ti_id}/{key}"
def serialize_task_store_to_ref(self, *, value, key, scope: TaskScope):
return f"s3://bucket/{scope.dag_id}/{scope.task_id}/{key}"

def deserialize_task_store_from_ref(self, stored):
return f"fetched:{stored}"

b = MyBackend()
assert b.serialize_task_store_to_ref(value="app_1234", key="job_id", ti_id="abc-123") == (
"s3://bucket/abc-123/job_id"
scope = TaskScope(dag_id="my_dag", run_id="r", task_id="my_task", map_index=-1)
assert b.serialize_task_store_to_ref(value="app_1234", key="job_id", scope=scope) == (
"s3://bucket/my_dag/my_task/job_id"
)
assert (
b.deserialize_task_store_from_ref("s3://bucket/abc-123/job_id")
== "fetched:s3://bucket/abc-123/job_id"
b.deserialize_task_store_from_ref("s3://bucket/my_dag/my_task/job_id")
== "fetched:s3://bucket/my_dag/my_task/job_id"
)

def test_asset_store_serialize_deserialize_round_trip(self, backend):
original = "2026-05-01"
serialized = backend.serialize_asset_store_to_ref(
value="2026-05-01", key="watermark", asset_ref="my_asset"
)
scope = AssetScope(name="my_asset")
serialized = backend.serialize_asset_store_to_ref(value="2026-05-01", key="watermark", scope=scope)
deserialized = backend.deserialize_asset_store_from_ref(serialized)
assert deserialized == original

def test_asset_store_serialize_deserialize_typed_values(self, backend):
scope = AssetScope(name="my_asset")
assert (
backend.deserialize_asset_store_from_ref(
backend.serialize_asset_store_to_ref(value=5, key="total_runs", asset_ref="my_asset")
backend.serialize_asset_store_to_ref(value=5, key="total_runs", scope=scope)
)
== 5
)
assert backend.deserialize_asset_store_from_ref(
backend.serialize_asset_store_to_ref(value={"rows": 1234}, key="last_run", asset_ref="my_asset")
backend.serialize_asset_store_to_ref(value={"rows": 1234}, key="last_run", scope=scope)
) == {"rows": 1234}

def test_custom_backend_overrides_asset_store_ser_deser(self):
Expand All @@ -160,14 +163,15 @@ async def aset(self, scope, key, value): ...
async def adelete(self, scope, key): ...
async def aclear(self, scope, *, all_map_indices=False): ...

def serialize_asset_store_to_ref(self, *, value, key, asset_ref):
return f"s3://bucket/assets/{asset_ref}/{key}"
def serialize_asset_store_to_ref(self, *, value, key, scope: AssetScope):
return f"s3://bucket/assets/{scope.name}/{key}"

def deserialize_asset_store_from_ref(self, stored):
return f"resolved:{stored}"

b = MyBackend()
assert b.serialize_asset_store_to_ref(value="2026-05-01", key="watermark", asset_ref="my_asset") == (
scope = AssetScope(name="my_asset")
assert b.serialize_asset_store_to_ref(value="2026-05-01", key="watermark", scope=scope) == (
"s3://bucket/assets/my_asset/watermark"
)
assert (
Expand Down
7 changes: 4 additions & 3 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import attrs
import structlog

from airflow.sdk._shared.state import AssetScope
from airflow.sdk.configuration import conf
from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
from airflow.sdk.definitions._internal.types import NOTSET
Expand Down Expand Up @@ -593,7 +594,7 @@ def set(self, key: str, value: JsonValue, *, retention: timedelta | None = None)
backend = _get_worker_state_store_backend()
stored: JsonValue = value
if backend is not None:
ref: str = backend.serialize_task_store_to_ref(value=value, key=key, ti_id=str(self._ti_id))
ref: str = backend.serialize_task_store_to_ref(value=value, key=key, scope=self._scope)
# wrap the value with a marker to indicate that it's stored externally, and include the ref to the external storage
stored = _wrap_external_ref(ref)

Expand Down Expand Up @@ -715,10 +716,10 @@ def set(self, key: str, value: JsonValue) -> None:
# if custom backend is configured, store the value on the custom backend, and return the reference
# to the stored value to store in the DB
backend = _get_worker_state_store_backend()
asset_ref = self._name or self._uri or ""
stored: JsonValue = value
if backend is not None:
ref = backend.serialize_asset_store_to_ref(value=value, key=key, asset_ref=asset_ref)
scope = AssetScope(name=self._name, uri=self._uri)
ref = backend.serialize_asset_store_to_ref(value=value, key=key, scope=scope)
Comment thread
amoghrajesh marked this conversation as resolved.
stored = _wrap_external_ref(ref)

msg: ToSupervisor
Expand Down
14 changes: 8 additions & 6 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,8 +1619,8 @@ def __init__(self):
self._actual_key_value_store: dict[str, str] = {} # key -> actual value
self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI)

def serialize_task_store_to_ref(self, *, value, key: str, ti_id: str) -> str:
ref = f"mem://{ti_id}/{key}"
def serialize_task_store_to_ref(self, *, value, key: str, scope) -> str:
ref = f"mem://{scope.dag_id}/{scope.run_id}/{scope.task_id}/{scope.map_index}/{key}"
self._actual_key_value_store[key] = value
self.reference[key] = ref
return ref
Expand All @@ -1629,8 +1629,8 @@ def deserialize_task_store_from_ref(self, stored: str) -> JsonValue:
key = stored.rsplit("/", 1)[-1]
return self._actual_key_value_store.get(key, stored)

def serialize_asset_store_to_ref(self, *, value, key: str, asset_ref: str) -> str:
ref = f"mem://{asset_ref}/{key}"
def serialize_asset_store_to_ref(self, *, value, key: str, scope) -> str:
ref = f"mem://{scope.name or scope.uri}/{key}"
self._actual_key_value_store[key] = value
self.reference[key] = ref
return ref
Expand Down Expand Up @@ -1672,7 +1672,7 @@ def backend(self):
def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend, time_machine):
"""set() stores actual value in backend and sends mem:// reference via comms."""
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
expected_ref = f"mem://{self.TI_ID}/job_id"
expected_ref = f"mem://{self.SCOPE.dag_id}/{self.SCOPE.run_id}/{self.SCOPE.task_id}/{self.SCOPE.map_index}/job_id"

frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
time_machine.move_to(frozen_dt, tick=False)
Expand All @@ -1693,7 +1693,9 @@ def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend,

def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend):
"""get() fetches mem:// reference from DB, resolves it to actual value via backend."""
ref = _wrap_external_ref(f"mem://{self.TI_ID}/job_id")
ref = _wrap_external_ref(
f"mem://{self.SCOPE.dag_id}/{self.SCOPE.run_id}/{self.SCOPE.task_id}/{self.SCOPE.map_index}/job_id"
)
backend._actual_key_value_store["job_id"] = "app_001"
mock_supervisor_comms.send.return_value = TaskStoreResult(value=ref)

Expand Down
14 changes: 10 additions & 4 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
timezone,
)
from airflow.sdk._shared.observability.metrics.base_stats_logger import StatsLogger
from airflow.sdk._shared.state import TaskScope
from airflow.sdk._shared.state import AssetScope, TaskScope
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
AssetResponse,
Expand Down Expand Up @@ -5817,7 +5817,7 @@ def execute(self, context):
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

mock_backend.serialize_asset_store_to_ref.assert_called_once_with(
value="2026-05-01", key="watermark", asset_ref="my_asset"
value="2026-05-01", key="watermark", scope=AssetScope(name="my_asset", uri=None)
)
mock_supervisor_comms.send.assert_any_call(
SetAssetStoreByName(
Expand All @@ -5843,7 +5843,13 @@ def execute(self, context):
mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect

mock_backend = mock.MagicMock()
ref = f"mem://{runtime_ti.id}/job_id"
scope = TaskScope(
dag_id=runtime_ti.dag_id,
run_id=runtime_ti.run_id,
task_id=runtime_ti.task_id,
map_index=runtime_ti.map_index,
)
ref = f"mem://{scope.dag_id}/{scope.run_id}/{scope.task_id}/{scope.map_index}/job_id"
mock_backend.serialize_task_store_to_ref.return_value = ref

with mock.patch(
Expand All @@ -5852,7 +5858,7 @@ def execute(self, context):
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

mock_backend.serialize_task_store_to_ref.assert_called_once_with(
value="app_001", key="job_id", ti_id=str(runtime_ti.id)
value="app_001", key="job_id", scope=scope
)
mock_supervisor_comms.send.assert_any_call(
SetTaskStore(
Expand Down
Loading