Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,17 @@ repos:
pass_filenames: false
require_serial: true
- id: check-partition-mapper-defaults-in-sync
name: Check FanOutMapper default mapper table stays in sync (core/SDK)
name: Check partition-mapper core/SDK sync (FanOutMapper table + SegmentWindow/FixedKeyMapper)
entry: ./scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
language: python
files: >
(?x)
^airflow-core/src/airflow/partition_mappers/temporal\.py$|
^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$
^airflow-core/src/airflow/partition_mappers/window\.py$|
^airflow-core/src/airflow/partition_mappers/fixed_key\.py$|
^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$|
^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$|
^task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key\.py$
pass_filenames: false
require_serial: true
- id: check-window-in-sync
Expand Down
83 changes: 83 additions & 0 deletions airflow-core/docs/authoring-and-scheduling/assets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ downstream Dag partition key:
passes the key through unchanged if valid.
For example, ``AllowedKeyMapper(["us", "eu", "apac"])`` accepts only those
region keys and rejects all others.
* ``FixedKeyMapper`` collapses every upstream key onto a fixed downstream key,
regardless of the upstream value.
* ``SegmentWindow`` declares a fixed categorical set of string keys (e.g. regions,
tenants) that constitute one downstream period; paired with ``FixedKeyMapper``
Comment thread
Lee-W marked this conversation as resolved.
inside a ``RollupMapper`` it holds the downstream run until every declared segment
has arrived (see :ref:`segment-rollup <segment-categorical-rollup>`).

Example of per-asset mapper configuration and composite-key mapping:

Expand Down Expand Up @@ -733,6 +739,83 @@ so the run is held indefinitely) and the fall-back day has twenty-five (the repe
hour is dropped). Use a UTC-based upstream mapper for any rollup that crosses a DST
boundary; see the ``DayWindow`` class docstring for the full discussion.

.. _segment-categorical-rollup:

Segment (categorical) rollup
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. versionadded:: 3.3.0

For categorical partitioning — regions, tenants, experiment variants — compose a
``RollupMapper`` from two primitives:

* ``SegmentWindow(["us", "eu", "apac"])`` declares the fixed set of string keys
that constitute one downstream period; ``to_upstream`` returns the full set
regardless of the downstream anchor.
* ``FixedKeyMapper("all_regions")`` collapses every upstream key onto the single
downstream partition key ``"all_regions"``.

The scheduler holds the downstream Dag run until every declared segment has arrived
from the upstream producer, then fires once. All the segment events accumulate into
one ``AssetPartitionDagRun``; the fired run's ``partition_key`` is the value passed
to ``FixedKeyMapper``. This composition only makes sense under ``WAIT_FOR_ALL``
semantics (the default).

.. code-block:: python

from airflow.sdk import (
DAG,
Asset,
FixedKeyMapper,
PartitionAtRuntime,
PartitionedAssetTimetable,
RollupMapper,
SegmentWindow,
asset,
task,
)


@asset(
uri="file://incoming/player-stats/multi-region.csv",
schedule=PartitionAtRuntime(),
)
def multi_region_player_stats(self, outlet_events):
# Emit one event per region in a single run.
outlet_events[self].add_partitions(["us", "eu", "apac"])


# Consumer: fires once all three region partitions have arrived.
with DAG(
dag_id="segment_region_stats_rollup",
schedule=PartitionedAssetTimetable(
assets=Asset.ref(name="multi_region_player_stats"),
default_partition_mapper=RollupMapper(
upstream_mapper=FixedKeyMapper("all_regions"),
window=SegmentWindow(["us", "eu", "apac"]),
),
),
catchup=False,
):

@task
def aggregate_all_regions(dag_run=None):
# dag_run.partition_key is the downstream key once all segments arrive.
print(dag_run.partition_key)

aggregate_all_regions()

