From 97a7aea91a59105ad6afa4bccb1eb53414cce719 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 15 Oct 2024 16:07:08 -0700 Subject: [PATCH 1/7] Start splitting write directive --- simpeg/directives/__init__.py | 1 + simpeg/directives/directives.py | 303 +++++++++++++++++++------------- 2 files changed, 178 insertions(+), 126 deletions(-) diff --git a/simpeg/directives/__init__.py b/simpeg/directives/__init__.py index 45c97db0f4..f9449f3a4d 100644 --- a/simpeg/directives/__init__.py +++ b/simpeg/directives/__init__.py @@ -119,6 +119,7 @@ UpdateSensitivityWeights, VectorInversion, SaveIterationsGeoH5, + SaveLogFilesGeoH5, ProjectSphericalBounds, ScaleMisfitMultipliers, ) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 5cdaba6152..a1d6032a34 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -1,5 +1,5 @@ from __future__ import annotations # needed to use type operands in Python 3.8 - +from abc import ABC, abstractmethod from pathlib import Path from datetime import datetime @@ -49,6 +49,8 @@ validate_ndarray_with_shape, ) +from geoh5py.groups.property_group import GroupTypeEnum +from geoh5py.groups import UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace @@ -2982,22 +2984,23 @@ def endIter(self): misfit.simulation.model = self.invProb.model -class SaveIterationsGeoH5(InversionDirective): - """ - Saves inversion results to a geoh5 file - """ - +class BaseSaveGeoH5(InversionDirective, ABC): def __init__( - self, h5_object, dmisfit=None, attribute_type: str = "model", **kwargs + self, + h5_object, + dmisfit=None, + attribute_type: str = "model", + channels: list[str] = ("",), + components: list[str] = ("",), + **kwargs ): self.data_type = {} self._association = None self.attribute_type = attribute_type self._label = None - self.channels = [""] - self.components = [""] + self.channels = channels + self.components = components self._transforms: list = [] - self.save_objective_function = False self.sorting = None self._reshape = None self.h5_object = h5_object @@ -3013,18 +3016,136 @@ def __init__( ) def initialize(self): - self.save_components(0) - - if self.save_objective_function: - self.write_update(0) - self.save_log() + self.write(0) def endIter(self): - self.save_components(self.opt.iter) + self.write(self.opt.iter) + + def get_names( + self, component: str, channel: str, iteration: int + ) -> tuple[str, str]: + """ + Format the data and property_group name. + """ + base_name = f"Iteration_{iteration}" + if len(component) > 0: + base_name += f"_{component}" + + channel_name = base_name + if channel: + channel_name += f"_{channel}" + + if self.label is not None: + channel_name += f"_{self.label}" + base_name += f"_{self.label}" + + return channel_name, base_name + + @abstractmethod + def write( # flake8: noqa + self, iteration: int, values: list[np.ndarray] = None + ): + """ + Save the components of the inversion. + """ + + @property + def joint_index(self): + """ + Index for joint inversions defining the element in the list of predicted data. + """ + return self._joint_index + + @joint_index.setter + def joint_index(self, value: list[int]): + if not isinstance(value, list): + raise TypeError("Input 'joint_index' should be a list of int") + + self._joint_index = value + + @property + def label(self): + return self._label + + @label.setter + def label(self, value: str): + assert isinstance(value, str), "'label' must be a string" + + self._label = value + + @property + def reshape(self): + """ + Reshape function + """ + if getattr(self, "_reshape", None) is None: + self._reshape = lambda x: x.reshape( + (len(self.channels), len(self.components), -1), order="F" + ) + + return self._reshape + + @reshape.setter + def reshape(self, fun): + self._reshape = fun + + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, funcs: list): + if not isinstance(funcs, list): + funcs = [funcs] + + for fun in funcs: + if not any( + [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] + ): + raise TypeError( + "Input transformation must be of type" + + "SimPEG.maps, numpy.ndarray or callable function" + ) + + self._transforms = funcs + + @property + def h5_object(self): + return self._h5_object + + @h5_object.setter + def h5_object(self, entity: ObjectBase): + if not isinstance(entity, ObjectBase | UIJsonGroup): + raise TypeError( + f"Input entity should be of type {ObjectBase}. {type(entity)} provided" + ) + + self._h5_object = entity.uid + self._geoh5 = entity.workspace + + if getattr(entity, "n_cells", None) is not None: + self.association = "CELL" + else: + self.association = "VERTEX" + + @property + def association(self): + return self._association + + @association.setter + def association(self, value): + if not value.upper() in ["CELL", "VERTEX"]: + raise ValueError( + f"'association must be one of 'CELL', 'VERTEX'. {value} provided" + ) + + self._association = value.upper() + - if self.save_objective_function: - self.write_update(self.opt.iter) - self.save_log() +class SaveIterationsGeoH5(BaseSaveGeoH5): + """ + Saves inversion results to a geoh5 file + """ def stack_channels(self, dpred: list): """ @@ -3053,26 +3174,6 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray: return prop - def get_names( - self, component: str, channel: str, iteration: int - ) -> tuple[str, str]: - """ - Format the data and property_group name. - """ - base_name = f"Iteration_{iteration}" - if len(component) > 0: - base_name += f"_{component}" - - channel_name = base_name - if channel: - channel_name += f"_{channel}" - - if self.label is not None: - channel_name += f"_{self.label}" - base_name += f"_{self.label}" - - return channel_name, base_name - def get_values(self, values: list[np.ndarray] | None): """ Get values for the inversion depending on the output type. @@ -3098,7 +3199,7 @@ def get_values(self, values: list[np.ndarray] | None): return prop - def save_components( # flake8: noqa + def write( # flake8: noqa self, iteration: int, values: list[np.ndarray] = None ): """ @@ -3149,10 +3250,10 @@ def save_components( # flake8: noqa if len(self.channels) > 1 and self.attribute_type == "predicted": h5_object.add_data_to_group(data, base_name) - def write_update(self, iteration: int): - """ - Write update to file. - """ + +class SaveLogFilesGeoH5(BaseSaveGeoH5): + + def write(self, iteration): dirpath = Path(self._geoh5.h5file).parent filepath = dirpath / "SimPEG.out" @@ -3167,6 +3268,8 @@ def write_update(self, iteration: int): f"{self.invProb.phi_m:.3e} {date_time}\n" ) + self.save_log() + def save_log(self): """ Save iteration metrics to comments. @@ -3186,103 +3289,51 @@ def save_log(self): raw_file = f.read() if h5_object.parent.get_entity(file)[0] is not None: - file_entity = h5_object.parent.get_entity(file)[0] + file_entity = h5_object.get_entity(file)[0] else: - file_entity = h5_object.parent.add_file(filepath) + file_entity = h5_object.add_file(filepath) file_entity.file_bytes = raw_file - @property - def joint_index(self): - """ - Index for joint inversions defining the element in the list of predicted data. - """ - return self._joint_index - @joint_index.setter - def joint_index(self, value: list[int]): - if not isinstance(value, list): - raise TypeError("Input 'joint_index' should be a list of int") - - self._joint_index = value +class SavePropertyGroup(BaseSaveGeoH5): + """ + Save the model as a property group in the geoh5 file + """ - @property - def label(self): - return self._label + def __init__( + self, + h5_object, + dmisfit=None, + attribute_type: str = "property_group", + group_type: GroupTypeEnum = "Dip direction & dip", + **kwargs + ): + self.group_type = group_type - @label.setter - def label(self, value: str): - assert isinstance(value, str), "'label' must be a string" + super().__init__(h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs) - self._label = value - @property - def reshape(self): + def write(self, iteration: int): """ - Reshape function + Save the model to the geoh5 file """ - if getattr(self, "_reshape", None) is None: - self._reshape = lambda x: x.reshape( - (len(self.channels), len(self.components), -1), order="F" - ) - - return self._reshape - - @reshape.setter - def reshape(self, fun): - self._reshape = fun - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, funcs: list): - if not isinstance(funcs, list): - funcs = [funcs] - - for fun in funcs: - if not any( - [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] - ): - raise TypeError( - "Input transformation must be of type" - + "SimPEG.maps, numpy.ndarray or callable function" - ) - - self._transforms = funcs - - @property - def h5_object(self): - return self._h5_object - - @h5_object.setter - def h5_object(self, entity: ObjectBase): - if not isinstance(entity, ObjectBase): - raise TypeError( - f"Input entity should be of type {ObjectBase}. {type(entity)} provided" - ) - - self._h5_object = entity.uid - self._geoh5 = entity.workspace + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] + properties = [] + for channel in self.channels: + for component in self.components: + channel_name, base_name = self.get_names( + component, channel, iteration + ) + child = h5_object.get_entity(channel_name)[0] - if getattr(entity, "n_cells", None) is not None: - self.association = "CELL" - else: - self.association = "VERTEX" + if child is not None: + properties.append(child) - @property - def association(self): - return self._association - @association.setter - def association(self, value): - if not value.upper() in ["CELL", "VERTEX"]: - raise ValueError( - f"'association must be one of 'CELL', 'VERTEX'. {value} provided" - ) + group = PropertyGroup(parent=h5_object, name=base_name, properties=properties, property_group_type=self.group_type) - self._association = value.upper() class VectorInversion(InversionDirective): From 894529baccaf12da9cfd0e96310b62624eef4b2c Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 16 Oct 2024 12:47:31 -0700 Subject: [PATCH 2/7] Continue split --- simpeg/directives/__init__.py | 4 +- simpeg/directives/directives.py | 154 +++++++++++++++++++------------- 2 files changed, 94 insertions(+), 64 deletions(-) diff --git a/simpeg/directives/__init__.py b/simpeg/directives/__init__.py index f9449f3a4d..1b781f785e 100644 --- a/simpeg/directives/__init__.py +++ b/simpeg/directives/__init__.py @@ -118,7 +118,9 @@ JointScalingSchedule, UpdateSensitivityWeights, VectorInversion, - SaveIterationsGeoH5, + SaveModelGeoH5, + SaveDataGeoH5, + SaveSensitivityGeoH5, SaveLogFilesGeoH5, ProjectSphericalBounds, ScaleMisfitMultipliers, diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index a1d6032a34..3b062cb24a 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3142,11 +3142,10 @@ def association(self, value): self._association = value.upper() -class SaveIterationsGeoH5(BaseSaveGeoH5): +class SaveIterationsGeoH5(BaseSaveGeoH5, ABC): """ Saves inversion results to a geoh5 file """ - def stack_channels(self, dpred: list): """ Regroup channel values along rows. @@ -3174,30 +3173,11 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray: return prop + @abstractmethod def get_values(self, values: list[np.ndarray] | None): """ Get values for the inversion depending on the output type. """ - prop = self.invProb.model - if values is not None: - prop = self.stack_channels(values) - elif self.attribute_type == "predicted": - dpred = getattr(self.invProb, "dpred", None) - if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) - self.invProb.dpred = dpred - - if self.joint_index is not None: - dpred = [dpred[ind] for ind in self.joint_index] - - prop = self.stack_channels(dpred) - elif self.attribute_type == "sensitivities": - - prop = np.zeros_like(self.invProb.model) - for fun in self.dmisfit.objfcts: - prop += fun.getJtJdiag(self.invProb.model) - - return prop def write( # flake8: noqa self, iteration: int, values: list[np.ndarray] = None @@ -3235,6 +3215,7 @@ def write( # flake8: noqa } } ) + # Re-assign the data type if channel not in self.data_type[component].keys(): self.data_type[component][channel] = data.entity_type type_name = f"{self.attribute_type}_{component}" @@ -3251,9 +3232,57 @@ def write( # flake8: noqa h5_object.add_data_to_group(data, base_name) +class SaveModelGeoH5(SaveIterationsGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = self.invProb.model + + return values + + +class SaveSensitivityGeoH5(SaveIterationsGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = np.zeros_like(self.invProb.model) + for fun in self.dmisfit.objfcts: + values += fun.getJtJdiag(self.invProb.model) + + return values + + +class SaveDataGeoH5(SaveIterationsGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + def get_values(self, values: list[np.ndarray] | None): + + if values is not None: + prop = self.stack_channels(values) + + else: + dpred = getattr(self.invProb, "dpred", None) + if dpred is None: + dpred = self.invProb.get_dpred(self.invProb.model) + self.invProb.dpred = dpred + + if self.joint_index is not None: + dpred = [dpred[ind] for ind in self.joint_index] + + prop = self.stack_channels(dpred) + + return prop + + class SaveLogFilesGeoH5(BaseSaveGeoH5): - def write(self, iteration): + def write(self, iteration: int, **_): dirpath = Path(self._geoh5.h5file).parent filepath = dirpath / "SimPEG.out" @@ -3288,51 +3317,50 @@ def save_log(self): with open(filepath, "rb") as f: raw_file = f.read() - if h5_object.parent.get_entity(file)[0] is not None: - file_entity = h5_object.get_entity(file)[0] - else: + file_entity = h5_object.get_entity(file)[0] + if file_entity is None: file_entity = h5_object.add_file(filepath) file_entity.file_bytes = raw_file -class SavePropertyGroup(BaseSaveGeoH5): - """ - Save the model as a property group in the geoh5 file - """ - - def __init__( - self, - h5_object, - dmisfit=None, - attribute_type: str = "property_group", - group_type: GroupTypeEnum = "Dip direction & dip", - **kwargs - ): - self.group_type = group_type - - super().__init__(h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs) - - - def write(self, iteration: int): - """ - Save the model to the geoh5 file - """ - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: - h5_object = w_s.get_entity(self.h5_object)[0] - properties = [] - for channel in self.channels: - for component in self.components: - channel_name, base_name = self.get_names( - component, channel, iteration - ) - child = h5_object.get_entity(channel_name)[0] - - if child is not None: - properties.append(child) - - - group = PropertyGroup(parent=h5_object, name=base_name, properties=properties, property_group_type=self.group_type) +# class SavePropertyGroup(BaseSaveGeoH5): +# """ +# Save the model as a property group in the geoh5 file +# """ +# +# def __init__( +# self, +# h5_object, +# dmisfit=None, +# attribute_type: str = "property_group", +# group_type: GroupTypeEnum = "Dip direction & dip", +# **kwargs +# ): +# self.group_type = group_type +# +# super().__init__(h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs) +# +# +# def write(self, iteration: int): +# """ +# Save the model to the geoh5 file +# """ +# with fetch_active_workspace(self._geoh5, mode="r+") as w_s: +# h5_object = w_s.get_entity(self.h5_object)[0] +# properties = [] +# for channel in self.channels: +# for component in self.components: +# channel_name, base_name = self.get_names( +# component, channel, iteration +# ) +# child = h5_object.get_entity(channel_name)[0] +# +# if child is not None: +# properties.append(child) +# +# +# group = PropertyGroup(parent=h5_object, name=base_name, properties=properties, property_group_type=self.group_type) From eff126e1f1e71523e44509618134d1f9f8753900 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 16 Oct 2024 14:25:58 -0700 Subject: [PATCH 3/7] Implement save property group directive --- simpeg/directives/__init__.py | 5 +- simpeg/directives/directives.py | 100 ++++++++++++++++++-------------- 2 files changed, 58 insertions(+), 47 deletions(-) diff --git a/simpeg/directives/__init__.py b/simpeg/directives/__init__.py index 1b781f785e..cbc9a6ab6e 100644 --- a/simpeg/directives/__init__.py +++ b/simpeg/directives/__init__.py @@ -118,10 +118,11 @@ JointScalingSchedule, UpdateSensitivityWeights, VectorInversion, - SaveModelGeoH5, SaveDataGeoH5, - SaveSensitivityGeoH5, SaveLogFilesGeoH5, + SaveModelGeoH5, + SavePropertyGroup, + SaveSensitivityGeoH5, ProjectSphericalBounds, ScaleMisfitMultipliers, ) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 3b062cb24a..281dde1a81 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -50,7 +50,7 @@ ) from geoh5py.groups.property_group import GroupTypeEnum -from geoh5py.groups import UIJsonGroup +from geoh5py.groups import PropertyGroup, UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace @@ -2992,7 +2992,7 @@ def __init__( attribute_type: str = "model", channels: list[str] = ("",), components: list[str] = ("",), - **kwargs + **kwargs, ): self.data_type = {} self._association = None @@ -3042,9 +3042,7 @@ def get_names( return channel_name, base_name @abstractmethod - def write( # flake8: noqa - self, iteration: int, values: list[np.ndarray] = None - ): + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa """ Save the components of the inversion. """ @@ -3146,6 +3144,7 @@ class SaveIterationsGeoH5(BaseSaveGeoH5, ABC): """ Saves inversion results to a geoh5 file """ + def stack_channels(self, dpred: list): """ Regroup channel values along rows. @@ -3179,9 +3178,7 @@ def get_values(self, values: list[np.ndarray] | None): Get values for the inversion depending on the output type. """ - def write( # flake8: noqa - self, iteration: int, values: list[np.ndarray] = None - ): + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa """ Sort, transform and store data per components and channels. """ @@ -3236,6 +3233,7 @@ class SaveModelGeoH5(SaveIterationsGeoH5): """ Save the model at the current iteration to a geoh5 file. """ + def get_values(self, values: list[np.ndarray] | None): if values is None: values = self.invProb.model @@ -3247,6 +3245,7 @@ class SaveSensitivityGeoH5(SaveIterationsGeoH5): """ Save the model at the current iteration to a geoh5 file. """ + def get_values(self, values: list[np.ndarray] | None): if values is None: values = np.zeros_like(self.invProb.model) @@ -3324,44 +3323,55 @@ def save_log(self): file_entity.file_bytes = raw_file -# class SavePropertyGroup(BaseSaveGeoH5): -# """ -# Save the model as a property group in the geoh5 file -# """ -# -# def __init__( -# self, -# h5_object, -# dmisfit=None, -# attribute_type: str = "property_group", -# group_type: GroupTypeEnum = "Dip direction & dip", -# **kwargs -# ): -# self.group_type = group_type -# -# super().__init__(h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs) -# -# -# def write(self, iteration: int): -# """ -# Save the model to the geoh5 file -# """ -# with fetch_active_workspace(self._geoh5, mode="r+") as w_s: -# h5_object = w_s.get_entity(self.h5_object)[0] -# properties = [] -# for channel in self.channels: -# for component in self.components: -# channel_name, base_name = self.get_names( -# component, channel, iteration -# ) -# child = h5_object.get_entity(channel_name)[0] -# -# if child is not None: -# properties.append(child) -# -# -# group = PropertyGroup(parent=h5_object, name=base_name, properties=properties, property_group_type=self.group_type) +class SavePropertyGroup(BaseSaveGeoH5): + """ + Save the model as a property group in the geoh5 file + """ + + def __init__( + self, + h5_object, + dmisfit=None, + attribute_type: str = "property_group", + group_type: GroupTypeEnum = "Dip direction & dip", + **kwargs, + ): + self.group_type = group_type + super().__init__( + h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs + ) + + def write(self, iteration: int, **_): + """ + Save the model to the geoh5 file + """ + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] + properties = [] + for channel in self.channels: + for component in self.components: + channel_name, base_name = self.get_names( + component, channel, iteration + ) + child = [ + child + for child in h5_object.children + if channel_name in child.name + ][0] + + if child is not None: + properties.append(child) + + if len(properties) == 0: + return + + PropertyGroup( + parent=h5_object, + name=base_name, + properties=properties, + property_group_type=self.group_type, + ) class VectorInversion(InversionDirective): From fba6934e3b63ca3daca946e0397f1a5796942c9c Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 16 Oct 2024 15:51:46 -0700 Subject: [PATCH 4/7] Further cleanups --- simpeg/directives/directives.py | 216 ++++++++++++++++++-------------- 1 file changed, 124 insertions(+), 92 deletions(-) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 281dde1a81..ac361cf7a6 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -2985,31 +2985,27 @@ def endIter(self): class BaseSaveGeoH5(InversionDirective, ABC): + """ + Base class for saving inversion results to a geoh5 file + """ + def __init__( self, h5_object, dmisfit=None, - attribute_type: str = "model", + label: str | None = None, channels: list[str] = ("",), components: list[str] = ("",), + association: str | None = None, **kwargs, ): - self.data_type = {} - self._association = None - self.attribute_type = attribute_type - self._label = None + self.label = label self.channels = channels self.components = components - self._transforms: list = [] - self.sorting = None - self._reshape = None self.h5_object = h5_object - self._joint_index = None - if attribute_type == "sensitivities" and dmisfit is None: - raise ValueError( - "To save sensitivities, the data misfit object must be provided." - ) + if association is not None: + self.association = association super().__init__( inversion=None, dmisfit=dmisfit, reg=None, verbose=False, **kwargs @@ -3047,66 +3043,17 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq Save the components of the inversion. """ - @property - def joint_index(self): - """ - Index for joint inversions defining the element in the list of predicted data. - """ - return self._joint_index - - @joint_index.setter - def joint_index(self, value: list[int]): - if not isinstance(value, list): - raise TypeError("Input 'joint_index' should be a list of int") - - self._joint_index = value - @property def label(self): return self._label @label.setter - def label(self, value: str): - assert isinstance(value, str), "'label' must be a string" + def label(self, value: str | None): + if not isinstance(value, str | type(None)): + raise TypeError("'label' must be a string or None") self._label = value - @property - def reshape(self): - """ - Reshape function - """ - if getattr(self, "_reshape", None) is None: - self._reshape = lambda x: x.reshape( - (len(self.channels), len(self.components), -1), order="F" - ) - - return self._reshape - - @reshape.setter - def reshape(self, fun): - self._reshape = fun - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, funcs: list): - if not isinstance(funcs, list): - funcs = [funcs] - - for fun in funcs: - if not any( - [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] - ): - raise TypeError( - "Input transformation must be of type" - + "SimPEG.maps, numpy.ndarray or callable function" - ) - - self._transforms = funcs - @property def h5_object(self): return self._h5_object @@ -3140,11 +3087,71 @@ def association(self, value): self._association = value.upper() -class SaveIterationsGeoH5(BaseSaveGeoH5, ABC): +class SaveArrayGeoH5(BaseSaveGeoH5, ABC): """ - Saves inversion results to a geoh5 file + Saves array-based inversion results (model, data) to a geoh5 file. + + Parameters + ---------- + + transforms: List of transformations applied to the values before save. + sorting: Special re-indexing of the vector values. + reshape: Re-ordering applied to the data before slicing. """ + _attribute_type = None + + def __init__( + self, + h5_object, + transforms: list | tuple = (), + reshape=None, + sorting=None, + **kwargs, + ): + self.data_type = {} + self.transforms = transforms + self.sorting = sorting + self.reshape = reshape + + super().__init__(h5_object, **kwargs) + + @property + def reshape(self): + """ + Reshape function + """ + if getattr(self, "_reshape", None) is None: + self._reshape = lambda x: x.reshape( + (len(self.channels), len(self.components), -1), order="F" + ) + + return self._reshape + + @reshape.setter + def reshape(self, fun): + self._reshape = fun + + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, funcs: list | tuple): + if not isinstance(funcs, list | tuple): + funcs = [funcs] + + for fun in funcs: + if not any( + [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] + ): + raise TypeError( + "Input transformation must be of type" + + "SimPEG.maps, numpy.ndarray or callable function" + ) + + self._transforms = funcs + def stack_channels(self, dpred: list): """ Regroup channel values along rows. @@ -3215,7 +3222,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq # Re-assign the data type if channel not in self.data_type[component].keys(): self.data_type[component][channel] = data.entity_type - type_name = f"{self.attribute_type}_{component}" + type_name = f"{self._attribute_type}_{component}" if channel: type_name += f"_{channel}" data.entity_type.name = type_name @@ -3225,15 +3232,14 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq type(self.data_type[component][channel]), ) - if len(self.channels) > 1 and self.attribute_type == "predicted": - h5_object.add_data_to_group(data, base_name) - -class SaveModelGeoH5(SaveIterationsGeoH5): +class SaveModelGeoH5(SaveArrayGeoH5): """ Save the model at the current iteration to a geoh5 file. """ + _attribute_type = "model" + def get_values(self, values: list[np.ndarray] | None): if values is None: values = self.invProb.model @@ -3241,11 +3247,20 @@ def get_values(self, values: list[np.ndarray] | None): return values -class SaveSensitivityGeoH5(SaveIterationsGeoH5): +class SaveSensitivityGeoH5(SaveArrayGeoH5): """ Save the model at the current iteration to a geoh5 file. """ + _attribute_type = "sensitivities" + + def __init__(self, h5_object, dmisfit=None, **kwargs): + if dmisfit is None: + raise ValueError( + "To save sensitivities, the data misfit object must be provided." + ) + super().__init__(h5_object, dmisfit=dmisfit, **kwargs) + def get_values(self, values: list[np.ndarray] | None): if values is None: values = np.zeros_like(self.invProb.model) @@ -3255,11 +3270,18 @@ def get_values(self, values: list[np.ndarray] | None): return values -class SaveDataGeoH5(SaveIterationsGeoH5): +class SaveDataGeoH5(SaveArrayGeoH5): """ Save the model at the current iteration to a geoh5 file. """ + _attribute_type = "predicted" + + def __init__(self, h5_object, joint_index: list[int] | None = None, **kwargs): + self.joint_index = joint_index + + super().__init__(h5_object, **kwargs) + def get_values(self, values: list[np.ndarray] | None): if values is not None: @@ -3278,6 +3300,20 @@ def get_values(self, values: list[np.ndarray] | None): return prop + @property + def joint_index(self): + """ + Index for joint inversions defining the element in the list of predicted data. + """ + return self._joint_index + + @joint_index.setter + def joint_index(self, value: list[int] | None): + if not isinstance(value, list | type(None)): + raise TypeError("Input 'joint_index' should be a list of int") + + self._joint_index = value + class SaveLogFilesGeoH5(BaseSaveGeoH5): @@ -3331,16 +3367,12 @@ class SavePropertyGroup(BaseSaveGeoH5): def __init__( self, h5_object, - dmisfit=None, - attribute_type: str = "property_group", group_type: GroupTypeEnum = "Dip direction & dip", **kwargs, ): self.group_type = group_type - super().__init__( - h5_object, dmisfit=dmisfit, attribute_type=attribute_type, **kwargs - ) + super().__init__(h5_object, **kwargs) def write(self, iteration: int, **_): """ @@ -3349,8 +3381,8 @@ def write(self, iteration: int, **_): with fetch_active_workspace(self._geoh5, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] properties = [] - for channel in self.channels: - for component in self.components: + for component in self.components: + for channel in self.channels: channel_name, base_name = self.get_names( component, channel, iteration ) @@ -3493,18 +3525,18 @@ def endIter(self): # Add and update directives for directive in self.inversion.directiveList.dList: - if isinstance(directive, SaveIterationsGeoH5): + if ( + isinstance(directive, SaveModelGeoH5) + and cartesian2amplitude_dip_azimuth in directive.transforms + ): transforms = [] - if ( - directive.attribute_type == "model" - and cartesian2amplitude_dip_azimuth in directive.transforms - ): - for fun in directive.transforms: - if fun is cartesian2amplitude_dip_azimuth: - transforms += [spherical2cartesian] - transforms += [fun] - directive.transforms = transforms + for fun in directive.transforms: + if fun is cartesian2amplitude_dip_azimuth: + transforms += [spherical2cartesian] + transforms += [fun] + + directive.transforms = transforms elif isinstance(directive, Update_IRLS): directive.sphericalDomain = True @@ -3520,7 +3552,7 @@ def endIter(self): self.inversion.directiveList = directiveList for directive in directiveList: - if not isinstance(directive, SaveIterationsGeoH5): + if not isinstance(directive, BaseSaveGeoH5): directive.endIter() From 99f4787fa7d7bdacf47c1bbc896a1599c205c6ab Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 16 Oct 2024 16:13:45 -0700 Subject: [PATCH 5/7] Fix group creation --- simpeg/directives/directives.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index ac361cf7a6..ac95c20c32 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3367,7 +3367,7 @@ class SavePropertyGroup(BaseSaveGeoH5): def __init__( self, h5_object, - group_type: GroupTypeEnum = "Dip direction & dip", + group_type: GroupTypeEnum = GroupTypeEnum.MULTI, **kwargs, ): self.group_type = group_type @@ -3380,9 +3380,11 @@ def write(self, iteration: int, **_): """ with fetch_active_workspace(self._geoh5, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] - properties = [] + for component in self.components: + properties = [] for channel in self.channels: + channel_name, base_name = self.get_names( component, channel, iteration ) @@ -3395,15 +3397,15 @@ def write(self, iteration: int, **_): if child is not None: properties.append(child) - if len(properties) == 0: - return + if len(properties) == 0: + return - PropertyGroup( - parent=h5_object, - name=base_name, - properties=properties, - property_group_type=self.group_type, - ) + PropertyGroup( + parent=h5_object, + name=base_name, + properties=properties, + property_group_type=self.group_type, + ) class VectorInversion(InversionDirective): From ba618807abb3ce606321f8b6d852ed8c8f5c3947 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 17 Oct 2024 09:15:03 -0700 Subject: [PATCH 6/7] Bump pymatsolver --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5c965ed27a..246f77537c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "matplotlib", "numpy>=1.20", "pandas", - "pymatsolver>=0.2", + "pymatsolver>=0.3", "scikit-learn>=1.2", "scipy>=1.8.0", ] From e179b8e4ac769022773b898e174b3c1e0a80e681 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 17 Oct 2024 09:17:45 -0700 Subject: [PATCH 7/7] Bring back down pymatsolver --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 246f77537c..65a70f6633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "matplotlib", "numpy>=1.20", "pandas", - "pymatsolver>=0.3", + "pymatsolver>=0.2, <0.3.0", "scikit-learn>=1.2", "scipy>=1.8.0", ]