diff --git a/docs/developing/grass.py b/docs/developing/grass.py index e3238082..7470cf11 100644 --- a/docs/developing/grass.py +++ b/docs/developing/grass.py @@ -2,7 +2,6 @@ import enum from collections.abc import Hashable, Sequence from functools import cached_property -from typing import Optional import numpy import xarray @@ -33,7 +32,7 @@ class Grass(DimensionConvention[GrassGridKind, GrassIndex]): default_grid_kind = GrassGridKind.field @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: # A Grass dataset is recognised by the 'Conventions' global attribute if dataset.attrs['Conventions'] == 'Grass 1.0': return Specificity.HIGH diff --git a/docs/roles.py b/docs/roles.py index c29279e3..dcae1549 100644 --- a/docs/roles.py +++ b/docs/roles.py @@ -1,6 +1,6 @@ import re -from collections.abc import Iterable -from typing import Callable, cast +from collections.abc import Callable, Iterable +from typing import cast import yaml from docutils import nodes, utils @@ -81,7 +81,7 @@ class Citation(Directive): def load_citation_file(self) -> dict: citation_file = self.options['citation_file'] - with open(citation_file, 'r') as f: + with open(citation_file) as f: return cast(dict, yaml.load(f, yaml.Loader)) def run(self) -> list[nodes.Node]: diff --git a/scripts/min_deps_check.py b/scripts/min_deps_check.py index 1d6fbc48..86a5a794 100755 --- a/scripts/min_deps_check.py +++ b/scripts/min_deps_check.py @@ -55,7 +55,7 @@ def parse_requirements( Yield (package name, major version, minor version, patch version) """ - for line_number, line in enumerate(open(fname, 'r'), start=1): + for line_number, line in enumerate(open(fname), start=1): if '#' in line: line = line[:line.index('#')] line = line.strip() diff --git a/scripts/release.py b/scripts/release.py index 57682339..b9211941 100644 --- a/scripts/release.py +++ b/scripts/release.py @@ -6,7 +6,6 @@ import shlex import subprocess import sys -from typing import Optional PROJECT = pathlib.Path(__file__).parent.parent @@ -29,7 +28,7 @@ def main( - args: Optional[list[str]] = None, + args: list[str] | None = None, ) -> None: parser = argparse.ArgumentParser() add_options(parser) @@ -222,7 +221,7 @@ def output(*args: str) -> bytes: def yn( prompt: str, - default: Optional[bool] = None, + default: bool | None = None, ) -> bool: examples = {True: '[Yn]', False: '[yN]', None: '[yn]'}[default] prompt = f'{prompt.strip()} {examples} ' diff --git a/src/emsarray/cli/command.py b/src/emsarray/cli/command.py index cec483b0..ef774ad1 100644 --- a/src/emsarray/cli/command.py +++ b/src/emsarray/cli/command.py @@ -1,6 +1,5 @@ import abc import argparse -from typing import Optional from emsarray.cli import utils @@ -24,11 +23,11 @@ def name(self) -> str: #: A short description of what this subcommand does, #: shown as part of the usage message for the base command. - help: Optional[str] = None + help: str | None = None #: A longer description of what this subcommand does, #: shown as part of the usage message for this subcommand. - description: Optional[str] = None + description: str | None = None def add_parser(self, subparsers: argparse._SubParsersAction) -> None: parser = subparsers.add_parser( diff --git a/src/emsarray/cli/commands/export_geometry.py b/src/emsarray/cli/commands/export_geometry.py index 4a0a860c..21d9a9c9 100644 --- a/src/emsarray/cli/commands/export_geometry.py +++ b/src/emsarray/cli/commands/export_geometry.py @@ -1,7 +1,7 @@ import argparse import logging +from collections.abc import Callable from pathlib import Path -from typing import Callable import xarray diff --git a/src/emsarray/cli/commands/plot.py b/src/emsarray/cli/commands/plot.py index 5266f44c..ec824a8a 100644 --- a/src/emsarray/cli/commands/plot.py +++ b/src/emsarray/cli/commands/plot.py @@ -1,8 +1,9 @@ import argparse import functools import logging +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar import emsarray from emsarray.cli import BaseCommand, CommandException @@ -28,7 +29,7 @@ def __init__( dest: str, *, value_type: Callable = str, - default: Optional[dict[str, Any]] = None, + default: dict[str, Any] | None = None, **kwargs: Any, ): if default is None: @@ -42,7 +43,7 @@ def __call__( parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, - option_string: Optional[str] = None, + option_string: str | None = None, ) -> None: super().__call__ holder = getattr(namespace, self.dest, {}) diff --git a/src/emsarray/cli/utils.py b/src/emsarray/cli/utils.py index 446dcdf8..34c3e285 100644 --- a/src/emsarray/cli/utils.py +++ b/src/emsarray/cli/utils.py @@ -9,10 +9,10 @@ import re import sys import textwrap -from collections.abc import Iterator +from collections.abc import Callable, Iterator from functools import wraps from pathlib import Path -from typing import Callable, Optional, Protocol +from typing import Protocol from shapely.geometry import box, shape from shapely.geometry.base import BaseGeometry @@ -30,7 +30,7 @@ class MainCallable(Protocol): def __call__( self, - argv: Optional[list[str]] = None, + argv: list[str] | None = None, handle_errors: bool = True, ) -> None: ... @@ -112,7 +112,7 @@ def decorator( ) -> MainCallable: @wraps(fn) def wrapper( - argv: Optional[list[str]] = None, + argv: list[str] | None = None, handle_errors: bool = True, ) -> None: parser = argparse.ArgumentParser( @@ -172,7 +172,7 @@ def nice_console_errors() -> Iterator: class DoubleNewlineDescriptionFormatter(argparse.HelpFormatter): def _fill_text(self, text: str, width: int, indent: str) -> str: - fill_text = super(DoubleNewlineDescriptionFormatter, self)._fill_text + fill_text = super()._fill_text return '\n\n'.join( fill_text(paragraph, width, indent) diff --git a/src/emsarray/compat/shapely.py b/src/emsarray/compat/shapely.py index 1fe8d2c2..d5cf4295 100644 --- a/src/emsarray/compat/shapely.py +++ b/src/emsarray/compat/shapely.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Iterable -from typing import Generic, TypeVar, Union, cast +from typing import Generic, TypeVar, cast import numpy import shapely @@ -31,7 +31,7 @@ class SpatialIndex(Generic[T]): def __init__( self, - items: Union[numpy.ndarray, Iterable[tuple[BaseGeometry, T]]], + items: numpy.ndarray | Iterable[tuple[BaseGeometry, T]], ): self.items = numpy.array(items, dtype=self.dtype) diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index a0ba6f82..d43d5eb7 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -3,11 +3,9 @@ import enum import logging import warnings -from collections.abc import Hashable, Iterable, Sequence +from collections.abc import Callable, Hashable, Iterable, Sequence from functools import cached_property -from typing import ( - TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, cast -) +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast import numpy import xarray @@ -166,7 +164,7 @@ def check_validity(cls, dataset: xarray.Dataset) -> None: @classmethod @abc.abstractmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: """ Check if a dataset uses this convention. @@ -582,7 +580,7 @@ def wind_index( self, linear_index: int, *, - grid_kind: Optional[GridKind] = None, + grid_kind: GridKind | None = None, ) -> Index: """Convert a linear index to a conventnion native index. @@ -635,7 +633,7 @@ def wind_index( def unravel_index( self, linear_index: int, - grid_kind: Optional[GridKind] = None, + grid_kind: GridKind | None = None, ) -> Index: """An alias for :meth:`Convention.wind_index()`. @@ -772,7 +770,7 @@ def ravel( self, data_array: xarray.DataArray, *, - linear_dimension: Optional[Hashable] = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: """ Flatten the surface dimensions of a :class:`~xarray.DataArray`, @@ -815,9 +813,9 @@ def wind( self, data_array: xarray.DataArray, *, - grid_kind: Optional[GridKind] = None, - axis: Optional[int] = None, - linear_dimension: Optional[Hashable] = None, + grid_kind: GridKind | None = None, + axis: int | None = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: """ Wind a flattened :class:`~xarray.DataArray` @@ -935,9 +933,9 @@ def data_crs(self) -> 'CRS': def plot_on_figure( self, figure: 'Figure', - scalar: Optional[DataArrayOrName] = None, - vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None, - title: Optional[str] = None, + scalar: DataArrayOrName | None = None, + vector: tuple[DataArrayOrName, DataArrayOrName] | None = None, + title: str | None = None, **kwargs: Any, ) -> None: """Plot values for a :class:`~xarray.DataArray` @@ -1015,10 +1013,10 @@ def plot(self, *args: Any, **kwargs: Any) -> None: def animate_on_figure( self, figure: 'Figure', - scalar: Optional[DataArrayOrName] = None, - vector: Optional[tuple[DataArrayOrName, DataArrayOrName]] = None, - coordinate: Optional[DataArrayOrName] = None, - title: Optional[Union[str, Callable[[Any], str]]] = None, + scalar: DataArrayOrName | None = None, + vector: tuple[DataArrayOrName, DataArrayOrName] | None = None, + coordinate: DataArrayOrName | None = None, + title: str | Callable[[Any], str] | None = None, **kwargs: Any, ) -> 'FuncAnimation': """ @@ -1115,7 +1113,7 @@ def animate_on_figure( @utils.timed_func def make_poly_collection( self, - data_array: Optional[DataArrayOrName] = None, + data_array: DataArrayOrName | None = None, **kwargs: Any, ) -> 'PolyCollection': """ @@ -1192,7 +1190,7 @@ def make_poly_collection( def make_patch_collection( self, - data_array: Optional[DataArrayOrName] = None, + data_array: DataArrayOrName | None = None, **kwargs: Any, ) -> 'PolyCollection': warnings.warn( @@ -1206,8 +1204,8 @@ def make_patch_collection( def make_quiver( self, axes: 'Axes', - u: Optional[DataArrayOrName] = None, - v: Optional[DataArrayOrName] = None, + u: DataArrayOrName | None = None, + v: DataArrayOrName | None = None, **kwargs: Any, ) -> 'Quiver': """ @@ -1238,7 +1236,7 @@ def make_quiver( # sometimes preferring to fill them in later, # so `u` and `v` are optional. # If they are not provided, we set default quiver values of `numpy.nan`. - values: Union[tuple[numpy.ndarray, numpy.ndarray], tuple[float, float]] + values: tuple[numpy.ndarray, numpy.ndarray] | tuple[float, float] values = numpy.nan, numpy.nan if u is not None and v is not None: @@ -1331,7 +1329,7 @@ def mask(self) -> numpy.ndarray: return cast(numpy.ndarray, mask) @cached_property - def geometry(self) -> Union[Polygon, MultiPolygon]: + def geometry(self) -> Polygon | MultiPolygon: """ A :class:`shapely.Polygon` or :class:`shapely.MultiPolygon` that represents the geometry of the entire dataset. @@ -1438,7 +1436,7 @@ def spatial_index(self) -> SpatialIndex[SpatialIndexItem[Index]]: def get_index_for_point( self, point: Point, - ) -> Optional[SpatialIndexItem[Index]]: + ) -> SpatialIndexItem[Index] | None: """ Find the index for a :class:`~shapely.Point` in the dataset. @@ -1761,8 +1759,8 @@ def ocean_floor(self) -> xarray.Dataset: def normalize_depth_variables( self, *, - positive_down: Optional[bool] = None, - deep_to_shallow: Optional[bool] = None, + positive_down: bool | None = None, + deep_to_shallow: bool | None = None, ) -> xarray.Dataset: """An alias for :func:`emsarray.operations.depth.normalize_depth_variables`""" return depth.normalize_depth_variables( @@ -1895,7 +1893,7 @@ def wind_index( self, linear_index: int, *, - grid_kind: Optional[GridKind] = None, + grid_kind: GridKind | None = None, ) -> Index: if grid_kind is None: grid_kind = self.default_grid_kind @@ -1907,7 +1905,7 @@ def ravel( self, data_array: xarray.DataArray, *, - linear_dimension: Optional[Hashable] = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: kind = self.get_grid_kind(data_array) dimensions = self.grid_dimensions[kind] @@ -1919,9 +1917,9 @@ def wind( self, data_array: xarray.DataArray, *, - grid_kind: Optional[GridKind] = None, - axis: Optional[int] = None, - linear_dimension: Optional[Hashable] = None, + grid_kind: GridKind | None = None, + axis: int | None = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: if axis is not None: linear_dimension = data_array.dims[axis] diff --git a/src/emsarray/conventions/_registry.py b/src/emsarray/conventions/_registry.py index c70d5d14..dfbb65ac 100644 --- a/src/emsarray/conventions/_registry.py +++ b/src/emsarray/conventions/_registry.py @@ -1,12 +1,10 @@ import logging -import sys import warnings from collections.abc import Iterable from contextlib import suppress from functools import cached_property from importlib import metadata from itertools import chain -from typing import Optional import xarray @@ -94,7 +92,7 @@ def match_conventions(self, dataset: xarray.Dataset) -> list[tuple[type[Conventi matches.append((convention, match)) return sorted(matches, key=lambda m: m[1], reverse=True) - def guess_convention(self, dataset: xarray.Dataset) -> Optional[type[Convention]]: + def guess_convention(self, dataset: xarray.Dataset) -> type[Convention] | None: """ Guess the correct :class:`.Convention` implementation for a dataset. """ @@ -109,7 +107,7 @@ def guess_convention(self, dataset: xarray.Dataset) -> Optional[type[Convention] registry = ConventionRegistry() -def get_dataset_convention(dataset: xarray.Dataset) -> Optional[type[Convention]]: +def get_dataset_convention(dataset: xarray.Dataset) -> type[Convention] | None: """Find the most appropriate Convention subclass for this dataset. Parameters @@ -146,10 +144,7 @@ def entry_point_conventions() -> Iterable[type[Convention]]: ('emsarray.formats', True), ] for group, deprecated in groups: - if sys.version_info >= (3, 10): - entry_points = metadata.entry_points(group=group) - else: - entry_points = metadata.entry_points().get(group, []) + entry_points = metadata.entry_points(group=group) for entry_point in entry_points: if deprecated: diff --git a/src/emsarray/conventions/arakawa_c.py b/src/emsarray/conventions/arakawa_c.py index 1ea25d1d..150ed6b0 100644 --- a/src/emsarray/conventions/arakawa_c.py +++ b/src/emsarray/conventions/arakawa_c.py @@ -10,7 +10,7 @@ import logging from collections.abc import Hashable, Sequence from functools import cached_property -from typing import Optional, cast +from typing import cast import numpy import xarray @@ -168,7 +168,7 @@ def __init__( self, dataset: xarray.Dataset, *, - coordinate_names: Optional[dict[Hashable, tuple[Hashable, Hashable]]] = None, + coordinate_names: dict[Hashable, tuple[Hashable, Hashable]] | None = None, ): super().__init__(dataset) @@ -192,7 +192,7 @@ def __init__( ) @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: if not hasattr(cls, 'coordinate_names'): return None @@ -330,7 +330,7 @@ def apply_clip_mask(self, clip_mask: xarray.Dataset, work_dir: Pathish) -> xarra def c_mask_from_centres( face_mask: numpy.ndarray, dimensions: ArakawaCDimensions, - coords: Optional[DatasetCoordinates] = None, + coords: DatasetCoordinates | None = None, ) -> xarray.Dataset: """ Create a mask for a SHOC standard file given a mask array for the cell diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index ae0b278e..d45fd231 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Sequence from contextlib import suppress from functools import cached_property -from typing import Generic, Optional, TypeVar, cast +from typing import Generic, TypeVar, cast import numpy import xarray @@ -51,8 +51,8 @@ class CFGridTopology(abc.ABC): def __init__( self, dataset: xarray.Dataset, - longitude: Optional[Hashable] = None, - latitude: Optional[Hashable] = None, + longitude: Hashable | None = None, + latitude: Hashable | None = None, ): """ Construct a new :class:`CFGridTopology` instance for a dataset. @@ -200,9 +200,9 @@ def __init__( self, dataset: xarray.Dataset, *, - latitude: Optional[Hashable] = None, - longitude: Optional[Hashable] = None, - topology: Optional[Topology] = None, + latitude: Hashable | None = None, + longitude: Hashable | None = None, + topology: Topology | None = None, ) -> None: """ Construct a new :class:`CFGrid` instance. @@ -270,7 +270,7 @@ def get_all_geometry_names(self) -> list[Hashable]: self.topology.latitude_name, ] - bounds_names: list[Optional[Hashable]] = [ + bounds_names: list[Hashable | None] = [ self.topology.longitude.attrs.get('bounds', None), self.topology.latitude.attrs.get('bounds', None), ] @@ -387,7 +387,7 @@ class CFGrid1D(CFGrid[CFGrid1DTopology]): topology_class = CFGrid1DTopology @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: """ A dataset is a 1D CF grid if it has one dimensional latitude and longitude coordinate variables. @@ -543,7 +543,7 @@ class CFGrid2D(CFGrid[CFGrid2DTopology]): topology_class = CFGrid2DTopology @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: """ A dataset is a 2D CF grid if it has two dimensional latitude and longitude coordinate variables. diff --git a/src/emsarray/conventions/shoc.py b/src/emsarray/conventions/shoc.py index 1483499e..cb205711 100644 --- a/src/emsarray/conventions/shoc.py +++ b/src/emsarray/conventions/shoc.py @@ -17,7 +17,6 @@ import logging from collections.abc import Hashable from functools import cached_property -from typing import Optional import xarray @@ -108,7 +107,7 @@ def topology(self) -> CFGrid2DTopology: return CFGrid2DTopology(self.dataset, latitude=latitude, longitude=longitude) @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: if 'ems_version' not in dataset.attrs: return None if not set(dataset.dims).issuperset(cls._dimensions): diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index a81fb631..0d74bd4b 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -14,7 +14,7 @@ from contextlib import suppress from dataclasses import dataclass from functools import cached_property -from typing import Any, Optional, cast +from typing import Any, cast import numpy import shapely @@ -353,7 +353,7 @@ class Mesh2DTopology: #: The name of the mesh topology variable. Optional. If not provided, the #: mesh topology dummy variable will be found by checking the ``cf_role`` #: attribute. - topology_key: Optional[Hashable] = None + topology_key: Hashable | None = None #: The default dtype to use for index data arrays. Hard coded to ``int32``, #: which should be sufficient for all datasets. ``int16`` is too small for @@ -444,7 +444,7 @@ def node_y(self) -> xarray.DataArray: return self.dataset.data_vars[self._node_coordinates[1]] @property - def edge_x(self) -> Optional[xarray.DataArray]: + def edge_x(self) -> xarray.DataArray | None: """Data array of characteristic edge X / longitude coordinates. Optional.""" try: return self.dataset.data_vars[self._edge_coordinates[0]] @@ -452,7 +452,7 @@ def edge_x(self) -> Optional[xarray.DataArray]: return None @property - def edge_y(self) -> Optional[xarray.DataArray]: + def edge_y(self) -> xarray.DataArray | None: """Data array of characteristic edge y / latitude coordinates. Optional.""" try: return self.dataset.data_vars[self._edge_coordinates[1]] @@ -460,7 +460,7 @@ def edge_y(self) -> Optional[xarray.DataArray]: return None @property - def face_x(self) -> Optional[xarray.DataArray]: + def face_x(self) -> xarray.DataArray | None: """Data array of characteristic face x / longitude coordinates. Optional.""" try: return self.dataset.data_vars[self._face_coordinates[0]] @@ -468,7 +468,7 @@ def face_x(self) -> Optional[xarray.DataArray]: return None @property - def face_y(self) -> Optional[xarray.DataArray]: + def face_y(self) -> xarray.DataArray | None: """Data array of characteristic face y / latitude coordinates. Optional.""" try: return self.dataset.data_vars[self._face_coordinates[1]] @@ -1024,7 +1024,7 @@ class UGrid(DimensionConvention[UGridKind, UGridIndex]): default_grid_kind = UGridKind.face @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: """ A UGrid dataset needs a global attribute of Conventions = 'UGRID/...', and a variable with attribute cf_role = 'mesh_topology' diff --git a/src/emsarray/formats.py b/src/emsarray/formats.py index 0b2bb8a4..e5efd92d 100644 --- a/src/emsarray/formats.py +++ b/src/emsarray/formats.py @@ -1,6 +1,6 @@ import warnings from functools import wraps -from typing import Any, Optional +from typing import Any import xarray @@ -17,7 +17,7 @@ def _warn_old_new(old: str, new: str, **kwargs: Any) -> None: @wraps(get_dataset_convention) -def get_file_format(dataset: xarray.Dataset, **kwargs: Any) -> Optional[type[Convention]]: +def get_file_format(dataset: xarray.Dataset, **kwargs: Any) -> type[Convention] | None: _warn_old_new( old="emsarray.formats.get_file_format", new="emsarray.conventions.get_dataset_convention", diff --git a/src/emsarray/nco.py b/src/emsarray/nco.py index 04c5ca56..3f767ed7 100644 --- a/src/emsarray/nco.py +++ b/src/emsarray/nco.py @@ -11,9 +11,9 @@ import subprocess from collections.abc import Sequence from pathlib import Path -from typing import Any, Optional, Union +from typing import Any -Pathish = Union[Path, str] +from emsarray.types import Pathish def _check_call(cmd: Sequence[str], stdin: Any = subprocess.DEVNULL, **kwargs: Any) -> None: @@ -23,7 +23,7 @@ def _check_call(cmd: Sequence[str], stdin: Any = subprocess.DEVNULL, **kwargs: A def ncrcat( input_files: Sequence[Pathish], output_file: Pathish, - flags: Optional[str] = None, + flags: str | None = None, history: bool = False, ) -> None: """Concatenates a set of netCDF files together using `ncrcat`.""" diff --git a/src/emsarray/operations/depth.py b/src/emsarray/operations/depth.py index 3ab97651..4473cf3f 100644 --- a/src/emsarray/operations/depth.py +++ b/src/emsarray/operations/depth.py @@ -4,8 +4,8 @@ """ import warnings from collections import defaultdict -from collections.abc import Hashable -from typing import Iterable, Optional, cast +from collections.abc import Hashable, Iterable +from typing import cast import numpy import xarray @@ -18,7 +18,7 @@ def ocean_floor( dataset: xarray.Dataset, depth_coordinates: Iterable[DataArrayOrName], *, - non_spatial_variables: Optional[Iterable[DataArrayOrName]] = None, + non_spatial_variables: Iterable[DataArrayOrName] | None = None, ) -> xarray.Dataset: """Make a new :class:`xarray.Dataset` reduced along the given depth coordinates to only contain values along the ocean floor. @@ -200,8 +200,8 @@ def normalize_depth_variables( dataset: xarray.Dataset, depth_coordinates: Iterable[DataArrayOrName], *, - positive_down: Optional[bool] = None, - deep_to_shallow: Optional[bool] = None, + positive_down: bool | None = None, + deep_to_shallow: bool | None = None, ) -> xarray.Dataset: """ Some datasets represent depth as a positive variable, some as negative. diff --git a/src/emsarray/operations/geometry.py b/src/emsarray/operations/geometry.py index 1034646a..2ef472dd 100644 --- a/src/emsarray/operations/geometry.py +++ b/src/emsarray/operations/geometry.py @@ -7,7 +7,7 @@ import pathlib from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager -from typing import IO, Any, Optional, TypeVar, Union +from typing import IO, Any, TypeVar import geojson import shapefile @@ -108,7 +108,7 @@ def write_geojson( @contextmanager -def _maybe_open(path_or_file: Union[Pathish, IO], mode: str) -> Generator[IO, None, None]: +def _maybe_open(path_or_file: Pathish | IO, mode: str) -> Generator[IO, None, None]: """ Given either a path to a file or an open file handle, return an open file handle wrapped in a context manager. @@ -128,12 +128,12 @@ def _maybe_open(path_or_file: Union[Pathish, IO], mode: str) -> Generator[IO, No def write_shapefile( dataset: xarray.Dataset, - target: Optional[Pathish] = None, + target: Pathish | None = None, *, - shp: Optional[Union[Pathish, IO]] = None, - shx: Optional[Union[Pathish, IO]] = None, - dbf: Optional[Union[Pathish, IO]] = None, - prj: Optional[Union[Pathish, IO]] = None, + shp: Pathish | IO | None = None, + shx: Pathish | IO | None = None, + dbf: Pathish | IO | None = None, + prj: Pathish | IO | None = None, **kwargs: Any, ) -> None: """ diff --git a/src/emsarray/plot.py b/src/emsarray/plot.py index 6cac2b28..348658d1 100644 --- a/src/emsarray/plot.py +++ b/src/emsarray/plot.py @@ -1,6 +1,6 @@ import importlib.metadata -from collections.abc import Iterable -from typing import Any, Callable, Literal, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any, Literal import numpy import packaging.version @@ -224,7 +224,7 @@ def polygons_to_collection( def make_plot_title( dataset: xarray.Dataset, data_array: xarray.DataArray, -) -> Optional[str]: +) -> str | None: """ Make a suitable plot title for a variable. This will attempt to find a name for the variable by looking through the attributes. @@ -268,11 +268,11 @@ def plot_on_figure( figure: Figure, convention: 'conventions.Convention', *, - scalar: Optional[xarray.DataArray] = None, - vector: Optional[tuple[xarray.DataArray, xarray.DataArray]] = None, - title: Optional[str] = None, - projection: Optional[cartopy.crs.Projection] = None, - landmarks: Optional[Iterable[Landmark]] = None, + scalar: xarray.DataArray | None = None, + vector: tuple[xarray.DataArray, xarray.DataArray] | None = None, + title: str | None = None, + projection: cartopy.crs.Projection | None = None, + landmarks: Iterable[Landmark] | None = None, gridlines: bool = True, coast: bool = True, ) -> None: @@ -360,15 +360,15 @@ def animate_on_figure( convention: 'conventions.Convention', *, coordinate: xarray.DataArray, - scalar: Optional[xarray.DataArray] = None, - vector: Optional[tuple[xarray.DataArray, xarray.DataArray]] = None, - title: Optional[Union[str, Callable[[Any], str]]] = None, - projection: Optional[cartopy.crs.Projection] = None, - landmarks: Optional[Iterable[Landmark]] = None, + scalar: xarray.DataArray | None = None, + vector: tuple[xarray.DataArray, xarray.DataArray] | None = None, + title: str | Callable[[Any], str] | None = None, + projection: cartopy.crs.Projection | None = None, + landmarks: Iterable[Landmark] | None = None, gridlines: bool = True, coast: bool = True, interval: int = 1000, - repeat: Union[bool, Literal['cycle', 'bounce']] = True, + repeat: bool | Literal['cycle', 'bounce'] = True, ) -> animation.FuncAnimation: """ Plot a :class:`xarray.DataArray` diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py index e3d92fee..9bba29b9 100644 --- a/src/emsarray/transect.py +++ b/src/emsarray/transect.py @@ -1,7 +1,7 @@ import dataclasses -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import cached_property -from typing import Any, Callable, Generic, Optional, Union, cast +from typing import Any, Generic, cast import cfunits import numpy @@ -117,7 +117,7 @@ def __init__( self, dataset: xarray.Dataset, line: shapely.LineString, - depth: Optional[DataArrayOrName] = None, + depth: DataArrayOrName | None = None, ): self.dataset = dataset self.convention = dataset.ems @@ -223,7 +223,7 @@ def transect_dataset(self) -> xarray.Dataset: def _set_up_axis(self, variable: xarray.DataArray) -> tuple[str, Formatter]: title = str(variable.attrs.get('long_name')) - units: Optional[str] = variable.attrs.get('units') + units: str | None = variable.attrs.get('units') if units is not None: # Use cfunits to normalize the units to their short symbol form. @@ -237,7 +237,7 @@ def _set_up_axis(self, variable: xarray.DataArray) -> tuple[str, Formatter]: def _crs_for_point( self, point: shapely.Point, - globe: Optional[crs.Globe] = None, + globe: crs.Globe | None = None, ) -> crs.Projection: return crs.AzimuthalEquidistant( central_longitude=point.x, central_latitude=point.y, globe=globe) @@ -557,13 +557,13 @@ def plot_on_figure( figure: Figure, data_array: xarray.DataArray, *, - title: Optional[str] = None, + title: str | None = None, trim_nans: bool = True, clamp_to_surface: bool = True, - bathymetry: Optional[xarray.DataArray] = None, - cmap: Union[str, Colormap] = 'jet', + bathymetry: xarray.DataArray | None = None, + cmap: str | Colormap = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[list[Landmark]] = None, + landmarks: list[Landmark] | None = None, ) -> None: """ Plot the data array along this transect. @@ -617,14 +617,14 @@ def animate_on_figure( figure: Figure, data_array: xarray.DataArray, *, - title: Optional[Union[str, Callable[[Any], str]]] = None, + title: str | Callable[[Any], str] | None = None, trim_nans: bool = True, clamp_to_surface: bool = True, - bathymetry: Optional[xarray.DataArray] = None, - cmap: Union[str, Colormap] = 'jet', + bathymetry: xarray.DataArray | None = None, + cmap: str | Colormap = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[list[Landmark]] = None, - coordinate: Optional[xarray.DataArray] = None, + landmarks: list[Landmark] | None = None, + coordinate: xarray.DataArray | None = None, interval: int = 200, ) -> animation.FuncAnimation: """ @@ -706,13 +706,13 @@ def _plot_on_figure( figure: Figure, data_array: xarray.DataArray, *, - title: Optional[str] = None, + title: str | None = None, trim_nans: bool = True, clamp_to_surface: bool = True, - bathymetry: Optional[xarray.DataArray] = None, - cmap: Union[str, Colormap] = 'jet', + bathymetry: xarray.DataArray | None = None, + cmap: str | Colormap = 'jet', ocean_floor_colour: str = 'black', - landmarks: Optional[list[Landmark]] = None, + landmarks: list[Landmark] | None = None, ) -> tuple[Axes, PolyCollection, xarray.DataArray]: """ Construct the axes and PolyCollections on a plot, diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index 63fa15fa..12c05a1b 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -15,10 +15,10 @@ import time import warnings from collections.abc import ( - Hashable, Iterable, Mapping, MutableMapping, Sequence + Callable, Hashable, Iterable, Mapping, MutableMapping, Sequence ) from types import TracebackType -from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast +from typing import Any, Literal, TypeVar, cast import cftime import netCDF4 @@ -58,10 +58,10 @@ def __enter__(self) -> 'PerfTimer': def __exit__( self, - exc_type: Optional[type[_Exception]], - exc_value: Optional[_Exception], + exc_type: type[_Exception] | None, + exc_value: _Exception | None, traceback: TracebackType - ) -> Optional[bool]: + ) -> bool | None: self._stop = time.perf_counter() self.running = False return None @@ -113,7 +113,7 @@ def wrapper(*args: Any, **kwargs: Any) -> _T: def to_netcdf_with_fixes( dataset: xarray.Dataset, path: Pathish, - time_variable: Optional[DataArrayOrName] = None, + time_variable: DataArrayOrName | None = None, **kwargs: Any, ) -> None: """Saves a :class:`xarray.Dataset` to a netCDF4 file, @@ -151,7 +151,7 @@ def to_netcdf_with_fixes( fix_time_units_for_ems(path, data_array_to_name(dataset, time_variable)) -def format_time_units_for_ems(units: str, calendar: Optional[str] = DEFAULT_CALENDAR) -> str: +def format_time_units_for_ems(units: str, calendar: str | None = DEFAULT_CALENDAR) -> str: """ Reformat a given time unit string to an EMS-compatible string. ``xarray`` will always format time unit strings using ISO8601 strings with @@ -238,14 +238,14 @@ def fix_time_units_for_ems( dataset.sync() -def _get_variables(dataset_or_array: Union[xarray.Dataset, xarray.DataArray]) -> list[xarray.Variable]: +def _get_variables(dataset_or_array: xarray.Dataset | xarray.DataArray) -> list[xarray.Variable]: if isinstance(dataset_or_array, xarray.Dataset): return list(dataset_or_array.variables.values()) else: return [dataset_or_array.variable] -def disable_default_fill_value(dataset_or_array: Union[xarray.Dataset, xarray.DataArray]) -> None: +def disable_default_fill_value(dataset_or_array: xarray.Dataset | xarray.DataArray) -> None: """ Update all variables on this dataset or data array and disable the automatic ``_FillValue`` :mod:`xarray` sets. An automatic fill value can spoil @@ -428,7 +428,7 @@ def check_data_array_dimensions_match( dataset: xarray.Dataset, data_array: xarray.DataArray, *, - dimensions: Optional[Sequence[Hashable]] = None, + dimensions: Sequence[Hashable] | None = None, ) -> None: """ Check that the dimensions of a :class:`xarray.DataArray` @@ -522,7 +522,7 @@ def move_dimensions_to_end( def ravel_dimensions( data_array: xarray.DataArray, dimensions: list[Hashable], - linear_dimension: Optional[Hashable] = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: """ Flatten the given dimensions of a :class:`~xarray.DataArray`. @@ -688,7 +688,7 @@ def __init__(self, extra: str) -> None: def requires_extra( extra: str, - import_error: Optional[ImportError], + import_error: ImportError | None, exception_class: type[RequiresExtraException] = RequiresExtraException, ) -> Callable[[_T], _T]: if import_error is None: @@ -705,7 +705,7 @@ def error(*args: Any, **kwargs: Any) -> Any: def make_polygons_with_holes( points: numpy.ndarray, *, - out: Optional[numpy.ndarray] = None, + out: numpy.ndarray | None = None, ) -> numpy.ndarray: """ Make a :class:`numpy.ndarray` of :class:`shapely.Polygon` from an array of (n, m, 2) points. diff --git a/tests/conventions/test_base.py b/tests/conventions/test_base.py index 43b908e0..489cb4a5 100644 --- a/tests/conventions/test_base.py +++ b/tests/conventions/test_base.py @@ -3,7 +3,6 @@ import pathlib from collections.abc import Hashable from functools import cached_property -from typing import Optional import numpy import pandas @@ -40,7 +39,7 @@ class SimpleConvention(Convention[SimpleGridKind, SimpleGridIndex]): default_grid_kind = SimpleGridKind.face @classmethod - def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]: + def check_dataset(cls, dataset: xarray.Dataset) -> int | None: return None @cached_property @@ -64,7 +63,7 @@ def wind_index( self, index: int, *, - grid_kind: Optional[SimpleGridKind] = None, + grid_kind: SimpleGridKind | None = None, ) -> SimpleGridIndex: y, x = map(int, numpy.unravel_index(index, self.shape)) return SimpleGridIndex(y, x) @@ -79,7 +78,7 @@ def ravel( self, data_array: xarray.DataArray, *, - linear_dimension: Optional[Hashable] = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: self.get_grid_kind(data_array) return utils.ravel_dimensions( @@ -90,9 +89,9 @@ def wind( self, data_array: xarray.DataArray, *, - grid_kind: Optional[SimpleGridKind] = None, - axis: Optional[int] = None, - linear_dimension: Optional[Hashable] = None, + grid_kind: SimpleGridKind | None = None, + axis: int | None = None, + linear_dimension: Hashable | None = None, ) -> xarray.DataArray: if axis is not None: linear_dimension = data_array.dims[axis] @@ -312,9 +311,9 @@ def test_strtree(): convention = SimpleConvention(dataset) line = LineString([(-1, -1), (1.5, 1.5), (1.5, 2.5), (3.9, 3.9)]) - expected_intersections = set(convention.ravel_index(index) for index in [ + expected_intersections = {convention.ravel_index(index) for index in [ SimpleGridIndex(1, 1), SimpleGridIndex(2, 1), SimpleGridIndex(2, 2), - SimpleGridIndex(3, 2), SimpleGridIndex(3, 3)]) + SimpleGridIndex(3, 2), SimpleGridIndex(3, 3)]} # Query the spatial index items = convention.strtree.query(line, predicate='intersects') diff --git a/tests/conventions/test_registry.py b/tests/conventions/test_registry.py index cd6be356..8b61196d 100644 --- a/tests/conventions/test_registry.py +++ b/tests/conventions/test_registry.py @@ -1,7 +1,6 @@ """ Test convention class registration by entry points or manual registration. """ -import sys from importlib import metadata import pytest @@ -70,12 +69,8 @@ def monkeypatch_entrypoint( ] for group, entries in entry_points.items() } - if sys.version_info >= (3, 10): - def mocked(group: str) -> list[metadata.EntryPoint]: - return _entry_points.get(group, []) - else: - def mocked() -> list[metadata.EntryPoint]: - return _entry_points + def mocked(group: str) -> list[metadata.EntryPoint]: + return _entry_points.get(group, []) monkeypatch.setattr(metadata, 'entry_points', mocked) return entry_points diff --git a/tests/datasets/make_examples.py b/tests/datasets/make_examples.py index d427bf3f..e415888d 100644 --- a/tests/datasets/make_examples.py +++ b/tests/datasets/make_examples.py @@ -7,7 +7,7 @@ import datetime import functools import pathlib -from typing import Callable +from collections.abc import Callable import netCDF4 import numpy @@ -186,12 +186,12 @@ def make_ugrid_mesh2d(out: pathlib.Path) -> None: faces = [polygon.intersection(envelope) for polygon in voronoi.geoms] # Get the unique vertices of the faces - nodes = numpy.array(list(set( + nodes = numpy.array(list({ p for polygon in faces for p in polygon.exterior.coords - ))) + })) # A map between {point: index} - node_indices = dict((tuple(p), i) for i, p in enumerate(nodes)) + node_indices = {tuple(p): i for i, p in enumerate(nodes)} # Number of vertices nnodes = len(node_indices) # Maximum vertex count for any face diff --git a/tests/operations/depth/test_normalize_depth_variables.py b/tests/operations/depth/test_normalize_depth_variables.py index f415e095..1f7b6828 100644 --- a/tests/operations/depth/test_normalize_depth_variables.py +++ b/tests/operations/depth/test_normalize_depth_variables.py @@ -1,5 +1,3 @@ -from typing import Optional - import numpy import pytest import xarray @@ -35,8 +33,8 @@ def test_normalize_depth_variable( input_positive: str, input_deep_to_shallow: bool, set_positive: bool, - positive_down: Optional[bool], - deep_to_shallow: Optional[bool], + positive_down: bool | None, + deep_to_shallow: bool | None, recwarn, ): # Some datasets have a coordinate with the same dimension name diff --git a/tests/utils.py b/tests/utils.py index fd1083b2..23b92ed5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ import warnings from collections.abc import Hashable from functools import cached_property -from typing import Any, Optional +from typing import Any import numpy import pytest @@ -42,7 +42,7 @@ def box(minx, miny, maxx, maxy) -> shapely.Polygon: ]) -def reduce_axes(arr: numpy.ndarray, axes: Optional[tuple[bool, ...]] = None) -> numpy.ndarray: +def reduce_axes(arr: numpy.ndarray, axes: tuple[bool, ...] | None = None) -> numpy.ndarray: """ Reduce the size of an array by one on an axis-by-axis basis. If an axis is reduced, neigbouring values are averaged together @@ -138,7 +138,7 @@ def __init__( self, *, j: int, i: int, - face_mask: Optional[numpy.ndarray] = None, + face_mask: numpy.ndarray | None = None, include_bounds: bool = False, ): self.j_size = j