diff --git a/meta.yaml b/meta.yaml index 351a5dbd7b..e620ab6589 100644 --- a/meta.yaml +++ b/meta.yaml @@ -1,5 +1,5 @@ {% set name = "mira-simpeg" %} -{% set version = "0.21.2.1rc1" %} +{% set version = "0.21.2.1b3" %} package: name: {{ name|lower }} diff --git a/simpeg/dask/__init__.py b/simpeg/dask/__init__.py index f5a00b7334..89fee4fcd9 100644 --- a/simpeg/dask/__init__.py +++ b/simpeg/dask/__init__.py @@ -10,7 +10,6 @@ import simpeg.dask.potential_fields.gravity.simulation import simpeg.dask.potential_fields.magnetics.simulation import simpeg.dask.simulation - import simpeg.dask.data_misfit import simpeg.dask.inverse_problem import simpeg.dask.objective_function diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py deleted file mode 100644 index f01d646248..0000000000 --- a/simpeg/dask/data_misfit.py +++ /dev/null @@ -1,96 +0,0 @@ -import numpy as np - -from ..data_misfit import L2DataMisfit -from ..fields import Fields -from ..utils import mkvc -from .utils import compute -import dask.array as da -from scipy.sparse import csr_matrix as csr -from dask import delayed - - -def dask_call(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__` - """ - R = self.W * self.residual(m, f=f) - phi_d = da.dot(R, R) - if not isinstance(phi_d, np.ndarray): - return compute(self, phi_d) - return phi_d - - -L2DataMisfit.__call__ = dask_call - - -def dask_deriv(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv` - """ - mapping_deriv = self.model_map.deriv(m) - if getattr(self, "model_map", None) is not None: - m = self.model_map @ m - - wtw_d = self.W.diagonal() ** 2.0 * self.residual(m, f=f) - Jtvec = compute(self, self.simulation.Jtvec(m, wtw_d)) - - if getattr(self, "model_map", None) is not None: - Jtjvec_dmudm = delayed(csr.dot)(Jtvec, mapping_deriv) - h_vec = da.from_delayed( - Jtjvec_dmudm, dtype=float, shape=[mapping_deriv.shape[1]] - ) - if not isinstance(h_vec, np.ndarray): - return compute(self, h_vec) - return h_vec - - if not isinstance(Jtvec, np.ndarray): - return compute(self, Jtvec) - return Jtvec - - -L2DataMisfit.deriv = dask_deriv - - -def dask_deriv2(self, m, v, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv2` - """ - mapping_deriv = self.model_map.deriv(m) - if getattr(self, "model_map", None) is not None: - m = self.model_map @ m - v = mapping_deriv @ v - - jvec = compute(self, self.simulation.Jvec(m, v)) - w_jvec = self.W.diagonal() ** 2.0 * jvec - jtwjvec = compute(self, self.simulation.Jtvec(m, w_jvec)) - - if getattr(self, "model_map", None) is not None: - Jtjvec_dmudm = delayed(csr.dot)(jtwjvec, mapping_deriv) - h_vec = da.from_delayed( - Jtjvec_dmudm, dtype=float, shape=[mapping_deriv.shape[1]] - ) - if not isinstance(h_vec, np.ndarray): - return compute(self, h_vec) - return h_vec - - if not isinstance(jtwjvec, np.ndarray): - return compute(self, jtwjvec) - return jtwjvec - - -L2DataMisfit.deriv2 = dask_deriv2 - - -def dask_residual(self, m, f=None): - if self.data is None: - raise Exception("data must be set before a residual can be calculated.") - - if isinstance(f, Fields) or f is None: - return self.simulation.residual(m, self.data.dobs, f=f) - elif f.shape == self.data.dobs.shape: - return mkvc(f - self.data.dobs) - else: - raise Exception(f"Attribute f must be or type {Fields}, numpy.array or None.") - - -L2DataMisfit.residual = dask_residual diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index eb8527ed8a..09a50e61a2 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,56 +1,144 @@ +import gc + from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp -from multiprocessing import cpu_count -from dask import array, compute, delayed -# from dask.distributed import get_client, Client, performance_report -from simpeg.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag +from dask import array, compute, delayed +from dask.distributed import get_client from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr -from tqdm import tqdm -Sim.sensitivity_path = "./sensitivity/" -Sim.gtgdiag = None -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +def receivers_eval(block, mesh, fields): + data = [] + for source, _, receiver in block: + data.append(receiver.eval(source, mesh, fields).flatten()) + return np.hstack(data) -@delayed -def source_evaluation(simulation, sources): + +def source_eval(simulation, sources, indices): s_m, s_e = [], [] - for source in sources: - sm, se = source.eval(simulation) + for ind in indices: + sm, se = sources[ind].eval(simulation) s_m.append(sm) s_e.append(se) return s_m, s_e -def dask_getSourceTerm(self, freq, source=None): +def receiver_derivs(survey, mesh, fields, blocks): + field_derivatives = [] + for address in blocks: + source = survey.source_list[address[0][0]] + receiver = source.receiver_list[address[0][1]] + + if isinstance(source, PlanewaveXYPrimary): + v = np.eye(receiver.nD, dtype=float) + else: + v = sp.csr_matrix(np.ones(receiver.nD), dtype=float) + + # Assume the derivatives in terms of model are Zero (seems to always be case) + dfduT, _ = receiver.evalDeriv( + source, mesh, fields, v=v[:, address[1][0]], adjoint=True + ) + field_derivatives.append(dfduT) + + return field_derivatives + + +def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): + """ + Evaluate the sensitivities for the block or data + """ + + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_columns = Ainv_deriv_u[:, deriv_indices] + + n_receivers = address[1][2] + source = simulation.survey.source_list[address[0][0]] + + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + n_cols = 2 + else: + source_fields = fields[:, address[0][0]] + n_cols = 1 + + n_cols *= n_receivers + + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) + + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) + + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(deriv_m, Zero): + du_dmT += deriv_m + + return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + + +def getSourceTerm(self, freq, source=None): """ Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ + if source is None: + + try: + client = get_client() + sim = client.scatter(self, workers=self.worker) + except ValueError: + client = None + sim = self + source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, cpu_count()) + source_blocks = np.array_split( + np.arange(len(source_list)), self.n_threads(client=client) + ) + + if client: + source_list = client.scatter(source_list, workers=self.worker) block_compute = [] - for block in source_block: + + for block in source_blocks: if len(block) == 0: continue - block_compute.append(source_evaluation(self, block)) + if client: + block_compute.append( + client.submit( + source_eval, sim, source_list, block, workers=self.worker + ) + ) + else: + block_compute.append(source_eval(sim, source_list, block)) + + if client: + block_compute = client.gather(block_compute) - blocks = compute(block_compute)[0] s_m, s_e = [], [] - for block in blocks: + for block in block_compute: if block[0]: s_m += block[0] s_e += block[1] @@ -72,85 +160,17 @@ def dask_getSourceTerm(self, freq, source=None): s_e = np.vstack(s_e) if s_e.shape[0] < s_e.shape[1]: s_e = s_e.T - return s_m, s_e - - -Sim.getSourceTerm = dask_getSourceTerm - - -@delayed -def evaluate_receivers(block, mesh, fields): - data = [] - for source, _, receiver in block: - data.append(receiver.eval(source, mesh, fields).flatten()) - - return np.hstack(data) - - -def dask_dpred(self, m=None, f=None, compute_J=False): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m, return_Ainv=compute_J) - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - - receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) - rows = [] - mesh = delayed(self.mesh) - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue - - rows.append( - array.from_delayed( - evaluate_receivers(block, mesh, f), - dtype=np.float64, - shape=(n_data,), - ) - ) - - data = compute(array.hstack(rows))[0] - - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix - - return data - - -Sim.dpred = dask_dpred -Sim.field_derivs = None + return s_m, s_e def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: + return self._stashed_fields + f = self.fieldsPair(self) Ainv = {} for freq in self.survey.frequencies: @@ -160,33 +180,29 @@ def fields(self, m=None, return_Ainv=False): u = Ainv_solve * rhs sources = self.survey.get_sources_by_frequency(freq) f[sources, self._solutionType] = u + Ainv[freq] = Ainv_solve - if return_Ainv: - Ainv[freq] = Ainv_solve - else: - Ainv_solve.clean() + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = fields - +def compute_J(self, m, f=None): + self.model = m -def compute_J(self, f=None): if f is None: - f = self.fields(self.model, return_Ainv=True) + f, Ainv = self.fields(m=m, return_Ainv=True) - if len(self.Ainv) > 1: + if len(Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(self.Ainv.values())[0] - m_size = self.model.size + A_i = list(Ainv.values())[0] + m_size = m.size if self.store_sensitivities == "disk": Jmatrix = zarr.open( @@ -202,56 +218,105 @@ def compute_J(self, f=None): blocks = get_parallel_blocks( self.survey.source_list, compute_row_size, optimize=False ) - fields_array = delayed(f[:, self._solutionType]) - fields = delayed(f) - survey = delayed(self.survey) - mesh = delayed(self.mesh) + fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - for block in blocks: - blocks_receiver_derivs.append( - receiver_derivs( - survey, - mesh, - fields, - block, + try: + client = get_client() + worker = self.worker + except ValueError: + client = None + worker = None + + if client: + fields_array = client.scatter(f[:, self._solutionType], workers=worker) + fields = client.scatter(f, workers=worker) + survey = client.scatter(self.survey, workers=worker) + mesh = client.scatter(self.mesh, workers=worker) + simulation = client.scatter(self, workers=worker) + for block in blocks: + blocks_receiver_derivs.append( + client.submit( + receiver_derivs, + survey, + mesh, + fields, + block, + workers=worker, + ) + ) + else: + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + simulation = delayed(self) + delayed_derivs = delayed(receiver_derivs) + for block in blocks: + blocks_receiver_derivs.append( + delayed_derivs( + survey, + mesh, + fields, + block, + ) ) - ) - - # with Client(processes=False) as client: - # with performance_report(filename="dask-report.html"): # Dask process for all derivatives - blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + if client: + blocks_receiver_derivs = client.gather(blocks_receiver_derivs) + else: + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - for block_derivs_chunks, addresses_chunks in tqdm( - zip(blocks_receiver_derivs, blocks), - ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", + for block_derivs_chunks, addresses_chunks in zip( + blocks_receiver_derivs, blocks, strict=True ): Jmatrix = parallel_block_compute( - self, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + simulation, + m, + Jmatrix, + block_derivs_chunks, + A_i, + fields_array, + addresses_chunks, + client, + worker, + store_sensitivities=self.store_sensitivities, ) - for A in self.Ainv.values(): + for A in Ainv.values(): A.clean() + del Ainv + gc.collect() if self.store_sensitivities == "disk": del Jmatrix - return array.from_zarr(self.sensitivity_path) - else: - return Jmatrix - + Jmatrix = array.from_zarr(self.sensitivity_path) -Sim.compute_J = compute_J + return Jmatrix def parallel_block_compute( - self, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + simulation, + m, + Jmatrix, + blocks_receiver_derivs, + A_i, + fields_array, + addresses, + client, + worker=None, + store_sensitivities="disk", ): - m_size = self.model.size + m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() - ATinvdf_duT = delayed(A_i * block_stack) + + ATinvdf_duT = A_i * block_stack + + if client: + ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker) + else: + ATinvdf_duT = delayed(ATinvdf_duT) count = 0 rows = [] block_delayed = [] @@ -259,96 +324,64 @@ def parallel_block_compute( for address, dfduT in zip(addresses, blocks_receiver_derivs): n_cols = dfduT.shape[1] n_rows = address[1][2] - block_delayed.append( - array.from_delayed( - eval_block( - self, + + if client: + block_delayed.append( + client.submit( + eval_block, + simulation, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, - ), - dtype=np.float32, - shape=(n_rows, m_size), + workers=worker, + ) + ) + else: + delayed_eval = delayed(eval_block) + block_delayed.append( + array.from_delayed( + delayed_eval( + simulation, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), + ) ) - ) count += n_cols rows += address[1][1].tolist() indices = np.hstack(rows) - if self.store_sensitivities == "disk": + if client: + block_delayed = client.gather(block_delayed) + block = np.vstack(block_delayed) + else: + block = compute(array.vstack(block_delayed))[0] + + if store_sensitivities == "disk": Jmatrix.set_orthogonal_selection( (indices, slice(None)), - compute(array.vstack(block_delayed))[0], + block, ) else: # Dask process to compute row and store - Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + Jmatrix[indices, :] = block return Jmatrix -@delayed -def receiver_derivs(survey, mesh, fields, blocks): - field_derivatives = [] - for address in blocks: - source = survey.source_list[address[0][0]] - receiver = source.receiver_list[address[0][1]] - - if isinstance(source, PlanewaveXYPrimary): - v = np.eye(receiver.nD, dtype=float) - else: - v = sp.csr_matrix(np.ones(receiver.nD), dtype=float) - - # Assume the derivatives in terms of model are Zero (seems to always be case) - dfduT, _ = receiver.evalDeriv( - source, mesh, fields, v=v[:, address[1][0]], adjoint=True - ) - field_derivatives.append(dfduT) - - return field_derivatives - - -@delayed -def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): - """ - Evaluate the sensitivities for the block or data - """ - if Ainv_deriv_u.ndim == 1: - deriv_columns = Ainv_deriv_u[:, np.newaxis] - else: - deriv_columns = Ainv_deriv_u[:, deriv_indices] - - n_receivers = address[1][2] - source = simulation.survey.source_list[address[0][0]] - - if isinstance(source, PlanewaveXYPrimary): - source_fields = fields - n_cols = 2 - else: - source_fields = fields[:, address[0][0]] - n_cols = 1 - - n_cols *= n_receivers - - dA_dmT = simulation.getADeriv( - source.frequency, - source_fields, - deriv_columns, - adjoint=True, - ) - dRHS_dmT = simulation.getRHSDeriv( - source.frequency, - source, - deriv_columns, - adjoint=True, - ) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(deriv_m, Zero): - du_dmT += deriv_m - - return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T +Sim.parallel_block_compute = parallel_block_compute +Sim.compute_J = compute_J +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix +Sim.fields = fields +Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 81f2db5a0b..28f3c12205 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -2,6 +2,12 @@ BaseIPSimulation as Sim, ) +from ..resistivity.simulation import ( + compute_J, + getSourceTerm, +) + + from .....data import Data import dask.array as da from dask.distributed import Future @@ -10,18 +16,8 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" - -from ..resistivity.simulation import ( - compute_J, - dask_getSourceTerm, -) - -Sim.compute_J = compute_J -Sim.getSourceTerm = dask_getSourceTerm - -def dask_fields(self, m=None, return_Ainv=False): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -30,7 +26,7 @@ def dask_fields(self, m=None, return_Ainv=False): RHS = self.getRHS() f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + f[:, self._solutionType] = Ainv * np.asarray(RHS.todense()) if self._scale is None: scale = Data(self.survey, np.ones(self.survey.nD)) @@ -44,16 +40,13 @@ def dask_fields(self, m=None, return_Ainv=False): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) self._scale = scale.dobs + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = dask_fields - - -def dask_dpred(self, m=None, f=None, compute_J=False): +def dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -72,51 +65,32 @@ def dask_dpred(self, m=None, f=None, compute_J=False): "data. Please set the survey for the simulation: " "simulation.survey = survey" ) - if self._Jmatrix is None or self._scale is None: - if m is None: - m = self.model - f = self.fields(m, return_Ainv=True) - self._Jmatrix = self.compute_J(f=f) data = self.Jvec(m, m) - if compute_J: - return np.asarray(data), self._Jmatrix - return np.asarray(data) -Sim.dpred = dask_dpred - - -def dask_getJtJdiag(self, m, W=None): +def getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ """ self.model = m if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish if W is None: W = self._scale * np.ones(self.nD) else: W = (self._scale * W.diagonal()) ** 2.0 - diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, da.Array): - diag = np.asarray(diag.compute()) + diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) self._jtjdiag = diag return self._jtjdiag -Sim.getJtJdiag = dask_getJtJdiag - - -def dask_Jvec(self, m, v, f=None): +def Jvec(self, m, v, f=None): """ Compute sensitivity matrix (J) and vector (v) product. """ @@ -131,10 +105,7 @@ def dask_Jvec(self, m, v, f=None): return self._scale.astype(np.float32) * da.dot(self.Jmatrix, v).astype(np.float32) -Sim.Jvec = dask_Jvec - - -def dask_Jtvec(self, m, v, f=None): +def Jtvec(self, m, v, f=None): """ Compute adjoint sensitivity matrix (J^T) and vector (v) product. """ @@ -146,7 +117,13 @@ def dask_Jtvec(self, m, v, f=None): if isinstance(self.Jmatrix, Future): self.Jmatrix # Wait to finish - return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) + return da.dot((v * self._scale).astype(np.float32), self.Jmatrix).astype(np.float32) -Sim.Jtvec = dask_Jtvec +Sim.compute_J = compute_J +Sim.getSourceTerm = getSourceTerm +Sim.Jtvec = Jtvec +Sim.Jvec = Jvec +Sim.getJtJdiag = getJtJdiag +Sim.dpred = dpred +Sim.fields = fields diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index 963fc12451..9d9b4b2657 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -6,22 +6,11 @@ import numcodecs numcodecs.blosc.use_threads = False +from .simulation import getJtJdiag, Jvec, Jtvec, dpred +from ..resistivity.simulation_2d import compute_J, getSourceTerm -Sim.sensitivity_path = "./sensitivity/" -from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec, dask_dpred - -from ..resistivity.simulation_2d import compute_J, dask_getSourceTerm - -Sim.compute_J = compute_J -Sim.getSourceTerm = dask_getSourceTerm -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.dpred = dask_dpred - - -def dask_fields(self, m=None, return_Ainv=False): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -50,10 +39,16 @@ def dask_fields(self, m=None, return_Ainv=False): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) self._scale = scale.dobs + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = dask_fields +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.dpred = dpred +Sim.fields = fields +Sim.compute_J = compute_J +Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 292cbfa206..3605218046 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,8 +1,12 @@ -from simpeg.dask.simulation import dask_dpred, dask_Jvec, dask_Jtvec, dask_getJtJdiag -from .....electromagnetics.static.resistivity.simulation import BaseDCSimulation as Sim +from .....electromagnetics.static.resistivity.simulation import Simulation3DNodal as Sim + +from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix + from .....utils import Zero +from dask.distributed import get_client import dask.array as da import numpy as np +from scipy import sparse as sp import zarr import numcodecs @@ -13,41 +17,32 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" - -Sim.dpred = dask_dpred -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] - -def dask_fields(self, m=None, return_Ainv=False): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: + return self._stashed_fields + A = self.getA() Ainv = self.solver(A, **self.solver_opts) RHS = self.getRHS() f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + f[:, self._solutionType] = Ainv * np.asarray(RHS.todense()) + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = dask_fields - - -def compute_J(self, f=None): +def compute_J(self, m, f=None): - if f is None: - f = self.fields(self.model, return_Ainv=True) + f, Ainv = self.fields(m=m, return_Ainv=True) - m_size = self.model.size + m_size = m.size row_chunks = int( np.ceil( float(self.survey.nD) @@ -72,7 +67,7 @@ def compute_J(self, f=None): for rx in source.receiver_list: - if rx.orientation is not None: + if getattr(rx, "orientation", None) is not None: projected_grid = f._GLoc(rx.projField) + rx.orientation else: projected_grid = f._GLoc(rx.projField) @@ -87,7 +82,7 @@ def compute_J(self, f=None): df_duT, df_dmT = df_duTFun( source, None, PTv[:, start:end], adjoint=True ) - ATinvdf_duT = self.Ainv * df_duT + ATinvdf_duT = Ainv * df_duT dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) du_dmT = -dA_dmT @@ -131,19 +126,29 @@ def compute_J(self, f=None): else: Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - self.Ainv.clean() + Ainv.clean() if self.store_sensitivities == "disk": del Jmatrix - return da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") else: - return Jmatrix + self._Jmatrix = Jmatrix + return self._Jmatrix -Sim.compute_J = compute_J +def source_eval(simulation, sources, indices): + """ + Evaluate the source term for the given source and index + """ + blocks = [] + for ind in indices: + blocks.append(sources[ind].eval(simulation)) + + return sp.csr_matrix(np.vstack(blocks).T) -def dask_getSourceTerm(self): + +def getSourceTerm(self): """ Evaluates the sources, and puts them in matrix form :rtype: tuple @@ -153,25 +158,39 @@ def dask_getSourceTerm(self): if getattr(self, "_q", None) is None: if self._mini_survey is not None: - Srcs = self._mini_survey.source_list + source_list = self._mini_survey.source_list else: - Srcs = self.survey.source_list - - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError - - elif self._formulation == "HJ": - n = self.mesh.nC - - q = np.zeros((n, len(Srcs)), order="F") + source_list = self.survey.source_list + + indices = np.arange(len(source_list)) + try: + + client = get_client() + sim = client.scatter(self, workers=self.worker) + future_list = client.scatter(source_list, workers=self.worker) + indices = np.array_split(indices, self.n_threads(client=client)) + blocks = [] + for ind in indices: + blocks.append( + client.submit( + source_eval, sim, future_list, ind, workers=self.worker + ) + ) - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + blocks = sp.hstack(client.gather(blocks)) + except ValueError: + blocks = source_eval(self, source_list, indices) - self._q = q + self._q = blocks return self._q -Sim.getSourceTerm = dask_getSourceTerm +Sim.getSourceTerm = getSourceTerm +Sim.fields = fields +Sim.compute_J = compute_J + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index 08b5ba08de..bb430ec2c1 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,7 +1,8 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( - BaseDCSimulation2D as Sim, + Simulation2DNodal as Sim, ) -from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec +from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix + import dask.array as da import numpy as np import zarr @@ -9,18 +10,14 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" - -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] - -def dask_fields(self, m=None, return_Ainv=False): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: + return self._stashed_fields + kys = self._quad_points f = self.fieldsPair(self) f._quad_weights = self._quad_weights @@ -33,23 +30,19 @@ def dask_fields(self, m=None, return_Ainv=False): RHS = self.getRHS(ky) f[:, self._solutionType, iky] = Ainv[iky] * RHS + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = dask_fields - - -def compute_J(self, f=None): +def compute_J(self, m, f=None): kys = self._quad_points weights = self._quad_weights - if f is None: - f = self.fields(self.model, return_Ainv=True) + f, Ainv = self.fields(m, return_Ainv=True) - m_size = self.model.size + m_size = m.size row_chunks = int( np.ceil( float(self.survey.nD) @@ -79,7 +72,7 @@ def compute_J(self, f=None): for i_src, source in enumerate(self.survey.source_list): for rx in source.receiver_list: - if rx.orientation is not None: + if getattr(rx, "orientation", None) is not None: projected_grid = f._GLoc(rx.projField) + rx.orientation else: projected_grid = f._GLoc(rx.projField) @@ -95,7 +88,7 @@ def compute_J(self, f=None): u_ky = f[:, self._solutionType, iky] u_source = u_ky[:, i_src] - ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] + ATinvdf_duT = Ainv[iky] * PTv[:, start:end] dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) du_dmT = -weights[iky] * dA_dmT block += du_dmT.T.reshape((-1, m_size)) @@ -131,19 +124,18 @@ def compute_J(self, f=None): Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) for iky, _ in enumerate(kys): - self.Ainv[iky].clean() + Ainv[iky].clean() if self.store_sensitivities == "disk": del Jmatrix - return da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") else: - return Jmatrix + self._Jmatrix = Jmatrix - -Sim.compute_J = compute_J + return self._Jmatrix -def dask_dpred(self, m=None, f=None, compute_J=False): +def dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -172,7 +164,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) + f = self.fields(m) temp = np.empty(survey.nD) count = 0 @@ -182,17 +174,10 @@ def dask_dpred(self, m=None, f=None, compute_J=False): temp[count : count + len(d)] = d count += len(d) - if compute_J: - Jmatrix = self.compute_J(f=f) - return self._mini_survey_data(temp), Jmatrix - return self._mini_survey_data(temp) -Sim.dpred = dask_dpred - - -def dask_getSourceTerm(self, _): +def getSourceTerm(self, _): """ Evaluates the sources, and puts them in matrix form :rtype: tuple @@ -223,4 +208,12 @@ def dask_getSourceTerm(self, _): return self._q -Sim.getSourceTerm = dask_getSourceTerm +Sim.fields = fields +Sim.compute_J = compute_J +Sim.dpred = dpred +Sim.getSourceTerm = getSourceTerm + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 569d9f6ef1..fe68e234f1 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -2,100 +2,26 @@ import dask.array import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim + from ....utils import Zero -from simpeg.fields import TimeFields -from multiprocessing import cpu_count +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix + import numpy as np import scipy.sparse as sp from dask import array, delayed +from dask.distributed import get_client -from simpeg.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc -from time import time -from tqdm import tqdm - -Sim.sensitivity_path = "./sensitivity/" -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] - - -@delayed -def field_projection(field_array, src_list, array_ind, time_ind, func): - fieldI = field_array[:, :, array_ind] - if fieldI.shape[0] == fieldI.size: - fieldI = mkvc(fieldI, 2) - new_array = func(fieldI, src_list, time_ind) - if new_array.ndim == 1: - new_array = new_array[:, np.newaxis, np.newaxis] - elif new_array.ndim == 2: - new_array = new_array[:, :, np.newaxis] - - return new_array - - -def _getField(self, name, ind, src_list): - srcInd, timeInd = ind - - if name in self._fields: - out = self._fields[name][:, srcInd, timeInd] - else: - # Aliased fields - alias, loc, func = self.aliasFields[name] - if isinstance(func, str): - assert hasattr(self, func), ( - "The alias field function is a string, but it does " - "not exist in the Fields class." - ) - func = getattr(self, func) - pointerFields = self._fields[alias][:, srcInd, timeInd] - pointerShape = self._correctShape(alias, ind) - pointerFields = pointerFields.reshape(pointerShape, order="F") - - # First try to return the function as three arguments (without timeInd) - if timeInd == slice(None, None, None): - try: - # assume it will take care of integrating over all times - return func(pointerFields, srcInd) - except TypeError: - pass - - timeII = np.arange(self.simulation.nT + 1)[timeInd] - if not isinstance(src_list, list): - src_list = [src_list] - - if timeII.size == 1: - pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) - pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") - out = func(pointerFields, src_list, timeII) - else: # loop over the time steps - arrays = [] - - for i, TIND_i in enumerate(timeII): # Need to parallelize this - arrays.append( - array.from_delayed( - field_projection(pointerFields, src_list, i, TIND_i, func), - dtype=np.float32, - shape=(pointerShape[0], pointerShape[1], 1), - ) - ) - - out = array.dstack(arrays).compute() - - shape = self._correctShape(name, ind, deflate=True) - return out.reshape(shape, order="F") - - -TimeFields._getField = _getField - def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: + return self._stashed_fields + f = self.fieldsPair(self) f[:, self._fieldType + "Solution", 0] = self.getInitialFields() Ainv = {} @@ -107,49 +33,75 @@ def fields(self, m=None, return_Ainv=False): Asubdiag = self.getAsubdiag(tInd) rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] - if ( np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) > 1e-8 ): rhs += self.getRHS(tInd + 1) - sol = Ainv[dt] * rhs + sol = Ainv[dt] * np.asarray(rhs) f[:, self._fieldType + "Solution", tInd + 1] = sol + self._stashed_fields = f if return_Ainv: - self.Ainv = Ainv - + return f, Ainv return f -Sim.fields = fields - - -@delayed -def source_evaluation(simulation, sources, time_channel): - s_m, s_e = [], [] - for source in sources: - sm, se = source.eval(simulation, time_channel) - s_m.append(sm) - s_e.append(se) - - return s_m, s_e - - -def dask_getSourceTerm(self, tInd): +def getSourceTerm(self, tInd): """ Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ + if ( + getattr(self, "_stashed_sources", None) is not None + and tInd in self._stashed_sources + ): + return self._stashed_sources[tInd] + elif getattr(self, "_stashed_sources", None) is None: + self._stashed_sources = {} + + try: + client = get_client() + sim = client.scatter(self, workers=self.worker) + except ValueError: + client = None + sim = self + source_list = self.survey.source_list - source_block = np.array_split(source_list, cpu_count()) + source_block = np.array_split( + np.arange(len(source_list)), self.n_threads(client=client) + ) + + if client: + sim = client.scatter(self, workers=self.worker) + source_list = client.scatter(source_list, workers=self.worker) + else: + delayed_source_eval = delayed(source_evaluation) + sim = self block_compute = [] for block in source_block: - block_compute.append(source_evaluation(self, block, self.times[tInd])) + if client: + block_compute.append( + client.submit( + source_evaluation, + sim, + block, + self.times[tInd], + source_list, + workers=self.worker, + ) + ) + else: + block_compute.append( + delayed_source_eval(self, block, self.times[tInd], source_list) + ) - blocks = dask.compute(block_compute)[0] + if client: + blocks = client.gather(block_compute) + else: + blocks = dask.compute(block_compute)[0] s_m, s_e = [], [] for block in blocks: @@ -158,166 +110,251 @@ def dask_getSourceTerm(self, tInd): s_e.append(block[1]) if isinstance(s_m[0][0], Zero): - return Zero(), np.vstack(s_e).T - - return np.vstack(s_m).T, np.vstack(s_e).T - - -Sim.getSourceTerm = dask_getSourceTerm - - -@delayed -def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): - data = [] - for _, ind, receiver in block: - Ps = receiver.getSpatialP(mesh, fields) - Pt = receiver.getTimeP(time_mesh, fields) - vector = (Pt * (Ps * fields_array[:, ind, :]).T).flatten() - - data.append(vector) - - return np.hstack(data) + self._stashed_sources[tInd] = Zero(), sp.csr_matrix(np.vstack(s_e).T) + else: + self._stashed_sources[tInd] = sp.csr_matrix(np.vstack(s_m).T), sp.csr_matrix( + np.vstack(s_e).T + ) + return self._stashed_sources[tInd] -def dask_dpred(self, m=None, f=None, compute_J=False): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - .. math:: +def compute_J(self, m, f=None): + """ + Compute the rows for the sensitivity matrix. + """ + if f is None: + f, Ainv = self.fields(m=m, return_Ainv=True) - d_\\text{pred} = P(f(m)) + try: + client = get_client() + except ValueError: + client = None - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" + ftype = self._fieldType + "Solution" + sens_name = self.sensitivity_path[:-5] + if self.store_sensitivities == "disk": + rows = array.zeros( + (self.survey.nD, m.size), + chunks=(self.max_chunk_size, m.size), + dtype=np.float32, ) + Jmatrix = array.to_zarr( + rows, + os.path.join(sens_name + "_1.zarr"), + compute=True, + return_stored=True, + overwrite=True, + ) + else: + Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) - if f is None: - if m is None: - m = self.model - f = self.fields(m, return_Ainv=compute_J) - - rows = [] - receiver_projection = self.survey.source_list[0].receiver_list[0].projField - fields_array = f[:, receiver_projection, :] + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, + compute_row_size, + thread_count=self.n_threads(client=client), + ) + fields_array = f[:, ftype, :] if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] - all_receivers = [] + times_field_derivs, Jmatrix = compute_field_derivs( + self, f, blocks, Jmatrix, fields_array.shape, client + ) - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) + ATinv_df_duT_v = [[] for _ in blocks] - receiver_blocks = np.array_split(all_receivers, cpu_count()) + if client: + fields_array = client.scatter(fields_array, workers=self.worker) + sim = client.scatter(self, workers=self.worker) + else: + delayed_compute_rows = delayed(compute_rows) + sim = self + for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): + AdiagTinv = Ainv[dt] + j_row_updates = [] + time_mask = data_times > simulation_times[tInd] - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: + if not np.any(time_mask): continue - rows.append( - array.from_delayed( - evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), - dtype=np.float64, - shape=(n_data,), + for ind, (block, field_deriv) in enumerate( + zip(blocks, times_field_derivs[tInd + 1], strict=True) + ): + ATinv_df_duT_v[ind] = get_field_deriv_block( + self, + block, + field_deriv, + tInd, + AdiagTinv, + ATinv_df_duT_v[ind], + time_mask, + client, ) - ) - data = array.hstack(rows).compute() + if len(block) == 0: + continue - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix + if client: + field_derivatives = client.scatter( + ATinv_df_duT_v[ind], workers=self.worker + ) + j_row_updates.append( + client.submit( + compute_rows, + sim, + tInd, + block, + field_derivatives, + fields_array, + time_mask, + workers=self.worker, + ) + ) + else: + j_row_updates.append( + array.from_delayed( + delayed_compute_rows( + sim, + tInd, + block, + ATinv_df_duT_v[ind], + fields_array, + time_mask, + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), + ) + ) - return data + if client: + j_row_updates = np.vstack(client.gather(j_row_updates)) + else: + j_row_updates = array.vstack(j_row_updates).compute() + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" + array.to_zarr( + Jmatrix + j_row_updates, + sens_name, + compute=True, + overwrite=True, + ) + Jmatrix = array.from_zarr(sens_name) + else: + Jmatrix += j_row_updates -Sim.dpred = dask_dpred -Sim.field_derivs = None + for A in Ainv.values(): + A.clean() + if self.store_sensitivities == "ram": + self._Jmatrix = np.asarray(Jmatrix) -@delayed -def delayed_block_deriv( - n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape -): - """Compute derivatives for sources and receivers in a block""" - df_duT = [] - j_updates = [] + self._Jmatrix = Jmatrix - for indices, arrays in chunks: - j_update = 0.0 - source = source_list[indices[0]] - receiver = source.receiver_list[indices[1]] + return self._Jmatrix - spatialP = receiver.getSpatialP(mesh, fields) - timeP = receiver.getTimeP(time_mesh, fields) - derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) - time_derivs = [] - for time_index in range(n_times + 1): - if len(timeP[:, time_index].data) == 0: - time_derivs.append( - sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) - ) - j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - continue +def field_projection(field_array, src_list, array_ind, time_ind, func): + fieldI = field_array[:, :, array_ind] + if fieldI.shape[0] == fieldI.size: + fieldI = mkvc(fieldI, 2) + new_array = func(fieldI, src_list, time_ind) + if new_array.ndim == 1: + new_array = new_array[:, np.newaxis, np.newaxis] + elif new_array.ndim == 2: + new_array = new_array[:, :, np.newaxis] - projection = sp.kron(timeP[:, time_index], spatialP, format="csr") - cur = derivative_fun( - time_index, - source, - None, - projection.T, - adjoint=True, - ) + return new_array - time_derivs.append(cur[0][:, arrays[0]]) - if not isinstance(cur[1], Zero): - j_update += cur[1].T - else: - j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) +def source_evaluation(simulation, indices, time_channel, sources): + s_m, s_e = [], [] + for ind in indices: + sm, se = sources[ind].eval(simulation, time_channel) + s_m.append(sm) + s_e.append(se) - j_updates.append(j_update) - df_duT.append(time_derivs) + return s_m, s_e - return df_duT, j_updates + +def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): + data = [] + for _, ind, receiver in block: + Ps = receiver.getSpatialP(mesh, fields) + Pt = receiver.getTimeP(time_mesh, fields) + vector = (Pt * (Ps * fields_array[:, ind, :]).T).flatten() + + data.append(vector) + + return np.hstack(data) -def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): +def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): """ Compute the derivative of the fields """ delayed_chunks = [] + + if client: + mesh = client.scatter(self.mesh, workers=self.worker) + time_mesh = client.scatter(self.time_mesh, workers=self.worker) + fields = client.scatter(fields, workers=self.worker) + source_list = client.scatter(self.survey.source_list, workers=self.worker) + else: + mesh = self.mesh + time_mesh = self.time_mesh + delayed_block_deriv = delayed(block_deriv) + source_list = self.survey.source_list + for chunks in blocks: if len(chunks) == 0: continue - delayed_block = delayed_block_deriv( - simulation.nT, - chunks, - fields_shape[0], - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - simulation.model.size, - ) - delayed_chunks.append(delayed_block) + if client: + delayed_chunks.append( + client.submit( + block_deriv, + self.nT, + chunks, + fields_shape[0], + source_list, + mesh, + time_mesh, + fields, + self.model.size, + workers=self.worker, + ) + ) + else: + delayed_chunks.append( + delayed_block_deriv( + self.nT, + chunks, + fields_shape[0], + source_list, + self.mesh, + self.time_mesh, + fields, + self.model.size, + ) + ) + + if client: + result = client.gather(delayed_chunks) + else: + result = dask.compute(delayed_chunks)[0] - result = dask.compute(delayed_chunks)[0] df_duT = [ [[[] for _ in block] for block in blocks if len(block) > 0] - for _ in range(simulation.nT + 1) + for _ in range(self.nT + 1) ] j_updates = [] @@ -331,67 +368,39 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): if len(j_updates.data) > 0: Jmatrix += j_updates - if simulation.store_sensitivities == "disk": - sens_name = simulation.sensitivity_path[:-5] + f"_{time() % 2}.zarr" + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{time() % 2}.zarr" array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) Jmatrix = array.from_zarr(sens_name) return df_duT, Jmatrix -@delayed -def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd -): - if (s_id, r_id, b_id) not in ATinv_df_duT_v: - # last timestep (first to be solved) - stacked_block = field_derivs.toarray()[:, local_ind] - - else: - stacked_block = np.asarray( - field_derivs[:, local_ind] - - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] - ) - - return stacked_block - - -def update_deriv_blocks(address, indices, derivatives, solve, shape): - if address not in derivatives: - deriv_array = np.zeros(shape) - else: - deriv_array = derivatives[address] - - if address in indices: - columns, local_ind = indices[address] - if solve is not None: - deriv_array[:, local_ind] = solve[:, columns] - - derivatives[address] = deriv_array - - def get_field_deriv_block( - simulation, + self, block: list, field_derivs: list, tInd: int, AdiagTinv, - ATinv_df_duT_v: dict, + ATinv_df_duT_v, time_mask, + client, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ stacked_blocks = [] - indices = {} + if len(ATinv_df_duT_v) == 0: + ATinv_df_duT_v = [[] for _ in block] + indices = [] count = 0 Asubdiag = None - if tInd < simulation.nT - 1: - Asubdiag = simulation.getAsubdiag(tInd + 1) + if tInd < self.nT - 1: + Asubdiag = self.getAsubdiag(tInd + 1) - for ((s_id, r_id, b_id), (rx_ind, _, shape)), field_deriv in zip( - block, field_derivs + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( + block, field_derivs, ATinv_df_duT_v ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] @@ -400,49 +409,110 @@ def get_field_deriv_block( if len(local_ind) < 1: continue - indices[(s_id, r_id, b_id)] = ( - np.arange(count, count + len(local_ind)), - local_ind, + indices.append( + (np.arange(count, count + len(local_ind)), local_ind), ) count += len(local_ind) - deriv_comp = deriv_block( - s_id, - r_id, - b_id, - ATinv_df_duT_v, - Asubdiag, - local_ind, - field_deriv, - tInd, - ) - stacked_blocks.append( - array.from_delayed( - deriv_comp, - dtype=float, - shape=( - field_deriv.shape[0], - len(local_ind), - ), + + if len(ATinv_chunk) == 0: + # last timestep (first to be solved) + stacked_block = field_deriv.toarray()[:, local_ind] + + else: + stacked_block = np.asarray( + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] ) - ) + + stacked_blocks.append(stacked_block) + if len(stacked_blocks) > 0: - blocks = array.hstack(stacked_blocks).compute() + blocks = np.hstack(stacked_blocks) + solve = (AdiagTinv * blocks).reshape(blocks.shape) else: solve = None - for (address, arrays), field_deriv in zip(block, field_derivs): - shape = ( - field_deriv.shape[0], - len(arrays[0]), - ) + updated_ATinv_df_duT_v = [] + for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip( + block, field_derivs, ATinv_df_duT_v, indices, strict=True + ): + + if len(ATinv_chunk) == 0: + shape = ( + field_deriv.shape[0], + len(arrays[0]), + ) + ATinv_chunk = np.zeros(shape, dtype=np.float32) + + if solve is None: + continue + + ATinv_chunk[:, local_ind] = solve[:, columns] + updated_ATinv_df_duT_v.append(ATinv_chunk) + + return updated_ATinv_df_duT_v + + +def block_deriv( + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape +): + """Compute derivatives for sources and receivers in a block""" + df_duT = [] + j_updates = [] + + for indices, arrays in chunks: + j_update = 0.0 + source = source_list[indices[0]] + receiver = source.receiver_list[indices[1]] + + spatialP = receiver.getSpatialP(mesh, fields) + timeP = receiver.getTimeP(time_mesh, fields) + + derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) + time_derivs = [] + for time_index in range(n_times + 1): + if len(timeP[:, time_index].data) == 0: + time_derivs.append( + sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) + ) + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + continue + + projection = sp.kron(timeP[:, time_index], spatialP, format="csr") + cur = derivative_fun( + time_index, + source, + None, + projection.T, + adjoint=True, + ) + + time_derivs.append(cur[0][:, arrays[0]]) - update_deriv_blocks(address, indices, ATinv_df_duT_v, solve, shape) + if not isinstance(cur[1], Zero): + j_update += cur[1].T + else: + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - return ATinv_df_duT_v + j_updates.append(j_update) + df_duT.append(time_derivs) + + return df_duT, j_updates + + +def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): + if len(ATinv_df_duT_v) == 0: + # last timestep (first to be solved) + stacked_block = field_derivs.toarray()[:, local_ind] + + else: + stacked_block = np.asarray( + field_derivs[:, local_ind] - Asubdiag.T * ATinv_df_duT_v[:, local_ind] + ) + + return stacked_block -@delayed def compute_rows( simulation, tInd, @@ -456,7 +526,7 @@ def compute_rows( """ rows = [] - for address, ind_array in chunks: + for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] local_ind = np.arange(len(ind_array[0]))[time_check] @@ -468,7 +538,6 @@ def compute_rows( rows.append(row_block) continue - field_derivs = ATinv_df_duT_v[address] dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], @@ -496,98 +565,10 @@ def compute_rows( return np.vstack(rows) -def compute_J(self, f=None): - """ - Compute the rows for the sensitivity matrix. - """ - if f is None: - f = self.fields(self.model, return_Ainv=True) - - ftype = self._fieldType + "Solution" - sens_name = self.sensitivity_path[:-5] - if self.store_sensitivities == "disk": - rows = array.zeros( - (self.survey.nD, self.model.size), - chunks=(self.max_chunk_size, self.model.size), - dtype=np.float32, - ) - Jmatrix = array.to_zarr( - rows, - os.path.join(sens_name + "_1.zarr"), - compute=True, - return_stored=True, - overwrite=True, - ) - else: - Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float64) - - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - compute_row_size = np.ceil(self.max_chunk_size / (self.model.shape[0] * 8.0 * 1e-6)) - blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) - fields_array = f[:, ftype, :] - - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] - - times_field_derivs, Jmatrix = compute_field_derivs( - self, f, blocks, Jmatrix, fields_array.shape - ) - - ATinv_df_duT_v = {} - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - AdiagTinv = self.Ainv[dt] - j_row_updates = [] - time_mask = data_times > simulation_times[tInd] - - if not np.any(time_mask): - continue - - for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): - ATinv_df_duT_v = get_field_deriv_block( - self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask - ) - - if len(block) == 0: - continue - - j_row_updates.append( - array.from_delayed( - compute_rows( - self, - tInd, - block, - ATinv_df_duT_v, - fields_array, - time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - self.model.size, - ), - ) - ) - - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr( - Jmatrix + array.vstack(j_row_updates), - sens_name, - compute=True, - overwrite=True, - ) - Jmatrix = array.from_zarr(sens_name) - else: - Jmatrix += array.vstack(j_row_updates).compute() - - for A in self.Ainv.values(): - A.clean() - - if self.store_sensitivities == "ram": - return np.asarray(Jmatrix) - - return Jmatrix - - +Sim.fields = fields +Sim.getSourceTerm = getSourceTerm Sim.compute_J = compute_J +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 9c0d8d058d..aeb2da9878 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,130 +1,24 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from time import time -from datetime import timedelta -from dask.distributed import Future, get_client -import dask.array as da + +from .objective_function import DaskComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse -from ..data_misfit import BaseDataMisfit -from ..objective_function import BaseObjectiveFunction, ComboObjectiveFunction - - -def dask_getFields(self, m, store=False, deleteWarmstart=True): - f = None - - # try: - # client = get_client() - # fields = lambda f, x, workers: client.compute(f(x), workers=workers) - # except: - # fields = lambda f, x: f(x) - - for mtest, u_ofmtest in self.warmstart: - if m is mtest: - f = u_ofmtest - if self.debug: - print("InvProb is Warm Starting!") - break - - if f is None: - if isinstance(self.dmisfit, BaseDataMisfit): - if self.dmisfit.model_map is not None: - vec = self.dmisfit.model_map @ m - else: - vec = m - - f = fields(self.dmisfit.simulation.fields, vec) - - elif isinstance(self.dmisfit, BaseObjectiveFunction): - f = [] - for objfct in self.dmisfit.objfcts: - if hasattr(objfct, "simulation"): - if objfct.model_map is not None: - vec = objfct.model_map @ m - else: - vec = m - - f += [fields(objfct.simulation.fields, vec, objfct.workers)] - else: - f += [] - - if isinstance(f, Future) or isinstance(f[0], Future): - f = client.gather(f) - - if deleteWarmstart: - self.warmstart = [] - if store: - self.warmstart += [(m, f)] +from ..objective_function import ComboObjectiveFunction +from simpeg.utils import call_hooks +from simpeg.version import __version__ as simpeg_version - return f - -BaseInvProblem.getFields = dask_getFields - - -def get_dpred(self, m, f=None, compute_J=False): +def get_dpred(self, m, f=None): dpreds = [] - if isinstance(self.dmisfit, BaseDataMisfit): - return self.dmisfit.simulation.dpred(m) - elif isinstance(self.dmisfit, BaseObjectiveFunction): - for i, objfct in enumerate(self.dmisfit.objfcts): - if hasattr(objfct, "simulation"): - if getattr(objfct, "model_map", None) is not None: - vec = objfct.model_map @ m - else: - vec = m - - compute_sensitivities = compute_J and ( - objfct.simulation._Jmatrix is None - ) - - if compute_sensitivities and i == 0: - print("Computing forward & sensitivities") - - if objfct.workers is not None: - client = get_client() - future = client.compute( - objfct.simulation.dpred(vec, compute_J=compute_sensitivities), - workers=objfct.workers, - ) - else: - # For locals, the future is now - ct = time() - - future = objfct.simulation.dpred( - vec, compute_J=compute_sensitivities - ) - - if compute_sensitivities: - runtime = time() - ct - total = len(self.dmisfit.objfcts) - - message = f"{i+1} of {total} in {timedelta(seconds=runtime)}. " - if (total - i - 1) > 0: - message += ( - f"ETA -> {timedelta(seconds=(total - i - 1) * runtime)}" - ) - print(message) - - dpreds += [future] + if isinstance(self.dmisfit, DaskComboMisfits): + return self.dmisfit.get_dpred(m, f=f) - else: - dpreds += [] + for objfct in self.dmisfit.objfcts: + dpred = objfct.simulation.dpred(m, f=f) + dpreds += [np.asarray(dpred)] - if isinstance(dpreds[0], Future): - client = get_client() - dpreds = client.gather(dpreds) - - preds = [] - if isinstance(dpreds[0], tuple): # Jmatrix was computed - for future, objfct in zip(dpreds, self.dmisfit.objfcts): - preds += [future[0]] - objfct.simulation._Jmatrix = future[1] - return preds - - else: - dpreds = da.compute(dpreds)[0] return dpreds @@ -134,15 +28,18 @@ def get_dpred(self, m, f=None, compute_J=False): def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m - self.dpred = self.get_dpred(m, compute_J=return_H) + self.dpred = self.get_dpred(m) + residuals = [] - phi_d = 0 - for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residual = objfct.W * (objfct.data.dobs - pred) - phi_d += np.vdot(residual, residual) + if isinstance(self.dmisfit, DaskComboMisfits): + residuals = self.dmisfit.residuals(m) + else: + for (_, objfct), pred in zip(self.dmisfit, self.dpred): + residuals.append(objfct.W * (objfct.data.dobs - pred)) - phi_d = np.asarray(phi_d) - # print(self.dpred[0]) + phi_d = 0.0 + for residual in residuals: + phi_d += np.vdot(residual, residual) reg2Deriv = [] if isinstance(self.reg, ComboObjectiveFunction): @@ -195,11 +92,8 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): out = (phi,) if return_g: - phi_dDeriv = self.dmisfit.deriv(m, f=self.dpred) - # if hasattr(self.reg.objfcts[0], "space") and self.reg.objfcts[0].space == "spherical": + phi_dDeriv = self.dmisfit.deriv(m) phi_mDeriv = self.reg.deriv(m) - # else: - # phi_mDeriv = np.sum([reg2Deriv * obj.f_m for reg2Deriv, obj in zip(self.reg2Deriv, self.reg.objfcts)], axis=0) g = np.asarray(phi_dDeriv) + self.beta * phi_mDeriv out += (g,) @@ -209,7 +103,6 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): def H_fun(v): phi_d2Deriv = self.dmisfit.deriv2(m, v) phi_m2Deriv = self.reg2Deriv * v - H = phi_d2Deriv + self.beta * phi_m2Deriv return H @@ -220,3 +113,32 @@ def H_fun(v): BaseInvProblem.evalFunction = dask_evalFunction + + +@call_hooks("startup") +def startup(self, m0): + """startup(m0) + + Called when inversion is first starting. + """ + if self.debug: + print("Calling InvProblem.startup") + + if self.print_version: + print(f"\nRunning inversion with SimPEG v{simpeg_version}") + + for fct in self.reg.objfcts: + if ( + hasattr(fct, "reference_model") + and getattr(fct, "reference_model", None) is None + ): + print("simpeg.InvProblem will set Regularization.reference_model to m0.") + fct.reference_model = m0 + + self.phi_d = np.nan + self.phi_m = np.nan + + self.model = m0 + + +BaseInvProblem.startup = startup diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 72864df5bc..cf5f70fa0f 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,167 +1,413 @@ from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction -import dask.array as da - import numpy as np -from dask.distributed import Future, get_client, Client +from dask.distributed import Client from ..data_misfit import L2DataMisfit -BaseObjectiveFunction._workers = None - +from simpeg.utils import validate_list_of_types -@property -def client(self): - if getattr(self, "_client", None) is None: - self._client = get_client() - return self._client +def _calc_fields(objfct, _): + return objfct.simulation.fields(m=objfct.simulation.model) -@client.setter -def client(self, client): - assert isinstance(client, Client) - self._client = client +def _calc_dpred(objfct, _): + return objfct.simulation.dpred(m=objfct.simulation.model) -BaseObjectiveFunction.client = client +def _calc_residual(objfct, _): + return objfct.W * ( + objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) + ) -@property -def workers(self): - return self._workers +def _deriv(objfct, multiplier, _): + return multiplier * objfct.deriv(objfct.simulation.model) -@workers.setter -def workers(self, workers): - self._workers = workers +def _deriv2(objfct, multiplier, _, v): + return multiplier * objfct.deriv2(objfct.simulation.model, v) -BaseObjectiveFunction.workers = workers +def _store_model(objfct, model): + objfct.simulation.model = model -def dask_call(self, m, f=None): - fcts = [] - multipliers = [] - for i, phi in enumerate(self): - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: +def _get_jtj_diag(objfct, _): + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) + return jtj.flatten() - if f is not None and objfct._has_fields: - fct = objfct(m, f=f[i]) - else: - fct = objfct(m) - if isinstance(fct, Future): - future = self.client.compute( - self.client.submit(da.multiply, multiplier, fct).result() - ) - fcts += [future] - else: - fcts += [fct] +def _validate_type_or_future_of_type( + property_name, + objects, + obj_type, + client, + workers: list[str] | None = None, + return_workers=False, +): - multipliers += [multiplier] + if workers is None: + workers = [ + (worker.worker_address,) for worker in client.cluster.workers.values() + ] - if isinstance(fcts[0], Future): - phi = self.client.submit( - da.sum, self.client.submit(da.vstack, fcts), axis=0 - ).result() - return phi - else: - return np.sum(np.r_[multipliers][:, None] * np.vstack(fcts), axis=0).squeeze() + objects = validate_list_of_types( + property_name, objects, obj_type, ensure_unique=True + ) + workload = [[]] + count = 0 + for obj in objects: + if count == len(workers): + count = 0 + workload.append([]) + obj.simulation.simulations[0].worker = workers[count] + future = client.scatter([obj], workers=workers[count])[0] + if hasattr(obj, "name"): + future.name = obj.name -ComboObjectiveFunction.__call__ = dask_call - - -def dask_deriv(self, m, f=None): - """ - First derivative of the composite objective function is the sum of the - derivatives of each objective function in the list, weighted by their - respective multplier. + workload[-1].append(future) + count += 1 - :param numpy.ndarray m: model - :param SimPEG.Fields f: Fields object (if applicable) - """ + futures = [] + for work in workload: - g = [] - multipliers = [] - for i, phi in enumerate(self): - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: - - if f is not None and isinstance(objfct, L2DataMisfit): - fct = objfct.deriv(m, f=f[i]) - else: - fct = objfct.deriv(m) - - if isinstance(fct, Future): - future = self.client.compute( - self.client.submit(da.multiply, multiplier, fct) + for obj, worker in zip(work, workers): + futures.append( + client.submit( + lambda v: not isinstance(v, obj_type), obj, workers=worker ) - g += [future] - else: - g += [fct] - - multipliers += [multiplier] - - if isinstance(g[0], Future): - big_future = self.client.submit( - da.sum, self.client.submit(da.vstack, g), axis=0 - ).result() - return self.client.compute(big_future).result() + ) + is_not_obj = np.array(client.gather(futures)) + if np.any(is_not_obj): + raise TypeError(f"{property_name} futures must be an instance of {obj_type}") + if return_workers: + return workload, workers else: - return np.sum(np.r_[multipliers][:, None] * np.vstack(g), axis=0).squeeze() - + return workload -ComboObjectiveFunction.deriv = dask_deriv - -def dask_deriv2(self, m, v=None, f=None): +class DaskComboMisfits(ComboObjectiveFunction): """ - Second derivative of the composite objective function is the sum of the - second derivatives of each objective function in the list, weighted by - their respective multplier. - - :param numpy.ndarray m: model - :param numpy.ndarray v: vector we are multiplying by - :param SimPEG.Fields f: Fields object (if applicable) + A composite objective function for distributed computing. """ - H = [] - multipliers = [] - for phi in self: - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: - fct = objfct.deriv2(m, v) - - if isinstance(fct, Future): - future = self.client.submit(da.multiply, multiplier, fct) - H += [future] - else: - H += [fct] - - multipliers += [multiplier] - - if isinstance(H[0], Future): - big_future = self.client.submit( - da.sum, self.client.submit(da.vstack, H), axis=0 - ).result() - - return np.asarray(big_future) - - else: - phi_deriv2 = 0 - for multiplier, h in zip(multipliers, H): - phi_deriv2 += multiplier * h - - return phi_deriv2 + def __init__( + self, + objfcts: list[BaseObjectiveFunction], + multipliers=None, + client: Client | None = None, + workers: list[str] | None = None, + **kwargs, + ): + self._model: np.ndarray | None = None + self.client = client + self.workers = workers + + super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) + + def __call__(self, m, f=None): + self.model = m + client = self.client + m_future = self._m_as_future + + values = [] + count = 0 + for futures in self._futures: + for objfct, worker in zip(futures, self._workers, strict=True): + + if self.multipliers[count] == 0.0: + continue + + values.append( + client.submit( + _calc_objective, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) + ) + count += 1 + + values = self.client.gather(values) + return np.sum(values) + + @property + def client(self): + """ + Get the dask.distributed.Client instance. + """ + return self._client + + @client.setter + def client(self, client): + if not isinstance(client, Client): + raise TypeError("client must be a dask.distributed.Client") + + self._client = client + + @property + def workers(self): + """ + List of worker addresses + """ + return self._workers + + @workers.setter + def workers(self, workers): + if not isinstance(workers, list | type(None)): + raise TypeError("workers must be a list of strings") + + self._workers = workers + + def deriv(self, m, f=None): + """ + First derivative of the composite objective function is the sum of the + derivatives of each objective function in the list, weighted by their + respective multplier. + + :param numpy.ndarray m: model + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + client = self.client + m_future = self._m_as_future + + derivs = 0.0 + count = 0 + + for futures in self._futures: + future_deriv = [] + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + future_deriv.append( + client.submit( + _deriv, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) + ) + count += 1 + future_deriv = client.gather(future_deriv) + + derivs += np.sum(future_deriv, axis=0) + + return derivs + + def deriv2(self, m, v=None, f=None): + """ + Second derivative of the composite objective function is the sum of the + second derivatives of each objective function in the list, weighted by + their respective multplier. + + :param numpy.ndarray m: model + :param numpy.ndarray v: vector we are multiplying by + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + client = self.client + m_future = self._m_as_future + [v_future] = client.scatter([v], broadcast=True) + + derivs = 0.0 + count = 0 + + for futures in self._futures: + + future_derivs = [] + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + future_derivs.append( + client.submit( + _deriv2, + objfct, + self.multipliers[count], + m_future, + v_future, + # field, + workers=worker, + ) + ) + count += 1 + + future_derivs = self.client.gather(future_derivs) + derivs += np.sum(future_derivs, axis=0) + + return derivs + + def get_dpred(self, m, f=None): + """ + Request calculation of predicted data from all simulations. + """ + self.model = m + + client = self.client + m_future = self._m_as_future + dpred = [] + + for futures in self._futures: + future_preds = [] + for objfct, worker in zip(futures, self._workers): + future_preds.append( + client.submit( + _calc_dpred, + objfct, + m_future, + workers=worker, + ) + ) + dpred += client.gather(future_preds) + + return dpred + + def getJtJdiag(self, m, f=None): + """ + Request calculation of the diagonal of JtJ from all simulations. + """ + self.model = m + m_future = self._m_as_future + if getattr(self, "_jtjdiag", None) is None: + + jtj_diag = 0.0 + client = self.client + + for futures in self._futures: + work = [] + + for objfct, worker in zip(futures, self._workers): + work.append( + client.submit( + _get_jtj_diag, + objfct, + m_future, + workers=worker, + ) + ) + + work = client.gather(work) + jtj_diag += np.sum(work, axis=0) + + self._jtjdiag = jtj_diag + + return self._jtjdiag + + def fields(self, m): + """ + Request calculation of fields from all simulations. + + Store list of futures for fields in self._stashed_fields. + """ + self.model = m + client = self.client + m_future = self._m_as_future + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + # The above should pass the model to all the internal simulations. + f = [] + + for futures in self._futures: + f.append([]) + for objfct, worker in zip(futures, self._workers): + f[-1].append( + client.submit( + _calc_fields, + objfct, + m_future, + workers=worker, + ) + ) + self._stashed_fields = f + return f + + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + # Only send the model to the internal simulations if it was updated. + if ( + isinstance(value, np.ndarray) + and isinstance(self.model, np.ndarray) + and np.allclose(value, self.model) + ): + return + + self._stashed_fields = None + self._jtjdiag = None + + client = self.client + [self._m_as_future] = client.scatter([value], broadcast=True) + + stores = [] + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): + stores.append( + client.submit( + _store_model, + objfct, + self._m_as_future, + workers=worker, + ) + ) + self.client.gather(stores) # blocking call to ensure all models were stored + self._model = value + + @property + def objfcts(self): + """ + List of objective functions associated with the data misfit. + """ + return self._objfcts + + @objfcts.setter + def objfcts(self, objfcts): + client = self.client + + futures, workers = _validate_type_or_future_of_type( + "objfcts", + objfcts, + L2DataMisfit, + client, + workers=self.workers, + return_workers=True, + ) + + self._objfcts = objfcts + self._futures = futures + self._workers = workers + + def residuals(self, m, f=None): + """ + Compute the residual for the data misfit. + """ + self.model = m + + client = self.client + m_future = self._m_as_future + residuals = [] + + for futures in self._futures: + future_residuals = [] + for objfct, worker in zip(futures, self._workers): + future_residuals.append( + client.submit( + _calc_residual, + objfct, + m_future, + workers=worker, + ) + ) + residuals += client.gather(future_residuals) -ComboObjectiveFunction.deriv2 = dask_deriv2 + return residuals diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 92424ad8b4..87c6cd4219 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,11 +1,14 @@ import numpy as np + from ...potential_fields.base import BasePFSimulation as Sim +from dask.distributed import get_client import os from dask import delayed, array, config from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes -Sim._chunk_format = "row" + +_chunk_format = "row" @property @@ -21,10 +24,7 @@ def chunk_format(self, other): self._chunk_format = other -Sim.chunk_format = chunk_format - - -def dask_dpred(self, m=None, f=None, compute_J=False): +def dpred(self, m=None, f=None): if m is not None: self.model = m if f is not None: @@ -32,31 +32,75 @@ def dask_dpred(self, m=None, f=None, compute_J=False): return self.fields(self.model) -Sim.dpred = dask_dpred +def residual(self, m, dobs, f=None): + return self.dpred(m, f=f) - dobs -def dask_residual(self, m, dobs, f=None): - return self.dpred(m, f=f) - dobs +def block_compute(sim, rows, components): + block = [] + for row in rows: + block.append(sim.evaluate_integral(row, components)) + if sim.store_sensitivities == "forward_only": + return np.hstack(block) -Sim.residual = dask_residual + return np.vstack(block) -def dask_linear_operator(self): +def linear_operator(self): forward_only = self.store_sensitivities == "forward_only" - row = delayed(self.evaluate_integral, pure=True) n_cells = self.nC if getattr(self, "model_type", None) == "vector": n_cells *= 3 - rows = [ - array.from_delayed( - row(receiver_location, components), - dtype=self.sensitivity_dtype, - shape=(len(components),) if forward_only else (len(components), n_cells), - ) - for receiver_location, components in self.survey._location_component_iterator() - ] + n_components = len(self.survey.components) + n_blocks = np.ceil( + (n_cells * n_components * self.survey.receiver_locations.shape[0] * 8.0 * 1e-6) + / self.max_chunk_size + ) + block_split = np.array_split(self.survey.receiver_locations, n_blocks) + + try: + client = get_client() + except ValueError: + client = None + + if client: + sim = client.scatter(self, workers=self.worker) + else: + delayed_compute = delayed(block_compute) + + rows = [] + for block in block_split: + if client: + rows.append( + client.submit( + block_compute, + sim, + block, + self.survey.components, + workers=self.worker, + ) + ) + else: + chunk = delayed_compute(self, block, self.survey.components) + rows.append( + array.from_delayed( + chunk, + dtype=self.sensitivity_dtype, + shape=( + (len(block) * n_components,) + if forward_only + else (len(block) * n_components, n_cells) + ), + ) + ) + + if client: + if forward_only: + return np.hstack(client.gather(rows)) + return np.vstack(client.gather(rows)) + if forward_only: stack = array.concatenate(rows) else: @@ -100,22 +144,33 @@ def dask_linear_operator(self): kernel = array.to_zarr( stack, sens_name, compute=True, return_stored=True, overwrite=True ) - elif forward_only: - with ProgressBar(): - print("Forward calculation: ") - kernel = stack.compute() - else: - with ProgressBar(): - print("Computing sensitivities to local ram") - kernel = stack.persist() + + with ProgressBar(): + kernel = stack.compute() return kernel -Sim.linear_operator = dask_linear_operator +def compute_J(self, _, f=None): + return self.linear_operator() -def compute_J(self): - return self.linear_operator() +@property +def Jmatrix(self): + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.compute_J(self.model) + return self._Jmatrix +@Jmatrix.setter +def Jmatrix(self, value): + self._Jmatrix = value + + +Sim.clean_on_model_update = [] +Sim._chunk_format = _chunk_format +Sim.chunk_format = chunk_format +Sim.dpred = dpred +Sim.residual = residual +Sim.linear_operator = linear_operator Sim.compute_J = compute_J +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index 780e37057a..4c8d39271c 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,25 +1,18 @@ -import numpy as np from ....potential_fields.gravity import Simulation3DIntegral as Sim -from ....utils import sdiag, mkvc +from ...simulation import getJtJdiag -def dask_getJtJdiag(self, m, W=None, f=None): +@property +def G(self): """ - Return the diagonal of JtJ + Gravity forward operator """ + if getattr(self, "_G", None) is None: + self._G = self.Jmatrix - self.model = m + return self._G - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) - -Sim.getJtJdiag = dask_getJtJdiag +Sim.clean_on_model_update = [] +Sim.getJtJdiag = getJtJdiag +Sim.G = G diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 5682066d2f..cf3303215b 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,35 +1,18 @@ -import numpy as np from ....potential_fields.magnetics import Simulation3DIntegral as Sim -from ....utils import sdiag, mkvc +from ...simulation import getJtJdiag -def dask_getJtJdiag(self, m, W=None, f=None): +@property +def G(self): """ - Return the diagonal of JtJ + Gravity forward operator """ + if getattr(self, "_G", None) is None: + self._G = self.Jmatrix - self.model = m + return self._G - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - if not self.is_amplitude_data: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() - else: - ampDeriv = self.ampDeriv - J = ( - ampDeriv[0, :, None] * self.Jmatrix[::3] - + ampDeriv[1, :, None] * self.Jmatrix[1::3] - + ampDeriv[2, :, None] * self.Jmatrix[2::3] - ) - diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) - - -Sim.getJtJdiag = dask_getJtJdiag +Sim.clean_on_model_update = [] +Sim.getJtJdiag = getJtJdiag +Sim.G = G diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 2e94bfcc4a..d04cfb0a32 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,13 +1,13 @@ from ..simulation import BaseSimulation as Sim -from dask.distributed import get_client, Future -from dask import array, delayed -import multiprocessing -import warnings -from ..data import SyntheticData + +from dask import array import numpy as np -from .utils import compute +from multiprocessing import cpu_count +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] +Sim.sensitivity_path = "./sensitivity/" Sim._max_ram = 16 +Sim._max_chunk_size = 128 @property @@ -25,8 +25,6 @@ def max_ram(self, other): Sim.max_ram = max_ram -Sim._max_chunk_size = 128 - @property def max_chunk_size(self): @@ -44,91 +42,28 @@ def max_chunk_size(self, other): Sim.max_chunk_size = max_chunk_size -@property -def n_cpu(self): - """Number of cpu's available.""" - if getattr(self, "_n_cpu", None) is None: - self._n_cpu = int(multiprocessing.cpu_count()) - return self._n_cpu - - -@n_cpu.setter -def n_cpu(self, other): - if other <= 0: - raise ValueError("n_cpu must be greater than 0") - self._n_cpu = other - - -Sim.n_cpu = n_cpu - - -def make_synthetic_data( - self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs -): +def getJtJdiag(self, m, W=None, f=None): """ - Make synthetic data given a model, and a standard deviation. - :param numpy.ndarray m: geophysical model - :param numpy.ndarray relative_error: standard deviation - :param numpy.ndarray noise_floor: noise floor - :param numpy.ndarray f: fields for the given model (if pre-calculated) + Return the diagonal of JtJ """ + if getattr(self, "_jtjdiag", None) is None: + self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() - std = kwargs.pop("std", None) - if std is not None: - warnings.warn( - "The std parameter will be deprecated in SimPEG 0.15.0. " - "Please use relative_error.", - DeprecationWarning, - stacklevel=2, + self._jtj_diag = np.asarray( + np.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) ) - relative_error = std - - dpred = self.dpred(m, f=f) - - if not isinstance(dpred, np.ndarray): - dpred = compute(self, dpred) - if isinstance(dpred, Future): - client = get_client() - dpred = client.gather(dpred) - - dclean = np.asarray(dpred) - - if add_noise is True: - std = relative_error * abs(dclean) + noise_floor - noise = std * np.random.randn(*dclean.shape) - dobs = dclean + noise - else: - dobs = dclean - - return SyntheticData( - survey=self.survey, - dobs=dobs, - dclean=dclean, - relative_error=relative_error, - noise_floor=noise_floor, - ) - -Sim.make_synthetic_data = make_synthetic_data + return self._jtj_diag -@property -def workers(self): - if getattr(self, "_workers", None) is None: - self._workers = None - - return self._workers - - -@workers.setter -def workers(self, workers): - self._workers = workers +Sim.getJtJdiag = getJtJdiag -Sim.workers = workers - - -def dask_Jvec(self, m, v): +def Jvec(self, m, v, **_): """ Compute sensitivity matrix (J) and vector (v) product. """ @@ -137,16 +72,13 @@ def dask_Jvec(self, m, v): if isinstance(self.Jmatrix, np.ndarray): return self.Jmatrix @ v.astype(np.float32) - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - return array.dot(self.Jmatrix, v).astype(np.float32) + return array.dot(self.Jmatrix, v.astype(np.float32)) -Sim.Jvec = dask_Jvec +Sim.Jvec = Jvec -def dask_Jtvec(self, m, v): +def Jtvec(self, m, v, **_): """ Compute adjoint sensitivity matrix (J^T) and vector (v) product. """ @@ -155,40 +87,20 @@ def dask_Jtvec(self, m, v): if isinstance(self.Jmatrix, np.ndarray): return self.Jmatrix.T @ v.astype(np.float32) - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish + return array.dot(v.astype(np.float32), self.Jmatrix) - return array.dot(v, self.Jmatrix).astype(np.float32) - -Sim.Jtvec = dask_Jtvec +Sim.Jtvec = Jtvec @property def Jmatrix(self): """ Sensitivity matrix stored on disk + Return the diagonal of JtJ """ if getattr(self, "_Jmatrix", None) is None: - if self.workers is None: - self._Jmatrix = self.compute_J() - self._G = self._Jmatrix - else: - client = get_client() # Assumes a Client already exists - - if self.store_sensitivities == "ram": - self._Jmatrix = client.persist( - delayed(self.compute_J)(), workers=self.workers - ) - else: - self._Jmatrix = client.compute( - delayed(self.compute_J)(), workers=self.workers - ) - - elif isinstance(self._Jmatrix, Future): - self._Jmatrix.result() - if self.store_sensitivities == "disk": - self._Jmatrix = array.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = self.compute_J(self.model) return self._Jmatrix @@ -196,18 +108,38 @@ def Jmatrix(self): Sim.Jmatrix = Jmatrix -def dask_dpred(self, m=None, f=None, compute_J=False): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). +def n_threads(self, client=None): + """ + Number of threads used by Dask + """ + if getattr(self, "_n_threads", None) is None: + if client: + self._n_threads = client.nthreads()[self.worker[0]] + else: + self._n_threads = cpu_count() + + return self._n_threads + - .. math:: +Sim.n_threads = n_threads - d_\\text{pred} = P(f(m)) - Where P is a projection of the fields onto the data space. +# TODO: Make dpred parallel +def dpred(self, m=None, f=None): + r"""Predicted data for the model provided. + + Parameters + ---------- + m : (n_param,) numpy.ndarray + The model parameters. + f : simpeg.fields.Fields, optional + If provided, will be used to compute the predicted data + without recalculating the fields. + + Returns + ------- + (n_data, ) numpy.ndarray + The predicted data vector. """ if self.survey is None: raise AttributeError( @@ -219,54 +151,11 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) - def evaluate_receiver(source, receiver, mesh, fields): - return receiver.eval(source, mesh, fields).flatten() + f = self.fields(m) - row = delayed(evaluate_receiver, pure=True) - rows = [] + data = Data(self.survey) for src in self.survey.source_list: for rx in src.receiver_list: - rows.append( - array.from_delayed( - row(src, rx, self.mesh, f), - dtype=np.float32, - shape=(rx.nD,), - ) - ) - - data = array.hstack(rows).compute() - - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix - - return data - - -Sim.dpred = dask_dpred - - -def dask_getJtJdiag(self, m, W=None): - """ - Return the diagonal of JtJ - """ - self.model = m - if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() ** 2.0 - - diag = array.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, array.Array): - diag = np.asarray(diag.compute()) - - self._jtjdiag = diag - - return self._jtjdiag + data[src, rx] = rx.eval(src, self.mesh, f) + return mkvc(data) diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index ad292dc9a2..10188b7bc0 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -1,5 +1,4 @@ import numpy as np -from dask.distributed import get_client from multiprocessing import cpu_count @@ -27,20 +26,9 @@ def compute_chunk_sizes(M, N, target_chunk_size): return rowChunk, colChunk -def compute(self, job): - """ - Compute dask job for either dask array or client. - """ - if isinstance(job, np.ndarray): - return job - try: - client = get_client() - return client.compute(job, workers=self.workers) - except ValueError: - return job.compute() - - -def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> list: +def get_parallel_blocks( + source_list: list, data_block_size, optimize=True, thread_count=64 +) -> list: """ Get the blocks of sources and receivers to be computed in parallel. @@ -64,7 +52,7 @@ def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> li chunk_size = len(chunk) # Condition to start a new block - if (row_count + chunk_size) > (data_block_size * cpu_count()): + if (row_count + chunk_size) > (data_block_size * thread_count): row_count = 0 block_count += 1 blocks.append([]) @@ -83,7 +71,7 @@ def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> li row_count += chunk_size # Re-split over cpu_count if too few blocks - if len(blocks) < cpu_count() and optimize: + if len(blocks) < thread_count and optimize: flatten_blocks = [] for block in blocks: flatten_blocks += block diff --git a/simpeg/data_misfit.py b/simpeg/data_misfit.py index b796f78c21..ef8273b36f 100644 --- a/simpeg/data_misfit.py +++ b/simpeg/data_misfit.py @@ -1,5 +1,5 @@ import numpy as np -from .utils import Counter, mkvc, sdiag, timeIt, Identity, validate_type +from .utils import Counter, sdiag, timeIt, Identity, validate_type from .data import Data from .simulation import BaseSimulation from .objective_function import L2ObjectiveFunction @@ -359,16 +359,6 @@ def getJtJdiag(self, m): + "Cannot form the sensitivity explicitly" ) - mapping_deriv = self.model_map.deriv(m) - - if self.model_map is not None: - m = mapping_deriv @ m - jtjdiag = self.simulation.getJtJdiag(m, W=self.W) - if self.model_map is not None: - jtjdiag = mkvc( - (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0) - ) - return jtjdiag diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 1b3a19dcc6..1873eaea88 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -8,7 +8,7 @@ import warnings import os import scipy.sparse as sp - +from ..meta.simulation import MetaSimulation from ..typing import RandomSeed from ..data_misfit import BaseDataMisfit @@ -48,13 +48,27 @@ validate_float, validate_ndarray_with_shape, ) - from geoh5py.groups.property_group import GroupTypeEnum from geoh5py.groups import PropertyGroup, UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace +def compute_JtJdiags(data_misfit, m): + if hasattr(data_misfit, "getJtJdiag"): + return data_misfit.getJtJdiag(m) + else: + jtj_diags = [] + for dmisfit in data_misfit.objfcts: + jtj_diags.append(dmisfit.getJtJdiag(m)) + + jtj_diag = np.zeros_like(jtj_diags[0]) + for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + jtj_diag += multiplier * diag + + return np.asarray(jtj_diag) + + class InversionDirective: """Base inversion directive class. @@ -2466,18 +2480,7 @@ def initialize(self): if not isinstance(rdg, Zero): regDiag += multiplier * rdg.diagonal() - JtJdiag = np.zeros_like(self.invProb.model) - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - assert getattr(sim, "getJ", None) is not None, ( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - JtJdiag += multiplier * np.sum( - np.power((dmisfit.W * sim.getJ(m)), 2), axis=0 - ) - else: - JtJdiag += multiplier * dmisfit.getJtJdiag(m) + JtJdiag = compute_JtJdiags(self.dmisfit, self.invProb.model) diagA = JtJdiag + self.invProb.beta * regDiag diagA[diagA != 0] = diagA[diagA != 0] ** -1.0 @@ -2498,18 +2501,7 @@ def endIter(self): # Check if regularization has a projection regDiag += multiplier * reg.deriv2(m).diagonal() - JtJdiag = np.zeros_like(self.invProb.model) - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - assert getattr(sim, "getJ", None) is not None, ( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - JtJdiag += multiplier * np.sum( - np.power((dmisfit.W * sim.getJ(m)), 2), axis=0 - ) - else: - JtJdiag += multiplier * dmisfit.getJtJdiag(m) + JtJdiag = compute_JtJdiags(self.dmisfit, m) diagA = JtJdiag + self.invProb.beta * regDiag diagA[diagA != 0] = diagA[diagA != 0] ** -1.0 @@ -2833,21 +2825,7 @@ def endIter(self): def update(self): """Update sensitivity weights""" - jtj_diag = np.zeros_like(self.invProb.model) - m = self.invProb.model - - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - if getattr(sim, "getJ", None) is None: - raise AttributeError( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - jtj_diag += multiplier * mkvc( - np.sum((dmisfit.W * sim.getJ(m)) ** 2.0, axis=0) - ) - else: - jtj_diag += multiplier * dmisfit.getJtJdiag(m) + jtj_diag = compute_JtJdiags(self.dmisfit, self.invProb.model) # Compute and sum root-mean squared sensitivities for all objective functions wr = np.zeros_like(self.invProb.model) @@ -3263,9 +3241,7 @@ def __init__(self, h5_object, dmisfit=None, **kwargs): def get_values(self, values: list[np.ndarray] | None): if values is None: - values = np.zeros_like(self.invProb.model) - for fun in self.dmisfit.objfcts: - values += fun.getJtJdiag(self.invProb.model) + values = compute_JtJdiags(self.dmisfit, self.invProb.model) return values @@ -3523,7 +3499,11 @@ def endIter(self): self.opt.upper[indices[nC:]] = np.inf for simulation in self.simulations: - simulation.chiMap = SphericalSystem() * simulation.chiMap + if isinstance(simulation, MetaSimulation): + for sim in simulation.simulations: + sim.chiMap = SphericalSystem() * sim.chiMap + else: + simulation.chiMap = SphericalSystem() * simulation.chiMap # Add and update directives for directive in self.inversion.directiveList.dList: diff --git a/simpeg/electromagnetics/frequency_domain/receivers.py b/simpeg/electromagnetics/frequency_domain/receivers.py index b424135a20..a9eab28f3e 100644 --- a/simpeg/electromagnetics/frequency_domain/receivers.py +++ b/simpeg/electromagnetics/frequency_domain/receivers.py @@ -153,8 +153,8 @@ def getP(self, mesh, projected_grid): scipy.sparse.csr_matrix P, the interpolation matrix """ - if (mesh, projected_grid) in self._Ps: - return self._Ps[(mesh, projected_grid)] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P = Zero() for strength, comp in zip(self.orientation, ["x", "y", "z"]): @@ -164,7 +164,7 @@ def getP(self, mesh, projected_grid): ) if self.storeProjections: - self._Ps[(mesh, projected_grid)] = P + self.spatialP = P return P def eval(self, src, mesh, f): # noqa: A003 diff --git a/simpeg/electromagnetics/natural_source/receivers.py b/simpeg/electromagnetics/natural_source/receivers.py index 930fff879e..75498271f5 100644 --- a/simpeg/electromagnetics/natural_source/receivers.py +++ b/simpeg/electromagnetics/natural_source/receivers.py @@ -175,8 +175,8 @@ def getP(self, mesh, projected_grid, field="e"): if mesh.dim < 3: return super().getP(mesh, projected_grid) - if (mesh, projected_grid, field) in self._Ps: - return self._Ps[(mesh, projected_grid, field)] + if (mesh.n_cells, projected_grid, field) in self._Ps: + return self._Ps[(mesh.n_cells, projected_grid, field)] if field == "e": locs = self.locations_e @@ -184,7 +184,7 @@ def getP(self, mesh, projected_grid, field="e"): locs = self.locations_h P = mesh.get_interpolation_matrix(locs, projected_grid) if self.storeProjections: - self._Ps[(mesh, projected_grid, field)] = P + self._Ps[(mesh.n_cells, projected_grid, field)] = P return P def _eval_impedance(self, src, mesh, f): diff --git a/simpeg/electromagnetics/static/resistivity/receivers.py b/simpeg/electromagnetics/static/resistivity/receivers.py index 53c2614bb7..3607a4049b 100644 --- a/simpeg/electromagnetics/static/resistivity/receivers.py +++ b/simpeg/electromagnetics/static/resistivity/receivers.py @@ -30,6 +30,7 @@ def __init__( projField="phi", **kwargs, ): + self.spatialP = None super(BaseRx, self).__init__(locations=locations, **kwargs) self.orientation = orientation @@ -410,15 +411,15 @@ def getP(self, mesh, projected_grid, transpose=False): P, the interpolation matrix """ - if mesh in self._Ps: - return self._Ps[mesh] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P0 = mesh.get_interpolation_matrix(self.locations[0], projected_grid) P1 = mesh.get_interpolation_matrix(self.locations[1], projected_grid) P = P0 - P1 if self.storeProjections: - self._Ps[mesh] = P + self.spatialP = P if transpose: P = P.toarray().T @@ -489,12 +490,12 @@ def getP(self, mesh, projected_grid): P, the interpolation matrix """ - if mesh in self._Ps: - return self._Ps[mesh] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P = mesh.get_interpolation_matrix(self.locations, projected_grid) if self.storeProjections: - self._Ps[mesh] = P + self.spatialP = P return P diff --git a/simpeg/electromagnetics/static/resistivity/simulation.py b/simpeg/electromagnetics/static/resistivity/simulation.py index 4e65a8d07a..bdd09e1666 100644 --- a/simpeg/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/electromagnetics/static/resistivity/simulation.py @@ -606,7 +606,7 @@ def setBC(self): # TODO: Implement Zhang et al. (1995) r_vec = boundary_faces - source_point - r = np.linalg.norm(r_vec, axis=-1) + r = np.linalg.norm(r_vec, axis=-1) + 1e-12 r_hat = r_vec / r[:, None] r_dot_n = np.einsum("ij,ij->i", r_hat, boundary_normals) diff --git a/simpeg/electromagnetics/time_domain/receivers.py b/simpeg/electromagnetics/time_domain/receivers.py index 98e4a5053e..dc2cbb0255 100644 --- a/simpeg/electromagnetics/time_domain/receivers.py +++ b/simpeg/electromagnetics/time_domain/receivers.py @@ -148,15 +148,15 @@ def getP(self, mesh, time_mesh, f): ----- Projection matrices are stored as a dictionary (mesh, time_mesh) if storeProjections is True """ - if (mesh, time_mesh) in self._Ps: - return self._Ps[(mesh, time_mesh)] + if (mesh.n_cells, time_mesh.n_cells) in self._Ps: + return self._Ps[(mesh.n_cells, time_mesh.n_cells)] Ps = self.getSpatialP(mesh, f) Pt = self.getTimeP(time_mesh, f) P = sp.kron(Pt, Ps) if self.storeProjections: - self._Ps[(mesh, time_mesh)] = P + self._Ps[(mesh.n_cells, time_mesh.n_cells)] = P return P diff --git a/simpeg/inverse_problem.py b/simpeg/inverse_problem.py index e554c95cee..59617296c0 100644 --- a/simpeg/inverse_problem.py +++ b/simpeg/inverse_problem.py @@ -283,14 +283,14 @@ def getFields(self, m, store=False, deleteWarmstart=True): return f - def get_dpred(self, m, f): + def get_dpred(self, m, f=None): dpred = [] for i, objfct in enumerate(self.dmisfit.objfcts): if hasattr(objfct, "simulation"): - dpred += [objfct.simulation.dpred(m, f=f[i])] + dpred += [objfct.simulation.dpred(m, f=f if f is None else f[i])] else: dpred += [] - return np.hstack(dpred) + return dpred @timeIt def evalFunction(self, m, return_g=True, return_H=True): diff --git a/simpeg/meta/__init__.py b/simpeg/meta/__init__.py index 3dca694298..7c58eeb2f8 100644 --- a/simpeg/meta/__init__.py +++ b/simpeg/meta/__init__.py @@ -78,6 +78,7 @@ try: from .dask_sim import ( DaskMetaSimulation, + DaskMetaSimulationExplicit, DaskSumMetaSimulation, DaskRepeatedSimulation, ) diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index bddf091920..df3a714db1 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -1,11 +1,13 @@ import numpy as np from simpeg.simulation import BaseSimulation + from simpeg.survey import BaseSurvey from simpeg.maps import IdentityMap from simpeg.utils import validate_list_of_types, validate_type from simpeg.props import HasModel import itertools + from dask.distributed import Client from dask.distributed import Future from .simulation import MetaSimulation, SumMetaSimulation @@ -81,6 +83,7 @@ def _validate_type_or_future_of_type( objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) + if workers is None: objects = client.scatter(objects) else: @@ -110,6 +113,7 @@ def _validate_type_or_future_of_type( warnings.warn( f"{property_name} {i} is not on the expected worker.", stacklevel=2 ) + # obj = client.submit(_set_worker, obj, worker) # Ensure this runs on the expected worker futures = [] @@ -150,8 +154,11 @@ class DaskMetaSimulation(MetaSimulation): The dask client to use for communication. """ + clean_on_model_update = ["_jtjdiag", "_stashed_fields"] + def __init__(self, simulations, mappings, client): self._client = validate_type("client", client, Client, cast=False) + super().__init__(simulations, mappings) def _make_survey(self): @@ -177,6 +184,7 @@ def simulations(self): @simulations.setter def simulations(self, value): client = self.client + simulations, workers = _validate_type_or_future_of_type( "simulations", value, BaseSimulation, client, return_workers=True ) @@ -247,7 +255,7 @@ def check_mapping(mapping, sim, model_len): raise ValueError("All mappings must have the same input length") if np.any(error_checks == 2): raise ValueError( - f"Simulations and mappings at indices {np.where(error_checks==2)}" + f"Simulations and mappings at indices {np.where(error_checks == 2)}" f" are inconsistent." ) @@ -314,6 +322,8 @@ def fields(self, m): self.model = m client = self.client m_future = self._m_as_future + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): @@ -327,6 +337,7 @@ def fields(self, m): workers=worker, ) ) + self._stashed_fields = f return f def dpred(self, m=None, f=None): diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index ae9846475a..5ede0799c5 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -89,10 +89,10 @@ class MetaSimulation(BaseSimulation): _repeat_sim = False def __init__(self, simulations, mappings): - warnings.warn( - "The MetaSimulation class is a work in progress and might change in the future", - stacklevel=2, - ) + # warnings.warn( + # "The MetaSimulation class is a work in progress and might change in the future", + # stacklevel=2, + # ) self.simulations = simulations self.mappings = mappings self.model = None @@ -216,7 +216,7 @@ def fields(self, m): f.append(sim.fields(sim.model)) return f - def dpred(self, m=None, f=None): + def dpred(self, m=None, f=None, **kwargs): if f is None: if m is None: m = self.model @@ -225,7 +225,7 @@ def dpred(self, m=None, f=None): for mapping, sim, field in zip(self.mappings, self.simulations, f): if self._repeat_sim: sim.model = mapping * self.model - d_pred.append(sim.dpred(m=sim.model, f=field)) + d_pred.append(sim.dpred(m=sim.model, f=field, **kwargs)) return np.concatenate(d_pred) def Jvec(self, m, v, f=None): @@ -307,15 +307,13 @@ def getJtJdiag(self, m, W=None, f=None): # (i.e. projections, multipliers, etc.). # It is usually close within a scaling factor for others, whose accuracy is controlled # by how diagonally dominant JtJ is. - if f is None: - f = self.fields(m) - for i, (mapping, sim, field) in enumerate( - zip(self.mappings, self.simulations, f) - ): + for i, (mapping, sim) in enumerate(zip(self.mappings, self.simulations)): if self._repeat_sim: sim.model = mapping * self.model sim_w = sp.diags(W[self._data_offsets[i] : self._data_offsets[i + 1]]) - sim_jtj = sp.diags(np.sqrt(sim.getJtJdiag(sim.model, sim_w, f=field))) + sim_jtj = sp.diags( + np.sqrt(np.asarray(sim.getJtJdiag(sim.model, sim_w))) + ) m_deriv = mapping.deriv(self.model) jtj_diag += np.asarray( (sim_jtj @ m_deriv).power(2).sum(axis=0) diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 1596b15ec2..7fcf032cd1 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -145,7 +145,7 @@ def __init__( self._sensitivity_gravity = _sensitivity_gravity_serial self._forward_gravity = _forward_gravity_serial - def fields(self, m): + def fields(self, m=None): """ Forward model the gravity field of the mesh on the receivers in the survey @@ -160,16 +160,18 @@ def fields(self, m): Gravity fields generated by the given model on every receiver location. """ - self.model = m + if m is not None: + self.model = m + if self.store_sensitivities == "forward_only": # Compute the linear operation without forming the full dense G if self.engine == "choclo": fields = self._forward(self.rho) else: - fields = mkvc(self.linear_operator()) + fields = self.linear_operator() else: fields = self.G @ (self.rho).astype(self.sensitivity_dtype, copy=False) - return np.asarray(fields) + return fields def getJtJdiag(self, m, W=None, f=None): """ @@ -201,7 +203,7 @@ def Jvec(self, m, v, f=None): Sensitivity times a vector """ dmu_dm_v = self.rhoDeriv @ v - return self.G @ dmu_dm_v.astype(self.sensitivity_dtype, copy=False) + return np.asarray(self.G @ dmu_dm_v.astype(self.sensitivity_dtype, copy=False)) def Jtvec(self, m, v, f=None): """ diff --git a/simpeg/potential_fields/magnetics/simulation.py b/simpeg/potential_fields/magnetics/simulation.py index f960e63810..228fc1de2e 100644 --- a/simpeg/potential_fields/magnetics/simulation.py +++ b/simpeg/potential_fields/magnetics/simulation.py @@ -178,9 +178,10 @@ def M(self, M): M = np.asarray(M) self._M = M.reshape((self.nC, 3)) - def fields(self, model): - self.model = model - # model = self.chiMap * model + def fields(self, m=None): + if m is not None: + self.model = m + if self.store_sensitivities == "forward_only": if self.engine == "choclo": fields = self._forward(self.chi) @@ -198,12 +199,14 @@ def fields(self, model): @property def G(self): + """ + Gravity forward operator + """ if getattr(self, "_G", None) is None: if self.engine == "choclo": self._G = self._sensitivity_matrix() else: self._G = self.linear_operator() - return self._G modelType = deprecate_property( @@ -243,8 +246,7 @@ def getJtJdiag(self, m, W=None, f=None): if getattr(self, "_gtg_diagonal", None) is None: diag = np.zeros(self.Jmatrix.shape[1]) if not self.is_amplitude_data: - for i in range(len(W)): - diag += W[i] * (self.Jmatrix[i] * self.Jmatrix[i]) + diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) else: ampDeriv = self.ampDeriv Gx = self.Jmatrix[::3]