diff --git a/simpeg/dask/__init__.py b/simpeg/dask/__init__.py index 89fee4fcd9..f8aae06c9c 100644 --- a/simpeg/dask/__init__.py +++ b/simpeg/dask/__init__.py @@ -6,6 +6,7 @@ import simpeg.dask.electromagnetics.static.induced_polarization.simulation import simpeg.dask.electromagnetics.static.induced_polarization.simulation_2d import simpeg.dask.electromagnetics.time_domain.simulation + import simpeg.dask.electromagnetics.time_domain.simulation_1d import simpeg.dask.potential_fields.base import simpeg.dask.potential_fields.gravity.simulation import simpeg.dask.potential_fields.magnetics.simulation diff --git a/simpeg/dask/electromagnetics/time_domain/simulation_1d.py b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py new file mode 100644 index 0000000000..8eeb6058d2 --- /dev/null +++ b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py @@ -0,0 +1,24 @@ +from ....electromagnetics.time_domain.simulation_1d import Simulation1DLayered as Sim + +from ...simulation import getJtJdiag, Jvec, Jtvec + +Sim._delete_on_model_update = ["_Jmatrix", "_jtjdiag", "_J"] + + +@property +def Jmatrix(self): + """ + Sensitivity matrix stored on disk + Return the diagonal of JtJ + """ + if getattr(self, "_Jmatrix", None) is None: + Jmat = self.getJ(self.model) + self._Jmatrix = Jmat["ds"] + + return self._Jmatrix + + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index f4ea2b134b..89c4575442 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -4,7 +4,7 @@ from pathlib import Path import numpy as np - +from scipy.sparse import csc_matrix, csr_matrix from .directives import InversionDirective from simpeg.maps import IdentityMap @@ -188,7 +188,12 @@ def transforms(self, funcs: list | tuple): for fun in funcs: if not any( - [isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)] + [ + isinstance( + fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float) + ), + callable(fun), + ] ): raise TypeError( "Input transformation must be of type" @@ -212,7 +217,9 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray: """ prop = prop.flatten() for fun in self.transforms: - if isinstance(fun, (IdentityMap, np.ndarray, float)): + if isinstance( + fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float) + ): prop = fun * prop else: prop = fun(prop)