Construction validates both components: ``SegmentWindow`` raises ``ValueError`` for
an empty list, non-string items, or empty-string keys; duplicate entries are silently
deduplicated. ``FixedKeyMapper`` raises ``ValueError`` if its argument is not a
non-empty string. Pass a distinct ``FixedKeyMapper`` key when one consumer Dag rolls
up more than one asset, so each rollup uses a distinct bucket and they do not collide
on the same ``(target_dag_id, partition_key)``.

For a segment set that must be computed at runtime, do not encode it here — evaluate
completeness in a consumer-side task instead (the scheduler must not run user code to
decide a partition set).

Setting partition keys at runtime
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67716.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``FixedKeyMapper`` and ``SegmentWindow`` for categorical asset-partition rollup. ``FixedKeyMapper`` collapses any upstream partition key onto a single fixed downstream key, and ``SegmentWindow`` enumerates a fixed categorical segment set the scheduler waits for. Composing ``RollupMapper(FixedKeyMapper(...), SegmentWindow(...))`` expresses a categorical rollup, mirroring the temporal rollup shape, and ``SegmentWindow`` also composes with ``FanOutMapper`` for categorical scatter. Both classes are importable from ``airflow.sdk``.
40 changes: 40 additions & 0 deletions airflow-core/src/airflow/example_dags/example_asset_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
CronPartitionTimetable,
DayWindow,
FanOutMapper,
FixedKeyMapper,
IdentityMapper,
MonthWindow,
PartitionAtRuntime,
PartitionedAssetTimetable,
ProductMapper,
RollupMapper,
SegmentWindow,
StartOfDayMapper,
StartOfHourMapper,
StartOfMonthMapper,
Expand Down Expand Up @@ -409,3 +411,41 @@ def run_inference(dag_run=None):
print(dag_run.partition_key)

run_inference()


# --- Segment (categorical) rollup -------------------------------------------
# ``multi_region_player_stats`` (defined above) emits one partition per region
# (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a downstream
# run until every declared region key has arrived.

with DAG(
dag_id="segment_region_stats_rollup",
schedule=PartitionedAssetTimetable(
assets=Asset.ref(name="multi_region_player_stats"),
default_partition_mapper=RollupMapper(
upstream_mapper=FixedKeyMapper("all_regions"),
window=SegmentWindow(["us", "eu", "apac"]),
),
),
catchup=False,
tags=["example", "player-stats", "rollup", "segment"],
):
"""
Categorical rollup: hold until all three region partitions arrive.

``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow([...]))``
declares the fixed set of region keys required for one downstream run and collapses every
region key onto a single ``all_regions`` partition, so the three region events accumulate
into one downstream run. The run is held until ``us``, ``eu``, and ``apac`` have all
arrived from ``multi_region_player_stats``; partial arrivals remain pending in the
next-run-assets view so operators can track progress.
"""

@task
def aggregate_all_regions(dag_run=None):
"""Produce the cross-region summary once every region partition has arrived."""
if TYPE_CHECKING:
assert dag_run
print(f"All region partitions received. Partition: {dag_run.partition_key}")

