From 156370126af8acded9d362d33876d4b1e0d62ffd Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 17:46:08 -0400 Subject: [PATCH 01/24] feat(_transforms): create private transforms package skeleton Empty package, no exports yet. Subsequent commits port modules from PR #3906's transforms package and rewire imports to the private name. --- src/zarr/core/_transforms/__init__.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/zarr/core/_transforms/__init__.py diff --git a/src/zarr/core/_transforms/__init__.py b/src/zarr/core/_transforms/__init__.py new file mode 100644 index 0000000000..e29a0ccf9b --- /dev/null +++ b/src/zarr/core/_transforms/__init__.py @@ -0,0 +1,19 @@ +"""Composable, lazy coordinate transforms for zarr array indexing. + +This package implements TensorStore-inspired index transforms. The core idea: +every indexing operation (slicing, fancy indexing, etc.) produces a coordinate +mapping from user space to storage space. These mappings compose lazily - no +I/O until you explicitly read or write. + +Private package: this module is not part of the public zarr API. The leading +underscore in the package name signals this. Importers outside this package +must be limited to other private zarr modules. + +Key types: + +- ``IndexDomain`` -- a rectangular region of integer coordinates +- ``IndexTransform`` -- maps input coordinates to storage coordinates +- ``ConstantMap``, ``DimensionMap``, ``ArrayMap`` -- the three ways a single + output dimension can depend on the input (see ``output_map.py``) +- ``compose`` -- chain two transforms into one +""" From 273eed9bd27bada4a6a06e6429a2de697f49b5fc Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 17:48:40 -0400 Subject: [PATCH 02/24] feat(_transforms): port output_map module Defines ConstantMap, DimensionMap, ArrayMap, and the OutputIndexMap union. No in-package dependencies. --- src/zarr/core/_transforms/output_map.py | 83 ++++++++++++++++++++++++ tests/test_transforms/__init__.py | 0 tests/test_transforms/test_output_map.py | 56 ++++++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 src/zarr/core/_transforms/output_map.py create mode 100644 tests/test_transforms/__init__.py create mode 100644 tests/test_transforms/test_output_map.py diff --git a/src/zarr/core/_transforms/output_map.py b/src/zarr/core/_transforms/output_map.py new file mode 100644 index 0000000000..5e17a0ae82 --- /dev/null +++ b/src/zarr/core/_transforms/output_map.py @@ -0,0 +1,83 @@ +"""Output index maps — three representations of a set of integer coordinates. + +An output index map describes, for one dimension of storage, which coordinates +an array access will touch. Conceptually it is a **set of integers**. Three +representations cover the cases that arise in practice: + +- ``ConstantMap(offset=5)`` — a singleton set: ``{5}`` +- ``DimensionMap(input_dimension=0, offset=3, stride=2)`` over input ``[0, 5)`` + — an arithmetic progression: ``{3, 5, 7, 9, 11}`` +- ``ArrayMap(index_array=[1, 5, 9])`` — an explicit enumeration: ``{1, 5, 9}`` + +Every output map supports two set-theoretic operations (defined on +``IndexTransform``, which provides the input domain context these maps lack): + +- **intersect** — restrict to coordinates within a range (e.g., a chunk). + ``{3, 5, 7, 9, 11} ∩ [4, 8) = {5, 7}`` +- **translate** — shift every coordinate by a constant (e.g., make chunk-local). + ``{5, 7} - 4 = {1, 3}`` + +These two operations are the foundation of chunk resolution: for each chunk, +intersect the map with the chunk's range, then translate to chunk-local +coordinates. + +The three types exist because they trade off generality for efficiency: + +- ``ConstantMap``: O(1) storage, O(1) intersection +- ``DimensionMap``: O(1) storage, O(1) intersection (analytical) +- ``ArrayMap``: O(n) storage, O(n) intersection (must scan the array) + +Collapsing everything to ``ArrayMap`` would be correct but wasteful — a +billion-element slice would materialize a billion coordinates just to group +them by chunk, when ``DimensionMap`` does it with three integers. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + + +@dataclass(frozen=True, slots=True) +class ConstantMap: + """A singleton set: one storage coordinate. + + Represents ``{offset}``. Arises from integer indexing (e.g., ``arr[5]`` + fixes one dimension to coordinate 5). + """ + + offset: int = 0 + + +@dataclass(frozen=True, slots=True) +class DimensionMap: + """An arithmetic progression of storage coordinates. + + Represents ``{offset + stride * i : i in input_range}``, where the input + range comes from the enclosing ``IndexTransform``'s domain. Arises from + slice indexing (e.g., ``arr[2:10:3]`` gives offset=2, stride=3). + """ + + input_dimension: int + offset: int = 0 + stride: int = 1 + + +@dataclass(frozen=True, slots=True) +class ArrayMap: + """An explicit enumeration of storage coordinates. + + Represents ``{offset + stride * index_array[i] : i in input_range}``. + Arises from fancy indexing (e.g., ``arr[[1, 5, 9]]`` or boolean masks). + """ + + index_array: npt.NDArray[np.intp] + offset: int = 0 + stride: int = 1 + + +OutputIndexMap = ConstantMap | DimensionMap | ArrayMap diff --git a/tests/test_transforms/__init__.py b/tests/test_transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_transforms/test_output_map.py b/tests/test_transforms/test_output_map.py new file mode 100644 index 0000000000..358ea6ed6c --- /dev/null +++ b/tests/test_transforms/test_output_map.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap + + +class TestConstantMap: + def test_construction(self) -> None: + m = ConstantMap(offset=42) + assert m.offset == 42 + + def test_default_offset(self) -> None: + m = ConstantMap() + assert m.offset == 0 + + def test_frozen(self) -> None: + m = ConstantMap(offset=5) + assert isinstance(m, ConstantMap) + + +class TestDimensionMap: + def test_construction(self) -> None: + m = DimensionMap(input_dimension=3, offset=5, stride=2) + assert m.input_dimension == 3 + assert m.offset == 5 + assert m.stride == 2 + + def test_defaults(self) -> None: + m = DimensionMap(input_dimension=0) + assert m.offset == 0 + assert m.stride == 1 + + def test_frozen(self) -> None: + m = DimensionMap(input_dimension=0) + assert isinstance(m, DimensionMap) + + +class TestArrayMap: + def test_construction(self) -> None: + arr = np.array([1, 3, 5], dtype=np.intp) + m = ArrayMap(index_array=arr, offset=10, stride=2) + assert m.offset == 10 + assert m.stride == 2 + np.testing.assert_array_equal(m.index_array, arr) + + def test_defaults(self) -> None: + arr = np.array([0, 1], dtype=np.intp) + m = ArrayMap(index_array=arr) + assert m.offset == 0 + assert m.stride == 1 + + def test_frozen(self) -> None: + arr = np.array([0], dtype=np.intp) + m = ArrayMap(index_array=arr) + assert isinstance(m, ArrayMap) From 29367e226c0e96b67b34c11d39eaf5b208a4a317 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 17:54:33 -0400 Subject: [PATCH 03/24] feat(_transforms): port domain module Defines IndexDomain (a rectangular integer-coordinate region) plus the _normalize_selection helper used by transform construction. --- src/zarr/core/_transforms/domain.py | 178 +++++++++++++++++++++++ tests/test_transforms/test_domain.py | 202 +++++++++++++++++++++++++++ 2 files changed, 380 insertions(+) create mode 100644 src/zarr/core/_transforms/domain.py create mode 100644 tests/test_transforms/test_domain.py diff --git a/src/zarr/core/_transforms/domain.py b/src/zarr/core/_transforms/domain.py new file mode 100644 index 0000000000..90bcc08ace --- /dev/null +++ b/src/zarr/core/_transforms/domain.py @@ -0,0 +1,178 @@ +"""Index domains — rectangular regions in N-dimensional integer space. + +An ``IndexDomain`` represents the set of valid coordinates for an array or +array view. It is the cartesian product of per-dimension integer ranges:: + + IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) + # represents {(i, j) : 2 <= i < 10, 5 <= j < 20} + +Unlike NumPy, domains can have **non-zero origins**. After slicing +``arr[5:10]``, the result has origin 5 and shape 5 — coordinates 5 through +9 are valid. This follows the TensorStore convention. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class IndexDomain: + """A rectangular region in N-dimensional index space. + + The valid coordinates are the integers in + ``[inclusive_min[d], exclusive_max[d])`` for each dimension ``d``. + """ + + inclusive_min: tuple[int, ...] + exclusive_max: tuple[int, ...] + labels: tuple[str, ...] | None = None + + def __post_init__(self) -> None: + if len(self.inclusive_min) != len(self.exclusive_max): + raise ValueError( + f"inclusive_min and exclusive_max must have the same length. " + f"Got {len(self.inclusive_min)} and {len(self.exclusive_max)}." + ) + for i, (lo, hi) in enumerate(zip(self.inclusive_min, self.exclusive_max, strict=True)): + if lo > hi: + raise ValueError( + f"inclusive_min must be <= exclusive_max for all dimensions. " + f"Dimension {i}: {lo} > {hi}" + ) + if self.labels is not None and len(self.labels) != len(self.inclusive_min): + raise ValueError( + f"labels must have the same length as dimensions. " + f"Got {len(self.labels)} labels for {len(self.inclusive_min)} dimensions." + ) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexDomain: + """Create a domain with origin at zero.""" + return cls( + inclusive_min=(0,) * len(shape), + exclusive_max=shape, + ) + + @property + def ndim(self) -> int: + return len(self.inclusive_min) + + @property + def origin(self) -> tuple[int, ...]: + return self.inclusive_min + + @property + def shape(self) -> tuple[int, ...]: + return tuple(hi - lo for lo, hi in zip(self.inclusive_min, self.exclusive_max, strict=True)) + + def contains(self, index: tuple[int, ...]) -> bool: + if len(index) != self.ndim: + return False + return all( + lo <= idx < hi + for lo, hi, idx in zip(self.inclusive_min, self.exclusive_max, index, strict=True) + ) + + def contains_domain(self, other: IndexDomain) -> bool: + if other.ndim != self.ndim: + return False + return all( + self_lo <= other_lo and other_hi <= self_hi + for self_lo, self_hi, other_lo, other_hi in zip( + self.inclusive_min, + self.exclusive_max, + other.inclusive_min, + other.exclusive_max, + strict=True, + ) + ) + + def intersect(self, other: IndexDomain) -> IndexDomain | None: + if other.ndim != self.ndim: + raise ValueError( + f"Cannot intersect domains with different ranks: {self.ndim} vs {other.ndim}" + ) + new_min = tuple( + max(a, b) for a, b in zip(self.inclusive_min, other.inclusive_min, strict=True) + ) + new_max = tuple( + min(a, b) for a, b in zip(self.exclusive_max, other.exclusive_max, strict=True) + ) + if any(lo >= hi for lo, hi in zip(new_min, new_max, strict=True)): + return None + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def translate(self, offset: tuple[int, ...]) -> IndexDomain: + if len(offset) != self.ndim: + raise ValueError( + f"Offset must have same length as domain dimensions. " + f"Domain has {self.ndim} dimensions, offset has {len(offset)}." + ) + new_min = tuple(lo + off for lo, off in zip(self.inclusive_min, offset, strict=True)) + new_max = tuple(hi + off for hi, off in zip(self.exclusive_max, offset, strict=True)) + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def narrow(self, selection: Any) -> IndexDomain: + """Apply a basic selection and return a narrowed domain. + Indices are absolute coordinates. Integer indices produce length-1 extent. + Strided slices are not supported — use IndexTransform for strides. + """ + normalized = _normalize_selection(selection, self.ndim) + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + for dim_idx, (sel, dim_lo, dim_hi) in enumerate( + zip(normalized, self.inclusive_min, self.exclusive_max, strict=True) + ): + if isinstance(sel, int): + if sel < dim_lo or sel >= dim_hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {dim_idx} " + f"with domain [{dim_lo}, {dim_hi})" + ) + new_inclusive_min.append(sel) + new_exclusive_max.append(sel + 1) + else: + start, stop, step = sel.start, sel.stop, sel.step + if step is not None and step != 1: + raise IndexError( + "IndexDomain.narrow only supports step=1 slices. " + f"Got step={step}. Use IndexTransform for strided access." + ) + abs_start = dim_lo if start is None else start + abs_stop = dim_hi if stop is None else stop + abs_start = max(abs_start, dim_lo) + abs_stop = min(abs_stop, dim_hi) + abs_stop = max(abs_stop, abs_start) + new_inclusive_min.append(abs_start) + new_exclusive_max.append(abs_stop) + return IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + +def _normalize_selection(selection: Any, ndim: int) -> tuple[int | slice, ...]: + """Normalize a basic selection to a tuple of ints/slices with length ndim.""" + if not isinstance(selection, tuple): + selection = (selection,) + result: list[int | slice] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - (len(selection) - 1) + result.extend([slice(None)] * num_missing) + else: + result.append(sel) + while len(result) < ndim: + result.append(slice(None)) + if len(result) > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, " + f"but {len(result)} were indexed" + ) + return tuple(result) diff --git a/tests/test_transforms/test_domain.py b/tests/test_transforms/test_domain.py new file mode 100644 index 0000000000..58f3808d95 --- /dev/null +++ b/tests/test_transforms/test_domain.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import pytest + +from zarr.core._transforms.domain import IndexDomain + + +class TestIndexDomainConstruction: + def test_from_shape(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.inclusive_min == (0, 0) + assert d.exclusive_max == (10, 20) + assert d.ndim == 2 + assert d.origin == (0, 0) + assert d.shape == (10, 20) + + def test_from_shape_0d(self) -> None: + d = IndexDomain.from_shape(()) + assert d.ndim == 0 + assert d.shape == () + + def test_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(5, 10), exclusive_max=(15, 30)) + assert d.origin == (5, 10) + assert d.shape == (10, 20) + assert d.ndim == 2 + + def test_validation_mismatched_lengths(self) -> None: + with pytest.raises(ValueError, match="same length"): + IndexDomain(inclusive_min=(0,), exclusive_max=(10, 20)) + + def test_validation_min_greater_than_max(self) -> None: + with pytest.raises(ValueError, match="inclusive_min must be <="): + IndexDomain(inclusive_min=(10,), exclusive_max=(5,)) + + def test_empty_domain(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(5,)) + assert d.shape == (0,) + + def test_labels(self) -> None: + d = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 20), labels=("x", "y")) + assert d.labels == ("x", "y") + + def test_labels_none(self) -> None: + d = IndexDomain.from_shape((10,)) + assert d.labels is None + + +class TestIndexDomainContains: + def test_contains_inside(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((0, 0)) is True + assert d.contains((9, 19)) is True + assert d.contains((5, 10)) is True + + def test_contains_outside(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((10, 0)) is False + assert d.contains((-1, 0)) is False + assert d.contains((0, 20)) is False + + def test_contains_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + assert d.contains((5,)) is True + assert d.contains((9,)) is True + assert d.contains((4,)) is False + assert d.contains((10,)) is False + + def test_contains_wrong_ndim(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((5,)) is False + + def test_contains_domain_inside(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(8, 15)) + assert outer.contains_domain(inner) is True + + def test_contains_domain_outside(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(11, 15)) + assert outer.contains_domain(inner) is False + + def test_contains_domain_wrong_ndim(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain.from_shape((5,)) + assert outer.contains_domain(inner) is False + + +class TestIndexDomainIntersect: + def test_overlapping(self) -> None: + a = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 10)) + b = IndexDomain(inclusive_min=(5, 5), exclusive_max=(15, 15)) + result = a.intersect(b) + assert result is not None + assert result.inclusive_min == (5, 5) + assert result.exclusive_max == (10, 10) + + def test_disjoint(self) -> None: + a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) + b = IndexDomain(inclusive_min=(10,), exclusive_max=(15,)) + assert a.intersect(b) is None + + def test_touching_boundary(self) -> None: + a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) + b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + assert a.intersect(b) is None + + def test_contained(self) -> None: + a = IndexDomain.from_shape((20,)) + b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + result = a.intersect(b) + assert result is not None + assert result.inclusive_min == (5,) + assert result.exclusive_max == (10,) + + def test_wrong_ndim(self) -> None: + a = IndexDomain.from_shape((10,)) + b = IndexDomain.from_shape((10, 20)) + with pytest.raises(ValueError, match="different ranks"): + a.intersect(b) + + +class TestIndexDomainTranslate: + def test_translate_positive(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.translate((5, 10)) + assert result.inclusive_min == (5, 10) + assert result.exclusive_max == (15, 30) + + def test_translate_negative(self) -> None: + d = IndexDomain(inclusive_min=(10, 20), exclusive_max=(30, 40)) + result = d.translate((-10, -20)) + assert result.inclusive_min == (0, 0) + assert result.exclusive_max == (20, 20) + + def test_translate_wrong_length(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(ValueError, match="same length"): + d.translate((1, 2)) + + +class TestIndexDomainNarrow: + def test_narrow_slice(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.narrow((slice(2, 8), slice(5, 15))) + assert result.inclusive_min == (2, 5) + assert result.exclusive_max == (8, 15) + + def test_narrow_int(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.narrow((3, slice(None))) + assert result.inclusive_min == (3, 0) + assert result.exclusive_max == (4, 20) + + def test_narrow_ellipsis(self) -> None: + d = IndexDomain.from_shape((10, 20, 30)) + result = d.narrow((slice(1, 5), ...)) + assert result.inclusive_min == (1, 0, 0) + assert result.exclusive_max == (5, 20, 30) + + def test_narrow_slice_none(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow((slice(None),)) + assert result == d + + def test_narrow_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(10,), exclusive_max=(20,)) + result = d.narrow((slice(12, 18),)) + assert result.inclusive_min == (12,) + assert result.exclusive_max == (18,) + + def test_narrow_int_out_of_bounds(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="out of bounds"): + d.narrow((10,)) + + def test_narrow_int_below_origin(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + with pytest.raises(IndexError, match="out of bounds"): + d.narrow((4,)) + + def test_narrow_clamps_to_domain(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow((slice(-5, 100),)) + assert result.inclusive_min == (0,) + assert result.exclusive_max == (10,) + + def test_narrow_bare_slice(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow(slice(2, 8)) + assert result.inclusive_min == (2,) + assert result.exclusive_max == (8,) + + def test_narrow_too_many_indices(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="too many indices"): + d.narrow((1, 2)) + + def test_narrow_step_not_one(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="step=1"): + d.narrow((slice(0, 10, 2),)) From 2e7b6466c94ee5382f3f1644170e71a12b7633c9 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 18:02:17 -0400 Subject: [PATCH 04/24] feat(_transforms): port transform module Defines IndexTransform plus selection_to_transform and the per-mode application functions (_apply_basic_indexing, _apply_oindex, _apply_vindex). This is the core of the package. --- src/zarr/core/_transforms/transform.py | 932 ++++++++++++++++++++++++ tests/test_transforms/test_transform.py | 516 +++++++++++++ 2 files changed, 1448 insertions(+) create mode 100644 src/zarr/core/_transforms/transform.py create mode 100644 tests/test_transforms/test_transform.py diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py new file mode 100644 index 0000000000..d8dbf11b52 --- /dev/null +++ b/src/zarr/core/_transforms/transform.py @@ -0,0 +1,932 @@ +"""Index transforms — composable, lazy coordinate mappings. + +An ``IndexTransform`` pairs an **input domain** (the coordinates a user sees) +with a tuple of **output maps** (the storage coordinates those inputs map to). +One output map per storage dimension. See ``output_map.py`` for the three +output map types. + +Key operations: + +- **Indexing** (``transform[2:8]``, ``.oindex[idx]``, ``.vindex[idx]``) — + produces a new transform with a narrower input domain and adjusted output + maps. No I/O occurs. This is how lazy slicing works. + +- **intersect(output_domain)** — restrict to storage coordinates within a + region. This is chunk resolution: "which of my coordinates fall in this + chunk?" + +- **translate(shift)** — shift all output coordinates. This makes coordinates + chunk-local: "express my coordinates relative to the chunk origin." + +- **compose(outer, inner)** — chain two transforms. See ``composition.py``. + +The transform is the atomic unit that connects user-facing indexing to +chunk-level I/O. Every ``Array`` holds a transform (identity by default). +``Array.z[...]`` composes a new transform lazily. Reading resolves the +transform against the chunk grid via intersect + translate. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Literal + +import numpy as np + +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap + + +@dataclass(frozen=True, slots=True) +class IndexTransform: + """A composable mapping from input coordinates to storage coordinates. + + An ``IndexTransform`` has: + + - ``domain``: an ``IndexDomain`` describing the valid input coordinates + (the user-facing shape, possibly with non-zero origin). + - ``output``: a tuple of output maps (one per storage dimension), each + describing which storage coordinates the inputs touch. + + For a freshly opened array, the transform is the identity: input + coordinate ``i`` maps to storage coordinate ``i``. Indexing operations + compose new transforms without I/O. + """ + + domain: IndexDomain + output: tuple[OutputIndexMap, ...] + + def __post_init__(self) -> None: + for i, m in enumerate(self.output): + if isinstance(m, DimensionMap): + if m.input_dimension < 0 or m.input_dimension >= self.domain.ndim: + raise ValueError( + f"output[{i}].input_dimension = {m.input_dimension} " + f"is out of range for input rank {self.domain.ndim}" + ) + elif isinstance(m, ArrayMap) and m.index_array.ndim > self.domain.ndim: + raise ValueError( + f"output[{i}].index_array has {m.index_array.ndim} dims " + f"but input domain has {self.domain.ndim} dims" + ) + + @property + def input_rank(self) -> int: + return self.domain.ndim + + @property + def output_rank(self) -> int: + return len(self.output) + + @classmethod + def identity(cls, domain: IndexDomain) -> IndexTransform: + output = tuple(DimensionMap(input_dimension=i) for i in range(domain.ndim)) + return cls(domain=domain, output=output) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexTransform: + return cls.identity(IndexDomain.from_shape(shape)) + + @property + def selection_repr(self) -> str: + """Compact domain string, e.g. ``'{ [2, 8), [0, 10) }'``. + + Follows TensorStore's IndexDomain notation: each dimension shown + as ``[inclusive_min, exclusive_max)`` with stride annotation if not 1. + Constant (integer-indexed) dimensions show as a single value. + Array-indexed dimensions show the set of selected coordinates. + """ + parts: list[str] = [] + for m in self.output: + if isinstance(m, ConstantMap): + parts.append(str(m.offset)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + lo = self.domain.inclusive_min[d] + hi = self.domain.exclusive_max[d] + start = m.offset + m.stride * lo + stop = m.offset + m.stride * hi + if m.stride == 1: + parts.append(f"[{start}, {stop})") + else: + parts.append(f"[{start}, {stop}) step {m.stride}") + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + n = len(storage) + if n <= 5: + vals = ", ".join(str(int(v)) for v in storage.ravel()) + parts.append("{" + vals + "}") + else: + parts.append("{" + f"array({n})" + "}") + return "{ " + ", ".join(parts) + " }" + + def __repr__(self) -> str: + maps: list[str] = [] + for i, m in enumerate(self.output): + if isinstance(m, ConstantMap): + maps.append(f"out[{i}] = {m.offset}") + elif isinstance(m, DimensionMap): + maps.append(f"out[{i}] = {m.offset} + {m.stride} * in[{m.input_dimension}]") + elif isinstance(m, ArrayMap): + maps.append(f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[in]") + maps_str = ", ".join(maps) + return f"IndexTransform(domain={self.domain}, {maps_str})" + + def intersect( + self, output_domain: IndexDomain + ) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Restrict this transform to storage coordinates within output_domain. + + Returns ``(restricted_transform, surviving_indices)`` or None if empty. + + ``surviving_indices`` is an integer array of which input positions + survived the intersection (for ArrayMap dimensions), or None if all + positions survived (ConstantMap/DimensionMap only). + """ + return _intersect(self, output_domain) + + def translate(self, shift: tuple[int, ...]) -> IndexTransform: + """Shift all output coordinates by ``shift``.""" + if len(shift) != self.output_rank: + raise ValueError(f"shift must have length {self.output_rank}, got {len(shift)}") + new_output: list[OutputIndexMap] = [] + for m, s in zip(self.output, shift, strict=True): + if isinstance(m, ConstantMap): + new_output.append(ConstantMap(offset=m.offset + s)) + elif isinstance(m, DimensionMap): + new_output.append( + DimensionMap( + input_dimension=m.input_dimension, + offset=m.offset + s, + stride=m.stride, + ) + ) + elif isinstance(m, ArrayMap): + new_output.append( + ArrayMap( + index_array=m.index_array, + offset=m.offset + s, + stride=m.stride, + ) + ) + return IndexTransform(domain=self.domain, output=tuple(new_output)) + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_basic_indexing(self, selection) + + @property + def oindex(self) -> _OIndexHelper: + return _OIndexHelper(self) + + @property + def vindex(self) -> _VIndexHelper: + return _VIndexHelper(self) + + +def _intersect( + transform: IndexTransform, output_domain: IndexDomain +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a transform with an output domain (e.g., a chunk's bounds). + + For each output dimension, restrict to storage coordinates within + [output_domain.inclusive_min[d], output_domain.exclusive_max[d]). + + For orthogonal transforms (ConstantMap, DimensionMap, independent ArrayMaps), + each dimension is intersected independently and the input domain is narrowed. + + For vectorized transforms (correlated ArrayMaps), all array dimensions + must be checked simultaneously — a point survives only if ALL its + coordinates fall within the output domain. + + Returns None if the intersection is empty. + """ + if output_domain.ndim != transform.output_rank: + raise ValueError( + f"output_domain rank ({output_domain.ndim}) != " + f"transform output rank ({transform.output_rank})" + ) + + # Check if we have correlated ArrayMaps (vectorized) + array_dims = [i for i, m in enumerate(transform.output) if isinstance(m, ArrayMap)] + if len(array_dims) >= 2: + return _intersect_vectorized(transform, output_domain, array_dims) + + # Orthogonal: intersect each output dimension independently + new_min = list(transform.domain.inclusive_min) + new_max = list(transform.domain.exclusive_max) + new_output: list[OutputIndexMap] = [] + surviving_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None + + for out_dim, m in enumerate(transform.output): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + + if isinstance(m, ConstantMap): + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + + elif isinstance(m, DimensionMap): + d = m.input_dimension + input_lo = new_min[d] + input_hi = new_max[d] + if input_lo >= input_hi: + return None + + # Find input range that produces storage coords in [lo, hi) + if m.stride > 0: + new_input_lo = max(input_lo, math.ceil((lo - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((hi - m.offset) / m.stride)) + elif m.stride < 0: + new_input_lo = max(input_lo, math.ceil((hi - 1 - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((lo - 1 - m.offset) / m.stride)) + else: + if lo <= m.offset < hi: + new_input_lo, new_input_hi = input_lo, input_hi + else: + return None + + if new_input_lo >= new_input_hi: + return None + + new_min[d] = new_input_lo + new_max[d] = new_input_hi + new_output.append(m) + + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + mask = (storage >= lo) & (storage < hi) + if not np.any(mask): + return None + surviving_indices = np.nonzero(mask.ravel())[0].astype(np.intp) + filtered = m.index_array.ravel()[surviving_indices] + new_output.append( + ArrayMap( + index_array=filtered, + offset=m.offset, + stride=m.stride, + ) + ) + + new_domain = IndexDomain( + inclusive_min=tuple(new_min), + exclusive_max=tuple(new_max), + ) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving_indices) + + +def _intersect_vectorized( + transform: IndexTransform, + output_domain: IndexDomain, + array_dims: list[int], +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a vectorized transform with an output domain. + + All ArrayMap outputs are correlated — a point survives only if ALL its + storage coordinates fall within the output domain. + """ + # Compute storage coords per array dim and check bounds simultaneously + n_points: int | None = None + masks: list[np.ndarray[Any, np.dtype[np.bool_]]] = [] + + for out_dim in array_dims: + m = transform.output[out_dim] + assert isinstance(m, ArrayMap) + storage = m.offset + m.stride * m.index_array + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + masks.append((storage >= lo) & (storage < hi)) + if n_points is None: + n_points = storage.size + + # A point survives only if it's in-bounds on ALL array dims + combined_mask = masks[0] + for mask in masks[1:]: + combined_mask = combined_mask & mask + + if not np.any(combined_mask): + return None + + surviving = np.nonzero(combined_mask.ravel())[0].astype(np.intp) + + # Build new output maps + new_output: list[OutputIndexMap] = [] + for out_dim, m in enumerate(transform.output): + if isinstance(m, ArrayMap): + filtered = m.index_array.ravel()[surviving] + new_output.append( + ArrayMap( + index_array=filtered, + offset=m.offset, + stride=m.stride, + ) + ) + elif isinstance(m, ConstantMap): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + elif isinstance(m, DimensionMap): + new_output.append(m) + + new_domain = IndexDomain.from_shape((len(surviving),)) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving) + + +def _normalize_basic_selection(selection: Any, ndim: int) -> tuple[int | slice | None, ...]: + """Normalize a selection to a tuple of int, slice, or None (newaxis), + expanding ellipsis and padding with slice(None) as needed. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Count non-newaxis, non-ellipsis entries to determine how many real dims are addressed + n_newaxis = sum(1 for s in selection if s is None) + has_ellipsis = any(s is Ellipsis for s in selection) + n_real = len(selection) - n_newaxis - (1 if has_ellipsis else 0) + + if n_real > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, but {n_real} were indexed" + ) + + result: list[int | slice | None] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, (int, np.integer)): + result.append(int(sel)) + elif isinstance(sel, slice) or sel is None: + result.append(sel) + else: + raise IndexError(f"unsupported selection type for basic indexing: {type(sel)!r}") + + # Pad remaining dimensions with slice(None) + while sum(1 for s in result if s is not None) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _reindex_array( + arr: np.ndarray[Any, np.dtype[np.intp]], + normalized: tuple[int | slice | None, ...], + domain: IndexDomain, +) -> np.ndarray[Any, np.dtype[np.intp]]: + """Apply basic indexing operations to an ArrayMap's index_array. + + The array's axes correspond to the transform's input dimensions (0-indexed + over the domain shape). When input dimensions are dropped (int), sliced, + or inserted (newaxis), the array must be updated accordingly. + """ + # Build a numpy indexing tuple: one entry per old input dimension + idx: list[Any] = [] + old_dim = 0 + newaxis_positions: list[int] = [] + result_axis = 0 + + for sel in normalized: + if sel is None: + newaxis_positions.append(result_axis) + result_axis += 1 + elif isinstance(sel, int): + if old_dim < arr.ndim: + # Convert absolute domain coordinate to 0-based array index + array_idx = sel - domain.inclusive_min[old_dim] + idx.append(array_idx) + old_dim += 1 + elif isinstance(sel, slice): + if old_dim < arr.ndim: + dim_size = domain.shape[old_dim] + # sel.indices gives 0-based start/stop/step for the array axis + start, stop, step = sel.indices(dim_size) + idx.append(slice(start, stop, step)) + old_dim += 1 + result_axis += 1 + + result = arr[tuple(idx)] if idx else arr + + for pos in newaxis_positions: + result = np.expand_dims(result, axis=pos) + + return np.asarray(result, dtype=np.intp) + + +def _reindex_array_oindex( + arr: np.ndarray[Any, np.dtype[np.intp]], + normalized: tuple[Any, ...] | list[Any], + domain: IndexDomain, +) -> np.ndarray[Any, np.dtype[np.intp]]: + """Apply oindex/vindex selection to an existing ArrayMap's index_array. + + Each old input dimension gets either an array (fancy index that axis) + or a slice applied to the corresponding array axis. + """ + idx: list[Any] = [] + for old_dim, sel in enumerate(normalized): + if old_dim >= arr.ndim: + break + if isinstance(sel, np.ndarray): + idx.append(sel) + elif isinstance(sel, slice): + dim_size = domain.shape[old_dim] + start, stop, step = sel.indices(dim_size) + idx.append(slice(start, stop, step)) + else: + idx.append(slice(None)) + + result = arr[tuple(idx)] if idx else arr + return np.asarray(result, dtype=np.intp) + + +def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply basic indexing (int, slice, ellipsis, newaxis) to an IndexTransform.""" + normalized = _normalize_basic_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + old_dim = 0 + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + dropped_dims: set[int] = set() + + # Per old-dim: the slice parameters (for computing new output maps) + dim_slice_params: dict[int, tuple[int, int, int]] = {} # old_dim -> (start, stop, step) + dim_int_val: dict[int, int] = {} # old_dim -> integer index value + + for sel in normalized: + if sel is None: + # newaxis: add a size-1 dimension + new_inclusive_min.append(0) + new_exclusive_max.append(1) + new_dim_idx += 1 + elif isinstance(sel, int): + # Integer index: drop this input dimension. + # Negative indices are literal coordinates (TensorStore convention), + # NOT "from the end" like NumPy. The Array layer handles conversion. + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + idx = sel + if idx < lo or idx >= hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {old_dim} with domain [{lo}, {hi})" + ) + dropped_dims.add(old_dim) + dim_int_val[old_dim] = idx + old_dim += 1 + elif isinstance(sel, slice): + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + + # Resolve slice relative to the current domain (origin-based) + start, stop, step = sel.indices(dim_size) + # start, stop, step are now relative to a 0-based range of size dim_size + + if step <= 0: + raise IndexError("slice step must be positive") + + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + + # Absolute start in the original domain coordinates + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + old_dim += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Now update output maps + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dropped_dims: + # Integer index: this output becomes constant + new_offset = m.offset + m.stride * dim_int_val[d] + new_output.append(ConstantMap(offset=new_offset)) + elif d in old_to_new_dim: + # Slice: update offset and stride + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError(f"unexpected: dimension {d} not handled") + elif isinstance(m, ArrayMap): + new_arr = _reindex_array(m.index_array, normalized, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _OIndexHelper: + """Helper that provides orthogonal (outer) indexing via ``transform.oindex[...]``.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_oindex(self._transform, selection) + + +def _normalize_oindex_selection( + selection: Any, ndim: int +) -> tuple[np.ndarray[Any, np.dtype[np.intp]] | slice, ...]: + """Normalize an oindex selection: arrays, slices, booleans, integers.""" + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis + has_ellipsis = any(s is Ellipsis for s in selection) + n_ellipsis = 1 if has_ellipsis else 0 + n_real = len(selection) - n_ellipsis + + result: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + # Boolean array -> integer indices + (indices,) = np.nonzero(sel) + result.append(indices.astype(np.intp)) + elif isinstance(sel, np.ndarray): + result.append(sel.astype(np.intp)) + elif isinstance(sel, slice): + result.append(sel) + elif isinstance(sel, (int, np.integer)): + # Convert integer scalars to 1-element arrays for orthogonal indexing + result.append(np.array([int(sel)], dtype=np.intp)) + elif isinstance(sel, (list, tuple)): + result.append(np.asarray(sel, dtype=np.intp)) + else: + result.append(sel) + + # Pad with slice(None) + while len(result) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply orthogonal indexing to an IndexTransform. + + Each index array is applied independently per dimension (outer product). + """ + normalized = _normalize_oindex_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + + # Info per old dim + dim_array: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + dim_slice_params: dict[int, tuple[int, int, int]] = {} + + for old_dim, sel in enumerate(normalized): + if isinstance(sel, np.ndarray): + dim_array[old_dim] = sel + new_inclusive_min.append(0) + new_exclusive_max.append(len(sel)) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + elif isinstance(sel, slice): + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dim_array: + new_output.append( + ArrayMap( + index_array=dim_array[d], + offset=m.offset, + stride=m.stride, + ) + ) + elif d in dim_slice_params: + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError(f"unexpected: dimension {d} not handled") + elif isinstance(m, ArrayMap): + new_arr = _reindex_array_oindex(m.index_array, normalized, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _VIndexHelper: + """Helper that provides vectorized (fancy) indexing via ``transform.vindex[...]``.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_vindex(self._transform, selection) + + +def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply vectorized indexing to an IndexTransform. + + All array indices are broadcast together. Broadcast dimensions are prepended, + followed by non-array (slice) dimensions. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis and count consumed dimensions + # Boolean arrays with ndim > 1 consume ndim dims + n_consumed = 0 + for s in selection: + if s is Ellipsis: + continue + if isinstance(s, np.ndarray) and s.dtype == np.bool_ and s.ndim > 1: + n_consumed += s.ndim + else: + n_consumed += 1 + ndim = transform.domain.ndim + + expanded: list[Any] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_consumed + expanded.extend([slice(None)] * num_missing) + else: + expanded.append(sel) + # Count dimensions already consumed by expanded entries + n_expanded_dims = 0 + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_ and sel.ndim > 1: + n_expanded_dims += sel.ndim + else: + n_expanded_dims += 1 + while n_expanded_dims < ndim: + expanded.append(slice(None)) + n_expanded_dims += 1 + + # Convert booleans, lists, ints to integer arrays + processed: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + indices_tuple = np.nonzero(sel) + processed.extend(indices.astype(np.intp) for indices in indices_tuple) + elif isinstance(sel, np.ndarray): + processed.append(sel.astype(np.intp)) + elif isinstance(sel, (list, tuple)): + processed.append(np.asarray(sel, dtype=np.intp)) + elif isinstance(sel, (int, np.integer)): + processed.append(np.array([int(sel)], dtype=np.intp)) + else: + processed.append(sel) + + # Separate array dims and slice dims + array_dims: list[int] = [] + slice_dims: list[int] = [] + arrays: list[np.ndarray[Any, np.dtype[np.intp]]] = [] + + for i, sel in enumerate(processed): + if isinstance(sel, np.ndarray): + array_dims.append(i) + arrays.append(sel) + else: + slice_dims.append(i) + + # Broadcast all arrays together + broadcast_arrays: list[np.ndarray[Any, np.dtype[np.intp]]] + if arrays: + broadcast_arrays = list(np.broadcast_arrays(*arrays)) + broadcast_shape = broadcast_arrays[0].shape + else: + broadcast_arrays = [] + broadcast_shape = () + + # Build new domain: broadcast dims first, then slice dims + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + + # Broadcast dimensions + for s in broadcast_shape: + new_inclusive_min.append(0) + new_exclusive_max.append(s) + + # Slice dimensions + slice_dim_params: dict[int, tuple[int, int, int]] = {} + for old_dim in slice_dims: + sel = processed[old_dim] + assert isinstance(sel, slice) + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + slice_dim_params[old_dim] = (abs_start, stop, step) + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Build output maps + array_dim_to_broadcast: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + for i, d in enumerate(array_dims): + array_dim_to_broadcast[d] = broadcast_arrays[i] + + # New dim index for slice dims starts after broadcast dims + n_broadcast_dims = len(broadcast_shape) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in array_dim_to_broadcast: + new_output.append( + ArrayMap( + index_array=array_dim_to_broadcast[d], + offset=m.offset, + stride=m.stride, + ) + ) + else: + # Slice dim + abs_start, _, step = slice_dim_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = n_broadcast_dims + slice_dims.index(d) + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + elif isinstance(m, ArrayMap): + new_arr = _reindex_array_oindex(m.index_array, processed, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +def _normalize_negative_indices(selection: Any, shape: tuple[int, ...]) -> Any: + """Convert negative indices to positive ones using the array shape. + + Only normalizes integer and array-like index components; leaves + slices, Ellipsis, None, etc. untouched. + """ + if not isinstance(selection, tuple): + selection_tuple: tuple[Any, ...] = (selection,) + else: + selection_tuple = selection + + # Count real dimensions (non-None, non-Ellipsis) to map each entry to a shape dim + has_ellipsis = any(s is Ellipsis for s in selection_tuple) + n_non_newaxis = sum(1 for s in selection_tuple if s is not None and s is not Ellipsis) + n_ellipsis_dims = len(shape) - n_non_newaxis + (1 if has_ellipsis else 0) + + result: list[Any] = [] + dim = 0 + + for sel in selection_tuple: + if sel is Ellipsis: + result.append(sel) + dim += max(0, n_ellipsis_dims) + elif sel is None: + result.append(sel) + elif isinstance(sel, (int, np.integer)) and not isinstance(sel, bool): + idx = int(sel) + if idx < 0 and dim < len(shape): + idx = idx + shape[dim] + result.append(idx) + dim += 1 + elif isinstance(sel, np.ndarray) and sel.dtype != np.bool_: + arr = sel.copy() + if dim < len(shape): + arr = np.where(arr < 0, arr + shape[dim], arr) + result.append(arr) + dim += 1 + elif isinstance(sel, list): + # Convert lists to arrays with negative index normalization + arr = np.asarray(sel, dtype=np.intp) + if dim < len(shape): + arr = np.where(arr < 0, arr + shape[dim], arr) + result.append(arr) + dim += 1 + else: + # slice, bool array, or anything else: pass through + result.append(sel) + if sel is not None and sel is not Ellipsis: + dim += 1 + + if not isinstance(selection, tuple) and len(result) == 1: + return result[0] + return tuple(result) + + +def _validate_array_selection(selection: Any, shape: tuple[int, ...], mode: str) -> None: + """Validate array-based selections (orthogonal, vectorized). + + Rejects types that are not valid for coordinate/vectorized indexing. + Does not check bounds — the transform operations handle that. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for sel in items: + if sel is Ellipsis or isinstance(sel, (int, np.integer, slice)): + continue + if isinstance(sel, (list, np.ndarray)): + continue + raise IndexError(f"unsupported selection type for {mode} indexing: {type(sel)!r}") + + +def _validate_basic_selection(selection: Any) -> None: + """Validate that a selection only contains basic indexing types (int, slice, Ellipsis). + + Rejects None (newaxis), arrays, lists, floats, strings, etc. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for s in items: + if s is Ellipsis or isinstance(s, (int, np.integer, slice)): + continue + raise IndexError(f"unsupported selection type for basic indexing: {type(s)!r}") + + +def selection_to_transform( + selection: Any, + transform: IndexTransform, + mode: Literal["basic", "orthogonal", "vectorized"], +) -> IndexTransform: + """Convert a user selection into a composed IndexTransform. + + Negative indices are treated as literal coordinates (TensorStore convention). + The caller (Array layer) is responsible for converting numpy-style negative + indices before calling this function. + """ + if mode == "basic": + _validate_basic_selection(selection) + return transform[selection] + elif mode == "orthogonal": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.oindex[selection] + elif mode == "vectorized": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.vindex[selection] + else: + raise ValueError(f"Unknown mode: {mode!r}") diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py new file mode 100644 index 0000000000..03f26cfb5d --- /dev/null +++ b/tests/test_transforms/test_transform.py @@ -0,0 +1,516 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform, selection_to_transform + + +class TestIndexTransformConstruction: + def test_from_shape(self) -> None: + t = IndexTransform.from_shape((10, 20)) + assert t.input_rank == 2 + assert t.output_rank == 2 + assert t.domain.shape == (10, 20) + assert t.domain.origin == (0, 0) + for i, m in enumerate(t.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + def test_identity(self) -> None: + domain = IndexDomain(inclusive_min=(5,), exclusive_max=(15,)) + t = IndexTransform.identity(domain) + assert t.input_rank == 1 + assert t.output_rank == 1 + assert t.domain == domain + assert isinstance(t.output[0], DimensionMap) + assert t.output[0].input_dimension == 0 + + def test_from_shape_0d(self) -> None: + t = IndexTransform.from_shape(()) + assert t.input_rank == 0 + assert t.output_rank == 0 + assert t.domain.shape == () + + def test_custom_output_maps(self) -> None: + domain = IndexDomain.from_shape((10,)) + maps = (ConstantMap(offset=42), DimensionMap(input_dimension=0, offset=5, stride=2)) + t = IndexTransform(domain=domain, output=maps) + assert t.input_rank == 1 + assert t.output_rank == 2 + + def test_validation_input_dimension_out_of_range(self) -> None: + domain = IndexDomain.from_shape((10,)) + maps = (DimensionMap(input_dimension=5),) + with pytest.raises(ValueError, match="input_dimension"): + IndexTransform(domain=domain, output=maps) + + +class TestIndexTransformBasicIndexing: + def test_slice_identity(self) -> None: + """slice(None) on identity transform is a no-op.""" + t = IndexTransform.from_shape((10, 20)) + result = t[slice(None), slice(None)] + assert result.domain.shape == (10, 20) + assert result.input_rank == 2 + assert result.output_rank == 2 + + def test_slice_narrows(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[2:8, 5:15] + assert result.domain.shape == (6, 10) + assert result.domain.origin == (0, 0) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 2 + assert result.output[0].stride == 1 + assert result.output[0].input_dimension == 0 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 5 + assert result.output[1].input_dimension == 1 + + def test_strided_slice(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t[::2] + assert result.domain.shape == (5,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 0 + assert result.output[0].stride == 2 + + def test_strided_slice_with_start(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t[1:9:3] + # indices: 1, 4, 7 -> 3 elements + assert result.domain.shape == (3,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 1 + assert result.output[0].stride == 3 + + def test_int_drops_dimension(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[3] + assert result.input_rank == 1 + assert result.output_rank == 2 + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 3 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + + def test_int_middle_dimension(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + result = t[:, 5, :] + assert result.input_rank == 2 + assert result.output_rank == 3 + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].input_dimension == 0 + assert isinstance(result.output[1], ConstantMap) + assert result.output[1].offset == 5 + assert isinstance(result.output[2], DimensionMap) + assert result.output[2].input_dimension == 1 + + def test_ellipsis(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + result = t[2:8, ...] + assert result.input_rank == 3 + assert result.domain.shape == (6, 20, 30) + + def test_newaxis(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[np.newaxis, :, :] + assert result.input_rank == 3 + assert result.domain.shape == (1, 10, 20) + assert result.output_rank == 2 + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].input_dimension == 1 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 2 + + def test_int_out_of_bounds(self) -> None: + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError): + t[10] + + def test_negative_int_is_literal(self) -> None: + """Negative indices are literal coordinates (TensorStore convention), + not 'from the end' like NumPy.""" + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError): + t[-1] # -1 is out of bounds for domain [0, 10) + + def test_negative_int_valid_with_negative_origin(self) -> None: + """Negative index is valid if the domain includes negative coordinates.""" + domain = IndexDomain(inclusive_min=(-5,), exclusive_max=(5,)) + t = IndexTransform.identity(domain) + result = t[-3] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == -3 + + def test_composition_of_slices(self) -> None: + """Slicing a sliced transform should compose offsets.""" + t = IndexTransform.from_shape((100,)) + result = t[10:50][5:20] + assert result.domain.shape == (15,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 15 + assert result.output[0].stride == 1 + + def test_composition_of_strides(self) -> None: + t = IndexTransform.from_shape((100,)) + result = t[::2][::3] + # t[::2] -> shape (50,), offset=0, stride=2 + # [::3] -> shape ceil(50/3)=17, offset=0, stride=2*3=6 + assert result.domain.shape == (17,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].stride == 6 + + def test_bare_int(self) -> None: + """Non-tuple selection.""" + t = IndexTransform.from_shape((10, 20)) + result = t[3] + assert result.input_rank == 1 + + def test_bare_slice(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[2:8] + assert result.domain.shape == (6, 20) + + +class TestBasicIndexingOnArrayMaps: + """When a transform already has ArrayMap outputs, basic indexing must + apply the corresponding operation to the index_array's axes.""" + + def test_int_on_array_map_drops_axis(self) -> None: + """Integer index on a dimension referenced by an ArrayMap should + index into the array on that axis.""" + arr = np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp) + # 2D input domain (3, 2), one ArrayMap output + t = IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=(ArrayMap(index_array=arr),), + ) + # Index with int on dim 0 -> pick row 1 -> arr[1, :] = [30, 40] + result = t[1] + assert result.input_rank == 1 + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([30, 40])) + + def test_slice_on_array_map(self) -> None: + """Slice on a dimension referenced by an ArrayMap should slice the array.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[1:4] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30, 40])) + + def test_strided_slice_on_array_map(self) -> None: + """Strided slice on ArrayMap should stride the array.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[::2] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 30, 50])) + + def test_newaxis_on_array_map(self) -> None: + """Newaxis should insert an axis in the index_array.""" + arr = np.array([10, 20, 30], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[np.newaxis, :] + assert result.input_rank == 2 + assert result.domain.shape == (1, 3) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].index_array.shape == (1, 3) + np.testing.assert_array_equal(result.output[0].index_array, np.array([[10, 20, 30]])) + + def test_int_drops_one_of_two_array_dims(self) -> None: + """2D array map, int on dim 0, slice on dim 1.""" + arr = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((2, 3)), + output=(ArrayMap(index_array=arr),), + ) + result = t[0, 1:3] + assert result.input_rank == 1 + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + # arr[0, 1:3] = [20, 30] + np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30])) + + +class TestIndexTransformOindex: + def test_oindex_int_array(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.oindex[idx, :] + assert result.input_rank == 2 + assert result.domain.shape == (3, 20) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx) + assert result.output[0].offset == 0 + assert result.output[0].stride == 1 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 1 + + def test_oindex_bool_array(self) -> None: + t = IndexTransform.from_shape((5,)) + mask = np.array([True, False, True, False, True]) + result = t.oindex[mask] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal( + result.output[0].index_array, np.array([0, 2, 4], dtype=np.intp) + ) + + def test_oindex_mixed(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([2, 4], dtype=np.intp) + result = t.oindex[idx, 5:15] + assert result.input_rank == 2 + assert result.domain.shape == (2, 10) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 5 + + def test_oindex_multiple_arrays(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + idx0 = np.array([1, 3], dtype=np.intp) + idx1 = np.array([5, 10, 15], dtype=np.intp) + result = t.oindex[idx0, :, idx1] + assert result.input_rank == 3 + assert result.domain.shape == (2, 20, 3) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + assert isinstance(result.output[2], ArrayMap) + + +class TestIndexTransformVindex: + def test_vindex_single_array(self) -> None: + t = IndexTransform.from_shape((10,)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.vindex[idx] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx) + + def test_vindex_broadcast(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([[1, 2], [3, 4]], dtype=np.intp) + idx1 = np.array([[10, 11], [12, 13]], dtype=np.intp) + result = t.vindex[idx0, idx1] + assert result.input_rank == 2 + assert result.domain.shape == (2, 2) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx0) + np.testing.assert_array_equal(result.output[1].index_array, idx1) + + def test_vindex_with_slice(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.vindex[idx, :, :] + assert result.input_rank == 3 + assert result.domain.shape == (3, 20, 30) + assert isinstance(result.output[0], ArrayMap) + + def test_vindex_bool_mask(self) -> None: + t = IndexTransform.from_shape((5,)) + mask = np.array([True, False, True, False, True]) + result = t.vindex[mask] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + + def test_vindex_broadcast_different_shapes(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([1, 2, 3], dtype=np.intp) + idx1 = np.array([[10], [11]], dtype=np.intp) + result = t.vindex[idx0, idx1] + assert result.input_rank == 2 + assert result.domain.shape == (2, 3) + + +class TestSelectionToTransform: + def test_basic_slice(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform((slice(2, 8), slice(5, 15)), t, "basic") + assert result.domain.shape == (6, 10) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 2 + + def test_basic_int(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform((3, slice(None)), t, "basic") + assert result.input_rank == 1 + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 3 + + def test_basic_ellipsis(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform(Ellipsis, t, "basic") + assert result.domain.shape == (10, 20) + + def test_orthogonal(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = selection_to_transform((idx, slice(None)), t, "orthogonal") + assert result.domain.shape == (3, 20) + assert isinstance(result.output[0], ArrayMap) + + def test_vectorized(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([1, 3], dtype=np.intp) + idx1 = np.array([5, 7], dtype=np.intp) + result = selection_to_transform((idx0, idx1), t, "vectorized") + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + def test_composition_with_non_identity(self) -> None: + """Indexing a sliced transform composes offsets.""" + t = IndexTransform.from_shape((100,))[10:50] + result = selection_to_transform(slice(5, 20), t, "basic") + assert result.domain.shape == (15,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 15 + + +class TestIndexTransformIntersect: + def test_constant_inside(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(0,), exclusive_max=(10,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert surviving is None + + def test_constant_outside(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) + assert result is None + + def test_dimension_partial(self) -> None: + """DimensionMap over [0,10) intersected with [5,15) narrows input to [5,10).""" + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))) + assert result is not None + restricted, surviving = result + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + assert surviving is None + + def test_dimension_no_overlap(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(20,), exclusive_max=(30,))) + assert result is None + + def test_dimension_strided(self) -> None: + """stride=2, offset=1 over [0,5): storage 1,3,5,7,9. Chunk [4,8).""" + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=2),), + ) + result = t.intersect(IndexDomain(inclusive_min=(4,), exclusive_max=(8,))) + assert result is not None + restricted, _surviving = result + # input 2->5, input 3->7. Both in [4,8). + assert restricted.domain.inclusive_min == (2,) + assert restricted.domain.exclusive_max == (4,) + + def test_array_partial(self) -> None: + arr = np.array([3, 8, 15, 22], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(ArrayMap(index_array=arr),), + ) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([8, 15])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([1, 2])) + + def test_array_none_inside(self) -> None: + arr = np.array([1, 2, 3], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + def test_2d_mixed(self) -> None: + """2D: ConstantMap on dim 0, DimensionMap on dim 1.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 5), exclusive_max=(10, 15)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert isinstance(restricted.output[1], DimensionMap) + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + + +class TestIndexTransformTranslate: + def test_translate_constant(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.translate((-5,)) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 0 + + def test_translate_dimension(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t.translate((-3,)) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == -3 + assert result.output[0].stride == 1 + + def test_translate_array(self) -> None: + arr = np.array([5, 10], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((2,)), + output=(ArrayMap(index_array=arr, offset=3),), + ) + result = t.translate((-3,)) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 0 + np.testing.assert_array_equal(result.output[0].index_array, arr) + + def test_translate_2d(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t.translate((-5, -10)) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == -5 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == -10 From 34b0384f6eb0bb19def500fb3e766087c2fc6fa4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 18:07:26 -0400 Subject: [PATCH 05/24] feat(_transforms): port composition module Defines compose(outer, inner) -> IndexTransform and the per-output-map dispatch helpers. --- src/zarr/core/_transforms/composition.py | 113 +++++++++++++++ tests/test_transforms/test_composition.py | 166 ++++++++++++++++++++++ 2 files changed, 279 insertions(+) create mode 100644 src/zarr/core/_transforms/composition.py create mode 100644 tests/test_transforms/test_composition.py diff --git a/src/zarr/core/_transforms/composition.py b/src/zarr/core/_transforms/composition.py new file mode 100644 index 0000000000..c762897d5b --- /dev/null +++ b/src/zarr/core/_transforms/composition.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap +from zarr.core._transforms.transform import IndexTransform + + +def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform: + """Compose two IndexTransforms. + + ``outer`` maps user coords (rank m) to intermediate coords (rank n). + ``inner`` maps intermediate coords (rank n) to storage coords (rank p). + The result maps user coords (rank m) to storage coords (rank p). + + Precondition: ``outer.output_rank == inner.domain.ndim``. + """ + if outer.output_rank != inner.domain.ndim: + raise ValueError( + f"outer output rank ({outer.output_rank}) must match inner input rank " + f"({inner.domain.ndim})" + ) + + result_output = [_compose_single(outer, inner_map) for inner_map in inner.output] + + return IndexTransform(domain=outer.domain, output=tuple(result_output)) + + +def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap: + """Compose a single inner output map with the full outer transform.""" + if isinstance(inner_map, ConstantMap): + return ConstantMap(offset=inner_map.offset) + + if isinstance(inner_map, DimensionMap): + return _compose_dimension(outer, inner_map) + + if isinstance(inner_map, ArrayMap): + return _compose_array(outer, inner_map) + + raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover + + +def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap: + """Compose when inner is a DimensionMap. + + storage = offset_i + stride_i * intermediate[dim_i] + where intermediate[dim_i] = outer.output[dim_i](user_input) + """ + dim_i = inner_map.input_dimension + offset_i = inner_map.offset + stride_i = inner_map.stride + outer_map = outer.output[dim_i] + + if isinstance(outer_map, ConstantMap): + return ConstantMap(offset=offset_i + stride_i * outer_map.offset) + + if isinstance(outer_map, DimensionMap): + return DimensionMap( + input_dimension=outer_map.input_dimension, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + if isinstance(outer_map, ArrayMap): + return ArrayMap( + index_array=outer_map.index_array, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover + + +def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap: + """Compose when inner is an ArrayMap. + + storage = offset_i + stride_i * arr_i[intermediate] + We need to evaluate arr_i at the intermediate coordinates produced by outer. + """ + arr_i = inner_map.index_array + offset_i = inner_map.offset + stride_i = inner_map.stride + + # Check if all outer outputs are constant + all_constant = all(isinstance(m, ConstantMap) for m in outer.output) + + if all_constant: + # Evaluate arr_i at the single constant point + idx = tuple(m.offset for m in outer.output if isinstance(m, ConstantMap)) + value = int(arr_i[idx]) + return ConstantMap(offset=offset_i + stride_i * value) + + # For 1D inner array with a single outer output (simple case) + if arr_i.ndim == 1 and len(outer.output) == 1: + outer_map = outer.output[0] + + if isinstance(outer_map, DimensionMap): + dim_size = outer.domain.shape[outer_map.input_dimension] + user_indices = np.arange(dim_size, dtype=np.intp) + intermediate_vals = outer_map.offset + outer_map.stride * user_indices + new_arr = arr_i[intermediate_vals] + return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) + + if isinstance(outer_map, ArrayMap): + intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array + new_arr = arr_i[intermediate_vals] + return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) + + # General multi-dim case: not yet implemented + raise NotImplementedError( + "Composing a multi-dimensional inner array map with non-constant outer maps " + "is not yet supported." + ) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py new file mode 100644 index 0000000000..b5060a7b9e --- /dev/null +++ b/tests/test_transforms/test_composition.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from zarr.core._transforms.composition import compose +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform + + +class TestComposeConstantInner: + """Inner = constant. Result is always constant.""" + + def test_constant_inner_any_outer(self) -> None: + outer = IndexTransform.from_shape((5,)) + inner = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=42),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +class TestComposeDimensionInner: + """Inner = DimensionMap.""" + + def test_dimension_inner_constant_outer(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 25 + + def test_dimension_inner_dimension_outer(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 25 + assert result.output[0].stride == 6 + assert result.output[0].input_dimension == 0 + + def test_dimension_inner_array_outer(self) -> None: + arr = np.array([0, 2, 4], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=5, stride=2),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 25 + assert result.output[0].stride == 6 + np.testing.assert_array_equal(result.output[0].index_array, arr) + + +class TestComposeArrayInner: + """Inner = ArrayMap.""" + + def test_array_inner_constant_outer(self) -> None: + inner_arr = np.array([10, 20, 30], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=1),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 20 + + def test_array_inner_array_outer(self) -> None: + outer_arr = np.array([0, 2, 1], dtype=np.intp) + inner_arr = np.array([10, 20, 30], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=outer_arr, offset=0, stride=1),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ArrayMap) + expected = np.array([10, 30, 20], dtype=np.intp) + np.testing.assert_array_equal(result.output[0].index_array, expected) + + +class TestComposeMultiDim: + def test_2d_identity_compose(self) -> None: + a = IndexTransform.from_shape((10, 20)) + b = IndexTransform.from_shape((10, 20)) + result = compose(a, b) + assert result.domain.shape == (10, 20) + for i in range(2): + m = result.output[i] + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + def test_mixed_map_types(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10, 10)), + output=( + DimensionMap(input_dimension=0, offset=2, stride=3), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 17 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + assert result.output[1].offset == 0 + assert result.output[1].stride == 1 + + def test_rank_mismatch_raises(self) -> None: + outer = IndexTransform.from_shape((10,)) + inner = IndexTransform.from_shape((10, 20)) + with pytest.raises(ValueError, match="rank"): + compose(outer, inner) + + +class TestComposeChain: + def test_three_transforms(self) -> None: + a = IndexTransform.from_shape((100,)) + b = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=1),), + ) + c = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + bc = compose(b, c) + abc = compose(a, bc) + assert isinstance(abc.output[0], DimensionMap) + assert abc.output[0].offset == 25 + assert abc.output[0].stride == 2 From 5a5ad23990677b73da3b8b2944782f871e047c55 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 18:12:05 -0400 Subject: [PATCH 06/24] feat(_transforms): port chunk_resolution module Defines iter_chunk_transforms and sub_transform_to_selections, which bridge a transform to per-chunk selections. No callers in this PR; included so the package is internally complete and ready for future internal-rewiring work. --- src/zarr/core/_transforms/chunk_resolution.py | 207 ++++++++++++++++++ .../test_transforms/test_chunk_resolution.py | 181 +++++++++++++++ 2 files changed, 388 insertions(+) create mode 100644 src/zarr/core/_transforms/chunk_resolution.py create mode 100644 tests/test_transforms/test_chunk_resolution.py diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py new file mode 100644 index 0000000000..56f2c7ac93 --- /dev/null +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -0,0 +1,207 @@ +"""Chunk resolution — mapping transforms to chunk-level I/O. + +Given an ``IndexTransform`` (which coordinates a user wants to access) and a +``ChunkGrid`` (how storage is divided into chunks), chunk resolution answers: + + For each chunk, which storage coordinates does this transform touch, + and where do those values land in the output buffer? + +The algorithm is: + +1. **Enumerate candidate chunks** — determine which chunks could possibly + be touched by the transform's output coordinate ranges. + +2. **Intersect** — for each candidate chunk, call + ``transform.intersect(chunk_domain)`` to restrict the transform to + coordinates within that chunk. If the intersection is empty, skip it. + +3. **Translate** — shift the restricted transform to chunk-local coordinates + via ``transform.translate(-chunk_origin)``. + +4. **Yield** — produce ``(chunk_coords, local_transform, surviving_indices)`` + triples that the codec pipeline consumes. + +``sub_transform_to_selections`` bridges from the transform representation +back to the raw ``(chunk_selection, out_selection, drop_axes)`` tuples that +the current codec pipeline expects. This bridge will go away when the codec +pipeline accepts transforms natively. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform + +if TYPE_CHECKING: + from collections.abc import Iterator + + from zarr.core.chunk_grids import ChunkGrid + +ChunkTransformResult = tuple[ + tuple[int, ...], + IndexTransform, + np.ndarray[Any, np.dtype[np.intp]] | None, +] + + +def iter_chunk_transforms( + transform: IndexTransform, + chunk_grid: ChunkGrid, +) -> Iterator[ChunkTransformResult]: + """Resolve a composed IndexTransform against a ChunkGrid. + + Yields ``(chunk_coords, sub_transform, out_indices)`` triples: + + - ``chunk_coords``: which chunk to access. + - ``sub_transform``: maps output buffer coords to chunk-local coords. + - ``out_indices``: for vectorized/array indexing, the output scatter + indices (integer array). ``None`` for basic/slice indexing. + """ + dim_grids = chunk_grid._dimensions + + # Enumerate all possible chunks via cartesian product of per-dim chunk ranges + # For each candidate chunk, intersect the transform with the chunk domain. + # The transform.intersect method handles both orthogonal and vectorized cases. + chunk_ranges: list[range] = [] + for out_dim, m in enumerate(transform.output): + dg = dim_grids[out_dim] + if isinstance(m, ConstantMap): + # Single chunk + c = dg.index_to_chunk(m.offset) + chunk_ranges.append(range(c, c + 1)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + dim_lo = transform.domain.inclusive_min[d] + dim_hi = transform.domain.exclusive_max[d] + if dim_lo >= dim_hi: + return # empty domain + if m.stride > 0: + s_min = m.offset + m.stride * dim_lo + s_max = m.offset + m.stride * (dim_hi - 1) + else: + s_min = m.offset + m.stride * (dim_hi - 1) + s_max = m.offset + m.stride * dim_lo + first = dg.index_to_chunk(s_min) + last = dg.index_to_chunk(s_max) + chunk_ranges.append(range(first, last + 1)) + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + flat = storage.ravel().astype(np.intp) + chunk_ids = dg.indices_to_chunks(flat) + first = int(chunk_ids.min()) + last = int(chunk_ids.max()) + chunk_ranges.append(range(first, last + 1)) + + import itertools + + for chunk_coords_tuple in itertools.product(*chunk_ranges): + chunk_coords = tuple(int(c) for c in chunk_coords_tuple) + + # Build the chunk domain in storage space + chunk_min: list[int] = [] + chunk_max: list[int] = [] + chunk_shift: list[int] = [] + for out_dim, c in enumerate(chunk_coords): + dg = dim_grids[out_dim] + c_start = dg.chunk_offset(c) + c_size = dg.chunk_size(c) + chunk_min.append(c_start) + chunk_max.append(c_start + c_size) + chunk_shift.append(-c_start) + + chunk_domain = IndexDomain( + inclusive_min=tuple(chunk_min), + exclusive_max=tuple(chunk_max), + ) + + # Intersect transform with chunk domain + result = transform.intersect(chunk_domain) + if result is None: + continue + + restricted, surviving = result + + # Translate to chunk-local coordinates + local = restricted.translate(tuple(chunk_shift)) + + yield (chunk_coords, local, surviving) + + +def sub_transform_to_selections( + sub_transform: IndexTransform, + out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None, +) -> tuple[ + tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[int, ...], +]: + """Convert a chunk-local sub-transform to raw selections for the codec pipeline. + + Parameters + ---------- + sub_transform + A chunk-local IndexTransform (output maps already translated to + chunk-local coordinates). + out_indices + For vectorized indexing: the output scatter indices for this chunk. + None for orthogonal/basic indexing. + + Returns + ------- + tuple + ``(chunk_selection, out_selection, drop_axes)`` + """ + chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + drop_axes: list[int] = [] + + for m in sub_transform.output: + if isinstance(m, ConstantMap): + chunk_sel.append(m.offset) + elif isinstance(m, DimensionMap): + dim_lo = sub_transform.domain.inclusive_min[m.input_dimension] + dim_hi = sub_transform.domain.exclusive_max[m.input_dimension] + start = m.offset + m.stride * dim_lo + stop = m.offset + m.stride * dim_hi + if m.stride < 0: + start, stop = stop + 1, start + 1 + chunk_sel.append(slice(start, stop, m.stride)) + elif isinstance(m, ArrayMap): + if m.offset == 0 and m.stride == 1: + chunk_sel.append(m.index_array) + else: + storage_coords = m.offset + m.stride * m.index_array + chunk_sel.append(storage_coords.astype(np.intp)) + + # Build out_sel: one entry per non-dropped output dim. + out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + + # Vectorized: multiple correlated ArrayMaps share one scatter index + is_vectorized = ( + out_indices is not None + and sum(1 for m in sub_transform.output if isinstance(m, ArrayMap)) >= 2 + ) + + if is_vectorized: + assert out_indices is not None + out_sel.append(out_indices) + else: + for m in sub_transform.output: + if isinstance(m, ConstantMap): + continue + if isinstance(m, DimensionMap): + lo = sub_transform.domain.inclusive_min[m.input_dimension] + hi = sub_transform.domain.exclusive_max[m.input_dimension] + out_sel.append(slice(lo, hi)) + elif isinstance(m, ArrayMap): + if out_indices is not None: + # Orthogonal ArrayMap: out_indices has the surviving positions + out_sel.append(out_indices) + else: + out_sel.append(slice(0, len(m.index_array))) + + return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes) diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py new file mode 100644 index 0000000000..ba27964028 --- /dev/null +++ b/tests/test_transforms/test_chunk_resolution.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core._transforms.chunk_resolution import ( + iter_chunk_transforms, + sub_transform_to_selections, +) +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform +from zarr.core.chunk_grids import ChunkGrid, FixedDimension + + +class TestChunkResolutionIdentity: + def test_single_chunk(self) -> None: + """Array fits in one chunk.""" + t = IndexTransform.from_shape((10,)) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=10),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + coords, sub_t, _ = results[0] + assert coords == (0,) + assert sub_t.domain.shape == (10,) + + def test_multiple_chunks_1d(self) -> None: + """1D array spanning 3 chunks.""" + t = IndexTransform.from_shape((30,)) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 3 + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + def test_multiple_chunks_2d(self) -> None: + """2D array spanning 2x3 chunks.""" + t = IndexTransform.from_shape((20, 30)) + grid = ChunkGrid( + dimensions=( + FixedDimension(size=10, extent=20), + FixedDimension(size=10, extent=30), + ) + ) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 6 + coords_list = [r[0] for r in results] + assert (0, 0) in coords_list + assert (1, 2) in coords_list + + +class TestChunkResolutionSliced: + def test_slice_within_chunk(self) -> None: + """Slice that falls within a single chunk.""" + t = IndexTransform.from_shape((100,))[5:8] + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + coords, sub_t, _ = results[0] + assert coords == (0,) + assert isinstance(sub_t.output[0], DimensionMap) + assert sub_t.output[0].offset == 5 + + def test_slice_across_chunks(self) -> None: + """Slice that spans two chunks.""" + t = IndexTransform.from_shape((100,))[8:15] + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 2 + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + + +class TestChunkResolutionConstant: + def test_integer_index(self) -> None: + """Integer index produces constant map — single chunk per constant dim.""" + t = IndexTransform.from_shape((100, 100))[25, :] + grid = ChunkGrid( + dimensions=( + FixedDimension(size=10, extent=100), + FixedDimension(size=10, extent=100), + ) + ) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 10 + for coords, _, _ in results: + assert coords[0] == 2 + + +class TestChunkResolutionArray: + def test_array_index(self) -> None: + """Array index map — chunks determined by array values.""" + idx = np.array([5, 15, 25], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=idx),), + ) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) + results = list(iter_chunk_transforms(t, grid)) + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + +class TestSubTransformToSelections: + def test_constant_map(self) -> None: + """ConstantMap produces int selection + drop axis.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (5,) + assert out_sel == () + assert drop_axes == () + + def test_dimension_map_stride_1(self) -> None: + """DimensionMap with stride=1 produces contiguous slice.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=3, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (slice(3, 13, 1),) + assert out_sel == (slice(0, 10),) + assert drop_axes == () + + def test_dimension_map_strided(self) -> None: + """DimensionMap with stride>1 produces strided slice.""" + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=2, stride=3),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (slice(2, 17, 3),) + assert out_sel == (slice(0, 5),) + assert drop_axes == () + + def test_array_map(self) -> None: + """ArrayMap produces integer array selection.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=0, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], arr) + # Without chunk_mask, out_sel falls back to domain-based slices + assert out_sel == (slice(0, 3),) + assert drop_axes == () + + def test_array_map_with_offset_stride(self) -> None: + """ArrayMap with offset and stride computes storage coords.""" + arr = np.array([0, 1, 2], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=10, stride=5),), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], np.array([10, 15, 20])) + assert drop_axes == () + + def test_mixed_maps_2d(self) -> None: + """Mix of ConstantMap and DimensionMap.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel[0] == 5 + assert chunk_sel[1] == slice(0, 10, 1) + # drop_axes is empty — integer in chunk_sel naturally drops the dim via numpy + assert drop_axes == () From 5370f917891c8c93516ea2b6eb75aeb6745beb14 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 18:41:07 -0400 Subject: [PATCH 07/24] feat(_transforms): expose package exports Re-export IndexDomain, IndexTransform, the OutputIndexMap variants, and compose from the package root. JSON helpers are intentionally excluded; chunk_resolution exports are intentionally not re-exported (no callers in this PR). --- src/zarr/core/_transforms/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/zarr/core/_transforms/__init__.py b/src/zarr/core/_transforms/__init__.py index e29a0ccf9b..84db415f79 100644 --- a/src/zarr/core/_transforms/__init__.py +++ b/src/zarr/core/_transforms/__init__.py @@ -17,3 +17,23 @@ output dimension can depend on the input (see ``output_map.py``) - ``compose`` -- chain two transforms into one """ + +from zarr.core._transforms.composition import compose +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ( + ArrayMap, + ConstantMap, + DimensionMap, + OutputIndexMap, +) +from zarr.core._transforms.transform import IndexTransform + +__all__ = [ + "ArrayMap", + "ConstantMap", + "DimensionMap", + "IndexDomain", + "IndexTransform", + "OutputIndexMap", + "compose", +] From fca29269df969031c48b7ee18e572e3b93985b10 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 6 May 2026 20:08:32 -0400 Subject: [PATCH 08/24] docs(_transforms): use markdown single-backticks; fix stale lazy accessor reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The docstrings inherited RST-style double-backticks from the reference branch. Convert to markdown single-backticks throughout (zarr-python's docs are markdown-rendered, so the doubles render as literal `` text). Also drop the stale "Every Array holds a transform" / "Array.z[...]" comment from transform.py — those statements describe an earlier design where the transform lived on Array; the actual lazy view design parks the transform on a separate wrapper type. --- src/zarr/core/_transforms/__init__.py | 10 +++--- src/zarr/core/_transforms/chunk_resolution.py | 26 +++++++------- src/zarr/core/_transforms/composition.py | 6 ++-- src/zarr/core/_transforms/domain.py | 6 ++-- src/zarr/core/_transforms/output_map.py | 36 +++++++++---------- src/zarr/core/_transforms/transform.py | 35 +++++++++--------- 6 files changed, 59 insertions(+), 60 deletions(-) diff --git a/src/zarr/core/_transforms/__init__.py b/src/zarr/core/_transforms/__init__.py index 84db415f79..6ac4d85343 100644 --- a/src/zarr/core/_transforms/__init__.py +++ b/src/zarr/core/_transforms/__init__.py @@ -11,11 +11,11 @@ Key types: -- ``IndexDomain`` -- a rectangular region of integer coordinates -- ``IndexTransform`` -- maps input coordinates to storage coordinates -- ``ConstantMap``, ``DimensionMap``, ``ArrayMap`` -- the three ways a single - output dimension can depend on the input (see ``output_map.py``) -- ``compose`` -- chain two transforms into one +- `IndexDomain` -- a rectangular region of integer coordinates +- `IndexTransform` -- maps input coordinates to storage coordinates +- `ConstantMap`, `DimensionMap`, `ArrayMap` -- the three ways a single + output dimension can depend on the input (see `output_map.py`) +- `compose` -- chain two transforms into one """ from zarr.core._transforms.composition import compose diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py index 56f2c7ac93..179e2ed2c8 100644 --- a/src/zarr/core/_transforms/chunk_resolution.py +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -1,7 +1,7 @@ """Chunk resolution — mapping transforms to chunk-level I/O. -Given an ``IndexTransform`` (which coordinates a user wants to access) and a -``ChunkGrid`` (how storage is divided into chunks), chunk resolution answers: +Given an `IndexTransform` (which coordinates a user wants to access) and a +`ChunkGrid` (how storage is divided into chunks), chunk resolution answers: For each chunk, which storage coordinates does this transform touch, and where do those values land in the output buffer? @@ -12,17 +12,17 @@ be touched by the transform's output coordinate ranges. 2. **Intersect** — for each candidate chunk, call - ``transform.intersect(chunk_domain)`` to restrict the transform to + `transform.intersect(chunk_domain)` to restrict the transform to coordinates within that chunk. If the intersection is empty, skip it. 3. **Translate** — shift the restricted transform to chunk-local coordinates - via ``transform.translate(-chunk_origin)``. + via `transform.translate(-chunk_origin)`. -4. **Yield** — produce ``(chunk_coords, local_transform, surviving_indices)`` +4. **Yield** — produce `(chunk_coords, local_transform, surviving_indices)` triples that the codec pipeline consumes. -``sub_transform_to_selections`` bridges from the transform representation -back to the raw ``(chunk_selection, out_selection, drop_axes)`` tuples that +`sub_transform_to_selections` bridges from the transform representation +back to the raw `(chunk_selection, out_selection, drop_axes)` tuples that the current codec pipeline expects. This bridge will go away when the codec pipeline accepts transforms natively. """ @@ -55,12 +55,12 @@ def iter_chunk_transforms( ) -> Iterator[ChunkTransformResult]: """Resolve a composed IndexTransform against a ChunkGrid. - Yields ``(chunk_coords, sub_transform, out_indices)`` triples: + Yields `(chunk_coords, sub_transform, out_indices)` triples: - - ``chunk_coords``: which chunk to access. - - ``sub_transform``: maps output buffer coords to chunk-local coords. - - ``out_indices``: for vectorized/array indexing, the output scatter - indices (integer array). ``None`` for basic/slice indexing. + - `chunk_coords`: which chunk to access. + - `sub_transform`: maps output buffer coords to chunk-local coords. + - `out_indices`: for vectorized/array indexing, the output scatter + indices (integer array). `None` for basic/slice indexing. """ dim_grids = chunk_grid._dimensions @@ -154,7 +154,7 @@ def sub_transform_to_selections( Returns ------- tuple - ``(chunk_selection, out_selection, drop_axes)`` + `(chunk_selection, out_selection, drop_axes)` """ chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] drop_axes: list[int] = [] diff --git a/src/zarr/core/_transforms/composition.py b/src/zarr/core/_transforms/composition.py index c762897d5b..40bba89e95 100644 --- a/src/zarr/core/_transforms/composition.py +++ b/src/zarr/core/_transforms/composition.py @@ -9,11 +9,11 @@ def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform: """Compose two IndexTransforms. - ``outer`` maps user coords (rank m) to intermediate coords (rank n). - ``inner`` maps intermediate coords (rank n) to storage coords (rank p). + `outer` maps user coords (rank m) to intermediate coords (rank n). + `inner` maps intermediate coords (rank n) to storage coords (rank p). The result maps user coords (rank m) to storage coords (rank p). - Precondition: ``outer.output_rank == inner.domain.ndim``. + Precondition: `outer.output_rank == inner.domain.ndim`. """ if outer.output_rank != inner.domain.ndim: raise ValueError( diff --git a/src/zarr/core/_transforms/domain.py b/src/zarr/core/_transforms/domain.py index 90bcc08ace..e43ec30836 100644 --- a/src/zarr/core/_transforms/domain.py +++ b/src/zarr/core/_transforms/domain.py @@ -1,13 +1,13 @@ """Index domains — rectangular regions in N-dimensional integer space. -An ``IndexDomain`` represents the set of valid coordinates for an array or +An `IndexDomain` represents the set of valid coordinates for an array or array view. It is the cartesian product of per-dimension integer ranges:: IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) # represents {(i, j) : 2 <= i < 10, 5 <= j < 20} Unlike NumPy, domains can have **non-zero origins**. After slicing -``arr[5:10]``, the result has origin 5 and shape 5 — coordinates 5 through +`arr[5:10]`, the result has origin 5 and shape 5 — coordinates 5 through 9 are valid. This follows the TensorStore convention. """ @@ -22,7 +22,7 @@ class IndexDomain: """A rectangular region in N-dimensional index space. The valid coordinates are the integers in - ``[inclusive_min[d], exclusive_max[d])`` for each dimension ``d``. + `[inclusive_min[d], exclusive_max[d])` for each dimension `d`. """ inclusive_min: tuple[int, ...] diff --git a/src/zarr/core/_transforms/output_map.py b/src/zarr/core/_transforms/output_map.py index 5e17a0ae82..f1b32aa95e 100644 --- a/src/zarr/core/_transforms/output_map.py +++ b/src/zarr/core/_transforms/output_map.py @@ -4,18 +4,18 @@ an array access will touch. Conceptually it is a **set of integers**. Three representations cover the cases that arise in practice: -- ``ConstantMap(offset=5)`` — a singleton set: ``{5}`` -- ``DimensionMap(input_dimension=0, offset=3, stride=2)`` over input ``[0, 5)`` - — an arithmetic progression: ``{3, 5, 7, 9, 11}`` -- ``ArrayMap(index_array=[1, 5, 9])`` — an explicit enumeration: ``{1, 5, 9}`` +- `ConstantMap(offset=5)` — a singleton set: `{5}` +- `DimensionMap(input_dimension=0, offset=3, stride=2)` over input `[0, 5)` + — an arithmetic progression: `{3, 5, 7, 9, 11}` +- `ArrayMap(index_array=[1, 5, 9])` — an explicit enumeration: `{1, 5, 9}` Every output map supports two set-theoretic operations (defined on -``IndexTransform``, which provides the input domain context these maps lack): +`IndexTransform`, which provides the input domain context these maps lack): - **intersect** — restrict to coordinates within a range (e.g., a chunk). - ``{3, 5, 7, 9, 11} ∩ [4, 8) = {5, 7}`` + `{3, 5, 7, 9, 11} ∩ [4, 8) = {5, 7}` - **translate** — shift every coordinate by a constant (e.g., make chunk-local). - ``{5, 7} - 4 = {1, 3}`` + `{5, 7} - 4 = {1, 3}` These two operations are the foundation of chunk resolution: for each chunk, intersect the map with the chunk's range, then translate to chunk-local @@ -23,13 +23,13 @@ The three types exist because they trade off generality for efficiency: -- ``ConstantMap``: O(1) storage, O(1) intersection -- ``DimensionMap``: O(1) storage, O(1) intersection (analytical) -- ``ArrayMap``: O(n) storage, O(n) intersection (must scan the array) +- `ConstantMap`: O(1) storage, O(1) intersection +- `DimensionMap`: O(1) storage, O(1) intersection (analytical) +- `ArrayMap`: O(n) storage, O(n) intersection (must scan the array) -Collapsing everything to ``ArrayMap`` would be correct but wasteful — a +Collapsing everything to `ArrayMap` would be correct but wasteful — a billion-element slice would materialize a billion coordinates just to group -them by chunk, when ``DimensionMap`` does it with three integers. +them by chunk, when `DimensionMap` does it with three integers. """ from __future__ import annotations @@ -46,7 +46,7 @@ class ConstantMap: """A singleton set: one storage coordinate. - Represents ``{offset}``. Arises from integer indexing (e.g., ``arr[5]`` + Represents `{offset}`. Arises from integer indexing (e.g., `arr[5]` fixes one dimension to coordinate 5). """ @@ -57,9 +57,9 @@ class ConstantMap: class DimensionMap: """An arithmetic progression of storage coordinates. - Represents ``{offset + stride * i : i in input_range}``, where the input - range comes from the enclosing ``IndexTransform``'s domain. Arises from - slice indexing (e.g., ``arr[2:10:3]`` gives offset=2, stride=3). + Represents `{offset + stride * i : i in input_range}`, where the input + range comes from the enclosing `IndexTransform`'s domain. Arises from + slice indexing (e.g., `arr[2:10:3]` gives offset=2, stride=3). """ input_dimension: int @@ -71,8 +71,8 @@ class DimensionMap: class ArrayMap: """An explicit enumeration of storage coordinates. - Represents ``{offset + stride * index_array[i] : i in input_range}``. - Arises from fancy indexing (e.g., ``arr[[1, 5, 9]]`` or boolean masks). + Represents `{offset + stride * index_array[i] : i in input_range}`. + Arises from fancy indexing (e.g., `arr[[1, 5, 9]]` or boolean masks). """ index_array: npt.NDArray[np.intp] diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index d8dbf11b52..3b160cb18d 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -1,13 +1,13 @@ """Index transforms — composable, lazy coordinate mappings. -An ``IndexTransform`` pairs an **input domain** (the coordinates a user sees) +An `IndexTransform` pairs an **input domain** (the coordinates a user sees) with a tuple of **output maps** (the storage coordinates those inputs map to). -One output map per storage dimension. See ``output_map.py`` for the three +One output map per storage dimension. See `output_map.py` for the three output map types. Key operations: -- **Indexing** (``transform[2:8]``, ``.oindex[idx]``, ``.vindex[idx]``) — +- **Indexing** (`transform[2:8]`, `.oindex[idx]`, `.vindex[idx]`) — produces a new transform with a narrower input domain and adjusted output maps. No I/O occurs. This is how lazy slicing works. @@ -18,12 +18,11 @@ - **translate(shift)** — shift all output coordinates. This makes coordinates chunk-local: "express my coordinates relative to the chunk origin." -- **compose(outer, inner)** — chain two transforms. See ``composition.py``. +- **compose(outer, inner)** — chain two transforms. See `composition.py`. The transform is the atomic unit that connects user-facing indexing to -chunk-level I/O. Every ``Array`` holds a transform (identity by default). -``Array.z[...]`` composes a new transform lazily. Reading resolves the -transform against the chunk grid via intersect + translate. +chunk-level I/O. Indexing into a lazy view composes a new transform; reading +resolves the transform against the chunk grid via intersect + translate. """ from __future__ import annotations @@ -42,15 +41,15 @@ class IndexTransform: """A composable mapping from input coordinates to storage coordinates. - An ``IndexTransform`` has: + An `IndexTransform` has: - - ``domain``: an ``IndexDomain`` describing the valid input coordinates + - `domain`: an `IndexDomain` describing the valid input coordinates (the user-facing shape, possibly with non-zero origin). - - ``output``: a tuple of output maps (one per storage dimension), each + - `output`: a tuple of output maps (one per storage dimension), each describing which storage coordinates the inputs touch. For a freshly opened array, the transform is the identity: input - coordinate ``i`` maps to storage coordinate ``i``. Indexing operations + coordinate `i` maps to storage coordinate `i`. Indexing operations compose new transforms without I/O. """ @@ -90,10 +89,10 @@ def from_shape(cls, shape: tuple[int, ...]) -> IndexTransform: @property def selection_repr(self) -> str: - """Compact domain string, e.g. ``'{ [2, 8), [0, 10) }'``. + """Compact domain string, e.g. `'{ [2, 8), [0, 10) }'`. Follows TensorStore's IndexDomain notation: each dimension shown - as ``[inclusive_min, exclusive_max)`` with stride annotation if not 1. + as `[inclusive_min, exclusive_max)` with stride annotation if not 1. Constant (integer-indexed) dimensions show as a single value. Array-indexed dimensions show the set of selected coordinates. """ @@ -138,16 +137,16 @@ def intersect( ) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: """Restrict this transform to storage coordinates within output_domain. - Returns ``(restricted_transform, surviving_indices)`` or None if empty. + Returns `(restricted_transform, surviving_indices)` or None if empty. - ``surviving_indices`` is an integer array of which input positions + `surviving_indices` is an integer array of which input positions survived the intersection (for ArrayMap dimensions), or None if all positions survived (ConstantMap/DimensionMap only). """ return _intersect(self, output_domain) def translate(self, shift: tuple[int, ...]) -> IndexTransform: - """Shift all output coordinates by ``shift``.""" + """Shift all output coordinates by `shift`.""" if len(shift) != self.output_rank: raise ValueError(f"shift must have length {self.output_rank}, got {len(shift)}") new_output: list[OutputIndexMap] = [] @@ -545,7 +544,7 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra class _OIndexHelper: - """Helper that provides orthogonal (outer) indexing via ``transform.oindex[...]``.""" + """Helper that provides orthogonal (outer) indexing via `transform.oindex[...]`.""" def __init__(self, transform: IndexTransform) -> None: self._transform = transform @@ -671,7 +670,7 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: class _VIndexHelper: - """Helper that provides vectorized (fancy) indexing via ``transform.vindex[...]``.""" + """Helper that provides vectorized (fancy) indexing via `transform.vindex[...]`.""" def __init__(self, transform: IndexTransform) -> None: self._transform = transform From cf56e5b10c3413d9107ebc299ce3567b0ec3e0b2 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 09:28:25 -0400 Subject: [PATCH 09/24] test(_transforms): rewrite output_map tests in parametrized style Replace class-based test grouping with top-level parametrized functions. Each test covers either all success branches or all error branches of a single concern. Adds tests/test_transforms/conftest.py with project-style `Expect[TIn, TOut]` and `ExpectErr[TIn]` test-case helpers. The new mutation_errors test actually exercises FrozenInstanceError (the previous test_frozen variants only checked isinstance, which the @dataclass decorator guarantees regardless of frozen=True). --- tests/test_transforms/conftest.py | 22 ++++ tests/test_transforms/test_output_map.py | 134 +++++++++++++++-------- 2 files changed, 109 insertions(+), 47 deletions(-) create mode 100644 tests/test_transforms/conftest.py diff --git a/tests/test_transforms/conftest.py b/tests/test_transforms/conftest.py new file mode 100644 index 0000000000..12c9896c1c --- /dev/null +++ b/tests/test_transforms/conftest.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """Model an input and an expected output value for a test case.""" + + input: TIn + expected: TOut + id: str + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """Model an input and an expected error for a test case.""" + + input: TIn + msg: str + exception_cls: type[Exception] + id: str diff --git a/tests/test_transforms/test_output_map.py b/tests/test_transforms/test_output_map.py index 358ea6ed6c..5e1695172f 100644 --- a/tests/test_transforms/test_output_map.py +++ b/tests/test_transforms/test_output_map.py @@ -1,56 +1,96 @@ from __future__ import annotations +from dataclasses import FrozenInstanceError, asdict +from typing import Any + import numpy as np +import pytest +from tests.test_transforms.conftest import Expect, ExpectErr from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap -class TestConstantMap: - def test_construction(self) -> None: - m = ConstantMap(offset=42) - assert m.offset == 42 - - def test_default_offset(self) -> None: - m = ConstantMap() - assert m.offset == 0 - - def test_frozen(self) -> None: - m = ConstantMap(offset=5) - assert isinstance(m, ConstantMap) - - -class TestDimensionMap: - def test_construction(self) -> None: - m = DimensionMap(input_dimension=3, offset=5, stride=2) - assert m.input_dimension == 3 - assert m.offset == 5 - assert m.stride == 2 - - def test_defaults(self) -> None: - m = DimensionMap(input_dimension=0) - assert m.offset == 0 - assert m.stride == 1 - - def test_frozen(self) -> None: - m = DimensionMap(input_dimension=0) - assert isinstance(m, DimensionMap) - - -class TestArrayMap: - def test_construction(self) -> None: - arr = np.array([1, 3, 5], dtype=np.intp) - m = ArrayMap(index_array=arr, offset=10, stride=2) - assert m.offset == 10 - assert m.stride == 2 - np.testing.assert_array_equal(m.index_array, arr) +@pytest.mark.parametrize( + "case", + [ + Expect( + input=ConstantMap(offset=42), + expected={"offset": 42}, + id="ConstantMap-explicit-offset", + ), + Expect( + input=ConstantMap(), + expected={"offset": 0}, + id="ConstantMap-default-offset", + ), + Expect( + input=DimensionMap(input_dimension=3, offset=5, stride=2), + expected={"input_dimension": 3, "offset": 5, "stride": 2}, + id="DimensionMap-all-fields", + ), + Expect( + input=DimensionMap(input_dimension=0), + expected={"input_dimension": 0, "offset": 0, "stride": 1}, + id="DimensionMap-defaults", + ), + Expect( + input=ArrayMap(index_array=np.array([1, 3, 5], dtype=np.intp), offset=10, stride=2), + expected={ + "index_array": np.array([1, 3, 5], dtype=np.intp), + "offset": 10, + "stride": 2, + }, + id="ArrayMap-all-fields", + ), + Expect( + input=ArrayMap(index_array=np.array([0, 1], dtype=np.intp)), + expected={ + "index_array": np.array([0, 1], dtype=np.intp), + "offset": 0, + "stride": 1, + }, + id="ArrayMap-defaults", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[Any, dict[str, Any]]) -> None: + """Constructing each map type with explicit and default values yields the expected fields.""" + actual = asdict(case.input) + assert set(actual) == set(case.expected) + for field, expected_value in case.expected.items(): + if isinstance(expected_value, np.ndarray): + np.testing.assert_array_equal(actual[field], expected_value) + else: + assert actual[field] == expected_value - def test_defaults(self) -> None: - arr = np.array([0, 1], dtype=np.intp) - m = ArrayMap(index_array=arr) - assert m.offset == 0 - assert m.stride == 1 - def test_frozen(self) -> None: - arr = np.array([0], dtype=np.intp) - m = ArrayMap(index_array=arr) - assert isinstance(m, ArrayMap) +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(ConstantMap(offset=5), "offset", 99), + msg="cannot assign to field 'offset'", + exception_cls=FrozenInstanceError, + id="ConstantMap-frozen", + ), + ExpectErr( + input=(DimensionMap(input_dimension=0), "stride", 7), + msg="cannot assign to field 'stride'", + exception_cls=FrozenInstanceError, + id="DimensionMap-frozen", + ), + ExpectErr( + input=(ArrayMap(index_array=np.array([0], dtype=np.intp)), "offset", 1), + msg="cannot assign to field 'offset'", + exception_cls=FrozenInstanceError, + id="ArrayMap-frozen", + ), + ], + ids=lambda c: c.id, +) +def test_mutation_errors(case: ExpectErr[tuple[Any, str, Any]]) -> None: + """Attempting to mutate a frozen output map raises FrozenInstanceError.""" + obj, field, new_value = case.input + with pytest.raises(case.exception_cls, match=case.msg): + setattr(obj, field, new_value) From 6e54443508622a7f1ff77f6f975bf01d6c0f8136 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 09:56:44 -0400 Subject: [PATCH 10/24] test(_transforms): rewrite domain tests in parametrized style MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit One success test and (where applicable) one error test per public function on IndexDomain. Adds direct tests for the non-trivial private helper _normalize_selection, including the previously-untested double-ellipsis branch. Coverage: 34 tests → 53 tests. --- tests/test_transforms/test_domain.py | 665 +++++++++++++++++++-------- 1 file changed, 467 insertions(+), 198 deletions(-) diff --git a/tests/test_transforms/test_domain.py b/tests/test_transforms/test_domain.py index 58f3808d95..0a168eadd5 100644 --- a/tests/test_transforms/test_domain.py +++ b/tests/test_transforms/test_domain.py @@ -1,202 +1,471 @@ from __future__ import annotations +from typing import Any + import pytest -from zarr.core._transforms.domain import IndexDomain - - -class TestIndexDomainConstruction: - def test_from_shape(self) -> None: - d = IndexDomain.from_shape((10, 20)) - assert d.inclusive_min == (0, 0) - assert d.exclusive_max == (10, 20) - assert d.ndim == 2 - assert d.origin == (0, 0) - assert d.shape == (10, 20) - - def test_from_shape_0d(self) -> None: - d = IndexDomain.from_shape(()) - assert d.ndim == 0 - assert d.shape == () - - def test_non_zero_origin(self) -> None: - d = IndexDomain(inclusive_min=(5, 10), exclusive_max=(15, 30)) - assert d.origin == (5, 10) - assert d.shape == (10, 20) - assert d.ndim == 2 - - def test_validation_mismatched_lengths(self) -> None: - with pytest.raises(ValueError, match="same length"): - IndexDomain(inclusive_min=(0,), exclusive_max=(10, 20)) - - def test_validation_min_greater_than_max(self) -> None: - with pytest.raises(ValueError, match="inclusive_min must be <="): - IndexDomain(inclusive_min=(10,), exclusive_max=(5,)) - - def test_empty_domain(self) -> None: - d = IndexDomain(inclusive_min=(5,), exclusive_max=(5,)) - assert d.shape == (0,) - - def test_labels(self) -> None: - d = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 20), labels=("x", "y")) - assert d.labels == ("x", "y") - - def test_labels_none(self) -> None: - d = IndexDomain.from_shape((10,)) - assert d.labels is None - - -class TestIndexDomainContains: - def test_contains_inside(self) -> None: - d = IndexDomain.from_shape((10, 20)) - assert d.contains((0, 0)) is True - assert d.contains((9, 19)) is True - assert d.contains((5, 10)) is True - - def test_contains_outside(self) -> None: - d = IndexDomain.from_shape((10, 20)) - assert d.contains((10, 0)) is False - assert d.contains((-1, 0)) is False - assert d.contains((0, 20)) is False - - def test_contains_non_zero_origin(self) -> None: - d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) - assert d.contains((5,)) is True - assert d.contains((9,)) is True - assert d.contains((4,)) is False - assert d.contains((10,)) is False - - def test_contains_wrong_ndim(self) -> None: - d = IndexDomain.from_shape((10, 20)) - assert d.contains((5,)) is False - - def test_contains_domain_inside(self) -> None: - outer = IndexDomain.from_shape((10, 20)) - inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(8, 15)) - assert outer.contains_domain(inner) is True - - def test_contains_domain_outside(self) -> None: - outer = IndexDomain.from_shape((10, 20)) - inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(11, 15)) - assert outer.contains_domain(inner) is False - - def test_contains_domain_wrong_ndim(self) -> None: - outer = IndexDomain.from_shape((10, 20)) - inner = IndexDomain.from_shape((5,)) - assert outer.contains_domain(inner) is False - - -class TestIndexDomainIntersect: - def test_overlapping(self) -> None: - a = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 10)) - b = IndexDomain(inclusive_min=(5, 5), exclusive_max=(15, 15)) - result = a.intersect(b) - assert result is not None - assert result.inclusive_min == (5, 5) - assert result.exclusive_max == (10, 10) - - def test_disjoint(self) -> None: - a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) - b = IndexDomain(inclusive_min=(10,), exclusive_max=(15,)) - assert a.intersect(b) is None - - def test_touching_boundary(self) -> None: - a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) - b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) - assert a.intersect(b) is None - - def test_contained(self) -> None: - a = IndexDomain.from_shape((20,)) - b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) - result = a.intersect(b) - assert result is not None - assert result.inclusive_min == (5,) - assert result.exclusive_max == (10,) - - def test_wrong_ndim(self) -> None: - a = IndexDomain.from_shape((10,)) - b = IndexDomain.from_shape((10, 20)) - with pytest.raises(ValueError, match="different ranks"): - a.intersect(b) - - -class TestIndexDomainTranslate: - def test_translate_positive(self) -> None: - d = IndexDomain.from_shape((10, 20)) - result = d.translate((5, 10)) - assert result.inclusive_min == (5, 10) - assert result.exclusive_max == (15, 30) - - def test_translate_negative(self) -> None: - d = IndexDomain(inclusive_min=(10, 20), exclusive_max=(30, 40)) - result = d.translate((-10, -20)) - assert result.inclusive_min == (0, 0) - assert result.exclusive_max == (20, 20) - - def test_translate_wrong_length(self) -> None: - d = IndexDomain.from_shape((10,)) - with pytest.raises(ValueError, match="same length"): - d.translate((1, 2)) - - -class TestIndexDomainNarrow: - def test_narrow_slice(self) -> None: - d = IndexDomain.from_shape((10, 20)) - result = d.narrow((slice(2, 8), slice(5, 15))) - assert result.inclusive_min == (2, 5) - assert result.exclusive_max == (8, 15) - - def test_narrow_int(self) -> None: - d = IndexDomain.from_shape((10, 20)) - result = d.narrow((3, slice(None))) - assert result.inclusive_min == (3, 0) - assert result.exclusive_max == (4, 20) - - def test_narrow_ellipsis(self) -> None: - d = IndexDomain.from_shape((10, 20, 30)) - result = d.narrow((slice(1, 5), ...)) - assert result.inclusive_min == (1, 0, 0) - assert result.exclusive_max == (5, 20, 30) - - def test_narrow_slice_none(self) -> None: - d = IndexDomain.from_shape((10,)) - result = d.narrow((slice(None),)) - assert result == d - - def test_narrow_non_zero_origin(self) -> None: - d = IndexDomain(inclusive_min=(10,), exclusive_max=(20,)) - result = d.narrow((slice(12, 18),)) - assert result.inclusive_min == (12,) - assert result.exclusive_max == (18,) - - def test_narrow_int_out_of_bounds(self) -> None: - d = IndexDomain.from_shape((10,)) - with pytest.raises(IndexError, match="out of bounds"): - d.narrow((10,)) - - def test_narrow_int_below_origin(self) -> None: - d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) - with pytest.raises(IndexError, match="out of bounds"): - d.narrow((4,)) - - def test_narrow_clamps_to_domain(self) -> None: - d = IndexDomain.from_shape((10,)) - result = d.narrow((slice(-5, 100),)) - assert result.inclusive_min == (0,) - assert result.exclusive_max == (10,) - - def test_narrow_bare_slice(self) -> None: - d = IndexDomain.from_shape((10,)) - result = d.narrow(slice(2, 8)) - assert result.inclusive_min == (2,) - assert result.exclusive_max == (8,) - - def test_narrow_too_many_indices(self) -> None: - d = IndexDomain.from_shape((10,)) - with pytest.raises(IndexError, match="too many indices"): - d.narrow((1, 2)) - - def test_narrow_step_not_one(self) -> None: - d = IndexDomain.from_shape((10,)) - with pytest.raises(IndexError, match="step=1"): - d.narrow((slice(0, 10, 2),)) +from tests.test_transforms.conftest import Expect, ExpectErr +from zarr.core._transforms.domain import IndexDomain, _normalize_selection + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20)}, + expected={"ndim": 2, "origin": (0, 0), "shape": (10, 20), "labels": None}, + id="2d-zero-origin", + ), + Expect( + input={"inclusive_min": (5, 10), "exclusive_max": (15, 30)}, + expected={"ndim": 2, "origin": (5, 10), "shape": (10, 20), "labels": None}, + id="2d-non-zero-origin", + ), + Expect( + input={"inclusive_min": (5,), "exclusive_max": (5,)}, + expected={"ndim": 1, "origin": (5,), "shape": (0,), "labels": None}, + id="1d-empty", + ), + Expect( + input={"inclusive_min": (), "exclusive_max": ()}, + expected={"ndim": 0, "origin": (), "shape": (), "labels": None}, + id="0d", + ), + Expect( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20), "labels": ("x", "y")}, + expected={"ndim": 2, "origin": (0, 0), "shape": (10, 20), "labels": ("x", "y")}, + id="2d-with-labels", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[dict[str, Any], dict[str, Any]]) -> None: + """IndexDomain construction yields the expected shape, origin, ndim, and labels.""" + d = IndexDomain(**case.input) + for prop, expected in case.expected.items(): + assert getattr(d, prop) == expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={"inclusive_min": (0,), "exclusive_max": (10, 20)}, + msg="same length", + exception_cls=ValueError, + id="mismatched-min-max-lengths", + ), + ExpectErr( + input={"inclusive_min": (10,), "exclusive_max": (5,)}, + msg="inclusive_min must be <=", + exception_cls=ValueError, + id="min-greater-than-max", + ), + ExpectErr( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20), "labels": ("x",)}, + msg="labels must have the same length as dimensions", + exception_cls=ValueError, + id="labels-wrong-length", + ), + ], + ids=lambda c: c.id, +) +def test_construction_errors(case: ExpectErr[dict[str, Any]]) -> None: + """IndexDomain construction with invalid inputs raises ValueError.""" + with pytest.raises(case.exception_cls, match=case.msg): + IndexDomain(**case.input) + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=(10, 20), expected=(2, (0, 0), (10, 20)), id="2d"), + Expect(input=(10,), expected=(1, (0,), (10,)), id="1d"), + Expect(input=(), expected=(0, (), ()), id="0d"), + ], + ids=lambda c: c.id, +) +def test_from_shape_success( + case: Expect[tuple[int, ...], tuple[int, tuple[int, ...], tuple[int, ...]]], +) -> None: + """IndexDomain.from_shape produces a zero-origin domain with the requested shape.""" + d = IndexDomain.from_shape(case.input) + expected_ndim, expected_origin, expected_shape = case.expected + assert d.ndim == expected_ndim + assert d.origin == expected_origin + assert d.shape == expected_shape + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (0, 0)), + expected=True, + id="2d-corner-low", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (9, 19)), + expected=True, + id="2d-corner-high", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (5, 10)), + expected=True, + id="2d-interior", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (10, 0)), + expected=False, + id="2d-outside-high", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (-1, 0)), + expected=False, + id="2d-outside-low", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (5,)), + expected=False, + id="wrong-ndim", + ), + Expect( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (5,)), + expected=True, + id="non-zero-origin-low", + ), + Expect( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (4,)), + expected=False, + id="non-zero-origin-below", + ), + ], + ids=lambda c: c.id, +) +def test_contains_success(case: Expect[tuple[IndexDomain, tuple[int, ...]], bool]) -> None: + """IndexDomain.contains returns True iff the index is within the domain.""" + domain, index = case.input + assert domain.contains(index) is case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexDomain.from_shape((10, 20)), + IndexDomain(inclusive_min=(2, 3), exclusive_max=(8, 15)), + ), + expected=True, + id="strict-subset", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), IndexDomain.from_shape((10, 20))), + expected=True, + id="equal-domains", + ), + Expect( + input=( + IndexDomain.from_shape((10, 20)), + IndexDomain(inclusive_min=(2, 3), exclusive_max=(11, 15)), + ), + expected=False, + id="extends-past-max", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), IndexDomain.from_shape((5,))), + expected=False, + id="wrong-ndim", + ), + ], + ids=lambda c: c.id, +) +def test_contains_domain_success(case: Expect[tuple[IndexDomain, IndexDomain], bool]) -> None: + """IndexDomain.contains_domain returns True iff `other` is fully contained.""" + outer, inner = case.input + assert outer.contains_domain(inner) is case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 10)), + IndexDomain(inclusive_min=(5, 5), exclusive_max=(15, 15)), + ), + expected=IndexDomain(inclusive_min=(5, 5), exclusive_max=(10, 10)), + id="overlapping-2d", + ), + Expect( + input=( + IndexDomain.from_shape((20,)), + IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + ), + expected=IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + id="contained", + ), + Expect( + input=( + IndexDomain(inclusive_min=(0,), exclusive_max=(5,)), + IndexDomain(inclusive_min=(10,), exclusive_max=(15,)), + ), + expected=None, + id="disjoint", + ), + Expect( + input=( + IndexDomain(inclusive_min=(0,), exclusive_max=(5,)), + IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + ), + expected=None, + id="touching-boundary", + ), + ], + ids=lambda c: c.id, +) +def test_intersect_success( + case: Expect[tuple[IndexDomain, IndexDomain], IndexDomain | None], +) -> None: + """IndexDomain.intersect returns the intersection, or None when disjoint.""" + a, b = case.input + assert a.intersect(b) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), IndexDomain.from_shape((10, 20))), + msg="different ranks", + exception_cls=ValueError, + id="rank-mismatch", + ), + ], + ids=lambda c: c.id, +) +def test_intersect_errors(case: ExpectErr[tuple[IndexDomain, IndexDomain]]) -> None: + """IndexDomain.intersect raises ValueError on rank mismatch.""" + a, b = case.input + with pytest.raises(case.exception_cls, match=case.msg): + a.intersect(b) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (5, 10)), + expected=IndexDomain(inclusive_min=(5, 10), exclusive_max=(15, 30)), + id="positive-offset", + ), + Expect( + input=(IndexDomain(inclusive_min=(10, 20), exclusive_max=(30, 40)), (-10, -20)), + expected=IndexDomain(inclusive_min=(0, 0), exclusive_max=(20, 20)), + id="negative-offset", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (0,)), + expected=IndexDomain.from_shape((10,)), + id="zero-offset", + ), + ], + ids=lambda c: c.id, +) +def test_translate_success( + case: Expect[tuple[IndexDomain, tuple[int, ...]], IndexDomain], +) -> None: + """IndexDomain.translate shifts every coordinate by the offset.""" + domain, offset = case.input + assert domain.translate(offset) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), (1, 2)), + msg="same length", + exception_cls=ValueError, + id="offset-too-long", + ), + ExpectErr( + input=(IndexDomain.from_shape((10, 20)), (1,)), + msg="same length", + exception_cls=ValueError, + id="offset-too-short", + ), + ], + ids=lambda c: c.id, +) +def test_translate_errors(case: ExpectErr[tuple[IndexDomain, tuple[int, ...]]]) -> None: + """IndexDomain.translate raises when offset length differs from ndim.""" + domain, offset = case.input + with pytest.raises(case.exception_cls, match=case.msg): + domain.translate(offset) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (slice(2, 8), slice(5, 15))), + expected=IndexDomain(inclusive_min=(2, 5), exclusive_max=(8, 15)), + id="2d-slices", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (3, slice(None))), + expected=IndexDomain(inclusive_min=(3, 0), exclusive_max=(4, 20)), + id="int-and-slice", + ), + Expect( + input=(IndexDomain.from_shape((10, 20, 30)), (slice(1, 5), ...)), + expected=IndexDomain(inclusive_min=(1, 0, 0), exclusive_max=(5, 20, 30)), + id="ellipsis-fills-trailing", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (slice(None),)), + expected=IndexDomain.from_shape((10,)), + id="slice-none-is-noop", + ), + Expect( + input=(IndexDomain(inclusive_min=(10,), exclusive_max=(20,)), (slice(12, 18),)), + expected=IndexDomain(inclusive_min=(12,), exclusive_max=(18,)), + id="non-zero-origin", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (slice(-5, 100),)), + expected=IndexDomain(inclusive_min=(0,), exclusive_max=(10,)), + id="clamps-to-domain", + ), + Expect( + input=(IndexDomain.from_shape((10,)), slice(2, 8)), + expected=IndexDomain(inclusive_min=(2,), exclusive_max=(8,)), + id="bare-slice-is-wrapped", + ), + ], + ids=lambda c: c.id, +) +def test_narrow_success(case: Expect[tuple[IndexDomain, Any], IndexDomain]) -> None: + """IndexDomain.narrow applies basic indexing to produce a sub-domain.""" + domain, selection = case.input + assert domain.narrow(selection) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), (10,)), + msg="out of bounds", + exception_cls=IndexError, + id="int-at-upper-bound", + ), + ExpectErr( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (4,)), + msg="out of bounds", + exception_cls=IndexError, + id="int-below-origin", + ), + ExpectErr( + input=(IndexDomain.from_shape((10,)), (1, 2)), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ExpectErr( + input=(IndexDomain.from_shape((10,)), (slice(0, 10, 2),)), + msg="step=1", + exception_cls=IndexError, + id="non-unit-step", + ), + ], + ids=lambda c: c.id, +) +def test_narrow_errors(case: ExpectErr[tuple[IndexDomain, Any]]) -> None: + """IndexDomain.narrow raises IndexError on invalid selections.""" + domain, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + domain.narrow(selection) + + +# --------------------------------------------------------------------------- +# Direct tests for the non-trivial private helper _normalize_selection. +# Public callers (`IndexDomain.narrow` and `selection_to_transform`) exercise +# most branches transitively, but the double-ellipsis guard only triggers on +# inputs no public caller currently constructs. Test it directly. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=((slice(2, 8), slice(5, 15)), 2), + expected=(slice(2, 8), slice(5, 15)), + id="explicit-slices", + ), + Expect( + input=((3, slice(None)), 2), + expected=(3, slice(None)), + id="int-and-slice", + ), + Expect( + input=((..., slice(0, 5)), 3), + expected=(slice(None), slice(None), slice(0, 5)), + id="leading-ellipsis-fills", + ), + Expect( + input=((slice(0, 5), ...), 3), + expected=(slice(0, 5), slice(None), slice(None)), + id="trailing-ellipsis-fills", + ), + Expect( + input=((slice(2, 8),), 3), + expected=(slice(2, 8), slice(None), slice(None)), + id="implicit-trailing-fills", + ), + Expect( + input=(slice(2, 8), 1), + expected=(slice(2, 8),), + id="bare-slice-is-wrapped", + ), + Expect( + input=(5, 1), + expected=(5,), + id="bare-int-is-wrapped", + ), + ], + ids=lambda c: c.id, +) +def test_normalize_selection_success( + case: Expect[tuple[Any, int], tuple[int | slice, ...]], +) -> None: + """_normalize_selection produces a length-ndim tuple of ints/slices.""" + selection, ndim = case.input + assert _normalize_selection(selection, ndim) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=((..., ..., slice(0, 5)), 3), + msg="single ellipsis", + exception_cls=IndexError, + id="double-ellipsis", + ), + ExpectErr( + input=((1, 2, 3), 2), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ], + ids=lambda c: c.id, +) +def test_normalize_selection_errors(case: ExpectErr[tuple[Any, int]]) -> None: + """_normalize_selection rejects double ellipsis and over-long selections.""" + selection, ndim = case.input + with pytest.raises(case.exception_cls, match=case.msg): + _normalize_selection(selection, ndim) From d91f59283fa6306139580c85c99c631165f5c3cc Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:09:55 -0400 Subject: [PATCH 11/24] refactor(_transforms): drop unused _normalize_negative_indices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The function bridged NumPy-style negative indices to TensorStore-style absolute coordinates. It was originally used by the eager-path rewiring in PR #3906 (4 callsites in src/zarr/core/array.py). This PR has no such rewiring — PR #2's lazy view materializes via the existing eager indexing path, which already handles negatives — so the helper has no callers in our shipping plan. Re-add when (and if) the internal-rewiring follow-up arrives. Removing it now keeps the package surface lean and the diff focused. --- src/zarr/core/_transforms/transform.py | 55 -------------------------- 1 file changed, 55 deletions(-) diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index 3b160cb18d..f8abfd9285 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -825,61 +825,6 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: return IndexTransform(domain=new_domain, output=tuple(new_output)) -def _normalize_negative_indices(selection: Any, shape: tuple[int, ...]) -> Any: - """Convert negative indices to positive ones using the array shape. - - Only normalizes integer and array-like index components; leaves - slices, Ellipsis, None, etc. untouched. - """ - if not isinstance(selection, tuple): - selection_tuple: tuple[Any, ...] = (selection,) - else: - selection_tuple = selection - - # Count real dimensions (non-None, non-Ellipsis) to map each entry to a shape dim - has_ellipsis = any(s is Ellipsis for s in selection_tuple) - n_non_newaxis = sum(1 for s in selection_tuple if s is not None and s is not Ellipsis) - n_ellipsis_dims = len(shape) - n_non_newaxis + (1 if has_ellipsis else 0) - - result: list[Any] = [] - dim = 0 - - for sel in selection_tuple: - if sel is Ellipsis: - result.append(sel) - dim += max(0, n_ellipsis_dims) - elif sel is None: - result.append(sel) - elif isinstance(sel, (int, np.integer)) and not isinstance(sel, bool): - idx = int(sel) - if idx < 0 and dim < len(shape): - idx = idx + shape[dim] - result.append(idx) - dim += 1 - elif isinstance(sel, np.ndarray) and sel.dtype != np.bool_: - arr = sel.copy() - if dim < len(shape): - arr = np.where(arr < 0, arr + shape[dim], arr) - result.append(arr) - dim += 1 - elif isinstance(sel, list): - # Convert lists to arrays with negative index normalization - arr = np.asarray(sel, dtype=np.intp) - if dim < len(shape): - arr = np.where(arr < 0, arr + shape[dim], arr) - result.append(arr) - dim += 1 - else: - # slice, bool array, or anything else: pass through - result.append(sel) - if sel is not None and sel is not Ellipsis: - dim += 1 - - if not isinstance(selection, tuple) and len(result) == 1: - return result[0] - return tuple(result) - - def _validate_array_selection(selection: Any, shape: tuple[int, ...], mode: str) -> None: """Validate array-based selections (orthogonal, vectorized). From 6d18502b66b5979457afc024043fa97dad106f65 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:10:20 -0400 Subject: [PATCH 12/24] test(_transforms): rewrite transform tests in parametrized style MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit One success test (and where applicable one error test) per public method on IndexTransform plus the public selection_to_transform function. Adds direct tests for _intersect_vectorized — the correlated multi-ArrayMap intersection path. Public `intersect` only reaches it when the transform has 2+ ArrayMap outputs; all the public-surface intersect cases use a single ArrayMap, so this branch was previously unreached by tests. --- tests/test_transforms/test_transform.py | 1212 +++++++++++++---------- 1 file changed, 704 insertions(+), 508 deletions(-) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index 03f26cfb5d..ad8e9f925b 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -1,516 +1,712 @@ from __future__ import annotations +from typing import Any + import numpy as np import pytest +from tests.test_transforms.conftest import Expect, ExpectErr from zarr.core._transforms.domain import IndexDomain from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap -from zarr.core._transforms.transform import IndexTransform, selection_to_transform - - -class TestIndexTransformConstruction: - def test_from_shape(self) -> None: - t = IndexTransform.from_shape((10, 20)) - assert t.input_rank == 2 - assert t.output_rank == 2 - assert t.domain.shape == (10, 20) - assert t.domain.origin == (0, 0) - for i, m in enumerate(t.output): - assert isinstance(m, DimensionMap) - assert m.input_dimension == i - assert m.offset == 0 - assert m.stride == 1 - - def test_identity(self) -> None: - domain = IndexDomain(inclusive_min=(5,), exclusive_max=(15,)) - t = IndexTransform.identity(domain) - assert t.input_rank == 1 - assert t.output_rank == 1 - assert t.domain == domain - assert isinstance(t.output[0], DimensionMap) - assert t.output[0].input_dimension == 0 - - def test_from_shape_0d(self) -> None: - t = IndexTransform.from_shape(()) - assert t.input_rank == 0 - assert t.output_rank == 0 - assert t.domain.shape == () - - def test_custom_output_maps(self) -> None: - domain = IndexDomain.from_shape((10,)) - maps = (ConstantMap(offset=42), DimensionMap(input_dimension=0, offset=5, stride=2)) - t = IndexTransform(domain=domain, output=maps) - assert t.input_rank == 1 - assert t.output_rank == 2 - - def test_validation_input_dimension_out_of_range(self) -> None: - domain = IndexDomain.from_shape((10,)) - maps = (DimensionMap(input_dimension=5),) - with pytest.raises(ValueError, match="input_dimension"): - IndexTransform(domain=domain, output=maps) - - -class TestIndexTransformBasicIndexing: - def test_slice_identity(self) -> None: - """slice(None) on identity transform is a no-op.""" - t = IndexTransform.from_shape((10, 20)) - result = t[slice(None), slice(None)] - assert result.domain.shape == (10, 20) - assert result.input_rank == 2 - assert result.output_rank == 2 - - def test_slice_narrows(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = t[2:8, 5:15] - assert result.domain.shape == (6, 10) - assert result.domain.origin == (0, 0) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 2 - assert result.output[0].stride == 1 - assert result.output[0].input_dimension == 0 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].offset == 5 - assert result.output[1].input_dimension == 1 - - def test_strided_slice(self) -> None: - t = IndexTransform.from_shape((10,)) - result = t[::2] - assert result.domain.shape == (5,) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 0 - assert result.output[0].stride == 2 - - def test_strided_slice_with_start(self) -> None: - t = IndexTransform.from_shape((10,)) - result = t[1:9:3] - # indices: 1, 4, 7 -> 3 elements - assert result.domain.shape == (3,) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 1 - assert result.output[0].stride == 3 - - def test_int_drops_dimension(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = t[3] - assert result.input_rank == 1 - assert result.output_rank == 2 - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 3 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].input_dimension == 0 - - def test_int_middle_dimension(self) -> None: - t = IndexTransform.from_shape((10, 20, 30)) - result = t[:, 5, :] - assert result.input_rank == 2 - assert result.output_rank == 3 - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].input_dimension == 0 - assert isinstance(result.output[1], ConstantMap) - assert result.output[1].offset == 5 - assert isinstance(result.output[2], DimensionMap) - assert result.output[2].input_dimension == 1 - - def test_ellipsis(self) -> None: - t = IndexTransform.from_shape((10, 20, 30)) - result = t[2:8, ...] - assert result.input_rank == 3 - assert result.domain.shape == (6, 20, 30) - - def test_newaxis(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = t[np.newaxis, :, :] - assert result.input_rank == 3 - assert result.domain.shape == (1, 10, 20) - assert result.output_rank == 2 - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].input_dimension == 1 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].input_dimension == 2 - - def test_int_out_of_bounds(self) -> None: - t = IndexTransform.from_shape((10,)) - with pytest.raises(IndexError): - t[10] - - def test_negative_int_is_literal(self) -> None: - """Negative indices are literal coordinates (TensorStore convention), - not 'from the end' like NumPy.""" - t = IndexTransform.from_shape((10,)) - with pytest.raises(IndexError): - t[-1] # -1 is out of bounds for domain [0, 10) - - def test_negative_int_valid_with_negative_origin(self) -> None: - """Negative index is valid if the domain includes negative coordinates.""" - domain = IndexDomain(inclusive_min=(-5,), exclusive_max=(5,)) - t = IndexTransform.identity(domain) - result = t[-3] - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == -3 - - def test_composition_of_slices(self) -> None: - """Slicing a sliced transform should compose offsets.""" - t = IndexTransform.from_shape((100,)) - result = t[10:50][5:20] - assert result.domain.shape == (15,) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 15 - assert result.output[0].stride == 1 - - def test_composition_of_strides(self) -> None: - t = IndexTransform.from_shape((100,)) - result = t[::2][::3] - # t[::2] -> shape (50,), offset=0, stride=2 - # [::3] -> shape ceil(50/3)=17, offset=0, stride=2*3=6 - assert result.domain.shape == (17,) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].stride == 6 - - def test_bare_int(self) -> None: - """Non-tuple selection.""" - t = IndexTransform.from_shape((10, 20)) - result = t[3] - assert result.input_rank == 1 - - def test_bare_slice(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = t[2:8] - assert result.domain.shape == (6, 20) - - -class TestBasicIndexingOnArrayMaps: - """When a transform already has ArrayMap outputs, basic indexing must - apply the corresponding operation to the index_array's axes.""" - - def test_int_on_array_map_drops_axis(self) -> None: - """Integer index on a dimension referenced by an ArrayMap should - index into the array on that axis.""" - arr = np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp) - # 2D input domain (3, 2), one ArrayMap output - t = IndexTransform( - domain=IndexDomain.from_shape((3, 2)), - output=(ArrayMap(index_array=arr),), - ) - # Index with int on dim 0 -> pick row 1 -> arr[1, :] = [30, 40] - result = t[1] - assert result.input_rank == 1 - assert result.domain.shape == (2,) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, np.array([30, 40])) - - def test_slice_on_array_map(self) -> None: - """Slice on a dimension referenced by an ArrayMap should slice the array.""" - arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(ArrayMap(index_array=arr),), - ) - result = t[1:4] - assert result.domain.shape == (3,) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30, 40])) - - def test_strided_slice_on_array_map(self) -> None: - """Strided slice on ArrayMap should stride the array.""" - arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(ArrayMap(index_array=arr),), - ) - result = t[::2] - assert result.domain.shape == (3,) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 30, 50])) - - def test_newaxis_on_array_map(self) -> None: - """Newaxis should insert an axis in the index_array.""" - arr = np.array([10, 20, 30], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr),), - ) - result = t[np.newaxis, :] - assert result.input_rank == 2 - assert result.domain.shape == (1, 3) - assert isinstance(result.output[0], ArrayMap) - assert result.output[0].index_array.shape == (1, 3) - np.testing.assert_array_equal(result.output[0].index_array, np.array([[10, 20, 30]])) - - def test_int_drops_one_of_two_array_dims(self) -> None: - """2D array map, int on dim 0, slice on dim 1.""" - arr = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((2, 3)), - output=(ArrayMap(index_array=arr),), - ) - result = t[0, 1:3] - assert result.input_rank == 1 - assert result.domain.shape == (2,) - assert isinstance(result.output[0], ArrayMap) - # arr[0, 1:3] = [20, 30] - np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30])) - - -class TestIndexTransformOindex: - def test_oindex_int_array(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx = np.array([1, 3, 5], dtype=np.intp) - result = t.oindex[idx, :] - assert result.input_rank == 2 - assert result.domain.shape == (3, 20) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, idx) - assert result.output[0].offset == 0 - assert result.output[0].stride == 1 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].input_dimension == 1 - - def test_oindex_bool_array(self) -> None: - t = IndexTransform.from_shape((5,)) - mask = np.array([True, False, True, False, True]) - result = t.oindex[mask] - assert result.domain.shape == (3,) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal( - result.output[0].index_array, np.array([0, 2, 4], dtype=np.intp) - ) - - def test_oindex_mixed(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx = np.array([2, 4], dtype=np.intp) - result = t.oindex[idx, 5:15] - assert result.input_rank == 2 - assert result.domain.shape == (2, 10) - assert isinstance(result.output[0], ArrayMap) - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].offset == 5 - - def test_oindex_multiple_arrays(self) -> None: - t = IndexTransform.from_shape((10, 20, 30)) - idx0 = np.array([1, 3], dtype=np.intp) - idx1 = np.array([5, 10, 15], dtype=np.intp) - result = t.oindex[idx0, :, idx1] - assert result.input_rank == 3 - assert result.domain.shape == (2, 20, 3) - assert isinstance(result.output[0], ArrayMap) - assert isinstance(result.output[1], DimensionMap) - assert isinstance(result.output[2], ArrayMap) - - -class TestIndexTransformVindex: - def test_vindex_single_array(self) -> None: - t = IndexTransform.from_shape((10,)) - idx = np.array([1, 3, 5], dtype=np.intp) - result = t.vindex[idx] - assert result.input_rank == 1 - assert result.domain.shape == (3,) - assert isinstance(result.output[0], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, idx) - - def test_vindex_broadcast(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx0 = np.array([[1, 2], [3, 4]], dtype=np.intp) - idx1 = np.array([[10, 11], [12, 13]], dtype=np.intp) - result = t.vindex[idx0, idx1] - assert result.input_rank == 2 - assert result.domain.shape == (2, 2) - assert isinstance(result.output[0], ArrayMap) - assert isinstance(result.output[1], ArrayMap) - np.testing.assert_array_equal(result.output[0].index_array, idx0) - np.testing.assert_array_equal(result.output[1].index_array, idx1) - - def test_vindex_with_slice(self) -> None: - t = IndexTransform.from_shape((10, 20, 30)) - idx = np.array([1, 3, 5], dtype=np.intp) - result = t.vindex[idx, :, :] - assert result.input_rank == 3 - assert result.domain.shape == (3, 20, 30) - assert isinstance(result.output[0], ArrayMap) - - def test_vindex_bool_mask(self) -> None: - t = IndexTransform.from_shape((5,)) - mask = np.array([True, False, True, False, True]) - result = t.vindex[mask] - assert result.domain.shape == (3,) - assert isinstance(result.output[0], ArrayMap) - - def test_vindex_broadcast_different_shapes(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx0 = np.array([1, 2, 3], dtype=np.intp) - idx1 = np.array([[10], [11]], dtype=np.intp) - result = t.vindex[idx0, idx1] - assert result.input_rank == 2 - assert result.domain.shape == (2, 3) - - -class TestSelectionToTransform: - def test_basic_slice(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = selection_to_transform((slice(2, 8), slice(5, 15)), t, "basic") - assert result.domain.shape == (6, 10) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 2 - - def test_basic_int(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = selection_to_transform((3, slice(None)), t, "basic") - assert result.input_rank == 1 - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 3 - - def test_basic_ellipsis(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = selection_to_transform(Ellipsis, t, "basic") - assert result.domain.shape == (10, 20) - - def test_orthogonal(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx = np.array([1, 3, 5], dtype=np.intp) - result = selection_to_transform((idx, slice(None)), t, "orthogonal") - assert result.domain.shape == (3, 20) - assert isinstance(result.output[0], ArrayMap) - - def test_vectorized(self) -> None: - t = IndexTransform.from_shape((10, 20)) - idx0 = np.array([1, 3], dtype=np.intp) - idx1 = np.array([5, 7], dtype=np.intp) - result = selection_to_transform((idx0, idx1), t, "vectorized") - assert result.domain.shape == (2,) - assert isinstance(result.output[0], ArrayMap) - assert isinstance(result.output[1], ArrayMap) - - def test_composition_with_non_identity(self) -> None: - """Indexing a sliced transform composes offsets.""" - t = IndexTransform.from_shape((100,))[10:50] - result = selection_to_transform(slice(5, 20), t, "basic") - assert result.domain.shape == (15,) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 15 - - -class TestIndexTransformIntersect: - def test_constant_inside(self) -> None: - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(ConstantMap(offset=5),), - ) - result = t.intersect(IndexDomain(inclusive_min=(0,), exclusive_max=(10,))) - assert result is not None - restricted, surviving = result - assert isinstance(restricted.output[0], ConstantMap) - assert restricted.output[0].offset == 5 - assert surviving is None - - def test_constant_outside(self) -> None: - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(ConstantMap(offset=5),), - ) - result = t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) - assert result is None - - def test_dimension_partial(self) -> None: - """DimensionMap over [0,10) intersected with [5,15) narrows input to [5,10).""" - t = IndexTransform.from_shape((10,)) - result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))) - assert result is not None - restricted, surviving = result - assert restricted.domain.inclusive_min == (5,) - assert restricted.domain.exclusive_max == (10,) - assert surviving is None - - def test_dimension_no_overlap(self) -> None: - t = IndexTransform.from_shape((10,)) - result = t.intersect(IndexDomain(inclusive_min=(20,), exclusive_max=(30,))) - assert result is None - - def test_dimension_strided(self) -> None: - """stride=2, offset=1 over [0,5): storage 1,3,5,7,9. Chunk [4,8).""" - t = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(DimensionMap(input_dimension=0, offset=1, stride=2),), - ) - result = t.intersect(IndexDomain(inclusive_min=(4,), exclusive_max=(8,))) - assert result is not None - restricted, _surviving = result - # input 2->5, input 3->7. Both in [4,8). - assert restricted.domain.inclusive_min == (2,) - assert restricted.domain.exclusive_max == (4,) - - def test_array_partial(self) -> None: - arr = np.array([3, 8, 15, 22], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((4,)), - output=(ArrayMap(index_array=arr),), - ) - result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) - assert result is not None - restricted, surviving = result - assert isinstance(restricted.output[0], ArrayMap) - np.testing.assert_array_equal(restricted.output[0].index_array, np.array([8, 15])) - assert surviving is not None - np.testing.assert_array_equal(surviving, np.array([1, 2])) - - def test_array_none_inside(self) -> None: - arr = np.array([1, 2, 3], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr),), - ) - assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None - - def test_2d_mixed(self) -> None: - """2D: ConstantMap on dim 0, DimensionMap on dim 1.""" - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=( - ConstantMap(offset=5), - DimensionMap(input_dimension=0, offset=0, stride=1), +from zarr.core._transforms.transform import ( + IndexTransform, + _intersect_vectorized, + selection_to_transform, +) + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=IndexTransform.from_shape((10, 20)), + expected={"input_rank": 2, "output_rank": 2, "domain_shape": (10, 20)}, + id="from_shape-2d", + ), + Expect( + input=IndexTransform.from_shape(()), + expected={"input_rank": 0, "output_rank": 0, "domain_shape": ()}, + id="from_shape-0d", + ), + Expect( + input=IndexTransform.identity(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))), + expected={"input_rank": 1, "output_rank": 1, "domain_shape": (10,)}, + id="identity-non-zero-origin", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ), + expected={"input_rank": 1, "output_rank": 2, "domain_shape": (10,)}, + id="custom-output-maps", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[IndexTransform, dict[str, Any]]) -> None: + """IndexTransform constructors yield the expected ranks and domain shape.""" + t = case.input + assert t.input_rank == case.expected["input_rank"] + assert t.output_rank == case.expected["output_rank"] + assert t.domain.shape == case.expected["domain_shape"] + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=5),), + }, + msg="input_dimension", + exception_cls=ValueError, + id="dimension-map-out-of-range", + ), + ], + ids=lambda c: c.id, +) +def test_construction_errors(case: ExpectErr[dict[str, Any]]) -> None: + """IndexTransform construction with invalid output maps raises ValueError.""" + with pytest.raises(case.exception_cls, match=case.msg): + IndexTransform(**case.input) + + +# --------------------------------------------------------------------------- +# from_shape produces an identity transform whose output maps are DimensionMaps +# pointing at the corresponding input dim with offset=0, stride=1. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=(10, 20), expected=2, id="2d"), + Expect(input=(7,), expected=1, id="1d"), + Expect(input=(), expected=0, id="0d"), + ], + ids=lambda c: c.id, +) +def test_from_shape_produces_identity_dimension_maps( + case: Expect[tuple[int, ...], int], +) -> None: + """IndexTransform.from_shape produces DimensionMaps that map each output dim + back to the corresponding input dim, with no offset and unit stride.""" + t = IndexTransform.from_shape(case.input) + assert len(t.output) == case.expected + for i, m in enumerate(t.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + +# --------------------------------------------------------------------------- +# __getitem__ (basic indexing) +# +# Most successful branches are covered by selection_to_transform tests below; +# this set focuses on cases unique to the __getitem__ surface (composition, +# bare-int / bare-slice, ArrayMap interactions). +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(None), slice(None))), + expected={"shape": (10, 20), "input_rank": 2, "output_rank": 2}, + id="identity-slice", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(2, 8), slice(5, 15))), + expected={"shape": (6, 10), "input_rank": 2, "output_rank": 2}, + id="2d-narrowing-slices", + ), + Expect( + input=(IndexTransform.from_shape((10,)), slice(None, None, 2)), + expected={"shape": (5,), "input_rank": 1, "output_rank": 1}, + id="strided-slice", + ), + Expect( + input=(IndexTransform.from_shape((10,)), slice(1, 9, 3)), + expected={"shape": (3,), "input_rank": 1, "output_rank": 1}, + id="strided-slice-with-start", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), 3), + expected={"shape": (20,), "input_rank": 1, "output_rank": 2}, + id="bare-int-drops-leading-dim", + ), + Expect( + input=(IndexTransform.from_shape((10, 20, 30)), (slice(None), 5, slice(None))), + expected={"shape": (10, 30), "input_rank": 2, "output_rank": 3}, + id="int-drops-middle-dim", + ), + Expect( + input=(IndexTransform.from_shape((10, 20, 30)), (slice(2, 8), ...)), + expected={"shape": (6, 20, 30), "input_rank": 3, "output_rank": 3}, + id="ellipsis-fills-trailing", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (np.newaxis, slice(None), slice(None))), + expected={"shape": (1, 10, 20), "input_rank": 3, "output_rank": 2}, + id="newaxis-prepends-axis", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), slice(2, 8)), + expected={"shape": (6, 20), "input_rank": 2, "output_rank": 2}, + id="bare-slice-implicitly-fills-trailing", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_basic_success( + case: Expect[tuple[IndexTransform, Any], dict[str, Any]], +) -> None: + """IndexTransform.__getitem__ produces a sub-transform with the expected shape and rank.""" + transform, selection = case.input + result = transform[selection] + assert result.domain.shape == case.expected["shape"] + assert result.input_rank == case.expected["input_rank"] + assert result.output_rank == case.expected["output_rank"] + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), 10), + msg="out of bounds", + exception_cls=IndexError, + id="int-at-upper-bound", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), -1), + msg="out of bounds", + exception_cls=IndexError, + id="negative-int-out-of-domain", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_basic_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.__getitem__ rejects out-of-domain integer indices. + + Note: negative indices are LITERAL coordinates per TensorStore convention, + not wrap-around. arr[-1] on a domain [0, 10) is out of bounds, not arr[9]. + """ + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform[selection] + + +def test_getitem_negative_int_valid_with_negative_origin() -> None: + """A negative integer index is valid when the domain's origin is negative. + + Stand-alone test (not parametrized) because verifying the *literal-coordinate* + semantics is the whole point — the assertion on the resulting ConstantMap + offset is the load-bearing check, not the shape. + """ + domain = IndexDomain(inclusive_min=(-5,), exclusive_max=(5,)) + t = IndexTransform.identity(domain) + result = t[-3] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == -3 + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((100,))[10:50], slice(5, 20)), + expected={"shape": (15,), "offset": 15, "stride": 1}, + id="composed-slices", + ), + Expect( + input=(IndexTransform.from_shape((100,))[::2], slice(None, None, 3)), + expected={"shape": (17,), "offset": 0, "stride": 6}, + id="composed-strides", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_composition( + case: Expect[tuple[IndexTransform, Any], dict[str, Any]], +) -> None: + """Indexing a sliced transform composes offsets and strides on the DimensionMap.""" + transform, selection = case.input + result = transform[selection] + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == case.expected["offset"] + assert result.output[0].stride == case.expected["stride"] + + +# Indexing into a transform whose output is already an ArrayMap — basic +# operations (int/slice/stride/newaxis) must transform the index_array itself +# rather than building a new map. +_array_map_1d = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp)),), +) +_array_map_2d_3x2 = IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=(ArrayMap(index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp)),), +) +_array_map_2d_2x3 = IndexTransform( + domain=IndexDomain.from_shape((2, 3)), + output=(ArrayMap(index_array=np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp)),), +) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(_array_map_2d_3x2, 1), + expected=np.array([30, 40], dtype=np.intp), + id="int-on-array-map-drops-axis", + ), + Expect( + input=(_array_map_1d, slice(1, 4)), + expected=np.array([20, 30, 40], dtype=np.intp), + id="slice-on-array-map", + ), + Expect( + input=(_array_map_1d, slice(None, None, 2)), + expected=np.array([10, 30, 50], dtype=np.intp), + id="strided-slice-on-array-map", + ), + Expect( + input=(_array_map_2d_2x3, (0, slice(1, 3))), + expected=np.array([20, 30], dtype=np.intp), + id="int-then-slice-on-2d-array-map", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_on_array_map( + case: Expect[tuple[IndexTransform, Any], np.ndarray[Any, np.dtype[np.intp]]], +) -> None: + """Basic indexing on a transform whose output is an ArrayMap reshapes the index array.""" + transform, selection = case.input + result = transform[selection] + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, case.expected) + + +def test_getitem_newaxis_on_array_map() -> None: + """np.newaxis on an ArrayMap inserts a new axis in the index_array, not just the domain.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=np.array([10, 20, 30], dtype=np.intp)),), + ) + result = t[np.newaxis, :] + assert result.input_rank == 2 + assert result.domain.shape == (1, 3) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].index_array.shape == (1, 3) + np.testing.assert_array_equal(result.output[0].index_array, np.array([[10, 20, 30]])) + + +# --------------------------------------------------------------------------- +# oindex (orthogonal indexing) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3, 5], dtype=np.intp), slice(None)), + ), + expected={"shape": (3, 20), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="int-array-and-slice", + ), + Expect( + input=(IndexTransform.from_shape((5,)), (np.array([True, False, True, False, True]),)), + expected={"shape": (3,), "out0_kind": ArrayMap, "out1_kind": None}, + id="bool-mask", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([2, 4], dtype=np.intp), slice(5, 15)), + ), + expected={"shape": (2, 10), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="array-and-narrowing-slice", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20, 30)), + ( + np.array([1, 3], dtype=np.intp), + slice(None), + np.array([5, 10, 15], dtype=np.intp), + ), + ), + expected={"shape": (2, 20, 3), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="three-dims-mixed", + ), + ], + ids=lambda c: c.id, +) +def test_oindex_success(case: Expect[tuple[IndexTransform, Any], dict[str, Any]]) -> None: + """IndexTransform.oindex combines array indices independently per dimension.""" + transform, selection = case.input + result = transform.oindex[selection] + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], case.expected["out0_kind"]) + if case.expected["out1_kind"] is not None: + assert isinstance(result.output[1], case.expected["out1_kind"]) + + +# --------------------------------------------------------------------------- +# vindex (vectorized indexing) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10,)), np.array([1, 3, 5], dtype=np.intp)), + expected=(3,), + id="single-1d-array", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + ( + np.array([[1, 2], [3, 4]], dtype=np.intp), + np.array([[10, 11], [12, 13]], dtype=np.intp), + ), + ), + expected=(2, 2), + id="two-2d-arrays-broadcast", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20, 30)), + (np.array([1, 3, 5], dtype=np.intp), slice(None), slice(None)), + ), + expected=(3, 20, 30), + id="array-with-trailing-slices", + ), + Expect( + input=(IndexTransform.from_shape((5,)), np.array([True, False, True, False, True])), + expected=(3,), + id="bool-mask", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 2, 3], dtype=np.intp), np.array([[10], [11]], dtype=np.intp)), + ), + expected=(2, 3), + id="broadcast-different-shapes", + ), + ], + ids=lambda c: c.id, +) +def test_vindex_success(case: Expect[tuple[IndexTransform, Any], tuple[int, ...]]) -> None: + """IndexTransform.vindex broadcasts array indices and produces correlated ArrayMaps.""" + transform, selection = case.input + result = transform.vindex[selection] + assert result.domain.shape == case.expected + + +# --------------------------------------------------------------------------- +# selection_to_transform — the public dispatch front door for all three modes. +# Sanity check that each mode produces the expected output kind. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(2, 8), slice(5, 15)), "basic"), + expected={"shape": (6, 10), "out0_kind": DimensionMap}, + id="basic-slices", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (3, slice(None)), "basic"), + expected={"shape": (20,), "out0_kind": ConstantMap}, + id="basic-int-and-slice", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), Ellipsis, "basic"), + expected={"shape": (10, 20), "out0_kind": DimensionMap}, + id="basic-bare-ellipsis", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3, 5], dtype=np.intp), slice(None)), + "orthogonal", + ), + expected={"shape": (3, 20), "out0_kind": ArrayMap}, + id="orthogonal", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3], dtype=np.intp), np.array([5, 7], dtype=np.intp)), + "vectorized", ), - ) - chunk = IndexDomain(inclusive_min=(0, 5), exclusive_max=(10, 15)) - result = t.intersect(chunk) - assert result is not None - restricted, _ = result - assert isinstance(restricted.output[0], ConstantMap) - assert restricted.output[0].offset == 5 - assert isinstance(restricted.output[1], DimensionMap) - assert restricted.domain.inclusive_min == (5,) - assert restricted.domain.exclusive_max == (10,) - - -class TestIndexTransformTranslate: - def test_translate_constant(self) -> None: - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(ConstantMap(offset=5),), - ) - result = t.translate((-5,)) - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 0 - - def test_translate_dimension(self) -> None: - t = IndexTransform.from_shape((10,)) - result = t.translate((-3,)) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == -3 - assert result.output[0].stride == 1 - - def test_translate_array(self) -> None: - arr = np.array([5, 10], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((2,)), - output=(ArrayMap(index_array=arr, offset=3),), - ) - result = t.translate((-3,)) - assert isinstance(result.output[0], ArrayMap) - assert result.output[0].offset == 0 - np.testing.assert_array_equal(result.output[0].index_array, arr) - - def test_translate_2d(self) -> None: - t = IndexTransform.from_shape((10, 20)) - result = t.translate((-5, -10)) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == -5 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].offset == -10 + expected={"shape": (2,), "out0_kind": ArrayMap}, + id="vectorized", + ), + Expect( + input=(IndexTransform.from_shape((100,))[10:50], slice(5, 20), "basic"), + expected={"shape": (15,), "out0_kind": DimensionMap}, + id="composes-with-non-identity-base", + ), + ], + ids=lambda c: c.id, +) +def test_selection_to_transform_success( + case: Expect[tuple[IndexTransform, Any, str], dict[str, Any]], +) -> None: + """selection_to_transform dispatches to basic/orthogonal/vectorized correctly.""" + transform, selection, mode = case.input + result = selection_to_transform(selection, transform, mode) + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], case.expected["out0_kind"]) + + +# --------------------------------------------------------------------------- +# intersect — restrict an output domain. Returns (sub_transform, surviving) +# or None when the intersection is empty. +# --------------------------------------------------------------------------- + + +def test_intersect_constant_inside() -> None: + """A ConstantMap whose offset is inside the chunk survives unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(0,), exclusive_max=(10,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert surviving is None + + +def test_intersect_constant_outside() -> None: + """A ConstantMap whose offset is outside the chunk yields None.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + +def test_intersect_dimension_partial() -> None: + """A DimensionMap whose storage-coord range partially overlaps the chunk + narrows the input domain to the surviving slice.""" + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))) + assert result is not None + restricted, surviving = result + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + assert surviving is None + + +def test_intersect_dimension_no_overlap() -> None: + """A DimensionMap whose storage-coord range does not overlap the chunk yields None.""" + t = IndexTransform.from_shape((10,)) + assert t.intersect(IndexDomain(inclusive_min=(20,), exclusive_max=(30,))) is None + + +def test_intersect_dimension_strided() -> None: + """Strided DimensionMap: storage = offset + stride * input. Only inputs that land + in the chunk survive.""" + # offset=1, stride=2, input [0,5): storage = {1, 3, 5, 7, 9}. Chunk [4, 8) -> {5, 7}. + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=2),), + ) + result = t.intersect(IndexDomain(inclusive_min=(4,), exclusive_max=(8,))) + assert result is not None + restricted, _ = result + assert restricted.domain.inclusive_min == (2,) + assert restricted.domain.exclusive_max == (4,) + + +def test_intersect_array_partial() -> None: + """An ArrayMap whose storage coords partially overlap the chunk yields a filtered ArrayMap + plus a `surviving` mask of the input indices that survived.""" + arr = np.array([3, 8, 15, 22], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(ArrayMap(index_array=arr),), + ) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([8, 15])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([1, 2])) + + +def test_intersect_array_disjoint() -> None: + """An ArrayMap whose storage coords are entirely outside the chunk yields None.""" + arr = np.array([1, 2, 3], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + +def test_intersect_2d_mixed_constant_and_dimension() -> None: + """2D output: ConstantMap on dim 0 (inside chunk), DimensionMap on dim 1 (overlaps chunk).""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 5), exclusive_max=(10, 15)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert isinstance(restricted.output[1], DimensionMap) + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + + +# --------------------------------------------------------------------------- +# Direct tests for _intersect_vectorized. +# +# Public `intersect` only calls _intersect_vectorized when the transform has +# 2+ ArrayMap outputs (correlated indices). All public test cases use exactly +# one ArrayMap, so this branch is unreachable from public-surface tests. +# --------------------------------------------------------------------------- + + +def _vectorized_2d_array_map() -> IndexTransform: + """Helper: a vectorized transform over a (3,) input domain with two + correlated ArrayMaps. Storage coords: (1,10), (5,11), (9,12).""" + return IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ), + ) + + +def test_intersect_vectorized_partial_survival() -> None: + """Two correlated ArrayMaps; only points where ALL coords are in-chunk survive.""" + t = _vectorized_2d_array_map() + chunk = IndexDomain(inclusive_min=(0, 10), exclusive_max=(8, 12)) + # Storage points (1,10), (5,11), (9,12). In-chunk: (1,10), (5,11). (9,12) fails dim 1. + result = _intersect_vectorized(t, chunk, [0, 1]) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + assert isinstance(restricted.output[1], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([1, 5])) + np.testing.assert_array_equal(restricted.output[1].index_array, np.array([10, 11])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1])) + + +def test_intersect_vectorized_no_survival() -> None: + """If no point is in-chunk on all dims, returns None.""" + t = _vectorized_2d_array_map() + chunk = IndexDomain(inclusive_min=(20, 20), exclusive_max=(30, 30)) + assert _intersect_vectorized(t, chunk, [0, 1]) is None + + +def test_intersect_vectorized_with_constant_outside_drops_to_none() -> None: + """When a ConstantMap output is outside the chunk, the entire transform fails.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ConstantMap(offset=99), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(10, 20, 5)) + assert _intersect_vectorized(t, chunk, [0, 1]) is None + + +# --------------------------------------------------------------------------- +# translate — shift every coordinate by an offset. +# --------------------------------------------------------------------------- + +_translate_dimension_t = IndexTransform.from_shape((10,)) +_translate_array_t = IndexTransform( + domain=IndexDomain.from_shape((2,)), + output=(ArrayMap(index_array=np.array([5, 10], dtype=np.intp), offset=3),), +) +_translate_constant_t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), +) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(_translate_constant_t, (-5,)), + expected={"out_kind": ConstantMap, "offset": 0}, + id="constant", + ), + Expect( + input=(_translate_dimension_t, (-3,)), + expected={"out_kind": DimensionMap, "offset": -3, "stride": 1}, + id="dimension", + ), + Expect( + input=(_translate_array_t, (-3,)), + expected={"out_kind": ArrayMap, "offset": 0}, + id="array", + ), + ], + ids=lambda c: c.id, +) +def test_translate_success( + case: Expect[tuple[IndexTransform, tuple[int, ...]], dict[str, Any]], +) -> None: + """IndexTransform.translate adjusts each output map's offset uniformly.""" + transform, shift = case.input + result = transform.translate(shift) + out0 = result.output[0] + assert isinstance(out0, case.expected["out_kind"]) + assert out0.offset == case.expected["offset"] + if "stride" in case.expected: + assert isinstance(out0, DimensionMap) + assert out0.stride == case.expected["stride"] + + +def test_translate_2d() -> None: + """A multi-dimensional translate shifts all output dims independently.""" + t = IndexTransform.from_shape((10, 20)) + result = t.translate((-5, -10)) + out0, out1 = result.output + assert isinstance(out0, DimensionMap) + assert out0.offset == -5 + assert isinstance(out1, DimensionMap) + assert out1.offset == -10 From c25bf7bbf0d3e3bbf22e74f44049e28544f43edc Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:15:05 -0400 Subject: [PATCH 13/24] test(_transforms): rewrite composition tests in parametrized style Six (inner_kind, outer_kind) pairs for the compose dispatch matrix collapse from six per-case methods into one parametrized test. The non-trivial private helpers (_compose_single, _compose_dimension, _compose_array) are reached transitively via this matrix and need no direct tests. --- tests/test_transforms/test_composition.py | 349 ++++++++++++---------- 1 file changed, 193 insertions(+), 156 deletions(-) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index b5060a7b9e..cbf9097fd2 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -1,166 +1,203 @@ from __future__ import annotations +from typing import Any + import numpy as np import pytest +from tests.test_transforms.conftest import Expect, ExpectErr from zarr.core._transforms.composition import compose from zarr.core._transforms.domain import IndexDomain from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap from zarr.core._transforms.transform import IndexTransform - -class TestComposeConstantInner: - """Inner = constant. Result is always constant.""" - - def test_constant_inner_any_outer(self) -> None: - outer = IndexTransform.from_shape((5,)) - inner = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(ConstantMap(offset=42),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 42 - - -class TestComposeDimensionInner: - """Inner = DimensionMap.""" - - def test_dimension_inner_constant_outer(self) -> None: - outer = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(ConstantMap(offset=5),), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(DimensionMap(input_dimension=0, offset=10, stride=3),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 25 - - def test_dimension_inner_dimension_outer(self) -> None: - outer = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(DimensionMap(input_dimension=0, offset=5, stride=2),), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(DimensionMap(input_dimension=0, offset=10, stride=3),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], DimensionMap) - assert result.output[0].offset == 25 - assert result.output[0].stride == 6 - assert result.output[0].input_dimension == 0 - - def test_dimension_inner_array_outer(self) -> None: - arr = np.array([0, 2, 4], dtype=np.intp) - outer = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr, offset=5, stride=2),), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(DimensionMap(input_dimension=0, offset=10, stride=3),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ArrayMap) - assert result.output[0].offset == 25 - assert result.output[0].stride == 6 - np.testing.assert_array_equal(result.output[0].index_array, arr) - - -class TestComposeArrayInner: - """Inner = ArrayMap.""" - - def test_array_inner_constant_outer(self) -> None: - inner_arr = np.array([10, 20, 30], dtype=np.intp) - outer = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(ConstantMap(offset=1),), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 20 - - def test_array_inner_array_outer(self) -> None: - outer_arr = np.array([0, 2, 1], dtype=np.intp) - inner_arr = np.array([10, 20, 30], dtype=np.intp) - outer = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=outer_arr, offset=0, stride=1),), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ArrayMap) - expected = np.array([10, 30, 20], dtype=np.intp) - np.testing.assert_array_equal(result.output[0].index_array, expected) - - -class TestComposeMultiDim: - def test_2d_identity_compose(self) -> None: - a = IndexTransform.from_shape((10, 20)) - b = IndexTransform.from_shape((10, 20)) - result = compose(a, b) - assert result.domain.shape == (10, 20) - for i in range(2): - m = result.output[i] - assert isinstance(m, DimensionMap) - assert m.input_dimension == i - assert m.offset == 0 - assert m.stride == 1 - - def test_mixed_map_types(self) -> None: - outer = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=( - ConstantMap(offset=5), - DimensionMap(input_dimension=0, offset=0, stride=1), - ), - ) - inner = IndexTransform( - domain=IndexDomain.from_shape((10, 10)), - output=( - DimensionMap(input_dimension=0, offset=2, stride=3), - DimensionMap(input_dimension=1, offset=0, stride=1), - ), - ) - result = compose(outer, inner) - assert isinstance(result.output[0], ConstantMap) - assert result.output[0].offset == 17 - assert isinstance(result.output[1], DimensionMap) - assert result.output[1].input_dimension == 0 - assert result.output[1].offset == 0 - assert result.output[1].stride == 1 - - def test_rank_mismatch_raises(self) -> None: - outer = IndexTransform.from_shape((10,)) - inner = IndexTransform.from_shape((10, 20)) - with pytest.raises(ValueError, match="rank"): - compose(outer, inner) - - -class TestComposeChain: - def test_three_transforms(self) -> None: - a = IndexTransform.from_shape((100,)) - b = IndexTransform( - domain=IndexDomain.from_shape((100,)), - output=(DimensionMap(input_dimension=0, offset=10, stride=1),), - ) - c = IndexTransform( - domain=IndexDomain.from_shape((100,)), - output=(DimensionMap(input_dimension=0, offset=5, stride=2),), - ) - bc = compose(b, c) - abc = compose(a, bc) - assert isinstance(abc.output[0], DimensionMap) - assert abc.output[0].offset == 25 - assert abc.output[0].stride == 2 +# Inner = ConstantMap: result is always ConstantMap regardless of outer. +_constant_inner = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=42),), +) +_identity_outer_5 = IndexTransform.from_shape((5,)) + +# Inner = DimensionMap with various outers. +_dimension_inner_0_10_3 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), +) +_constant_outer_5 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), +) +_dimension_outer_0_5_2 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), +) +_array_outer_arr = np.array([0, 2, 4], dtype=np.intp) +_array_outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=_array_outer_arr, offset=5, stride=2),), +) + +# Inner = ArrayMap with various outers. +_array_inner_arr = np.array([10, 20, 30], dtype=np.intp) +_array_inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=_array_inner_arr, offset=0, stride=1),), +) +_constant_outer_1 = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=1),), +) +_array_outer_for_array_inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=np.array([0, 2, 1], dtype=np.intp), offset=0, stride=1),), +) + + +@pytest.mark.parametrize( + "case", + [ + # Inner = ConstantMap. Result is constant with the inner's offset, regardless of outer. + Expect( + input=(_identity_outer_5, _constant_inner), + expected={"kind": ConstantMap, "offset": 42}, + id="constant-inner-identity-outer", + ), + # Inner = DimensionMap. + Expect( + input=(_constant_outer_5, _dimension_inner_0_10_3), + expected={"kind": ConstantMap, "offset": 25}, + id="dimension-inner-constant-outer", + ), + Expect( + input=(_dimension_outer_0_5_2, _dimension_inner_0_10_3), + expected={ + "kind": DimensionMap, + "offset": 25, + "stride": 6, + "input_dimension": 0, + }, + id="dimension-inner-dimension-outer", + ), + Expect( + input=(_array_outer, _dimension_inner_0_10_3), + expected={ + "kind": ArrayMap, + "offset": 25, + "stride": 6, + "index_array": _array_outer_arr, + }, + id="dimension-inner-array-outer", + ), + # Inner = ArrayMap. + Expect( + input=(_constant_outer_1, _array_inner), + expected={"kind": ConstantMap, "offset": 20}, + id="array-inner-constant-outer", + ), + Expect( + input=(_array_outer_for_array_inner, _array_inner), + expected={ + "kind": ArrayMap, + "offset": 0, + "stride": 1, + "index_array": np.array([10, 30, 20], dtype=np.intp), + }, + id="array-inner-array-outer", + ), + ], + ids=lambda c: c.id, +) +def test_compose_success( + case: Expect[tuple[IndexTransform, IndexTransform], dict[str, Any]], +) -> None: + """compose dispatches over (inner_kind, outer_kind) pairs and produces the expected result map.""" + outer, inner = case.input + result = compose(outer, inner) + assert len(result.output) == 1 + out0 = result.output[0] + assert isinstance(out0, case.expected["kind"]) + if "offset" in case.expected: + assert out0.offset == case.expected["offset"] + if "stride" in case.expected: + assert isinstance(out0, (DimensionMap, ArrayMap)) + assert out0.stride == case.expected["stride"] + if "input_dimension" in case.expected: + assert isinstance(out0, DimensionMap) + assert out0.input_dimension == case.expected["input_dimension"] + if "index_array" in case.expected: + assert isinstance(out0, ArrayMap) + np.testing.assert_array_equal(out0.index_array, case.expected["index_array"]) + + +def test_compose_2d_identity() -> None: + """Composing two identity 2D transforms yields a 2D identity.""" + a = IndexTransform.from_shape((10, 20)) + b = IndexTransform.from_shape((10, 20)) + result = compose(a, b) + assert result.domain.shape == (10, 20) + for i, m in enumerate(result.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + +def test_compose_mixed_map_types() -> None: + """Outer has heterogeneous output maps; each composes independently with its inner image.""" + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10, 10)), + output=( + DimensionMap(input_dimension=0, offset=2, stride=3), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 17 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + assert result.output[1].offset == 0 + assert result.output[1].stride == 1 + + +def test_compose_chains_associatively() -> None: + """compose(a, compose(b, c)) yields the same offsets/strides as composing in order.""" + a = IndexTransform.from_shape((100,)) + b = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=1),), + ) + c = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + abc = compose(a, compose(b, c)) + assert isinstance(abc.output[0], DimensionMap) + assert abc.output[0].offset == 25 + assert abc.output[0].stride == 2 + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), IndexTransform.from_shape((10, 20))), + msg="rank", + exception_cls=ValueError, + id="outer-output-rank-vs-inner-input-rank-mismatch", + ), + ], + ids=lambda c: c.id, +) +def test_compose_errors(case: ExpectErr[tuple[IndexTransform, IndexTransform]]) -> None: + """compose rejects rank-mismatched outer/inner pairs.""" + outer, inner = case.input + with pytest.raises(case.exception_cls, match=case.msg): + compose(outer, inner) From 20df22cc9a0151a6a086fdad1103025c97c16446 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:17:11 -0400 Subject: [PATCH 14/24] test(_transforms): rewrite chunk_resolution tests in parametrized style iter_chunk_transforms cases parametrized over (transform, grid) inputs; sub_transform_to_selections cases parametrized over the output-map-kind matrix. Adds 3 cases that exercise the previously-untested out_indices parameter on sub_transform_to_selections: orthogonal-array-with- out_indices, vectorized-with-out_indices, and the implicit fallback when out_indices is None. --- .../test_transforms/test_chunk_resolution.py | 406 +++++++++++------- 1 file changed, 241 insertions(+), 165 deletions(-) diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py index ba27964028..ffc196daaa 100644 --- a/tests/test_transforms/test_chunk_resolution.py +++ b/tests/test_transforms/test_chunk_resolution.py @@ -1,7 +1,11 @@ from __future__ import annotations +from typing import Any + import numpy as np +import pytest +from tests.test_transforms.conftest import Expect from zarr.core._transforms.chunk_resolution import ( iter_chunk_transforms, sub_transform_to_selections, @@ -11,171 +15,243 @@ from zarr.core._transforms.transform import IndexTransform from zarr.core.chunk_grids import ChunkGrid, FixedDimension +# --------------------------------------------------------------------------- +# iter_chunk_transforms — for a transform composed against a ChunkGrid, yield +# (chunk_coords, sub_transform, out_indices) for each touched chunk. +# --------------------------------------------------------------------------- -class TestChunkResolutionIdentity: - def test_single_chunk(self) -> None: - """Array fits in one chunk.""" - t = IndexTransform.from_shape((10,)) - grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=10),)) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 1 - coords, sub_t, _ = results[0] - assert coords == (0,) - assert sub_t.domain.shape == (10,) - - def test_multiple_chunks_1d(self) -> None: - """1D array spanning 3 chunks.""" - t = IndexTransform.from_shape((30,)) - grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 3 - coords_list = [r[0] for r in results] - assert (0,) in coords_list - assert (1,) in coords_list - assert (2,) in coords_list - - def test_multiple_chunks_2d(self) -> None: - """2D array spanning 2x3 chunks.""" - t = IndexTransform.from_shape((20, 30)) - grid = ChunkGrid( - dimensions=( - FixedDimension(size=10, extent=20), - FixedDimension(size=10, extent=30), - ) - ) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 6 - coords_list = [r[0] for r in results] - assert (0, 0) in coords_list - assert (1, 2) in coords_list - - -class TestChunkResolutionSliced: - def test_slice_within_chunk(self) -> None: - """Slice that falls within a single chunk.""" - t = IndexTransform.from_shape((100,))[5:8] - grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 1 - coords, sub_t, _ = results[0] - assert coords == (0,) - assert isinstance(sub_t.output[0], DimensionMap) - assert sub_t.output[0].offset == 5 - - def test_slice_across_chunks(self) -> None: - """Slice that spans two chunks.""" - t = IndexTransform.from_shape((100,))[8:15] - grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 2 - coords_list = [r[0] for r in results] - assert (0,) in coords_list - assert (1,) in coords_list - - -class TestChunkResolutionConstant: - def test_integer_index(self) -> None: - """Integer index produces constant map — single chunk per constant dim.""" - t = IndexTransform.from_shape((100, 100))[25, :] - grid = ChunkGrid( - dimensions=( - FixedDimension(size=10, extent=100), - FixedDimension(size=10, extent=100), - ) - ) - results = list(iter_chunk_transforms(t, grid)) - assert len(results) == 10 - for coords, _, _ in results: - assert coords[0] == 2 - - -class TestChunkResolutionArray: - def test_array_index(self) -> None: - """Array index map — chunks determined by array values.""" - idx = np.array([5, 15, 25], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=idx),), - ) - grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) - results = list(iter_chunk_transforms(t, grid)) - coords_list = [r[0] for r in results] - assert (0,) in coords_list - assert (1,) in coords_list - assert (2,) in coords_list - - -class TestSubTransformToSelections: - def test_constant_map(self) -> None: - """ConstantMap produces int selection + drop axis.""" - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(ConstantMap(offset=5),), - ) - chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) - assert chunk_sel == (5,) - assert out_sel == () - assert drop_axes == () - - def test_dimension_map_stride_1(self) -> None: - """DimensionMap with stride=1 produces contiguous slice.""" - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=(DimensionMap(input_dimension=0, offset=3, stride=1),), - ) - chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) - assert chunk_sel == (slice(3, 13, 1),) - assert out_sel == (slice(0, 10),) - assert drop_axes == () - - def test_dimension_map_strided(self) -> None: - """DimensionMap with stride>1 produces strided slice.""" - t = IndexTransform( - domain=IndexDomain.from_shape((5,)), - output=(DimensionMap(input_dimension=0, offset=2, stride=3),), - ) - chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) - assert chunk_sel == (slice(2, 17, 3),) - assert out_sel == (slice(0, 5),) - assert drop_axes == () - - def test_array_map(self) -> None: - """ArrayMap produces integer array selection.""" - arr = np.array([1, 5, 9], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr, offset=0, stride=1),), - ) - chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) - assert isinstance(chunk_sel[0], np.ndarray) - np.testing.assert_array_equal(chunk_sel[0], arr) - # Without chunk_mask, out_sel falls back to domain-based slices - assert out_sel == (slice(0, 3),) - assert drop_axes == () - - def test_array_map_with_offset_stride(self) -> None: - """ArrayMap with offset and stride computes storage coords.""" - arr = np.array([0, 1, 2], dtype=np.intp) - t = IndexTransform( - domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr, offset=10, stride=5),), + +def _grid_1d(size: int, extent: int) -> ChunkGrid: + return ChunkGrid(dimensions=(FixedDimension(size=size, extent=extent),)) + + +def _grid_2d(size0: int, extent0: int, size1: int, extent1: int) -> ChunkGrid: + return ChunkGrid( + dimensions=( + FixedDimension(size=size0, extent=extent0), + FixedDimension(size=size1, extent=extent1), ) - chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) - assert isinstance(chunk_sel[0], np.ndarray) - np.testing.assert_array_equal(chunk_sel[0], np.array([10, 15, 20])) - assert drop_axes == () - - def test_mixed_maps_2d(self) -> None: - """Mix of ConstantMap and DimensionMap.""" - t = IndexTransform( - domain=IndexDomain.from_shape((10,)), - output=( - ConstantMap(offset=5), - DimensionMap(input_dimension=0, offset=0, stride=1), + ) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10,)), _grid_1d(10, 10)), + expected={"n_chunks": 1, "coords": [(0,)]}, + id="single-chunk-fits-array", + ), + Expect( + input=(IndexTransform.from_shape((30,)), _grid_1d(10, 30)), + expected={"n_chunks": 3, "coords": [(0,), (1,), (2,)]}, + id="three-chunks-1d", + ), + Expect( + input=(IndexTransform.from_shape((20, 30)), _grid_2d(10, 20, 10, 30)), + expected={ + "n_chunks": 6, + "coords": [(i, j) for i in (0, 1) for j in (0, 1, 2)], + }, + id="six-chunks-2x3", + ), + Expect( + input=(IndexTransform.from_shape((100,))[5:8], _grid_1d(10, 100)), + expected={"n_chunks": 1, "coords": [(0,)]}, + id="slice-within-chunk", + ), + Expect( + input=(IndexTransform.from_shape((100,))[8:15], _grid_1d(10, 100)), + expected={"n_chunks": 2, "coords": [(0,), (1,)]}, + id="slice-across-two-chunks", + ), + ], + ids=lambda c: c.id, +) +def test_iter_chunk_transforms_yields_expected_chunks( + case: Expect[tuple[IndexTransform, ChunkGrid], dict[str, Any]], +) -> None: + """iter_chunk_transforms enumerates all chunks intersected by the transform.""" + transform, grid = case.input + results = list(iter_chunk_transforms(transform, grid)) + assert len(results) == case.expected["n_chunks"] + coords_list = [r[0] for r in results] + for expected_coord in case.expected["coords"]: + assert expected_coord in coords_list + + +def test_iter_chunk_transforms_constant_map_picks_single_chunk_per_dim() -> None: + """An integer index produces a ConstantMap, fixing the chunk on that dim. + + arr[25, :] over a 10-element chunk grid: chunk index for storage 25 is 2, + so every chunk yielded has coords[0] == 2. The free dim (the slice) iterates.""" + t = IndexTransform.from_shape((100, 100))[25, :] + grid = _grid_2d(10, 100, 10, 100) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 10 + for coords, _, _ in results: + assert coords[0] == 2 + + +def test_iter_chunk_transforms_array_map_lists_chunks_for_array_entries() -> None: + """An ArrayMap yields chunks for each unique chunk-id of its index_array entries.""" + idx = np.array([5, 15, 25], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=idx),), + ) + results = list(iter_chunk_transforms(t, _grid_1d(10, 30))) + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + +def test_iter_chunk_transforms_within_chunk_offset_is_local() -> None: + """The yielded sub-transform's output is in chunk-local coordinates, + so a slice arr[5:8] in chunk 0 yields offset=5 (the offset within the chunk).""" + t = IndexTransform.from_shape((100,))[5:8] + grid = _grid_1d(10, 100) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + _, sub_t, _ = results[0] + assert isinstance(sub_t.output[0], DimensionMap) + assert sub_t.output[0].offset == 5 + + +# --------------------------------------------------------------------------- +# sub_transform_to_selections — convert a chunk-local sub-transform into +# (chunk_selection, out_selection, drop_axes) tuples for the codec pipeline. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), ), - ) - chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) - assert chunk_sel[0] == 5 - assert chunk_sel[1] == slice(0, 10, 1) - # drop_axes is empty — integer in chunk_sel naturally drops the dim via numpy - assert drop_axes == () + expected={ + "chunk_sel": (5,), + "out_sel": (), + "drop_axes": (), + }, + id="constant-map-yields-int-selection-no-out", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=3, stride=1),), + ), + expected={ + "chunk_sel": (slice(3, 13, 1),), + "out_sel": (slice(0, 10),), + "drop_axes": (), + }, + id="dimension-map-stride-1-yields-contiguous-slice", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=2, stride=3),), + ), + expected={ + "chunk_sel": (slice(2, 17, 3),), + "out_sel": (slice(0, 5),), + "drop_axes": (), + }, + id="dimension-map-strided-yields-strided-slice", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ), + expected={ + "chunk_sel_kinds": (int, slice), + "chunk_sel_values": (5, slice(0, 10, 1)), + "drop_axes": (), + }, + id="mixed-2d-constant-and-dimension", + ), + ], + ids=lambda c: c.id, +) +def test_sub_transform_to_selections_basic(case: Expect[IndexTransform, dict[str, Any]]) -> None: + """sub_transform_to_selections produces the expected (chunk_sel, out_sel, drop_axes) for non-array maps.""" + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(case.input) + if "chunk_sel" in case.expected: + assert chunk_sel == case.expected["chunk_sel"] + if "chunk_sel_kinds" in case.expected: + for got, expected_kind in zip(chunk_sel, case.expected["chunk_sel_kinds"], strict=True): + assert isinstance(got, expected_kind) + if "chunk_sel_values" in case.expected: + for got, expected_val in zip(chunk_sel, case.expected["chunk_sel_values"], strict=True): + assert got == expected_val + if "out_sel" in case.expected: + assert out_sel == case.expected["out_sel"] + assert drop_axes == case.expected["drop_axes"] + + +def test_sub_transform_to_selections_array_map_no_offset() -> None: + """An ArrayMap with offset=0, stride=1 produces the index_array itself as chunk_sel.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=0, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], arr) + # Without out_indices, out_sel falls back to a domain-derived slice. + assert out_sel == (slice(0, 3),) + assert drop_axes == () + + +def test_sub_transform_to_selections_array_map_with_offset_stride() -> None: + """An ArrayMap with non-zero offset/stride is materialized into storage coords.""" + arr = np.array([0, 1, 2], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=10, stride=5),), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], np.array([10, 15, 20])) + assert drop_axes == () + + +def test_sub_transform_to_selections_orthogonal_array_with_out_indices() -> None: + """When out_indices is supplied with a single ArrayMap (orthogonal mode), + out_sel uses the supplied scatter indices rather than a domain slice.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + out_indices = np.array([0, 2], dtype=np.intp) + _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + assert len(out_sel) == 1 + assert isinstance(out_sel[0], np.ndarray) + np.testing.assert_array_equal(out_sel[0], out_indices) + + +def test_sub_transform_to_selections_vectorized_with_out_indices() -> None: + """When out_indices is supplied with 2+ correlated ArrayMaps (vectorized mode), + out_sel collapses to a single shared scatter array.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + assert len(out_sel) == 1 + np.testing.assert_array_equal(out_sel[0], out_indices) From 2d659a1be09aad3d790b1272d52a99603394f45c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:18:48 -0400 Subject: [PATCH 15/24] test(_transforms): fix mypy errors in cross-file mypy run When mypy runs over the rewritten test files together (rather than one at a time), it tightens up two type checks that were lax in isolation: - selection_to_transform's mode parameter is a Literal, not str. - assert_array_equal expects ndarray, but out_sel[0] is typed as slice | ndarray. Add an isinstance check. These were latent issues in the rewrites; they only surfaced when mypy saw all the test files at once. --- tests/test_transforms/test_chunk_resolution.py | 1 + tests/test_transforms/test_transform.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py index ffc196daaa..367f8b8d32 100644 --- a/tests/test_transforms/test_chunk_resolution.py +++ b/tests/test_transforms/test_chunk_resolution.py @@ -254,4 +254,5 @@ def test_sub_transform_to_selections_vectorized_with_out_indices() -> None: out_indices = np.array([0, 1], dtype=np.intp) _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) assert len(out_sel) == 1 + assert isinstance(out_sel[0], np.ndarray) np.testing.assert_array_equal(out_sel[0], out_indices) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index ad8e9f925b..44a02b6a20 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import numpy as np import pytest @@ -475,7 +475,9 @@ def test_vindex_success(case: Expect[tuple[IndexTransform, Any], tuple[int, ...] ids=lambda c: c.id, ) def test_selection_to_transform_success( - case: Expect[tuple[IndexTransform, Any, str], dict[str, Any]], + case: Expect[ + tuple[IndexTransform, Any, Literal["basic", "orthogonal", "vectorized"]], dict[str, Any] + ], ) -> None: """selection_to_transform dispatches to basic/orthogonal/vectorized correctly.""" transform, selection, mode = case.input From f7d05da68ee84a2d40139401fea6dbf062dba214 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:40:58 -0400 Subject: [PATCH 16/24] docs(_transforms): convert remaining RST code block to markdown in domain.py Missed by the earlier polish commit (fca29269): the docstring at the top of domain.py used an RST `::` indented code block. Convert to a fenced markdown block to match the rest of the package's docstrings. --- src/zarr/core/_transforms/domain.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zarr/core/_transforms/domain.py b/src/zarr/core/_transforms/domain.py index e43ec30836..e2f9175225 100644 --- a/src/zarr/core/_transforms/domain.py +++ b/src/zarr/core/_transforms/domain.py @@ -1,10 +1,12 @@ """Index domains — rectangular regions in N-dimensional integer space. An `IndexDomain` represents the set of valid coordinates for an array or -array view. It is the cartesian product of per-dimension integer ranges:: +array view. It is the cartesian product of per-dimension integer ranges: - IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) - # represents {(i, j) : 2 <= i < 10, 5 <= j < 20} +```python +IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) +# represents {(i, j) : 2 <= i < 10, 5 <= j < 20} +``` Unlike NumPy, domains can have **non-zero origins**. After slicing `arr[5:10]`, the result has origin 5 and shape 5 — coordinates 5 through From 895b1df9f0962b48cb3c74ac9624f02d1d9e6e22 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:41:25 -0400 Subject: [PATCH 17/24] test(_transforms): add error tests for previously-untested public branches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code review surfaced six public functions/methods with reachable error branches and no test coverage. Add parametrized error tests (or standalone tests, where there is one trivial input) to close the gaps: - IndexTransform.oindex: negative slice step raises IndexError. - IndexTransform.vindex: negative slice step raises IndexError. - IndexTransform.intersect: rank-mismatched output_domain raises ValueError. - IndexTransform.translate: wrong-length shift raises ValueError. - selection_to_transform: unknown mode string raises ValueError. - compose: 2D ArrayMap inner with non-constant outer raises NotImplementedError. Test count: 145 → 152. --- tests/test_transforms/test_composition.py | 21 +++++- tests/test_transforms/test_transform.py | 82 +++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index cbf9097fd2..6ec1c87214 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -193,11 +193,30 @@ def test_compose_chains_associatively() -> None: exception_cls=ValueError, id="outer-output-rank-vs-inner-input-rank-mismatch", ), + ExpectErr( + input=( + # Outer is a non-constant 2D identity transform. + IndexTransform.from_shape((3, 2)), + # Inner has a 2D ArrayMap. _compose_array's general multi-dim + # path raises NotImplementedError for this combination. + IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=( + ArrayMap( + index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp) + ), + ), + ), + ), + msg="not yet supported", + exception_cls=NotImplementedError, + id="multi-d-array-inner-non-constant-outer", + ), ], ids=lambda c: c.id, ) def test_compose_errors(case: ExpectErr[tuple[IndexTransform, IndexTransform]]) -> None: - """compose rejects rank-mismatched outer/inner pairs.""" + """compose raises on rank mismatch and on the unsupported multi-d-array compose path.""" outer, inner = case.input with pytest.raises(case.exception_cls, match=case.msg): compose(outer, inner) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index 44a02b6a20..b07f14839d 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -369,6 +369,25 @@ def test_oindex_success(case: Expect[tuple[IndexTransform, Any], dict[str, Any]] assert isinstance(result.output[1], case.expected["out1_kind"]) +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), (slice(None, None, -1),)), + msg="slice step must be positive", + exception_cls=IndexError, + id="negative-slice-step", + ), + ], + ids=lambda c: c.id, +) +def test_oindex_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.oindex rejects non-positive slice steps.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.oindex[selection] + + # --------------------------------------------------------------------------- # vindex (vectorized indexing) # --------------------------------------------------------------------------- @@ -424,6 +443,25 @@ def test_vindex_success(case: Expect[tuple[IndexTransform, Any], tuple[int, ...] assert result.domain.shape == case.expected +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), (slice(None, None, -1),)), + msg="slice step must be positive", + exception_cls=IndexError, + id="negative-slice-step", + ), + ], + ids=lambda c: c.id, +) +def test_vindex_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.vindex rejects non-positive slice steps.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.vindex[selection] + + # --------------------------------------------------------------------------- # selection_to_transform — the public dispatch front door for all three modes. # Sanity check that each mode produces the expected output kind. @@ -486,6 +524,17 @@ def test_selection_to_transform_success( assert isinstance(result.output[0], case.expected["out0_kind"]) +def test_selection_to_transform_unknown_mode_errors() -> None: + """selection_to_transform rejects unknown indexing modes. + + The `mode` parameter is typed as `Literal["basic", "orthogonal", "vectorized"]`, + so this test bypasses static type checking to exercise the runtime guard. + """ + t = IndexTransform.from_shape((10,)) + with pytest.raises(ValueError, match="Unknown mode"): + selection_to_transform(slice(None), t, "diagonal") + + # --------------------------------------------------------------------------- # intersect — restrict an output domain. Returns (sub_transform, surviving) # or None when the intersection is empty. @@ -595,6 +644,14 @@ def test_intersect_2d_mixed_constant_and_dimension() -> None: assert restricted.domain.exclusive_max == (10,) +def test_intersect_rank_mismatch_errors() -> None: + """intersect rejects an output_domain whose rank differs from the transform's output rank.""" + t = IndexTransform.from_shape((10,)) # output rank 1 + chunk = IndexDomain.from_shape((10, 20)) # rank 2 + with pytest.raises(ValueError, match="output rank"): + t.intersect(chunk) + + # --------------------------------------------------------------------------- # Direct tests for _intersect_vectorized. # @@ -712,3 +769,28 @@ def test_translate_2d() -> None: assert out0.offset == -5 assert isinstance(out1, DimensionMap) assert out1.offset == -10 + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10, 20)), (1,)), + msg="shift must have length", + exception_cls=ValueError, + id="shift-too-short", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), (1, 2)), + msg="shift must have length", + exception_cls=ValueError, + id="shift-too-long", + ), + ], + ids=lambda c: c.id, +) +def test_translate_errors(case: ExpectErr[tuple[IndexTransform, tuple[int, ...]]]) -> None: + """IndexTransform.translate rejects shifts whose length doesn't match output_rank.""" + transform, shift = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.translate(shift) From edae196bcb3ec6db59f422fa2c2c26f1a5c14848 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 10:55:18 -0400 Subject: [PATCH 18/24] test(_transforms): add hypothesis property test for compose associativity An IndexTransform models a function from input coords to output coords; function composition is associative by definition. This test verifies that the implementation preserves that algebraic property by sampling random affine triples (a, b, c) with compatible ranks and checking that compose(compose(a, b), c) evaluates the same as compose(a, compose(b, c)) at random points in a's domain. 200 hypothesis examples; passes cleanly. Restricted to DimensionMap + ConstantMap outputs (the affine subset). ArrayMap composition has implementation-level branching that depends on outer structure and would need a more careful generator to avoid the NotImplementedError path; the affine case is the algebraic core, the ArrayMap case is deferred to a follow-up. --- tests/test_transforms/test_composition.py | 122 ++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index 6ec1c87214..7a9649d0b3 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -220,3 +220,125 @@ def test_compose_errors(case: ExpectErr[tuple[IndexTransform, IndexTransform]]) outer, inner = case.input with pytest.raises(case.exception_cls, match=case.msg): compose(outer, inner) + + +# --------------------------------------------------------------------------- +# Associativity property test. +# +# An IndexTransform models a function from input coords to output coords; +# function composition is associative by definition. Verify the implementation +# preserves that algebraic property by sampling random affine triples +# `(a, b, c)` with compatible ranks and checking that +# compose(compose(a, b), c) +# evaluates the same as +# compose(a, compose(b, c)) +# at randomly-chosen points in `a`'s domain. +# +# Restricted to DimensionMap + ConstantMap outputs (the affine subset). +# ArrayMap composition has implementation-level branching that depends on +# outer structure, and would need a more careful generator to avoid the +# NotImplementedError path; saved for a follow-up. +# --------------------------------------------------------------------------- + +pytest.importorskip("hypothesis") + +from hypothesis import assume, given, settings # noqa: E402 +from hypothesis import strategies as st # noqa: E402 + + +def _evaluate(transform: IndexTransform, user_coord: tuple[int, ...]) -> tuple[int, ...]: + """Evaluate a transform at a single input coordinate. + + Restricted to DimensionMap + ConstantMap outputs; `ArrayMap` is unsupported + here because the property test only generates affine triples. + """ + storage: list[int] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + storage.append(m.offset) + elif isinstance(m, DimensionMap): + storage.append(m.offset + m.stride * user_coord[m.input_dimension]) + else: + raise TypeError(f"property test should not generate {type(m).__name__}; got {m!r}") + return tuple(storage) + + +def _affine_output_map(input_rank: int, draw: st.DrawFn) -> ConstantMap | DimensionMap: + """Generate one ConstantMap or DimensionMap output map. + + DimensionMap requires input_rank >= 1; falls back to ConstantMap otherwise. + Offsets and strides are kept small to avoid integer overflow during + repeated composition. + """ + if input_rank == 0: + return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) + kind = draw(st.sampled_from(["constant", "dimension"])) + if kind == "constant": + return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) + # stride must be non-zero; sample sign and magnitude separately. + stride_mag = draw(st.integers(min_value=1, max_value=3)) + stride_sign = draw(st.sampled_from([-1, 1])) + return DimensionMap( + input_dimension=draw(st.integers(min_value=0, max_value=input_rank - 1)), + offset=draw(st.integers(min_value=-10, max_value=10)), + stride=stride_sign * stride_mag, + ) + + +@st.composite +def _affine_transform(draw: st.DrawFn, input_rank: int, output_rank: int) -> IndexTransform: + """Generate an affine IndexTransform with the requested ranks.""" + domain_shape = tuple(draw(st.integers(min_value=1, max_value=8)) for _ in range(input_rank)) + domain = IndexDomain.from_shape(domain_shape) + output = tuple(_affine_output_map(input_rank, draw) for _ in range(output_rank)) + return IndexTransform(domain=domain, output=output) + + +@st.composite +def _affine_triple( + draw: st.DrawFn, +) -> tuple[IndexTransform, IndexTransform, IndexTransform]: + """Generate three rank-compatible affine transforms (a, b, c).""" + m = draw(st.integers(min_value=1, max_value=3)) # a's input rank + n = draw(st.integers(min_value=1, max_value=3)) # a's output / b's input rank + p = draw(st.integers(min_value=1, max_value=3)) # b's output / c's input rank + q = draw(st.integers(min_value=1, max_value=3)) # c's output rank + a = draw(_affine_transform(input_rank=m, output_rank=n)) + b = draw(_affine_transform(input_rank=n, output_rank=p)) + c = draw(_affine_transform(input_rank=p, output_rank=q)) + return a, b, c + + +@settings(max_examples=200, deadline=None) +@given(triple=_affine_triple(), data=st.data()) +def test_compose_is_associative( + triple: tuple[IndexTransform, IndexTransform, IndexTransform], + data: st.DataObject, +) -> None: + """For affine transforms, compose(compose(a,b),c) and compose(a,compose(b,c)) + evaluate identically at every point in a's domain.""" + a, b, c = triple + left = compose(compose(a, b), c) + right = compose(a, compose(b, c)) + + # Sanity: both compositions agree on rank and domain. + assert left.input_rank == right.input_rank + assert left.output_rank == right.output_rank + assert left.domain == right.domain + + # Sample a few points from a's domain and compare evaluations. + if a.input_rank == 0: + coord: tuple[int, ...] = () + else: + coord = tuple( + data.draw( + st.integers( + min_value=a.domain.inclusive_min[d], + max_value=a.domain.exclusive_max[d] - 1, + ) + ) + for d in range(a.input_rank) + ) + assume(a.domain.contains(coord)) + + assert _evaluate(left, coord) == _evaluate(right, coord) From 2a108500acda2c3cb257ac5ad21af5f582c054c4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 7 May 2026 12:16:01 -0400 Subject: [PATCH 19/24] refactor(_transforms): make ArrayMap.input_dimensions explicit Previously, ArrayMap relied on an implicit convention to encode which input dims it was parameterized by: the array shape was matched against the input domain's shape by length, and "correlation" between two ArrayMaps (vectorized indexing) was inferred from "two ArrayMaps exist in the same transform." This refactor makes parameterization explicit: - ArrayMap gains `input_dimensions: tuple[int, ...]` as a required field. Each axis of the index_array corresponds to one entry in input_dimensions, in order. - IndexTransform.__post_init__ enforces that ArrayMap.index_array.shape equals the input domain's extent on input_dimensions. Out-of-range and duplicate input dims are also rejected. - Two ArrayMaps are correlated iff their input_dimensions overlap. The vectorized-indexing detection in `intersect` and `chunk_resolution` uses this rule rather than the brittle "count >= 2" heuristic. - _intersect_vectorized correctly preserves preserved (non-correlated) input dims when collapsing the correlated dims into a single surviving-points dim. - _intersect on an orthogonal ArrayMap now correctly shrinks the input dim it parameterizes when filtering. (Previous code mutated the array without updating the input domain, which would now fail __post_init__.) - _compose_array uses input_dimensions to determine parameterization uniformly; the previous arr.ndim-and-len(outer.output) heuristic is replaced with a clearer single-input-dim case. - _apply_basic_indexing rebuilds ArrayMap input_dimensions correctly for newaxis (preserves the array's parameterization, just shifts the referenced dim index). - _apply_oindex and _apply_vindex set input_dimensions on every newly- constructed ArrayMap. - Drop unused helpers _reindex_array and _reindex_array_oindex. - Test fixtures and assertions updated to use the new ArrayMap shape. Why: the implicit-correlation-by-position convention was a footgun for anyone constructing transforms directly. Making input_dimensions explicit lets the type system enforce shape and correlation invariants, simplifies the composition and intersect implementations, and produces a model that can be coherently explained. All 153 transforms tests pass (including the 200-example associativity property test); 150 baseline indexing tests unchanged. --- src/zarr/core/_transforms/chunk_resolution.py | 16 +- src/zarr/core/_transforms/composition.py | 58 ++-- src/zarr/core/_transforms/output_map.py | 31 +- src/zarr/core/_transforms/transform.py | 277 +++++++++++------- .../test_transforms/test_chunk_resolution.py | 12 +- tests/test_transforms/test_composition.py | 16 +- tests/test_transforms/test_output_map.py | 20 +- tests/test_transforms/test_transform.py | 65 +++- 8 files changed, 340 insertions(+), 155 deletions(-) diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py index 179e2ed2c8..d9a0420e39 100644 --- a/src/zarr/core/_transforms/chunk_resolution.py +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -180,11 +180,17 @@ def sub_transform_to_selections( # Build out_sel: one entry per non-dropped output dim. out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] - # Vectorized: multiple correlated ArrayMaps share one scatter index - is_vectorized = ( - out_indices is not None - and sum(1 for m in sub_transform.output if isinstance(m, ArrayMap)) >= 2 - ) + # Vectorized: 2+ ArrayMaps that share at least one input dimension are + # correlated; they all index into a single shared scatter array. + is_vectorized = False + if out_indices is not None: + seen_input_dims: set[int] = set() + for m in sub_transform.output: + if isinstance(m, ArrayMap): + if seen_input_dims & set(m.input_dimensions): + is_vectorized = True + break + seen_input_dims.update(m.input_dimensions) if is_vectorized: assert out_indices is not None diff --git a/src/zarr/core/_transforms/composition.py b/src/zarr/core/_transforms/composition.py index 40bba89e95..86c05503e2 100644 --- a/src/zarr/core/_transforms/composition.py +++ b/src/zarr/core/_transforms/composition.py @@ -64,6 +64,7 @@ def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> Output if isinstance(outer_map, ArrayMap): return ArrayMap( index_array=outer_map.index_array, + input_dimensions=outer_map.input_dimensions, offset=offset_i + stride_i * outer_map.offset, stride=stride_i * outer_map.stride, ) @@ -74,39 +75,62 @@ def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> Output def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap: """Compose when inner is an ArrayMap. - storage = offset_i + stride_i * arr_i[intermediate] - We need to evaluate arr_i at the intermediate coordinates produced by outer. + storage = offset_i + stride_i * arr_i[intermediate[input_dimensions[0]], + intermediate[input_dimensions[1]], ...] + + For each axis k of arr_i, the corresponding intermediate dim is + inner_map.input_dimensions[k] = d. We need to evaluate arr_i over the + product of `outer.output[d]` for each such d. + + All-constant outer: collapse to a single ConstantMap. + + Single 1-D inner array, single outer output: evaluate arr_i along the + one outer output's parameterization. """ arr_i = inner_map.index_array offset_i = inner_map.offset stride_i = inner_map.stride + in_dims_i = inner_map.input_dimensions - # Check if all outer outputs are constant - all_constant = all(isinstance(m, ConstantMap) for m in outer.output) - - if all_constant: - # Evaluate arr_i at the single constant point - idx = tuple(m.offset for m in outer.output if isinstance(m, ConstantMap)) + # All-constant outer: arr_i is evaluated at a single fixed point. + if all(isinstance(m, ConstantMap) for m in outer.output): + idx = tuple(outer.output[d].offset for d in in_dims_i) value = int(arr_i[idx]) return ConstantMap(offset=offset_i + stride_i * value) - # For 1D inner array with a single outer output (simple case) - if arr_i.ndim == 1 and len(outer.output) == 1: - outer_map = outer.output[0] + # 1-D inner array, single referenced outer output. + if len(in_dims_i) == 1: + dim_i = in_dims_i[0] + outer_map = outer.output[dim_i] if isinstance(outer_map, DimensionMap): - dim_size = outer.domain.shape[outer_map.input_dimension] - user_indices = np.arange(dim_size, dtype=np.intp) + # Evaluate arr_i at the outer DimensionMap's range. + input_d = outer_map.input_dimension + input_lo = outer.domain.inclusive_min[input_d] + input_hi = outer.domain.exclusive_max[input_d] + user_indices = np.arange(input_lo, input_hi, dtype=np.intp) intermediate_vals = outer_map.offset + outer_map.stride * user_indices new_arr = arr_i[intermediate_vals] - return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) + return ArrayMap( + index_array=new_arr, + input_dimensions=(input_d,), + offset=offset_i, + stride=stride_i, + ) if isinstance(outer_map, ArrayMap): + # Evaluate arr_i at outer's array values; new array inherits outer's + # parameterization. intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array new_arr = arr_i[intermediate_vals] - return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) - - # General multi-dim case: not yet implemented + return ArrayMap( + index_array=new_arr, + input_dimensions=outer_map.input_dimensions, + offset=offset_i, + stride=stride_i, + ) + + # General multi-dim case: not yet implemented. raise NotImplementedError( "Composing a multi-dimensional inner array map with non-constant outer maps " "is not yet supported." diff --git a/src/zarr/core/_transforms/output_map.py b/src/zarr/core/_transforms/output_map.py index f1b32aa95e..c4b2dca6eb 100644 --- a/src/zarr/core/_transforms/output_map.py +++ b/src/zarr/core/_transforms/output_map.py @@ -1,13 +1,15 @@ """Output index maps — three representations of a set of integer coordinates. An output index map describes, for one dimension of storage, which coordinates -an array access will touch. Conceptually it is a **set of integers**. Three +an array access will touch. Conceptually it is a **set of integers** (1-D) +or a structured set of integers parameterized by some input dims. Three representations cover the cases that arise in practice: - `ConstantMap(offset=5)` — a singleton set: `{5}` - `DimensionMap(input_dimension=0, offset=3, stride=2)` over input `[0, 5)` — an arithmetic progression: `{3, 5, 7, 9, 11}` -- `ArrayMap(index_array=[1, 5, 9])` — an explicit enumeration: `{1, 5, 9}` +- `ArrayMap(index_array=[1, 5, 9], input_dimensions=(0,))` — an explicit + enumeration parameterized by input dim 0: `{1, 5, 9}` indexed by `i ∈ [0, 3)`. Every output map supports two set-theoretic operations (defined on `IndexTransform`, which provides the input domain context these maps lack): @@ -30,6 +32,14 @@ Collapsing everything to `ArrayMap` would be correct but wasteful — a billion-element slice would materialize a billion coordinates just to group them by chunk, when `DimensionMap` does it with three integers. + +Correlation between `ArrayMap`s is encoded by `input_dimensions`. Two +`ArrayMap`s in the same transform that share an input dim are correlated: +their values at the same input coordinate belong to the same storage point +(this is how vectorized indexing is represented). Two `ArrayMap`s with +disjoint `input_dimensions` are independent (orthogonal-style). The +type-level distinction prevents the older convention of inferring +correlation from array length and rank. """ from __future__ import annotations @@ -69,13 +79,24 @@ class DimensionMap: @dataclass(frozen=True, slots=True) class ArrayMap: - """An explicit enumeration of storage coordinates. + """An explicit enumeration of storage coordinates parameterized by input dims. + + Represents `{offset + stride * index_array[i_d0, i_d1, ...]}` where + `(i_d0, i_d1, ...)` ranges over the input coordinates on the dimensions + listed in `input_dimensions`. + + Shape contract (enforced by the enclosing `IndexTransform.__post_init__`): + `index_array.shape` equals the input domain's extent on the dimensions + in `input_dimensions`, in order. For example, if `input_dimensions=(0, 2)` + and the enclosing transform's domain is `(5, 3, 4)`, then + `index_array.shape == (5, 4)`. - Represents `{offset + stride * index_array[i] : i in input_range}`. - Arises from fancy indexing (e.g., `arr[[1, 5, 9]]` or boolean masks). + Arises from fancy indexing (e.g., `arr.oindex[[1, 5, 9]]`, boolean masks + via vindex, etc.). """ index_array: npt.NDArray[np.intp] + input_dimensions: tuple[int, ...] offset: int = 0 stride: int = 1 diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index f8abfd9285..330b523888 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -64,11 +64,29 @@ def __post_init__(self) -> None: f"output[{i}].input_dimension = {m.input_dimension} " f"is out of range for input rank {self.domain.ndim}" ) - elif isinstance(m, ArrayMap) and m.index_array.ndim > self.domain.ndim: - raise ValueError( - f"output[{i}].index_array has {m.index_array.ndim} dims " - f"but input domain has {self.domain.ndim} dims" - ) + elif isinstance(m, ArrayMap): + # Every input dim referenced must be in range. + for d in m.input_dimensions: + if d < 0 or d >= self.domain.ndim: + raise ValueError( + f"output[{i}].input_dimensions = {m.input_dimensions} " + f"references dimension {d}, out of range for input " + f"rank {self.domain.ndim}" + ) + # No duplicates allowed. + if len(set(m.input_dimensions)) != len(m.input_dimensions): + raise ValueError( + f"output[{i}].input_dimensions = {m.input_dimensions} " + f"contains duplicate dimensions" + ) + # index_array.shape must match the extents on input_dimensions. + expected_shape = tuple(self.domain.shape[d] for d in m.input_dimensions) + if m.index_array.shape != expected_shape: + raise ValueError( + f"output[{i}].index_array.shape = {m.index_array.shape} " + f"does not match expected shape {expected_shape} for " + f"input_dimensions={m.input_dimensions}" + ) @property def input_rank(self) -> int: @@ -128,7 +146,10 @@ def __repr__(self) -> str: elif isinstance(m, DimensionMap): maps.append(f"out[{i}] = {m.offset} + {m.stride} * in[{m.input_dimension}]") elif isinstance(m, ArrayMap): - maps.append(f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[in]") + in_dims = ",".join(f"in[{d}]" for d in m.input_dimensions) + maps.append( + f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[{in_dims}]" + ) maps_str = ", ".join(maps) return f"IndexTransform(domain={self.domain}, {maps_str})" @@ -165,6 +186,7 @@ def translate(self, shift: tuple[int, ...]) -> IndexTransform: new_output.append( ArrayMap( index_array=m.index_array, + input_dimensions=m.input_dimensions, offset=m.offset + s, stride=m.stride, ) @@ -206,10 +228,21 @@ def _intersect( f"transform output rank ({transform.output_rank})" ) - # Check if we have correlated ArrayMaps (vectorized) - array_dims = [i for i, m in enumerate(transform.output) if isinstance(m, ArrayMap)] - if len(array_dims) >= 2: - return _intersect_vectorized(transform, output_domain, array_dims) + # Check if we have correlated ArrayMaps (vectorized). + # Two ArrayMaps are correlated iff they share at least one input dimension. + array_output_dims = [i for i, m in enumerate(transform.output) if isinstance(m, ArrayMap)] + if len(array_output_dims) >= 2: + seen_input_dims: set[int] = set() + is_vectorized = False + for out_d in array_output_dims: + m = transform.output[out_d] + assert isinstance(m, ArrayMap) + if seen_input_dims & set(m.input_dimensions): + is_vectorized = True + break + seen_input_dims.update(m.input_dimensions) + if is_vectorized: + return _intersect_vectorized(transform, output_domain, array_output_dims) # Orthogonal: intersect each output dimension independently new_min = list(transform.domain.inclusive_min) @@ -259,11 +292,22 @@ def _intersect( mask = (storage >= lo) & (storage < hi) if not np.any(mask): return None + # Orthogonal ArrayMap: filter the array and shrink the input dim + # it parameterizes. Only the 1-D case is currently exercised; the + # multi-dim orthogonal ArrayMap path is rejected with a clear error. + if len(m.input_dimensions) != 1: + raise NotImplementedError( + "intersect on a multi-dimensional orthogonal ArrayMap is not yet supported" + ) + (input_d,) = m.input_dimensions surviving_indices = np.nonzero(mask.ravel())[0].astype(np.intp) filtered = m.index_array.ravel()[surviving_indices] + new_min[input_d] = 0 + new_max[input_d] = len(filtered) new_output.append( ArrayMap( index_array=filtered, + input_dimensions=(input_d,), offset=m.offset, stride=m.stride, ) @@ -280,28 +324,31 @@ def _intersect( def _intersect_vectorized( transform: IndexTransform, output_domain: IndexDomain, - array_dims: list[int], + array_output_dims: list[int], ) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: """Intersect a vectorized transform with an output domain. - All ArrayMap outputs are correlated — a point survives only if ALL its - storage coordinates fall within the output domain. + All ArrayMap outputs in `array_output_dims` are correlated — a point + survives only if ALL its storage coordinates fall within the output + domain. The correlated ArrayMaps share `input_dimensions`; after + filtering, those shared input dims collapse into a single 1-D domain + `(len(surviving),)`. Any non-correlated input dims (used by + DimensionMap outputs on independent input dims) are preserved. """ - # Compute storage coords per array dim and check bounds simultaneously - n_points: int | None = None + # Compute storage coords per array dim and check bounds simultaneously. masks: list[np.ndarray[Any, np.dtype[np.bool_]]] = [] + correlated_input_dims: set[int] = set() - for out_dim in array_dims: + for out_dim in array_output_dims: m = transform.output[out_dim] assert isinstance(m, ArrayMap) storage = m.offset + m.stride * m.index_array lo = output_domain.inclusive_min[out_dim] hi = output_domain.exclusive_max[out_dim] masks.append((storage >= lo) & (storage < hi)) - if n_points is None: - n_points = storage.size + correlated_input_dims.update(m.input_dimensions) - # A point survives only if it's in-bounds on ALL array dims + # A point survives only if it's in-bounds on ALL array dims. combined_mask = masks[0] for mask in masks[1:]: combined_mask = combined_mask & mask @@ -311,7 +358,20 @@ def _intersect_vectorized( surviving = np.nonzero(combined_mask.ravel())[0].astype(np.intp) - # Build new output maps + # Build the new domain. The correlated input dims collapse into a single + # 1-D dim (the surviving-points dim, placed at index 0). Any input dims + # NOT consumed by the correlated ArrayMaps are preserved in their order. + preserved_input_dims = [ + d for d in range(transform.domain.ndim) if d not in correlated_input_dims + ] + new_inclusive_min = [0] + [transform.domain.inclusive_min[d] for d in preserved_input_dims] + new_exclusive_max = [len(surviving)] + [ + transform.domain.exclusive_max[d] for d in preserved_input_dims + ] + # old input dim -> new input dim + old_to_new: dict[int, int] = {d: i + 1 for i, d in enumerate(preserved_input_dims)} + + # Build new output maps. new_output: list[OutputIndexMap] = [] for out_dim, m in enumerate(transform.output): if isinstance(m, ArrayMap): @@ -319,6 +379,7 @@ def _intersect_vectorized( new_output.append( ArrayMap( index_array=filtered, + input_dimensions=(0,), offset=m.offset, stride=m.stride, ) @@ -331,9 +392,23 @@ def _intersect_vectorized( else: return None elif isinstance(m, DimensionMap): - new_output.append(m) + if m.input_dimension in correlated_input_dims: + raise NotImplementedError( + "vectorized intersect with a DimensionMap on a correlated " + "input dim is not supported" + ) + new_output.append( + DimensionMap( + input_dimension=old_to_new[m.input_dimension], + offset=m.offset, + stride=m.stride, + ) + ) - new_domain = IndexDomain.from_shape((len(surviving),)) + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) result = IndexTransform(domain=new_domain, output=tuple(new_output)) return (result, surviving) @@ -378,77 +453,6 @@ def _normalize_basic_selection(selection: Any, ndim: int) -> tuple[int | slice | return tuple(result) -def _reindex_array( - arr: np.ndarray[Any, np.dtype[np.intp]], - normalized: tuple[int | slice | None, ...], - domain: IndexDomain, -) -> np.ndarray[Any, np.dtype[np.intp]]: - """Apply basic indexing operations to an ArrayMap's index_array. - - The array's axes correspond to the transform's input dimensions (0-indexed - over the domain shape). When input dimensions are dropped (int), sliced, - or inserted (newaxis), the array must be updated accordingly. - """ - # Build a numpy indexing tuple: one entry per old input dimension - idx: list[Any] = [] - old_dim = 0 - newaxis_positions: list[int] = [] - result_axis = 0 - - for sel in normalized: - if sel is None: - newaxis_positions.append(result_axis) - result_axis += 1 - elif isinstance(sel, int): - if old_dim < arr.ndim: - # Convert absolute domain coordinate to 0-based array index - array_idx = sel - domain.inclusive_min[old_dim] - idx.append(array_idx) - old_dim += 1 - elif isinstance(sel, slice): - if old_dim < arr.ndim: - dim_size = domain.shape[old_dim] - # sel.indices gives 0-based start/stop/step for the array axis - start, stop, step = sel.indices(dim_size) - idx.append(slice(start, stop, step)) - old_dim += 1 - result_axis += 1 - - result = arr[tuple(idx)] if idx else arr - - for pos in newaxis_positions: - result = np.expand_dims(result, axis=pos) - - return np.asarray(result, dtype=np.intp) - - -def _reindex_array_oindex( - arr: np.ndarray[Any, np.dtype[np.intp]], - normalized: tuple[Any, ...] | list[Any], - domain: IndexDomain, -) -> np.ndarray[Any, np.dtype[np.intp]]: - """Apply oindex/vindex selection to an existing ArrayMap's index_array. - - Each old input dimension gets either an array (fancy index that axis) - or a slice applied to the corresponding array axis. - """ - idx: list[Any] = [] - for old_dim, sel in enumerate(normalized): - if old_dim >= arr.ndim: - break - if isinstance(sel, np.ndarray): - idx.append(sel) - elif isinstance(sel, slice): - dim_size = domain.shape[old_dim] - start, stop, step = sel.indices(dim_size) - idx.append(slice(start, stop, step)) - else: - idx.append(slice(None)) - - result = arr[tuple(idx)] if idx else arr - return np.asarray(result, dtype=np.intp) - - def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTransform: """Apply basic indexing (int, slice, ellipsis, newaxis) to an IndexTransform.""" normalized = _normalize_basic_selection(selection, transform.domain.ndim) @@ -537,8 +541,42 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra else: raise RuntimeError(f"unexpected: dimension {d} not handled") elif isinstance(m, ArrayMap): - new_arr = _reindex_array(m.index_array, normalized, transform.domain) - new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + # The array's axes are labeled by m.input_dimensions, in order. + # For each labeled axis: if the corresponding old input dim is + # dropped (int), select that one entry; if sliced, slice the axis; + # otherwise leave the axis intact. Newaxis insertions don't touch + # the array (they add new input dims not in input_dimensions). + arr_idx: list[Any] = [] + new_input_dims: list[int] = [] + for axis_dim in m.input_dimensions: + if axis_dim in dropped_dims: + array_idx = dim_int_val[axis_dim] - transform.domain.inclusive_min[axis_dim] + arr_idx.append(array_idx) + elif axis_dim in old_to_new_dim: + abs_start, _, step = dim_slice_params[axis_dim] + array_start = abs_start - transform.domain.inclusive_min[axis_dim] + new_size = ( + new_exclusive_max[old_to_new_dim[axis_dim]] + - new_inclusive_min[old_to_new_dim[axis_dim]] + ) + array_stop = array_start + step * new_size + arr_idx.append(slice(array_start, array_stop, step)) + new_input_dims.append(old_to_new_dim[axis_dim]) + else: + raise RuntimeError( + f"unexpected: ArrayMap input_dim {axis_dim} not in " + "dropped_dims or old_to_new_dim" + ) + new_arr = m.index_array[tuple(arr_idx)] if arr_idx else m.index_array + new_arr = np.asarray(new_arr, dtype=np.intp) + new_output.append( + ArrayMap( + index_array=new_arr, + input_dimensions=tuple(new_input_dims), + offset=m.offset, + stride=m.stride, + ) + ) return IndexTransform(domain=new_domain, output=tuple(new_output)) @@ -646,6 +684,7 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: new_output.append( ArrayMap( index_array=dim_array[d], + input_dimensions=(old_to_new_dim[d],), offset=m.offset, stride=m.stride, ) @@ -663,8 +702,39 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: else: raise RuntimeError(f"unexpected: dimension {d} not handled") elif isinstance(m, ArrayMap): - new_arr = _reindex_array_oindex(m.index_array, normalized, transform.domain) - new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + # Each axis of m.index_array corresponds to one entry in + # m.input_dimensions. For each such old input dim, oindex either + # picks specific entries (dim_array[d]) or slices the axis + # (dim_slice_params[d]). + arr_idx: list[Any] = [] + for axis_dim in m.input_dimensions: + if axis_dim in dim_array: + arr_idx.append(dim_array[axis_dim]) + elif axis_dim in dim_slice_params: + abs_start, _, step = dim_slice_params[axis_dim] + array_start = abs_start - transform.domain.inclusive_min[axis_dim] + new_size = ( + new_exclusive_max[old_to_new_dim[axis_dim]] + - new_inclusive_min[old_to_new_dim[axis_dim]] + ) + array_stop = array_start + step * new_size + arr_idx.append(slice(array_start, array_stop, step)) + else: + raise RuntimeError( + f"unexpected: ArrayMap input_dim {axis_dim} not in " + "dim_array or dim_slice_params" + ) + new_arr = m.index_array[tuple(arr_idx)] if arr_idx else m.index_array + new_arr = np.asarray(new_arr, dtype=np.intp) + new_input_dims = tuple(old_to_new_dim[d] for d in m.input_dimensions) + new_output.append( + ArrayMap( + index_array=new_arr, + input_dimensions=new_input_dims, + offset=m.offset, + stride=m.stride, + ) + ) return IndexTransform(domain=new_domain, output=tuple(new_output)) @@ -793,6 +863,9 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: # New dim index for slice dims starts after broadcast dims n_broadcast_dims = len(broadcast_shape) + # Broadcast dims are placed at input_dim positions [0, n_broadcast_dims). + broadcast_input_dims = tuple(range(n_broadcast_dims)) + new_output: list[OutputIndexMap] = [] for m in transform.output: if isinstance(m, ConstantMap): @@ -803,6 +876,7 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: new_output.append( ArrayMap( index_array=array_dim_to_broadcast[d], + input_dimensions=broadcast_input_dims, offset=m.offset, stride=m.stride, ) @@ -819,8 +893,13 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: ) ) elif isinstance(m, ArrayMap): - new_arr = _reindex_array_oindex(m.index_array, processed, transform.domain) - new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + # vindex on a transform that already has an ArrayMap output is not + # currently exercised. The semantics are subtle (broadcasting can + # reshape the array's parameterization) and require careful design; + # raise rather than produce wrong results. + raise NotImplementedError( + "vindex on a transform whose output is already an ArrayMap is not yet supported" + ) return IndexTransform(domain=new_domain, output=tuple(new_output)) diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py index 367f8b8d32..9314549308 100644 --- a/tests/test_transforms/test_chunk_resolution.py +++ b/tests/test_transforms/test_chunk_resolution.py @@ -98,7 +98,7 @@ def test_iter_chunk_transforms_array_map_lists_chunks_for_array_entries() -> Non idx = np.array([5, 15, 25], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=idx),), + output=(ArrayMap(index_array=idx, input_dimensions=(0,)),), ) results = list(iter_chunk_transforms(t, _grid_1d(10, 30))) coords_list = [r[0] for r in results] @@ -203,7 +203,7 @@ def test_sub_transform_to_selections_array_map_no_offset() -> None: arr = np.array([1, 5, 9], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr, offset=0, stride=1),), + output=(ArrayMap(index_array=arr, input_dimensions=(0,), offset=0, stride=1),), ) chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) assert isinstance(chunk_sel[0], np.ndarray) @@ -218,7 +218,7 @@ def test_sub_transform_to_selections_array_map_with_offset_stride() -> None: arr = np.array([0, 1, 2], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr, offset=10, stride=5),), + output=(ArrayMap(index_array=arr, input_dimensions=(0,), offset=10, stride=5),), ) chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) assert isinstance(chunk_sel[0], np.ndarray) @@ -232,7 +232,7 @@ def test_sub_transform_to_selections_orthogonal_array_with_out_indices() -> None arr = np.array([1, 5, 9], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr),), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), ) out_indices = np.array([0, 2], dtype=np.intp) _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) @@ -247,8 +247,8 @@ def test_sub_transform_to_selections_vectorized_with_out_indices() -> None: t = IndexTransform( domain=IndexDomain.from_shape((3,)), output=( - ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp)), - ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), ), ) out_indices = np.array([0, 1], dtype=np.intp) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index 7a9649d0b3..5cffe561d0 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -34,14 +34,14 @@ _array_outer_arr = np.array([0, 2, 4], dtype=np.intp) _array_outer = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=_array_outer_arr, offset=5, stride=2),), + output=(ArrayMap(index_array=_array_outer_arr, input_dimensions=(0,), offset=5, stride=2),), ) # Inner = ArrayMap with various outers. _array_inner_arr = np.array([10, 20, 30], dtype=np.intp) _array_inner = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=_array_inner_arr, offset=0, stride=1),), + output=(ArrayMap(index_array=_array_inner_arr, input_dimensions=(0,), offset=0, stride=1),), ) _constant_outer_1 = IndexTransform( domain=IndexDomain.from_shape((5,)), @@ -49,7 +49,14 @@ ) _array_outer_for_array_inner = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=np.array([0, 2, 1], dtype=np.intp), offset=0, stride=1),), + output=( + ArrayMap( + index_array=np.array([0, 2, 1], dtype=np.intp), + input_dimensions=(0,), + offset=0, + stride=1, + ), + ), ) @@ -203,7 +210,8 @@ def test_compose_chains_associatively() -> None: domain=IndexDomain.from_shape((3, 2)), output=( ArrayMap( - index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp) + index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp), + input_dimensions=(0, 1), ), ), ), diff --git a/tests/test_transforms/test_output_map.py b/tests/test_transforms/test_output_map.py index 5e1695172f..0e62017b0a 100644 --- a/tests/test_transforms/test_output_map.py +++ b/tests/test_transforms/test_output_map.py @@ -34,18 +34,28 @@ id="DimensionMap-defaults", ), Expect( - input=ArrayMap(index_array=np.array([1, 3, 5], dtype=np.intp), offset=10, stride=2), + input=ArrayMap( + index_array=np.array([1, 3, 5], dtype=np.intp), + input_dimensions=(0,), + offset=10, + stride=2, + ), expected={ "index_array": np.array([1, 3, 5], dtype=np.intp), + "input_dimensions": (0,), "offset": 10, "stride": 2, }, id="ArrayMap-all-fields", ), Expect( - input=ArrayMap(index_array=np.array([0, 1], dtype=np.intp)), + input=ArrayMap( + index_array=np.array([0, 1], dtype=np.intp), + input_dimensions=(0,), + ), expected={ "index_array": np.array([0, 1], dtype=np.intp), + "input_dimensions": (0,), "offset": 0, "stride": 1, }, @@ -81,7 +91,11 @@ def test_construction_success(case: Expect[Any, dict[str, Any]]) -> None: id="DimensionMap-frozen", ), ExpectErr( - input=(ArrayMap(index_array=np.array([0], dtype=np.intp)), "offset", 1), + input=( + ArrayMap(index_array=np.array([0], dtype=np.intp), input_dimensions=(0,)), + "offset", + 1, + ), msg="cannot assign to field 'offset'", exception_cls=FrozenInstanceError, id="ArrayMap-frozen", diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index b07f14839d..f227d8aecf 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -253,15 +253,30 @@ def test_getitem_composition( # rather than building a new map. _array_map_1d = IndexTransform( domain=IndexDomain.from_shape((5,)), - output=(ArrayMap(index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp)),), + output=( + ArrayMap( + index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp), + input_dimensions=(0,), + ), + ), ) _array_map_2d_3x2 = IndexTransform( domain=IndexDomain.from_shape((3, 2)), - output=(ArrayMap(index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp)),), + output=( + ArrayMap( + index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp), + input_dimensions=(0, 1), + ), + ), ) _array_map_2d_2x3 = IndexTransform( domain=IndexDomain.from_shape((2, 3)), - output=(ArrayMap(index_array=np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp)),), + output=( + ArrayMap( + index_array=np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp), + input_dimensions=(0, 1), + ), + ), ) @@ -302,17 +317,26 @@ def test_getitem_on_array_map( def test_getitem_newaxis_on_array_map() -> None: - """np.newaxis on an ArrayMap inserts a new axis in the index_array, not just the domain.""" + """np.newaxis on an ArrayMap inserts a new input dim into the domain but + leaves the array's parameterization unchanged. The array's input_dimensions + just shifts to point at the new index of the old dim.""" t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=np.array([10, 20, 30], dtype=np.intp)),), + output=( + ArrayMap( + index_array=np.array([10, 20, 30], dtype=np.intp), + input_dimensions=(0,), + ), + ), ) result = t[np.newaxis, :] assert result.input_rank == 2 assert result.domain.shape == (1, 3) assert isinstance(result.output[0], ArrayMap) - assert result.output[0].index_array.shape == (1, 3) - np.testing.assert_array_equal(result.output[0].index_array, np.array([[10, 20, 30]])) + # newaxis is at new dim 0; old dim 0 shifts to new dim 1. + assert result.output[0].input_dimensions == (1,) + assert result.output[0].index_array.shape == (3,) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 20, 30])) # --------------------------------------------------------------------------- @@ -532,7 +556,7 @@ def test_selection_to_transform_unknown_mode_errors() -> None: """ t = IndexTransform.from_shape((10,)) with pytest.raises(ValueError, match="Unknown mode"): - selection_to_transform(slice(None), t, "diagonal") + selection_to_transform(slice(None), t, "diagonal") # type: ignore[arg-type] # --------------------------------------------------------------------------- @@ -603,7 +627,7 @@ def test_intersect_array_partial() -> None: arr = np.array([3, 8, 15, 22], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((4,)), - output=(ArrayMap(index_array=arr),), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), ) result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) assert result is not None @@ -619,7 +643,7 @@ def test_intersect_array_disjoint() -> None: arr = np.array([1, 2, 3], dtype=np.intp) t = IndexTransform( domain=IndexDomain.from_shape((3,)), - output=(ArrayMap(index_array=arr),), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), ) assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None @@ -663,12 +687,15 @@ def test_intersect_rank_mismatch_errors() -> None: def _vectorized_2d_array_map() -> IndexTransform: """Helper: a vectorized transform over a (3,) input domain with two - correlated ArrayMaps. Storage coords: (1,10), (5,11), (9,12).""" + correlated ArrayMaps. Storage coords: (1,10), (5,11), (9,12). + + Both ArrayMaps share input_dimensions=(0,) — that's what makes them + correlated under the new design.""" return IndexTransform( domain=IndexDomain.from_shape((3,)), output=( - ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp)), - ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), ), ) @@ -701,8 +728,8 @@ def test_intersect_vectorized_with_constant_outside_drops_to_none() -> None: t = IndexTransform( domain=IndexDomain.from_shape((3,)), output=( - ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp)), - ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp)), + ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), ConstantMap(offset=99), ), ) @@ -717,7 +744,13 @@ def test_intersect_vectorized_with_constant_outside_drops_to_none() -> None: _translate_dimension_t = IndexTransform.from_shape((10,)) _translate_array_t = IndexTransform( domain=IndexDomain.from_shape((2,)), - output=(ArrayMap(index_array=np.array([5, 10], dtype=np.intp), offset=3),), + output=( + ArrayMap( + index_array=np.array([5, 10], dtype=np.intp), + input_dimensions=(0,), + offset=3, + ), + ), ) _translate_constant_t = IndexTransform( domain=IndexDomain.from_shape((10,)), From c4de02366e6e2268ea7004d904ca3a90a92c90d9 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 May 2026 10:05:43 -0400 Subject: [PATCH 20/24] refactor(_transforms): hoist itertools import in chunk_resolution Function-level `import itertools` inside iter_chunk_transforms moved to module top. The TYPE_CHECKING-guarded imports are intentionally inside their guard and remain there. --- src/zarr/core/_transforms/chunk_resolution.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py index d9a0420e39..393bbbb21c 100644 --- a/src/zarr/core/_transforms/chunk_resolution.py +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -29,6 +29,7 @@ from __future__ import annotations +import itertools from typing import TYPE_CHECKING, Any import numpy as np @@ -97,8 +98,6 @@ def iter_chunk_transforms( last = int(chunk_ids.max()) chunk_ranges.append(range(first, last + 1)) - import itertools - for chunk_coords_tuple in itertools.product(*chunk_ranges): chunk_coords = tuple(int(c) for c in chunk_coords_tuple) From ee91e0b8d2753763cf10bd09361565a999829a09 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 May 2026 10:56:00 -0400 Subject: [PATCH 21/24] fix(_transforms): guard multi-dim ArrayMap case in oindex; address review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code review on the post-refactor branch found one real bug and three minor follow-ups. All addressed: 1. **`_apply_oindex` on a multi-dim ArrayMap with 2+ axes selected by integer arrays** previously did `m.index_array[a, b]` which performs NumPy vectorized broadcasting (wrong) instead of orthogonal outer product (right). The shape check in `__post_init__` would catch the resulting wrong shape, but the user-facing error was opaque. Replace with an explicit `NotImplementedError`. The 1-D case (one axis) and the all-slices case (zero axes) are correct and remain unchanged. Adds three test cases: - test_oindex_on_1d_array_map_with_int_array (1-D path works) - test_oindex_on_2d_array_map_all_slices (multi-dim, all slices, works) - test_oindex_on_multi_dim_array_map_with_two_array_axes_errors (the now-guarded case raises NotImplementedError) 2. Reachability comments on the two remaining NotImplementedErrors in `_intersect` (multi-dim orthogonal ArrayMap) and `_intersect_vectorized` (DimensionMap on a correlated input dim). Both are reachable only via direct manual transform construction; the public oindex/vindex/basic API never produces these states. 3. Strengthen the hypothesis associativity property test to sample 5 coords per generated triple instead of 1. Negligible cost (still <0.5s), better probabilistic coverage of any coord-dependent bugs. Test count: 153 → 156. mypy clean. --- src/zarr/core/_transforms/transform.py | 22 ++++++++++-- tests/test_transforms/test_composition.py | 29 +++++++-------- tests/test_transforms/test_transform.py | 44 +++++++++++++++++++++++ 3 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index 330b523888..e47a0faace 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -293,8 +293,10 @@ def _intersect( if not np.any(mask): return None # Orthogonal ArrayMap: filter the array and shrink the input dim - # it parameterizes. Only the 1-D case is currently exercised; the - # multi-dim orthogonal ArrayMap path is rejected with a clear error. + # it parameterizes. The 1-D case is the only one produced by the + # public oindex / vindex API; multi-dim orthogonal ArrayMaps would + # only arise via direct manual construction. Reject rather than + # silently produce an unsupported result. if len(m.input_dimensions) != 1: raise NotImplementedError( "intersect on a multi-dimensional orthogonal ArrayMap is not yet supported" @@ -392,6 +394,10 @@ def _intersect_vectorized( else: return None elif isinstance(m, DimensionMap): + # _apply_vindex never places a DimensionMap on a correlated (broadcast) + # input dim: the broadcast dims always become ArrayMap parameters and + # the slice dims become DimensionMaps on dims past the broadcast block. + # This guard is reachable only via direct manual transform construction. if m.input_dimension in correlated_input_dims: raise NotImplementedError( "vectorized intersect with a DimensionMap on a correlated " @@ -707,9 +713,11 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: # picks specific entries (dim_array[d]) or slices the axis # (dim_slice_params[d]). arr_idx: list[Any] = [] + n_array_axes = 0 for axis_dim in m.input_dimensions: if axis_dim in dim_array: arr_idx.append(dim_array[axis_dim]) + n_array_axes += 1 elif axis_dim in dim_slice_params: abs_start, _, step = dim_slice_params[axis_dim] array_start = abs_start - transform.domain.inclusive_min[axis_dim] @@ -724,6 +732,16 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: f"unexpected: ArrayMap input_dim {axis_dim} not in " "dim_array or dim_slice_params" ) + # Multi-dim ArrayMap with two or more axes selected by arrays needs + # `np.ix_`-style outer-product indexing to preserve oindex semantics + # (NumPy's `arr[a, b]` would broadcast a and b instead). Until that + # is implemented, refuse rather than silently produce wrong results. + if n_array_axes >= 2: + raise NotImplementedError( + "oindex on a multi-dimensional ArrayMap with two or more " + "axes selected by integer/boolean arrays is not yet " + "supported" + ) new_arr = m.index_array[tuple(arr_idx)] if arr_idx else m.index_array new_arr = np.asarray(new_arr, dtype=np.intp) new_input_dims = tuple(old_to_new_dim[d] for d in m.input_dimensions) diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index 5cffe561d0..db01923f48 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -334,19 +334,20 @@ def test_compose_is_associative( assert left.output_rank == right.output_rank assert left.domain == right.domain - # Sample a few points from a's domain and compare evaluations. - if a.input_rank == 0: - coord: tuple[int, ...] = () - else: - coord = tuple( - data.draw( - st.integers( - min_value=a.domain.inclusive_min[d], - max_value=a.domain.exclusive_max[d] - 1, + # Sample several points from a's domain and compare evaluations at each. + # 5 coords per triple raises probabilistic coverage at negligible cost. + for _ in range(5): + if a.input_rank == 0: + coord: tuple[int, ...] = () + else: + coord = tuple( + data.draw( + st.integers( + min_value=a.domain.inclusive_min[d], + max_value=a.domain.exclusive_max[d] - 1, + ) ) + for d in range(a.input_rank) ) - for d in range(a.input_rank) - ) - assume(a.domain.contains(coord)) - - assert _evaluate(left, coord) == _evaluate(right, coord) + assume(a.domain.contains(coord)) + assert _evaluate(left, coord) == _evaluate(right, coord) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index f227d8aecf..14f9eaff22 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -412,6 +412,50 @@ def test_oindex_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: transform.oindex[selection] +def test_oindex_on_1d_array_map_with_int_array() -> None: + """oindex on a transform with a 1-D ArrayMap output indexes that ArrayMap's + array along its single parameterizing input dim.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + result = t.oindex[np.array([0, 2, 4], dtype=np.intp)] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 30, 50])) + + +def test_oindex_on_2d_array_map_all_slices() -> None: + """oindex on a 2-D ArrayMap with slices on every axis is well-defined + (no axes selected by integer arrays).""" + arr = np.arange(12, dtype=np.intp).reshape(3, 4) + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=(ArrayMap(index_array=arr, input_dimensions=(0, 1)),), + ) + # Both axes sliced; no array indices. + result = t.oindex[1:3, 0:3] + assert result.input_rank == 2 + assert result.domain.shape == (2, 3) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, arr[1:3, 0:3]) + + +def test_oindex_on_multi_dim_array_map_with_two_array_axes_errors() -> None: + """oindex on a multi-dim ArrayMap with two or more axes selected by + integer arrays needs np.ix_-style outer-product semantics. Until that + is implemented, raise NotImplementedError.""" + arr = np.arange(12, dtype=np.intp).reshape(3, 4) + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=(ArrayMap(index_array=arr, input_dimensions=(0, 1)),), + ) + with pytest.raises(NotImplementedError, match="multi-dimensional ArrayMap"): + t.oindex[np.array([0, 2], dtype=np.intp), np.array([1, 3], dtype=np.intp)] + + # --------------------------------------------------------------------------- # vindex (vectorized indexing) # --------------------------------------------------------------------------- From 43a9b98dc824aff76eaaa7bb7cb45ee6673e5e3e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 May 2026 11:14:13 -0400 Subject: [PATCH 22/24] test(_transforms): close coverage gaps; reach 99% line+branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Coverage was 83% before this commit. Audit identified meaningful gaps across three tiers: **Tier 1 - real gaps in production paths:** - selection_repr / __repr__: zero tests (covers all OutputIndexMap kinds). - IndexTransform.intersect() vectorized dispatch: tests called the internal _intersect_vectorized directly, never the public dispatch. - __post_init__ validation paths from the input_dimensions refactor: out-of-range input_dim, duplicate input_dimensions, shape mismatch. - _intersect_vectorized DimensionMap-on-non-correlated-dim path (the post-refactor remap of preserved input dims). - _apply_basic_indexing rejects negative slice step. - compose dim-outer / arr-inner case (the one untested cell of the composition matrix). - vindex on a transform that already has an ArrayMap output (now-explicit NotImplementedError). - _validate_basic_selection / _validate_array_selection reject unsupported types (e.g., float). - ConstantMap survives basic / oindex / vindex unchanged (each apply function had this branch but no test exercised it). - _normalize_basic_selection error paths: too-many-indices, double-ellipsis, unsupported type. **Tier 2 - dead negative-stride DimensionMap paths:** DimensionMap.stride was previously unconstrained, but the public API (slice.indices(), oindex/vindex normalize, _apply_basic_indexing) never produces a negative-stride DimensionMap, and composition preserves sign so multiplication of positives stays positive. Every negative-stride code path was unreachable. Two options were: test those paths or delete them. Chose deletion + add a positive-stride constraint to __post_init__. Removed: - chunk_resolution._iter_chunk_transforms negative-stride branch - chunk_resolution.sub_transform_to_selections negative-stride branch - transform._intersect DimensionMap negative-stride and zero-stride branches (zero was already unreachable; both deleted) Added: __post_init__ rejects DimensionMap with stride <= 0. **Tier 3 - defensive RuntimeErrors and unreachable guards:** Added `# pragma: no cover` to four defensive RuntimeErrors and two NotImplementedErrors that are reachable only via direct manual construction of malformed transforms (the public oindex/vindex/basic API never produces these states). Also marked the silent fallthrough in _normalize_oindex_selection / _apply_vindex parsing (the upstream _validate_array_selection rejects the unsupported types first). **Other test additions:** - iter_chunk_transforms over-estimate skip path (when intersect returns None for a candidate chunk in the chunk-range). - iter_chunk_transforms empty-domain path. - oindex with ellipsis and trailing-dim implicit fill. - oindex/vindex with bare int and Python list selections. - vindex with ellipsis, 2-D bool mask (consumes 2 dims), trailing slice padding. Adjusted hypothesis associativity strategy to only generate positive strides (matching the new __post_init__ constraint). Test count: 156 → 191. Coverage: 83% → 99% (line+branch). The remaining 1% is partial-branch artifacts from loop fall-through and `else: continue` constructs that fire incidentally based on test case order; not real semantic gaps. --- src/zarr/core/_transforms/chunk_resolution.py | 12 +- src/zarr/core/_transforms/transform.py | 44 +- .../test_transforms/test_chunk_resolution.py | 33 ++ tests/test_transforms/test_composition.py | 36 +- tests/test_transforms/test_transform.py | 492 ++++++++++++++++++ 5 files changed, 584 insertions(+), 33 deletions(-) diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py index 393bbbb21c..f1d3c80e67 100644 --- a/src/zarr/core/_transforms/chunk_resolution.py +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -81,12 +81,9 @@ def iter_chunk_transforms( dim_hi = transform.domain.exclusive_max[d] if dim_lo >= dim_hi: return # empty domain - if m.stride > 0: - s_min = m.offset + m.stride * dim_lo - s_max = m.offset + m.stride * (dim_hi - 1) - else: - s_min = m.offset + m.stride * (dim_hi - 1) - s_max = m.offset + m.stride * dim_lo + # DimensionMap.stride is always positive (enforced by __post_init__). + s_min = m.offset + m.stride * dim_lo + s_max = m.offset + m.stride * (dim_hi - 1) first = dg.index_to_chunk(s_min) last = dg.index_to_chunk(s_max) chunk_ranges.append(range(first, last + 1)) @@ -162,12 +159,11 @@ def sub_transform_to_selections( if isinstance(m, ConstantMap): chunk_sel.append(m.offset) elif isinstance(m, DimensionMap): + # DimensionMap.stride is always positive (enforced by __post_init__). dim_lo = sub_transform.domain.inclusive_min[m.input_dimension] dim_hi = sub_transform.domain.exclusive_max[m.input_dimension] start = m.offset + m.stride * dim_lo stop = m.offset + m.stride * dim_hi - if m.stride < 0: - start, stop = stop + 1, start + 1 chunk_sel.append(slice(start, stop, m.stride)) elif isinstance(m, ArrayMap): if m.offset == 0 and m.stride == 1: diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index e47a0faace..9a4233d0ad 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -64,6 +64,12 @@ def __post_init__(self) -> None: f"output[{i}].input_dimension = {m.input_dimension} " f"is out of range for input rank {self.domain.ndim}" ) + if m.stride <= 0: + raise ValueError( + f"output[{i}].stride = {m.stride} must be positive. " + "Negative-stride DimensionMaps are not supported; the " + "Array layer normalizes negative strides upstream." + ) elif isinstance(m, ArrayMap): # Every input dim referenced must be in range. for d in m.input_dimensions: @@ -267,18 +273,10 @@ def _intersect( if input_lo >= input_hi: return None - # Find input range that produces storage coords in [lo, hi) - if m.stride > 0: - new_input_lo = max(input_lo, math.ceil((lo - m.offset) / m.stride)) - new_input_hi = min(input_hi, math.ceil((hi - m.offset) / m.stride)) - elif m.stride < 0: - new_input_lo = max(input_lo, math.ceil((hi - 1 - m.offset) / m.stride)) - new_input_hi = min(input_hi, math.ceil((lo - 1 - m.offset) / m.stride)) - else: - if lo <= m.offset < hi: - new_input_lo, new_input_hi = input_lo, input_hi - else: - return None + # Find the input range that produces storage coords in [lo, hi). + # DimensionMap.stride is always positive (enforced by __post_init__). + new_input_lo = max(input_lo, math.ceil((lo - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((hi - m.offset) / m.stride)) if new_input_lo >= new_input_hi: return None @@ -297,7 +295,7 @@ def _intersect( # public oindex / vindex API; multi-dim orthogonal ArrayMaps would # only arise via direct manual construction. Reject rather than # silently produce an unsupported result. - if len(m.input_dimensions) != 1: + if len(m.input_dimensions) != 1: # pragma: no cover - public API never produces this raise NotImplementedError( "intersect on a multi-dimensional orthogonal ArrayMap is not yet supported" ) @@ -398,7 +396,9 @@ def _intersect_vectorized( # input dim: the broadcast dims always become ArrayMap parameters and # the slice dims become DimensionMaps on dims past the broadcast block. # This guard is reachable only via direct manual transform construction. - if m.input_dimension in correlated_input_dims: + if ( + m.input_dimension in correlated_input_dims + ): # pragma: no cover - public API never produces this raise NotImplementedError( "vectorized intersect with a DimensionMap on a correlated " "input dim is not supported" @@ -545,7 +545,9 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra ) ) else: - raise RuntimeError(f"unexpected: dimension {d} not handled") + raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms + f"unexpected: dimension {d} not handled" + ) elif isinstance(m, ArrayMap): # The array's axes are labeled by m.input_dimensions, in order. # For each labeled axis: if the corresponding old input dim is @@ -568,7 +570,7 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra array_stop = array_start + step * new_size arr_idx.append(slice(array_start, array_stop, step)) new_input_dims.append(old_to_new_dim[axis_dim]) - else: + else: # pragma: no cover - defensive; unreachable for validated transforms raise RuntimeError( f"unexpected: ArrayMap input_dim {axis_dim} not in " "dropped_dims or old_to_new_dim" @@ -627,7 +629,7 @@ def _normalize_oindex_selection( result.append(np.array([int(sel)], dtype=np.intp)) elif isinstance(sel, (list, tuple)): result.append(np.asarray(sel, dtype=np.intp)) - else: + else: # pragma: no cover - upstream _validate_array_selection rejects other types result.append(sel) # Pad with slice(None) @@ -706,7 +708,9 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: ) ) else: - raise RuntimeError(f"unexpected: dimension {d} not handled") + raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms + f"unexpected: dimension {d} not handled" + ) elif isinstance(m, ArrayMap): # Each axis of m.index_array corresponds to one entry in # m.input_dimensions. For each such old input dim, oindex either @@ -727,7 +731,7 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: ) array_stop = array_start + step * new_size arr_idx.append(slice(array_start, array_stop, step)) - else: + else: # pragma: no cover - defensive; unreachable for validated transforms raise RuntimeError( f"unexpected: ArrayMap input_dim {axis_dim} not in " "dim_array or dim_slice_params" @@ -818,7 +822,7 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: processed.append(np.asarray(sel, dtype=np.intp)) elif isinstance(sel, (int, np.integer)): processed.append(np.array([int(sel)], dtype=np.intp)) - else: + else: # pragma: no cover - upstream _validate_array_selection rejects other types processed.append(sel) # Separate array dims and slice dims diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py index 9314549308..b0fd929bbd 100644 --- a/tests/test_transforms/test_chunk_resolution.py +++ b/tests/test_transforms/test_chunk_resolution.py @@ -256,3 +256,36 @@ def test_sub_transform_to_selections_vectorized_with_out_indices() -> None: assert len(out_sel) == 1 assert isinstance(out_sel[0], np.ndarray) np.testing.assert_array_equal(out_sel[0], out_indices) + + +def test_iter_chunk_transforms_empty_domain() -> None: + """When the input domain is empty (some dim has zero extent), + iter_chunk_transforms yields nothing.""" + t = IndexTransform( + domain=IndexDomain(inclusive_min=(0,), exclusive_max=(0,)), + output=(DimensionMap(input_dimension=0, offset=0, stride=1),), + ) + grid = _grid_1d(10, 30) + results = list(iter_chunk_transforms(t, grid)) + assert results == [] + + +def test_iter_chunk_transforms_skips_chunks_that_intersect_returns_none() -> None: + """A strided DimensionMap can produce a chunk-range overestimate that + includes chunks the transform doesn't actually touch. iter_chunk_transforms + must skip those (the `if result is None: continue` branch).""" + # arr[::5] over a domain of size 30 yields storage coords [0, 5, 10, 15, 20, 25]. + # With chunk size 4, those land in chunks 0, 1, 2, 3, 5, 6 — chunk 4 (storage [16,20)) + # is in the chunk-range but contains no surviving storage coord (storage 20 is in chunk 5). + # Wait: storage 20 lands in chunk floor(20/4) = 5; 16 is in chunk 4. Let me recheck. + # arr[::5] gives [0, 5, 10, 15, 20, 25]. Chunks (size 4): 0/4=0, 5/4=1, 10/4=2, + # 15/4=3, 20/4=5, 25/4=6. So chunk 4 (storage [16, 20)) is skipped. + # The chunk-range computed in iter_chunk_transforms is range(0, 7) -> 0..6 inclusive, + # so chunk 4 is iterated and intersect() returns None. + t = IndexTransform.from_shape((30,))[::5] + grid = _grid_1d(4, 30) + results = list(iter_chunk_transforms(t, grid)) + coords = sorted(r[0][0] for r in results) + # Every storage coord is hit exactly once; chunk 4 is NOT in the result. + assert 4 not in coords + assert sorted(coords) == [0, 1, 2, 3, 5, 6] diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index db01923f48..74ccffbe93 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -101,6 +101,34 @@ expected={"kind": ConstantMap, "offset": 20}, id="array-inner-constant-outer", ), + Expect( + input=( + # Outer: 1-D identity-ish, input domain (4,), DimensionMap with + # offset=1 stride=1. Intermediate produced: [1, 2, 3, 4]. + IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=1),), + ), + # Inner: ArrayMap of length 5 on intermediate dim 0. + # arr[1..4] = [200, 300, 400, 500]. + IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=( + ArrayMap( + index_array=np.array([100, 200, 300, 400, 500], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ), + ), + expected={ + "kind": ArrayMap, + "offset": 0, + "stride": 1, + "index_array": np.array([200, 300, 400, 500], dtype=np.intp), + }, + id="array-inner-dimension-outer", + ), Expect( input=(_array_outer_for_array_inner, _array_inner), expected={ @@ -276,20 +304,18 @@ def _affine_output_map(input_rank: int, draw: st.DrawFn) -> ConstantMap | Dimens DimensionMap requires input_rank >= 1; falls back to ConstantMap otherwise. Offsets and strides are kept small to avoid integer overflow during - repeated composition. + repeated composition. Strides are positive (DimensionMap rejects + non-positive strides at construction). """ if input_rank == 0: return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) kind = draw(st.sampled_from(["constant", "dimension"])) if kind == "constant": return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) - # stride must be non-zero; sample sign and magnitude separately. - stride_mag = draw(st.integers(min_value=1, max_value=3)) - stride_sign = draw(st.sampled_from([-1, 1])) return DimensionMap( input_dimension=draw(st.integers(min_value=0, max_value=input_rank - 1)), offset=draw(st.integers(min_value=-10, max_value=10)), - stride=stride_sign * stride_mag, + stride=draw(st.integers(min_value=1, max_value=3)), ) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index 14f9eaff22..bba15884c8 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -68,6 +68,52 @@ def test_construction_success(case: Expect[IndexTransform, dict[str, Any]]) -> N exception_cls=ValueError, id="dimension-map-out-of-range", ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=0, stride=0),), + }, + msg="must be positive", + exception_cls=ValueError, + id="dimension-map-zero-stride", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=0, stride=-1),), + }, + msg="must be positive", + exception_cls=ValueError, + id="dimension-map-negative-stride", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((5, 3)), + "output": ( + ArrayMap( + index_array=np.array([0, 1, 2], dtype=np.intp), + input_dimensions=(7,), + ), + ), + }, + msg="out of range", + exception_cls=ValueError, + id="array-map-input-dim-out-of-range", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((5, 3)), + "output": ( + ArrayMap( + index_array=np.zeros((5, 5), dtype=np.intp), + input_dimensions=(0, 0), + ), + ), + }, + msg="duplicate dimensions", + exception_cls=ValueError, + id="array-map-input-dims-duplicate", + ), ], ids=lambda c: c.id, ) @@ -871,3 +917,449 @@ def test_translate_errors(case: ExpectErr[tuple[IndexTransform, tuple[int, ...]] transform, shift = case.input with pytest.raises(case.exception_cls, match=case.msg): transform.translate(shift) + + +# --------------------------------------------------------------------------- +# selection_repr and __repr__: verify the human-readable strings cover each +# OutputIndexMap variant. +# --------------------------------------------------------------------------- + + +def test_selection_repr_covers_all_map_kinds() -> None: + """selection_repr produces a TensorStore-style domain string with one + entry per output dim, formatted differently for each map kind.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=2, stride=1), + DimensionMap(input_dimension=0, offset=0, stride=3), + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + repr_str = t.selection_repr + assert "5" in repr_str # ConstantMap + assert "[2, 5)" in repr_str # DimensionMap stride=1 over input [0, 3) + assert "step 3" in repr_str # DimensionMap stride=3 + assert "{1, 5, 9}" in repr_str # ArrayMap (small) + + +def test_selection_repr_array_map_large() -> None: + """ArrayMaps with more than 5 elements show as `array(N)` rather than spelled out.""" + arr = np.arange(10, dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + assert "array(10)" in t.selection_repr + + +def test_repr_covers_all_map_kinds() -> None: + """__repr__ formats each output map with its kind-specific shape.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10, 5)), + output=( + ConstantMap(offset=7), + DimensionMap(input_dimension=0, offset=1, stride=2), + ArrayMap( + index_array=np.array([0, 1, 2, 3, 4], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + s = repr(t) + assert "out[0] = 7" in s + assert "out[1] = 1 + 2 * in[0]" in s + assert "out[2] = 0 + 1 * arr(5,)[in[1]]" in s + + +# --------------------------------------------------------------------------- +# intersect() public dispatch: prior tests call _intersect_vectorized directly; +# the public IndexTransform.intersect() vectorized path was untested. +# --------------------------------------------------------------------------- + + +def test_intersect_dispatches_to_vectorized_when_arraymaps_correlated() -> None: + """IndexTransform.intersect() uses the vectorized path when 2+ ArrayMaps + share an input dimension. It uses the orthogonal path when ArrayMaps have + disjoint input dimensions.""" + # Correlated: both ArrayMaps share input_dimensions=(0,) on a 1-D domain. + t_correlated = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([10, 11, 12], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 10), exclusive_max=(8, 12)) + result = t_correlated.intersect(chunk) + assert result is not None + _, surviving = result + # Both points (1,10), (5,11) survive; (9,12) fails dim 1. + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1])) + + +# --------------------------------------------------------------------------- +# _intersect_vectorized with a DimensionMap output on a NON-correlated input +# dim: the post-refactor path that preserves the non-broadcast input dim. +# --------------------------------------------------------------------------- + + +def test_intersect_vectorized_preserves_non_correlated_dim() -> None: + """A vindex transform with a non-broadcast input dim produces a + DimensionMap on that dim. Intersecting must remap that DimensionMap's + input_dimension to the new domain (where the broadcast dim has been + collapsed to (len(surviving),) at index 0).""" + # Construct the transform that vindex-with-trailing-slice would produce: + # (broadcast_dim=3, slice_dim=20). output[0] = ArrayMap on broadcast, + # output[1] = DimensionMap on slice_dim. + t = IndexTransform( + domain=IndexDomain.from_shape((3, 20)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + # Two correlated outputs needed for vectorized path; add a second ArrayMap + # on the same broadcast dim. + t_with_two_arrays = IndexTransform( + domain=IndexDomain.from_shape((3, 20)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([2, 6, 10], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(20, 20, 20)) + result = t_with_two_arrays.intersect(chunk) + assert result is not None + restricted, surviving = result + # Surviving points: all 3 (all storage coords in [0,20)). + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1, 2])) + # New domain: (3 surviving, 20 from preserved slice dim). + assert restricted.domain.shape == (3, 20) + # output[2] (the DimensionMap) should have its input_dimension remapped + # from old dim 1 to new dim 1 (broadcast dim is now new dim 0). + out_dim_map = restricted.output[2] + assert isinstance(out_dim_map, DimensionMap) + assert out_dim_map.input_dimension == 1 + # silence unused-var: t was an intermediate construction reference + assert t.output_rank == 2 + + +# --------------------------------------------------------------------------- +# _apply_basic_indexing rejects negative slice steps. +# --------------------------------------------------------------------------- + + +def test_basic_indexing_rejects_negative_slice_step() -> None: + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError, match="slice step must be positive"): + t[slice(None, None, -1)] + + +# --------------------------------------------------------------------------- +# _apply_vindex on an existing ArrayMap output raises NotImplementedError. +# --------------------------------------------------------------------------- + + +def test_vindex_on_existing_arraymap_errors() -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=( + ArrayMap( + index_array=np.array([1, 2, 3, 4, 5], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + with pytest.raises(NotImplementedError, match="ArrayMap"): + t.vindex[np.array([0, 2], dtype=np.intp)] + + +# --------------------------------------------------------------------------- +# selection_to_transform validation: reject unsupported selection types. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "basic"), + msg="unsupported selection type", + exception_cls=IndexError, + id="basic-rejects-float", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "orthogonal"), + msg="unsupported selection type", + exception_cls=IndexError, + id="orthogonal-rejects-float", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "vectorized"), + msg="unsupported selection type", + exception_cls=IndexError, + id="vectorized-rejects-float", + ), + ], + ids=lambda c: c.id, +) +def test_selection_to_transform_rejects_unsupported_types( + case: ExpectErr[tuple[IndexTransform, Any, Literal["basic", "orthogonal", "vectorized"]]], +) -> None: + """selection_to_transform's validators reject types like float.""" + transform, selection, mode = case.input + with pytest.raises(case.exception_cls, match=case.msg): + selection_to_transform(selection, transform, mode) + + +# --------------------------------------------------------------------------- +# _apply_oindex parsing branches: bare int, list selection. +# --------------------------------------------------------------------------- + + +def test_oindex_bare_int_becomes_singleton_array() -> None: + """oindex[3] on a 1-D transform converts the int to a 1-element array, + producing an ArrayMap of length 1 (not a ConstantMap).""" + t = IndexTransform.from_shape((10,)) + result = t.oindex[3] + assert result.input_rank == 1 + assert result.domain.shape == (1,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([3])) + + +def test_oindex_list_selection() -> None: + """oindex accepts a Python list and converts it to an integer array.""" + t = IndexTransform.from_shape((10,)) + result = t.oindex[[1, 3, 5]] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([1, 3, 5])) + + +# --------------------------------------------------------------------------- +# _apply_vindex parsing branches: ellipsis, 2D bool, list, bare int. +# --------------------------------------------------------------------------- + + +def test_vindex_ellipsis() -> None: + """vindex[...] is a no-op identity.""" + t = IndexTransform.from_shape((4, 5)) + result = t.vindex[...] + assert result.domain.shape == (4, 5) + + +def test_vindex_2d_bool_mask_consumes_two_dims() -> None: + """A 2-D bool mask in vindex consumes both dims of a 2-D domain and + expands into two correlated 1-D ArrayMaps.""" + t = IndexTransform.from_shape((3, 4)) + mask = np.array( + [[True, False, True, False], [False, True, False, True], [True, True, False, False]] + ) + result = t.vindex[mask] + # 6 True entries; broadcast shape (6,). + assert result.domain.shape == (6,) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + +def test_vindex_list_selection() -> None: + """vindex accepts a Python list like oindex does.""" + t = IndexTransform.from_shape((10,)) + result = t.vindex[[1, 3, 5]] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + + +def test_vindex_bare_int_becomes_singleton_array() -> None: + """vindex[3] on a 1-D transform produces an ArrayMap of length 1.""" + t = IndexTransform.from_shape((10,)) + result = t.vindex[3] + assert result.domain.shape == (1,) + assert isinstance(result.output[0], ArrayMap) + + +def test_vindex_with_fewer_selections_than_dims_pads_with_slice() -> None: + """vindex(arr) on a 2-D domain leaves trailing dims untouched (slice fill).""" + t = IndexTransform.from_shape((3, 5)) + result = t.vindex[np.array([0, 1], dtype=np.intp)] + # Broadcast dim (2,) prepended; trailing dim (5,) preserved. + assert result.domain.shape == (2, 5) + + +# --------------------------------------------------------------------------- +# ConstantMap survives basic / oindex / vindex unchanged. The tests above +# exercise these paths for DimensionMap-only transforms; these cover the +# `output[i] is ConstantMap` branch in each of the three apply functions. +# --------------------------------------------------------------------------- + + +def test_basic_indexing_preserves_constant_map() -> None: + """A ConstantMap output passes through basic indexing unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t[2:8] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_oindex_preserves_constant_map() -> None: + """A ConstantMap output passes through oindex unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t.oindex[np.array([1, 3, 5], dtype=np.intp)] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_vindex_preserves_constant_map() -> None: + """A ConstantMap output passes through vindex unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t.vindex[np.array([1, 3, 5], dtype=np.intp)] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_intersect_vectorized_constant_inside_chunk_passes() -> None: + """In _intersect_vectorized, a ConstantMap whose offset is inside the + chunk's range on its output dim is passed through. (The outside-chunk + case yields None and is already tested.)""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), + ConstantMap(offset=5), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(10, 20, 10)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[2], ConstantMap) + assert restricted.output[2].offset == 5 + + +# --------------------------------------------------------------------------- +# Domain-level edge cases: empty-domain intersect, oindex with ellipsis or +# trailing-dim implicit fill. +# --------------------------------------------------------------------------- + + +def test_intersect_dimension_map_on_empty_domain_returns_none() -> None: + """When a DimensionMap's input dim is already empty (input_lo >= input_hi), + intersect returns None.""" + t = IndexTransform( + domain=IndexDomain(inclusive_min=(0,), exclusive_max=(0,)), + output=(DimensionMap(input_dimension=0, offset=0, stride=1),), + ) + assert t.intersect(IndexDomain.from_shape((10,))) is None + + +def test_oindex_with_ellipsis() -> None: + """oindex with ellipsis fills missing dims with slice(None).""" + t = IndexTransform.from_shape((4, 5, 6)) + result = t.oindex[np.array([0, 2], dtype=np.intp), ...] + # ellipsis fills dims 1 and 2 with slice(None); domain becomes (2, 5, 6). + assert result.domain.shape == (2, 5, 6) + + +def test_oindex_with_implicit_trailing_dim_fill() -> None: + """oindex with fewer entries than ndim pads trailing dims with slice(None).""" + t = IndexTransform.from_shape((4, 5, 6)) + result = t.oindex[np.array([0, 2], dtype=np.intp)] + # Only the first dim is selected; trailing dims pad with slice(None). + assert result.domain.shape == (2, 5, 6) + + +# --------------------------------------------------------------------------- +# IndexTransform.__post_init__ shape mismatch error: covered in test_construction_errors +# above? No — the shape mismatch is implicit (the __post_init__ check fires when +# ArrayMap shape != domain shape on input_dimensions), and it's hit by the +# multi-dim oindex test elsewhere. Add an explicit test. +# --------------------------------------------------------------------------- + + +def test_construction_rejects_shape_mismatch() -> None: + """ArrayMap.index_array.shape must match the input domain's extents on + input_dimensions (in order).""" + with pytest.raises(ValueError, match="does not match expected shape"): + IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ArrayMap( + index_array=np.array([1, 2, 3], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + + +# --------------------------------------------------------------------------- +# _normalize_basic_selection error paths. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((2,)), (1, 2, 3)), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ExpectErr( + input=(IndexTransform.from_shape((3, 3, 3)), (..., 0, ...)), + msg="single ellipsis", + exception_cls=IndexError, + id="double-ellipsis", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5), + msg="unsupported selection type", + exception_cls=IndexError, + id="float-not-supported", + ), + ], + ids=lambda c: c.id, +) +def test_basic_indexing_rejects_malformed_selections( + case: ExpectErr[tuple[IndexTransform, Any]], +) -> None: + """_normalize_basic_selection error paths: too-many-indices, double-ellipsis, + and unsupported types like float.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform[selection] From ea661d4f1ee3342b4f8dc7ce9edd1a5299c4ddf4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 May 2026 11:22:48 -0400 Subject: [PATCH 23/24] docs(domain): add import in docstring --- src/zarr/core/_transforms/domain.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zarr/core/_transforms/domain.py b/src/zarr/core/_transforms/domain.py index e2f9175225..18a21ae4e4 100644 --- a/src/zarr/core/_transforms/domain.py +++ b/src/zarr/core/_transforms/domain.py @@ -4,6 +4,8 @@ array view. It is the cartesian product of per-dimension integer ranges: ```python +from zarr.core._transforms.domain import IndexDomain + IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) # represents {(i, j) : 2 <= i < 10, 5 <= j < 20} ``` From 542036f714f9c0f1c1769c2201408b9b8659c5ba Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 May 2026 11:29:46 -0400 Subject: [PATCH 24/24] test(_transforms): reach 100% line+branch coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two complementary changes: 1. Add tests that exercise real production paths the previous tests missed: - selection_repr / __repr__ / translate / basic / oindex / intersect output loops with an ArrayMap-then-DimensionMap transform (covers the ArrayMap branch's loop-continuation path). - intersect with two uncorrelated ArrayMaps (disjoint input_dimensions): exercises the orthogonal path's vectorized- detection no-correlation-found exit. - iter_chunk_transforms / sub_transform_to_selections with ArrayMap-followed-by-DimensionMap output (loop continuation). - sub_transform_to_selections with out_indices and a non-ArrayMap output preceding an ArrayMap (correlation-detection skip). - sub_transform_to_selections with two uncorrelated ArrayMaps under out_indices (orthogonal path with two arraymap entries). - compose: outer with mixed types (ConstantMap + DimensionMap), inner ArrayMap pointing at the ConstantMap dim. Exercises the path where outer.output[dim_i] is ConstantMap (not all-constant outer) and falls through to NotImplementedError. 2. Add `# pragma: no branch` to elifs that close exhaustive type unions: `elif isinstance(m, ArrayMap)` after ConstantMap and DimensionMap (OutputIndexMap = ConstantMap | DimensionMap | ArrayMap, so the False outcome is structurally unreachable), and `elif isinstance(sel, slice)` after the corresponding earlier elifs in _apply_basic_indexing and _apply_oindex (for sel ∈ int|slice|None and sel ∈ ndarray|slice respectively). 11 sites total — coverage tools can't detect type-union exhaustiveness. Test count: 191 → 202. Coverage: 99% → 100% (0 missed lines, 0 partial branches). --- src/zarr/core/_transforms/chunk_resolution.py | 8 +- src/zarr/core/_transforms/transform.py | 32 ++++-- .../test_transforms/test_chunk_resolution.py | 89 +++++++++++++++ tests/test_transforms/test_composition.py | 28 +++++ tests/test_transforms/test_transform.py | 107 ++++++++++++++++++ 5 files changed, 251 insertions(+), 13 deletions(-) diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py index f1d3c80e67..bed247fcaf 100644 --- a/src/zarr/core/_transforms/chunk_resolution.py +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -87,7 +87,7 @@ def iter_chunk_transforms( first = dg.index_to_chunk(s_min) last = dg.index_to_chunk(s_max) chunk_ranges.append(range(first, last + 1)) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union storage = m.offset + m.stride * m.index_array flat = storage.ravel().astype(np.intp) chunk_ids = dg.indices_to_chunks(flat) @@ -165,7 +165,7 @@ def sub_transform_to_selections( start = m.offset + m.stride * dim_lo stop = m.offset + m.stride * dim_hi chunk_sel.append(slice(start, stop, m.stride)) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union if m.offset == 0 and m.stride == 1: chunk_sel.append(m.index_array) else: @@ -198,7 +198,9 @@ def sub_transform_to_selections( lo = sub_transform.domain.inclusive_min[m.input_dimension] hi = sub_transform.domain.exclusive_max[m.input_dimension] out_sel.append(slice(lo, hi)) - elif isinstance(m, ArrayMap): + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union if out_indices is not None: # Orthogonal ArrayMap: out_indices has the surviving positions out_sel.append(out_indices) diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py index 9a4233d0ad..b67a5400e6 100644 --- a/src/zarr/core/_transforms/transform.py +++ b/src/zarr/core/_transforms/transform.py @@ -70,7 +70,9 @@ def __post_init__(self) -> None: "Negative-stride DimensionMaps are not supported; the " "Array layer normalizes negative strides upstream." ) - elif isinstance(m, ArrayMap): + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union # Every input dim referenced must be in range. for d in m.input_dimensions: if d < 0 or d >= self.domain.ndim: @@ -134,7 +136,9 @@ def selection_repr(self) -> str: parts.append(f"[{start}, {stop})") else: parts.append(f"[{start}, {stop}) step {m.stride}") - elif isinstance(m, ArrayMap): + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union storage = m.offset + m.stride * m.index_array n = len(storage) if n <= 5: @@ -151,7 +155,9 @@ def __repr__(self) -> str: maps.append(f"out[{i}] = {m.offset}") elif isinstance(m, DimensionMap): maps.append(f"out[{i}] = {m.offset} + {m.stride} * in[{m.input_dimension}]") - elif isinstance(m, ArrayMap): + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union in_dims = ",".join(f"in[{d}]" for d in m.input_dimensions) maps.append( f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[{in_dims}]" @@ -188,7 +194,9 @@ def translate(self, shift: tuple[int, ...]) -> IndexTransform: stride=m.stride, ) ) - elif isinstance(m, ArrayMap): + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union new_output.append( ArrayMap( index_array=m.index_array, @@ -285,7 +293,7 @@ def _intersect( new_max[d] = new_input_hi new_output.append(m) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union storage = m.offset + m.stride * m.index_array mask = (storage >= lo) & (storage < hi) if not np.any(mask): @@ -494,7 +502,9 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra dropped_dims.add(old_dim) dim_int_val[old_dim] = idx old_dim += 1 - elif isinstance(sel, slice): + elif isinstance( + sel, slice + ): # pragma: no branch - exhaustive over normalized's element type lo = transform.domain.inclusive_min[old_dim] hi = transform.domain.exclusive_max[old_dim] dim_size = hi - lo @@ -548,7 +558,7 @@ def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTra raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms f"unexpected: dimension {d} not handled" ) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union # The array's axes are labeled by m.input_dimensions, in order. # For each labeled axis: if the corresponding old input dim is # dropped (int), select that one entry; if sliced, slice the axis; @@ -662,7 +672,9 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: new_exclusive_max.append(len(sel)) old_to_new_dim[old_dim] = new_dim_idx new_dim_idx += 1 - elif isinstance(sel, slice): + elif isinstance( + sel, slice + ): # pragma: no branch - exhaustive over normalized's element type lo = transform.domain.inclusive_min[old_dim] hi = transform.domain.exclusive_max[old_dim] dim_size = hi - lo @@ -711,7 +723,7 @@ def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms f"unexpected: dimension {d} not handled" ) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union # Each axis of m.index_array corresponds to one entry in # m.input_dimensions. For each such old input dim, oindex either # picks specific entries (dim_array[d]) or slices the axis @@ -914,7 +926,7 @@ def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: input_dimension=new_input_dim, offset=new_offset, stride=new_stride ) ) - elif isinstance(m, ArrayMap): + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union # vindex on a transform that already has an ArrayMap output is not # currently exercised. The semantics are subtle (broadcasting can # reshape the array's parameterization) and require careful design; diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py index b0fd929bbd..f6af94daa4 100644 --- a/tests/test_transforms/test_chunk_resolution.py +++ b/tests/test_transforms/test_chunk_resolution.py @@ -270,6 +270,95 @@ def test_iter_chunk_transforms_empty_domain() -> None: assert results == [] +def test_iter_chunk_transforms_arraymap_followed_by_dimensionmap() -> None: + """An ArrayMap output followed by a DimensionMap output exercises the + ArrayMap branch's loop-continuation path in iter_chunk_transforms.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + grid = _grid_2d(10, 10, 10, 5) + results = list(iter_chunk_transforms(t, grid)) + # Sanity: at least one result is yielded. + assert results + + +def test_sub_transform_to_selections_arraymap_followed_by_dimensionmap_orthogonal() -> None: + """An ArrayMap output followed by a DimensionMap output in non-vectorized + mode (out_indices=None) exercises the ArrayMap branch's loop-continuation + path in both the chunk_sel and out_sel construction loops.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + # Two output dims, both with selections. + assert len(chunk_sel) == 2 + assert isinstance(chunk_sel[0], np.ndarray) # ArrayMap → array selection + assert isinstance(chunk_sel[1], slice) # DimensionMap → slice + assert len(out_sel) == 2 + assert drop_axes == () + + +def test_sub_transform_to_selections_with_out_indices_skips_non_arraymap_in_correlation_check() -> ( + None +): + """When out_indices is supplied and an output is NOT an ArrayMap, the + correlation-detection loop skips it (covers the `if isinstance(m, ArrayMap)` + False branch in vectorized detection).""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + DimensionMap(input_dimension=0, offset=0, stride=1), + ArrayMap( + index_array=np.array([1, 2, 3], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + # Single ArrayMap → not vectorized; falls through to the orthogonal path. + assert len(chunk_sel) == 2 + assert len(out_sel) == 2 + + +def test_sub_transform_to_selections_uncorrelated_arraymaps_with_out_indices() -> None: + """Two uncorrelated ArrayMaps (disjoint input_dimensions) plus out_indices + falls through to the non-vectorized branch (covers the for-loop early + exit when no correlation found).""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=( + ArrayMap( + index_array=np.array([0, 1, 2], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([0, 1, 2, 3], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + # Non-vectorized: each ArrayMap contributes its own out_sel entry. + assert len(chunk_sel) == 2 + assert len(out_sel) == 2 + + def test_iter_chunk_transforms_skips_chunks_that_intersect_returns_none() -> None: """A strided DimensionMap can produce a chunk-range overestimate that includes chunks the transform doesn't actually touch. iter_chunk_transforms diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py index 74ccffbe93..3035c3a2a9 100644 --- a/tests/test_transforms/test_composition.py +++ b/tests/test_transforms/test_composition.py @@ -248,6 +248,34 @@ def test_compose_chains_associatively() -> None: exception_cls=NotImplementedError, id="multi-d-array-inner-non-constant-outer", ), + ExpectErr( + input=( + # Outer with mixed types: ConstantMap on dim 0, DimensionMap on dim 1. + # Outer is NOT all-constant, so the early-return path is skipped. + IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=( + ConstantMap(offset=2), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ), + # Inner: 1-D ArrayMap referencing outer's dim 0 (the ConstantMap). + # _compose_array reaches the 1-D path; outer.output[0] is ConstantMap, + # which falls through both inner elifs to NotImplementedError. + IndexTransform( + domain=IndexDomain.from_shape((5, 4)), + output=( + ArrayMap( + index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ), + ), + msg="not yet supported", + exception_cls=NotImplementedError, + id="single-input-dim-points-at-constantmap-with-mixed-outer", + ), ], ids=lambda c: c.id, ) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py index bba15884c8..b615048d4a 100644 --- a/tests/test_transforms/test_transform.py +++ b/tests/test_transforms/test_transform.py @@ -1363,3 +1363,110 @@ def test_basic_indexing_rejects_malformed_selections( transform, selection = case.input with pytest.raises(case.exception_cls, match=case.msg): transform[selection] + + +# --------------------------------------------------------------------------- +# Transforms with ArrayMap NOT in the last output position. +# +# Several `for m in self.output:` loops in selection_repr, __repr__, basic / +# oindex / vindex apply functions, and _intersect's orthogonal path have an +# `elif isinstance(m, ArrayMap):` branch that, for branch coverage, needs to +# be exercised with an ArrayMap that is NOT the last output (i.e., the loop +# continues to a next iteration after the ArrayMap branch). The fixture below +# constructs a transform with ArrayMap-then-DimensionMap output ordering; +# the tests use it to hit those continuation branches. +# --------------------------------------------------------------------------- + + +def _arraymap_then_dimensionmap() -> IndexTransform: + """Helper: a 2-D-input transform whose first output is an ArrayMap and + whose second output is a DimensionMap. Ensures `for m in output` loops + encounter an ArrayMap with a next iteration available.""" + return IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 4, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + + +def test_selection_repr_with_arraymap_not_last() -> None: + """selection_repr output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + s = t.selection_repr + assert "{1, 4, 9}" in s + assert "[0, 5)" in s + + +def test_repr_with_arraymap_not_last() -> None: + """__repr__ output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + s = repr(t) + assert "out[0] = 0 + 1 * arr(3,)[in[0]]" in s + assert "out[1] = 0 + 1 * in[1]" in s + + +def test_translate_with_arraymap_not_last() -> None: + """IndexTransform.translate output loop visits ArrayMap then continues. + + The shift is applied to every output, so an ArrayMap-then-DimensionMap + transform produces a (translated ArrayMap, translated DimensionMap) + pair.""" + t = _arraymap_then_dimensionmap() + result = t.translate((10, 100)) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 10 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 100 + + +def test_basic_indexing_with_arraymap_not_last() -> None: + """_apply_basic_indexing output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + result = t[:, 2:5] + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + + +def test_oindex_with_arraymap_not_last() -> None: + """_apply_oindex output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + result = t.oindex[:, np.array([0, 2, 4], dtype=np.intp)] + # Two outputs preserved: the original ArrayMap (untouched on its + # parameterizing dim) and the new ArrayMap created from the DimensionMap. + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + +def test_intersect_with_two_uncorrelated_arraymaps_uses_orthogonal_path() -> None: + """When 2+ ArrayMaps have disjoint input_dimensions (no shared input dim), + intersect detects no correlation and falls through to the orthogonal path, + NOT the vectorized path. Also exercises the `for m in output` orthogonal + loop visiting an ArrayMap that is not the last output.""" + # 2-D input domain (3, 4); two ArrayMaps with disjoint input_dimensions. + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=( + ArrayMap( + index_array=np.array([0, 5, 10], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([20, 30, 40, 50], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + # Chunk that includes everything. The orthogonal path filters each + # ArrayMap independently against its output dim's chunk range. + chunk = IndexDomain.from_shape((100, 100)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + # Both outputs survive as ArrayMaps (orthogonal path preserves them). + assert isinstance(restricted.output[0], ArrayMap) + assert isinstance(restricted.output[1], ArrayMap)