diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1de10325dd64b..2be1cba4bd28e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -249,13 +249,17 @@ repos: pass_filenames: false require_serial: true - id: check-partition-mapper-defaults-in-sync - name: Check FanOutMapper default mapper table stays in sync (core/SDK) + name: Check partition-mapper core/SDK sync (FanOutMapper table + SegmentWindow/FixedKeyMapper) entry: ./scripts/ci/prek/check_partition_mapper_defaults_in_sync.py language: python files: > (?x) ^airflow-core/src/airflow/partition_mappers/temporal\.py$| - ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$ + ^airflow-core/src/airflow/partition_mappers/window\.py$| + ^airflow-core/src/airflow/partition_mappers/fixed_key\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$| + ^task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key\.py$ pass_filenames: false require_serial: true - id: check-window-in-sync diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst b/airflow-core/docs/authoring-and-scheduling/assets.rst index 7bf609e9d9834..2bc06a3011721 100644 --- a/airflow-core/docs/authoring-and-scheduling/assets.rst +++ b/airflow-core/docs/authoring-and-scheduling/assets.rst @@ -565,6 +565,12 @@ downstream Dag partition key: passes the key through unchanged if valid. For example, ``AllowedKeyMapper(["us", "eu", "apac"])`` accepts only those region keys and rejects all others. +* ``FixedKeyMapper`` collapses every upstream key onto a fixed downstream key, + regardless of the upstream value. +* ``SegmentWindow`` declares a fixed categorical set of string keys (e.g. regions, + tenants) that constitute one downstream period; paired with ``FixedKeyMapper`` + inside a ``RollupMapper`` it holds the downstream run until every declared segment + has arrived (see :ref:`segment-rollup `). Example of per-asset mapper configuration and composite-key mapping: @@ -733,6 +739,83 @@ so the run is held indefinitely) and the fall-back day has twenty-five (the repe hour is dropped). Use a UTC-based upstream mapper for any rollup that crosses a DST boundary; see the ``DayWindow`` class docstring for the full discussion. +.. _segment-categorical-rollup: + +Segment (categorical) rollup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 3.3.0 + +For categorical partitioning — regions, tenants, experiment variants — compose a +``RollupMapper`` from two primitives: + +* ``SegmentWindow(["us", "eu", "apac"])`` declares the fixed set of string keys + that constitute one downstream period; ``to_upstream`` returns the full set + regardless of the downstream anchor. +* ``FixedKeyMapper("all_regions")`` collapses every upstream key onto the single + downstream partition key ``"all_regions"``. + +The scheduler holds the downstream Dag run until every declared segment has arrived +from the upstream producer, then fires once. All the segment events accumulate into +one ``AssetPartitionDagRun``; the fired run's ``partition_key`` is the value passed +to ``FixedKeyMapper``. This composition only makes sense under ``WAIT_FOR_ALL`` +semantics (the default). + +.. code-block:: python + + from airflow.sdk import ( + DAG, + Asset, + FixedKeyMapper, + PartitionAtRuntime, + PartitionedAssetTimetable, + RollupMapper, + SegmentWindow, + asset, + task, + ) + + + @asset( + uri="file://incoming/player-stats/multi-region.csv", + schedule=PartitionAtRuntime(), + ) + def multi_region_player_stats(self, outlet_events): + # Emit one event per region in a single run. + outlet_events[self].add_partitions(["us", "eu", "apac"]) + + + # Consumer: fires once all three region partitions have arrived. + with DAG( + dag_id="segment_region_stats_rollup", + schedule=PartitionedAssetTimetable( + assets=Asset.ref(name="multi_region_player_stats"), + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + catchup=False, + ): + + @task + def aggregate_all_regions(dag_run=None): + # dag_run.partition_key is the downstream key once all segments arrive. + print(dag_run.partition_key) + + aggregate_all_regions() + +Construction validates both components: ``SegmentWindow`` raises ``ValueError`` for +an empty list, non-string items, or empty-string keys; duplicate entries are silently +deduplicated. ``FixedKeyMapper`` raises ``ValueError`` if its argument is not a +non-empty string. Pass a distinct ``FixedKeyMapper`` key when one consumer Dag rolls +up more than one asset, so each rollup uses a distinct bucket and they do not collide +on the same ``(target_dag_id, partition_key)``. + +For a segment set that must be computed at runtime, do not encode it here — evaluate +completeness in a consumer-side task instead (the scheduler must not run user code to +decide a partition set). + Setting partition keys at runtime ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/airflow-core/newsfragments/67716.feature.rst b/airflow-core/newsfragments/67716.feature.rst new file mode 100644 index 0000000000000..63daeda113ca1 --- /dev/null +++ b/airflow-core/newsfragments/67716.feature.rst @@ -0,0 +1 @@ +Add ``FixedKeyMapper`` and ``SegmentWindow`` for categorical asset-partition rollup. ``FixedKeyMapper`` collapses any upstream partition key onto a single fixed downstream key, and ``SegmentWindow`` enumerates a fixed categorical segment set the scheduler waits for. Composing ``RollupMapper(FixedKeyMapper(...), SegmentWindow(...))`` expresses a categorical rollup, mirroring the temporal rollup shape, and ``SegmentWindow`` also composes with ``FanOutMapper`` for categorical scatter. Both classes are importable from ``airflow.sdk``. diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 775fea80ec314..65a078d3c3f9e 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -26,12 +26,14 @@ CronPartitionTimetable, DayWindow, FanOutMapper, + FixedKeyMapper, IdentityMapper, MonthWindow, PartitionAtRuntime, PartitionedAssetTimetable, ProductMapper, RollupMapper, + SegmentWindow, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, @@ -409,3 +411,41 @@ def run_inference(dag_run=None): print(dag_run.partition_key) run_inference() + + +# --- Segment (categorical) rollup ------------------------------------------- +# ``multi_region_player_stats`` (defined above) emits one partition per region +# (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a downstream +# run until every declared region key has arrived. + +with DAG( + dag_id="segment_region_stats_rollup", + schedule=PartitionedAssetTimetable( + assets=Asset.ref(name="multi_region_player_stats"), + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + catchup=False, + tags=["example", "player-stats", "rollup", "segment"], +): + """ + Categorical rollup: hold until all three region partitions arrive. + + ``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow([...]))`` + declares the fixed set of region keys required for one downstream run and collapses every + region key onto a single ``all_regions`` partition, so the three region events accumulate + into one downstream run. The run is held until ``us``, ``eu``, and ``apac`` have all + arrived from ``multi_region_player_stats``; partial arrivals remain pending in the + next-run-assets view so operators can track progress. + """ + + @task + def aggregate_all_regions(dag_run=None): + """Produce the cross-region summary once every region partition has arrived.""" + if TYPE_CHECKING: + assert dag_run + print(f"All region partitions received. Partition: {dag_run.partition_key}") + + aggregate_all_regions() diff --git a/airflow-core/src/airflow/partition_mappers/__init__.py b/airflow-core/src/airflow/partition_mappers/__init__.py index 9b66876e0df8a..f780603870736 100644 --- a/airflow-core/src/airflow/partition_mappers/__init__.py +++ b/airflow-core/src/airflow/partition_mappers/__init__.py @@ -19,6 +19,7 @@ from airflow.partition_mappers.allowed_key import AllowedKeyMapper from airflow.partition_mappers.base import PartitionMapper, RollupMapper from airflow.partition_mappers.chain import ChainMapper +from airflow.partition_mappers.fixed_key import FixedKeyMapper from airflow.partition_mappers.identity import IdentityMapper from airflow.partition_mappers.product import ProductMapper from airflow.partition_mappers.temporal import ( @@ -34,6 +35,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -43,6 +45,7 @@ "AllowedKeyMapper", "ChainMapper", "DayWindow", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "MonthWindow", @@ -50,6 +53,7 @@ "ProductMapper", "QuarterWindow", "RollupMapper", + "SegmentWindow", "StartOfDayMapper", "StartOfHourMapper", "StartOfMonthMapper", diff --git a/airflow-core/src/airflow/partition_mappers/fixed_key.py b/airflow-core/src/airflow/partition_mappers/fixed_key.py new file mode 100644 index 0000000000000..89388bfed12cd --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/fixed_key.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +import attrs + +from airflow.partition_mappers.base import PartitionMapper + + +@attrs.define +class FixedKeyMapper(PartitionMapper): + """ + Collapse every upstream partition key onto one fixed downstream key. + + Returns the same *downstream_key* for any upstream key passed to + ``to_downstream``. Does not override ``decode_downstream`` or + ``encode_upstream``, so it works with the string-based identity path and + satisfies :class:`~airflow.partition_mappers.base.RollupMapper`'s guard + when paired with :class:`~airflow.partition_mappers.window.SegmentWindow`. + + Typical use is as the ``upstream_mapper`` inside a categorical rollup:: + + RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ) + + :param downstream_key: The fixed downstream partition key every upstream key + maps to. Must be a non-empty string. + :raises ValueError: if *downstream_key* is not a non-empty ``str``. + """ + + downstream_key: str = attrs.field() + + @downstream_key.validator + def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) -> None: + if not isinstance(value, str) or value == "": + raise ValueError(f"FixedKeyMapper downstream_key must be a non-empty str; got {value!r}.") + + def to_downstream(self, key: str) -> str: + """Return the fixed downstream key regardless of *key*.""" + return self.downstream_key + + def serialize(self) -> dict[str, Any]: + return {"downstream_key": self.downstream_key} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FixedKeyMapper: + return cls(data["downstream_key"]) diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index 087b6180f2ade..ffcf0e85ab001 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -21,6 +21,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar +import attrs + if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -246,3 +248,63 @@ class YearWindow(Window): def to_upstream(self, period_start: datetime) -> Iterable[datetime]: _require_day_one(period_start, type(self)) return _build_directional_steps(period_start, 12, _shift_months, self.direction) + + +def _convert_segments(segments: Iterable[str]) -> frozenset[str]: + """ + Validate and convert *segments* to a ``frozenset[str]``. + + Validates each element for type and non-emptiness (with index reporting) + before collapsing into a frozenset, then checks the result is non-empty. + """ + validated: list[str] = [] + for i, item in enumerate(segments): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; got {type(item).__name__!r} at index {i}: {item!r}" + ) + if not item: + raise ValueError( + f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." + ) + validated.append(item) + result = frozenset(validated) + if not result: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + return result + + +@attrs.define +class SegmentWindow(Window): + """ + A fixed categorical set of string keys that constitute one downstream period. + + Paired with :class:`~airflow.partition_mappers.fixed_key.FixedKeyMapper` inside a + :class:`~airflow.partition_mappers.base.RollupMapper` to express a categorical + rollup: the scheduler holds the downstream run until every declared segment key + has arrived from the upstream producer, then fires once. + + ``to_upstream`` returns the complete segment set regardless of the downstream + anchor value — the anchor is intentionally ignored because all segments map onto + a single downstream partition key, not a time-based period. + + :param segments: Non-empty iterable of non-empty string segment keys. Duplicates + are silently de-duplicated. + :raises ValueError: if *segments* is empty, contains a non-``str`` element, or + contains an empty-string element. + """ + + expected_decoded_type: ClassVar[type] = str + + _segments: frozenset[str] = attrs.field(converter=_convert_segments) + + def to_upstream(self, decoded_downstream: Any) -> frozenset[str]: + """Return the full declared segment set, ignoring the downstream anchor.""" + return self._segments + + def serialize(self) -> dict[str, Any]: + return {"segments": sorted(self._segments)} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SegmentWindow: + return cls(data["segments"]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index f4ec081df1bef..0c2f030da8f78 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -43,6 +43,7 @@ DeltaTriggerTimetable, EventsTimetable, FanOutMapper, + FixedKeyMapper, HourWindow, IdentityMapper, MonthWindow, @@ -51,6 +52,7 @@ ProductMapper, QuarterWindow, RollupMapper, + SegmentWindow, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, @@ -437,6 +439,7 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: AllowedKeyMapper: "airflow.partition_mappers.allowed_key.AllowedKeyMapper", ChainMapper: "airflow.partition_mappers.chain.ChainMapper", FanOutMapper: "airflow.partition_mappers.temporal.FanOutMapper", + FixedKeyMapper: "airflow.partition_mappers.fixed_key.FixedKeyMapper", IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", RollupMapper: "airflow.partition_mappers.base.RollupMapper", @@ -464,6 +467,10 @@ def _(self, partition_mapper: ChainMapper) -> dict[str, Any]: def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]: return {} + @serialize_partition_mapper.register + def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]: + return {"downstream_key": partition_mapper.downstream_key} + @serialize_partition_mapper.register(StartOfHourMapper) @serialize_partition_mapper.register(StartOfDayMapper) @serialize_partition_mapper.register(StartOfWeekMapper) @@ -517,6 +524,7 @@ def _(self, partition_mapper: FanOutMapper) -> dict[str, Any]: WeekWindow: "airflow.partition_mappers.window.WeekWindow", MonthWindow: "airflow.partition_mappers.window.MonthWindow", QuarterWindow: "airflow.partition_mappers.window.QuarterWindow", + SegmentWindow: "airflow.partition_mappers.window.SegmentWindow", YearWindow: "airflow.partition_mappers.window.YearWindow", } @@ -538,6 +546,10 @@ def _( ) -> dict[str, Any]: return window.serialize() + @serialize_window.register + def _(self, window: SegmentWindow) -> dict[str, Any]: + return {"segments": sorted(window._segments)} + _serializer = _Serializer() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 89383233fad02..d08628a51fbcd 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -104,9 +104,11 @@ AssetAlias, AssetWatcher, CronPartitionTimetable, + FixedKeyMapper, HourWindow, IdentityMapper, RollupMapper, + SegmentWindow, StartOfHourMapper, task, ) @@ -10273,6 +10275,74 @@ def test_partitioned_dag_run_rollup_holds_until_window_complete( assert partition_dags == {"rollup-consumer"} +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_segment_rollup_holds_until_all_segments_arrive( + dag_maker: DagMaker, + session: Session, +): + """ + A categorical (segment) rollup fires once every declared segment has arrived. + + ``RollupMapper(FixedKeyMapper("all_regions"), SegmentWindow([...]))`` collapses + each region key onto a single ``all_regions`` partition, so all three events + accumulate into one APDR, and holds the downstream run until ``us``, ``eu``, + and ``apac`` are all present. + """ + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="segment-rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu", "apac"]), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # First region arrives — only 1 / 3 segments, so the APDR must not fire. + # Every region collapses onto the single ``all_regions`` partition. + apdr = _produce_and_register_asset_event( + dag_id="segment-producer-us", + asset=asset_1, + partition_key="us", + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + # The remaining two regions arrive — once all three segments are present the + # rollup is satisfied and the APDR creates its Dag run on the next tick. All + # three events share the one ``all_regions`` APDR. + for region in ("eu", "apac"): + sibling = _produce_and_register_asset_event( + dag_id=f"segment-producer-{region}", + asset=asset_1, + partition_key=region, + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + assert sibling.id == apdr.id + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + assert apdr.partition_key == "all_regions" + assert partition_dags == {"segment-rollup-consumer"} + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( diff --git a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py new file mode 100644 index 0000000000000..efdd5be331370 --- /dev/null +++ b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.partition_mappers.base import PartitionMapper, RollupMapper +from airflow.partition_mappers.fixed_key import FixedKeyMapper +from airflow.partition_mappers.window import DayWindow, SegmentWindow +from airflow.sdk import ( + FixedKeyMapper as SdkFixedKeyMapper, + RollupMapper as SdkRollupMapper, + SegmentWindow as SdkSegmentWindow, +) +from airflow.serialization.decoders import decode_partition_mapper +from airflow.serialization.encoders import encode_partition_mapper + + +class TestFixedKeyMapper: + @pytest.mark.parametrize("key", ["us", "eu", "apac", "anything-else"]) + def test_to_downstream_returns_constant_for_any_key(self, key): + assert FixedKeyMapper("all").to_downstream(key) == "all" + + def test_is_rollup_false(self): + # A bare FixedKeyMapper is not a rollup; rollup-ness comes from RollupMapper. + assert FixedKeyMapper("all").is_rollup is False + + def test_does_not_override_decode_encode(self): + m = FixedKeyMapper("all") + assert type(m).decode_downstream is PartitionMapper.decode_downstream + assert type(m).encode_upstream is PartitionMapper.encode_upstream + + @pytest.mark.parametrize( + ("downstream_key", "match"), + [ + pytest.param("", "non-empty str", id="empty-string"), + pytest.param(None, "non-empty str", id="none"), + pytest.param(1, "non-empty str", id="int"), + ], + ) + def test_rejects_invalid_downstream_key(self, downstream_key, match): + with pytest.raises(ValueError, match=match): + FixedKeyMapper(downstream_key) + + def test_requires_downstream_key(self): + with pytest.raises(TypeError): + FixedKeyMapper() + + def test_serialize_round_trip(self): + m = FixedKeyMapper("bucket") + restored = FixedKeyMapper.deserialize(m.serialize()) + assert isinstance(restored, FixedKeyMapper) + assert restored.downstream_key == "bucket" + + +class TestCategoricalRollupEquivalence: + """RollupMapper(FixedKeyMapper, SegmentWindow) behaves like old SegmentMapper.""" + + def setup_method(self): + self.m = RollupMapper( + upstream_mapper=FixedKeyMapper("all"), + window=SegmentWindow(["us", "eu", "apac"]), + ) + + def test_is_rollup_flag(self): + assert self.m.is_rollup is True + + def test_to_downstream_collapses_every_segment_onto_downstream_key(self): + # Full-sequence equality: every declared segment key maps to the constant key. + assert [self.m.to_downstream(s) for s in ("us", "eu", "apac")] == ["all", "all", "all"] + + @pytest.mark.parametrize("anchor", ["all", "anything"]) + def test_to_upstream_returns_full_set_ignoring_anchor(self, anchor): + assert self.m.to_upstream(anchor) == frozenset({"us", "eu", "apac"}) + + def test_core_encode_decode_round_trip(self): + restored = decode_partition_mapper(encode_partition_mapper(self.m)) + assert isinstance(restored, RollupMapper) + assert restored.is_rollup is True + assert restored.to_downstream("us") == "all" + assert restored.to_upstream("all") == frozenset({"us", "eu", "apac"}) + + def test_sdk_encode_decode_round_trip(self): + # User code authors with SDK classes; the scheduler serializes and deserializes + # into core classes. + sdk_mapper = SdkRollupMapper( + upstream_mapper=SdkFixedKeyMapper("all_regions"), + window=SdkSegmentWindow(["us", "eu", "apac"]), + ) + restored = decode_partition_mapper(encode_partition_mapper(sdk_mapper)) + assert isinstance(restored, RollupMapper) + assert restored.to_upstream("all_regions") == frozenset({"us", "eu", "apac"}) + + +class TestCategoricalRollupTypeGuard: + """Core-side RollupMapper guard: FixedKeyMapper(str) + SegmentWindow(str) must pass.""" + + def test_fixed_key_with_segment_window_does_not_raise(self): + # Core guard: FixedKeyMapper does not override decode_downstream, + # SegmentWindow.expected_decoded_type is str -> guard passes. + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=SegmentWindow(["us", "eu"])) + + def test_str_mapper_with_datetime_window_raises(self): + # Core guard: FixedKeyMapper (no decode override) + DayWindow (datetime) -> raise. + with pytest.raises(TypeError, match="DayWindow expects decoded values of type 'datetime'"): + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=DayWindow()) diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index ddfd44cbd76f7..b4956761c2ce7 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -34,6 +34,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -465,6 +466,47 @@ def test_invalid_direction_raises_value_error(self, bad_value): WeekWindow(direction=bad_value) +class TestSegmentWindow: + def test_to_upstream_returns_full_set_ignoring_anchor(self): + w = SegmentWindow(["us", "eu", "apac"]) + result_a = frozenset(w.to_upstream("any-anchor")) + result_b = frozenset(w.to_upstream("different-anchor")) + assert result_a == frozenset({"us", "eu", "apac"}) + assert result_a == result_b + + def test_expected_decoded_type_is_str(self): + assert SegmentWindow.expected_decoded_type is str + + @pytest.mark.parametrize( + ("segments", "match"), + [ + pytest.param([], "at least one segment key", id="empty-list"), + pytest.param(iter([]), "at least one segment key", id="empty-iterator"), + pytest.param([1, "b"], "must be str", id="int-element"), + pytest.param([None, "b"], "must be str", id="none-element"), + pytest.param(["", "b"], "non-empty", id="empty-string-first"), + pytest.param(["a", ""], "non-empty", id="empty-string-second"), + ], + ) + def test_rejects_invalid_segments(self, segments, match): + with pytest.raises(ValueError, match=match): + SegmentWindow(segments) + + def test_deduplication(self): + w = SegmentWindow(["us", "us", "eu"]) + assert frozenset(w.to_upstream("any")) == frozenset({"us", "eu"}) + + def test_serialize_uses_sorted_order(self): + w = SegmentWindow(["z", "a", "m"]) + assert w.serialize() == {"segments": ["a", "m", "z"]} + + def test_deserialize_round_trip(self): + w = SegmentWindow(["us", "eu", "apac"]) + restored = SegmentWindow.deserialize(w.serialize()) + assert isinstance(restored, SegmentWindow) + assert frozenset(restored.to_upstream("any")) == frozenset({"us", "eu", "apac"}) + + class TestWindowSerializationGate: """``encode_window`` / ``decode_window`` must reject non-built-in Windows. diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py index 0c1adea966940..24f7520f1b5a6 100755 --- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py +++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py @@ -23,24 +23,25 @@ # ] # /// """ -Verify ``FanOutMapper.default_downstream_mapper_by_window_name`` stays in sync -between core and the Task SDK. +Verify partition-mapper definitions stay in sync between core and the Task SDK. -The default downstream-mapper table is defined twice — once in the core class -hierarchy and once in the SDK copy — because the two hierarchies are -independent (the SDK cannot import core) and the lookup is by ``Window`` class -*name*. Both copies must list the same ``Window`` name -> mapper class mapping, -otherwise a ``FanOutMapper`` resolves a different default depending on whether -it runs in Dag-author code (SDK) or after deserialization (core). +Checks two things: -This check parses the ``default_downstream_mapper_by_window_name`` class -attribute from both files via AST and asserts the two mappings are identical. +1. ``FanOutMapper.default_downstream_mapper_by_window_name`` — the default + downstream-mapper table is defined twice (core and SDK) because the two + hierarchies are independent. Both copies must list the same + ``Window`` name -> mapper class mapping. + +2. ``SegmentWindow`` and ``FixedKeyMapper`` — for each class, the field-name + set and the ``raise ValueError(...)`` message-template set must be identical + between core and the SDK. This catches wording drift (e.g. "non-empty" vs + "non-empty strings") that would otherwise silently diverge. Run from the repo root: uv run --project scripts python scripts/ci/prek/check_partition_mapper_defaults_in_sync.py -Exits 0 if the two tables match, 1 (with a diff) otherwise. +Exits 0 if everything matches, 1 (with a diff) otherwise. """ from __future__ import annotations @@ -65,6 +66,15 @@ AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "temporal.py" ) +CORE_WINDOW_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" / "window.py" +SDK_WINDOW_FILE = ( + AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "window.py" +) +CORE_FIXED_KEY_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" / "fixed_key.py" +SDK_FIXED_KEY_FILE = ( + AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "partition_mappers" / "fixed_key.py" +) + def _find_attr_value(file_path: Path) -> ast.Dict: """Return the AST node assigned to ``FanOutMapper.default_downstream_mapper_by_window_name``.""" @@ -113,7 +123,200 @@ def _extract_mapping(file_path: Path) -> dict[str, str]: return mapping +def _joinedstr_to_template(node: ast.JoinedStr) -> str: + """ + Convert an f-string AST node to a template string. + + Literal text fragments are kept as-is; each interpolated expression + ``{...}`` is replaced by a ``{}`` placeholder. This lets us compare + f-string message templates between core and SDK without caring about the + exact expression inside the braces (e.g. ``{type(item).__name__!r}`` vs + a future refactoring). + """ + parts: list[str] = [] + for value in node.values: + if isinstance(value, ast.Constant): + parts.append(str(value.value)) + else: + parts.append("{}") + return "".join(parts) + + +def _extract_raise_messages(subtree: ast.AST) -> set[str]: + """ + Collect ``raise ValueError(...)`` message templates from an arbitrary AST subtree. + + For each ``raise ValueError(...)`` whose first argument is a plain string + constant, an f-string, or adjacent f-string/constant concatenation + (``BinOp(Add, ...)``), extracts the template: + + - ``ast.Constant`` → the literal string value. + - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated + expressions replaced by ``{}``. + - ``ast.BinOp(Add, ...)`` → recursively concatenated from both sides. + """ + messages: set[str] = set() + + def _collect_bin(n: ast.expr) -> str: + if isinstance(n, ast.Constant) and isinstance(n.value, str): + return n.value + if isinstance(n, ast.JoinedStr): + return _joinedstr_to_template(n) + if isinstance(n, ast.BinOp) and isinstance(n.op, ast.Add): + return _collect_bin(n.left) + _collect_bin(n.right) + return "{}" + + for node in ast.walk(subtree): + if not isinstance(node, ast.Raise): + continue + exc = node.exc + if exc is None: + continue + if not ( + isinstance(exc, ast.Call) + and isinstance(exc.func, ast.Name) + and exc.func.id == "ValueError" + and exc.args + ): + continue + arg = exc.args[0] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + messages.add(arg.value) + elif isinstance(arg, ast.JoinedStr): + messages.add(_joinedstr_to_template(arg)) + elif isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Add): + messages.add(_collect_bin(arg)) + + return messages + + +def _collect_converter_names(class_node: ast.ClassDef) -> set[str]: + """ + Return function names referenced via ``converter=`` in attrs field definitions. + + Scans the top-level statements of *class_node* for annotated assignments of + the form ``: = attrs.field(converter=)`` (or + ``attr.field`` / bare ``field(...)``). Also handles unannotated + ``Assign`` nodes with an ``attrs.field(converter=)`` call on the RHS. + Only bare ``ast.Name`` references are collected; lambdas and attribute + lookups are ignored (they cannot resolve to a module-level function by + name). + """ + names: set[str] = set() + + def _check_call(call_node: ast.expr) -> None: + """If *call_node* is an attrs.field / attr.field / field() call, collect converter=.""" + if not isinstance(call_node, ast.Call): + return + func = call_node.func + is_field_call = ( + # attrs.field(...) or attr.field(...) + (isinstance(func, ast.Attribute) and func.attr == "field") + # bare field(...) + or (isinstance(func, ast.Name) and func.id == "field") + ) + if not is_field_call: + return + for kw in call_node.keywords: + if kw.arg == "converter" and isinstance(kw.value, ast.Name): + names.add(kw.value.id) + + for stmt in class_node.body: + if isinstance(stmt, ast.AnnAssign) and stmt.value is not None: + _check_call(stmt.value) + elif isinstance(stmt, ast.Assign): + for target_value in [stmt.value]: + _check_call(target_value) + + return names + + +def extract_class_error_messages(file_path: Path, class_name: str) -> set[str]: + """ + Return the set of ``raise ValueError(...)`` message templates for *class_name*. + + Scans two sources: + + 1. The class body itself (all depths) — covers validator methods and nested + helpers declared inside the class. + 2. Module-level functions referenced via ``converter=`` in + ``attrs.field(...)`` declarations inside the class body. These + converters live outside the class but are logically part of its + construction-time validation. + + For each ``raise ValueError(...)`` whose first argument is a plain string + constant or an f-string, extracts the template: + + - ``ast.Constant`` → the literal string value. + - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated + expressions replaced by ``{}``. + """ + tree = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + + target_class: ast.ClassDef | None = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + target_class = node + break + + if target_class is None: + return set() + + # 1. Collect messages from the class body. + messages = _extract_raise_messages(target_class) + + # 2. Follow converter= references to module-level functions. + converter_names = _collect_converter_names(target_class) + if converter_names: + # Only look at top-level (module-body) function definitions to avoid + # accidentally matching a same-named method inside another class. + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name in converter_names: + messages |= _extract_raise_messages(node) + + return messages + + +def extract_class_field_names(file_path: Path, class_name: str) -> set[str]: + """ + Return the set of annotated attribute names declared in *class_name*'s body. + + Collects names from annotated assignments (``name: type = ...`` or + ``name: type``) at the top level of the class body. ClassVar annotations + are excluded because they are class-level constants, not instance fields. + """ + tree = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + fields: set[str] = set() + + for node in ast.walk(tree): + if not (isinstance(node, ast.ClassDef) and node.name == class_name): + continue + for stmt in node.body: + if not isinstance(stmt, ast.AnnAssign): + continue + if not isinstance(stmt.target, ast.Name): + continue + # Skip ClassVar[...] annotations + ann = stmt.annotation + is_classvar = False + if isinstance(ann, ast.Subscript): + if isinstance(ann.value, ast.Name) and ann.value.id == "ClassVar": + is_classvar = True + elif isinstance(ann.value, ast.Attribute) and ann.value.attr == "ClassVar": + is_classvar = True + elif isinstance(ann, ast.Name) and ann.id == "ClassVar": + is_classvar = True + if is_classvar: + continue + fields.add(stmt.target.id) + break + + return fields + + def main() -> int: + failed = False + try: core_mapping = _extract_mapping(CORE_FILE) sdk_mapping = _extract_mapping(SDK_FILE) @@ -121,22 +324,56 @@ def main() -> int: console.print(f"[red]Could not read the default mapper table:[/red] {exc}") return 1 - if core_mapping == sdk_mapping: - return 0 - - console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core and the Task SDK.[/red]\n") - all_windows = sorted(core_mapping.keys() | sdk_mapping.keys()) - for window in all_windows: - core_val = core_mapping.get(window, "") - sdk_val = sdk_mapping.get(window, "") - marker = " " if core_val == sdk_val else "->" - color = "" if core_val == sdk_val else "[red]" - end = "" if core_val == sdk_val else "[/red]" - console.print(f"{color}{marker} {window}: core={core_val} sdk={sdk_val}{end}") - console.print( - f"\nMake both tables list the same window -> mapper entries:\n core: {CORE_FILE}\n sdk: {SDK_FILE}" - ) - return 1 + if core_mapping != sdk_mapping: + console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core and the Task SDK.[/red]\n") + all_windows = sorted(core_mapping.keys() | sdk_mapping.keys()) + for window in all_windows: + core_val = core_mapping.get(window, "") + sdk_val = sdk_mapping.get(window, "") + marker = " " if core_val == sdk_val else "->" + color = "" if core_val == sdk_val else "[red]" + end = "" if core_val == sdk_val else "[/red]" + console.print(f"{color}{marker} {window}: core={core_val} sdk={sdk_val}{end}") + console.print( + f"\nMake both tables list the same window -> mapper entries:\n core: {CORE_FILE}\n sdk: {SDK_FILE}" + ) + failed = True + + checks = [ + ("SegmentWindow", CORE_WINDOW_FILE, SDK_WINDOW_FILE), + ("FixedKeyMapper", CORE_FIXED_KEY_FILE, SDK_FIXED_KEY_FILE), + ] + + for class_name, core_file, sdk_file in checks: + core_fields = extract_class_field_names(core_file, class_name) + sdk_fields = extract_class_field_names(sdk_file, class_name) + if core_fields != sdk_fields: + console.print(f"[red]{class_name}: field names are out of sync between core and SDK.[/red]") + core_only = core_fields - sdk_fields + sdk_only = sdk_fields - core_fields + if core_only: + console.print(f" core-only fields: {sorted(core_only)}") + if sdk_only: + console.print(f" sdk-only fields: {sorted(sdk_only)}") + console.print(f" core: {core_file}\n sdk: {sdk_file}") + failed = True + + core_msgs = extract_class_error_messages(core_file, class_name) + sdk_msgs = extract_class_error_messages(sdk_file, class_name) + if core_msgs != sdk_msgs: + console.print( + f"[red]{class_name}: raise ValueError(...) message templates are out of sync.[/red]" + ) + core_only = core_msgs - sdk_msgs + sdk_only = sdk_msgs - core_msgs + if core_only: + console.print(f" core-only messages: {sorted(core_only)}") + if sdk_only: + console.print(f" sdk-only messages: {sorted(sdk_only)}") + console.print(f" core: {core_file}\n sdk: {sdk_file}") + failed = True + + return 1 if failed else 0 if __name__ == "__main__": diff --git a/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py new file mode 100644 index 0000000000000..aa1742d33a0a6 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py @@ -0,0 +1,388 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest +from check_partition_mapper_defaults_in_sync import ( + extract_class_error_messages, + extract_class_field_names, +) + + +class TestExtractClassErrorMessages: + @pytest.mark.parametrize( + ("raise_stmt", "expected"), + [ + pytest.param('raise ValueError("msg a")', {"msg a"}, id="plain-string"), + pytest.param('raise ValueError(f"msg b at {i}")', {"msg b at {}"}, id="fstring-template"), + ], + ) + def test_extracts_message(self, tmp_path: Path, raise_stmt: str, expected: set[str]): + """Plain strings are kept verbatim; f-string interpolations become {} placeholders.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent(f"""\ + class MyClass: + def validate(self, x, i): + {raise_stmt} + """) + ) + assert extract_class_error_messages(f, "MyClass") == expected + + def test_extracts_both_plain_and_fstring(self, tmp_path: Path): + """Covers the explicit test requirement: plain + f-string together.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def check(self, item, i): + raise ValueError("msg a") + + def check2(self, item, i): + raise ValueError(f"msg b at {i}") + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == {"msg a", "msg b at {}"} + + def test_ignores_other_exception_types(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def check(self): + raise TypeError("not a value error") + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == set() + + def test_returns_empty_for_missing_class(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("x = 1\n") + result = extract_class_error_messages(f, "Missing") + assert result == set() + + def test_nested_function_messages_included(self, tmp_path: Path): + """Messages raised inside nested helpers within the class body are collected.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class MyClass: + def __init__(self, items): + def _check(i, item): + if not isinstance(item, str): + raise ValueError(f"must be str; got {type(item).__name__!r} at {i}") + for i, item in enumerate(items): + _check(i, item) + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert result == {"must be str; got {} at {}"} + + def test_multipart_fstring_message(self, tmp_path: Path): + """f-string with multiple literal parts and multiple interpolations.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + class C: + def v(self, item, i): + raise ValueError( + f"Prefix segment keys must be str; " + f"got {type(item).__name__!r} at index {i}: {item!r}" + ) + """) + ) + result = extract_class_error_messages(f, "C") + # The two adjacent f-strings form a single BinOp(Add) node at AST level; + # the extractor concatenates them into one template. + assert "Prefix segment keys must be str; got {} at index {}: {}" in result + + +class TestExtractClassFieldNames: + def test_extracts_annotated_instance_fields(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + from typing import ClassVar + import attrs + + @attrs.define + class MyClass: + expected_decoded_type: ClassVar[type] = str + _segments: frozenset[str] = attrs.field() + """) + ) + result = extract_class_field_names(f, "MyClass") + # ClassVar is excluded; _segments is included + assert result == {"_segments"} + + def test_excludes_classvar_fields(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + from typing import ClassVar + + class MyClass: + flag: ClassVar[bool] = False + name: str + """) + ) + result = extract_class_field_names(f, "MyClass") + assert result == {"name"} + + def test_returns_empty_for_missing_class(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("x = 1\n") + result = extract_class_field_names(f, "Missing") + assert result == set() + + def test_private_field_name_preserved(self, tmp_path: Path): + """Leading underscore in field name is kept verbatim.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + @attrs.define + class Container: + _segments: frozenset[str] = attrs.field() + """) + ) + result = extract_class_field_names(f, "Container") + assert "_segments" in result + + +class TestInSyncPasses: + def test_in_sync_passes(self, tmp_path: Path): + """Two files with identical field names and messages compare equal.""" + code = textwrap.dedent("""\ + import attrs + + @attrs.define + class MyMapper: + downstream_key: str = attrs.field() + + def validate(self, x): + raise ValueError(f"must be non-empty str; got {x!r}.") + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(code) + sdk_file.write_text(code) + + core_fields = extract_class_field_names(core_file, "MyMapper") + sdk_fields = extract_class_field_names(sdk_file, "MyMapper") + assert core_fields == sdk_fields + + core_msgs = extract_class_error_messages(core_file, "MyMapper") + sdk_msgs = extract_class_error_messages(sdk_file, "MyMapper") + assert core_msgs == sdk_msgs + + +class TestConverterPatternExtraction: + """Tests for the module-level converter= follow-through in extract_class_error_messages.""" + + def test_extracts_messages_from_module_level_converter(self, tmp_path: Path): + """Converter function outside the class body is followed and its messages are collected.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _my_convert(items): + for i, item in enumerate(items): + if not isinstance(item, str): + raise ValueError(f"must be str; got {type(item).__name__!r} at {i}") + if not items: + raise ValueError("must not be empty") + return frozenset(items) + + @attrs.define + class MyClass: + _data: frozenset = attrs.field(converter=_my_convert) + """) + ) + result = extract_class_error_messages(f, "MyClass") + assert "must be str; got {} at {}" in result + assert "must not be empty" in result + + def test_converter_fstring_template_placeholders(self, tmp_path: Path): + """f-string expressions in the converter become {} placeholders in the template.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _convert(items): + for i, item in enumerate(items): + raise ValueError(f"bad item {item!r} at index {i}") + return frozenset(items) + + @attrs.define + class Widget: + _items: frozenset = attrs.field(converter=_convert) + """) + ) + result = extract_class_error_messages(f, "Widget") + assert "bad item {} at index {}" in result + + def test_converter_messages_not_collected_for_unrelated_class(self, tmp_path: Path): + """The converter is only followed for the class that references it.""" + f = tmp_path / "code.py" + f.write_text( + textwrap.dedent("""\ + import attrs + + def _my_convert(items): + raise ValueError("converter error") + return frozenset(items) + + @attrs.define + class ClassA: + _data: frozenset = attrs.field(converter=_my_convert) + + @attrs.define + class ClassB: + _name: str = attrs.field() + """) + ) + # ClassA references the converter — should see the message + result_a = extract_class_error_messages(f, "ClassA") + assert "converter error" in result_a + + # ClassB does not reference the converter — should NOT see the message + result_b = extract_class_error_messages(f, "ClassB") + assert "converter error" not in result_b + + +class TestConverterDivergenceDetected: + """Verify that message drift in a module-level converter is caught by the extractor.""" + + def test_divergent_converter_message_detected(self, tmp_path: Path): + """Changing 'non-empty' to 'non-empty strings' in a converter is detected as drift.""" + core_code = textwrap.dedent("""\ + import attrs + + def _convert_segments(segments): + for i, item in enumerate(segments): + if not item: + raise ValueError(f"keys must be non-empty; got empty at {i}.") + return frozenset(segments) + + @attrs.define + class SegmentWindow: + _segments: frozenset = attrs.field(converter=_convert_segments) + """) + sdk_code_diverged = textwrap.dedent("""\ + import attrs + + def _convert_segments(segments): + for i, item in enumerate(segments): + if not item: + raise ValueError(f"keys must be non-empty strings; got empty at {i}.") + return frozenset(segments) + + @attrs.define + class SegmentWindow: + _segments: frozenset = attrs.field(converter=_convert_segments) + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(core_code) + sdk_file.write_text(sdk_code_diverged) + + core_msgs = extract_class_error_messages(core_file, "SegmentWindow") + sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow") + assert core_msgs != sdk_msgs + + +class TestDivergentMessageFails: + def test_divergent_message_detected(self, tmp_path: Path): + """Changing 'non-empty' to 'non-empty strings' on one side is detected.""" + core_code = textwrap.dedent("""\ + import attrs + + @attrs.define + class SegmentWindow: + _segments: frozenset[str] = attrs.field() + + def validate(self, item, i): + raise ValueError(f"keys must be non-empty; got empty at {i}.") + """) + sdk_code_diverged = textwrap.dedent("""\ + import attrs + + @attrs.define + class SegmentWindow: + _segments: frozenset[str] = attrs.field() + + def validate(self, item, i): + raise ValueError(f"keys must be non-empty strings; got empty at {i}.") + """) + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + core_file.write_text(core_code) + sdk_file.write_text(sdk_code_diverged) + + core_msgs = extract_class_error_messages(core_file, "SegmentWindow") + sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow") + assert core_msgs != sdk_msgs + + @pytest.mark.parametrize( + ("core_msg", "sdk_msg"), + [ + pytest.param( + "keys must be non-empty; got empty at {}.", + "keys must be non-empty strings; got empty at {}.", + id="non-empty-vs-non-empty-strings", + ), + pytest.param( + "requires at least one key; got an empty iterable.", + "requires at least one key.", + id="different-constant-wording", + ), + ], + ) + def test_parametrized_divergence(self, tmp_path: Path, core_msg: str, sdk_msg: str): + def _make_file(path: Path, msg: str) -> None: + # Use a plain f-string if the message contains '{}', else a constant. + if "{}" in msg: + # Reconstruct as f-string source: replace {} with {i} + src_msg = msg.replace("{}", "{i}") + stmt = f'raise ValueError(f"{src_msg}")' + else: + stmt = f'raise ValueError("{msg}")' + path.write_text( + textwrap.dedent(f"""\ + class C: + def v(self, i): + {stmt} + """) + ) + + core_file = tmp_path / "core.py" + sdk_file = tmp_path / "sdk.py" + _make_file(core_file, core_msg) + _make_file(sdk_file, sdk_msg) + + core_msgs = extract_class_error_messages(core_file, "C") + sdk_msgs = extract_class_error_messages(sdk_file, "C") + assert core_msgs != sdk_msgs diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 8d87897898057..cd9cec283c552 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -247,6 +247,8 @@ Partition Mapper .. autoapiclass:: airflow.sdk.FanOutMapper +.. autoapiclass:: airflow.sdk.FixedKeyMapper + Rollup Windows ~~~~~~~~~~~~~~ @@ -264,6 +266,8 @@ Rollup Windows .. autoapiclass:: airflow.sdk.YearWindow +.. autoapiclass:: airflow.sdk.SegmentWindow + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index ce5b8ab2c2ec1..00d724cb42d5b 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -55,6 +55,7 @@ "EventsTimetable", "ExceptionRetryPolicy", "FanOutMapper", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "Label", @@ -77,6 +78,7 @@ "RetryPolicy", "RetryRule", "RollupMapper", + "SegmentWindow", "SkipMixin", "SyncCallback", "StartOfDayMapper", @@ -154,6 +156,7 @@ from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper + from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( @@ -170,6 +173,7 @@ HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -244,6 +248,7 @@ "EventsTimetable": ".definitions.timetables.events", "ExceptionRetryPolicy": ".definitions.retry_policy", "FanOutMapper": ".definitions.partition_mappers.temporal", + "FixedKeyMapper": ".definitions.partition_mappers.fixed_key", "HourWindow": ".definitions.partition_mappers.window", "IdentityMapper": ".definitions.partition_mappers.identity", "Label": ".definitions.edges", @@ -266,6 +271,7 @@ "RetryRule": ".definitions.retry_policy", "RollupMapper": ".definitions.partition_mappers.base", "SecretCache": ".execution_time.cache", + "SegmentWindow": ".definitions.partition_mappers.window", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", "StartOfDayMapper": ".definitions.partition_mappers.temporal", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 1bb975e2202dd..d6fa8bc2a0d1f 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -67,6 +67,7 @@ from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper +from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( @@ -83,6 +84,7 @@ from airflow.sdk.definitions.partition_mappers.window import ( HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -154,6 +156,7 @@ __all__ = [ "EventsTimetable", "ExceptionRetryPolicy", "FanOutMapper", + "FixedKeyMapper", "HourWindow", "IdentityMapper", "Label", @@ -175,6 +178,7 @@ __all__ = [ "ResumableJobMixin", "RollupMapper", "SecretCache", + "SegmentWindow", "SkipMixin", "StartOfDayMapper", "StartOfHourMapper", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py new file mode 100644 index 0000000000000..bf2b07905a818 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import attrs + +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper + + +@attrs.define +class FixedKeyMapper(PartitionMapper): + """ + Collapse every upstream partition key onto one fixed downstream key. + + Authoring marker for the scheduler-side + :class:`airflow.partition_mappers.fixed_key.FixedKeyMapper`. Paired with + :class:`~airflow.sdk.definitions.partition_mappers.window.SegmentWindow` inside a + :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to express a + categorical rollup. + + Construction validates *downstream_key* so Dag parse errors surface + immediately rather than deferring to scheduler deserialization. + + :param downstream_key: The fixed downstream partition key every upstream key + maps to. Must be a non-empty string. + :raises ValueError: if *downstream_key* is not a non-empty ``str``. + """ + + downstream_key: str = attrs.field() + + @downstream_key.validator + def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) -> None: + if not isinstance(value, str) or value == "": + raise ValueError(f"FixedKeyMapper downstream_key must be a non-empty str; got {value!r}.") + + def to_downstream(self, key: str) -> str: + """Return the fixed downstream key regardless of *key*.""" + return self.downstream_key diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py index 4ea008ace6f79..f491d4279ec7a 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py @@ -22,7 +22,12 @@ from datetime import datetime from enum import Enum -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar + +import attrs + +if TYPE_CHECKING: + from collections.abc import Iterable class Window: @@ -114,3 +119,52 @@ class YearWindow(Window): """Twelve consecutive monthly keys making up one calendar year.""" expected_decoded_type: ClassVar[type] = datetime + + +def _convert_segments(segments: Iterable[str]) -> frozenset[str]: + """ + Validate and convert *segments* to a ``frozenset[str]``. + + Validates each element for type and non-emptiness (with index reporting) + before collapsing into a frozenset, then checks the result is non-empty. + """ + validated: list[str] = [] + for i, item in enumerate(segments): + if not isinstance(item, str): + raise ValueError( + f"SegmentWindow segment keys must be str; got {type(item).__name__!r} at index {i}: {item!r}" + ) + if not item: + raise ValueError( + f"SegmentWindow segment keys must be non-empty; got an empty string at index {i}." + ) + validated.append(item) + result = frozenset(validated) + if not result: + raise ValueError("SegmentWindow requires at least one segment key; got an empty iterable.") + return result + + +@attrs.define +class SegmentWindow(Window): + """ + A fixed categorical set of string keys that constitute one downstream period. + + Authoring marker for the scheduler-side + :class:`airflow.partition_mappers.window.SegmentWindow`. Paired with + :class:`~airflow.sdk.definitions.partition_mappers.fixed_key.FixedKeyMapper` inside a + :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to express a + categorical rollup. + + Construction validates the segment list so Dag parse errors surface + immediately rather than deferring to scheduler deserialization. + + :param segments: Non-empty iterable of non-empty string segment keys. Duplicates + are silently de-duplicated. + :raises ValueError: if *segments* is empty, contains a non-``str`` element, or + contains an empty-string element. + """ + + expected_decoded_type: ClassVar[type] = str + + _segments: frozenset[str] = attrs.field(converter=_convert_segments) diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py index 8b902820da0d6..7114b65d6c4ba 100644 --- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py +++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py @@ -22,12 +22,14 @@ import pytest from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper +from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper from airflow.sdk.definitions.partition_mappers.window import ( DayWindow, HourWindow, MonthWindow, QuarterWindow, + SegmentWindow, WeekWindow, Window, YearWindow, @@ -103,3 +105,67 @@ class TestSdkWindowExpectedDecodedType: ) def test_temporal_windows_declare_datetime(self, window_cls): assert window_cls.expected_decoded_type is datetime + + +class TestSdkFixedKeyMapper: + """SDK-side FixedKeyMapper construction and validation.""" + + @pytest.mark.parametrize("key", ["us", "eu", "apac"]) + def test_to_downstream_returns_constant_for_any_key(self, key): + assert FixedKeyMapper("all_regions").to_downstream(key) == "all_regions" + + def test_is_rollup_false(self): + assert FixedKeyMapper("all").is_rollup is False + + @pytest.mark.parametrize( + ("downstream_key", "match"), + [ + pytest.param("", "non-empty str", id="empty-string"), + pytest.param(None, "non-empty str", id="none"), + pytest.param(1, "non-empty str", id="int"), + ], + ) + def test_rejects_invalid_downstream_key(self, downstream_key, match): + with pytest.raises(ValueError, match=match): + FixedKeyMapper(downstream_key) + + def test_requires_downstream_key(self): + with pytest.raises(TypeError): + FixedKeyMapper() + + +class TestSdkSegmentWindow: + """SDK-side SegmentWindow construction and validation mirrors the core implementation.""" + + def test_expected_decoded_type_is_str(self): + assert SegmentWindow.expected_decoded_type is str + + def test_deduplication(self): + w = SegmentWindow(["a", "b", "a"]) + assert w._segments == frozenset({"a", "b"}) + + @pytest.mark.parametrize( + ("segments", "match"), + [ + pytest.param([], "at least one segment key", id="empty-list"), + pytest.param([1, "b"], "must be str", id="int-element"), + pytest.param(["", "b"], "non-empty", id="empty-string"), + ], + ) + def test_rejects_invalid_segments(self, segments, match): + with pytest.raises(ValueError, match=match): + SegmentWindow(segments) + + +class TestSdkCategoricalRollupGuard: + """SDK-side RollupMapper guard mirrors core: str mapper + str window passes.""" + + def test_fixed_key_with_segment_window_does_not_raise(self): + # SDK guard: FixedKeyMapper.expected_decoded_type is str, + # SegmentWindow.expected_decoded_type is str -> guard passes. + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=SegmentWindow(["us", "eu"])) + + def test_str_mapper_with_datetime_window_raises(self): + # SDK guard: FixedKeyMapper (str) + DayWindow (datetime) -> raise. + with pytest.raises(TypeError, match="DayWindow expects decoded values of type 'datetime'"): + RollupMapper(upstream_mapper=FixedKeyMapper("all"), window=DayWindow())