aggregate_all_regions()
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/partition_mappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from airflow.partition_mappers.allowed_key import AllowedKeyMapper
from airflow.partition_mappers.base import PartitionMapper, RollupMapper
from airflow.partition_mappers.chain import ChainMapper
from airflow.partition_mappers.fixed_key import FixedKeyMapper
from airflow.partition_mappers.identity import IdentityMapper
from airflow.partition_mappers.product import ProductMapper
from airflow.partition_mappers.temporal import (
Expand All @@ -34,6 +35,7 @@
HourWindow,
MonthWindow,
QuarterWindow,
SegmentWindow,
WeekWindow,
Window,
YearWindow,
Expand All @@ -43,13 +45,15 @@
"AllowedKeyMapper",
"ChainMapper",
"DayWindow",
"FixedKeyMapper",
"HourWindow",
"IdentityMapper",
"MonthWindow",
"PartitionMapper",
"ProductMapper",
"QuarterWindow",
"RollupMapper",
"SegmentWindow",
"StartOfDayMapper",
"StartOfHourMapper",
"StartOfMonthMapper",
Expand Down
65 changes: 65 additions & 0 deletions airflow-core/src/airflow/partition_mappers/fixed_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any

import attrs

from airflow.partition_mappers.base import PartitionMapper


@attrs.define
class FixedKeyMapper(PartitionMapper):
"""
Collapse every upstream partition key onto one fixed downstream key.

Returns the same *downstream_key* for any upstream key passed to
``to_downstream``. Does not override ``decode_downstream`` or
``encode_upstream``, so it works with the string-based identity path and
satisfies :class:`~airflow.partition_mappers.base.RollupMapper`'s guard
when paired with :class:`~airflow.partition_mappers.window.SegmentWindow`.

Typical use is as the ``upstream_mapper`` inside a categorical rollup::

RollupMapper(
upstream_mapper=FixedKeyMapper("all_regions"),
window=SegmentWindow(["us", "eu", "apac"]),
)

:param downstream_key: The fixed downstream partition key every upstream key
maps to. Must be a non-empty string.
:raises ValueError: if *downstream_key* is not a non-empty ``str``.
"""

downstream_key: str = attrs.field()

@downstream_key.validator
def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) -> None:
if not isinstance(value, str) or value == "":
raise ValueError(f"FixedKeyMapper downstream_key must be a non-empty str; got {value!r}.")

def to_downstream(self, key: str) -> str:
"""Return the fixed downstream key regardless of *key*."""
return self.downstream_key

def serialize(self) -> dict[str, Any]:
return {"downstream_key": self.downstream_key}

@classmethod
def deserialize(cls, data: dict[str, Any]) -> FixedKeyMapper:
return cls(data["downstream_key"])
62 changes: 62 additions & 0 deletions airflow-core/src/airflow/partition_mappers/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar

import attrs

if TYPE_CHECKING:
from collections.abc import Callable, Iterable

Expand Down Expand Up @@ -246,3 +248,63 @@ class YearWindow(Window):
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
_require_day_one(period_start, type(self))
return _build_directional_steps(period_start, 12, _shift_months, self.direction)


def _convert_segments(segments: Iterable[str]) -> frozenset[str]:
"""
Validate and convert *segments* to a ``frozenset[str]``.

Validates each element for type and non-emptiness (with index reporting)
before collapsing into a frozenset, then checks the result is non-empty.
"""
validated: list[str] = []
for i, item in enumerate(segments):
if not isinstance(item, str):
raise ValueError(
f"SegmentWindow segment keys must be str; got {type(item).__name__!r} at index {i}: {item!r}"
)
if not item:
raise ValueError(
f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}."
)
validated.append(item)
result = frozenset(validated)
if not result:
raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.")
return result


@attrs.define
class SegmentWindow(Window):
"""
A fixed categorical set of string keys that constitute one downstream period.

Paired with :class:`~airflow.partition_mappers.fixed_key.FixedKeyMapper` inside a
:class:`~airflow.partition_mappers.base.RollupMapper` to express a categorical
rollup: the scheduler holds the downstream run until every declared segment key
has arrived from the upstream producer, then fires once.

``to_upstream`` returns the complete segment set regardless of the downstream
anchor value — the anchor is intentionally ignored because all segments map onto
a single downstream partition key, not a time-based period.

:param segments: Non-empty iterable of non-empty string segment keys. Duplicates
are silently de-duplicated.
:raises ValueError: if *segments* is empty, contains a non-``str`` element, or
contains an empty-string element.
"""

expected_decoded_type: ClassVar[type] = str

_segments: frozenset[str] = attrs.field(converter=_convert_segments)

def to_upstream(self, decoded_downstream: Any) -> frozenset[str]:
"""Return the full declared segment set, ignoring the downstream anchor."""
return self._segments

def serialize(self) -> dict[str, Any]:
return {"segments": sorted(self._segments)}

@classmethod
def deserialize(cls, data: dict[str, Any]) -> SegmentWindow:
return cls(data["segments"])
Loading
Loading