Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpeg/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import simpeg.dask.electromagnetics.static.induced_polarization.simulation
import simpeg.dask.electromagnetics.static.induced_polarization.simulation_2d
import simpeg.dask.electromagnetics.time_domain.simulation
import simpeg.dask.electromagnetics.time_domain.simulation_1d
import simpeg.dask.potential_fields.base
import simpeg.dask.potential_fields.gravity.simulation
import simpeg.dask.potential_fields.magnetics.simulation
Expand Down
24 changes: 24 additions & 0 deletions simpeg/dask/electromagnetics/time_domain/simulation_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ....electromagnetics.time_domain.simulation_1d import Simulation1DLayered as Sim

from ...simulation import getJtJdiag, Jvec, Jtvec

Sim._delete_on_model_update = ["_Jmatrix", "_jtjdiag", "_J"]


@property
def Jmatrix(self):
"""
Sensitivity matrix stored on disk
Return the diagonal of JtJ
"""
if getattr(self, "_Jmatrix", None) is None:
Jmat = self.getJ(self.model)
self._Jmatrix = Jmat["ds"]

return self._Jmatrix


Sim.getJtJdiag = getJtJdiag
Sim.Jvec = Jvec
Sim.Jtvec = Jtvec
Sim.Jmatrix = Jmatrix
13 changes: 10 additions & 3 deletions simpeg/directives/_save_geoh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import numpy as np

from scipy.sparse import csc_matrix, csr_matrix
from .directives import InversionDirective
from simpeg.maps import IdentityMap

Expand Down Expand Up @@ -188,7 +188,12 @@ def transforms(self, funcs: list | tuple):

for fun in funcs:
if not any(
[isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)]
[
isinstance(
fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float)
),
callable(fun),
]
):
raise TypeError(
"Input transformation must be of type"
Expand All @@ -212,7 +217,9 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray:
"""
prop = prop.flatten()
for fun in self.transforms:
if isinstance(fun, (IdentityMap, np.ndarray, float)):
if isinstance(
fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float)
):
prop = fun * prop
else:
prop = fun(prop)
Expand Down