diff --git a/imod/mf6/model.py b/imod/mf6/model.py index d0ff13abb..25aa2a0cb 100644 --- a/imod/mf6/model.py +++ b/imod/mf6/model.py @@ -21,6 +21,7 @@ from imod.mf6.interfaces.imodel import IModel from imod.mf6.package import Package from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase +from imod.mf6.utilities.mask import _mask_all_packages from imod.mf6.utilities.regrid import ( _regrid_like, ) @@ -500,12 +501,8 @@ def mask_all_packages( idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, -1 sets cells to vertical passthrough """ - if any([coord not in ["x", "y", "layer", "mesh2d_nFaces", "dx", "dy"] for coord in mask.coords]): - raise ValueError("unexpected coordinate dimension in masking domain") - for pkgname, pkg in self.items(): - self[pkgname] = pkg.mask(mask) - self.purge_empty_packages() + _mask_all_packages(self, mask) def purge_empty_packages(self, model_name: Optional[str] = "") -> None: """ diff --git a/imod/mf6/package.py b/imod/mf6/package.py index 324302364..c0e76b47e 100644 --- a/imod/mf6/package.py +++ b/imod/mf6/package.py @@ -1,7 +1,6 @@ from __future__ import annotations import abc -import numbers import pathlib from collections import defaultdict from typing import Any, Mapping, Optional, Tuple, Union @@ -11,20 +10,16 @@ import numpy as np import xarray as xr import xugrid as xu -from xarray.core.utils import is_scalar import imod -from imod.mf6.auxiliary_variables import ( - expand_transient_auxiliary_variables, - get_variable_names, - remove_expanded_auxiliary_variables_from_dataset, -) +from imod.mf6.auxiliary_variables import get_variable_names from imod.mf6.interfaces.ipackage import IPackage from imod.mf6.pkgbase import ( EXCHANGE_PACKAGES, TRANSPORT_PACKAGES, PackageBase, ) +from imod.mf6.utilities.mask import _mask from imod.mf6.utilities.regrid import ( RegridderType, _regrid_like, @@ -540,59 +535,7 @@ def mask(self, mask: GridDataArray) -> Any: The package with part masked. """ - masked = {} - if hasattr(self,"auxiliary_data_fields"): - remove_expanded_auxiliary_variables_from_dataset(self) - for var in self.dataset.data_vars.keys(): - if self._skip_masking_variable(var, self.dataset[var]): - masked[var] = self.dataset[var] - else: - masked[var] = self._mask_spatial_var(var, mask) - if hasattr(self,"auxiliary_data_fields"): - expand_transient_auxiliary_variables(self) - return type(self)(**masked) - - def _skip_masking_variable(self, var: str, da: GridDataArray)->bool: - if self._skip_masking_dataarray(var) or len(da.dims) == 0 or set(da.coords).issubset(["layer"]): - return True - if is_scalar(da.values[()]): - return True - spatial_dims = ["x", "y", "mesh2d_nFaces", "layer"] - if not np.any( [coord in spatial_dims for coord in da.coords]): - return True - return False - - def _mask_spatial_var(self, var: str, mask: GridDataArray)->GridDataArray: - da = self.dataset[var] - array_mask = self._adjust_mask_for_unlayered_data(da, mask) - - if issubclass(da.dtype.type, numbers.Integral): - if var == "idomain": - return da.where(array_mask > 0, other=array_mask) - else: - return da.where(array_mask > 0, other=0) - elif issubclass(da.dtype.type, numbers.Real): - return da.where(array_mask > 0) - else: - raise TypeError( - f"Expected dtype float or integer. Received instead: {da.dtype}" - ) - - def _adjust_mask_for_unlayered_data(self, da: GridDataArray, mask: GridDataArray)->GridDataArray: - ''' - Some arrays are not layered while the mask is layered (for example the - top array in dis or disv packaged). In that case we use the top layer of - the mask to perform the masking. If layer is not a dataset dimension, - but still a dataset coordinate, we limit the mask to the relevant layer - coordinate(s). - ''' - array_mask = mask - if "layer" in da.coords and "layer" not in da.dims: - array_mask = mask.sel(layer=da.coords["layer"]) - if "layer" not in da.coords and "layer" in array_mask.coords: - array_mask = mask.isel(layer=0) - - return array_mask + return _mask(self, mask) def regrid_like( diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index 5dc9faf33..c91d6302c 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -37,15 +37,14 @@ from imod.mf6.package import Package from imod.mf6.ssm import SourceSinkMixing from imod.mf6.statusinfo import NestedStatusInfo +from imod.mf6.utilities.mask import _mask_all_models from imod.mf6.utilities.regrid import _regrid_like from imod.mf6.write_context import WriteContext from imod.schemata import ValidationError from imod.typing import GridDataArray, GridDataset from imod.typing.grid import ( concat, - get_spatial_dimension_names, is_equal, - is_same_domain, is_unstructured, merge_partitions, ) @@ -1229,7 +1228,8 @@ def is_split(self) -> bool: def has_one_flow_model(self) -> bool: flow_models = self.get_models_of_type("gwf6") - return len(flow_models) == 1 + return len(flow_models) == 1 + def mask_all_models( self, mask: GridDataArray, @@ -1248,21 +1248,4 @@ def mask_all_models( idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, -1 sets cells to vertical passthrough """ - spatial_dims = get_spatial_dimension_names(mask) - if any([coord not in spatial_dims for coord in mask.coords]): - raise ValueError("unexpected coordinate dimension in masking domain") - - - if self.is_split(): - raise ValueError("masking can only be applied to simulations that have not been split. Apply masking before splitting.") - - flowmodels =list(self.get_models_of_type("gwf6").keys()) - transportmodels = list(self.get_models_of_type("gwt6").keys()) - modelnames = flowmodels + transportmodels - - - for name in modelnames: - if is_same_domain(self[name].domain, mask): - self[name].mask_all_packages(mask) - else: - raise ValueError("masking can only be applied to simulations when all the models in the simulation use the same grid.") \ No newline at end of file + _mask_all_models(self, mask) \ No newline at end of file diff --git a/imod/mf6/utilities/mask.py b/imod/mf6/utilities/mask.py new file mode 100644 index 000000000..16f44e06c --- /dev/null +++ b/imod/mf6/utilities/mask.py @@ -0,0 +1,111 @@ + +import numbers + +import numpy as np +from xarray.core.utils import is_scalar + +from imod.mf6.auxiliary_variables import ( + expand_transient_auxiliary_variables, + remove_expanded_auxiliary_variables_from_dataset, +) +from imod.mf6.interfaces.imodel import IModel +from imod.mf6.interfaces.ipackage import IPackage +from imod.mf6.interfaces.isimulation import ISimulation +from imod.typing.grid import GridDataArray, get_spatial_dimension_names, is_same_domain + + +def _mask_all_models( + simulation: ISimulation, + mask: GridDataArray, + ): + spatial_dims = get_spatial_dimension_names(mask) + if any([coord not in spatial_dims for coord in mask.coords]): + raise ValueError("unexpected coordinate dimension in masking domain") + + + if simulation.is_split(): + raise ValueError("masking can only be applied to simulations that have not been split. Apply masking before splitting.") + + flowmodels =list(simulation.get_models_of_type("gwf6").keys()) + transportmodels = list(simulation.get_models_of_type("gwt6").keys()) + modelnames = flowmodels + transportmodels + + + for name in modelnames: + if is_same_domain(simulation[name].domain, mask): + simulation[name].mask_all_packages(mask) + else: + raise ValueError("masking can only be applied to simulations when all the models in the simulation use the same grid.") + + +def _mask_all_packages( + model: IModel, + mask: GridDataArray, +): + spatial_dimension_names = get_spatial_dimension_names(mask) + if any([coord not in spatial_dimension_names for coord in mask.coords]): + raise ValueError("unexpected coordinate dimension in masking domain") + + for pkgname, pkg in model.items(): + model[pkgname] = pkg.mask(mask) + model.purge_empty_packages() + + +def _mask(package: IPackage, mask: GridDataArray) -> IPackage: + masked = {} + if len(package.auxiliary_data_fields) > 0: + remove_expanded_auxiliary_variables_from_dataset(package) + for var in package.dataset.data_vars.keys(): + if _skip_masking_variable(package, var, package.dataset[var]): + masked[var] = package.dataset[var] + else: + masked[var] = _mask_spatial_var(package, var, mask) + if len(package.auxiliary_data_fields) > 0: + expand_transient_auxiliary_variables(package) + return type(package)(**masked) + + +def _skip_masking_variable(package: IPackage, var: str, da: GridDataArray)->bool: + if package._skip_masking_dataarray(var) or len(da.dims) == 0 or set(da.coords).issubset(["layer"]): + return True + if is_scalar(da.values[()]): + return True + spatial_dims = ["x", "y", "mesh2d_nFaces", "layer"] + if not np.any( [coord in spatial_dims for coord in da.coords]): + return True + return False + + + + +def _mask_spatial_var(self, var: str, mask: GridDataArray)->GridDataArray: + da = self.dataset[var] + array_mask = _adjust_mask_for_unlayered_data(da, mask) + + if issubclass(da.dtype.type, numbers.Integral): + if var == "idomain": + return da.where(array_mask > 0, other=array_mask) + else: + return da.where(array_mask > 0, other=0) + elif issubclass(da.dtype.type, numbers.Real): + return da.where(array_mask > 0) + else: + raise TypeError( + f"Expected dtype float or integer. Received instead: {da.dtype}" + ) + +def _adjust_mask_for_unlayered_data(da: GridDataArray, mask: GridDataArray)->GridDataArray: + ''' + Some arrays are not layered while the mask is layered (for example the + top array in dis or disv packaged). In that case we use the top layer of + the mask to perform the masking. If layer is not a dataset dimension, + but still a dataset coordinate, we limit the mask to the relevant layer + coordinate(s). + ''' + array_mask = mask + if "layer" in da.coords and "layer" not in da.dims: + array_mask = mask.sel(layer=da.coords["layer"]) + if "layer" not in da.coords and "layer" in array_mask.coords: + array_mask = mask.isel(layer=0) + + return array_mask \ No newline at end of file