diff --git a/pyproject.toml b/pyproject.toml index 5c965ed27a..65a70f6633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "matplotlib", "numpy>=1.20", "pandas", - "pymatsolver>=0.2", + "pymatsolver>=0.2, <0.3.0", "scikit-learn>=1.2", "scipy>=1.8.0", ] diff --git a/simpeg/directives/__init__.py b/simpeg/directives/__init__.py index 45c97db0f4..cbc9a6ab6e 100644 --- a/simpeg/directives/__init__.py +++ b/simpeg/directives/__init__.py @@ -118,7 +118,11 @@ JointScalingSchedule, UpdateSensitivityWeights, VectorInversion, - SaveIterationsGeoH5, + SaveDataGeoH5, + SaveLogFilesGeoH5, + SaveModelGeoH5, + SavePropertyGroup, + SaveSensitivityGeoH5, ProjectSphericalBounds, ScaleMisfitMultipliers, ) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 5cdaba6152..ac95c20c32 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -1,5 +1,5 @@ from __future__ import annotations # needed to use type operands in Python 3.8 - +from abc import ABC, abstractmethod from pathlib import Path from datetime import datetime @@ -49,6 +49,8 @@ validate_ndarray_with_shape, ) +from geoh5py.groups.property_group import GroupTypeEnum +from geoh5py.groups import PropertyGroup, UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace @@ -2982,49 +2984,173 @@ def endIter(self): misfit.simulation.model = self.invProb.model -class SaveIterationsGeoH5(InversionDirective): +class BaseSaveGeoH5(InversionDirective, ABC): """ - Saves inversion results to a geoh5 file + Base class for saving inversion results to a geoh5 file """ def __init__( - self, h5_object, dmisfit=None, attribute_type: str = "model", **kwargs + self, + h5_object, + dmisfit=None, + label: str | None = None, + channels: list[str] = ("",), + components: list[str] = ("",), + association: str | None = None, + **kwargs, ): - self.data_type = {} - self._association = None - self.attribute_type = attribute_type - self._label = None - self.channels = [""] - self.components = [""] - self._transforms: list = [] - self.save_objective_function = False - self.sorting = None - self._reshape = None + self.label = label + self.channels = channels + self.components = components self.h5_object = h5_object - self._joint_index = None - if attribute_type == "sensitivities" and dmisfit is None: - raise ValueError( - "To save sensitivities, the data misfit object must be provided." - ) + if association is not None: + self.association = association super().__init__( inversion=None, dmisfit=dmisfit, reg=None, verbose=False, **kwargs ) def initialize(self): - self.save_components(0) - - if self.save_objective_function: - self.write_update(0) - self.save_log() + self.write(0) def endIter(self): - self.save_components(self.opt.iter) + self.write(self.opt.iter) + + def get_names( + self, component: str, channel: str, iteration: int + ) -> tuple[str, str]: + """ + Format the data and property_group name. + """ + base_name = f"Iteration_{iteration}" + if len(component) > 0: + base_name += f"_{component}" + + channel_name = base_name + if channel: + channel_name += f"_{channel}" + + if self.label is not None: + channel_name += f"_{self.label}" + base_name += f"_{self.label}" + + return channel_name, base_name + + @abstractmethod + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa + """ + Save the components of the inversion. + """ + + @property + def label(self): + return self._label + + @label.setter + def label(self, value: str | None): + if not isinstance(value, str | type(None)): + raise TypeError("'label' must be a string or None") + + self._label = value + + @property + def h5_object(self): + return self._h5_object + + @h5_object.setter + def h5_object(self, entity: ObjectBase): + if not isinstance(entity, ObjectBase | UIJsonGroup): + raise TypeError( + f"Input entity should be of type {ObjectBase}. {type(entity)} provided" + ) + + self._h5_object = entity.uid + self._geoh5 = entity.workspace + + if getattr(entity, "n_cells", None) is not None: + self.association = "CELL" + else: + self.association = "VERTEX" - if self.save_objective_function: - self.write_update(self.opt.iter) - self.save_log() + @property + def association(self): + return self._association + + @association.setter + def association(self, value): + if not value.upper() in ["CELL", "VERTEX"]: + raise ValueError( + f"'association must be one of 'CELL', 'VERTEX'. {value} provided" + ) + + self._association = value.upper() + + +class SaveArrayGeoH5(BaseSaveGeoH5, ABC): + """ + Saves array-based inversion results (model, data) to a geoh5 file. + + Parameters + ---------- + + transforms: List of transformations applied to the values before save. + sorting: Special re-indexing of the vector values. + reshape: Re-ordering applied to the data before slicing. + """ + + _attribute_type = None + + def __init__( + self, + h5_object, + transforms: list | tuple = (), + reshape=None, + sorting=None, + **kwargs, + ): + self.data_type = {} + self.transforms = transforms + self.sorting = sorting + self.reshape = reshape + + super().__init__(h5_object, **kwargs) + + @property + def reshape(self): + """ + Reshape function + """ + if getattr(self, "_reshape", None) is None: + self._reshape = lambda x: x.reshape( + (len(self.channels), len(self.components), -1), order="F" + ) + + return self._reshape + + @reshape.setter + def reshape(self, fun): + self._reshape = fun + + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, funcs: list | tuple): + if not isinstance(funcs, list | tuple): + funcs = [funcs] + + for fun in funcs: + if not any( + [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] + ): + raise TypeError( + "Input transformation must be of type" + + "SimPEG.maps, numpy.ndarray or callable function" + ) + + self._transforms = funcs def stack_channels(self, dpred: list): """ @@ -3053,54 +3179,13 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray: return prop - def get_names( - self, component: str, channel: str, iteration: int - ) -> tuple[str, str]: - """ - Format the data and property_group name. - """ - base_name = f"Iteration_{iteration}" - if len(component) > 0: - base_name += f"_{component}" - - channel_name = base_name - if channel: - channel_name += f"_{channel}" - - if self.label is not None: - channel_name += f"_{self.label}" - base_name += f"_{self.label}" - - return channel_name, base_name - + @abstractmethod def get_values(self, values: list[np.ndarray] | None): """ Get values for the inversion depending on the output type. """ - prop = self.invProb.model - if values is not None: - prop = self.stack_channels(values) - elif self.attribute_type == "predicted": - dpred = getattr(self.invProb, "dpred", None) - if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) - self.invProb.dpred = dpred - - if self.joint_index is not None: - dpred = [dpred[ind] for ind in self.joint_index] - - prop = self.stack_channels(dpred) - elif self.attribute_type == "sensitivities": - - prop = np.zeros_like(self.invProb.model) - for fun in self.dmisfit.objfcts: - prop += fun.getJtJdiag(self.invProb.model) - - return prop - def save_components( # flake8: noqa - self, iteration: int, values: list[np.ndarray] = None - ): + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa """ Sort, transform and store data per components and channels. """ @@ -3134,9 +3219,10 @@ def save_components( # flake8: noqa } } ) + # Re-assign the data type if channel not in self.data_type[component].keys(): self.data_type[component][channel] = data.entity_type - type_name = f"{self.attribute_type}_{component}" + type_name = f"{self._attribute_type}_{component}" if channel: type_name += f"_{channel}" data.entity_type.name = type_name @@ -3146,13 +3232,92 @@ def save_components( # flake8: noqa type(self.data_type[component][channel]), ) - if len(self.channels) > 1 and self.attribute_type == "predicted": - h5_object.add_data_to_group(data, base_name) - def write_update(self, iteration: int): +class SaveModelGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "model" + + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = self.invProb.model + + return values + + +class SaveSensitivityGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "sensitivities" + + def __init__(self, h5_object, dmisfit=None, **kwargs): + if dmisfit is None: + raise ValueError( + "To save sensitivities, the data misfit object must be provided." + ) + super().__init__(h5_object, dmisfit=dmisfit, **kwargs) + + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = np.zeros_like(self.invProb.model) + for fun in self.dmisfit.objfcts: + values += fun.getJtJdiag(self.invProb.model) + + return values + + +class SaveDataGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "predicted" + + def __init__(self, h5_object, joint_index: list[int] | None = None, **kwargs): + self.joint_index = joint_index + + super().__init__(h5_object, **kwargs) + + def get_values(self, values: list[np.ndarray] | None): + + if values is not None: + prop = self.stack_channels(values) + + else: + dpred = getattr(self.invProb, "dpred", None) + if dpred is None: + dpred = self.invProb.get_dpred(self.invProb.model) + self.invProb.dpred = dpred + + if self.joint_index is not None: + dpred = [dpred[ind] for ind in self.joint_index] + + prop = self.stack_channels(dpred) + + return prop + + @property + def joint_index(self): """ - Write update to file. + Index for joint inversions defining the element in the list of predicted data. """ + return self._joint_index + + @joint_index.setter + def joint_index(self, value: list[int] | None): + if not isinstance(value, list | type(None)): + raise TypeError("Input 'joint_index' should be a list of int") + + self._joint_index = value + + +class SaveLogFilesGeoH5(BaseSaveGeoH5): + + def write(self, iteration: int, **_): dirpath = Path(self._geoh5.h5file).parent filepath = dirpath / "SimPEG.out" @@ -3167,6 +3332,8 @@ def write_update(self, iteration: int): f"{self.invProb.phi_m:.3e} {date_time}\n" ) + self.save_log() + def save_log(self): """ Save iteration metrics to comments. @@ -3185,105 +3352,61 @@ def save_log(self): with open(filepath, "rb") as f: raw_file = f.read() - if h5_object.parent.get_entity(file)[0] is not None: - file_entity = h5_object.parent.get_entity(file)[0] - else: - file_entity = h5_object.parent.add_file(filepath) + file_entity = h5_object.get_entity(file)[0] + if file_entity is None: + file_entity = h5_object.add_file(filepath) file_entity.file_bytes = raw_file - @property - def joint_index(self): - """ - Index for joint inversions defining the element in the list of predicted data. - """ - return self._joint_index - - @joint_index.setter - def joint_index(self, value: list[int]): - if not isinstance(value, list): - raise TypeError("Input 'joint_index' should be a list of int") - self._joint_index = value - - @property - def label(self): - return self._label +class SavePropertyGroup(BaseSaveGeoH5): + """ + Save the model as a property group in the geoh5 file + """ - @label.setter - def label(self, value: str): - assert isinstance(value, str), "'label' must be a string" + def __init__( + self, + h5_object, + group_type: GroupTypeEnum = GroupTypeEnum.MULTI, + **kwargs, + ): + self.group_type = group_type - self._label = value + super().__init__(h5_object, **kwargs) - @property - def reshape(self): + def write(self, iteration: int, **_): """ - Reshape function + Save the model to the geoh5 file """ - if getattr(self, "_reshape", None) is None: - self._reshape = lambda x: x.reshape( - (len(self.channels), len(self.components), -1), order="F" - ) - - return self._reshape - - @reshape.setter - def reshape(self, fun): - self._reshape = fun + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] - @property - def transforms(self): - return self._transforms + for component in self.components: + properties = [] + for channel in self.channels: - @transforms.setter - def transforms(self, funcs: list): - if not isinstance(funcs, list): - funcs = [funcs] - - for fun in funcs: - if not any( - [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] - ): - raise TypeError( - "Input transformation must be of type" - + "SimPEG.maps, numpy.ndarray or callable function" + channel_name, base_name = self.get_names( + component, channel, iteration + ) + child = [ + child + for child in h5_object.children + if channel_name in child.name + ][0] + + if child is not None: + properties.append(child) + + if len(properties) == 0: + return + + PropertyGroup( + parent=h5_object, + name=base_name, + properties=properties, + property_group_type=self.group_type, ) - self._transforms = funcs - - @property - def h5_object(self): - return self._h5_object - - @h5_object.setter - def h5_object(self, entity: ObjectBase): - if not isinstance(entity, ObjectBase): - raise TypeError( - f"Input entity should be of type {ObjectBase}. {type(entity)} provided" - ) - - self._h5_object = entity.uid - self._geoh5 = entity.workspace - - if getattr(entity, "n_cells", None) is not None: - self.association = "CELL" - else: - self.association = "VERTEX" - - @property - def association(self): - return self._association - - @association.setter - def association(self, value): - if not value.upper() in ["CELL", "VERTEX"]: - raise ValueError( - f"'association must be one of 'CELL', 'VERTEX'. {value} provided" - ) - - self._association = value.upper() - class VectorInversion(InversionDirective): """ @@ -3404,18 +3527,18 @@ def endIter(self): # Add and update directives for directive in self.inversion.directiveList.dList: - if isinstance(directive, SaveIterationsGeoH5): + if ( + isinstance(directive, SaveModelGeoH5) + and cartesian2amplitude_dip_azimuth in directive.transforms + ): transforms = [] - if ( - directive.attribute_type == "model" - and cartesian2amplitude_dip_azimuth in directive.transforms - ): - for fun in directive.transforms: - if fun is cartesian2amplitude_dip_azimuth: - transforms += [spherical2cartesian] - transforms += [fun] - directive.transforms = transforms + for fun in directive.transforms: + if fun is cartesian2amplitude_dip_azimuth: + transforms += [spherical2cartesian] + transforms += [fun] + + directive.transforms = transforms elif isinstance(directive, Update_IRLS): directive.sphericalDomain = True @@ -3431,7 +3554,7 @@ def endIter(self): self.inversion.directiveList = directiveList for directive in directiveList: - if not isinstance(directive, SaveIterationsGeoH5): + if not isinstance(directive, BaseSaveGeoH5): directive.endIter()