From e5ceca96ab7ad03e5a9ca1b2c5d8e4c4d5f72868 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 6 Jun 2026 12:45:06 +0800 Subject: [PATCH 1/4] feat(partition-mappers): categorical rollup via FixedKeyMapper + SegmentWindow Express categorical asset-partition rollup by composing RollupMapper(FixedKeyMapper(...), SegmentWindow(...)): FixedKeyMapper routes every upstream segment event onto one downstream bucket, and SegmentWindow declares the segment set the scheduler waits for. Categorical and temporal rollups now share one composition model, and SegmentWindow also composes with FanOutMapper for categorical scatter. - Add FixedKeyMapper (collapse any upstream key onto a fixed downstream key) and SegmentWindow (enumerate a fixed categorical segment set), on both the SDK authoring side and the scheduler runtime side. --- .../docs/authoring-and-scheduling/assets.rst | 83 ++++++++++++ .../example_dags/example_asset_partition.py | 40 ++++++ .../src/airflow/partition_mappers/__init__.py | 4 + .../airflow/partition_mappers/fixed_key.py | 62 +++++++++ .../src/airflow/partition_mappers/window.py | 49 +++++++ .../src/airflow/serialization/encoders.py | 12 ++ .../tests/unit/jobs/test_scheduler_job.py | 70 ++++++++++ .../unit/partition_mappers/test_fixed_key.py | 123 ++++++++++++++++++ .../unit/partition_mappers/test_window.py | 42 ++++++ task-sdk/docs/api.rst | 4 + task-sdk/src/airflow/sdk/__init__.py | 6 + task-sdk/src/airflow/sdk/__init__.pyi | 4 + .../partition_mappers/fixed_key.py | 49 +++++++ .../definitions/partition_mappers/window.py | 43 +++++- .../definitions/test_partition_mappers.py | 68 ++++++++++ 15 files changed, 658 insertions(+), 1 deletion(-) create mode 100644 airflow-core/src/airflow/partition_mappers/fixed_key.py create mode 100644 airflow-core/tests/unit/partition_mappers/test_fixed_key.py create mode 100644 task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst b/airflow-core/docs/authoring-and-scheduling/assets.rst index 7bf609e9d9834..2bc06a3011721 100644 --- a/airflow-core/docs/authoring-and-scheduling/assets.rst +++ b/airflow-core/docs/authoring-and-scheduling/assets.rst @@ -565,6 +565,12 @@ downstream Dag partition key: passes the key through unchanged if valid. For example, ``AllowedKeyMapper(["us", "eu", "apac"])`` accepts only those region keys and rejects all others. +* ``FixedKeyMapper`` collapses every upstream key onto a fixed downstream key, + regardless of the upstream value. +* ``SegmentWindow`` declares a fixed categorical set of string keys (e.g. regions, + tenants) that constitute one downstream period; paired with ``FixedKeyMapper`` + inside a ``RollupMapper`` it holds the downstream run until every declared segment + has arrived (see :ref:`segment-rollup `). Example of per-asset mapper configuration and composite-key mapping: @@ -733,6 +739,83 @@ so the run is held indefinitely) and the fall-back day has twenty-five (the repe hour is dropped). Use a UTC-based upstream mapper for any rollup that crosses a DST boundary; see the ``DayWindow`` class docstring for the full discussion. +.. _segment-categorical-rollup: + +Segment (categorical) rollup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 3.3.0 + +For categorical partitioning — regions, tenants, experiment variants — compose a +``RollupMapper`` from two primitives: + +* ``SegmentWindow(["us", "eu", "apac"])`` declares the fixed set of string keys + that constitute one downstream period; ``to_upstream`` returns the full set + regardless of the downstream anchor. +* ``FixedKeyMapper("all_regions")`` collapses every upstream key onto the single + downstream partition key ``"all_regions"``. + +The scheduler holds the downstream Dag run until every declared segment has arrived +from the upstream producer, then fires once. All the segment events accumulate into +one ``AssetPartitionDagRun``; the fired run's ``partition_key`` is the value passed +to ``FixedKeyMapper``. This composition only makes sense under ``WAIT_FOR_ALL`` +semantics (the default). + +.. code-block:: python + + from airflow.sdk import ( + DAG, + Asset, + FixedKeyMapper, + PartitionAtRuntime, + PartitionedAssetTimetable, + RollupMapper, + SegmentWindow, + asset, + task, + ) + + + @asset( + uri="file://incoming/player-stats/multi-region.csv", + schedule=PartitionAtRuntime(), + ) + def multi_region_player_stats(self, outlet_events): + # Emit one event per region in a single run. + outlet_events[self].add_partitions(["us", "eu", "apac"]) + + + # Consumer: fires once all three region partitions have arrived. + with DAG( + dag_id="segment_region_stats_rollup", + schedule=PartitionedAssetTimetable( + assets=Asset.ref(name="multi_region_player_stats"), + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + catchup=False, + ): + + @task + def aggregate_all_regions(dag_run=None): + # dag_run.partition_key is the downstream key once all segments arrive. + print(dag_run.partition_key) + + aggregate_all_regions() + +Construction validates both components: ``SegmentWindow`` raises ``ValueError`` for +an empty list, non-string items, or empty-string keys; duplicate entries are silently +deduplicated. ``FixedKeyMapper`` raises ``ValueError`` if its argument is not a +non-empty string. Pass a distinct ``FixedKeyMapper`` key when one consumer Dag rolls +up more than one asset, so each rollup uses a distinct bucket and they do not collide +on the same ``(target_dag_id, partition_key)``. + +For a segment set that must be computed at runtime, do not encode it here — evaluate +completeness in a consumer-side task instead (the scheduler must not run user code to +decide a partition set). + Setting partition keys at runtime ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 775fea80ec314..65a078d3c3f9e 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -26,12 +26,14 @@ CronPartitionTimetable, DayWindow, FanOutMapper, + FixedKeyMapper, IdentityMapper, MonthWindow, PartitionAtRuntime, PartitionedAssetTimetable, ProductMapper, RollupMapper, + SegmentWindow, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, @@ -409,3 +411,41 @@ def run_inference(dag_run=None): print(dag_run.partition_key) run_inference() + + +# --- Segment (categorical) rollup ------------------------------------------- +# ``multi_region_player_stats`` (defined above) emits one partition per region +# (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a downstream +# run until every declared region key has arrived. + +with DAG( + dag_id="segment_region_stats_rollup", + schedule=PartitionedAssetTimetable( + assets=Asset.ref(name="multi_region_player_stats"), + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + catchup=False, + tags=["example", "player-stats", "rollup", "segment"], +): + """ + Categorical rollup: hold until all three region partitions arrive. + + ``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow([...]))`` + declares the fixed set of region keys required for one downstream run and collapses every + region key onto a single ``all_regions`` partition, so the three region events accumulate + into one downstream run. The run is held until ``us``, ``eu``, and ``apac`` have all + arrived from ``multi_region_player_stats``; partial arrivals remain pending in the + next-run-assets view so operators can track progress. + """ + + @task + def aggregate_all_regions(dag_run=None): + """Produce the cross-region summary once every region partition has arrived.""" + if TYPE_CHECKING: + assert dag_run + print(f"All region partitions received. Partition: {dag_run.partition_key}") + + aggregate_all_regions() diff --git a/airflow-core/src/airflow/partition_mappers/__init__.py b/airflow-core/src/airflow/partition_mappers/__init__.py index 9b66876e0df8a..f780603870736 100644 --- a/airflow-core/src/airflow/partition_mappers/__init__.py +++ b/airflow-core/src/airflow/partition_mappers/__init__.py @@ -19,6 +19,7 @@ from airflow.partition_mappers.allowed_key import AllowedKeyMapper from airflow.partition_mappers.base import PartitionMapper, RollupMapper from airflow.partition_mappers.chain import ChainMapper +from airflow.partition_mappers.fixed_key import FixedKeyMapper from airflow.partition_mappers.identity import IdentityMapper from airflow.partition_mappers.product import ProductMapper from airflow.partition_mappers.temporal import ( @@ -34,6 +35,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -43,6 +45,7 @@ "AllowedKeyMapper", "ChainMapper", "DayWindow", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "MonthWindow", @@ -50,6 +53,7 @@ "ProductMapper", "QuarterWindow", "RollupMapper", + "SegmentWindow", "StartOfDayMapper", "StartOfHourMapper", "StartOfMonthMapper", diff --git a/airflow-core/src/airflow/partition_mappers/fixed_key.py b/airflow-core/src/airflow/partition_mappers/fixed_key.py new file mode 100644 index 0000000000000..91c19d5102e5e --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/fixed_key.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.partition_mappers.base import PartitionMapper + + +class FixedKeyMapper(PartitionMapper): + """ + Collapse every upstream partition key onto one fixed downstream key. + + Returns the same *downstream_key* for any upstream key passed to + ``to_downstream``. Does not override ``decode_downstream`` or + ``encode_upstream``, so it works with the string-based identity path and + satisfies :class:`~airflow.partition_mappers.base.RollupMapper`'s guard + when paired with :class:`~airflow.partition_mappers.window.SegmentWindow`. + + Typical use is as the ``upstream_mapper`` inside a categorical rollup:: + + RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ) + + :param downstream_key: The fixed downstream partition key every upstream key + maps to. Must be a non-empty string. + :raises ValueError: if *downstream_key* is not a non-empty ``str``. + """ + + def __init__(self, downstream_key: str) -> None: + if not isinstance(downstream_key, str) or downstream_key == "": + raise ValueError( + f"FixedKeyMapper downstream_key must be a non-empty str; got {downstream_key!r}." + ) + self.downstream_key = downstream_key + + def to_downstream(self, key: str) -> str: + """Return the fixed downstream key regardless of *key*.""" + return self.downstream_key + + def serialize(self) -> dict[str, Any]: + return {"downstream_key": self.downstream_key} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FixedKeyMapper: + return cls(data["downstream_key"]) diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index 087b6180f2ade..b2b8ad901499b 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -246,3 +246,52 @@ class YearWindow(Window): def to_upstream(self, period_start: datetime) -> Iterable[datetime]: _require_day_one(period_start, type(self)) return _build_directional_steps(period_start, 12, _shift_months, self.direction) + + +class SegmentWindow(Window): + """ + A fixed categorical set of string keys that constitute one downstream period. + + Paired with :class:`~airflow.partition_mappers.fixed_key.FixedKeyMapper` inside a + :class:`~airflow.partition_mappers.base.RollupMapper` to express a categorical + rollup: the scheduler holds the downstream run until every declared segment key + has arrived from the upstream producer, then fires once. + + ``to_upstream`` returns the complete segment set regardless of the downstream + anchor value — the anchor is intentionally ignored because all segments map onto + a single downstream partition key, not a time-based period. + + :param segments: Non-empty iterable of non-empty string segment keys. Duplicates + are silently de-duplicated. + :raises ValueError: if *segments* is empty, contains a non-``str`` element, or + contains an empty-string element. + """ + + expected_decoded_type: ClassVar[type] = str + + def __init__(self, segments: Iterable[str]) -> None: + collected: list[str] = list(segments) + if not collected: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + for i, item in enumerate(collected): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; " + f"got {type(item).__name__!r} at index {i}: {item!r}" + ) + if item == "": + raise ValueError( + f"SegmentWindow segment keys must be non-empty strings; got an empty string at index {i}." + ) + self._segments: frozenset[str] = frozenset(collected) + + def to_upstream(self, decoded_downstream: Any) -> frozenset[str]: + """Return the full declared segment set, ignoring the downstream anchor.""" + return frozenset(self._segments) + + def serialize(self) -> dict[str, Any]: + return {"segments": sorted(self._segments)} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SegmentWindow: + return cls(data["segments"]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index f4ec081df1bef..0c2f030da8f78 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -43,6 +43,7 @@ DeltaTriggerTimetable, EventsTimetable, FanOutMapper, + FixedKeyMapper, HourWindow, IdentityMapper, MonthWindow, @@ -51,6 +52,7 @@ ProductMapper, QuarterWindow, RollupMapper, + SegmentWindow, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, @@ -437,6 +439,7 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: AllowedKeyMapper: "airflow.partition_mappers.allowed_key.AllowedKeyMapper", ChainMapper: "airflow.partition_mappers.chain.ChainMapper", FanOutMapper: "airflow.partition_mappers.temporal.FanOutMapper", + FixedKeyMapper: "airflow.partition_mappers.fixed_key.FixedKeyMapper", IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", RollupMapper: "airflow.partition_mappers.base.RollupMapper", @@ -464,6 +467,10 @@ def _(self, partition_mapper: ChainMapper) -> dict[str, Any]: def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]: return {} + @serialize_partition_mapper.register + def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]: + return {"downstream_key": partition_mapper.downstream_key} + @serialize_partition_mapper.register(StartOfHourMapper) @serialize_partition_mapper.register(StartOfDayMapper) @serialize_partition_mapper.register(StartOfWeekMapper) @@ -517,6 +524,7 @@ def _(self, partition_mapper: FanOutMapper) -> dict[str, Any]: WeekWindow: "airflow.partition_mappers.window.WeekWindow", MonthWindow: "airflow.partition_mappers.window.MonthWindow", QuarterWindow: "airflow.partition_mappers.window.QuarterWindow", + SegmentWindow: "airflow.partition_mappers.window.SegmentWindow", YearWindow: "airflow.partition_mappers.window.YearWindow", } @@ -538,6 +546,10 @@ def _( ) -> dict[str, Any]: return window.serialize() + @serialize_window.register + def _(self, window: SegmentWindow) -> dict[str, Any]: + return {"segments": sorted(window._segments)} + _serializer = _Serializer() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 89383233fad02..d08628a51fbcd 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -104,9 +104,11 @@ AssetAlias, AssetWatcher, CronPartitionTimetable, + FixedKeyMapper, HourWindow, IdentityMapper, RollupMapper, + SegmentWindow, StartOfHourMapper, task, ) @@ -10273,6 +10275,74 @@ def test_partitioned_dag_run_rollup_holds_until_window_complete( assert partition_dags == {"rollup-consumer"} +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_segment_rollup_holds_until_all_segments_arrive( + dag_maker: DagMaker, + session: Session, +): + """ + A categorical (segment) rollup fires once every declared segment has arrived. + + ``RollupMapper(FixedKeyMapper("all_regions"), SegmentWindow([...]))`` collapses + each region key onto a single ``all_regions`` partition, so all three events + accumulate into one APDR, and holds the downstream run until ``us``, ``eu``, + and ``apac`` are all present. + """ + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="segment-rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # First region arrives — only 1 / 3 segments, so the APDR must not fire. + # Every region collapses onto the single ``all_regions`` partition. + apdr = _produce_and_register_asset_event( + dag_id="segment-producer-us", + asset=asset_1, + partition_key="us", + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + # The remaining two regions arrive — once all three segments are present the + # rollup is satisfied and the APDR creates its Dag run on the next tick. All + # three events share the one ``all_regions`` APDR. + for region in ("eu", "apac"): + sibling = _produce_and_register_asset_event( + dag_id=f"segment-producer-{region}", + asset=asset_1, + partition_key=region, + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + assert sibling.id == apdr.id + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + assert apdr.partition_key == "all_regions" + assert partition_dags == {"segment-rollup-consumer"} + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( diff --git a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py new file mode 100644 index 0000000000000..800e96a4e3aba --- /dev/null +++ b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.partition_mappers.base import PartitionMapper, RollupMapper +from airflow.partition_mappers.fixed_key import FixedKeyMapper +from airflow.partition_mappers.window import DayWindow, SegmentWindow +from airflow.sdk import ( + FixedKeyMapper as SdkFixedKeyMapper, + RollupMapper as SdkRollupMapper, + SegmentWindow as SdkSegmentWindow, +) +from airflow.serialization.decoders import decode_partition_mapper +from airflow.serialization.encoders import encode_partition_mapper + + +class TestFixedKeyMapper: + def test_to_downstream_returns_constant_for_any_key(self): + m = FixedKeyMapper("all") + assert m.to_downstream("us") == "all" + assert m.to_downstream("eu") == "all" + assert m.to_downstream("apac") == "all" + assert m.to_downstream("anything-else") == "all" + + def test_is_rollup_false(self): + # A bare FixedKeyMapper is not a rollup; rollup-ness comes from RollupMapper. + assert FixedKeyMapper("all").is_rollup is False + + def test_does_not_override_decode_encode(self): + m = FixedKeyMapper("all") + assert type(m).decode_downstream is PartitionMapper.decode_downstream + assert type(m).encode_upstream is PartitionMapper.encode_upstream + + @pytest.mark.parametrize( + ("downstream_key", "match"), + [ + pytest.param("", "non-empty str", id="empty-string"), + pytest.param(None, "non-empty str", id="none"), + pytest.param(1, "non-empty str", id="int"), + ], + ) + def test_rejects_invalid_downstream_key(self, downstream_key, match): + with pytest.raises(ValueError, match=match): + FixedKeyMapper(downstream_key) + + def test_requires_downstream_key(self): + with pytest.raises(TypeError): + FixedKeyMapper() + + def test_serialize_round_trip(self): + m = FixedKeyMapper("bucket") + restored = FixedKeyMapper.deserialize(m.serialize()) + assert isinstance(restored, FixedKeyMapper) + assert restored.downstream_key == "bucket" + + +class TestCategoricalRollupEquivalence: + """RollupMapper(FixedKeyMapper, SegmentWindow) behaves like old SegmentMapper.""" + + def setup_method(self): + self.m = RollupMapper( + upstream_mapper=FixedKeyMapper("all"), + window=SegmentWindow(["us", "eu", "apac"]), + ) + + def test_is_rollup_flag(self): + assert self.m.is_rollup is True + + def test_to_downstream_collapses_every_segment_onto_downstream_key(self): + # Full-sequence equality: every declared segment key maps to the constant key. + assert [self.m.to_downstream(s) for s in ("us", "eu", "apac")] == ["all", "all", "all"] + + def test_to_upstream_returns_full_set_ignoring_anchor(self): + assert self.m.to_upstream("all") == frozenset({"us", "eu", "apac"}) + assert self.m.to_upstream("anything") == frozenset({"us", "eu", "apac"}) + + def test_core_encode_decode_round_trip(self): + restored = decode_partition_mapper(encode_partition_mapper(self.m)) + assert isinstance(restored, RollupMapper) + assert restored.is_rollup is True + assert restored.to_downstream("us") == "all" + assert restored.to_upstream("all") == frozenset({"us", "eu", "apac"}) + + def test_sdk_encode_decode_round_trip(self): + # User code authors with SDK classes; the scheduler serializes and deserializes + # into core classes. + sdk_mapper = SdkRollupMapper( + upstream_mapper=SdkFixedKeyMapper("all_regions"), + window=SdkSegmentWindow(["us", "eu", "apac"]), + ) + restored = decode_partition_mapper(encode_partition_mapper(sdk_mapper)) + assert isinstance(restored, RollupMapper) + assert restored.to_upstream("all_regions") == frozenset({"us", "eu", "apac"}) + + +class TestCategoricalRollupTypeGuard: + """Core-side RollupMapper guard: FixedKeyMapper(str) + SegmentWindow(str) must pass.""" + + def test_fixed_key_with_segment_window_does_not_raise(self): + # Core guard: FixedKeyMapper does not override decode_downstream, + # SegmentWindow.expected_decoded_type is str -> guard passes. + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=SegmentWindow(["us", "eu"])) + + def test_str_mapper_with_datetime_window_raises(self): + # Core guard: FixedKeyMapper (no decode override) + DayWindow (datetime) -> raise. + with pytest.raises(TypeError, match="DayWindow expects decoded values of type 'datetime'"): + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=DayWindow()) diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index ddfd44cbd76f7..f4aaf01b69e54 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -34,6 +34,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -465,6 +466,47 @@ def test_invalid_direction_raises_value_error(self, bad_value): WeekWindow(direction=bad_value) +class TestSegmentWindow: + def test_to_upstream_returns_full_set_ignoring_anchor(self): + w = SegmentWindow(["us", "eu", "apac"]) + result_a = frozenset(w.to_upstream("any-anchor")) + result_b = frozenset(w.to_upstream("different-anchor")) + assert result_a == frozenset({"us", "eu", "apac"}) + assert result_a == result_b + + def test_expected_decoded_type_is_str(self): + assert SegmentWindow.expected_decoded_type is str + + @pytest.mark.parametrize( + ("segments", "match"), + [ + pytest.param([], "at least one segment key", id="empty-list"), + pytest.param(iter([]), "at least one segment key", id="empty-iterator"), + pytest.param([1, "b"], "must be str", id="int-element"), + pytest.param([None, "b"], "must be str", id="none-element"), + pytest.param(["", "b"], "non-empty strings", id="empty-string-first"), + pytest.param(["a", ""], "non-empty strings", id="empty-string-second"), + ], + ) + def test_rejects_invalid_segments(self, segments, match): + with pytest.raises(ValueError, match=match): + SegmentWindow(segments) + + def test_deduplication(self): + w = SegmentWindow(["us", "us", "eu"]) + assert frozenset(w.to_upstream("any")) == frozenset({"us", "eu"}) + + def test_serialize_uses_sorted_order(self): + w = SegmentWindow(["z", "a", "m"]) + assert w.serialize() == {"segments": ["a", "m", "z"]} + + def test_deserialize_round_trip(self): + w = SegmentWindow(["us", "eu", "apac"]) + restored = SegmentWindow.deserialize(w.serialize()) + assert isinstance(restored, SegmentWindow) + assert frozenset(restored.to_upstream("any")) == frozenset({"us", "eu", "apac"}) + + class TestWindowSerializationGate: """``encode_window`` / ``decode_window`` must reject non-built-in Windows. diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 8d87897898057..cd9cec283c552 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -247,6 +247,8 @@ Partition Mapper .. autoapiclass:: airflow.sdk.FanOutMapper +.. autoapiclass:: airflow.sdk.FixedKeyMapper + Rollup Windows ~~~~~~~~~~~~~~ @@ -264,6 +266,8 @@ Rollup Windows .. autoapiclass:: airflow.sdk.YearWindow +.. autoapiclass:: airflow.sdk.SegmentWindow + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index ce5b8ab2c2ec1..00d724cb42d5b 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -55,6 +55,7 @@ "EventsTimetable", "ExceptionRetryPolicy", "FanOutMapper", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "Label", @@ -77,6 +78,7 @@ "RetryPolicy", "RetryRule", "RollupMapper", + "SegmentWindow", "SkipMixin", "SyncCallback", "StartOfDayMapper", @@ -154,6 +156,7 @@ from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper + from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( @@ -170,6 +173,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -244,6 +248,7 @@ "EventsTimetable": ".definitions.timetables.events", "ExceptionRetryPolicy": ".definitions.retry_policy", "FanOutMapper": ".definitions.partition_mappers.temporal", + "FixedKeyMapper": ".definitions.partition_mappers.fixed_key", "HourWindow": ".definitions.partition_mappers.window", "IdentityMapper": ".definitions.partition_mappers.identity", "Label": ".definitions.edges", @@ -266,6 +271,7 @@ "RetryRule": ".definitions.retry_policy", "RollupMapper": ".definitions.partition_mappers.base", "SecretCache": ".execution_time.cache", + "SegmentWindow": ".definitions.partition_mappers.window", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", "StartOfDayMapper": ".definitions.partition_mappers.temporal", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 1bb975e2202dd..d6fa8bc2a0d1f 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -67,6 +67,7 @@ from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper +from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( @@ -83,6 +84,7 @@ from airflow.sdk.definitions.partition_mappers.window import ( HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -154,6 +156,7 @@ __all__ = [ "EventsTimetable", "ExceptionRetryPolicy", "FanOutMapper", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "Label", @@ -175,6 +178,7 @@ __all__ = [ "ResumableJobMixin", "RollupMapper", "SecretCache", + "SegmentWindow", "SkipMixin", "StartOfDayMapper", "StartOfHourMapper", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py new file mode 100644 index 0000000000000..24a1212e1076b --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper + + +class FixedKeyMapper(PartitionMapper): + """ + Collapse every upstream partition key onto one fixed downstream key. + + Authoring marker for the scheduler-side + :class:`airflow.partition_mappers.fixed_key.FixedKeyMapper`. Paired with + :class:`~airflow.sdk.definitions.partition_mappers.window.SegmentWindow` inside a + :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to express a + categorical rollup. + + Construction validates *downstream_key* so Dag parse errors surface + immediately rather than deferring to scheduler deserialization. + + :param downstream_key: The fixed downstream partition key every upstream key + maps to. Must be a non-empty string. + :raises ValueError: if *downstream_key* is not a non-empty ``str``. + """ + + def __init__(self, downstream_key: str) -> None: + if not isinstance(downstream_key, str) or downstream_key == "": + raise ValueError( + f"FixedKeyMapper downstream_key must be a non-empty str; got {downstream_key!r}." + ) + self.downstream_key = downstream_key + + def to_downstream(self, key: str) -> str: + """Return the fixed downstream key regardless of *key*.""" + return self.downstream_key diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py index 4ea008ace6f79..9a02f46b873ad 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py @@ -22,7 +22,10 @@ from datetime import datetime from enum import Enum -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from collections.abc import Iterable class Window: @@ -114,3 +117,41 @@ class YearWindow(Window): """Twelve consecutive monthly keys making up one calendar year.""" expected_decoded_type: ClassVar[type] = datetime + + +class SegmentWindow(Window): + """ + A fixed categorical set of string keys that constitute one downstream period. + + Authoring marker for the scheduler-side + :class:`airflow.partition_mappers.window.SegmentWindow`. Paired with + :class:`~airflow.sdk.definitions.partition_mappers.fixed_key.FixedKeyMapper` inside a + :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to express a + categorical rollup. + + Construction validates the segment list so Dag parse errors surface + immediately rather than deferring to scheduler deserialization. + + :param segments: Non-empty iterable of non-empty string segment keys. Duplicates + are silently de-duplicated. + :raises ValueError: if *segments* is empty, contains a non-``str`` element, or + contains an empty-string element. + """ + + expected_decoded_type: ClassVar[type] = str + + def __init__(self, segments: Iterable[str]) -> None: + collected: list[str] = list(segments) + if not collected: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + for i, item in enumerate(collected): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; " + f"got {type(item).__name__!r} at index {i}: {item!r}" + ) + if item == "": + raise ValueError( + f"SegmentWindow segment keys must be non-empty strings; got an empty string at index {i}." + ) + self._segments: frozenset[str] = frozenset(collected) diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py index 8b902820da0d6..0f6c1e1a7a475 100644 --- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py +++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py @@ -22,12 +22,14 @@ import pytest from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper +from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper from airflow.sdk.definitions.partition_mappers.window import ( DayWindow, HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -103,3 +105,69 @@ class TestSdkWindowExpectedDecodedType: ) def test_temporal_windows_declare_datetime(self, window_cls): assert window_cls.expected_decoded_type is datetime + + +class TestSdkFixedKeyMapper: + """SDK-side FixedKeyMapper construction and validation.""" + + def test_to_downstream_returns_constant_for_any_key(self): + m = FixedKeyMapper("all_regions") + assert m.to_downstream("us") == "all_regions" + assert m.to_downstream("eu") == "all_regions" + assert m.to_downstream("apac") == "all_regions" + + def test_is_rollup_false(self): + assert FixedKeyMapper("all").is_rollup is False + + @pytest.mark.parametrize( + ("downstream_key", "match"), + [ + pytest.param("", "non-empty str", id="empty-string"), + pytest.param(None, "non-empty str", id="none"), + pytest.param(1, "non-empty str", id="int"), + ], + ) + def test_rejects_invalid_downstream_key(self, downstream_key, match): + with pytest.raises(ValueError, match=match): + FixedKeyMapper(downstream_key) + + def test_requires_downstream_key(self): + with pytest.raises(TypeError): + FixedKeyMapper() + + +class TestSdkSegmentWindow: + """SDK-side SegmentWindow construction and validation mirrors the core implementation.""" + + def test_expected_decoded_type_is_str(self): + assert SegmentWindow.expected_decoded_type is str + + def test_deduplication(self): + w = SegmentWindow(["a", "b", "a"]) + assert w._segments == frozenset({"a", "b"}) + + @pytest.mark.parametrize( + ("segments", "match"), + [ + pytest.param([], "at least one segment key", id="empty-list"), + pytest.param([1, "b"], "must be str", id="int-element"), + pytest.param(["", "b"], "non-empty strings", id="empty-string"), + ], + ) + def test_rejects_invalid_segments(self, segments, match): + with pytest.raises(ValueError, match=match): + SegmentWindow(segments) + + +class TestSdkCategoricalRollupGuard: + """SDK-side RollupMapper guard mirrors core: str mapper + str window passes.""" + + def test_fixed_key_with_segment_window_does_not_raise(self): + # SDK guard: FixedKeyMapper.expected_decoded_type is str, + # SegmentWindow.expected_decoded_type is str -> guard passes. + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=SegmentWindow(["us", "eu"])) + + def test_str_mapper_with_datetime_window_raises(self): + # SDK guard: FixedKeyMapper (str) + DayWindow (datetime) -> raise. + with pytest.raises(TypeError, match="DayWindow expects decoded values of type 'datetime'"): + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=DayWindow()) From 553f7f4a2d2067298bc157569c4e990afa711971 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 9 Jun 2026 09:02:49 +0800 Subject: [PATCH 2/4] docs: add newsfragments --- airflow-core/newsfragments/67716.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 airflow-core/newsfragments/67716.feature.rst diff --git a/airflow-core/newsfragments/67716.feature.rst b/airflow-core/newsfragments/67716.feature.rst new file mode 100644 index 0000000000000..63daeda113ca1 --- /dev/null +++ b/airflow-core/newsfragments/67716.feature.rst @@ -0,0 +1 @@ +Add ``FixedKeyMapper`` and ``SegmentWindow`` for categorical asset-partition rollup. ``FixedKeyMapper`` collapses any upstream partition key onto a single fixed downstream key, and ``SegmentWindow`` enumerates a fixed categorical segment set the scheduler waits for. Composing ``RollupMapper(FixedKeyMapper(...), SegmentWindow(...))`` expresses a categorical rollup, mirroring the temporal rollup shape, and ``SegmentWindow`` also composes with ``FanOutMapper`` for categorical scatter. Both classes are importable from ``airflow.sdk``. From 2e3035fb81370ad5ffd94ba2a02da9eafbf7a03e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 9 Jun 2026 15:09:42 +0800 Subject: [PATCH 3/4] refactor(partition-mappers): tidy SegmentWindow validation and drop redundant copy - to_upstream returns the stored frozenset directly instead of re-wrapping it - validate segments through a generator into frozenset, checking emptiness after the cast - reword empty-segment error to 'must be non-empty' Addresses review on PR #67716; SDK-side parity tracked in the follow-up commit. --- .../src/airflow/partition_mappers/window.py | 17 +++++++++-------- .../tests/unit/partition_mappers/test_window.py | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index b2b8ad901499b..ecaa2dc1d0180 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -270,24 +270,25 @@ class SegmentWindow(Window): expected_decoded_type: ClassVar[type] = str def __init__(self, segments: Iterable[str]) -> None: - collected: list[str] = list(segments) - if not collected: - raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") - for i, item in enumerate(collected): + def _check_one(i: int, item: str) -> str: if not isinstance(item, str): raise ValueError( f"SegmentWindow segment keys must be str; " f"got {type(item).__name__!r} at index {i}: {item!r}" ) - if item == "": + if not item: raise ValueError( - f"SegmentWindow segment keys must be non-empty strings; got an empty string at index {i}." + f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." ) - self._segments: frozenset[str] = frozenset(collected) + return item + + self._segments: frozenset[str] = frozenset(_check_one(i, item) for i, item in enumerate(segments)) + if not self._segments: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") def to_upstream(self, decoded_downstream: Any) -> frozenset[str]: """Return the full declared segment set, ignoring the downstream anchor.""" - return frozenset(self._segments) + return self._segments def serialize(self) -> dict[str, Any]: return {"segments": sorted(self._segments)} diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index f4aaf01b69e54..b4956761c2ce7 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -484,8 +484,8 @@ def test_expected_decoded_type_is_str(self): pytest.param(iter([]), "at least one segment key", id="empty-iterator"), pytest.param([1, "b"], "must be str", id="int-element"), pytest.param([None, "b"], "must be str", id="none-element"), - pytest.param(["", "b"], "non-empty strings", id="empty-string-first"), - pytest.param(["a", ""], "non-empty strings", id="empty-string-second"), + pytest.param(["", "b"], "non-empty", id="empty-string-first"), + pytest.param(["a", ""], "non-empty", id="empty-string-second"), ], ) def test_rejects_invalid_segments(self, segments, match): From 6f2a8850c7b8098541a06f9d19c07e88d9e9928e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 9 Jun 2026 15:56:03 +0800 Subject: [PATCH 4/4] =?UTF-8?q?refactor(partition-mappers):=20define=20Seg?= =?UTF-8?q?mentWindow/FixedKeyMapper=20with=20attrs=20and=20guard=20core?= =?UTF-8?q?=E2=86=94SDK=20message=20drift?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - core and SDK SegmentWindow/FixedKeyMapper now use @attrs.define with custom validators, preserving the same error messages and the _segments/downstream_key attribute names serialization relies on - align the SDK empty-segment error wording with core ("must be non-empty") - extend check-partition-mapper-defaults-in-sync to compare each class's field names and raise-message templates across core and SDK (following attrs converter functions), so validation drift fails CI Addresses review on PR #67716. --- .pre-commit-config.yaml | 8 +- .../airflow/partition_mappers/fixed_key.py | 15 +- .../src/airflow/partition_mappers/window.py | 44 +- .../unit/partition_mappers/test_fixed_key.py | 15 +- ...check_partition_mapper_defaults_in_sync.py | 291 +++++++++++-- ...check_partition_mapper_defaults_in_sync.py | 388 ++++++++++++++++++ .../partition_mappers/fixed_key.py | 15 +- .../definitions/partition_mappers/window.py | 43 +- .../definitions/test_partition_mappers.py | 10 +- 9 files changed, 742 insertions(+), 87 deletions(-) create mode 100644 scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1de10325dd64b..2be1cba4bd28e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -249,13 +249,17 @@ repos: pass_filenames: false require_serial: true - id: check-partition-mapper-defaults-in-sync - name: Check FanOutMapper default mapper table stays in sync (core/SDK) + name: Check partition-mapper core/SDK sync (FanOutMapper table + SegmentWindow/FixedKeyMapper) entry: ./scripts/ci/prek/check_partition_mapper_defaults_in_sync.py language: python files: > (?x) ^airflow-core/src/airflow/partition_mappers/temporal\.py$| - ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$ + ^airflow-core/src/airflow/partition_mappers/window\.py$| + ^airflow-core/src/airflow/partition_mappers/fixed_key\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key\.py$ pass_filenames: false require_serial: true - id: check-window-in-sync diff --git a/airflow-core/src/airflow/partition_mappers/fixed_key.py b/airflow-core/src/airflow/partition_mappers/fixed_key.py index 91c19d5102e5e..89388bfed12cd 100644 --- a/airflow-core/src/airflow/partition_mappers/fixed_key.py +++ b/airflow-core/src/airflow/partition_mappers/fixed_key.py @@ -18,9 +18,12 @@ from typing import Any +import attrs + from airflow.partition_mappers.base import PartitionMapper +@attrs.define class FixedKeyMapper(PartitionMapper): """ Collapse every upstream partition key onto one fixed downstream key. @@ -43,12 +46,12 @@ class FixedKeyMapper(PartitionMapper): :raises ValueError: if *downstream_key* is not a non-empty ``str``. """ - def __init__(self, downstream_key: str) -> None: - if not isinstance(downstream_key, str) or downstream_key == "": - raise ValueError( - f"FixedKeyMapper downstream_key must be a non-empty str; got {downstream_key!r}." - ) - self.downstream_key = downstream_key + downstream_key: str = attrs.field() + + @downstream_key.validator + def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) -> None: + if not isinstance(value, str) or value == "": + raise ValueError(f"FixedKeyMapper downstream_key must be a non-empty str; got {value!r}.") def to_downstream(self, key: str) -> str: """Return the fixed downstream key regardless of *key*.""" diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index ecaa2dc1d0180..ffcf0e85ab001 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -21,6 +21,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar +import attrs + if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -248,6 +250,31 @@ def to_upstream(self, period_start: datetime) -> Iterable[datetime]: return _build_directional_steps(period_start, 12, _shift_months, self.direction) +def _convert_segments(segments: Iterable[str]) -> frozenset[str]: + """ + Validate and convert *segments* to a ``frozenset[str]``. + + Validates each element for type and non-emptiness (with index reporting) + before collapsing into a frozenset, then checks the result is non-empty. + """ + validated: list[str] = [] + for i, item in enumerate(segments): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; got {type(item).__name__!r} at index {i}: {item!r}" + ) + if not item: + raise ValueError( + f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." + ) + validated.append(item) + result = frozenset(validated) + if not result: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + return result + + +@attrs.define class SegmentWindow(Window): """ A fixed categorical set of string keys that constitute one downstream period. @@ -269,22 +296,7 @@ class SegmentWindow(Window): expected_decoded_type: ClassVar[type] = str - def __init__(self, segments: Iterable[str]) -> None: - def _check_one(i: int, item: str) -> str: - if not isinstance(item, str): - raise ValueError( - f"SegmentWindow segment keys must be str; " - f"got {type(item).__name__!r} at index {i}: {item!r}" - ) - if not item: - raise ValueError( - f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." - ) - return item - - self._segments: frozenset[str] = frozenset(_check_one(i, item) for i, item in enumerate(segments)) - if not self._segments: - raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + _segments: frozenset[str] = attrs.field(converter=_convert_segments) def to_upstream(self, decoded_downstream: Any) -> frozenset[str]: """Return the full declared segment set, ignoring the downstream anchor.""" diff --git a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py index 800e96a4e3aba..efdd5be331370 100644 --- a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py +++ b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py @@ -31,12 +31,9 @@ class TestFixedKeyMapper: - def test_to_downstream_returns_constant_for_any_key(self): - m = FixedKeyMapper("all") - assert m.to_downstream("us") == "all" - assert m.to_downstream("eu") == "all" - assert m.to_downstream("apac") == "all" - assert m.to_downstream("anything-else") == "all" + @pytest.mark.parametrize("key", ["us", "eu", "apac", "anything-else"]) + def test_to_downstream_returns_constant_for_any_key(self, key): + assert FixedKeyMapper("all").to_downstream(key) == "all" def test_is_rollup_false(self): # A bare FixedKeyMapper is not a rollup; rollup-ness comes from RollupMapper. @@ -86,9 +83,9 @@ def test_to_downstream_collapses_every_segment_onto_downstream_key(self): # Full-sequence equality: every declared segment key maps to the constant key. assert [self.m.to_downstream(s) for s in ("us", "eu", "apac")] == ["all", "all", "all"] - def test_to_upstream_returns_full_set_ignoring_anchor(self): - assert self.m.to_upstream("all") == frozenset({"us", "eu", "apac"}) - assert self.m.to_upstream("anything") == frozenset({"us", "eu", "apac"}) + @pytest.mark.parametrize("anchor", ["all", "anything"]) + def test_to_upstream_returns_full_set_ignoring_anchor(self, anchor): + assert self.m.to_upstream(anchor) == frozenset({"us", "eu", "apac"}) def test_core_encode_decode_round_trip(self): restored = decode_partition_mapper(encode_partition_mapper(self.m)) diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py index 0c1adea966940..24f7520f1b5a6 100755 --- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py +++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py @@ -23,24 +23,25 @@ # ] # /// """ -Verify ``FanOutMapper.default_downstream_mapper_by_window_name`` stays in sync -between core and the Task SDK. +Verify partition-mapper definitions stay in sync between core and the Task SDK. -The default downstream-mapper table is defined twice — once in the core class -hierarchy and once in the SDK copy — because the two hierarchies are -independent (the SDK cannot import core) and the lookup is by ``Window`` class -*name*. Both copies must list the same ``Window`` name -> mapper class mapping, -otherwise a ``FanOutMapper`` resolves a different default depending on whether -it runs in Dag-author code (SDK) or after deserialization (core). +Checks two things: -This check parses the ``default_downstream_mapper_by_window_name`` class -attribute from both files via AST and asserts the two mappings are identical. +1. ``FanOutMapper.default_downstream_mapper_by_window_name`` — the default + downstream-mapper table is defined twice (core and SDK) because the two + hierarchies are independent. Both copies must list the same + ``Window`` name -> mapper class mapping. + +2. ``SegmentWindow`` and ``FixedKeyMapper`` — for each class, the field-name + set and the ``raise ValueError(...)`` message-template set must be identical + between core and the SDK. This catches wording drift (e.g. "non-empty" vs + "non-empty strings") that would otherwise silently diverge. Run from the repo root: uv run --project scripts python scripts/ci/prek/check_partition_mapper_defaults_in_sync.py -Exits 0 if the two tables match, 1 (with a diff) otherwise. +Exits 0 if everything matches, 1 (with a diff) otherwise. """ from __future__ import annotations @@ -65,6 +66,15 @@ AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "temporal.py" ) +CORE_WINDOW_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" / "window.py" +SDK_WINDOW_FILE = ( + AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "window.py" +) +CORE_FIXED_KEY_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" / "fixed_key.py" +SDK_FIXED_KEY_FILE = ( + AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "fixed_key.py" +) + def _find_attr_value(file_path: Path) -> ast.Dict: """Return the AST node assigned to ``FanOutMapper.default_downstream_mapper_by_window_name``.""" @@ -113,7 +123,200 @@ def _extract_mapping(file_path: Path) -> dict[str, str]: return mapping +def _joinedstr_to_template(node: ast.JoinedStr) -> str: + """ + Convert an f-string AST node to a template string. + + Literal text fragments are kept as-is; each interpolated expression + ``{...}`` is replaced by a ``{}`` placeholder. This lets us compare + f-string message templates between core and SDK without caring about the + exact expression inside the braces (e.g. ``{type(item).__name__!r}`` vs + a future refactoring). + """ + parts: list[str] = [] + for value in node.values: + if isinstance(value, ast.Constant): + parts.append(str(value.value)) + else: + parts.append("{}") + return "".join(parts) + + +def _extract_raise_messages(subtree: ast.AST) -> set[str]: + """ + Collect ``raise ValueError(...)`` message templates from an arbitrary AST subtree. + + For each ``raise ValueError(...)`` whose first argument is a plain string + constant, an f-string, or adjacent f-string/constant concatenation + (``BinOp(Add, ...)``), extracts the template: + + - ``ast.Constant`` → the literal string value. + - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated + expressions replaced by ``{}``. + - ``ast.BinOp(Add, ...)`` → recursively concatenated from both sides. + """ + messages: set[str] = set() + + def _collect_bin(n: ast.expr) -> str: + if isinstance(n, ast.Constant) and isinstance(n.value, str): + return n.value + if isinstance(n, ast.JoinedStr): + return _joinedstr_to_template(n) + if isinstance(n, ast.BinOp) and isinstance(n.op, ast.Add): + return _collect_bin(n.left) + _collect_bin(n.right) + return "{}" + + for node in ast.walk(subtree): + if not isinstance(node, ast.Raise): + continue + exc = node.exc + if exc is None: + continue + if not ( + isinstance(exc, ast.Call) + and isinstance(exc.func, ast.Name) + and exc.func.id == "ValueError" + and exc.args + ): + continue + arg = exc.args[0] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + messages.add(arg.value) + elif isinstance(arg, ast.JoinedStr): + messages.add(_joinedstr_to_template(arg)) + elif isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Add): + messages.add(_collect_bin(arg)) + + return messages + + +def _collect_converter_names(class_node: ast.ClassDef) -> set[str]: + """ + Return function names referenced via ``converter=`` in attrs field definitions. + + Scans the top-level statements of *class_node* for annotated assignments of + the form ``: = attrs.field(converter=)`` (or + ``attr.field`` / bare ``field(...)``). Also handles unannotated + ``Assign`` nodes with an ``attrs.field(converter=)`` call on the RHS. + Only bare ``ast.Name`` references are collected; lambdas and attribute + lookups are ignored (they cannot resolve to a module-level function by + name). + """ + names: set[str] = set() + + def _check_call(call_node: ast.expr) -> None: + """If *call_node* is an attrs.field / attr.field / field() call, collect converter=.""" + if not isinstance(call_node, ast.Call): + return + func = call_node.func + is_field_call = ( + # attrs.field(...) or attr.field(...) + (isinstance(func, ast.Attribute) and func.attr == "field") + # bare field(...) + or (isinstance(func, ast.Name) and func.id == "field") + ) + if not is_field_call: + return + for kw in call_node.keywords: + if kw.arg == "converter" and isinstance(kw.value, ast.Name): + names.add(kw.value.id) + + for stmt in class_node.body: + if isinstance(stmt, ast.AnnAssign) and stmt.value is not None: + _check_call(stmt.value) + elif isinstance(stmt, ast.Assign): + for target_value in [stmt.value]: + _check_call(target_value) + + return names + + +def extract_class_error_messages(file_path: Path, class_name: str) -> set[str]: + """ + Return the set of ``raise ValueError(...)`` message templates for *class_name*. + + Scans two sources: + + 1. The class body itself (all depths) — covers validator methods and nested + helpers declared inside the class. + 2. Module-level functions referenced via ``converter=`` in + ``attrs.field(...)`` declarations inside the class body. These + converters live outside the class but are logically part of its + construction-time validation. + + For each ``raise ValueError(...)`` whose first argument is a plain string + constant or an f-string, extracts the template: + + - ``ast.Constant`` → the literal string value. + - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated + expressions replaced by ``{}``. + """ + tree = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + + target_class: ast.ClassDef | None = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + target_class = node + break + + if target_class is None: + return set() + + # 1. Collect messages from the class body. + messages = _extract_raise_messages(target_class) + + # 2. Follow converter= references to module-level functions. + converter_names = _collect_converter_names(target_class) + if converter_names: + # Only look at top-level (module-body) function definitions to avoid + # accidentally matching a same-named method inside another class. + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name in converter_names: + messages |= _extract_raise_messages(node) + + return messages + + +def extract_class_field_names(file_path: Path, class_name: str) -> set[str]: + """ + Return the set of annotated attribute names declared in *class_name*'s body. + + Collects names from annotated assignments (``name: type = ...`` or + ``name: type``) at the top level of the class body. ClassVar annotations + are excluded because they are class-level constants, not instance fields. + """ + tree = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + fields: set[str] = set() + + for node in ast.walk(tree): + if not (isinstance(node, ast.ClassDef) and node.name == class_name): + continue + for stmt in node.body: + if not isinstance(stmt, ast.AnnAssign): + continue + if not isinstance(stmt.target, ast.Name): + continue + # Skip ClassVar[...] annotations + ann = stmt.annotation + is_classvar = False + if isinstance(ann, ast.Subscript): + if isinstance(ann.value, ast.Name) and ann.value.id == "ClassVar": + is_classvar = True + elif isinstance(ann.value, ast.Attribute) and ann.value.attr == "ClassVar": + is_classvar = True + elif isinstance(ann, ast.Name) and ann.id == "ClassVar": + is_classvar = True + if is_classvar: + continue + fields.add(stmt.target.id) + break + + return fields + + def main() -> int: + failed = False + try: core_mapping = _extract_mapping(CORE_FILE) sdk_mapping = _extract_mapping(SDK_FILE) @@ -121,22 +324,56 @@ def main() -> int: console.print(f"[red]Could not read the default mapper table:[/red] {exc}") return 1 - if core_mapping == sdk_mapping: - return 0 - - console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core and the Task SDK.[/red]\n") - all_windows = sorted(core_mapping.keys() | sdk_mapping.keys()) - for window in all_windows: - core_val = core_mapping.get(window, "") - sdk_val = sdk_mapping.get(window, "") - marker = " " if core_val == sdk_val else "->" - color = "" if core_val == sdk_val else "[red]" - end = "" if core_val == sdk_val else "[/red]" - console.print(f"{color}{marker} {window}: core={core_val} sdk={sdk_val}{end}") - console.print( - f"\nMake both tables list the same window -> mapper entries:\n core: {CORE_FILE}\n sdk: {SDK_FILE}" - ) - return 1 + if core_mapping != sdk_mapping: + console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core and the Task SDK.[/red]\n") + all_windows = sorted(core_mapping.keys() | sdk_mapping.keys()) + for window in all_windows: + core_val = core_mapping.get(window, "") + sdk_val = sdk_mapping.get(window, "") + marker = " " if core_val == sdk_val else "->" + color = "" if core_val == sdk_val else "[red]" + end = "" if core_val == sdk_val else "[/red]" + console.print(f"{color}{marker} {window}: core={core_val} sdk={sdk_val}{end}") + console.print( + f"\nMake both tables list the same window -> mapper entries:\n core: {CORE_FILE}\n sdk: {SDK_FILE}" + ) + failed = True + + checks = [ + ("SegmentWindow", CORE_WINDOW_FILE, SDK_WINDOW_FILE), + ("FixedKeyMapper", CORE_FIXED_KEY_FILE, SDK_FIXED_KEY_FILE), + ] + + for class_name, core_file, sdk_file in checks: + core_fields = extract_class_field_names(core_file, class_name) + sdk_fields = extract_class_field_names(sdk_file, class_name) + if core_fields != sdk_fields: + console.print(f"[red]{class_name}: field names are out of sync between core and SDK.[/red]") + core_only = core_fields - sdk_fields + sdk_only = sdk_fields - core_fields + if core_only: + console.print(f" core-only fields: {sorted(core_only)}") + if sdk_only: + console.print(f" sdk-only fields: {sorted(sdk_only)}") + console.print(f" core: {core_file}\n sdk: {sdk_file}") + failed = True + + core_msgs = extract_class_error_messages(core_file, class_name) + sdk_msgs = extract_class_error_messages(sdk_file, class_name) + if core_msgs != sdk_msgs: + console.print( + f"[red]{class_name}: raise ValueError(...) message templates are out of sync.[/red]" + ) + core_only = core_msgs - sdk_msgs + sdk_only = sdk_msgs - core_msgs + if core_only: + console.print(f" core-only messages: {sorted(core_only)}") + if sdk_only: + console.print(f" sdk-only messages: {sorted(sdk_only)}") + console.print(f" core: {core_file}\n sdk: {sdk_file}") + failed = True + + return 1 if failed else 0 if __name__ == "__main__": diff --git a/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py new file mode 100644 index 0000000000000..aa1742d33a0a6 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py @@ -0,0 +1,388 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest +from check_partition_mapper_defaults_in_sync import ( + extract_class_error_messages, + extract_class_field_names, +) + + +class TestExtractClassErrorMessages: + @pytest.mark.parametrize( + ("raise_stmt", "expected"), + [ + pytest.param('raise ValueError("msg a")', {"msg a"}, id="plain-string"), + pytest.param('raise ValueError(f"msg b at {i}")', {"msg b at {}"}, id="fstring-template"), + ], + ) + def test_extracts_message(self, tmp_path: Path, raise_stmt: str, expected: set[str]): + """Plain strings are kept verbatim; f-string interpolations become {} placeholders.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent(f"""\ + class MyClass: + def validate(self, x, i): + {raise_stmt} + """) + ) + assert extract_class_error_messages(f, "MyClass") == expected + + def test_extracts_both_plain_and_fstring(self, tmp_path: Path): + """Covers the explicit test requirement: plain + f-string together.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def check(self, item, i): + raise ValueError("msg a") + + def check2(self, item, i): + raise ValueError(f"msg b at {i}") + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == {"msg a", "msg b at {}"} + + def test_ignores_other_exception_types(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def check(self): + raise TypeError("not a value error") + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == set() + + def test_returns_empty_for_missing_class(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("x = 1\n") + result = extract_class_error_messages(f, "Missing") + assert result == set() + + def test_nested_function_messages_included(self, tmp_path: Path): + """Messages raised inside nested helpers within the class body are collected.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def __init__(self, items): + def _check(i, item): + if not isinstance(item, str): + raise ValueError(f"must be str; got {type(item).__name__!r} at {i}") + for i, item in enumerate(items): + _check(i, item) + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == {"must be str; got {} at {}"} + + def test_multipart_fstring_message(self, tmp_path: Path): + """f-string with multiple literal parts and multiple interpolations.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class C: + def v(self, item, i): + raise ValueError( + f"Prefix segment keys must be str; " + f"got {type(item).__name__!r} at index {i}: {item!r}" + ) + """) + ) + result = extract_class_error_messages(f, "C") + # The two adjacent f-strings form a single BinOp(Add) node at AST level; + # the extractor concatenates them into one template. + assert "Prefix segment keys must be str; got {} at index {}: {}" in result + + +class TestExtractClassFieldNames: + def test_extracts_annotated_instance_fields(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + from typing import ClassVar + import attrs + + @attrs.define + class MyClass: + expected_decoded_type: ClassVar[type] = str + _segments: frozenset[str] = attrs.field() + """) + ) + result = extract_class_field_names(f, "MyClass") + # ClassVar is excluded; _segments is included + assert result == {"_segments"} + + def test_excludes_classvar_fields(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + from typing import ClassVar + + class MyClass: + flag: ClassVar[bool] = False + name: str + """) + ) + result = extract_class_field_names(f, "MyClass") + assert result == {"name"} + + def test_returns_empty_for_missing_class(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("x = 1\n") + result = extract_class_field_names(f, "Missing") + assert result == set() + + def test_private_field_name_preserved(self, tmp_path: Path): + """Leading underscore in field name is kept verbatim.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + @attrs.define + class Container: + _segments: frozenset[str] = attrs.field() + """) + ) + result = extract_class_field_names(f, "Container") + assert "_segments" in result + + +class TestInSyncPasses: + def test_in_sync_passes(self, tmp_path: Path): + """Two files with identical field names and messages compare equal.""" + code = textwrap.dedent("""\ + import attrs + + @attrs.define + class MyMapper: + downstream_key: str = attrs.field() + + def validate(self, x): + raise ValueError(f"must be non-empty str; got {x!r}.") + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(code) + sdk_file.write_text(code) + + core_fields = extract_class_field_names(core_file, "MyMapper") + sdk_fields = extract_class_field_names(sdk_file, "MyMapper") + assert core_fields == sdk_fields + + core_msgs = extract_class_error_messages(core_file, "MyMapper") + sdk_msgs = extract_class_error_messages(sdk_file, "MyMapper") + assert core_msgs == sdk_msgs + + +class TestConverterPatternExtraction: + """Tests for the module-level converter= follow-through in extract_class_error_messages.""" + + def test_extracts_messages_from_module_level_converter(self, tmp_path: Path): + """Converter function outside the class body is followed and its messages are collected.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _my_convert(items): + for i, item in enumerate(items): + if not isinstance(item, str): + raise ValueError(f"must be str; got {type(item).__name__!r} at {i}") + if not items: + raise ValueError("must not be empty") + return frozenset(items) + + @attrs.define + class MyClass: + _data: frozenset = attrs.field(converter=_my_convert) + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert "must be str; got {} at {}" in result + assert "must not be empty" in result + + def test_converter_fstring_template_placeholders(self, tmp_path: Path): + """f-string expressions in the converter become {} placeholders in the template.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _convert(items): + for i, item in enumerate(items): + raise ValueError(f"bad item {item!r} at index {i}") + return frozenset(items) + + @attrs.define + class Widget: + _items: frozenset = attrs.field(converter=_convert) + """) + ) + result = extract_class_error_messages(f, "Widget") + assert "bad item {} at index {}" in result + + def test_converter_messages_not_collected_for_unrelated_class(self, tmp_path: Path): + """The converter is only followed for the class that references it.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _my_convert(items): + raise ValueError("converter error") + return frozenset(items) + + @attrs.define + class ClassA: + _data: frozenset = attrs.field(converter=_my_convert) + + @attrs.define + class ClassB: + _name: str = attrs.field() + """) + ) + # ClassA references the converter — should see the message + result_a = extract_class_error_messages(f, "ClassA") + assert "converter error" in result_a + + # ClassB does not reference the converter — should NOT see the message + result_b = extract_class_error_messages(f, "ClassB") + assert "converter error" not in result_b + + +class TestConverterDivergenceDetected: + """Verify that message drift in a module-level converter is caught by the extractor.""" + + def test_divergent_converter_message_detected(self, tmp_path: Path): + """Changing 'non-empty' to 'non-empty strings' in a converter is detected as drift.""" + core_code = textwrap.dedent("""\ + import attrs + + def _convert_segments(segments): + for i, item in enumerate(segments): + if not item: + raise ValueError(f"keys must be non-empty; got empty at {i}.") + return frozenset(segments) + + @attrs.define + class SegmentWindow: + _segments: frozenset = attrs.field(converter=_convert_segments) + """) + sdk_code_diverged = textwrap.dedent("""\ + import attrs + + def _convert_segments(segments): + for i, item in enumerate(segments): + if not item: + raise ValueError(f"keys must be non-empty strings; got empty at {i}.") + return frozenset(segments) + + @attrs.define + class SegmentWindow: + _segments: frozenset = attrs.field(converter=_convert_segments) + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(core_code) + sdk_file.write_text(sdk_code_diverged) + + core_msgs = extract_class_error_messages(core_file, "SegmentWindow") + sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow") + assert core_msgs != sdk_msgs + + +class TestDivergentMessageFails: + def test_divergent_message_detected(self, tmp_path: Path): + """Changing 'non-empty' to 'non-empty strings' on one side is detected.""" + core_code = textwrap.dedent("""\ + import attrs + + @attrs.define + class SegmentWindow: + _segments: frozenset[str] = attrs.field() + + def validate(self, item, i): + raise ValueError(f"keys must be non-empty; got empty at {i}.") + """) + sdk_code_diverged = textwrap.dedent("""\ + import attrs + + @attrs.define + class SegmentWindow: + _segments: frozenset[str] = attrs.field() + + def validate(self, item, i): + raise ValueError(f"keys must be non-empty strings; got empty at {i}.") + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(core_code) + sdk_file.write_text(sdk_code_diverged) + + core_msgs = extract_class_error_messages(core_file, "SegmentWindow") + sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow") + assert core_msgs != sdk_msgs + + @pytest.mark.parametrize( + ("core_msg", "sdk_msg"), + [ + pytest.param( + "keys must be non-empty; got empty at {}.", + "keys must be non-empty strings; got empty at {}.", + id="non-empty-vs-non-empty-strings", + ), + pytest.param( + "requires at least one key; got an empty iterable.", + "requires at least one key.", + id="different-constant-wording", + ), + ], + ) + def test_parametrized_divergence(self, tmp_path: Path, core_msg: str, sdk_msg: str): + def _make_file(path: Path, msg: str) -> None: + # Use a plain f-string if the message contains '{}', else a constant. + if "{}" in msg: + # Reconstruct as f-string source: replace {} with {i} + src_msg = msg.replace("{}", "{i}") + stmt = f'raise ValueError(f"{src_msg}")' + else: + stmt = f'raise ValueError("{msg}")' + path.write_text( + textwrap.dedent(f"""\ + class C: + def v(self, i): + {stmt} + """) + ) + + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + _make_file(core_file, core_msg) + _make_file(sdk_file, sdk_msg) + + core_msgs = extract_class_error_messages(core_file, "C") + sdk_msgs = extract_class_error_messages(sdk_file, "C") + assert core_msgs != sdk_msgs diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py index 24a1212e1076b..bf2b07905a818 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py @@ -16,9 +16,12 @@ # under the License. from __future__ import annotations +import attrs + from airflow.sdk.definitions.partition_mappers.base import PartitionMapper +@attrs.define class FixedKeyMapper(PartitionMapper): """ Collapse every upstream partition key onto one fixed downstream key. @@ -37,12 +40,12 @@ class FixedKeyMapper(PartitionMapper): :raises ValueError: if *downstream_key* is not a non-empty ``str``. """ - def __init__(self, downstream_key: str) -> None: - if not isinstance(downstream_key, str) or downstream_key == "": - raise ValueError( - f"FixedKeyMapper downstream_key must be a non-empty str; got {downstream_key!r}." - ) - self.downstream_key = downstream_key + downstream_key: str = attrs.field() + + @downstream_key.validator + def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) -> None: + if not isinstance(value, str) or value == "": + raise ValueError(f"FixedKeyMapper downstream_key must be a non-empty str; got {value!r}.") def to_downstream(self, key: str) -> str: """Return the fixed downstream key regardless of *key*.""" diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py index 9a02f46b873ad..f491d4279ec7a 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py @@ -24,6 +24,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar +import attrs + if TYPE_CHECKING: from collections.abc import Iterable @@ -119,6 +121,31 @@ class YearWindow(Window): expected_decoded_type: ClassVar[type] = datetime +def _convert_segments(segments: Iterable[str]) -> frozenset[str]: + """ + Validate and convert *segments* to a ``frozenset[str]``. + + Validates each element for type and non-emptiness (with index reporting) + before collapsing into a frozenset, then checks the result is non-empty. + """ + validated: list[str] = [] + for i, item in enumerate(segments): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; got {type(item).__name__!r} at index {i}: {item!r}" + ) + if not item: + raise ValueError( + f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." + ) + validated.append(item) + result = frozenset(validated) + if not result: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + return result + + +@attrs.define class SegmentWindow(Window): """ A fixed categorical set of string keys that constitute one downstream period. @@ -140,18 +167,4 @@ class SegmentWindow(Window): expected_decoded_type: ClassVar[type] = str - def __init__(self, segments: Iterable[str]) -> None: - collected: list[str] = list(segments) - if not collected: - raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") - for i, item in enumerate(collected): - if not isinstance(item, str): - raise ValueError( - f"SegmentWindow segment keys must be str; " - f"got {type(item).__name__!r} at index {i}: {item!r}" - ) - if item == "": - raise ValueError( - f"SegmentWindow segment keys must be non-empty strings; got an empty string at index {i}." - ) - self._segments: frozenset[str] = frozenset(collected) + _segments: frozenset[str] = attrs.field(converter=_convert_segments) diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py index 0f6c1e1a7a475..7114b65d6c4ba 100644 --- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py +++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py @@ -110,11 +110,9 @@ def test_temporal_windows_declare_datetime(self, window_cls): class TestSdkFixedKeyMapper: """SDK-side FixedKeyMapper construction and validation.""" - def test_to_downstream_returns_constant_for_any_key(self): - m = FixedKeyMapper("all_regions") - assert m.to_downstream("us") == "all_regions" - assert m.to_downstream("eu") == "all_regions" - assert m.to_downstream("apac") == "all_regions" + @pytest.mark.parametrize("key", ["us", "eu", "apac"]) + def test_to_downstream_returns_constant_for_any_key(self, key): + assert FixedKeyMapper("all_regions").to_downstream(key) == "all_regions" def test_is_rollup_false(self): assert FixedKeyMapper("all").is_rollup is False @@ -151,7 +149,7 @@ def test_deduplication(self): [ pytest.param([], "at least one segment key", id="empty-list"), pytest.param([1, "b"], "must be str", id="int-element"), - pytest.param(["", "b"], "non-empty strings", id="empty-string"), + pytest.param(["", "b"], "non-empty", id="empty-string"), ], ) def test_rejects_invalid_segments(self, segments, match):