diff --git a/docs/developing/grass.py b/docs/developing/grass.py index 19377858..97499bb3 100644 --- a/docs/developing/grass.py +++ b/docs/developing/grass.py @@ -1,7 +1,8 @@ # > imports import enum +from collections.abc import Hashable, Sequence from functools import cached_property -from typing import Dict, Hashable, Optional, Sequence, Tuple +from typing import Optional import numpy import xarray @@ -19,7 +20,7 @@ class GrassGridKind(enum.Enum): fence = 'fence' -GrassIndex = Tuple[GrassGridKind, Sequence[int]] +GrassIndex = tuple[GrassGridKind, Sequence[int]] class Grass(DimensionConvention[GrassGridKind, GrassIndex]): @@ -37,14 +38,14 @@ def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: return Specificity.HIGH return None - def unpack_index(self, index: GrassIndex) -> Tuple[GrassGridKind, Sequence[int]]: + def unpack_index(self, index: GrassIndex) -> tuple[GrassGridKind, Sequence[int]]: return index[0], list(index[1]) def pack_index(self, grid_kind: GrassGridKind, indices: Sequence[int]) -> GrassIndex: return (grid_kind, list(indices)) @cached_property - def grid_dimensions(self) -> Dict[GrassGridKind, Sequence[Hashable]]: + def grid_dimensions(self) -> dict[GrassGridKind, Sequence[Hashable]]: return { GrassGridKind.field: ['warp', 'weft'], GrassGridKind.fence: ['post'], diff --git a/docs/releases/development.rst b/docs/releases/development.rst index 7175a6f5..3fa1d91b 100644 --- a/docs/releases/development.rst +++ b/docs/releases/development.rst @@ -11,6 +11,9 @@ Next release (in development) * Drop dependency on importlib_metadata. This was only required to support Python 3.8, which was dropped in a previous release (:issue:`122`, :pr:`125`). -* Fix an error with `ShocSimple.get_all_depth_names()` +* Fix an error with ``ShocSimple.get_all_depth_names()`` when the dataset had no depth coordinates (:issue:`123`, :pr:`126`). +* Use `PEP 585 generic type annotations `_ + and stop using `PEP 563 postponed annotation evaluation `_ + (:issue:`109`, :pr:`127`). diff --git a/docs/roles.py b/docs/roles.py index 82b51dd5..c29279e3 100644 --- a/docs/roles.py +++ b/docs/roles.py @@ -1,5 +1,6 @@ import re -from typing import Callable, Iterable, List, Tuple, cast +from collections.abc import Iterable +from typing import Callable, cast import yaml from docutils import nodes, utils @@ -30,7 +31,7 @@ def role_fn( inliner: Inliner, options: dict = {}, content: list = [], - ) -> Tuple[list, list]: + ) -> tuple[list, list]: match = GITHUB_FULL_REF.match(utils.unescape(text)) if match is not None: repo = match.group('repo') @@ -83,7 +84,7 @@ def load_citation_file(self) -> dict: with open(citation_file, 'r') as f: return cast(dict, yaml.load(f, yaml.Loader)) - def run(self) -> List[nodes.Node]: + def run(self) -> list[nodes.Node]: if self.options['format'] == 'apa': return self.run_apa() elif self.options['format'] == 'biblatex': @@ -91,7 +92,7 @@ def run(self) -> List[nodes.Node]: else: raise ValueError("Unknown format") - def run_apa(self) -> List[nodes.Node]: + def run_apa(self) -> list[nodes.Node]: citation = self.load_citation_file() names = self.comma_ampersand_join(map(self.apa_name, citation['authors'])) year = citation['date-released'].year @@ -118,7 +119,7 @@ def comma_ampersand_join(self, items: Iterable[str]) -> str: return items[0] return '{}, & {}'.format(', '.join(items[:-1]), items[-1]) - def run_biblatex(self) -> List[nodes.Node]: + def run_biblatex(self) -> list[nodes.Node]: citation = self.load_citation_file() year = citation['date-released'].year diff --git a/scripts/release.py b/scripts/release.py index 741c91ce..57682339 100644 --- a/scripts/release.py +++ b/scripts/release.py @@ -6,7 +6,7 @@ import shlex import subprocess import sys -from typing import List, Optional +from typing import Optional PROJECT = pathlib.Path(__file__).parent.parent @@ -29,7 +29,7 @@ def main( - args: Optional[List[str]] = None, + args: Optional[list[str]] = None, ) -> None: parser = argparse.ArgumentParser() add_options(parser) diff --git a/src/emsarray/accessors.py b/src/emsarray/accessors.py index 4544cc9b..643d9242 100644 --- a/src/emsarray/accessors.py +++ b/src/emsarray/accessors.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging import xarray diff --git a/src/emsarray/cli/__init__.py b/src/emsarray/cli/__init__.py index 24f97d09..30d19197 100644 --- a/src/emsarray/cli/__init__.py +++ b/src/emsarray/cli/__init__.py @@ -6,7 +6,7 @@ import argparse import importlib import pkgutil -from typing import Iterable, Type +from collections.abc import Iterable import emsarray @@ -40,7 +40,7 @@ def main(options: argparse.Namespace) -> None: options.func(options) -def _find_all_commands() -> Iterable[Type[BaseCommand]]: +def _find_all_commands() -> Iterable[type[BaseCommand]]: for moduleinfo in pkgutil.iter_modules(commands.__path__): if moduleinfo.name.startswith('_'): continue diff --git a/src/emsarray/cli/commands/clip.py b/src/emsarray/cli/commands/clip.py index bb04028a..1d65b00c 100644 --- a/src/emsarray/cli/commands/clip.py +++ b/src/emsarray/cli/commands/clip.py @@ -3,7 +3,6 @@ import logging import tempfile from pathlib import Path -from typing import ContextManager import emsarray from emsarray.cli import BaseCommand @@ -39,7 +38,7 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: )) def handle(self, options: argparse.Namespace) -> None: - work_context: ContextManager[Pathish] + work_context: contextlib.AbstractContextManager[Pathish] if options.work_dir: work_context = contextlib.nullcontext(options.work_dir) else: diff --git a/src/emsarray/cli/commands/export_geometry.py b/src/emsarray/cli/commands/export_geometry.py index 97ef02d1..4a0a860c 100644 --- a/src/emsarray/cli/commands/export_geometry.py +++ b/src/emsarray/cli/commands/export_geometry.py @@ -1,7 +1,7 @@ import argparse import logging from pathlib import Path -from typing import Callable, Dict +from typing import Callable import xarray @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) Writer = Callable[[xarray.Dataset, Pathish], None] -format_writers: Dict[str, Writer] = { +format_writers: dict[str, Writer] = { 'geojson': geometry.write_geojson, 'shapefile': geometry.write_shapefile, 'wkt': geometry.write_wkt, diff --git a/src/emsarray/cli/commands/plot.py b/src/emsarray/cli/commands/plot.py index 37ed9a77..f0c342e7 100644 --- a/src/emsarray/cli/commands/plot.py +++ b/src/emsarray/cli/commands/plot.py @@ -2,7 +2,7 @@ import functools import logging from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Text, TypeVar +from typing import Any, Callable, Optional, TypeVar import emsarray from emsarray.cli import BaseCommand, CommandException @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def key_value(arg: str, value_type: Callable = str) -> Dict[str, T]: +def key_value(arg: str, value_type: Callable = str) -> dict[str, T]: try: name, value = arg.split("=", 2) except ValueError: @@ -24,11 +24,11 @@ def key_value(arg: str, value_type: Callable = str) -> Dict[str, T]: class UpdateDict(argparse.Action): def __init__( self, - option_strings: List[str], + option_strings: list[str], dest: str, *, value_type: Callable = str, - default: Optional[Dict[str, Any]] = None, + default: Optional[dict[str, Any]] = None, **kwargs: Any, ): if default is None: @@ -42,7 +42,7 @@ def __call__( parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, - option_string: Optional[Text] = None, + option_string: Optional[str] = None, ) -> None: super().__call__ holder = getattr(namespace, self.dest, {}) diff --git a/src/emsarray/cli/utils.py b/src/emsarray/cli/utils.py index 72485e83..446dcdf8 100644 --- a/src/emsarray/cli/utils.py +++ b/src/emsarray/cli/utils.py @@ -9,9 +9,10 @@ import re import sys import textwrap +from collections.abc import Iterator from functools import wraps from pathlib import Path -from typing import Callable, Iterator, List, Optional, Protocol +from typing import Callable, Optional, Protocol from shapely.geometry import box, shape from shapely.geometry.base import BaseGeometry @@ -29,7 +30,7 @@ class MainCallable(Protocol): def __call__( self, - argv: Optional[List[str]] = None, + argv: Optional[list[str]] = None, handle_errors: bool = True, ) -> None: ... @@ -90,7 +91,7 @@ def main(options: argparse.Namespace) -> None: .. code-block:: python @nice_console_errors() - def main(argv: Optional[List[str]]) -> None: + def main(argv: Optional[list[str]]) -> None: parser = argparse.ArgumentParser() add_verbosity_group(parser) command_line_flags(parser) @@ -111,7 +112,7 @@ def decorator( ) -> MainCallable: @wraps(fn) def wrapper( - argv: Optional[List[str]] = None, + argv: Optional[list[str]] = None, handle_errors: bool = True, ) -> None: parser = argparse.ArgumentParser( diff --git a/src/emsarray/compat/shapely.py b/src/emsarray/compat/shapely.py index 273130fc..1fe8d2c2 100644 --- a/src/emsarray/compat/shapely.py +++ b/src/emsarray/compat/shapely.py @@ -1,5 +1,6 @@ import warnings -from typing import Generic, Iterable, Tuple, TypeVar, Union, cast +from collections.abc import Iterable +from typing import Generic, TypeVar, Union, cast import numpy import shapely @@ -30,7 +31,7 @@ class SpatialIndex(Generic[T]): def __init__( self, - items: Union[numpy.ndarray, Iterable[Tuple[BaseGeometry, T]]], + items: Union[numpy.ndarray, Iterable[tuple[BaseGeometry, T]]], ): self.items = numpy.array(items, dtype=self.dtype) diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index 61be5e1b..b3b3939d 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -1,14 +1,12 @@ -from __future__ import annotations - import abc import dataclasses import enum import logging import warnings +from collections.abc import Hashable, Iterable, Sequence from functools import cached_property from typing import ( - TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, Iterable, - List, Optional, Sequence, Tuple, TypeVar, Union, cast + TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, cast ) import numpy @@ -399,7 +397,7 @@ def get_depth_name(self) -> Hashable: except IndexError: raise NoSuchCoordinateError("Could not find depth coordinate in dataset") - def get_all_depth_names(self) -> List[Hashable]: + def get_all_depth_names(self) -> list[Hashable]: """Get the names of all depth layers. Some datasets include both a depth layer centre, and the depth layer 'edges'. @@ -595,7 +593,7 @@ def unravel_index( @property @abc.abstractmethod - def grid_kinds(self) -> FrozenSet[GridKind]: + def grid_kinds(self) -> frozenset[GridKind]: """ All of the :data:`grid kinds <.GridKind>` this dataset includes. """ @@ -612,7 +610,7 @@ def default_grid_kind(self) -> GridKind: @property @abc.abstractmethod - def grid_size(self) -> Dict[GridKind, int]: + def grid_size(self) -> dict[GridKind, int]: """The linear size of each grid kind.""" pass @@ -662,7 +660,7 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> GridKind: def get_grid_kind_and_size( self, data_array: xarray.DataArray, - ) -> Tuple[GridKind, int]: + ) -> tuple[GridKind, int]: """ Determines the relevant index kind and the extent of the linear index space for this data array. @@ -869,7 +867,7 @@ def make_linear(self, data_array: xarray.DataArray) -> xarray.DataArray: @cached_property # type: ignore @_requires_plot - def data_crs(self) -> CRS: + def data_crs(self) -> 'CRS': """ The coordinate reference system that coordinates in this dataset are defined in. @@ -883,9 +881,9 @@ def data_crs(self) -> CRS: @_requires_plot def plot_on_figure( self, - figure: Figure, + figure: 'Figure', scalar: Optional[DataArrayOrName] = None, - vector: Optional[Tuple[DataArrayOrName, DataArrayOrName]] = None, + vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None, title: Optional[str] = None, **kwargs: Any, ) -> None: @@ -963,13 +961,13 @@ def plot(self, *args: Any, **kwargs: Any) -> None: @_requires_plot def animate_on_figure( self, - figure: Figure, + figure: 'Figure', scalar: Optional[DataArrayOrName] = None, - vector: Optional[Tuple[DataArrayOrName, DataArrayOrName]] = None, + vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None, coordinate: Optional[DataArrayOrName] = None, title: Optional[Union[str, Callable[[Any], str]]] = None, **kwargs: Any, - ) -> FuncAnimation: + ) -> 'FuncAnimation': """ Make an animated plot of a data array. @@ -1066,7 +1064,7 @@ def make_poly_collection( self, data_array: Optional[DataArrayOrName] = None, **kwargs: Any, - ) -> PolyCollection: + ) -> 'PolyCollection': """ Make a :class:`~matplotlib.collections.PolyCollection` from the geometry of this :class:`~xarray.Dataset`. @@ -1143,7 +1141,7 @@ def make_patch_collection( self, data_array: Optional[DataArrayOrName] = None, **kwargs: Any, - ) -> PolyCollection: + ) -> 'PolyCollection': warnings.warn( "Convention.make_patch_collection has been renamed to " "Convention.make_poly_collection, and now returns a PolyCollection", @@ -1154,11 +1152,11 @@ def make_patch_collection( @_requires_plot def make_quiver( self, - axes: Axes, + axes: 'Axes', u: Optional[DataArrayOrName] = None, v: Optional[DataArrayOrName] = None, **kwargs: Any, - ) -> Quiver: + ) -> 'Quiver': """ Make a :class:`matplotlib.quiver.Quiver` instance to plot vector data. @@ -1187,7 +1185,7 @@ def make_quiver( # sometimes preferring to fill them in later, # so `u` and `v` are optional. # If they are not provided, we set default quiver values of `numpy.nan`. - values: Union[Tuple[numpy.ndarray, numpy.ndarray], Tuple[float, float]] + values: Union[tuple[numpy.ndarray, numpy.ndarray], tuple[float, float]] values = numpy.nan, numpy.nan if u is not None and v is not None: @@ -1422,7 +1420,7 @@ def get_index_for_point( return None @abc.abstractmethod - def selector_for_index(self, index: Index) -> Dict[Hashable, int]: + def selector_for_index(self, index: Index) -> dict[Hashable, int]: """ Convert a convention native index into a selector that can be passed to :meth:`Dataset.isel `. @@ -1516,7 +1514,7 @@ def select_point(self, point: Point) -> xarray.Dataset: return self.select_index(index.index) @abc.abstractmethod - def get_all_geometry_names(self) -> List[Hashable]: + def get_all_geometry_names(self) -> list[Hashable]: """ Return a list of the names of all geometry variables used by this convention. @@ -1742,7 +1740,7 @@ class DimensionConvention(Convention[GridKind, Index]): @property @abc.abstractmethod - def grid_dimensions(self) -> Dict[GridKind, Sequence[Hashable]]: + def grid_dimensions(self) -> dict[GridKind, Sequence[Hashable]]: """ The dimensions associated with a particular grid kind. @@ -1760,7 +1758,7 @@ def grid_dimensions(self) -> Dict[GridKind, Sequence[Hashable]]: pass @property - def grid_shape(self) -> Dict[GridKind, Sequence[int]]: + def grid_shape(self) -> dict[GridKind, Sequence[int]]: """ The :attr:`shape ` of each grid kind. @@ -1775,7 +1773,7 @@ def grid_shape(self) -> Dict[GridKind, Sequence[int]]: } @property - def grid_size(self) -> Dict[GridKind, int]: + def grid_size(self) -> dict[GridKind, int]: return { grid_kind: int(numpy.prod(shape)) for grid_kind, shape in self.grid_shape.items() @@ -1806,7 +1804,7 @@ def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray: return self.dataset[data_array] @abc.abstractmethod - def unpack_index(self, index: Index) -> Tuple[GridKind, Sequence[int]]: + def unpack_index(self, index: Index) -> tuple[GridKind, Sequence[int]]: """Convert a native index in to a grid kind and dimension indices. Parameters @@ -1902,7 +1900,7 @@ def wind( dimensions=dimensions, sizes=sizes, linear_dimension=linear_dimension) - def selector_for_index(self, index: Index) -> Dict[Hashable, int]: + def selector_for_index(self, index: Index) -> dict[Hashable, int]: grid_kind, indices = self.unpack_index(index) dimensions = self.grid_dimensions[grid_kind] return dict(zip(dimensions, indices)) diff --git a/src/emsarray/conventions/_registry.py b/src/emsarray/conventions/_registry.py index 607627d2..c70d5d14 100644 --- a/src/emsarray/conventions/_registry.py +++ b/src/emsarray/conventions/_registry.py @@ -1,13 +1,12 @@ -from __future__ import annotations - import logging import sys import warnings +from collections.abc import Iterable from contextlib import suppress from functools import cached_property from importlib import metadata from itertools import chain -from typing import Iterable, List, Optional, Tuple, Type +from typing import Optional import xarray @@ -17,13 +16,13 @@ class ConventionRegistry: - registered_conventions: List[Type[Convention]] + registered_conventions: list[type[Convention]] def __init__(self) -> None: self.registered_conventions = [] @cached_property - def conventions(self) -> Iterable[Type[Convention]]: + def conventions(self) -> Iterable[type[Convention]]: """ A list of all the registered Convention subclasses. This includes those registered via entry points @@ -51,7 +50,7 @@ def conventions(self) -> Iterable[Type[Convention]]: return conventions @cached_property - def entry_point_conventions(self) -> List[Type[Convention]]: + def entry_point_conventions(self) -> list[type[Convention]]: """ Find all conventions registered via the ``emsarray.conventions`` entry point. This list is cached. @@ -64,7 +63,7 @@ def entry_point_conventions(self) -> List[Type[Convention]]: """ return list(entry_point_conventions()) - def add_convention(self, convention: Type[Convention]) -> None: + def add_convention(self, convention: type[Convention]) -> None: """Register a Convention subclass with this registry. Datasets will be checked against this Convention when guessing file types. """ @@ -72,7 +71,7 @@ def add_convention(self, convention: Type[Convention]) -> None: del self.conventions self.registered_conventions.append(convention) - def match_conventions(self, dataset: xarray.Dataset) -> List[Tuple[Type[Convention], int]]: + def match_conventions(self, dataset: xarray.Dataset) -> list[tuple[type[Convention], int]]: """ Get all :class:`~.Convention` implementations that support this dataset. @@ -88,14 +87,14 @@ def match_conventions(self, dataset: xarray.Dataset) -> List[Tuple[Type[Conventi A higher specificity means a better match. The list of matches will be ordered from most to least specific. """ - matches: List[Tuple[Type[Convention], int]] = [] + matches: list[tuple[type[Convention], int]] = [] for convention in self.conventions: match = convention.check_dataset(dataset) if match is not None: matches.append((convention, match)) return sorted(matches, key=lambda m: m[1], reverse=True) - def guess_convention(self, dataset: xarray.Dataset) -> Optional[Type[Convention]]: + def guess_convention(self, dataset: xarray.Dataset) -> Optional[type[Convention]]: """ Guess the correct :class:`.Convention` implementation for a dataset. """ @@ -110,7 +109,7 @@ def guess_convention(self, dataset: xarray.Dataset) -> Optional[Type[Convention] registry = ConventionRegistry() -def get_dataset_convention(dataset: xarray.Dataset) -> Optional[Type[Convention]]: +def get_dataset_convention(dataset: xarray.Dataset) -> Optional[type[Convention]]: """Find the most appropriate Convention subclass for this dataset. Parameters @@ -136,7 +135,7 @@ def get_dataset_convention(dataset: xarray.Dataset) -> Optional[Type[Convention] return registry.guess_convention(dataset) -def entry_point_conventions() -> Iterable[Type[Convention]]: +def entry_point_conventions() -> Iterable[type[Convention]]: """ Finds conventions registered using entry points """ @@ -175,7 +174,7 @@ def entry_point_conventions() -> Iterable[Type[Convention]]: seen.add(obj) -def register_convention(convention: Type[Convention]) -> Type[Convention]: +def register_convention(convention: type[Convention]) -> type[Convention]: """ Register a Convention subclass, used for guessing file types. Can be used as a decorator. diff --git a/src/emsarray/conventions/arakawa_c.py b/src/emsarray/conventions/arakawa_c.py index 141c4f46..1ea25d1d 100644 --- a/src/emsarray/conventions/arakawa_c.py +++ b/src/emsarray/conventions/arakawa_c.py @@ -6,12 +6,11 @@ `Arakawa grids `_ on Wikipedia """ -from __future__ import annotations - import enum import logging +from collections.abc import Hashable, Sequence from functools import cached_property -from typing import Dict, Hashable, List, Optional, Sequence, Tuple, cast +from typing import Optional, cast import numpy import xarray @@ -67,7 +66,7 @@ def i_dimension(self) -> Hashable: return self.latitude.dims[1] @cached_property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: """The shape of this grid, as a tuple of ``(j, i)``.""" return ( self.dataset.sizes[self.j_dimension], @@ -118,7 +117,7 @@ class ArakawaCGridKind(str, enum.Enum): #: :meta hide-value: node = 'node' - def __call__(self, j: int, i: int) -> ArakawaCIndex: + def __call__(self, j: int, i: int) -> 'ArakawaCIndex': return (self, j, i) @@ -126,9 +125,9 @@ def __call__(self, j: int, i: int) -> ArakawaCIndex: #: is a tuple with three elements: ``(kind, j, i).`` #: #: :meta hide-value: -ArakawaCIndex = Tuple[ArakawaCGridKind, int, int] -ArakawaCCoordinates = Dict[ArakawaCGridKind, Tuple[Hashable, Hashable]] -ArakawaCDimensions = Dict[ArakawaCGridKind, Tuple[Hashable, Hashable]] +ArakawaCIndex = tuple[ArakawaCGridKind, int, int] +ArakawaCCoordinates = dict[ArakawaCGridKind, tuple[Hashable, Hashable]] +ArakawaCDimensions = dict[ArakawaCGridKind, tuple[Hashable, Hashable]] class ArakawaC(DimensionConvention[ArakawaCGridKind, ArakawaCIndex]): @@ -169,7 +168,7 @@ def __init__( self, dataset: xarray.Dataset, *, - coordinate_names: Optional[Dict[Hashable, Tuple[Hashable, Hashable]]] = None, + coordinate_names: Optional[dict[Hashable, tuple[Hashable, Hashable]]] = None, ): super().__init__(dataset) @@ -206,7 +205,7 @@ def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: return None @cached_property - def _topology_for_grid_kind(self) -> Dict[ArakawaCGridKind, ArakawaCGridTopology]: + def _topology_for_grid_kind(self) -> dict[ArakawaCGridKind, ArakawaCGridTopology]: return { kind: ArakawaCGridTopology( self.dataset, @@ -249,13 +248,13 @@ def node(self) -> ArakawaCGridTopology: return self._topology_for_grid_kind[ArakawaCGridKind.node] @cached_property - def grid_dimensions(self) -> Dict[ArakawaCGridKind, Sequence[Hashable]]: + def grid_dimensions(self) -> dict[ArakawaCGridKind, Sequence[Hashable]]: return { - kind: cast(Tuple[Hashable, Hashable], self.dataset[coordinates[0]].dims) + kind: cast(tuple[Hashable, Hashable], self.dataset[coordinates[0]].dims) for kind, coordinates in self.coordinate_names.items() } - def unpack_index(self, index: ArakawaCIndex) -> Tuple[ArakawaCGridKind, Sequence[int]]: + def unpack_index(self, index: ArakawaCIndex) -> tuple[ArakawaCGridKind, Sequence[int]]: return index[0], index[1:] def pack_index(self, grid_kind: ArakawaCGridKind, indices: Sequence[int]) -> ArakawaCIndex: @@ -285,7 +284,7 @@ def face_centres(self) -> numpy.ndarray: )) return cast(numpy.ndarray, centres) - def get_all_geometry_names(self) -> List[Hashable]: + def get_all_geometry_names(self) -> list[Hashable]: return [ self.face.longitude.name, self.face.latitude.name, @@ -320,7 +319,7 @@ def make_clip_mask( # Complete the rest of the mask grid_dimensions = cast( - Dict[ArakawaCGridKind, Tuple[Hashable, Hashable]], + dict[ArakawaCGridKind, tuple[Hashable, Hashable]], self.grid_dimensions) return c_mask_from_centres(face_mask, grid_dimensions, self.dataset.coords) diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index a5aa0e13..ae0b278e 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -2,18 +2,14 @@ Datasets following the CF conventions with gridded datasets. Both 1D coordinates and 2D coordinates are supported. """ -from __future__ import annotations - import abc import enum import itertools import warnings +from collections.abc import Hashable, Sequence from contextlib import suppress from functools import cached_property -from typing import ( - Dict, Generic, Hashable, List, Optional, Sequence, Tuple, Type, TypeVar, - cast -) +from typing import Generic, Optional, TypeVar, cast import numpy import xarray @@ -33,7 +29,7 @@ class CFGridKind(str, enum.Enum): #: A two-tuple of ``(y, x)``. -CFGridIndex = Tuple[int, int] +CFGridIndex = tuple[int, int] CF_LATITUDE_UNITS = { @@ -171,7 +167,7 @@ def x_dimension(self) -> Hashable: pass @cached_property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: """The shape of this grid, as a tuple of ``(y, x)``.""" sizes = self.dataset.sizes return (sizes[self.y_dimension], sizes[self.x_dimension]) @@ -198,7 +194,7 @@ class CFGrid(Generic[Topology], DimensionConvention[CFGridKind, CFGridIndex]): grid_kinds = frozenset(CFGridKind) default_grid_kind = CFGridKind.face - topology_class: Type[Topology] + topology_class: type[Topology] def __init__( self, @@ -255,18 +251,18 @@ def bounds(self) -> Bounds: return (min_x, min_y, max_x, max_y) @cached_property - def grid_dimensions(self) -> Dict[CFGridKind, Sequence[Hashable]]: + def grid_dimensions(self) -> dict[CFGridKind, Sequence[Hashable]]: return { CFGridKind.face: [self.topology.y_dimension, self.topology.x_dimension], } - def unpack_index(self, index: CFGridIndex) -> Tuple[CFGridKind, Sequence[int]]: + def unpack_index(self, index: CFGridIndex) -> tuple[CFGridKind, Sequence[int]]: return CFGridKind.face, index def pack_index(self, grid_kind: CFGridKind, indices: Sequence[int]) -> CFGridIndex: return cast(CFGridIndex, indices) - def get_all_geometry_names(self) -> List[Hashable]: + def get_all_geometry_names(self) -> list[Hashable]: # Grid datasets contain latitude and longitude variables # plus optional bounds variables. names = [ @@ -274,7 +270,7 @@ def get_all_geometry_names(self) -> List[Hashable]: self.topology.latitude_name, ] - bounds_names: List[Optional[Hashable]] = [ + bounds_names: list[Optional[Hashable]] = [ self.topology.longitude.attrs.get('bounds', None), self.topology.latitude.attrs.get('bounds', None), ] diff --git a/src/emsarray/conventions/shoc.py b/src/emsarray/conventions/shoc.py index 69c9e15a..5cf574e1 100644 --- a/src/emsarray/conventions/shoc.py +++ b/src/emsarray/conventions/shoc.py @@ -14,11 +14,10 @@ -------- `SHOC documentation `_ """ -from __future__ import annotations - import logging +from collections.abc import Hashable from functools import cached_property -from typing import Hashable, List, Optional, Tuple +from typing import Optional import xarray @@ -56,7 +55,7 @@ def get_depth_name(self) -> Hashable: f"SHOC dataset did not have expected depth coordinate {name!r}") return name - def get_all_depth_names(self) -> List[Hashable]: + def get_all_depth_names(self) -> list[Hashable]: return [ name for name in ['z_centre', 'z_grid'] if name in self.dataset.variables] @@ -81,7 +80,7 @@ class ShocSimple(CFGrid2D): The latitude and longitude coordinate variables are named ``j`` and ``i``. Edge and node dimensions are dropped. """ - _dimensions: Tuple[Hashable, Hashable] = ('j', 'i') + _dimensions: tuple[Hashable, Hashable] = ('j', 'i') @cached_property def topology(self) -> CFGrid2DTopology: @@ -124,7 +123,7 @@ def get_depth_name(self) -> Hashable: f"SHOC dataset did not have expected depth coordinate {name!r}") return name - def get_all_depth_names(self) -> List[Hashable]: + def get_all_depth_names(self) -> list[Hashable]: name = 'zc' if name in self.dataset.variables: return [name] diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index 3a83e11a..a81fb631 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -5,20 +5,16 @@ -------- `UGRID conventions `_ """ -from __future__ import annotations - import enum import logging import pathlib import warnings from collections import defaultdict +from collections.abc import Hashable, Iterable, Mapping, Sequence from contextlib import suppress from dataclasses import dataclass from functools import cached_property -from typing import ( - Any, Dict, FrozenSet, Hashable, Iterable, List, Mapping, Optional, - Sequence, Set, Tuple, cast -) +from typing import Any, Optional, cast import numpy import shapely @@ -36,12 +32,15 @@ logger = logging.getLogger(__name__) -def _split_coord(attr: str) -> Tuple[str, str]: +def _split_coord(attr: str) -> tuple[str, str]: x, y = attr.split(None, 1) return (x, y) -def buffer_faces(face_indices: numpy.ndarray, topology: Mesh2DTopology) -> numpy.ndarray: +def buffer_faces( + face_indices: numpy.ndarray, + topology: 'Mesh2DTopology', +) -> numpy.ndarray: """ When clipping a dataset to a region, including a buffer of extra faces around the included faces is desired. Given an array of face indices, @@ -70,7 +69,10 @@ def buffer_faces(face_indices: numpy.ndarray, topology: Mesh2DTopology) -> numpy return cast(numpy.ndarray, numpy.fromiter(included_faces, dtype=topology.sensible_dtype)) -def mask_from_face_indices(face_indices: numpy.ndarray, topology: Mesh2DTopology) -> xarray.Dataset: +def mask_from_face_indices( + face_indices: numpy.ndarray, + topology: 'Mesh2DTopology', +) -> xarray.Dataset: """ Make a mask dataset from a list of face indices. This mask can later be applied using :meth:`~.Convention.apply_clip_mask`. @@ -386,7 +388,7 @@ def mesh_variable(self) -> xarray.DataArray: raise ValueError("No mesh variable found") @property - def mesh_attributes(self) -> Dict[Hashable, str]: + def mesh_attributes(self) -> dict[Hashable, str]: """ Get the mesh topology attributes from the dummy variable with the attribute ``cf_role`` of ``"mesh_topology"``. @@ -420,15 +422,15 @@ def sensible_fill_value(self) -> int: return int('9' * (len(str(max_count)) + 1)) @cached_property - def _node_coordinates(self) -> Tuple[Hashable, Hashable]: + def _node_coordinates(self) -> tuple[Hashable, Hashable]: return _split_coord(self.mesh_attributes['node_coordinates']) @cached_property - def _edge_coordinates(self) -> Tuple[Hashable, Hashable]: + def _edge_coordinates(self) -> tuple[Hashable, Hashable]: return _split_coord(self.mesh_attributes['edge_coordinates']) @cached_property - def _face_coordinates(self) -> Tuple[Hashable, Hashable]: + def _face_coordinates(self) -> tuple[Hashable, Hashable]: return _split_coord(self.mesh_attributes['face_coordinates']) @property @@ -598,7 +600,7 @@ def make_edge_node_array(self) -> numpy.ndarray: # once for each face. To de-duplicate this, edges are built up using # this dict-of-sets, where the dict index is the node with the # lower index, and the set is the node indices of the other end. - low_highs: Dict[int, Set[int]] = defaultdict(set) + low_highs: dict[int, set[int]] = defaultdict(set) for face_index, node_pairs in self._face_and_node_pair_iter(): for pair in node_pairs: @@ -857,7 +859,7 @@ def make_face_face_array(self) -> numpy.ndarray: return face_face - def _face_and_node_pair_iter(self) -> Iterable[Tuple[int, List[Tuple[int, int]]]]: + def _face_and_node_pair_iter(self) -> Iterable[tuple[int, list[tuple[int, int]]]]: """ An iterator returning a tuple of ``(face_index, edges)``, where ``edges`` is a list of ``(node_index, node_index)`` tuples @@ -870,7 +872,7 @@ def _face_and_node_pair_iter(self) -> Iterable[Tuple[int, List[Tuple[int, int]]] yield face_index, list(utils.pairwise(node_indices)) @cached_property - def dimension_for_grid_kind(self) -> Dict[UGridKind, Hashable]: + def dimension_for_grid_kind(self) -> dict['UGridKind', Hashable]: """ Get the dimension names for each of the grid types in this dataset. """ @@ -1012,7 +1014,7 @@ class UGridKind(str, enum.Enum): #: UGRID indices are always single integers, for all index kinds. -UGridIndex = Tuple[UGridKind, int] +UGridIndex = tuple[UGridKind, int] class UGrid(DimensionConvention[UGridKind, UGridIndex]): @@ -1051,8 +1053,8 @@ def topology(self) -> Mesh2DTopology: return Mesh2DTopology(self.dataset) @cached_property - def grid_dimensions(self) -> Dict[UGridKind, Sequence[Hashable]]: - dimensions: Dict[UGridKind, Sequence[Hashable]] = { + def grid_dimensions(self) -> dict[UGridKind, Sequence[Hashable]]: + dimensions: dict[UGridKind, Sequence[Hashable]] = { UGridKind.node: [self.topology.node_dimension], UGridKind.face: [self.topology.face_dimension], } @@ -1060,14 +1062,14 @@ def grid_dimensions(self) -> Dict[UGridKind, Sequence[Hashable]]: dimensions[UGridKind.edge] = [self.topology.edge_dimension] return dimensions - def unpack_index(self, index: UGridIndex) -> Tuple[UGridKind, Sequence[int]]: + def unpack_index(self, index: UGridIndex) -> tuple[UGridKind, Sequence[int]]: return index[0], index[1:] def pack_index(self, grid_kind: UGridKind, indices: Sequence[int]) -> UGridIndex: return (grid_kind, indices[0]) @cached_property - def grid_kinds(self) -> FrozenSet[UGridKind]: + def grid_kinds(self) -> frozenset[UGridKind]: items = [UGridKind.face, UGridKind.node] # The edge dimension is optional, not all UGRID datasets define it if self.topology.has_edge_dimension: @@ -1088,7 +1090,7 @@ def polygons(self) -> numpy.ndarray: # `shapely.polygons` will make polygons with the same number of vertices. # UGRID polygons have arbitrary numbers of vertices. # Group polygons by how many vertices they have, then make them in bulk. - polygons_of_size: Mapping[int, Dict[int, numpy.ndarray]] = defaultdict(dict) + polygons_of_size: Mapping[int, dict[int, numpy.ndarray]] = defaultdict(dict) for index, row in enumerate(face_node): vertices = row.compressed() polygons_of_size[vertices.size][index] = numpy.c_[node_x[vertices], node_y[vertices]] @@ -1176,7 +1178,7 @@ def apply_clip_mask(self, clip_mask: xarray.Dataset, work_dir: Pathish) -> xarra # Collect all the topology variables here. These need special handling, # compared to data variables. The mesh variable can be reused without # any changes. - topology_variables: List[xarray.DataArray] = [topology.mesh_variable] + topology_variables: list[xarray.DataArray] = [topology.mesh_variable] # This is the fill value used in the mask. new_fill_value = clip_mask.data_vars['new_node_index'].encoding['_FillValue'] @@ -1240,7 +1242,7 @@ def integer_indices(data_array: xarray.DataArray) -> numpy.ndarray: del topology_variables logger.debug("Slicing data variables...") - dimension_masks: Dict[Hashable, numpy.ndarray] = { + dimension_masks: dict[Hashable, numpy.ndarray] = { topology.node_dimension: ~numpy.ma.getmask(new_node_indices), topology.face_dimension: ~numpy.ma.getmask(new_face_indices), } @@ -1287,7 +1289,7 @@ def integer_indices(data_array: xarray.DataArray) -> numpy.ndarray: new_dataset = xarray.open_mfdataset(mfdataset_paths, lock=False) return utils.dataset_like(dataset, new_dataset) - def get_all_geometry_names(self) -> List[Hashable]: + def get_all_geometry_names(self) -> list[Hashable]: topology = self.topology names = [ diff --git a/src/emsarray/formats.py b/src/emsarray/formats.py index f88e170a..0b2bb8a4 100644 --- a/src/emsarray/formats.py +++ b/src/emsarray/formats.py @@ -1,6 +1,6 @@ import warnings from functools import wraps -from typing import Any, Optional, Type +from typing import Any, Optional import xarray @@ -17,7 +17,7 @@ def _warn_old_new(old: str, new: str, **kwargs: Any) -> None: @wraps(get_dataset_convention) -def get_file_format(dataset: xarray.Dataset, **kwargs: Any) -> Optional[Type[Convention]]: +def get_file_format(dataset: xarray.Dataset, **kwargs: Any) -> Optional[type[Convention]]: _warn_old_new( old="emsarray.formats.get_file_format", new="emsarray.conventions.get_dataset_convention", diff --git a/src/emsarray/masking.py b/src/emsarray/masking.py index 58103ee3..9b0f1964 100644 --- a/src/emsarray/masking.py +++ b/src/emsarray/masking.py @@ -3,14 +3,13 @@ Masks are used when clipping datasets to a smaller geographic subset, such as :meth:`.Convention.clip`. """ -from __future__ import annotations - import functools import itertools import logging import operator import pathlib -from typing import Any, Dict, Hashable, List, cast +from collections.abc import Hashable +from typing import Any, cast import numpy import xarray @@ -64,7 +63,7 @@ def mask_grid_dataset( mask = mask.isel(bounds) dataset = dataset.isel(bounds) - mfdataset_names: List[pathlib.Path] = [] + mfdataset_names: list[pathlib.Path] = [] logger.info("Applying masks...") # This is done variable-by-variable, as trying to do it to the entire @@ -205,7 +204,7 @@ def find_fill_value(data_array: xarray.DataArray) -> Any: raise ValueError("No appropriate fill value found") -def calculate_grid_mask_bounds(mask: xarray.Dataset) -> Dict[Hashable, slice]: +def calculate_grid_mask_bounds(mask: xarray.Dataset) -> dict[Hashable, slice]: """ Calculate the included bounds of a mask dataset for each dimension. @@ -247,7 +246,7 @@ def calculate_grid_mask_bounds(mask: xarray.Dataset) -> Dict[Hashable, slice]: return bounds -def smear_mask(arr: numpy.ndarray, pad_axes: List[bool]) -> numpy.ndarray: +def smear_mask(arr: numpy.ndarray, pad_axes: list[bool]) -> numpy.ndarray: """ Take a boolean numpy array and a list indicating which axes to smear along. Return a new array, expanded along the axes, with the boolean values diff --git a/src/emsarray/nco.py b/src/emsarray/nco.py index c50536ef..04c5ca56 100644 --- a/src/emsarray/nco.py +++ b/src/emsarray/nco.py @@ -9,8 +9,9 @@ :class:`xarray.Dataset` instances. """ import subprocess +from collections.abc import Sequence from pathlib import Path -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Union Pathish = Union[Path, str] diff --git a/src/emsarray/operations/depth.py b/src/emsarray/operations/depth.py index 2a8c24f9..7c90eb1d 100644 --- a/src/emsarray/operations/depth.py +++ b/src/emsarray/operations/depth.py @@ -4,7 +4,8 @@ """ import warnings from collections import defaultdict -from typing import Dict, FrozenSet, Hashable, List, Optional, cast +from collections.abc import Hashable +from typing import Optional, cast import numpy import xarray @@ -14,9 +15,9 @@ def ocean_floor( dataset: xarray.Dataset, - depth_variables: List[Hashable], + depth_variables: list[Hashable], *, - non_spatial_variables: Optional[List[Hashable]] = None, + non_spatial_variables: Optional[list[Hashable]] = None, ) -> xarray.Dataset: """Make a new :class:`xarray.Dataset` reduced along the given depth coordinates to only contain values along the ocean floor. @@ -117,7 +118,7 @@ def ocean_floor( non_spatial_dimensions = utils.dimensions_from_coords(dataset, non_spatial_variables) for depth_dimension in sorted(depth_dimensions, key=hash): - dimension_sets: Dict[FrozenSet[Hashable], List[Hashable]] = defaultdict(list) + dimension_sets: dict[frozenset[Hashable], list[Hashable]] = defaultdict(list) for name, variable in dataset.data_vars.items(): if depth_dimension not in variable.dims: continue # Skip data variables without this depth dimension @@ -195,7 +196,7 @@ def _find_ocean_floor_indices( def normalize_depth_variables( dataset: xarray.Dataset, - depth_variables: List[Hashable], + depth_variables: list[Hashable], *, positive_down: Optional[bool] = None, deep_to_shallow: Optional[bool] = None, diff --git a/src/emsarray/operations/geometry.py b/src/emsarray/operations/geometry.py index db77ebd7..1034646a 100644 --- a/src/emsarray/operations/geometry.py +++ b/src/emsarray/operations/geometry.py @@ -5,10 +5,9 @@ import json import os import pathlib +from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager -from typing import ( - IO, Any, Generator, Iterable, Iterator, Optional, TypeVar, Union -) +from typing import IO, Any, Optional, TypeVar, Union import geojson import shapefile diff --git a/src/emsarray/operations/point_extraction.py b/src/emsarray/operations/point_extraction.py index af89d273..58e24bab 100644 --- a/src/emsarray/operations/point_extraction.py +++ b/src/emsarray/operations/point_extraction.py @@ -14,7 +14,8 @@ :ref:`emsarray extract-points` is a command line interface to :func:`.extract_dataframe`. """ import dataclasses -from typing import Any, Hashable, List, Literal, Tuple +from collections.abc import Hashable +from typing import Any, Literal import numpy import pandas @@ -35,7 +36,7 @@ class NonIntersectingPoints(ValueError): indices: numpy.ndarray #: The non-intersecting points - points: List[shapely.Point] + points: list[shapely.Point] def __post_init__(self) -> None: super().__init__(f"{self.points[0].wkt} does not intersect the dataset geometry") @@ -55,7 +56,7 @@ def _dataframe_to_dataset( def extract_points( dataset: xarray.Dataset, - points: List[shapely.Point], + points: list[shapely.Point], *, point_dimension: Hashable = 'point', missing_points: Literal['error', 'drop'] = 'error', @@ -128,7 +129,7 @@ def extract_points( def extract_dataframe( dataset: xarray.Dataset, dataframe: pandas.DataFrame, - coordinate_columns: Tuple[str, str], + coordinate_columns: tuple[str, str], *, point_dimension: Hashable = 'point', missing_points: Literal['error', 'drop', 'fill'] = 'error', diff --git a/src/emsarray/operations/triangulate.py b/src/emsarray/operations/triangulate.py index 62cce393..d13c0d38 100644 --- a/src/emsarray/operations/triangulate.py +++ b/src/emsarray/operations/triangulate.py @@ -1,18 +1,18 @@ """ Operations for making a triangular mesh out of the polygons of a dataset. """ -from typing import List, Tuple, cast +from typing import cast import xarray from shapely.geometry import LineString, MultiPoint, Polygon -Vertex = Tuple[float, float] -Triangle = Tuple[int, int, int] +Vertex = tuple[float, float] +Triangle = tuple[int, int, int] def triangulate_dataset( dataset: xarray.Dataset, -) -> Tuple[List[Vertex], List[Triangle], List[int]]: +) -> tuple[list[Vertex], list[Triangle], list[int]]: """ Triangulate the polygon cells of a dataset @@ -89,7 +89,7 @@ def triangulate_dataset( # Getting all the vertices is easy - extract them from the polygons. # By going through a set, this will deduplicate the vertices. # Back to a list and we have a stable order - vertices: List[Vertex] = list({ + vertices: list[Vertex] = list({ vertex for polygon in polygons if polygon is not None @@ -114,13 +114,13 @@ def triangulate_dataset( for polygon, dataset_index in polygons_with_index for triangle_coords in _triangulate_polygon(polygon) ) - triangles: List[Triangle] = [tri for tri, index in triangles_with_index] # type: ignore + triangles: list[Triangle] = [tri for tri, index in triangles_with_index] # type: ignore indices = [index for tri, index in triangles_with_index] return (vertices, triangles, indices) -def _triangulate_polygon(polygon: Polygon) -> List[Tuple[Vertex, Vertex, Vertex]]: +def _triangulate_polygon(polygon: Polygon) -> list[tuple[Vertex, Vertex, Vertex]]: """ Triangulate a polygon. @@ -163,7 +163,7 @@ def _triangulate_polygon(polygon: Polygon) -> List[Tuple[Vertex, Vertex, Vertex] # Maintain a consistent winding order polygon = polygon.normalize() - triangles: List[Tuple[Vertex, Vertex, Vertex]] = [] + triangles: list[tuple[Vertex, Vertex, Vertex]] = [] # Note that shapely polygons with n vertices will be closed, and thus have # n+1 coordinates. We trim that superfluous coordinate off in the next line while len(polygon.exterior.coords) > 4: @@ -195,6 +195,6 @@ def _triangulate_polygon(polygon: Polygon) -> List[Tuple[Vertex, Vertex, Vertex] # The trimmed polygon is now a triangle. Add it to the list and we are done! triangles.append(cast( - Tuple[Vertex, Vertex, Vertex], + tuple[Vertex, Vertex, Vertex], tuple(map(tuple, polygon.exterior.coords[:-1])))) return triangles diff --git a/src/emsarray/plot.py b/src/emsarray/plot.py index 9255e93f..560afd00 100644 --- a/src/emsarray/plot.py +++ b/src/emsarray/plot.py @@ -1,20 +1,14 @@ -from __future__ import annotations - -from typing import ( - TYPE_CHECKING, Any, Callable, Iterable, List, Literal, Optional, Tuple, - Union -) +from collections.abc import Iterable +from typing import Any, Callable, Literal, Optional, Union import numpy import xarray +from emsarray import conventions from emsarray.exceptions import NoSuchCoordinateError from emsarray.types import Landmark from emsarray.utils import requires_extra -if TYPE_CHECKING: - from .conventions import Convention - try: import cartopy.crs from cartopy.feature import GSHHSFeature @@ -119,7 +113,7 @@ def add_landmarks( dataset = emsarray.tutorial.open_dataset('gbr4') - # Set up the figure + # set up the figure figure = pyplot.figure() axes = figure.add_subplot(projection=dataset.ems.data_crs) axes.set_title("Sea surface temperature around Mackay") @@ -160,7 +154,7 @@ def add_landmarks( text.set_path_effects([outline]) -def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]: +def bounds_to_extent(bounds: tuple[float, float, float, float]) -> list[float]: """ Convert a Shapely bounds tuple to a matplotlib extents. @@ -265,10 +259,10 @@ def make_plot_title( @_requires_plot def plot_on_figure( figure: Figure, - convention: Convention, + convention: 'conventions.Convention', *, scalar: Optional[xarray.DataArray] = None, - vector: Optional[Tuple[xarray.DataArray, xarray.DataArray]] = None, + vector: Optional[tuple[xarray.DataArray, xarray.DataArray]] = None, title: Optional[str] = None, projection: Optional[cartopy.crs.Projection] = None, landmarks: Optional[Iterable[Landmark]] = None, @@ -356,11 +350,11 @@ def plot_on_figure( @_requires_plot def animate_on_figure( figure: Figure, - convention: Convention, + convention: 'conventions.Convention', *, coordinate: xarray.DataArray, scalar: Optional[xarray.DataArray] = None, - vector: Optional[Tuple[xarray.DataArray, xarray.DataArray]] = None, + vector: Optional[tuple[xarray.DataArray, xarray.DataArray]] = None, title: Optional[Union[str, Callable[[Any], str]]] = None, projection: Optional[cartopy.crs.Projection] = None, landmarks: Optional[Iterable[Landmark]] = None, @@ -491,7 +485,7 @@ def animate_on_figure( coordinate_callable = title def animate(index: int) -> Iterable[Artist]: - changes: List[Artist] = [] + changes: list[Artist] = [] coordinate_value = coordinate.values[index] axes.title.set_text(coordinate_callable(coordinate_value)) changes.append(axes.title) diff --git a/src/emsarray/state.py b/src/emsarray/state.py index 068bfb58..82ff0ebf 100644 --- a/src/emsarray/state.py +++ b/src/emsarray/state.py @@ -1,15 +1,12 @@ """ Dataclass for containing state required for emsarray """ -from __future__ import annotations - import dataclasses -from typing import TYPE_CHECKING, Final, Optional, cast +from typing import Final, Optional, cast import xarray -if TYPE_CHECKING: - from emsarray.conventions._base import Convention +from emsarray import conventions @dataclasses.dataclass @@ -20,7 +17,7 @@ class State: to avoid convention autodetection. """ dataset: xarray.Dataset - convention: Optional[Convention] = None + convention: Optional['conventions.Convention'] = None accessor_name: Final[str] = "_emsarray_state" @@ -32,7 +29,7 @@ def get(cls, dataset: xarray.Dataset) -> "State": """ return cast(State, getattr(dataset, State.accessor_name)) - def bind_convention(self, convention: Convention) -> None: + def bind_convention(self, convention: 'conventions.Convention') -> None: """ Bind a Convention instance to this Dataset. If the Dataset is already bound, an error is raised. diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py index ee2840e9..22075227 100644 --- a/src/emsarray/transect.py +++ b/src/emsarray/transect.py @@ -1,10 +1,7 @@ -from __future__ import annotations - import dataclasses +from collections.abc import Hashable, Iterable from functools import cached_property -from typing import ( - Any, Callable, Generic, Hashable, Iterable, List, Optional, Tuple, Union -) +from typing import Any, Callable, Generic, Optional, Union import cfunits import numpy @@ -223,7 +220,7 @@ def transect_dataset(self) -> xarray.Dataset: }, ) - def _set_up_axis(self, variable: xarray.DataArray) -> Tuple[str, Formatter]: + def _set_up_axis(self, variable: xarray.DataArray) -> tuple[str, Formatter]: title = str(variable.attrs.get('long_name')) units: Optional[str] = variable.attrs.get('units') @@ -247,7 +244,7 @@ def _crs_for_point( @cached_property def points( self, - ) -> List[TransectPoint]: + ) -> list[TransectPoint]: """ A list of :class:`TransectPoints `, one for each point in the transect :attr:`.line`. @@ -283,7 +280,7 @@ def points( return points @cached_property - def segments(self) -> List[TransectSegment[Index]]: + def segments(self) -> list[TransectSegment[Index]]: """ A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry. Segments are listed in order from the start of the line to the end of the line. @@ -305,7 +302,7 @@ def segments(self) -> List[TransectSegment[Index]]: shapely.Point(intersection.coords[0]), shapely.Point(intersection.coords[-1]) ] - projections: Iterable[Tuple[shapely.Point, float]] = ( + projections: Iterable[tuple[shapely.Point, float]] = ( (point, self.distance_along_line(point)) for point in points) start, end = sorted(projections, key=lambda pair: pair[1]) @@ -326,7 +323,7 @@ def segments(self) -> List[TransectSegment[Index]]: def _intersect_polygon( self, polygon: shapely.Polygon, - ) -> List[shapely.LineString]: + ) -> list[shapely.LineString]: """ Intersect a cell of the dataset geometry with the transect line, and return a list of all LineString segments of the intersection. @@ -518,7 +515,7 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra return data_array - def _find_depth_bounds(self, data_array: xarray.DataArray) -> Tuple[int, int]: + def _find_depth_bounds(self, data_array: xarray.DataArray) -> tuple[int, int]: """ Find the shallowest and deepest layers of the data array where there is at least one value per depth. @@ -565,7 +562,7 @@ def plot_on_figure( bathymetry: Optional[xarray.DataArray] = None, cmap: Union[str, Colormap] = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[List[Landmark]] = None, + landmarks: Optional[list[Landmark]] = None, ) -> None: """ Plot the data array along this transect. @@ -625,7 +622,7 @@ def animate_on_figure( bathymetry: Optional[xarray.DataArray] = None, cmap: Union[str, Colormap] = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[List[Landmark]] = None, + landmarks: Optional[list[Landmark]] = None, coordinate: Optional[xarray.DataArray] = None, interval: int = 200, ) -> animation.FuncAnimation: @@ -682,7 +679,7 @@ def animate_on_figure( ) def animate(index: int) -> Iterable[Artist]: - changes: List[Artist] = [] + changes: list[Artist] = [] coordinate_value = coordinate.values[index] axes.set_title(coordinate_callable(coordinate_value)) @@ -714,8 +711,8 @@ def _plot_on_figure( bathymetry: Optional[xarray.DataArray] = None, cmap: Union[str, Colormap] = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[List[Landmark]] = None, - ) -> Tuple[Axes, PolyCollection, xarray.DataArray]: + landmarks: Optional[list[Landmark]] = None, + ) -> tuple[Axes, PolyCollection, xarray.DataArray]: """ Construct the axes and PolyCollections on a plot, and reformat the data array to the correct shape for plotting. diff --git a/src/emsarray/types.py b/src/emsarray/types.py index abc51ace..e6324861 100644 --- a/src/emsarray/types.py +++ b/src/emsarray/types.py @@ -3,7 +3,7 @@ """ import os -from typing import Tuple, Union +from typing import Union import shapely @@ -12,8 +12,8 @@ #: Bounds of a geometry or of an area. #: Components are ordered as (min x, min y, max x, max y). -Bounds = Tuple[float, float, float, float] +Bounds = tuple[float, float, float, float] #: A landmark for a plot. #: This is a tuple of the landmark name and and its location. -Landmark = Tuple[str, shapely.Point] +Landmark = tuple[str, shapely.Point] diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index fd816e45..293d38b9 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -7,8 +7,6 @@ -------- :mod:`emsarray.operations` """ -from __future__ import annotations - import datetime import functools import itertools @@ -16,11 +14,11 @@ import textwrap import time import warnings -from types import TracebackType -from typing import ( - Any, Callable, Hashable, Iterable, List, Literal, Mapping, MutableMapping, - Optional, Sequence, Tuple, Type, TypeVar, Union, cast +from collections.abc import ( + Hashable, Iterable, Mapping, MutableMapping, Sequence ) +from types import TracebackType +from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast import cftime import netCDF4 @@ -51,7 +49,7 @@ class PerfTimer: def __init__(self) -> None: self.running = False - def __enter__(self) -> PerfTimer: + def __enter__(self) -> 'PerfTimer': if self.running: raise RuntimeError("Timer is already running") self.running = True @@ -60,7 +58,7 @@ def __enter__(self) -> PerfTimer: def __exit__( self, - exc_type: Optional[Type[_Exception]], + exc_type: Optional[type[_Exception]], exc_value: Optional[_Exception], traceback: TracebackType ) -> Optional[bool]: @@ -240,7 +238,7 @@ def fix_time_units_for_ems( dataset.sync() -def _get_variables(dataset_or_array: Union[xarray.Dataset, xarray.DataArray]) -> List[xarray.Variable]: +def _get_variables(dataset_or_array: Union[xarray.Dataset, xarray.DataArray]) -> list[xarray.Variable]: if isinstance(dataset_or_array, xarray.Dataset): return list(dataset_or_array.variables.values()) else: @@ -378,7 +376,7 @@ def extract_vars( return dataset.drop_vars(drop_vars) -def pairwise(iterable: Iterable[_T]) -> Iterable[Tuple[_T, _T]]: +def pairwise(iterable: Iterable[_T]) -> Iterable[tuple[_T, _T]]: """ Iterate over values in an iterator in pairs. @@ -397,8 +395,8 @@ def pairwise(iterable: Iterable[_T]) -> Iterable[Tuple[_T, _T]]: def dimensions_from_coords( dataset: xarray.Dataset, - coordinate_names: List[Hashable], -) -> List[Hashable]: + coordinate_names: list[Hashable], +) -> list[Hashable]: """ Get the names of the dimensions for a set of coordinates. @@ -481,7 +479,7 @@ def check_data_array_dimensions_match( def move_dimensions_to_end( data_array: xarray.DataArray, - dimensions: List[Hashable], + dimensions: list[Hashable], ) -> xarray.DataArray: """ Transpose the dimensions of a :class:`xarray.DataArray` @@ -524,7 +522,7 @@ def move_dimensions_to_end( def ravel_dimensions( data_array: xarray.DataArray, - dimensions: List[Hashable], + dimensions: list[Hashable], linear_dimension: Optional[Hashable] = None, ) -> xarray.DataArray: """ @@ -692,7 +690,7 @@ def __init__(self, extra: str) -> None: def requires_extra( extra: str, import_error: Optional[ImportError], - exception_class: Type[RequiresExtraException] = RequiresExtraException, + exception_class: type[RequiresExtraException] = RequiresExtraException, ) -> Callable[[_T], _T]: if import_error is None: return lambda fn: fn @@ -741,7 +739,7 @@ def make_polygons_with_holes( return out -def deprecated(message: str, category: Type[Warning] = DeprecationWarning) -> Callable: +def deprecated(message: str, category: type[Warning] = DeprecationWarning) -> Callable: def decorator(fn: Callable) -> Callable: @functools.wraps(fn) def wrapped(*args: Any, **kwargs: Any) -> Any: @@ -751,5 +749,5 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return decorator -def splice_tuple(t: Tuple, index: int, values: Sequence) -> Tuple: +def splice_tuple(t: tuple, index: int, values: Sequence) -> tuple: return t[:index] + tuple(values) + t[index:][1:] diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 0cf3605d..6fa3a635 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import argparse import json from pathlib import Path -from typing import Any, List +from typing import Any import geojson import pytest @@ -70,7 +68,7 @@ def test_nice_console_errors_uncaught_exception(caplog: pytest.LogCaptureFixture (['--silent'], 0), ], ) -def test_add_verbosity_group(args: List[str], expected: int) -> None: +def test_add_verbosity_group(args: list[str], expected: int) -> None: parser = argparse.ArgumentParser() utils.add_verbosity_group(parser) options = parser.parse_args(args) diff --git a/tests/conventions/test_base.py b/tests/conventions/test_base.py index 723a4a54..14b30282 100644 --- a/tests/conventions/test_base.py +++ b/tests/conventions/test_base.py @@ -1,10 +1,9 @@ -from __future__ import annotations - import dataclasses import enum import pathlib +from collections.abc import Hashable from functools import cached_property -from typing import Dict, Hashable, List, Optional, Tuple +from typing import Optional import numpy import pandas @@ -51,12 +50,12 @@ def _get_data_array(self, data_array_or_name) -> xarray.DataArray: return data_array_or_name @cached_property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: y, x = map(int, self.dataset['botz'].shape) return (y, x) @cached_property - def grid_size(self) -> Dict[SimpleGridKind, int]: + def grid_size(self) -> dict[SimpleGridKind, int]: return {SimpleGridKind.face: int(numpy.prod(self.shape))} def get_grid_kind(self, data_array: xarray.DataArray) -> SimpleGridKind: @@ -64,7 +63,7 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> SimpleGridKind: return SimpleGridKind.face raise ValueError("Unknown grid type") - def get_all_geometry_names(self) -> List[Hashable]: + def get_all_geometry_names(self) -> list[Hashable]: return ['x', 'y'] def wind_index( @@ -79,7 +78,7 @@ def wind_index( def ravel_index(self, indices: SimpleGridIndex) -> int: return int(numpy.ravel_multi_index((indices.y, indices.x), self.shape)) - def selector_for_index(self, index: SimpleGridIndex) -> Dict[Hashable, int]: + def selector_for_index(self, index: SimpleGridIndex) -> dict[Hashable, int]: return {'x': index.x, 'y': index.y} def ravel( diff --git a/tests/conventions/test_cfgrid1d.py b/tests/conventions/test_cfgrid1d.py index da32bba3..24cc9797 100644 --- a/tests/conventions/test_cfgrid1d.py +++ b/tests/conventions/test_cfgrid1d.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import json import pathlib diff --git a/tests/conventions/test_cfgrid2d.py b/tests/conventions/test_cfgrid2d.py index 5c81c385..8e55f244 100644 --- a/tests/conventions/test_cfgrid2d.py +++ b/tests/conventions/test_cfgrid2d.py @@ -5,12 +5,9 @@ Instead of writing two identical test suites, the SHOC simple convention is used to test both. """ -from __future__ import annotations - import itertools import json import pathlib -from typing import Type import numpy import pandas @@ -38,7 +35,7 @@ def make_dataset( j_size: int, i_size: int, time_size: int = 4, - grid_type: Type[ShocGridGenerator] = DiagonalShocGrid, + grid_type: type[ShocGridGenerator] = DiagonalShocGrid, corner_size: int = 0, include_bounds: bool = False, ) -> xarray.Dataset: diff --git a/tests/conventions/test_registry.py b/tests/conventions/test_registry.py index 0ce672d0..cd6be356 100644 --- a/tests/conventions/test_registry.py +++ b/tests/conventions/test_registry.py @@ -3,7 +3,6 @@ """ import sys from importlib import metadata -from typing import Dict, List, Tuple import pytest @@ -62,7 +61,7 @@ class Foo: def monkeypatch_entrypoint( monkeypatch, - entry_points: Dict[str, List[Tuple[str, str]]], + entry_points: dict[str, list[tuple[str, str]]], ): _entry_points = { group: [ @@ -72,10 +71,10 @@ def monkeypatch_entrypoint( } if sys.version_info >= (3, 10): - def mocked(group: str) -> List[metadata.EntryPoint]: + def mocked(group: str) -> list[metadata.EntryPoint]: return _entry_points.get(group, []) else: - def mocked() -> List[metadata.EntryPoint]: + def mocked() -> list[metadata.EntryPoint]: return _entry_points monkeypatch.setattr(metadata, 'entry_points', mocked) diff --git a/tests/conventions/test_shoc_standard.py b/tests/conventions/test_shoc_standard.py index ac63da78..29484586 100644 --- a/tests/conventions/test_shoc_standard.py +++ b/tests/conventions/test_shoc_standard.py @@ -1,9 +1,6 @@ -from __future__ import annotations - import itertools import json import pathlib -from typing import Type import numpy import pandas @@ -28,7 +25,7 @@ def make_dataset( j_size: int, i_size: int, time_size: int = 4, - grid_type: Type[ShocGridGenerator] = DiagonalShocGrid, + grid_type: type[ShocGridGenerator] = DiagonalShocGrid, corner_size: int = 0, ) -> xarray.Dataset: """ diff --git a/tests/conventions/test_ugrid.py b/tests/conventions/test_ugrid.py index 79e411c9..c6241643 100644 --- a/tests/conventions/test_ugrid.py +++ b/tests/conventions/test_ugrid.py @@ -1,9 +1,6 @@ -from __future__ import annotations - import json import pathlib import warnings -from typing import Tuple import geojson import numpy @@ -26,7 +23,7 @@ from tests.utils import assert_property_not_cached, filter_warning -def make_faces(width: int, height, fill_value: int) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: +def make_faces(width: int, height, fill_value: int) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: triangle_nodes = sum(range(width + 2)) square_rows = height square_columns = width diff --git a/tests/masking/test_mask_dataset.py b/tests/masking/test_mask_dataset.py index d9723537..5fa33c8a 100644 --- a/tests/masking/test_mask_dataset.py +++ b/tests/masking/test_mask_dataset.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import pathlib import netCDF4 diff --git a/tests/masking/test_utils.py b/tests/masking/test_utils.py index 10868fec..1198d72f 100644 --- a/tests/masking/test_utils.py +++ b/tests/masking/test_utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import pathlib import netCDF4 diff --git a/tests/operations/triangulate/test_triangulate_dataset.py b/tests/operations/triangulate/test_triangulate_dataset.py index a21e85aa..16fb5c7a 100644 --- a/tests/operations/triangulate/test_triangulate_dataset.py +++ b/tests/operations/triangulate/test_triangulate_dataset.py @@ -1,6 +1,5 @@ from collections import defaultdict from functools import reduce -from typing import List, Tuple import numpy import pytest @@ -86,9 +85,9 @@ def test_triangulate_dataset_ugrid(datasets): def check_triangulation( dataset: xarray.Dataset, - vertices: List[Tuple[float, float]], - triangles: List[Tuple[int, int, int]], - cell_indices: List[int], + vertices: list[tuple[float, float]], + triangles: list[tuple[int, int, int]], + cell_indices: list[int], ): """ Check the triangulation of a dataset by reconstructing all polygons. diff --git a/tests/utils.py b/tests/utils.py index b8f07984..80a97326 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,10 @@ -from __future__ import annotations - import abc import contextlib import itertools import warnings +from collections.abc import Hashable from functools import cached_property -from typing import Any, Dict, Hashable, List, Optional, Tuple +from typing import Any, Optional import numpy import shapely @@ -40,7 +39,7 @@ def box(minx, miny, maxx, maxy) -> shapely.Polygon: ]) -def reduce_axes(arr: numpy.ndarray, axes: Optional[Tuple[bool, ...]] = None) -> numpy.ndarray: +def reduce_axes(arr: numpy.ndarray, axes: Optional[tuple[bool, ...]] = None) -> numpy.ndarray: """ Reduce the size of an array by one on an axis-by-axis basis. If an axis is reduced, neigbouring values are averaged together @@ -55,7 +54,7 @@ def reduce_axes(arr: numpy.ndarray, axes: Optional[Tuple[bool, ...]] = None) -> return numpy.mean([arr[tuple(p)] for p in itertools.product(*axes_slices)], axis=0) # type: ignore -def mask_from_strings(mask_strings: List[str]) -> numpy.ndarray: +def mask_from_strings(mask_strings: list[str]) -> numpy.ndarray: """ Make a boolean mask array from a list of strings: @@ -76,7 +75,7 @@ def __init__(self, *, k: int): self.k_size = k @property - def standard_vars(self) -> Dict[Hashable, xarray.DataArray]: + def standard_vars(self) -> dict[Hashable, xarray.DataArray]: return { "z_grid": xarray.DataArray( data=self.z_grid, @@ -99,7 +98,7 @@ def standard_vars(self) -> Dict[Hashable, xarray.DataArray]: } @property - def simple_coords(self) -> Dict[Hashable, xarray.DataArray]: + def simple_coords(self) -> dict[Hashable, xarray.DataArray]: return { "zc": xarray.DataArray( data=self.z_centre, @@ -169,7 +168,7 @@ def simple_mask(self) -> xarray.Dataset: }) @property - def standard_vars(self) -> Dict[Hashable, xarray.DataArray]: + def standard_vars(self) -> dict[Hashable, xarray.DataArray]: return { "x_grid": xarray.DataArray( data=self.x_grid, @@ -254,7 +253,7 @@ def standard_vars(self) -> Dict[Hashable, xarray.DataArray]: } @property - def simple_vars(self) -> Dict[str, xarray.DataArray]: + def simple_vars(self) -> dict[str, xarray.DataArray]: simple_vars = {} if self.include_bounds: simple_vars.update({ @@ -280,7 +279,7 @@ def simple_vars(self) -> Dict[str, xarray.DataArray]: return simple_vars @property - def simple_coords(self) -> Dict[Hashable, xarray.DataArray]: + def simple_coords(self) -> dict[Hashable, xarray.DataArray]: return { "longitude": xarray.DataArray( data=self.x_centre,