diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08be80b262..3329c7d633 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,6 +21,7 @@ repos: - flake8-mutable==1.2.0 - flake8-rst-docstrings==0.3.0 - flake8-docstrings==1.7.0 + - flake8-pyproject==1.2.3 - repo: https://github.com/MiraGeoscience/pre-commit-hooks rev: v1.1.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 52010b8733..50b9af3677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,7 @@ exclude_also = [ directory = "coverage_html_report" [tool.black] -required-version = '24.3.0' +required-version = '25.1.0' target-version = ['py38', 'py39', 'py310', 'py311'] [tool.flake8] diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index be7cf83eea..55b44b76c2 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,5 +1,40 @@ +import numpy as np from ....potential_fields.magnetics import Simulation3DIntegral as Sim -from ...simulation import getJtJdiag +from ....utils import sdiag, mkvc + + +def dask_getJtJdiag(self, m, W=None, f=None): + """ + Return the diagonal of JtJ + """ + + self.model = m + + self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() + + if getattr(self, "_gtg_diagonal", None) is None: + if not self.is_amplitude_data: + diag = np.asarray(np.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix)) + else: + ampDeriv = self.ampDeriv + J = ( + ampDeriv[0, :, None] * self.Jmatrix[::3] + + ampDeriv[1, :, None] * self.Jmatrix[1::3] + + ampDeriv[2, :, None] * self.Jmatrix[2::3] + ) + diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() + self._gtg_diagonal = diag + else: + diag = self._gtg_diagonal + + return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) + + +Sim.getJtJdiag = dask_getJtJdiag @property @@ -14,5 +49,4 @@ def G(self): Sim._delete_on_model_update = [] -Sim.getJtJdiag = getJtJdiag Sim.G = G diff --git a/simpeg/directives/__init__.py b/simpeg/directives/__init__.py index a22ef6566b..737425f76f 100644 --- a/simpeg/directives/__init__.py +++ b/simpeg/directives/__init__.py @@ -118,15 +118,24 @@ ScalingMultipleDataMisfits_ByEig, JointScalingSchedule, UpdateSensitivityWeights, - VectorInversion, + Update_IRLS, + ScaleMisfitMultipliers, +) + +from ._save_geoh5 import ( + BaseSaveGeoH5, SaveDataGeoH5, SaveLogFilesGeoH5, SaveModelGeoH5, SavePropertyGroup, SaveSensitivityGeoH5, - Update_IRLS, +) + +from ._regularization import UpdateIRLS, SphericalUnitsWeights + +from ._vector_models import ( + VectorInversion, ProjectSphericalBounds, - ScaleMisfitMultipliers, ) from .pgi_directives import ( @@ -135,8 +144,6 @@ PGI_AddMrefInSmooth, ) -from ._regularization import UpdateIRLS, SphericalUnitsWeights - from .sim_directives import ( SimilarityMeasureInversionDirective, SimilarityMeasureSaveOutputEveryIteration, diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py new file mode 100644 index 0000000000..f4ea2b134b --- /dev/null +++ b/simpeg/directives/_save_geoh5.py @@ -0,0 +1,459 @@ +import re +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path + +import numpy as np + +from .directives import InversionDirective +from simpeg.maps import IdentityMap + +from geoh5py.groups.property_group import GroupTypeEnum +from geoh5py.groups import PropertyGroup, UIJsonGroup +from geoh5py.objects import ObjectBase +from geoh5py.ui_json.utils import fetch_active_workspace + + +def compute_JtJdiags(data_misfit, m): + if hasattr(data_misfit, "getJtJdiag"): + return data_misfit.getJtJdiag(m) + else: + jtj_diags = [] + for dmisfit in data_misfit.objfcts: + jtj_diags.append(dmisfit.getJtJdiag(m)) + + jtj_diag = np.zeros_like(jtj_diags[0]) + for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + jtj_diag += multiplier * diag + + return np.asarray(jtj_diag) + + +class BaseSaveGeoH5(InversionDirective, ABC): + """ + Base class for saving inversion results to a geoh5 file + """ + + def __init__( + self, + h5_object, + dmisfit=None, + label: str | None = None, + channels: list[str] = ("",), + components: list[str] = ("",), + association: str | None = None, + **kwargs, + ): + self.label = label + self.channels = channels + self.components = components + self.h5_object = h5_object + + if association is not None: + self.association = association + + super().__init__( + inversion=None, dmisfit=dmisfit, reg=None, verbose=False, **kwargs + ) + + def initialize(self): + self.write(0) + + def endIter(self): + self.write(self.opt.iter) + + def get_names( + self, component: str, channel: str, iteration: int + ) -> tuple[str, str]: + """ + Format the data and property_group name. + """ + base_name = f"Iteration_{iteration}" + if len(component) > 0: + base_name += f"_{component}" + + channel_name = base_name + if channel: + channel_name += f"_{channel}" + + if self.label is not None: + channel_name += f"_{self.label}" + base_name += f"_{self.label}" + + return channel_name, base_name + + @abstractmethod + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa + """ + Save the components of the inversion. + """ + + @property + def label(self): + return self._label + + @label.setter + def label(self, value: str | None): + if not isinstance(value, str | type(None)): + raise TypeError("'label' must be a string or None") + + self._label = value + + @property + def h5_object(self): + return self._h5_object + + @h5_object.setter + def h5_object(self, entity: ObjectBase): + if not isinstance(entity, ObjectBase | UIJsonGroup): + raise TypeError( + f"Input entity should be of type {ObjectBase}. {type(entity)} provided" + ) + + self._h5_object = entity.uid + self._geoh5 = entity.workspace + + if getattr(entity, "n_cells", None) is not None: + self.association = "CELL" + else: + self.association = "VERTEX" + + @property + def association(self): + return self._association + + @association.setter + def association(self, value): + if not value.upper() in ["CELL", "VERTEX"]: + raise ValueError( + f"'association must be one of 'CELL', 'VERTEX'. {value} provided" + ) + + self._association = value.upper() + + +class SaveArrayGeoH5(BaseSaveGeoH5, ABC): + """ + Saves array-based inversion results (model, data) to a geoh5 file. + + Parameters + ---------- + + transforms: List of transformations applied to the values before save. + sorting: Special re-indexing of the vector values. + reshape: Re-ordering applied to the data before slicing. + """ + + _attribute_type = None + + def __init__( + self, + h5_object, + transforms: list | tuple = (), + reshape=None, + sorting=None, + **kwargs, + ): + self.data_type = {} + self.transforms = transforms + self.sorting = sorting + self.reshape = reshape + + super().__init__(h5_object, **kwargs) + + @property + def reshape(self): + """ + Reshape function + """ + if getattr(self, "_reshape", None) is None: + self._reshape = lambda x: x.reshape( + (len(self.channels), len(self.components), -1), order="F" + ) + + return self._reshape + + @reshape.setter + def reshape(self, fun): + self._reshape = fun + + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, funcs: list | tuple): + if not isinstance(funcs, list | tuple): + funcs = [funcs] + + for fun in funcs: + if not any( + [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] + ): + raise TypeError( + "Input transformation must be of type" + + "SimPEG.maps, numpy.ndarray or callable function" + ) + + self._transforms = funcs + + def stack_channels(self, dpred: list): + """ + Regroup channel values along rows. + """ + if isinstance(dpred, np.ndarray): + return self.reshape(dpred) + + return self.reshape(np.hstack(dpred)) + + def apply_transformations(self, prop: np.ndarray) -> np.ndarray: + """ + Re-order the values and apply transformations. + """ + prop = prop.flatten() + for fun in self.transforms: + if isinstance(fun, (IdentityMap, np.ndarray, float)): + prop = fun * prop + else: + prop = fun(prop) + + if prop.ndim == 2: + prop = prop.T.flatten() + + prop = prop.reshape((len(self.channels), len(self.components), -1)) + + return prop + + @abstractmethod + def get_values(self, values: list[np.ndarray] | None): + """ + Get values for the inversion depending on the output type. + """ + + def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa + """ + Sort, transform and store data per components and channels. + """ + prop = self.get_values(values) + + # Apply transformations + prop = self.apply_transformations(prop) + + # Save results + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] + for cc, component in enumerate(self.components): + if component not in self.data_type: + self.data_type[component] = {} + + for ii, channel in enumerate(self.channels): + values = prop[ii, cc, :] + + if self.sorting is not None: + values = values[self.sorting] + + channel_name, base_name = self.get_names( + component, channel, iteration + ) + + data = h5_object.add_data( + { + channel_name: { + "association": self.association, + "values": values, + } + } + ) + # Re-assign the data type + if channel not in self.data_type[component].keys(): + self.data_type[component][channel] = data.entity_type + type_name = f"{self._attribute_type}_{component}" + if channel: + type_name += f"_{channel}" + data.entity_type.name = type_name + else: + data.entity_type = w_s.find_type( + self.data_type[component][channel].uid, + type(self.data_type[component][channel]), + ) + + +class SaveModelGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "model" + + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = self.invProb.model + + return values + + +class SaveSensitivityGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "sensitivities" + + def __init__(self, h5_object, dmisfit=None, **kwargs): + if dmisfit is None: + raise ValueError( + "To save sensitivities, the data misfit object must be provided." + ) + super().__init__(h5_object, dmisfit=dmisfit, **kwargs) + + def get_values(self, values: list[np.ndarray] | None): + if values is None: + values = compute_JtJdiags(self.dmisfit, self.invProb.model) + + return values + + +class SaveDataGeoH5(SaveArrayGeoH5): + """ + Save the model at the current iteration to a geoh5 file. + """ + + _attribute_type = "predicted" + + def __init__(self, h5_object, joint_index: list[int] | None = None, **kwargs): + self.joint_index = joint_index + + super().__init__(h5_object, **kwargs) + + def get_values(self, values: list[np.ndarray] | None): + + if values is not None: + prop = self.stack_channels(values) + + else: + dpred = getattr(self.invProb, "dpred", None) + if dpred is None: + dpred = self.invProb.get_dpred(self.invProb.model) + self.invProb.dpred = dpred + + if self.joint_index is not None: + dpred = [dpred[ind] for ind in self.joint_index] + + prop = self.stack_channels(dpred) + + return prop + + @property + def joint_index(self): + """ + Index for joint inversions defining the element in the list of predicted data. + """ + return self._joint_index + + @joint_index.setter + def joint_index(self, value: list[int] | None): + if not isinstance(value, list | type(None)): + raise TypeError("Input 'joint_index' should be a list of int") + + self._joint_index = value + + +class SaveLogFilesGeoH5(BaseSaveGeoH5): + + def write(self, iteration: int, **_): + dirpath = Path(self._geoh5.h5file).parent + filepath = dirpath / "SimPEG.out" + + if iteration == 0: + with open(filepath, "w", encoding="utf-8") as f: + f.write("iteration beta phi_d phi_m time\n") + log = [] + with open(dirpath / "SimPEG.log", "r", encoding="utf-8") as file: + iteration = 0 + for line in file: + val = re.findall( + "[+\-]?(?:0|[1-9]\d*)(?:\.\d*)?(?:[eE][+\-]?\d+)", line # noqa + ) + if len(val) == 5: + log.append(val[:-2]) + iteration += 1 + + if len(log) > 0: + with open(filepath, "a", encoding="utf-8") as file: + date_time = datetime.now().strftime("%b-%d-%Y:%H:%M:%S") + file.write(f"{iteration-1} " + " ".join(log[-1]) + f" {date_time}\n") + + self.save_log() + + def save_log(self): + """ + Save iteration metrics to comments. + """ + dirpath = Path(self._geoh5.h5file).parent + + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] + + for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]: + filepath = dirpath / file + + if not filepath.is_file(): + continue + + with open(filepath, "rb") as f: + raw_file = f.read() + + file_entity = h5_object.get_entity(file)[0] + if file_entity is None: + file_entity = h5_object.add_file(filepath) + + file_entity.file_bytes = raw_file + + +class SavePropertyGroup(BaseSaveGeoH5): + """ + Save the model as a property group in the geoh5 file + """ + + def __init__( + self, + h5_object, + group_type: GroupTypeEnum = GroupTypeEnum.MULTI, + **kwargs, + ): + self.group_type = group_type + + super().__init__(h5_object, **kwargs) + + def write(self, iteration: int, **_): + """ + Save the model to the geoh5 file + """ + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + h5_object = w_s.get_entity(self.h5_object)[0] + + for component in self.components: + properties = [] + for channel in self.channels: + + channel_name, base_name = self.get_names( + component, channel, iteration + ) + child = [ + child + for child in h5_object.children + if channel_name in child.name + ][0] + + if child is not None: + properties.append(child) + + if len(properties) == 0: + return + + PropertyGroup( + parent=h5_object, + name=base_name, + properties=properties, + property_group_type=self.group_type, + ) diff --git a/simpeg/directives/_vector_models.py b/simpeg/directives/_vector_models.py new file mode 100644 index 0000000000..4124932df2 --- /dev/null +++ b/simpeg/directives/_vector_models.py @@ -0,0 +1,233 @@ +import numpy as np + +from . import ( + BaseSaveGeoH5, + InversionDirective, + SaveModelGeoH5, + SphericalUnitsWeights, + Update_IRLS, + UpdateIRLS, + UpdateSensitivityWeights, +) +from ..maps import SphericalSystem, Wires +from ..meta.simulation import MetaSimulation +from ..objective_function import ComboObjectiveFunction +from ..regularization import CrossGradient +from ..utils.mat_utils import cartesian2amplitude_dip_azimuth +from ..utils import set_kwargs, spherical2cartesian, cartesian2spherical + + +class ProjectSphericalBounds(InversionDirective): + r""" + Trick for spherical coordinate system. + Project \theta and \phi angles back to [-\pi,\pi] using + back and forth conversion. + spherical->cartesian->spherical + """ + + def __init__(self, mapping: Wires, **kwargs): + if not isinstance(mapping, Wires): + raise TypeError("mapping must be a Wires object") + + if len(mapping.maps) != 3: + raise ValueError("mapping must have 3 maps, one per vector component.") + + self.indices = mapping.deriv(None).indices + super().__init__(**kwargs) + + def initialize(self): + self.update() + + def endIter(self): + self.update() + + def update(self): + """ + Update the model and the simulation + """ + x = self.invProb.model + m = self._reproject(x) + phi_m_last = [] + for reg in self.reg.objfcts: + reg.model = self.invProb.model + phi_m_last += [reg(self.invProb.model)] + + self.invProb.phi_m_last = phi_m_last + self.invProb.model = m + self.opt.xc = self.invProb.model + + for misfit in self.dmisfit.objfcts: + misfit.simulation.model = m + + def _reproject(self, m): + """ + Round trip conversion to reproject the model. + """ + vec = m[self.indices] + xyz = spherical2cartesian(vec.reshape((-1, 3), order="F")) + vec = cartesian2spherical(xyz.reshape((-1, 3), order="F")) + + m[self.indices] = vec + return m + + +class VectorInversion(InversionDirective): + """ + Control a vector inversion from Cartesian to spherical coordinates. + """ + + chifact_target = 1.0 + reference_model = None + mode = "cartesian" + inversion_type = "mvis" + norms = [] + alphas = [] + cartesian_model = None + mappings = [] + regularization = [] + + def __init__( + self, simulations: list, regularizations: ComboObjectiveFunction, **kwargs + ): + self.reference_angles = (False, False, False) + self.simulations = simulations + self.regularizations = regularizations + + set_kwargs(self, **kwargs) + + @property + def target(self): + if getattr(self, "_target", None) is None: + nD = 0 + for survey in self.survey: + nD += survey.nD + + self._target = nD * self.chifact_target + + return self._target + + @target.setter + def target(self, val): + self._target = val + + def initialize(self): + for reg in self.reg.objfcts: + reg.model = self.invProb.model + + self.reference_model = reg.reference_model + + for dmisfit in self.dmisfit.objfcts: + if getattr(dmisfit.simulation, "coordinate_system", None) is not None: + dmisfit.simulation.coordinate_system = self.mode + + def endIter(self): + if ( + self.invProb.phi_d < self.target + ) and self.mode == "cartesian": # and self.inversion_type == 'mvis': + print("Switching MVI to spherical coordinates") + self.mode = "spherical" + self.cartesian_model = self.invProb.model + model = self.invProb.model.copy() + vec_model = [] + vec_ref = [] + indices = [] + mappings = [] + for reg in self.regularizations.objfcts: + mappings.append(reg.mapping) + vec_model.append(reg.mapping * model) + vec_ref.append(reg.mapping * reg.reference_model) + mapping = reg.mapping.deriv(np.zeros(reg.mapping.shape[1])) + indices.append(mapping.indices) + + indices = np.hstack(indices) + nC = mapping.shape[0] + vec_model = cartesian2spherical(np.vstack(vec_model).T) + vec_ref = cartesian2spherical(np.vstack(vec_ref).T).flatten() + model[indices] = vec_model.flatten() + + angle_map = [] + for ind, (reg_fun, ref_angle) in enumerate( + zip(self.regularizations.objfcts, self.reference_angles) + ): + reg_fun.model = model + reg_fun.reference_model[indices] = vec_ref + + if ind > 0: + if not ref_angle: + reg_fun.alpha_s = 0 + + reg_fun.eps_q = np.pi + reg_fun.units = "radian" + angle_map.append(reg_fun.mapping) + else: + reg_fun.units = "amplitude" + + # Change units of cross-gradient on angles + multipliers = [] + for mult, reg in self.reg: + if isinstance(reg, CrossGradient): + units = [] + for _, wire in reg.wire_map.maps: + if wire in angle_map: + units.append("radian") + mult = 0 # TODO Make this optional + else: + units.append("metric") + + reg.units = units + + multipliers.append(mult) + + self.reg.multipliers = multipliers + self.invProb.beta *= 2 + self.invProb.model = model + self.opt.xc = model + self.opt.lower[indices] = np.kron( + np.asarray([0, -np.inf, -np.inf]), np.ones(nC) + ) + self.opt.upper[indices[nC:]] = np.inf + + for simulation in self.simulations: + if isinstance(simulation, MetaSimulation): + for sim in simulation.simulations: + sim.chiMap = SphericalSystem() * sim.chiMap + else: + simulation.chiMap = SphericalSystem() * simulation.chiMap + + # Add and update directives + for directive in self.inversion.directiveList.dList: + if ( + isinstance(directive, SaveModelGeoH5) + and cartesian2amplitude_dip_azimuth in directive.transforms + ): + transforms = [] + + for fun in directive.transforms: + if fun is cartesian2amplitude_dip_azimuth: + transforms += [spherical2cartesian] + transforms += [fun] + + directive.transforms = transforms + + elif isinstance(directive, Update_IRLS | UpdateIRLS): + directive.sphericalDomain = True + directive.model = model + directive.coolingFactor = 1.5 + + elif isinstance(directive, UpdateSensitivityWeights): + directive.every_iteration = True + + spherical_units = SphericalUnitsWeights( + amplitude=self.regularizations.objfcts[0].mapping, + angles=self.regularizations.objfcts[1:], + ) + projections = [(comp, mapping) for comp, mapping in zip("xyz", mappings)] + directiveList = [ + ProjectSphericalBounds(Wires(*projections)), + spherical_units, + ] + self.inversion.directiveList.dList + self.inversion.directiveList = directiveList + + for directive in directiveList: + if not isinstance(directive, BaseSaveGeoH5): + directive.endIter() diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 4db080e652..be63899278 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -1,7 +1,5 @@ from __future__ import annotations # needed to use type operands in Python 3.8 -from abc import ABC, abstractmethod - from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING @@ -12,12 +10,12 @@ import warnings import os import scipy.sparse as sp -from ..meta.simulation import MetaSimulation + from ..typing import RandomSeed from ..data_misfit import BaseDataMisfit from ..objective_function import ComboObjectiveFunction -from ..maps import IdentityMap, SphericalSystem, Wires +from ..maps import IdentityMap, Wires from ..regularization import ( WeightedLeastSquares, @@ -30,21 +28,17 @@ SmoothnessFirstOrder, SparseSmoothness, BaseSimilarityMeasure, - CrossGradient, ) from ..utils import ( mkvc, set_kwargs, sdiag, estimate_diagonal, - spherical2cartesian, - cartesian2spherical, Zero, eigenvalue_by_power_iteration, validate_string, ) -from simpeg.utils.mat_utils import cartesian2amplitude_dip_azimuth from ..utils.code_utils import ( deprecate_class, @@ -54,10 +48,6 @@ validate_float, validate_ndarray_with_shape, ) -from geoh5py.groups.property_group import GroupTypeEnum -from geoh5py.groups import PropertyGroup, UIJsonGroup -from geoh5py.objects import ObjectBase -from geoh5py.ui_json.utils import fetch_active_workspace def compute_JtJdiags(data_misfit, m): @@ -3022,637 +3012,6 @@ def validate(self, directiveList): return True -class ProjectSphericalBounds(InversionDirective): - r""" - Trick for spherical coordinate system. - Project \theta and \phi angles back to [-\pi,\pi] using - back and forth conversion. - spherical->cartesian->spherical - """ - - def initialize(self): - x = self.invProb.model - # Convert to cartesian than back to avoid over rotation - nC = int(len(x) / 3) - xyz = spherical2cartesian(x.reshape((nC, 3), order="F")) - m = cartesian2spherical(xyz.reshape((nC, 3), order="F")) - self.invProb.model = m - self.opt.xc = m - - for misfit in self.dmisfit: - if getattr(misfit, "model_map", None) is not None: - misfit.simulation.model = misfit.model_map @ m - else: - misfit.simulation.model = m - - def endIter(self): - for misfit in self.dmisfit.objfcts: - if ( - hasattr(misfit.simulation, "model_type") - and misfit.simulation.model_type == "vector" - ): - mapping = misfit.model_map.deriv(np.zeros(misfit.model_map.shape[1])) - indices = ( - mapping.indices - ) # np.array(np.sum(mapping, axis=0)).flatten() > 0 - nC = int(len(indices) / 3) - vec = self.invProb.model[indices] - # Convert to cartesian than back to avoid over rotation - xyz = spherical2cartesian(vec.reshape((nC, 3), order="F")) - vec = cartesian2spherical(xyz.reshape((nC, 3), order="F")) - self.invProb.model[indices] = vec - - phi_m_last = [] - for reg in self.reg.objfcts: - reg.model = self.invProb.model - phi_m_last += [reg(self.invProb.model)] - - self.invProb.phi_m_last = phi_m_last - self.opt.xc = self.invProb.model - - for misfit in self.dmisfit.objfcts: - if getattr(misfit, "model_map", None) is not None: - misfit.simulation.model = misfit.model_map @ self.invProb.model - else: - misfit.simulation.model = self.invProb.model - - -class BaseSaveGeoH5(InversionDirective, ABC): - """ - Base class for saving inversion results to a geoh5 file - """ - - def __init__( - self, - h5_object, - dmisfit=None, - label: str | None = None, - channels: list[str] = ("",), - components: list[str] = ("",), - association: str | None = None, - **kwargs, - ): - self.label = label - self.channels = channels - self.components = components - self.h5_object = h5_object - - if association is not None: - self.association = association - - super().__init__( - inversion=None, dmisfit=dmisfit, reg=None, verbose=False, **kwargs - ) - - def initialize(self): - self.write(0) - - def endIter(self): - self.write(self.opt.iter) - - def get_names( - self, component: str, channel: str, iteration: int - ) -> tuple[str, str]: - """ - Format the data and property_group name. - """ - base_name = f"Iteration_{iteration}" - if len(component) > 0: - base_name += f"_{component}" - - channel_name = base_name - if channel: - channel_name += f"_{channel}" - - if self.label is not None: - channel_name += f"_{self.label}" - base_name += f"_{self.label}" - - return channel_name, base_name - - @abstractmethod - def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa - """ - Save the components of the inversion. - """ - - @property - def label(self): - return self._label - - @label.setter - def label(self, value: str | None): - if not isinstance(value, str | type(None)): - raise TypeError("'label' must be a string or None") - - self._label = value - - @property - def h5_object(self): - return self._h5_object - - @h5_object.setter - def h5_object(self, entity: ObjectBase): - if not isinstance(entity, ObjectBase | UIJsonGroup): - raise TypeError( - f"Input entity should be of type {ObjectBase}. {type(entity)} provided" - ) - - self._h5_object = entity.uid - self._geoh5 = entity.workspace - - if getattr(entity, "n_cells", None) is not None: - self.association = "CELL" - else: - self.association = "VERTEX" - - @property - def association(self): - return self._association - - @association.setter - def association(self, value): - if not value.upper() in ["CELL", "VERTEX"]: - raise ValueError( - f"'association must be one of 'CELL', 'VERTEX'. {value} provided" - ) - - self._association = value.upper() - - -class SaveArrayGeoH5(BaseSaveGeoH5, ABC): - """ - Saves array-based inversion results (model, data) to a geoh5 file. - - Parameters - ---------- - - transforms: List of transformations applied to the values before save. - sorting: Special re-indexing of the vector values. - reshape: Re-ordering applied to the data before slicing. - """ - - _attribute_type = None - - def __init__( - self, - h5_object, - transforms: list | tuple = (), - reshape=None, - sorting=None, - **kwargs, - ): - self.data_type = {} - self.transforms = transforms - self.sorting = sorting - self.reshape = reshape - - super().__init__(h5_object, **kwargs) - - @property - def reshape(self): - """ - Reshape function - """ - if getattr(self, "_reshape", None) is None: - self._reshape = lambda x: x.reshape( - (len(self.channels), len(self.components), -1), order="F" - ) - - return self._reshape - - @reshape.setter - def reshape(self, fun): - self._reshape = fun - - @property - def transforms(self): - return self._transforms - - @transforms.setter - def transforms(self, funcs: list | tuple): - if not isinstance(funcs, list | tuple): - funcs = [funcs] - - for fun in funcs: - if not any( - [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] - ): - raise TypeError( - "Input transformation must be of type" - + "SimPEG.maps, numpy.ndarray or callable function" - ) - - self._transforms = funcs - - def stack_channels(self, dpred: list): - """ - Regroup channel values along rows. - """ - if isinstance(dpred, np.ndarray): - return self.reshape(dpred) - - return self.reshape(np.hstack(dpred)) - - def apply_transformations(self, prop: np.ndarray) -> np.ndarray: - """ - Re-order the values and apply transformations. - """ - prop = prop.flatten() - for fun in self.transforms: - if isinstance(fun, (IdentityMap, np.ndarray, float)): - prop = fun * prop - else: - prop = fun(prop) - - if prop.ndim == 2: - prop = prop.T.flatten() - - prop = prop.reshape((len(self.channels), len(self.components), -1)) - - return prop - - @abstractmethod - def get_values(self, values: list[np.ndarray] | None): - """ - Get values for the inversion depending on the output type. - """ - - def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa - """ - Sort, transform and store data per components and channels. - """ - prop = self.get_values(values) - - # Apply transformations - prop = self.apply_transformations(prop) - - # Save results - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: - h5_object = w_s.get_entity(self.h5_object)[0] - for cc, component in enumerate(self.components): - if component not in self.data_type: - self.data_type[component] = {} - - for ii, channel in enumerate(self.channels): - values = prop[ii, cc, :] - - if self.sorting is not None: - values = values[self.sorting] - - channel_name, base_name = self.get_names( - component, channel, iteration - ) - - data = h5_object.add_data( - { - channel_name: { - "association": self.association, - "values": values, - } - } - ) - # Re-assign the data type - if channel not in self.data_type[component].keys(): - self.data_type[component][channel] = data.entity_type - type_name = f"{self._attribute_type}_{component}" - if channel: - type_name += f"_{channel}" - data.entity_type.name = type_name - else: - data.entity_type = w_s.find_type( - self.data_type[component][channel].uid, - type(self.data_type[component][channel]), - ) - - -class SaveModelGeoH5(SaveArrayGeoH5): - """ - Save the model at the current iteration to a geoh5 file. - """ - - _attribute_type = "model" - - def get_values(self, values: list[np.ndarray] | None): - if values is None: - values = self.invProb.model - - return values - - -class SaveSensitivityGeoH5(SaveArrayGeoH5): - """ - Save the model at the current iteration to a geoh5 file. - """ - - _attribute_type = "sensitivities" - - def __init__(self, h5_object, dmisfit=None, **kwargs): - if dmisfit is None: - raise ValueError( - "To save sensitivities, the data misfit object must be provided." - ) - super().__init__(h5_object, dmisfit=dmisfit, **kwargs) - - def get_values(self, values: list[np.ndarray] | None): - if values is None: - values = compute_JtJdiags(self.dmisfit, self.invProb.model) - - return values - - -class SaveDataGeoH5(SaveArrayGeoH5): - """ - Save the model at the current iteration to a geoh5 file. - """ - - _attribute_type = "predicted" - - def __init__(self, h5_object, joint_index: list[int] | None = None, **kwargs): - self.joint_index = joint_index - - super().__init__(h5_object, **kwargs) - - def get_values(self, values: list[np.ndarray] | None): - - if values is not None: - prop = self.stack_channels(values) - - else: - dpred = getattr(self.invProb, "dpred", None) - if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) - self.invProb.dpred = dpred - - if self.joint_index is not None: - dpred = [dpred[ind] for ind in self.joint_index] - - prop = self.stack_channels(dpred) - - return prop - - @property - def joint_index(self): - """ - Index for joint inversions defining the element in the list of predicted data. - """ - return self._joint_index - - @joint_index.setter - def joint_index(self, value: list[int] | None): - if not isinstance(value, list | type(None)): - raise TypeError("Input 'joint_index' should be a list of int") - - self._joint_index = value - - -class SaveLogFilesGeoH5(BaseSaveGeoH5): - - def write(self, iteration: int, **_): - dirpath = Path(self._geoh5.h5file).parent - filepath = dirpath / "SimPEG.out" - - if iteration == 0: - with open(filepath, "w", encoding="utf-8") as f: - f.write("iteration beta phi_d phi_m time\n") - - with open(filepath, "a", encoding="utf-8") as f: - date_time = datetime.now().strftime("%b-%d-%Y:%H:%M:%S") - f.write( - f"{iteration} {self.invProb.beta:.3e} {self.invProb.phi_d:.3e} " - f"{self.invProb.phi_m:.3e} {date_time}\n" - ) - - self.save_log() - - def save_log(self): - """ - Save iteration metrics to comments. - """ - dirpath = Path(self._geoh5.h5file).parent - - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: - h5_object = w_s.get_entity(self.h5_object)[0] - - for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]: - filepath = dirpath / file - - if not filepath.is_file(): - continue - - with open(filepath, "rb") as f: - raw_file = f.read() - - file_entity = h5_object.get_entity(file)[0] - if file_entity is None: - file_entity = h5_object.add_file(filepath) - - file_entity.file_bytes = raw_file - - -class SavePropertyGroup(BaseSaveGeoH5): - """ - Save the model as a property group in the geoh5 file - """ - - def __init__( - self, - h5_object, - group_type: GroupTypeEnum = GroupTypeEnum.MULTI, - **kwargs, - ): - self.group_type = group_type - - super().__init__(h5_object, **kwargs) - - def write(self, iteration: int, **_): - """ - Save the model to the geoh5 file - """ - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: - h5_object = w_s.get_entity(self.h5_object)[0] - - for component in self.components: - properties = [] - for channel in self.channels: - - channel_name, base_name = self.get_names( - component, channel, iteration - ) - child = [ - child - for child in h5_object.children - if channel_name in child.name - ][0] - - if child is not None: - properties.append(child) - - if len(properties) == 0: - return - - PropertyGroup( - parent=h5_object, - name=base_name, - properties=properties, - property_group_type=self.group_type, - ) - - -class VectorInversion(InversionDirective): - """ - Control a vector inversion from Cartesian to spherical coordinates. - """ - - chifact_target = 1.0 - reference_model = None - mode = "cartesian" - inversion_type = "mvis" - norms = [] - alphas = [] - cartesian_model = None - mappings = [] - regularization = [] - - def __init__( - self, simulations: list, regularizations: ComboObjectiveFunction, **kwargs - ): - self.reference_angles = (False, False, False) - self.simulations = simulations - self.regularizations = regularizations - - set_kwargs(self, **kwargs) - - @property - def target(self): - if getattr(self, "_target", None) is None: - nD = 0 - for survey in self.survey: - nD += survey.nD - - self._target = nD * self.chifact_target - - return self._target - - @target.setter - def target(self, val): - self._target = val - - def initialize(self): - for reg in self.reg.objfcts: - reg.model = self.invProb.model - - self.reference_model = reg.reference_model - - for dmisfit in self.dmisfit.objfcts: - if getattr(dmisfit.simulation, "coordinate_system", None) is not None: - dmisfit.simulation.coordinate_system = self.mode - - def endIter(self): - if ( - self.invProb.phi_d < self.target - ) and self.mode == "cartesian": # and self.inversion_type == 'mvis': - print("Switching MVI to spherical coordinates") - self.mode = "spherical" - self.cartesian_model = self.invProb.model - model = self.invProb.model - vec_model = [] - vec_ref = [] - indices = [] - for reg in self.regularizations.objfcts: - vec_model.append(reg.mapping * model) - vec_ref.append(reg.mapping * reg.reference_model) - mapping = reg.mapping.deriv(np.zeros(reg.mapping.shape[1])) - indices.append(mapping.indices) - - indices = np.hstack(indices) - nC = mapping.shape[0] - vec_model = cartesian2spherical(np.vstack(vec_model).T) - vec_ref = cartesian2spherical(np.vstack(vec_ref).T).flatten() - model[indices] = vec_model.flatten() - - angle_map = [] - for ind, (reg_fun, ref_angle) in enumerate( - zip(self.regularizations.objfcts, self.reference_angles) - ): - reg_fun.model = model - reg_fun.reference_model[indices] = vec_ref - - if ind > 0: - if not ref_angle: - reg_fun.alpha_s = 0 - - reg_fun.eps_q = np.pi - reg_fun.units = "radian" - angle_map.append(reg_fun.mapping) - else: - reg_fun.units = "amplitude" - - # Change units of cross-gradient on angles - multipliers = [] - for mult, reg in self.reg: - if isinstance(reg, CrossGradient): - units = [] - for _, wire in reg.wire_map.maps: - if wire in angle_map: - units.append("radian") - mult = 0 # TODO Make this optional - else: - units.append("metric") - - reg.units = units - - multipliers.append(mult) - - self.reg.multipliers = multipliers - self.invProb.beta *= 2 - self.invProb.model = model - self.opt.xc = model - self.opt.lower[indices] = np.kron( - np.asarray([0, -np.inf, -np.inf]), np.ones(nC) - ) - self.opt.upper[indices[nC:]] = np.inf - - for simulation in self.simulations: - if isinstance(simulation, MetaSimulation): - for sim in simulation.simulations: - sim.chiMap = SphericalSystem() * sim.chiMap - else: - simulation.chiMap = SphericalSystem() * simulation.chiMap - - # Add and update directives - for directive in self.inversion.directiveList.dList: - if ( - isinstance(directive, SaveModelGeoH5) - and cartesian2amplitude_dip_azimuth in directive.transforms - ): - transforms = [] - - for fun in directive.transforms: - if fun is cartesian2amplitude_dip_azimuth: - transforms += [spherical2cartesian] - transforms += [fun] - - directive.transforms = transforms - - elif isinstance(directive, Update_IRLS): - directive.sphericalDomain = True - directive.model = model - directive.coolingFactor = 1.5 - - elif isinstance(directive, UpdateSensitivityWeights): - directive.every_iteration = True - - directiveList = [ - ProjectSphericalBounds() - ] + self.inversion.directiveList.dList - self.inversion.directiveList = directiveList - - for directive in directiveList: - if not isinstance(directive, BaseSaveGeoH5): - directive.endIter() - - class ScaleMisfitMultipliers(InversionDirective): """ Scale the misfits by the relative chi-factors of multiple misfit functions.