diff --git a/docs/api/msw.rst b/docs/api/msw.rst index ca1a857a2..a14f60f59 100644 --- a/docs/api/msw.rst +++ b/docs/api/msw.rst @@ -14,6 +14,7 @@ Model objects & methods MetaSwapModel.dump MetaSwapModel.from_imod5_data MetaSwapModel.regrid_like + MetaSwapModel.mask_all_packages MetaSwapModel.clip_box MetaSwapModel.get_pkgkey diff --git a/imod/msw/model.py b/imod/msw/model.py index 5d5d0081e..a1dcf8b12 100644 --- a/imod/msw/model.py +++ b/imod/msw/model.py @@ -23,9 +23,9 @@ from imod.common.utilities.version import prepend_content_with_version_info from imod.mf6.dis import StructuredDiscretization from imod.mf6.mf6_wel_adapter import Mf6Wel +from imod.msw import GridData from imod.msw.copy_files import FileCopier from imod.msw.coupler_mapping import CouplerMapping -from imod.msw.grid_data import GridData from imod.msw.idf_mapping import IdfMapping from imod.msw.infiltration import Infiltration from imod.msw.initial_conditions import ( @@ -52,7 +52,11 @@ from imod.msw.utilities.imod5_converter import ( has_active_scaling_factor, ) -from imod.msw.utilities.mask import mask_and_broadcast_cap_data +from imod.msw.utilities.mask import ( + MetaSwapActive, + mask_and_broadcast_cap_data, + mask_and_broadcast_pkg_data, +) from imod.msw.utilities.parse import read_para_sim from imod.msw.vegetation import AnnualCropFactors from imod.typing import GridDataArray, Imod5DataDict @@ -511,6 +515,57 @@ def regrid_like( return regridded_model + def mask_all_packages( + self, + msw_active: MetaSwapActive, + # ignore_time_purge_empty: bool = False, + ): + """ + This function applies a mask to all packages in a model. The mask must + be presented as a MetaSwap Active object, which contains idomain-like integers. + The mask is applied to all packages in the model, and the values in the mask determine which cells are active and which are inactive. The mask is applied to all packages, regardless of whether they have a subunit dimension or not. + + Parameters + ---------- + msw_active: MetaSwapActive, dictionary of xr.DataArray + idomain-like integers. >0 sets cells to active, 0 sets cells to inactive, + all: applies to all packages without a subunit dimension + subunit: applies to all packages with a subunit dimension on a per-subunit basis + (mask has a subunit dimension) + + Example + ------- + >>> mask_per_subunit = xr.DataArray( + >>> np.array( + >>> [ + >>> [[0, 0, 0], [0, 1, 1], [0, 0, 0]], + >>> [[1, 1, 1], [0, 1, 1], [0, 0, 0]], + >>> ] + >>> ).astype(bool), + >>> dims=("subunit", "y", "x"), + >>> coords = { + >>> "x" : [1.0, 2.0, 3.0], + >>> "y" : [3.0, 2.0, 1.0], + >>> "dx" : 1.0, + >>> "dy" : 1.0, + >>> "subunit" : [0, 1] + >>> } + >>> ) + >>> mask_all = mask_per_subunit.any(dim="subunit") + >>> msw_active = MetaSwapActive(mask_all, mask_per_subunit) + >>> msw_model.mask_all_packages(msw_active) + """ + + for pkg in self.values(): + if "x" in pkg.dataset.dims and "y" in pkg.dataset.dims: + data_dict = { + key: pkg.dataset[key] for key in pkg.dataset.data_vars.keys() + } + masked_data = mask_and_broadcast_pkg_data(pkg, data_dict, msw_active) + for key, data in masked_data.items(): + pkg.dataset[key] = data + return + def clip_box( self, time_min: Optional[cftime.datetime | np.datetime64 | str] = None, diff --git a/imod/tests/test_msw/test_model.py b/imod/tests/test_msw/test_model.py index 771fed5ba..80e0da17f 100644 --- a/imod/tests/test_msw/test_model.py +++ b/imod/tests/test_msw/test_model.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import cast +import numpy as np import pytest import xarray as xr from numpy.testing import assert_almost_equal, assert_equal @@ -12,6 +13,7 @@ from imod.msw.meteo_grid import MeteoGridCopy from imod.msw.meteo_mapping import MeteoMapping from imod.msw.model import DEFAULT_SETTINGS, MetaSwapModel +from imod.msw.utilities.mask import MetaSwapActive from imod.msw.utilities.parse import read_para_sim from imod.typing import GridDataArray, Imod5DataDict from imod.typing.grid import zeros_like @@ -19,6 +21,29 @@ from imod.util.spatial import empty_2d +@pytest.fixture(scope="function") +def mask_fixture() -> MetaSwapActive: + mask_per_subunit = xr.DataArray( + np.array( + [ + [[0, 0, 0], [0, 1, 1], [0, 0, 0]], + [[1, 1, 1], [0, 1, 1], [0, 0, 0]], + ] + ).astype(bool), + dims=("subunit", "y", "x"), + coords={ + "x": [1.0, 2.0, 3.0], + "y": [3.0, 2.0, 1.0], + "dx": 1.0, + "dy": 1.0, + "subunit": [0, 1], + }, + ) + mask_all = mask_per_subunit.any(dim="subunit") + + return MetaSwapActive(mask_all, mask_per_subunit) + + def roundtrip(msw_model, tmpdir_factory, name, engine): # TODO: look at the values? tmp_path = tmpdir_factory.mktemp(name) @@ -47,6 +72,25 @@ def test_msw_pkgdump_zarrzip(msw_model, tmpdir_factory): roundtrip(msw_model, tmpdir_factory, name="testmodel", engine="zarr.zip") +def test_msw_mask_all(msw_model, tmpdir_factory, mask_fixture): + # Apply the mask to all packages in the model + msw_model.mask_all_packages(mask_fixture) + + # Check that the mask has been applied correctly to each package + for pkgname, pkg in msw_model.items(): + if isinstance(pkg, msw.meteo_mapping.PrecipitationMapping): + continue # Skip PrecipitationMapping package for this test + if isinstance(pkg, msw.meteo_mapping.EvapotranspirationMapping): + continue # Skip EvapotranspirationMapping package for this test + for var in pkg.dataset.data_vars: + da = pkg.dataset[var] + if "y" in da.dims and "x" in da.dims: + assert ( + da.where(mask_fixture).equals(da) + or da.where(~mask_fixture).isnull().all() + ) + + def test_msw_model_write(msw_model, coupled_mf6_model, coupled_mf6wel, tmp_path): mf6_dis = coupled_mf6_model["GWF_1"]["dis"]