From 917d2538c978b033d82f09183f082d60d5f4bccc Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 3 Dec 2024 19:59:33 -0800 Subject: [PATCH 01/84] Allow kwargs in metasim dpred --- simpeg/meta/simulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index ae9846475a..fd1bfd87d0 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -216,7 +216,7 @@ def fields(self, m): f.append(sim.fields(sim.model)) return f - def dpred(self, m=None, f=None): + def dpred(self, m=None, f=None, **kwargs): if f is None: if m is None: m = self.model @@ -225,7 +225,7 @@ def dpred(self, m=None, f=None): for mapping, sim, field in zip(self.mappings, self.simulations, f): if self._repeat_sim: sim.model = mapping * self.model - d_pred.append(sim.dpred(m=sim.model, f=field)) + d_pred.append(sim.dpred(m=sim.model, f=field, **kwargs)) return np.concatenate(d_pred) def Jvec(self, m, v, f=None): From 88750976e0848a661b34d0ce2bdac6ee1e06a4d5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 4 Dec 2024 09:27:37 -0800 Subject: [PATCH 02/84] Start refactoring dask classes --- simpeg/dask/data_misfit.py | 8 ++++---- simpeg/dask/inverse_problem.py | 4 +--- simpeg/data_misfit.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py index f01d646248..d278e56ae1 100644 --- a/simpeg/dask/data_misfit.py +++ b/simpeg/dask/data_misfit.py @@ -20,7 +20,7 @@ def dask_call(self, m, f=None): return phi_d -L2DataMisfit.__call__ = dask_call +# L2DataMisfit.__call__ = dask_call def dask_deriv(self, m, f=None): @@ -48,7 +48,7 @@ def dask_deriv(self, m, f=None): return Jtvec -L2DataMisfit.deriv = dask_deriv +# L2DataMisfit.deriv = dask_deriv def dask_deriv2(self, m, v, f=None): @@ -78,7 +78,7 @@ def dask_deriv2(self, m, v, f=None): return jtwjvec -L2DataMisfit.deriv2 = dask_deriv2 +# L2DataMisfit.deriv2 = dask_deriv2 def dask_residual(self, m, f=None): @@ -93,4 +93,4 @@ def dask_residual(self, m, f=None): raise Exception(f"Attribute f must be or type {Fields}, numpy.array or None.") -L2DataMisfit.residual = dask_residual +# L2DataMisfit.residual = dask_residual diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 9c0d8d058d..6f5a8620bc 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -75,9 +75,7 @@ def get_dpred(self, m, f=None, compute_J=False): else: vec = m - compute_sensitivities = compute_J and ( - objfct.simulation._Jmatrix is None - ) + compute_sensitivities = compute_J if compute_sensitivities and i == 0: print("Computing forward & sensitivities") diff --git a/simpeg/data_misfit.py b/simpeg/data_misfit.py index b796f78c21..05dabcbb41 100644 --- a/simpeg/data_misfit.py +++ b/simpeg/data_misfit.py @@ -359,16 +359,16 @@ def getJtJdiag(self, m): + "Cannot form the sensitivity explicitly" ) - mapping_deriv = self.model_map.deriv(m) - - if self.model_map is not None: - m = mapping_deriv @ m + # mapping_deriv = self.model_map.deriv(m) + # + # if self.model_map is not None: + # m = mapping_deriv @ m jtjdiag = self.simulation.getJtJdiag(m, W=self.W) - if self.model_map is not None: - jtjdiag = mkvc( - (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0) - ) + # if self.model_map is not None: + # jtjdiag = mkvc( + # (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0) + # ) return jtjdiag From 958ffb25e1d0fa603a6aee738d4399e2026c30ec Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 4 Dec 2024 10:01:41 -0800 Subject: [PATCH 03/84] Fix bad logic for re-compute (cherry picked from commit d1c7aa823e71d1f59f6e66cef8a0b6c11f8a35d6) --- simpeg/dask/potential_fields/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 3c5d05aedd..92424ad8b4 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -93,13 +93,13 @@ def dask_linear_operator(self): # Check that loaded kernel matches supplied data and mesh print("Zarr file detected with same shape and chunksize ... re-loading") return kernel - else: - print("Writing Zarr file to disk") - with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=True, return_stored=True, overwrite=True - ) + + print("Writing Zarr file to disk") + with ProgressBar(): + print("Saving kernel to zarr: " + sens_name) + kernel = array.to_zarr( + stack, sens_name, compute=True, return_stored=True, overwrite=True + ) elif forward_only: with ProgressBar(): print("Forward calculation: ") From 8d28d373b20953cc1b958c24fee480a7b61e1ee0 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 4 Dec 2024 15:19:05 -0800 Subject: [PATCH 04/84] First full run --- simpeg/dask/inverse_problem.py | 6 +++--- simpeg/dask/objective_function.py | 6 +++--- simpeg/inverse_problem.py | 6 +++--- simpeg/meta/simulation.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 6f5a8620bc..88eaf85aa2 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -59,7 +59,7 @@ def dask_getFields(self, m, store=False, deleteWarmstart=True): return f -BaseInvProblem.getFields = dask_getFields +# BaseInvProblem.getFields = dask_getFields def get_dpred(self, m, f=None, compute_J=False): @@ -126,7 +126,7 @@ def get_dpred(self, m, f=None, compute_J=False): return dpreds -BaseInvProblem.get_dpred = get_dpred +# BaseInvProblem.get_dpred = get_dpred def dask_evalFunction(self, m, return_g=True, return_H=True): @@ -136,7 +136,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): phi_d = 0 for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residual = objfct.W * (objfct.data.dobs - pred) + residual = objfct.W * objfct.residual(m, pred) phi_d += np.vdot(residual, residual) phi_d = np.asarray(phi_d) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 72864df5bc..f88274a0a8 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -72,7 +72,7 @@ def dask_call(self, m, f=None): return np.sum(np.r_[multipliers][:, None] * np.vstack(fcts), axis=0).squeeze() -ComboObjectiveFunction.__call__ = dask_call +# ComboObjectiveFunction.__call__ = dask_call def dask_deriv(self, m, f=None): @@ -118,7 +118,7 @@ def dask_deriv(self, m, f=None): return np.sum(np.r_[multipliers][:, None] * np.vstack(g), axis=0).squeeze() -ComboObjectiveFunction.deriv = dask_deriv +# ComboObjectiveFunction.deriv = dask_deriv def dask_deriv2(self, m, v=None, f=None): @@ -164,4 +164,4 @@ def dask_deriv2(self, m, v=None, f=None): return phi_deriv2 -ComboObjectiveFunction.deriv2 = dask_deriv2 +# ComboObjectiveFunction.deriv2 = dask_deriv2 diff --git a/simpeg/inverse_problem.py b/simpeg/inverse_problem.py index e554c95cee..da60f5b2a3 100644 --- a/simpeg/inverse_problem.py +++ b/simpeg/inverse_problem.py @@ -283,14 +283,14 @@ def getFields(self, m, store=False, deleteWarmstart=True): return f - def get_dpred(self, m, f): + def get_dpred(self, m, f=None, compute_J=False): dpred = [] for i, objfct in enumerate(self.dmisfit.objfcts): if hasattr(objfct, "simulation"): - dpred += [objfct.simulation.dpred(m, f=f[i])] + dpred += [objfct.simulation.dpred(m, f=f if f is None else f[i], compute_J=compute_J)] else: dpred += [] - return np.hstack(dpred) + return dpred @timeIt def evalFunction(self, m, return_g=True, return_H=True): diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index fd1bfd87d0..f3937d4617 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -226,7 +226,7 @@ def dpred(self, m=None, f=None, **kwargs): if self._repeat_sim: sim.model = mapping * self.model d_pred.append(sim.dpred(m=sim.model, f=field, **kwargs)) - return np.concatenate(d_pred) + return d_pred def Jvec(self, m, v, f=None): self.model = m From fc4c72b8d06affedbb54a727ddcd23c3afee8a99 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 9 Dec 2024 14:03:19 -0800 Subject: [PATCH 05/84] Remove compute J out of dpred calls --- .../frequency_domain/simulation.py | 27 +++++++------------ .../static/induced_polarization/simulation.py | 15 +++-------- .../induced_polarization/simulation_2d.py | 5 ++-- .../static/resistivity/simulation.py | 13 ++++----- .../static/resistivity/simulation_2d.py | 21 +++++++-------- .../time_domain/simulation.py | 21 +++++++-------- simpeg/dask/inverse_problem.py | 2 +- simpeg/dask/potential_fields/base.py | 2 +- 8 files changed, 42 insertions(+), 64 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index eb8527ed8a..a8efe97824 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -87,7 +87,7 @@ def evaluate_receivers(block, mesh, fields): return np.hstack(data) -def dask_dpred(self, m=None, f=None, compute_J=False): +def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -110,7 +110,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) + f = self.fields(m) all_receivers = [] @@ -136,10 +136,6 @@ def dask_dpred(self, m=None, f=None, compute_J=False): data = compute(array.hstack(rows))[0] - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix - return data @@ -147,7 +143,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): Sim.field_derivs = None -def fields(self, m=None, return_Ainv=False): +def fields(self, m=None): if m is not None: self.model = m @@ -160,14 +156,9 @@ def fields(self, m=None, return_Ainv=False): u = Ainv_solve * rhs sources = self.survey.get_sources_by_frequency(freq) f[sources, self._solutionType] = u + Ainv[freq] = Ainv_solve - if return_Ainv: - Ainv[freq] = Ainv_solve - else: - Ainv_solve.clean() - - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f @@ -177,7 +168,7 @@ def fields(self, m=None, return_Ainv=False): def compute_J(self, f=None): if f is None: - f = self.fields(self.model, return_Ainv=True) + f = self.fields(self.model) if len(self.Ainv) > 1: raise NotImplementedError( @@ -238,9 +229,11 @@ def compute_J(self, f=None): if self.store_sensitivities == "disk": del Jmatrix - return array.from_zarr(self.sensitivity_path) + self._Jmatrix = array.from_zarr(self.sensitivity_path) else: - return Jmatrix + self._Jmatrix = Jmatrix + + return self._Jmatrix Sim.compute_J = compute_J diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 81f2db5a0b..85430d9007 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -21,7 +21,7 @@ Sim.getSourceTerm = dask_getSourceTerm -def dask_fields(self, m=None, return_Ainv=False): +def dask_fields(self, m=None): if m is not None: self.model = m @@ -44,8 +44,7 @@ def dask_fields(self, m=None, return_Ainv=False): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) self._scale = scale.dobs - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f @@ -53,7 +52,7 @@ def dask_fields(self, m=None, return_Ainv=False): Sim.fields = dask_fields -def dask_dpred(self, m=None, f=None, compute_J=False): +def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -72,17 +71,9 @@ def dask_dpred(self, m=None, f=None, compute_J=False): "data. Please set the survey for the simulation: " "simulation.survey = survey" ) - if self._Jmatrix is None or self._scale is None: - if m is None: - m = self.model - f = self.fields(m, return_Ainv=True) - self._Jmatrix = self.compute_J(f=f) data = self.Jvec(m, m) - if compute_J: - return np.asarray(data), self._Jmatrix - return np.asarray(data) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index 963fc12451..b1f91fe1f4 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -21,7 +21,7 @@ Sim.dpred = dask_dpred -def dask_fields(self, m=None, return_Ainv=False): +def dask_fields(self, m=None): if m is not None: self.model = m @@ -50,8 +50,7 @@ def dask_fields(self, m=None, return_Ainv=False): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) self._scale = scale.dobs - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 292cbfa206..0083bada82 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -22,7 +22,7 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] -def dask_fields(self, m=None, return_Ainv=False): +def dask_fields(self, m=None): if m is not None: self.model = m @@ -33,8 +33,7 @@ def dask_fields(self, m=None, return_Ainv=False): f = self.fieldsPair(self) f[:, self._solutionType] = Ainv * RHS - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f @@ -45,7 +44,7 @@ def dask_fields(self, m=None, return_Ainv=False): def compute_J(self, f=None): if f is None: - f = self.fields(self.model, return_Ainv=True) + f = self.fields(self.model) m_size = self.model.size row_chunks = int( @@ -135,9 +134,11 @@ def compute_J(self, f=None): if self.store_sensitivities == "disk": del Jmatrix - return da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") else: - return Jmatrix + self._Jmatrix = Jmatrix + + return self._Jmatrix Sim.compute_J = compute_J diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index 08b5ba08de..fd995a5970 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -17,7 +17,7 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] -def dask_fields(self, m=None, return_Ainv=False): +def dask_fields(self, m=None): if m is not None: self.model = m @@ -33,8 +33,7 @@ def dask_fields(self, m=None, return_Ainv=False): RHS = self.getRHS(ky) f[:, self._solutionType, iky] = Ainv[iky] * RHS - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f @@ -47,7 +46,7 @@ def compute_J(self, f=None): weights = self._quad_weights if f is None: - f = self.fields(self.model, return_Ainv=True) + f = self.fields(self.model) m_size = self.model.size row_chunks = int( @@ -135,15 +134,17 @@ def compute_J(self, f=None): if self.store_sensitivities == "disk": del Jmatrix - return da.from_zarr(self.sensitivity_path + "J.zarr") + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") else: - return Jmatrix + self._Jmatrix = Jmatrix + + return self._Jmatrix Sim.compute_J = compute_J -def dask_dpred(self, m=None, f=None, compute_J=False): +def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -172,7 +173,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) + f = self.fields(m) temp = np.empty(survey.nD) count = 0 @@ -182,10 +183,6 @@ def dask_dpred(self, m=None, f=None, compute_J=False): temp[count : count + len(d)] = d count += len(d) - if compute_J: - Jmatrix = self.compute_J(f=f) - return self._mini_survey_data(temp), Jmatrix - return self._mini_survey_data(temp) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 569d9f6ef1..ba95e88460 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -92,7 +92,7 @@ def _getField(self, name, ind, src_list): TimeFields._getField = _getField -def fields(self, m=None, return_Ainv=False): +def fields(self, m=None): if m is not None: self.model = m @@ -117,8 +117,7 @@ def fields(self, m=None, return_Ainv=False): sol = Ainv[dt] * rhs f[:, self._fieldType + "Solution", tInd + 1] = sol - if return_Ainv: - self.Ainv = Ainv + self.Ainv = Ainv return f @@ -179,7 +178,7 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): return np.hstack(data) -def dask_dpred(self, m=None, f=None, compute_J=False): +def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -202,7 +201,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) + f = self.fields(m) rows = [] receiver_projection = self.survey.source_list[0].receiver_list[0].projField @@ -234,10 +233,6 @@ def dask_dpred(self, m=None, f=None, compute_J=False): data = array.hstack(rows).compute() - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix - return data @@ -501,7 +496,7 @@ def compute_J(self, f=None): Compute the rows for the sensitivity matrix. """ if f is None: - f = self.fields(self.model, return_Ainv=True) + f = self.fields(self.model) ftype = self._fieldType + "Solution" sens_name = self.sensitivity_path[:-5] @@ -585,9 +580,11 @@ def compute_J(self, f=None): A.clean() if self.store_sensitivities == "ram": - return np.asarray(Jmatrix) + self._Jmatrix = np.asarray(Jmatrix) + + self._Jmatrix = Jmatrix - return Jmatrix + return self._Jmatrix Sim.compute_J = compute_J diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 88eaf85aa2..e773f2ffd3 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -126,7 +126,7 @@ def get_dpred(self, m, f=None, compute_J=False): return dpreds -# BaseInvProblem.get_dpred = get_dpred +BaseInvProblem.get_dpred = get_dpred def dask_evalFunction(self, m, return_g=True, return_H=True): diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 92424ad8b4..d8be1ba930 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -24,7 +24,7 @@ def chunk_format(self, other): Sim.chunk_format = chunk_format -def dask_dpred(self, m=None, f=None, compute_J=False): +def dask_dpred(self, m=None, f=None): if m is not None: self.model = m if f is not None: From 54739982a60a21ce166c4ea15bdbe74a193a1f66 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 9 Dec 2024 14:06:18 -0800 Subject: [PATCH 06/84] Add compute_J to meta sims --- simpeg/inverse_problem.py | 4 ++-- simpeg/meta/dask_sim.py | 16 ++++++++++++++++ simpeg/meta/simulation.py | 13 +++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/simpeg/inverse_problem.py b/simpeg/inverse_problem.py index da60f5b2a3..59617296c0 100644 --- a/simpeg/inverse_problem.py +++ b/simpeg/inverse_problem.py @@ -283,11 +283,11 @@ def getFields(self, m, store=False, deleteWarmstart=True): return f - def get_dpred(self, m, f=None, compute_J=False): + def get_dpred(self, m, f=None): dpred = [] for i, objfct in enumerate(self.dmisfit.objfcts): if hasattr(objfct, "simulation"): - dpred += [objfct.simulation.dpred(m, f=f if f is None else f[i], compute_J=compute_J)] + dpred += [objfct.simulation.dpred(m, f=f if f is None else f[i])] else: dpred += [] return dpred diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index bddf091920..21da9f1658 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -436,6 +436,22 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag + def compute_J(self, m, f=None): + self.model = m + if f is None: + f = self.fields(m) + J = [] + client = self.client + for sim, worker, field in zip(self.simulations, self._workers, f): + J.append( + client.submit( + sim.compute_J, + field, + workers=worker, + ) + ) + return self.client.gather(J) + class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation): """A dask distributed version of :class:`.SumMetaSimulation`. diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index f3937d4617..6130a22f96 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -324,6 +324,19 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag + def compute_J(self, m, f=None): + self.model = m + if f is None: + f = self.fields(m) + J = [] + for sim, field in zip(self.simulations, f): + J.append( + sim.compute_J, + field, + workers=worker, + ) + return J + @property def deleteTheseOnModelUpdate(self): return super().deleteTheseOnModelUpdate + ["_jtjdiag"] From 269a36795307654ddb03241f4b99cce31e0f0e84 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 9 Dec 2024 14:45:36 -0800 Subject: [PATCH 07/84] Clean ups --- simpeg/dask/inverse_problem.py | 122 +++--------------- simpeg/dask/potential_fields/base.py | 2 +- .../dask/potential_fields/gravity/__init__.py | 0 .../potential_fields/gravity/simulation.py | 25 ---- .../potential_fields/magnetics/simulation.py | 6 +- simpeg/meta/simulation.py | 4 +- 6 files changed, 22 insertions(+), 137 deletions(-) delete mode 100644 simpeg/dask/potential_fields/gravity/__init__.py delete mode 100644 simpeg/dask/potential_fields/gravity/simulation.py diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index e773f2ffd3..7b28836f2f 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,128 +1,40 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from time import time -from datetime import timedelta + + from dask.distributed import Future, get_client import dask.array as da from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse -from ..data_misfit import BaseDataMisfit -from ..objective_function import BaseObjectiveFunction, ComboObjectiveFunction - - -def dask_getFields(self, m, store=False, deleteWarmstart=True): - f = None - - # try: - # client = get_client() - # fields = lambda f, x, workers: client.compute(f(x), workers=workers) - # except: - # fields = lambda f, x: f(x) - - for mtest, u_ofmtest in self.warmstart: - if m is mtest: - f = u_ofmtest - if self.debug: - print("InvProb is Warm Starting!") - break - - if f is None: - if isinstance(self.dmisfit, BaseDataMisfit): - if self.dmisfit.model_map is not None: - vec = self.dmisfit.model_map @ m - else: - vec = m - - f = fields(self.dmisfit.simulation.fields, vec) - elif isinstance(self.dmisfit, BaseObjectiveFunction): - f = [] - for objfct in self.dmisfit.objfcts: - if hasattr(objfct, "simulation"): - if objfct.model_map is not None: - vec = objfct.model_map @ m - else: - vec = m +from ..objective_function import ComboObjectiveFunction - f += [fields(objfct.simulation.fields, vec, objfct.workers)] - else: - f += [] - if isinstance(f, Future) or isinstance(f[0], Future): - f = client.gather(f) - - if deleteWarmstart: - self.warmstart = [] - if store: - self.warmstart += [(m, f)] - - return f +def get_dpred(self, m, f=None, compute_J=False): + dpreds = [] + for i, objfct in enumerate(self.dmisfit.objfcts): -# BaseInvProblem.getFields = dask_getFields + if compute_J and i == 0: + print("Computing forward & sensitivities") + if f is not None: + fields = f[i] + else: + fields = objfct.simulation.fields(m) -def get_dpred(self, m, f=None, compute_J=False): - dpreds = [] + future = objfct.simulation.dpred(m, f=fields) - if isinstance(self.dmisfit, BaseDataMisfit): - return self.dmisfit.simulation.dpred(m) - elif isinstance(self.dmisfit, BaseObjectiveFunction): - for i, objfct in enumerate(self.dmisfit.objfcts): - if hasattr(objfct, "simulation"): - if getattr(objfct, "model_map", None) is not None: - vec = objfct.model_map @ m - else: - vec = m - - compute_sensitivities = compute_J - - if compute_sensitivities and i == 0: - print("Computing forward & sensitivities") - - if objfct.workers is not None: - client = get_client() - future = client.compute( - objfct.simulation.dpred(vec, compute_J=compute_sensitivities), - workers=objfct.workers, - ) - else: - # For locals, the future is now - ct = time() - - future = objfct.simulation.dpred( - vec, compute_J=compute_sensitivities - ) - - if compute_sensitivities: - runtime = time() - ct - total = len(self.dmisfit.objfcts) - - message = f"{i+1} of {total} in {timedelta(seconds=runtime)}. " - if (total - i - 1) > 0: - message += ( - f"ETA -> {timedelta(seconds=(total - i - 1) * runtime)}" - ) - print(message) - - dpreds += [future] + if compute_J: + objfct.simulation.compute_J(m, f=fields) - else: - dpreds += [] + dpreds += [future] if isinstance(dpreds[0], Future): client = get_client() dpreds = client.gather(dpreds) - preds = [] - if isinstance(dpreds[0], tuple): # Jmatrix was computed - for future, objfct in zip(dpreds, self.dmisfit.objfcts): - preds += [future[0]] - objfct.simulation._Jmatrix = future[1] - return preds - - else: - dpreds = da.compute(dpreds)[0] + dpreds = da.compute(dpreds)[0] return dpreds diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index d8be1ba930..3645bd613a 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -114,7 +114,7 @@ def dask_linear_operator(self): Sim.linear_operator = dask_linear_operator -def compute_J(self): +def compute_J(self, _): return self.linear_operator() diff --git a/simpeg/dask/potential_fields/gravity/__init__.py b/simpeg/dask/potential_fields/gravity/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py deleted file mode 100644 index 780e37057a..0000000000 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np -from ....potential_fields.gravity import Simulation3DIntegral as Sim -from ....utils import sdiag, mkvc - - -def dask_getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - - self.model = m - - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) - - -Sim.getJtJdiag = dask_getJtJdiag diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 5682066d2f..0444eecfa6 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -14,7 +14,7 @@ def dask_getJtJdiag(self, m, W=None, f=None): W = np.ones(self.nD) else: W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: + if getattr(self, "_jtj_diag", None) is None: if not self.is_amplitude_data: diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() else: @@ -25,9 +25,9 @@ def dask_getJtJdiag(self, m, W=None, f=None): + ampDeriv[2, :, None] * self.Jmatrix[2::3] ) diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag + self._jtj_diag = diag else: - diag = self._gtg_diagonal + diag = self._jtj_diag return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 6130a22f96..3bae3c6228 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -331,9 +331,7 @@ def compute_J(self, m, f=None): J = [] for sim, field in zip(self.simulations, f): J.append( - sim.compute_J, - field, - workers=worker, + sim.compute_J(field), ) return J From e0ecc6849a9c650033ea599f79ee2a0be6b413e2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 9 Dec 2024 15:11:16 -0800 Subject: [PATCH 08/84] Remove duplicate dask Jtvec methods --- .../frequency_domain/simulation.py | 5 +- .../induced_polarization/simulation_2d.py | 4 +- .../static/resistivity/simulation.py | 5 +- .../static/resistivity/simulation_2d.py | 4 +- .../time_domain/simulation.py | 4 +- simpeg/dask/potential_fields/base.py | 4 + .../dask/potential_fields/gravity/__init__.py | 0 .../potential_fields/gravity/simulation.py | 25 ++ simpeg/dask/simulation.py | 357 +++++++++--------- 9 files changed, 209 insertions(+), 199 deletions(-) create mode 100644 simpeg/dask/potential_fields/gravity/__init__.py create mode 100644 simpeg/dask/potential_fields/gravity/simulation.py diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index a8efe97824..a841e20f57 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -6,7 +6,7 @@ from dask import array, compute, delayed # from dask.distributed import get_client, Client, performance_report -from simpeg.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag +from simpeg.dask.simulation import dask_getJtJdiag from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr @@ -16,8 +16,7 @@ Sim.gtgdiag = None Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec + Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index b1f91fe1f4..63f68068c0 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -9,15 +9,13 @@ Sim.sensitivity_path = "./sensitivity/" -from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec, dask_dpred +from .simulation import dask_getJtJdiag, dask_dpred from ..resistivity.simulation_2d import compute_J, dask_getSourceTerm Sim.compute_J = compute_J Sim.getSourceTerm = dask_getSourceTerm Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec Sim.dpred = dask_dpred diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 0083bada82..7a7e598114 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,4 +1,4 @@ -from simpeg.dask.simulation import dask_dpred, dask_Jvec, dask_Jtvec, dask_getJtJdiag +from simpeg.dask.simulation import dask_dpred, dask_getJtJdiag from .....electromagnetics.static.resistivity.simulation import BaseDCSimulation as Sim from .....utils import Zero import dask.array as da @@ -17,8 +17,7 @@ Sim.dpred = dask_dpred Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec + Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index fd995a5970..057f853caf 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,7 +1,7 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( BaseDCSimulation2D as Sim, ) -from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec +from .simulation import dask_getJtJdiag import dask.array as da import numpy as np import zarr @@ -12,8 +12,6 @@ Sim.sensitivity_path = "./sensitivity/" Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index ba95e88460..0a6cf07543 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -9,7 +9,7 @@ import scipy.sparse as sp from dask import array, delayed -from simpeg.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag +from simpeg.dask.simulation import dask_getJtJdiag from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -18,8 +18,6 @@ Sim.sensitivity_path = "./sensitivity/" Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 3645bd613a..10fe9bb2e5 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -5,6 +5,10 @@ from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes + +from simpeg.dask.simulation import dask_getJtJdiag + +Sim.getJtJdiag = dask_getJtJdiag Sim._chunk_format = "row" diff --git a/simpeg/dask/potential_fields/gravity/__init__.py b/simpeg/dask/potential_fields/gravity/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py new file mode 100644 index 0000000000..780e37057a --- /dev/null +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -0,0 +1,25 @@ +import numpy as np +from ....potential_fields.gravity import Simulation3DIntegral as Sim +from ....utils import sdiag, mkvc + + +def dask_getJtJdiag(self, m, W=None, f=None): + """ + Return the diagonal of JtJ + """ + + self.model = m + + if W is None: + W = np.ones(self.nD) + else: + W = W.diagonal() + if getattr(self, "_gtg_diagonal", None) is None: + diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() + self._gtg_diagonal = diag + else: + diag = self._gtg_diagonal + return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) + + +Sim.getJtJdiag = dask_getJtJdiag diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 2e94bfcc4a..e19f9c572c 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,11 +1,9 @@ from ..simulation import BaseSimulation as Sim -from dask.distributed import get_client, Future + from dask import array, delayed -import multiprocessing -import warnings -from ..data import SyntheticData + import numpy as np -from .utils import compute + Sim._max_ram = 16 @@ -44,159 +42,178 @@ def max_chunk_size(self, other): Sim.max_chunk_size = max_chunk_size -@property -def n_cpu(self): - """Number of cpu's available.""" - if getattr(self, "_n_cpu", None) is None: - self._n_cpu = int(multiprocessing.cpu_count()) - return self._n_cpu - - -@n_cpu.setter -def n_cpu(self, other): - if other <= 0: - raise ValueError("n_cpu must be greater than 0") - self._n_cpu = other - - -Sim.n_cpu = n_cpu - - -def make_synthetic_data( - self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs -): - """ - Make synthetic data given a model, and a standard deviation. - :param numpy.ndarray m: geophysical model - :param numpy.ndarray relative_error: standard deviation - :param numpy.ndarray noise_floor: noise floor - :param numpy.ndarray f: fields for the given model (if pre-calculated) - """ - - std = kwargs.pop("std", None) - if std is not None: - warnings.warn( - "The std parameter will be deprecated in SimPEG 0.15.0. " - "Please use relative_error.", - DeprecationWarning, - stacklevel=2, - ) - relative_error = std - - dpred = self.dpred(m, f=f) - - if not isinstance(dpred, np.ndarray): - dpred = compute(self, dpred) - if isinstance(dpred, Future): - client = get_client() - dpred = client.gather(dpred) - - dclean = np.asarray(dpred) - - if add_noise is True: - std = relative_error * abs(dclean) + noise_floor - noise = std * np.random.randn(*dclean.shape) - dobs = dclean + noise - else: - dobs = dclean - - return SyntheticData( - survey=self.survey, - dobs=dobs, - dclean=dclean, - relative_error=relative_error, - noise_floor=noise_floor, - ) - - -Sim.make_synthetic_data = make_synthetic_data - - -@property -def workers(self): - if getattr(self, "_workers", None) is None: - self._workers = None - - return self._workers - - -@workers.setter -def workers(self, workers): - self._workers = workers - - -Sim.workers = workers - - -def dask_Jvec(self, m, v): +def dask_getJtJdiag(self, m, W=None, f=None): """ - Compute sensitivity matrix (J) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix @ v.astype(np.float32) - - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - return array.dot(self.Jmatrix, v).astype(np.float32) - - -Sim.Jvec = dask_Jvec - - -def dask_Jtvec(self, m, v): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix.T @ v.astype(np.float32) - - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - return array.dot(v, self.Jmatrix).astype(np.float32) - - -Sim.Jtvec = dask_Jtvec - - -@property -def Jmatrix(self): - """ - Sensitivity matrix stored on disk + Return the diagonal of JtJ """ - if getattr(self, "_Jmatrix", None) is None: - if self.workers is None: - self._Jmatrix = self.compute_J() - self._G = self._Jmatrix + if self._jtj_diag is None: + if self._Jmatrix is None: + self.compute_J(m, f=f) + # Need to check if multiplying weights makes sense + if W is None: + W = np.ones(self.Jmatrix.shape[0]) else: - client = get_client() # Assumes a Client already exists - - if self.store_sensitivities == "ram": - self._Jmatrix = client.persist( - delayed(self.compute_J)(), workers=self.workers - ) - else: - self._Jmatrix = client.compute( - delayed(self.compute_J)(), workers=self.workers - ) - - elif isinstance(self._Jmatrix, Future): - self._Jmatrix.result() - if self.store_sensitivities == "disk": - self._Jmatrix = array.from_zarr(self.sensitivity_path + "J.zarr") - - return self._Jmatrix - - -Sim.Jmatrix = Jmatrix - - -def dask_dpred(self, m=None, f=None, compute_J=False): + W = W.diagonal() + + self._jtj_diag = da.einsum("i,ij,ij->j", W**2, J, J) + + return self._jtj_diag + + +Sim.getJtJdiag = dask_getJtJdiag + +# @property +# def n_cpu(self): +# """Number of cpu's available.""" +# if getattr(self, "_n_cpu", None) is None: +# self._n_cpu = int(multiprocessing.cpu_count()) +# return self._n_cpu +# +# +# @n_cpu.setter +# def n_cpu(self, other): +# if other <= 0: +# raise ValueError("n_cpu must be greater than 0") +# self._n_cpu = other +# +# +# Sim.n_cpu = n_cpu +# +# def make_synthetic_data( +# self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs +# ): +# """ +# Make synthetic data given a model, and a standard deviation. +# :param numpy.ndarray m: geophysical model +# :param numpy.ndarray relative_error: standard deviation +# :param numpy.ndarray noise_floor: noise floor +# :param numpy.ndarray f: fields for the given model (if pre-calculated) +# """ +# +# std = kwargs.pop("std", None) +# if std is not None: +# warnings.warn( +# "The std parameter will be deprecated in SimPEG 0.15.0. " +# "Please use relative_error.", +# DeprecationWarning, +# stacklevel=2, +# ) +# relative_error = std +# +# dpred = self.dpred(m, f=f) +# +# if not isinstance(dpred, np.ndarray): +# dpred = compute(self, dpred) +# if isinstance(dpred, Future): +# client = get_client() +# dpred = client.gather(dpred) +# +# dclean = np.asarray(dpred) +# +# if add_noise is True: +# std = relative_error * abs(dclean) + noise_floor +# noise = std * np.random.randn(*dclean.shape) +# dobs = dclean + noise +# else: +# dobs = dclean +# +# return SyntheticData( +# survey=self.survey, +# dobs=dobs, +# dclean=dclean, +# relative_error=relative_error, +# noise_floor=noise_floor, +# ) +# +# +# Sim.make_synthetic_data = make_synthetic_data +# +# +# @property +# def workers(self): +# if getattr(self, "_workers", None) is None: +# self._workers = None +# +# return self._workers +# +# +# @workers.setter +# def workers(self, workers): +# self._workers = workers +# +# +# Sim.workers = workers +# +# +# def dask_Jvec(self, m, v): +# """ +# Compute sensitivity matrix (J) and vector (v) product. +# """ +# self.model = m +# +# if isinstance(self.Jmatrix, np.ndarray): +# return self.Jmatrix @ v.astype(np.float32) +# +# if isinstance(self.Jmatrix, Future): +# self.Jmatrix # Wait to finish +# +# return array.dot(self.Jmatrix, v).astype(np.float32) +# +# +# Sim.Jvec = dask_Jvec +# +# +# def dask_Jtvec(self, m, v): +# """ +# Compute adjoint sensitivity matrix (J^T) and vector (v) product. +# """ +# self.model = m +# +# if isinstance(self.Jmatrix, np.ndarray): +# return self.Jmatrix.T @ v.astype(np.float32) +# +# if isinstance(self.Jmatrix, Future): +# self.Jmatrix # Wait to finish +# +# return array.dot(v, self.Jmatrix).astype(np.float32) +# +# +# Sim.Jtvec = dask_Jtvec + +# +# @property +# def Jmatrix(self): +# """ +# Sensitivity matrix stored on disk +# """ +# if getattr(self, "_Jmatrix", None) is None: +# if self.workers is None: +# self._Jmatrix = self.compute_J() +# self._G = self._Jmatrix +# else: +# client = get_client() # Assumes a Client already exists +# +# if self.store_sensitivities == "ram": +# self._Jmatrix = client.persist( +# delayed(self.compute_J)(), workers=self.workers +# ) +# else: +# self._Jmatrix = client.compute( +# delayed(self.compute_J)(), workers=self.workers +# ) +# +# elif isinstance(self._Jmatrix, Future): +# self._Jmatrix.result() +# if self.store_sensitivities == "disk": +# self._Jmatrix = array.from_zarr(self.sensitivity_path + "J.zarr") +# +# return self._Jmatrix +# +# +# Sim.Jmatrix = Jmatrix + + +def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) Create the projected data from a model. @@ -219,7 +236,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): if f is None: if m is None: m = self.model - f = self.fields(m, return_Ainv=compute_J) + f = self.fields(m) def evaluate_receiver(source, receiver, mesh, fields): return receiver.eval(source, mesh, fields).flatten() @@ -238,35 +255,7 @@ def evaluate_receiver(source, receiver, mesh, fields): data = array.hstack(rows).compute() - if compute_J and self._Jmatrix is None: - Jmatrix = self.compute_J(f=f) - return data, Jmatrix - return data Sim.dpred = dask_dpred - - -def dask_getJtJdiag(self, m, W=None): - """ - Return the diagonal of JtJ - """ - self.model = m - if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() ** 2.0 - - diag = array.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, array.Array): - diag = np.asarray(diag.compute()) - - self._jtjdiag = diag - - return self._jtjdiag From d0229d0fd8469ce0e28634cf5cb9193899a62b85 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 9 Dec 2024 18:23:32 -0800 Subject: [PATCH 09/84] Work in progress --- .../frequency_domain/simulation.py | 17 +- .../static/resistivity/simulation.py | 10 +- .../static/resistivity/simulation_2d.py | 10 +- .../time_domain/simulation.py | 19 +- simpeg/dask/inverse_problem.py | 12 +- simpeg/dask/potential_fields/base.py | 10 + simpeg/dask/simulation.py | 197 ++++-------------- simpeg/meta/simulation.py | 8 +- 8 files changed, 98 insertions(+), 185 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index a841e20f57..53be5a45e0 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -5,8 +5,7 @@ from multiprocessing import cpu_count from dask import array, compute, delayed -# from dask.distributed import get_client, Client, performance_report -from simpeg.dask.simulation import dask_getJtJdiag +from simpeg.dask.simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr @@ -16,6 +15,8 @@ Sim.gtgdiag = None Sim.getJtJdiag = dask_getJtJdiag +Sim.Jvec = dask_Jvec +Sim.Jtvec = dask_Jtvec Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] @@ -165,9 +166,9 @@ def fields(self, m=None): Sim.fields = fields -def compute_J(self, f=None): +def compute_J(self, m, f=None): if f is None: - f = self.fields(self.model) + f = self.fields(m) if len(self.Ainv) > 1: raise NotImplementedError( @@ -176,7 +177,7 @@ def compute_J(self, f=None): ) A_i = list(self.Ainv.values())[0] - m_size = self.model.size + m_size = m.size if self.store_sensitivities == "disk": Jmatrix = zarr.open( @@ -220,7 +221,7 @@ def compute_J(self, f=None): desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", ): Jmatrix = parallel_block_compute( - self, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + self, m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) for A in self.Ainv.values(): @@ -239,9 +240,9 @@ def compute_J(self, f=None): def parallel_block_compute( - self, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses ): - m_size = self.model.size + m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = delayed(A_i * block_stack) count = 0 diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 7a7e598114..2f974624df 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,4 +1,4 @@ -from simpeg.dask.simulation import dask_dpred, dask_getJtJdiag +from simpeg.dask.simulation import dask_dpred, dask_getJtJdiag, dask_Jvec, dask_Jtvec from .....electromagnetics.static.resistivity.simulation import BaseDCSimulation as Sim from .....utils import Zero import dask.array as da @@ -17,6 +17,8 @@ Sim.dpred = dask_dpred Sim.getJtJdiag = dask_getJtJdiag +Sim.Jvec = dask_Jvec +Sim.Jtvec = dask_Jtvec Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] @@ -40,12 +42,12 @@ def dask_fields(self, m=None): Sim.fields = dask_fields -def compute_J(self, f=None): +def compute_J(self, m, f=None): if f is None: - f = self.fields(self.model) + f = self.fields(m) - m_size = self.model.size + m_size = m.size row_chunks = int( np.ceil( float(self.survey.nD) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index 057f853caf..ce3432379d 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,7 +1,7 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( BaseDCSimulation2D as Sim, ) -from .simulation import dask_getJtJdiag +from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec import dask.array as da import numpy as np import zarr @@ -12,6 +12,8 @@ Sim.sensitivity_path = "./sensitivity/" Sim.getJtJdiag = dask_getJtJdiag +Sim.Jvec = dask_Jvec +Sim.Jtvec = dask_Jtvec Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] @@ -39,14 +41,14 @@ def dask_fields(self, m=None): Sim.fields = dask_fields -def compute_J(self, f=None): +def compute_J(self, m, f=None): kys = self._quad_points weights = self._quad_weights if f is None: - f = self.fields(self.model) + f = self.fields(m) - m_size = self.model.size + m_size = m.size row_chunks = int( np.ceil( float(self.survey.nD) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 0a6cf07543..40e31e42ea 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -9,7 +9,7 @@ import scipy.sparse as sp from dask import array, delayed -from simpeg.dask.simulation import dask_getJtJdiag +from simpeg.dask.simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -18,6 +18,9 @@ Sim.sensitivity_path = "./sensitivity/" Sim.getJtJdiag = dask_getJtJdiag +Sim.Jvec = dask_Jvec +Sim.Jtvec = dask_Jtvec + Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] @@ -489,19 +492,19 @@ def compute_rows( return np.vstack(rows) -def compute_J(self, f=None): +def compute_J(self, m, f=None): """ Compute the rows for the sensitivity matrix. """ if f is None: - f = self.fields(self.model) + f = self.fields(m) ftype = self._fieldType + "Solution" sens_name = self.sensitivity_path[:-5] if self.store_sensitivities == "disk": rows = array.zeros( - (self.survey.nD, self.model.size), - chunks=(self.max_chunk_size, self.model.size), + (self.survey.nD, m.size), + chunks=(self.max_chunk_size, m.size), dtype=np.float32, ) Jmatrix = array.to_zarr( @@ -512,11 +515,11 @@ def compute_J(self, f=None): overwrite=True, ) else: - Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float64) + Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times - compute_row_size = np.ceil(self.max_chunk_size / (self.model.shape[0] * 8.0 * 1e-6)) + compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) fields_array = f[:, ftype, :] @@ -557,7 +560,7 @@ def compute_J(self, f=None): dtype=np.float32, shape=( np.sum([len(chunk[1][0]) for chunk in block]), - self.model.size, + m.size, ), ) ) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 7b28836f2f..e0145300b6 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -44,11 +44,17 @@ def get_dpred(self, m, f=None, compute_J=False): def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m - self.dpred = self.get_dpred(m, compute_J=return_H) + + # Store fields if doing a line-search + fields = self.getFields(m, store=(return_g is False and return_H is False)) + + # if isinstance(self.dmisfit, BaseDataMisfit): + phi_d = self.dmisfit(m, f=fields) + self.dpred = self.get_dpred(m, f=fields, compute_J=return_H) phi_d = 0 for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residual = objfct.W * objfct.residual(m, pred) + residual = objfct.W * (objfct.data.dobs - pred) phi_d += np.vdot(residual, residual) phi_d = np.asarray(phi_d) @@ -105,7 +111,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): out = (phi,) if return_g: - phi_dDeriv = self.dmisfit.deriv(m, f=self.dpred) + phi_dDeriv = self.dmisfit.deriv(m, f=fields) # if hasattr(self.reg.objfcts[0], "space") and self.reg.objfcts[0].space == "spherical": phi_mDeriv = self.reg.deriv(m) # else: diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 10fe9bb2e5..42ee1c1328 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -123,3 +123,13 @@ def compute_J(self, _): Sim.compute_J = compute_J + + +@property +def Jmatrix(self): + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.linear_operator() + return self._Jmatrix + + +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index e19f9c572c..2f48276597 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -46,172 +46,20 @@ def dask_getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ """ - if self._jtj_diag is None: - if self._Jmatrix is None: - self.compute_J(m, f=f) - # Need to check if multiplying weights makes sense + if getattr(self, "_jtjdiag", None) is None: + self.model = m if W is None: W = np.ones(self.Jmatrix.shape[0]) else: W = W.diagonal() - self._jtj_diag = da.einsum("i,ij,ij->j", W**2, J, J) + self._jtj_diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) return self._jtj_diag Sim.getJtJdiag = dask_getJtJdiag -# @property -# def n_cpu(self): -# """Number of cpu's available.""" -# if getattr(self, "_n_cpu", None) is None: -# self._n_cpu = int(multiprocessing.cpu_count()) -# return self._n_cpu -# -# -# @n_cpu.setter -# def n_cpu(self, other): -# if other <= 0: -# raise ValueError("n_cpu must be greater than 0") -# self._n_cpu = other -# -# -# Sim.n_cpu = n_cpu -# -# def make_synthetic_data( -# self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs -# ): -# """ -# Make synthetic data given a model, and a standard deviation. -# :param numpy.ndarray m: geophysical model -# :param numpy.ndarray relative_error: standard deviation -# :param numpy.ndarray noise_floor: noise floor -# :param numpy.ndarray f: fields for the given model (if pre-calculated) -# """ -# -# std = kwargs.pop("std", None) -# if std is not None: -# warnings.warn( -# "The std parameter will be deprecated in SimPEG 0.15.0. " -# "Please use relative_error.", -# DeprecationWarning, -# stacklevel=2, -# ) -# relative_error = std -# -# dpred = self.dpred(m, f=f) -# -# if not isinstance(dpred, np.ndarray): -# dpred = compute(self, dpred) -# if isinstance(dpred, Future): -# client = get_client() -# dpred = client.gather(dpred) -# -# dclean = np.asarray(dpred) -# -# if add_noise is True: -# std = relative_error * abs(dclean) + noise_floor -# noise = std * np.random.randn(*dclean.shape) -# dobs = dclean + noise -# else: -# dobs = dclean -# -# return SyntheticData( -# survey=self.survey, -# dobs=dobs, -# dclean=dclean, -# relative_error=relative_error, -# noise_floor=noise_floor, -# ) -# -# -# Sim.make_synthetic_data = make_synthetic_data -# -# -# @property -# def workers(self): -# if getattr(self, "_workers", None) is None: -# self._workers = None -# -# return self._workers -# -# -# @workers.setter -# def workers(self, workers): -# self._workers = workers -# -# -# Sim.workers = workers -# -# -# def dask_Jvec(self, m, v): -# """ -# Compute sensitivity matrix (J) and vector (v) product. -# """ -# self.model = m -# -# if isinstance(self.Jmatrix, np.ndarray): -# return self.Jmatrix @ v.astype(np.float32) -# -# if isinstance(self.Jmatrix, Future): -# self.Jmatrix # Wait to finish -# -# return array.dot(self.Jmatrix, v).astype(np.float32) -# -# -# Sim.Jvec = dask_Jvec -# -# -# def dask_Jtvec(self, m, v): -# """ -# Compute adjoint sensitivity matrix (J^T) and vector (v) product. -# """ -# self.model = m -# -# if isinstance(self.Jmatrix, np.ndarray): -# return self.Jmatrix.T @ v.astype(np.float32) -# -# if isinstance(self.Jmatrix, Future): -# self.Jmatrix # Wait to finish -# -# return array.dot(v, self.Jmatrix).astype(np.float32) -# -# -# Sim.Jtvec = dask_Jtvec - -# -# @property -# def Jmatrix(self): -# """ -# Sensitivity matrix stored on disk -# """ -# if getattr(self, "_Jmatrix", None) is None: -# if self.workers is None: -# self._Jmatrix = self.compute_J() -# self._G = self._Jmatrix -# else: -# client = get_client() # Assumes a Client already exists -# -# if self.store_sensitivities == "ram": -# self._Jmatrix = client.persist( -# delayed(self.compute_J)(), workers=self.workers -# ) -# else: -# self._Jmatrix = client.compute( -# delayed(self.compute_J)(), workers=self.workers -# ) -# -# elif isinstance(self._Jmatrix, Future): -# self._Jmatrix.result() -# if self.store_sensitivities == "disk": -# self._Jmatrix = array.from_zarr(self.sensitivity_path + "J.zarr") -# -# return self._Jmatrix -# -# -# Sim.Jmatrix = Jmatrix - def dask_dpred(self, m=None, f=None): r""" @@ -259,3 +107,42 @@ def evaluate_receiver(source, receiver, mesh, fields): Sim.dpred = dask_dpred + + +def dask_Jvec(self, m, v, **_): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix @ v.astype(np.float32) + + return array.dot(self.Jmatrix, v).astype(np.float32) + + +def dask_Jtvec(self, m, v, **_): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix.T @ v.astype(np.float32) + + return array.dot(v, self.Jmatrix).astype(np.float32) + + +@property +def Jmatrix(self): + """ + Sensitivity matrix stored on disk + Return the diagonal of JtJ + """ + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.compute_J(self.model) + + return self._Jmatrix + + +Sim.Jmatrix = Jmatrix diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 3bae3c6228..27a99588cd 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -226,7 +226,7 @@ def dpred(self, m=None, f=None, **kwargs): if self._repeat_sim: sim.model = mapping * self.model d_pred.append(sim.dpred(m=sim.model, f=field, **kwargs)) - return d_pred + return np.concatenate(d_pred) def Jvec(self, m, v, f=None): self.model = m @@ -315,7 +315,9 @@ def getJtJdiag(self, m, W=None, f=None): if self._repeat_sim: sim.model = mapping * self.model sim_w = sp.diags(W[self._data_offsets[i] : self._data_offsets[i + 1]]) - sim_jtj = sp.diags(np.sqrt(sim.getJtJdiag(sim.model, sim_w, f=field))) + sim_jtj = sp.diags( + np.sqrt(np.asarray(sim.getJtJdiag(sim.model, sim_w, f=field))) + ) m_deriv = mapping.deriv(self.model) jtj_diag += np.asarray( (sim_jtj @ m_deriv).power(2).sum(axis=0) @@ -331,7 +333,7 @@ def compute_J(self, m, f=None): J = [] for sim, field in zip(self.simulations, f): J.append( - sim.compute_J(field), + sim.compute_J(m, field), ) return J From e4cb2ba011b5e0cf0fd04ba900bbab0d84d8054f Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 10 Dec 2024 11:49:30 -0800 Subject: [PATCH 10/84] Skip Jmatrix if not None --- simpeg/dask/inverse_problem.py | 2 -- simpeg/meta/simulation.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index e0145300b6..82de892568 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -3,7 +3,6 @@ from dask.distributed import Future, get_client -import dask.array as da from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse @@ -34,7 +33,6 @@ def get_dpred(self, m, f=None, compute_J=False): client = get_client() dpreds = client.gather(dpreds) - dpreds = da.compute(dpreds)[0] return dpreds diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 27a99588cd..75bbd47111 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -332,6 +332,8 @@ def compute_J(self, m, f=None): f = self.fields(m) J = [] for sim, field in zip(self.simulations, f): + if getattr(sim, "_Jmatrix", None) is not None: + continue J.append( sim.compute_J(m, field), ) From db3f85ad92198736566c4d5efdcf047baa4f2c34 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 10 Dec 2024 13:22:12 -0800 Subject: [PATCH 11/84] Stash fields --- .../dask/electromagnetics/frequency_domain/simulation.py | 9 ++++++++- .../electromagnetics/static/resistivity/simulation.py | 7 ++++++- .../electromagnetics/static/resistivity/simulation_2d.py | 6 +++++- simpeg/dask/electromagnetics/time_domain/simulation.py | 7 +++++-- simpeg/dask/inverse_problem.py | 3 --- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 53be5a45e0..615a063d73 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -18,7 +18,7 @@ Sim.Jvec = dask_Jvec Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] @delayed @@ -147,6 +147,9 @@ def fields(self, m=None): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + f = self.fieldsPair(self) Ainv = {} for freq in self.survey.frequencies: @@ -160,6 +163,8 @@ def fields(self, m=None): self.Ainv = Ainv + self._stashed_fields = f + return f @@ -167,6 +172,8 @@ def fields(self, m=None): def compute_J(self, m, f=None): + self.model = m + if f is None: f = self.fields(m) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 2f974624df..521bc8b34b 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -20,13 +20,16 @@ Sim.Jvec = dask_Jvec Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] def dask_fields(self, m=None): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + A = self.getA() Ainv = self.solver(A, **self.solver_opts) RHS = self.getRHS() @@ -36,6 +39,8 @@ def dask_fields(self, m=None): self.Ainv = Ainv + self._stashed_fields = f + return f diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index ce3432379d..e71a7f1b24 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -14,13 +14,16 @@ Sim.getJtJdiag = dask_getJtJdiag Sim.Jvec = dask_Jvec Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] def dask_fields(self, m=None): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + kys = self._quad_points f = self.fieldsPair(self) f._quad_weights = self._quad_weights @@ -35,6 +38,7 @@ def dask_fields(self, m=None): self.Ainv = Ainv + self._stashed_fields = f return f diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 40e31e42ea..0800a3cac4 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -21,7 +21,7 @@ Sim.Jvec = dask_Jvec Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] @delayed @@ -97,6 +97,9 @@ def fields(self, m=None): if m is not None: self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + f = self.fieldsPair(self) f[:, self._fieldType + "Solution", 0] = self.getInitialFields() Ainv = {} @@ -119,7 +122,7 @@ def fields(self, m=None): f[:, self._fieldType + "Solution", tInd + 1] = sol self.Ainv = Ainv - + self._stashed_fields = f return f diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 82de892568..a8dd98f6e0 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -14,9 +14,6 @@ def get_dpred(self, m, f=None, compute_J=False): for i, objfct in enumerate(self.dmisfit.objfcts): - if compute_J and i == 0: - print("Computing forward & sensitivities") - if f is not None: fields = f[i] else: From 025781744dff46af9f0862542ee1677aaa710de3 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 10 Dec 2024 15:35:04 -0800 Subject: [PATCH 12/84] Remove dask misfit module --- simpeg/dask/__init__.py | 1 - simpeg/dask/data_misfit.py | 96 -------------------------------------- 2 files changed, 97 deletions(-) delete mode 100644 simpeg/dask/data_misfit.py diff --git a/simpeg/dask/__init__.py b/simpeg/dask/__init__.py index f5a00b7334..89fee4fcd9 100644 --- a/simpeg/dask/__init__.py +++ b/simpeg/dask/__init__.py @@ -10,7 +10,6 @@ import simpeg.dask.potential_fields.gravity.simulation import simpeg.dask.potential_fields.magnetics.simulation import simpeg.dask.simulation - import simpeg.dask.data_misfit import simpeg.dask.inverse_problem import simpeg.dask.objective_function diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py deleted file mode 100644 index d278e56ae1..0000000000 --- a/simpeg/dask/data_misfit.py +++ /dev/null @@ -1,96 +0,0 @@ -import numpy as np - -from ..data_misfit import L2DataMisfit -from ..fields import Fields -from ..utils import mkvc -from .utils import compute -import dask.array as da -from scipy.sparse import csr_matrix as csr -from dask import delayed - - -def dask_call(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__` - """ - R = self.W * self.residual(m, f=f) - phi_d = da.dot(R, R) - if not isinstance(phi_d, np.ndarray): - return compute(self, phi_d) - return phi_d - - -# L2DataMisfit.__call__ = dask_call - - -def dask_deriv(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv` - """ - mapping_deriv = self.model_map.deriv(m) - if getattr(self, "model_map", None) is not None: - m = self.model_map @ m - - wtw_d = self.W.diagonal() ** 2.0 * self.residual(m, f=f) - Jtvec = compute(self, self.simulation.Jtvec(m, wtw_d)) - - if getattr(self, "model_map", None) is not None: - Jtjvec_dmudm = delayed(csr.dot)(Jtvec, mapping_deriv) - h_vec = da.from_delayed( - Jtjvec_dmudm, dtype=float, shape=[mapping_deriv.shape[1]] - ) - if not isinstance(h_vec, np.ndarray): - return compute(self, h_vec) - return h_vec - - if not isinstance(Jtvec, np.ndarray): - return compute(self, Jtvec) - return Jtvec - - -# L2DataMisfit.deriv = dask_deriv - - -def dask_deriv2(self, m, v, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv2` - """ - mapping_deriv = self.model_map.deriv(m) - if getattr(self, "model_map", None) is not None: - m = self.model_map @ m - v = mapping_deriv @ v - - jvec = compute(self, self.simulation.Jvec(m, v)) - w_jvec = self.W.diagonal() ** 2.0 * jvec - jtwjvec = compute(self, self.simulation.Jtvec(m, w_jvec)) - - if getattr(self, "model_map", None) is not None: - Jtjvec_dmudm = delayed(csr.dot)(jtwjvec, mapping_deriv) - h_vec = da.from_delayed( - Jtjvec_dmudm, dtype=float, shape=[mapping_deriv.shape[1]] - ) - if not isinstance(h_vec, np.ndarray): - return compute(self, h_vec) - return h_vec - - if not isinstance(jtwjvec, np.ndarray): - return compute(self, jtwjvec) - return jtwjvec - - -# L2DataMisfit.deriv2 = dask_deriv2 - - -def dask_residual(self, m, f=None): - if self.data is None: - raise Exception("data must be set before a residual can be calculated.") - - if isinstance(f, Fields) or f is None: - return self.simulation.residual(m, self.data.dobs, f=f) - elif f.shape == self.data.dobs.shape: - return mkvc(f - self.data.dobs) - else: - raise Exception(f"Attribute f must be or type {Fields}, numpy.array or None.") - - -# L2DataMisfit.residual = dask_residual From 6b1213fbe0c540e75e8a5427cd1652e3cbc877f1 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 11 Dec 2024 13:24:38 -0800 Subject: [PATCH 13/84] remove gather from meta.dask_sim, bring back dask dmisfit --- simpeg/dask/__init__.py | 1 + simpeg/dask/data_misfit.py | 73 +++++++ simpeg/dask/inverse_problem.py | 2 +- simpeg/dask/objective_function.py | 21 +- .../potential_fields/gravity/simulation.py | 35 +-- .../potential_fields/magnetics/simulation.py | 53 ++--- simpeg/dask/simulation.py | 201 ++++++++++++------ simpeg/directives/directives.py | 64 ++---- simpeg/meta/dask_sim.py | 38 +++- simpeg/potential_fields/gravity/simulation.py | 2 + 10 files changed, 312 insertions(+), 178 deletions(-) create mode 100644 simpeg/dask/data_misfit.py diff --git a/simpeg/dask/__init__.py b/simpeg/dask/__init__.py index 89fee4fcd9..960be9da43 100644 --- a/simpeg/dask/__init__.py +++ b/simpeg/dask/__init__.py @@ -1,5 +1,6 @@ try: import simpeg.dask.simulation + import simpeg.dask.data_misfit import simpeg.dask.electromagnetics.frequency_domain.simulation import simpeg.dask.electromagnetics.static.resistivity.simulation import simpeg.dask.electromagnetics.static.resistivity.simulation_2d diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py new file mode 100644 index 0000000000..86934f4918 --- /dev/null +++ b/simpeg/dask/data_misfit.py @@ -0,0 +1,73 @@ +import numpy as np + +from ..data_misfit import L2DataMisfit + +from ..utils import mkvc + +from dask.distributed import get_client, Future + + +def _data_residual(dpred, dobs): + return mkvc(dpred) - dobs + + +def _misfit(residual, W): + vec = W * residual + return np.dot(vec, vec) + + +def dask_call(self, m, f=None): + """ + Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__` + """ + dpred = self.simulation.dpred(m, f=f) + + if isinstance(dpred, Future): + client = get_client() + residuals = client.submit(_data_residual, dpred, self.data.dobs) + phi_d = client.submit(_misfit, residuals, self.W) + else: + residuals = _data_residual(dpred, self.data.dobs) + phi_d = _misfit(residuals, self.W) + + return phi_d + + +L2DataMisfit.__call__ = dask_call + + +def dask_deriv(self, m, f=None): + """ + Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv` + """ + wtw_d = self.W.diagonal() ** 2.0 * self.residual(m, f=f) + Jtvec = self, self.simulation.Jtvec(m, wtw_d) + + return Jtvec + + +L2DataMisfit.deriv = dask_deriv + + +def _stack_futures(futures, W): + return W * np.concatenate(futures) + + +def dask_deriv2(self, m, v, f=None): + """ + Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv2` + """ + jvec = self.simulation.Jvec(m, v) + if isinstance(jvec, Future): + client = get_client() + w_jvec = client.submit(_stack_futures, jvec, self.W.diagonal() ** 2.0) + + else: + w_jvec = self.W.diagonal() ** 2.0 * jvec + + jtwjvec = self.simulation.Jtvec(m, w_jvec) + + return jtwjvec + + +L2DataMisfit.deriv2 = dask_deriv2 diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index a8dd98f6e0..0b5a2d6531 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -28,7 +28,7 @@ def get_dpred(self, m, f=None, compute_J=False): if isinstance(dpreds[0], Future): client = get_client() - dpreds = client.gather(dpreds) + dpreds = np.concatenate(client.gather(dpreds)).flatten() return dpreds diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index f88274a0a8..2ad57299ff 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -6,8 +6,6 @@ from dask.distributed import Future, get_client, Client from ..data_misfit import L2DataMisfit -BaseObjectiveFunction._workers = None - @property def client(self): @@ -26,19 +24,6 @@ def client(self, client): BaseObjectiveFunction.client = client -@property -def workers(self): - return self._workers - - -@workers.setter -def workers(self, workers): - self._workers = workers - - -BaseObjectiveFunction.workers = workers - - def dask_call(self, m, f=None): fcts = [] multipliers = [] @@ -72,7 +57,7 @@ def dask_call(self, m, f=None): return np.sum(np.r_[multipliers][:, None] * np.vstack(fcts), axis=0).squeeze() -# ComboObjectiveFunction.__call__ = dask_call +ComboObjectiveFunction.__call__ = dask_call def dask_deriv(self, m, f=None): @@ -118,7 +103,7 @@ def dask_deriv(self, m, f=None): return np.sum(np.r_[multipliers][:, None] * np.vstack(g), axis=0).squeeze() -# ComboObjectiveFunction.deriv = dask_deriv +ComboObjectiveFunction.deriv = dask_deriv def dask_deriv2(self, m, v=None, f=None): @@ -164,4 +149,4 @@ def dask_deriv2(self, m, v=None, f=None): return phi_deriv2 -# ComboObjectiveFunction.deriv2 = dask_deriv2 +ComboObjectiveFunction.deriv2 = dask_deriv2 diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index 780e37057a..b380579ce0 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,25 +1,32 @@ import numpy as np +from dask import array from ....potential_fields.gravity import Simulation3DIntegral as Sim +from ...simulation import BaseSimulation from ....utils import sdiag, mkvc -def dask_getJtJdiag(self, m, W=None, f=None): +class Simulation3DIntegral(BaseSimulation, Sim): """ - Return the diagonal of JtJ + Overload the Simulation3DIntegral class to use Dask """ - self.model = m + def getJtJdiag(self, m, W=None, f=None): + """ + Return the diagonal of JtJ + """ - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) + self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() + if getattr(self, "_gtg_diagonal", None) is None: + diag = array.einsum( + "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix + ).compute() + self._gtg_diagonal = diag + else: + diag = self._gtg_diagonal -Sim.getJtJdiag = dask_getJtJdiag + return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 0444eecfa6..2dff14ae10 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,35 +1,40 @@ import numpy as np from ....potential_fields.magnetics import Simulation3DIntegral as Sim from ....utils import sdiag, mkvc +from ..base import Jmatrix +Sim.Jmatrix = Jmatrix -def dask_getJtJdiag(self, m, W=None, f=None): + +class Simulation3DIntegral(Sim): """ - Return the diagonal of JtJ + Overwrite the dask_getJtJdiag method """ - self.model = m - - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_jtj_diag", None) is None: - if not self.is_amplitude_data: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() - else: - ampDeriv = self.ampDeriv - J = ( - ampDeriv[0, :, None] * self.Jmatrix[::3] - + ampDeriv[1, :, None] * self.Jmatrix[1::3] - + ampDeriv[2, :, None] * self.Jmatrix[2::3] - ) - diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() - self._jtj_diag = diag - else: - diag = self._jtj_diag + def getJtJdiag(self, m, W=None, f=None): + """ + Return the diagonal of JtJ + """ - return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) + self.model = m + if W is None: + W = np.ones(self.nD) + else: + W = W.diagonal() + if getattr(self, "_jtj_diag", None) is None: + if not self.is_amplitude_data: + diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() + else: + ampDeriv = self.ampDeriv + J = ( + ampDeriv[0, :, None] * self.Jmatrix[::3] + + ampDeriv[1, :, None] * self.Jmatrix[1::3] + + ampDeriv[2, :, None] * self.Jmatrix[2::3] + ) + diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() + self._jtj_diag = diag + else: + diag = self._jtj_diag -Sim.getJtJdiag = dask_getJtJdiag + return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 2f48276597..4ce775d4dc 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -5,41 +5,153 @@ import numpy as np -Sim._max_ram = 16 +class BaseSimulation(Sim): + """ + Base class for SimPEG simulations + """ + _max_ram = 16 + _max_chunk_size = 128 + + @property + def max_ram(self): + "Maximum ram in (Gb)" + return self._max_ram + + @max_ram.setter + def max_ram(self, other): + if other <= 0: + raise ValueError("max_ram must be greater than 0") + self._max_ram = other + + @property + def max_chunk_size(self): + "Largest chunk size (Mb) used by Dask" + return self._max_chunk_size + + @max_chunk_size.setter + def max_chunk_size(self, other): + if other <= 0: + raise ValueError("max_chunk_size must be greater than 0") + self._max_chunk_size = other + + def getJtJdiag(self, m, W=None, f=None): + """ + Return the diagonal of JtJ + """ + if getattr(self, "_jtjdiag", None) is None: + self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() + + self._jtj_diag = array.einsum( + "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix + ) -@property -def max_ram(self): - "Maximum ram in (Gb)" - return self._max_ram + return self._jtj_diag + + # def dpred(self, m=None, f=None): + # r""" + # dpred(m, f=None) + # Create the projected data from a model. + # The fields, f, (if provided) will be used for the predicted data + # instead of recalculating the fields (which may be expensive!). + # + # .. math:: + # + # d_\\text{pred} = P(f(m)) + # + # Where P is a projection of the fields onto the data space. + # """ + # if self.survey is None: + # raise AttributeError( + # "The survey has not yet been set and is required to compute " + # "data. Please set the survey for the simulation: " + # "simulation.survey = survey" + # ) + # + # if f is None: + # if m is None: + # m = self.model + # f = self.fields(m) + # + # def evaluate_receiver(source, receiver, mesh, fields): + # return receiver.eval(source, mesh, fields).flatten() + # + # row = delayed(evaluate_receiver, pure=True) + # rows = [] + # for src in self.survey.source_list: + # for rx in src.receiver_list: + # rows.append( + # array.from_delayed( + # row(src, rx, self.mesh, f), + # dtype=np.float32, + # shape=(rx.nD,), + # ) + # ) + # + # data = array.hstack(rows) + # + # return data + + def Jvec(self, m, v, **_): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix @ v.astype(np.float32) -@max_ram.setter -def max_ram(self, other): - if other <= 0: - raise ValueError("max_ram must be greater than 0") - self._max_ram = other + return array.dot(self.Jmatrix, v.astype(np.float32)) + def Jtvec(self, m, v, **_): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m -Sim.max_ram = max_ram + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix.T @ v.astype(np.float32) -Sim._max_chunk_size = 128 + return array.dot(v.astype(np.float32), self.Jmatrix) + @property + def Jmatrix(self): + """ + Sensitivity matrix stored on disk + Return the diagonal of JtJ + """ + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.compute_J(self.model) -@property -def max_chunk_size(self): - "Largest chunk size (Mb) used by Dask" - return self._max_chunk_size + return self._Jmatrix -@max_chunk_size.setter -def max_chunk_size(self, other): - if other <= 0: - raise ValueError("max_chunk_size must be greater than 0") - self._max_chunk_size = other +def dask_Jvec(self, m, v, **_): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix @ v.astype(np.float32) -Sim.max_chunk_size = max_chunk_size + return array.dot(self.Jmatrix, v).astype(np.float32) + + +def dask_Jtvec(self, m, v, **_): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix.T @ v.astype(np.float32) + + return array.dot(v, self.Jmatrix).astype(np.float32) def dask_getJtJdiag(self, m, W=None, f=None): @@ -58,9 +170,6 @@ def dask_getJtJdiag(self, m, W=None, f=None): return self._jtj_diag -Sim.getJtJdiag = dask_getJtJdiag - - def dask_dpred(self, m=None, f=None): r""" dpred(m, f=None) @@ -104,45 +213,3 @@ def evaluate_receiver(source, receiver, mesh, fields): data = array.hstack(rows).compute() return data - - -Sim.dpred = dask_dpred - - -def dask_Jvec(self, m, v, **_): - """ - Compute sensitivity matrix (J) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix @ v.astype(np.float32) - - return array.dot(self.Jmatrix, v).astype(np.float32) - - -def dask_Jtvec(self, m, v, **_): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix.T @ v.astype(np.float32) - - return array.dot(v, self.Jmatrix).astype(np.float32) - - -@property -def Jmatrix(self): - """ - Sensitivity matrix stored on disk - Return the diagonal of JtJ - """ - if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.compute_J(self.model) - - return self._Jmatrix - - -Sim.Jmatrix = Jmatrix diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 1b3a19dcc6..c144dded7c 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -48,13 +48,29 @@ validate_float, validate_ndarray_with_shape, ) - +from dask.distributed import get_client, Future from geoh5py.groups.property_group import GroupTypeEnum from geoh5py.groups import PropertyGroup, UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace +def compute_JtJdiags(data_misfit, m): + jtj_diags = [] + for dmisfit in data_misfit.objfcts: + jtj_diags.append(dmisfit.getJtJdiag(m)) + + if isinstance(jtj_diags[0], Future): + client = get_client() + jtj_diags = client.gather(jtj_diags) + + jtj_diag = np.zeros_like(jtj_diags[0]) + for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + jtj_diag += multiplier * diag + + return jtj_diag + + class InversionDirective: """Base inversion directive class. @@ -2466,18 +2482,7 @@ def initialize(self): if not isinstance(rdg, Zero): regDiag += multiplier * rdg.diagonal() - JtJdiag = np.zeros_like(self.invProb.model) - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - assert getattr(sim, "getJ", None) is not None, ( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - JtJdiag += multiplier * np.sum( - np.power((dmisfit.W * sim.getJ(m)), 2), axis=0 - ) - else: - JtJdiag += multiplier * dmisfit.getJtJdiag(m) + JtJdiag = compute_JtJdiags(self.dmisfit, self.invProb.model) diagA = JtJdiag + self.invProb.beta * regDiag diagA[diagA != 0] = diagA[diagA != 0] ** -1.0 @@ -2498,18 +2503,7 @@ def endIter(self): # Check if regularization has a projection regDiag += multiplier * reg.deriv2(m).diagonal() - JtJdiag = np.zeros_like(self.invProb.model) - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - assert getattr(sim, "getJ", None) is not None, ( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - JtJdiag += multiplier * np.sum( - np.power((dmisfit.W * sim.getJ(m)), 2), axis=0 - ) - else: - JtJdiag += multiplier * dmisfit.getJtJdiag(m) + JtJdiag = compute_JtJdiags(self.dmisfit, m) diagA = JtJdiag + self.invProb.beta * regDiag diagA[diagA != 0] = diagA[diagA != 0] ** -1.0 @@ -2833,21 +2827,7 @@ def endIter(self): def update(self): """Update sensitivity weights""" - jtj_diag = np.zeros_like(self.invProb.model) - m = self.invProb.model - - for sim, (multiplier, dmisfit) in zip(self.simulation, self.dmisfit): - if getattr(sim, "getJtJdiag", None) is None: - if getattr(sim, "getJ", None) is None: - raise AttributeError( - "Simulation does not have a getJ attribute." - + "Cannot form the sensitivity explicitly" - ) - jtj_diag += multiplier * mkvc( - np.sum((dmisfit.W * sim.getJ(m)) ** 2.0, axis=0) - ) - else: - jtj_diag += multiplier * dmisfit.getJtJdiag(m) + jtj_diag = compute_JtJdiags(self.dmisfit, self.invProb.model) # Compute and sum root-mean squared sensitivities for all objective functions wr = np.zeros_like(self.invProb.model) @@ -3263,9 +3243,7 @@ def __init__(self, h5_object, dmisfit=None, **kwargs): def get_values(self, values: list[np.ndarray] | None): if values is None: - values = np.zeros_like(self.invProb.model) - for fun in self.dmisfit.objfcts: - values += fun.getJtJdiag(self.invProb.model) + values = compute_JtJdiags(self.dmisfit, self.invProb.model) return values diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 21da9f1658..c888d9c330 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -8,6 +8,7 @@ import itertools from dask.distributed import Client from dask.distributed import Future +from dask import array from .simulation import MetaSimulation, SumMetaSimulation import scipy.sparse as sp from operator import add @@ -32,20 +33,28 @@ def _calc_dpred(mapping, sim, model, field, apply_map=False): return sim.dpred(m=sim.model, f=field) +def _compute_J(sim, m, field): + if getattr(sim, "_Jmatrix", None) is not None: + return sim.Jmatrix + + return sim.compute_J(m, field) + + def _j_vec_op(mapping, sim, model, field, v, apply_map=False): + # return array.from_array(np.zeros(100)) sim_v = mapping.deriv(model) @ v if apply_map: - return sim.Jvec(mapping @ model, sim_v, f=field) + return array.compute(sim.Jvec(mapping @ model, sim_v, f=field)) else: - return sim.Jvec(sim.model, sim_v, f=field) + return array.compute(sim.Jvec(sim.model, sim_v, f=field)) -def _jt_vec_op(mapping, sim, model, field, v, apply_map=False): +def _jt_vec_op(mapping, sim, model, field, v, start, end, apply_map=False): if apply_map: - jtv = sim.Jtvec(mapping @ model, v, f=field) + jtv = sim.Jtvec(mapping @ model, v[start:end], f=field) else: - jtv = sim.Jtvec(sim.model, v, f=field) - return mapping.deriv(model).T @ jtv + jtv = sim.Jtvec(sim.model, v[start:end], f=field) + return mapping.deriv(model).T @ array.compute(jtv)[0] def _get_jtj_diag(mapping, sim, model, field, w, apply_map=False): @@ -65,7 +74,7 @@ def _reduce(client, operation, items): if len(items) % 2 == 1: new_reduce[-1] = client.submit(operation, new_reduce[-1], items[-1]) items = new_reduce - return client.gather(items[0]) + return items[0] def _validate_type_or_future_of_type( @@ -351,7 +360,7 @@ def dpred(self, m=None, f=None): workers=worker, ) ) - return np.concatenate(client.gather(dpred)) + return _reduce(client, array.hstack, dpred) def Jvec(self, m, v, f=None): self.model = m @@ -376,7 +385,7 @@ def Jvec(self, m, v, f=None): workers=worker, ) ) - return np.concatenate(self.client.gather(j_vec)) + return _reduce(client, array.hstack, j_vec) def Jtvec(self, m, v, f=None): self.model = m @@ -395,7 +404,9 @@ def Jtvec(self, m, v, f=None): sim, m_future, field, - v[self._data_offsets[i] : self._data_offsets[i + 1]], + v, + self._data_offsets[i], + self._data_offsets[i + 1], self._repeat_sim, workers=worker, ) @@ -420,6 +431,8 @@ def getJtJdiag(self, m, W=None, f=None): zip(self.mappings, self.simulations, self._workers, f) ): sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] + # s = client.gather(sim) + # ff = client.gather(field) jtj_diag.append( client.submit( _get_jtj_diag, @@ -431,6 +444,7 @@ def getJtJdiag(self, m, W=None, f=None): self._repeat_sim, workers=worker, ) + # s.getJtJdiag(self.model, sim_w, f=ff) ) self._jtjdiag = _reduce(client, add, jtj_diag) @@ -445,7 +459,9 @@ def compute_J(self, m, f=None): for sim, worker, field in zip(self.simulations, self._workers, f): J.append( client.submit( - sim.compute_J, + _compute_J, + sim, + m, field, workers=worker, ) diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 1596b15ec2..0cd0a95652 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -220,6 +220,8 @@ def G(self): self._G = self._sensitivity_matrix() else: self._G = self.linear_operator() + + self._Jmatrix = self._G return self._G @property From 04e937b14f7876ba4ab0a747a79b4eb95947848b Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 12 Dec 2024 10:01:16 -0800 Subject: [PATCH 14/84] Implement dask misfit residual and deriv --- simpeg/dask/data_misfit.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py index 86934f4918..f251155163 100644 --- a/simpeg/dask/data_misfit.py +++ b/simpeg/dask/data_misfit.py @@ -20,14 +20,12 @@ def dask_call(self, m, f=None): """ Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__` """ - dpred = self.simulation.dpred(m, f=f) + residuals = self.residual(m, f=f) - if isinstance(dpred, Future): + if isinstance(residuals, Future): client = get_client() - residuals = client.submit(_data_residual, dpred, self.data.dobs) phi_d = client.submit(_misfit, residuals, self.W) else: - residuals = _data_residual(dpred, self.data.dobs) phi_d = _misfit(residuals, self.W) return phi_d @@ -36,12 +34,34 @@ def dask_call(self, m, f=None): L2DataMisfit.__call__ = dask_call +def dask_residual(self, m, f=None): + dpred = self.simulation.dpred(m, f=f) + + if isinstance(dpred, Future): + client = get_client() + residuals = client.submit(_data_residual, dpred, self.data.dobs) + else: + residuals = _data_residual(dpred, self.data.dobs) + + return residuals + + +L2DataMisfit.residual = dask_residual + + def dask_deriv(self, m, f=None): """ Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv` """ - wtw_d = self.W.diagonal() ** 2.0 * self.residual(m, f=f) - Jtvec = self, self.simulation.Jtvec(m, wtw_d) + residuals = self.residual(m, f=f) + + if isinstance(residuals, Future): + client = get_client() + wtw_d = client.submit(_stack_futures, residuals, self.W.diagonal() ** 2.0) + else: + wtw_d = self.W.diagonal() ** 2.0 * residuals + + Jtvec = self.simulation.Jtvec(m, wtw_d) return Jtvec @@ -50,7 +70,7 @@ def dask_deriv(self, m, f=None): def _stack_futures(futures, W): - return W * np.concatenate(futures) + return W * np.hstack(futures).flatten() def dask_deriv2(self, m, v, f=None): From 93da36d3ad75c376cfce0c03e276f02514d7353b Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 12 Dec 2024 10:19:20 -0800 Subject: [PATCH 15/84] Don't stack dpred, full grav distributed run --- simpeg/dask/inverse_problem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 0b5a2d6531..a8dd98f6e0 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -28,7 +28,7 @@ def get_dpred(self, m, f=None, compute_J=False): if isinstance(dpreds[0], Future): client = get_client() - dpreds = np.concatenate(client.gather(dpreds)).flatten() + dpreds = client.gather(dpreds) return dpreds From 1af2bf8d2f2c230647b01bdf0eae989e223eea61 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 12 Dec 2024 15:35:59 -0800 Subject: [PATCH 16/84] implement rest of dask classes --- .../frequency_domain/simulation.py | 523 ++++++++------- .../static/induced_polarization/simulation.py | 189 +++--- .../induced_polarization/simulation_2d.py | 70 +- .../static/resistivity/simulation.py | 257 ++++---- .../static/resistivity/simulation_2d.py | 339 +++++----- .../time_domain/simulation.py | 600 +++++++++--------- simpeg/dask/inverse_problem.py | 7 +- simpeg/dask/potential_fields/base.py | 229 +++---- .../potential_fields/gravity/simulation.py | 4 +- .../potential_fields/magnetics/simulation.py | 17 +- simpeg/dask/simulation.py | 133 +--- simpeg/directives/directives.py | 8 +- simpeg/meta/dask_sim.py | 27 +- simpeg/meta/simulation.py | 13 - simpeg/potential_fields/gravity/simulation.py | 2 +- 15 files changed, 1086 insertions(+), 1332 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 615a063d73..f1f75497e1 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,24 +1,30 @@ from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim +from ...simulation import BaseSimulation from ....utils import Zero import numpy as np import scipy.sparse as sp from multiprocessing import cpu_count from dask import array, compute, delayed -from simpeg.dask.simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec from simpeg.dask.utils import get_parallel_blocks +from simpeg.electromagnetics.frequency_domain.simulation import ( + Simulation3DMagneticFluxDensity as MagFlux, +) from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary +from simpeg.electromagnetics.natural_source.simulation import ( + Simulation3DPrimarySecondary as NSPrimarySecondary, +) import zarr from tqdm import tqdm -Sim.sensitivity_path = "./sensitivity/" -Sim.gtgdiag = None -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec +@delayed +def evaluate_receivers(block, mesh, fields): + data = [] + for source, _, receiver in block: + data.append(receiver.eval(source, mesh, fields).flatten()) -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] + return np.hstack(data) @delayed @@ -32,264 +38,6 @@ def source_evaluation(simulation, sources): return s_m, s_e -def dask_getSourceTerm(self, freq, source=None): - """ - Assemble the source term. This ensures that the RHS is a vector / array - of the correct size - """ - if source is None: - source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, cpu_count()) - - block_compute = [] - for block in source_block: - if len(block) == 0: - continue - - block_compute.append(source_evaluation(self, block)) - - blocks = compute(block_compute)[0] - s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m += block[0] - s_e += block[1] - - else: - sm, se = source.eval(self) - s_m, s_e = [sm], [se] - - if isinstance(s_m[0][0], Zero): # Assume the rest is all Zero - s_m = Zero() - else: - s_m = np.vstack(s_m) - if s_m.shape[0] < s_m.shape[1]: - s_m = s_m.T - - if isinstance(s_e[0][0], Zero): # Assume the rest is all Zero - s_e = Zero() - else: - s_e = np.vstack(s_e) - if s_e.shape[0] < s_e.shape[1]: - s_e = s_e.T - return s_m, s_e - - -Sim.getSourceTerm = dask_getSourceTerm - - -@delayed -def evaluate_receivers(block, mesh, fields): - data = [] - for source, _, receiver in block: - data.append(receiver.eval(source, mesh, fields).flatten()) - - return np.hstack(data) - - -def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - - receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) - rows = [] - mesh = delayed(self.mesh) - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue - - rows.append( - array.from_delayed( - evaluate_receivers(block, mesh, f), - dtype=np.float64, - shape=(n_data,), - ) - ) - - data = compute(array.hstack(rows))[0] - - return data - - -Sim.dpred = dask_dpred -Sim.field_derivs = None - - -def fields(self, m=None): - if m is not None: - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - - f = self.fieldsPair(self) - Ainv = {} - for freq in self.survey.frequencies: - A = self.getA(freq) - rhs = self.getRHS(freq) - Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) - u = Ainv_solve * rhs - sources = self.survey.get_sources_by_frequency(freq) - f[sources, self._solutionType] = u - Ainv[freq] = Ainv_solve - - self.Ainv = Ainv - - self._stashed_fields = f - - return f - - -Sim.fields = fields - - -def compute_J(self, m, f=None): - self.model = m - - if f is None: - f = self.fields(m) - - if len(self.Ainv) > 1: - raise NotImplementedError( - "Current implementation of parallelization assumes a single frequency per simulation. " - "Consider creating one misfit per frequency." - ) - - A_i = list(self.Ainv.values())[0] - m_size = m.size - - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path, - mode="w", - shape=(self.survey.nD, m_size), - chunks=(self.max_chunk_size, m_size), - ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - - compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) - blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=False - ) - fields_array = delayed(f[:, self._solutionType]) - fields = delayed(f) - survey = delayed(self.survey) - mesh = delayed(self.mesh) - blocks_receiver_derivs = [] - - for block in blocks: - blocks_receiver_derivs.append( - receiver_derivs( - survey, - mesh, - fields, - block, - ) - ) - - # with Client(processes=False) as client: - # with performance_report(filename="dask-report.html"): - - # Dask process for all derivatives - blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - - for block_derivs_chunks, addresses_chunks in tqdm( - zip(blocks_receiver_derivs, blocks), - ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", - ): - Jmatrix = parallel_block_compute( - self, m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks - ) - - for A in self.Ainv.values(): - A.clean() - - if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = array.from_zarr(self.sensitivity_path) - else: - self._Jmatrix = Jmatrix - - return self._Jmatrix - - -Sim.compute_J = compute_J - - -def parallel_block_compute( - self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses -): - m_size = m.size - block_stack = sp.hstack(blocks_receiver_derivs).toarray() - ATinvdf_duT = delayed(A_i * block_stack) - count = 0 - rows = [] - block_delayed = [] - - for address, dfduT in zip(addresses, blocks_receiver_derivs): - n_cols = dfduT.shape[1] - n_rows = address[1][2] - block_delayed.append( - array.from_delayed( - eval_block( - self, - ATinvdf_duT, - np.arange(count, count + n_cols), - Zero(), - fields_array, - address, - ), - dtype=np.float32, - shape=(n_rows, m_size), - ) - ) - count += n_cols - rows += address[1][1].tolist() - - indices = np.hstack(rows) - - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (indices, slice(None)), - compute(array.vstack(block_delayed))[0], - ) - else: - # Dask process to compute row and store - Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] - - return Jmatrix - - @delayed def receiver_derivs(survey, mesh, fields, blocks): field_derivatives = [] @@ -352,3 +100,248 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address du_dmT += deriv_m return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + + +class BaseFDEMSimulation(BaseSimulation, Sim): + sensitivity_path = "./sensitivity/" + clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] + + def getSourceTerm(self, freq, source=None): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + if source is None: + source_list = self.survey.get_sources_by_frequency(freq) + source_block = np.array_split(source_list, cpu_count()) + + block_compute = [] + for block in source_block: + if len(block) == 0: + continue + + block_compute.append(source_evaluation(self, block)) + + blocks = compute(block_compute)[0] + s_m, s_e = [], [] + for block in blocks: + if block[0]: + s_m += block[0] + s_e += block[1] + + else: + sm, se = source.eval(self) + s_m, s_e = [sm], [se] + + if isinstance(s_m[0][0], Zero): # Assume the rest is all Zero + s_m = Zero() + else: + s_m = np.vstack(s_m) + if s_m.shape[0] < s_m.shape[1]: + s_m = s_m.T + + if isinstance(s_e[0][0], Zero): # Assume the rest is all Zero + s_e = Zero() + else: + s_e = np.vstack(s_e) + if s_e.shape[0] < s_e.shape[1]: + s_e = s_e.T + return s_m, s_e + + def dask_dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). + + .. math:: + + d_\\text{pred} = P(f(m)) + + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + if f is None: + if m is None: + m = self.model + f = self.fields(m) + + all_receivers = [] + + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) + rows = [] + mesh = delayed(self.mesh) + for block in receiver_blocks: + n_data = np.sum([rec.nD for _, _, rec in block]) + if n_data == 0: + continue + + rows.append( + array.from_delayed( + evaluate_receivers(block, mesh, f), + dtype=np.float64, + shape=(n_data,), + ) + ) + + data = compute(array.hstack(rows))[0] + + return data + + def fields(self, m=None): + if m is not None: + self.model = m + + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + + f = self.fieldsPair(self) + Ainv = {} + for freq in self.survey.frequencies: + A = self.getA(freq) + rhs = self.getRHS(freq) + Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) + u = Ainv_solve * rhs + sources = self.survey.get_sources_by_frequency(freq) + f[sources, self._solutionType] = u + Ainv[freq] = Ainv_solve + + self.Ainv = Ainv + + self._stashed_fields = f + + return f + + def compute_J(self, m, f=None): + self.model = m + + if f is None: + f = self.fields(m) + + if len(self.Ainv) > 1: + raise NotImplementedError( + "Current implementation of parallelization assumes a single frequency per simulation. " + "Consider creating one misfit per frequency." + ) + + A_i = list(self.Ainv.values())[0] + m_size = m.size + + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, m_size), + chunks=(self.max_chunk_size, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + + compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, compute_row_size, optimize=False + ) + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + blocks_receiver_derivs = [] + + for block in blocks: + blocks_receiver_derivs.append( + receiver_derivs( + survey, + mesh, + fields, + block, + ) + ) + + # Dask process for all derivatives + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + + for block_derivs_chunks, addresses_chunks in tqdm( + zip(blocks_receiver_derivs, blocks), + ncols=len(blocks_receiver_derivs), + desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", + ): + Jmatrix = self.parallel_block_compute( + m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + ) + + for A in self.Ainv.values(): + A.clean() + + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = array.from_zarr(self.sensitivity_path) + else: + self._Jmatrix = Jmatrix + + return self._Jmatrix + + def parallel_block_compute( + self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + ): + m_size = m.size + block_stack = sp.hstack(blocks_receiver_derivs).toarray() + ATinvdf_duT = delayed(A_i * block_stack) + count = 0 + rows = [] + block_delayed = [] + + for address, dfduT in zip(addresses, blocks_receiver_derivs): + n_cols = dfduT.shape[1] + n_rows = address[1][2] + block_delayed.append( + array.from_delayed( + eval_block( + self, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), + ) + ) + count += n_cols + rows += address[1][1].tolist() + + indices = np.hstack(rows) + + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (indices, slice(None)), + compute(array.vstack(block_delayed))[0], + ) + else: + # Dask process to compute row and store + Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + + return Jmatrix + + +class Simulation3DMagneticFluxDensity(MagFlux, BaseFDEMSimulation): + """ + Overload the Simulation3DMagneticFluxDensity class to provide the necessary functionality + """ + + +class Simulation3DPrimarySecondary(NSPrimarySecondary, BaseFDEMSimulation): + """ + Overload the Simulation3DPrimarySecondary class to provide the necessary functionality + """ diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 85430d9007..2acd63a822 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -1,7 +1,7 @@ from .....electromagnetics.static.induced_polarization.simulation import ( - BaseIPSimulation as Sim, + Simulation3DNodal as Sim, ) - +from ....simulation import BaseSimulation from .....data import Data import dask.array as da from dask.distributed import Future @@ -10,134 +10,119 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" +from ..resistivity.simulation import Simulation3DNodal as SimulationDC3D -from ..resistivity.simulation import ( - compute_J, - dask_getSourceTerm, -) -Sim.compute_J = compute_J -Sim.getSourceTerm = dask_getSourceTerm +class Simulation3DNodal(BaseSimulation, Sim): + def fields(self, m=None): + if m is not None: + self.model = m -def dask_fields(self, m=None): - if m is not None: - self.model = m + A = self.getA() + Ainv = self.solver(A, **self.solver_opts) + RHS = self.getRHS() - A = self.getA() - Ainv = self.solver(A, **self.solver_opts) - RHS = self.getRHS() + f = self.fieldsPair(self) + f[:, self._solutionType] = Ainv * RHS - f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + if self._scale is None: + scale = Data(self.survey, np.ones(self.survey.nD)) + # loop through receivers to check if they need to set the _dc_voltage + for src in self.survey.source_list: + for rx in src.receiver_list: + if ( + rx.data_type == "apparent_chargeability" + or self._data_type == "apparent_chargeability" + ): + scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) + self._scale = scale.dobs - if self._scale is None: - scale = Data(self.survey, np.ones(self.survey.nD)) - # loop through receivers to check if they need to set the _dc_voltage - for src in self.survey.source_list: - for rx in src.receiver_list: - if ( - rx.data_type == "apparent_chargeability" - or self._data_type == "apparent_chargeability" - ): - scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) - self._scale = scale.dobs + self.Ainv = Ainv - self.Ainv = Ainv + return f - return f + def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). + .. math:: -Sim.fields = dask_fields + d_\\text{pred} = P(f(m)) + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) -def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). + data = self.Jvec(m, m) - .. math:: + return np.asarray(data) - d_\\text{pred} = P(f(m)) + def getJtJdiag(self, m, W=None): + """ + Return the diagonal of JtJ + """ + self.model = m + if getattr(self, "_jtjdiag", None) is None: + if isinstance(self.Jmatrix, Future): + self.Jmatrix # Wait to finish - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) + if W is None: + W = self._scale * np.ones(self.nD) + else: + W = (self._scale * W.diagonal()) ** 2.0 + + diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - data = self.Jvec(m, m) + if isinstance(diag, da.Array): + diag = np.asarray(diag.compute()) - return np.asarray(data) + self._jtjdiag = diag + return self._jtjdiag -Sim.dpred = dask_dpred + def Jvec(self, m, v, f=None): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m + if isinstance(self.Jmatrix, np.ndarray): + return self._scale.astype(np.float32) * ( + self.Jmatrix @ v.astype(np.float32) + ) -def dask_getJtJdiag(self, m, W=None): - """ - Return the diagonal of JtJ - """ - self.model = m - if getattr(self, "_jtjdiag", None) is None: if isinstance(self.Jmatrix, Future): self.Jmatrix # Wait to finish - if W is None: - W = self._scale * np.ones(self.nD) - else: - W = (self._scale * W.diagonal()) ** 2.0 - - diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, da.Array): - diag = np.asarray(diag.compute()) - - self._jtjdiag = diag - - return self._jtjdiag - - -Sim.getJtJdiag = dask_getJtJdiag - - -def dask_Jvec(self, m, v, f=None): - """ - Compute sensitivity matrix (J) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self._scale.astype(np.float32) * (self.Jmatrix @ v.astype(np.float32)) - - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish - - return self._scale.astype(np.float32) * da.dot(self.Jmatrix, v).astype(np.float32) - - -Sim.Jvec = dask_Jvec - + return self._scale.astype(np.float32) * da.dot(self.Jmatrix, v).astype( + np.float32 + ) -def dask_Jtvec(self, m, v, f=None): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m + def Jtvec(self, m, v, f=None): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m - if isinstance(self.Jmatrix, np.ndarray): - return (self._scale * v.astype(np.float32)).astype(np.float32) @ self.Jmatrix + if isinstance(self.Jmatrix, np.ndarray): + return (self._scale * v.astype(np.float32)).astype( + np.float32 + ) @ self.Jmatrix - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish + if isinstance(self.Jmatrix, Future): + self.Jmatrix # Wait to finish - return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) + return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) -Sim.Jtvec = dask_Jtvec +Simulation3DNodal.compute_J = SimulationDC3D.compute_J +Simulation3DNodal.getSourceTerm = SimulationDC3D.getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index 63f68068c0..aa9afade78 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -1,56 +1,54 @@ from .....electromagnetics.static.induced_polarization.simulation import ( Simulation2DNodal as Sim, ) +from ....simulation import BaseSimulation from .....data import Data import numpy as np import numcodecs numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" +from ..resistivity.simulation_2d import Simulation2DNodal as SimulationDC2D -from .simulation import dask_getJtJdiag, dask_dpred -from ..resistivity.simulation_2d import compute_J, dask_getSourceTerm +class Simulation2DNodal(BaseSimulation, Sim): + """ + Overloaded Simulation2DNodal to include the dask methods + """ -Sim.compute_J = compute_J -Sim.getSourceTerm = dask_getSourceTerm -Sim.getJtJdiag = dask_getJtJdiag -Sim.dpred = dask_dpred + def fields(self, m=None): + if m is not None: + self.model = m + kys = self._quad_points + f = self.fieldsPair(self) + f._quad_weights = self._quad_weights -def dask_fields(self, m=None): - if m is not None: - self.model = m + Ainv = {} + for iky, ky in enumerate(kys): + A = self.getA(ky) + Ainv[iky] = self.solver(A, **self.solver_opts) - kys = self._quad_points - f = self.fieldsPair(self) - f._quad_weights = self._quad_weights + RHS = self.getRHS(ky) + f[:, self._solutionType, iky] = Ainv[iky] * RHS - Ainv = {} - for iky, ky in enumerate(kys): - A = self.getA(ky) - Ainv[iky] = self.solver(A, **self.solver_opts) + if self._scale is None: + scale = Data(self.survey, np.ones(self.survey.nD)) + f_fwd = self.fields_to_space(f) + # loop through receievers to check if they need to set the _dc_voltage + for src in self.survey.source_list: + for rx in src.receiver_list: + if ( + rx.data_type == "apparent_chargeability" + or self._data_type == "apparent_chargeability" + ): + scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) + self._scale = scale.dobs - RHS = self.getRHS(ky) - f[:, self._solutionType, iky] = Ainv[iky] * RHS + self.Ainv = Ainv - if self._scale is None: - scale = Data(self.survey, np.ones(self.survey.nD)) - f_fwd = self.fields_to_space(f) - # loop through receievers to check if they need to set the _dc_voltage - for src in self.survey.source_list: - for rx in src.receiver_list: - if ( - rx.data_type == "apparent_chargeability" - or self._data_type == "apparent_chargeability" - ): - scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) - self._scale = scale.dobs + return f - self.Ainv = Ainv - return f - - -Sim.fields = dask_fields +Simulation2DNodal.compute_J = SimulationDC2D.compute_J +Simulation2DNodal.getSourceTerm = SimulationDC2D.getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 521bc8b34b..8b43f07644 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,5 +1,5 @@ -from simpeg.dask.simulation import dask_dpred, dask_getJtJdiag, dask_Jvec, dask_Jtvec -from .....electromagnetics.static.resistivity.simulation import BaseDCSimulation as Sim +from .....electromagnetics.static.resistivity.simulation import Simulation3DNodal as Sim +from ....simulation import BaseSimulation from .....utils import Zero import dask.array as da import numpy as np @@ -13,172 +13,159 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" -Sim.dpred = dask_dpred -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec - -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] - - -def dask_fields(self, m=None): - if m is not None: - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - - A = self.getA() - Ainv = self.solver(A, **self.solver_opts) - RHS = self.getRHS() +class Simulation3DNodal(BaseSimulation, Sim): + """ + Overload of the Simulation3DNodal to include the dask operations + """ - f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + def fields(self, m=None): + if m is not None: + self.model = m - self.Ainv = Ainv + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - self._stashed_fields = f + A = self.getA() + Ainv = self.solver(A, **self.solver_opts) + RHS = self.getRHS() - return f + f = self.fieldsPair(self) + f[:, self._solutionType] = Ainv * RHS + self.Ainv = Ainv -Sim.fields = dask_fields + self._stashed_fields = f + return f -def compute_J(self, m, f=None): + def compute_J(self, m, f=None): - if f is None: - f = self.fields(m) + if f is None: + f = self.fields(m) - m_size = m.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) - ) - ) - - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), + m_size = m.size + row_chunks = int( + np.ceil( + float(self.survey.nD) + / np.ceil( + float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size + ) + ) ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - - blocks = [] - count = 0 - for source in self.survey.source_list: - u_source = f[source, self._solutionType] - for rx in source.receiver_list: + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path + "J.zarr", + mode="w", + shape=(self.survey.nD, m_size), + chunks=(row_chunks, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - if rx.orientation is not None: - projected_grid = f._GLoc(rx.projField) + rx.orientation - else: - projected_grid = f._GLoc(rx.projField) + blocks = [] + count = 0 + for source in self.survey.source_list: + u_source = f[source, self._solutionType] - PTv = rx.getP(self.mesh, projected_grid).toarray().T + for rx in source.receiver_list: - for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): - start, end = dd * row_chunks, np.min( - [(dd + 1) * row_chunks, PTv.shape[1]] - ) - df_duTFun = getattr(f, "_{0!s}Deriv".format(rx.projField), None) - df_duT, df_dmT = df_duTFun( - source, None, PTv[:, start:end], adjoint=True - ) - ATinvdf_duT = self.Ainv * df_duT - dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) - dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(df_dmT, Zero): - du_dmT += df_dmT - - # - du_dmT = du_dmT.T.reshape((-1, m_size)) - - if len(blocks) == 0: - blocks = du_dmT + if rx.orientation is not None: + projected_grid = f._GLoc(rx.projField) + rx.orientation else: - blocks = np.vstack([blocks, du_dmT]) - - while blocks.shape[0] >= row_chunks: - - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, count + row_chunks), slice(None)), - blocks[:row_chunks, :].astype(np.float32), - ) + projected_grid = f._GLoc(rx.projField) + + PTv = rx.getP(self.mesh, projected_grid).toarray().T + + for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): + start, end = dd * row_chunks, np.min( + [(dd + 1) * row_chunks, PTv.shape[1]] + ) + df_duTFun = getattr(f, "_{0!s}Deriv".format(rx.projField), None) + df_duT, df_dmT = df_duTFun( + source, None, PTv[:, start:end], adjoint=True + ) + ATinvdf_duT = self.Ainv * df_duT + dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) + dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(df_dmT, Zero): + du_dmT += df_dmT + + # + du_dmT = du_dmT.T.reshape((-1, m_size)) + + if len(blocks) == 0: + blocks = du_dmT else: - Jmatrix[count : count + row_chunks, :] = blocks[ - :row_chunks, : - ].astype(np.float32) - - blocks = blocks[row_chunks:, :].astype(np.float32) - count += row_chunks - - del df_duT, ATinvdf_duT, dA_dmT, dRHS_dmT, du_dmT + blocks = np.vstack([blocks, du_dmT]) - if len(blocks) != 0: - - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, self.survey.nD), slice(None)), - blocks.astype(np.float32), - ) - else: - Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) + while blocks.shape[0] >= row_chunks: - self.Ainv.clean() + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, count + row_chunks), slice(None)), + blocks[:row_chunks, :].astype(np.float32), + ) + else: + Jmatrix[count : count + row_chunks, :] = blocks[ + :row_chunks, : + ].astype(np.float32) - if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") - else: - self._Jmatrix = Jmatrix + blocks = blocks[row_chunks:, :].astype(np.float32) + count += row_chunks - return self._Jmatrix + del df_duT, ATinvdf_duT, dA_dmT, dRHS_dmT, du_dmT + if len(blocks) != 0: -Sim.compute_J = compute_J + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, self.survey.nD), slice(None)), + blocks.astype(np.float32), + ) + else: + Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) + self.Ainv.clean() -def dask_getSourceTerm(self): - """ - Evaluates the sources, and puts them in matrix form - :rtype: tuple - :return: q (nC or nN, nSrc) - """ + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + else: + self._Jmatrix = Jmatrix - if getattr(self, "_q", None) is None: + return self._Jmatrix - if self._mini_survey is not None: - Srcs = self._mini_survey.source_list - else: - Srcs = self.survey.source_list + def getSourceTerm(self): + """ + Evaluates the sources, and puts them in matrix form + :rtype: tuple + :return: q (nC or nN, nSrc) + """ - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError + if getattr(self, "_q", None) is None: - elif self._formulation == "HJ": - n = self.mesh.nC + if self._mini_survey is not None: + Srcs = self._mini_survey.source_list + else: + Srcs = self.survey.source_list - q = np.zeros((n, len(Srcs)), order="F") + if self._formulation == "EB": + n = self.mesh.nN + # return NotImplementedError - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + elif self._formulation == "HJ": + n = self.mesh.nC - self._q = q + q = np.zeros((n, len(Srcs)), order="F") - return self._q + for i, source in enumerate(Srcs): + q[:, i] = source.eval(self) + self._q = q -Sim.getSourceTerm = dask_getSourceTerm + return self._q diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index e71a7f1b24..b2d1546b74 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,7 +1,7 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( - BaseDCSimulation2D as Sim, + Simulation2DNodal as Sim, ) -from .simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec +from ....simulation import BaseSimulation import dask.array as da import numpy as np import zarr @@ -9,219 +9,204 @@ numcodecs.blosc.use_threads = False -Sim.sensitivity_path = "./sensitivity/" -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] - - -def dask_fields(self, m=None): - if m is not None: - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields +class Simulation2DNodal(BaseSimulation, Sim): + """ + Overload of the Simulation3DNodal to include the dask operations + """ - kys = self._quad_points - f = self.fieldsPair(self) - f._quad_weights = self._quad_weights + clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] - Ainv = {} - for iky, ky in enumerate(kys): - A = self.getA(ky) - Ainv[iky] = self.solver(A, **self.solver_opts) + def fields(self, m=None): + if m is not None: + self.model = m - RHS = self.getRHS(ky) - f[:, self._solutionType, iky] = Ainv[iky] * RHS + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - self.Ainv = Ainv + kys = self._quad_points + f = self.fieldsPair(self) + f._quad_weights = self._quad_weights - self._stashed_fields = f - return f + Ainv = {} + for iky, ky in enumerate(kys): + A = self.getA(ky) + Ainv[iky] = self.solver(A, **self.solver_opts) + RHS = self.getRHS(ky) + f[:, self._solutionType, iky] = Ainv[iky] * RHS -Sim.fields = dask_fields + self.Ainv = Ainv + self._stashed_fields = f + return f -def compute_J(self, m, f=None): - kys = self._quad_points - weights = self._quad_weights + def compute_J(self, m, f=None): + kys = self._quad_points + weights = self._quad_weights - if f is None: - f = self.fields(m) + if f is None: + f = self.fields(m) - m_size = m.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil( - float(m_size) - * self.survey.nD - * len(kys) - * 8.0 - * 1e-6 - / self.max_chunk_size + m_size = m.size + row_chunks = int( + np.ceil( + float(self.survey.nD) + / np.ceil( + float(m_size) + * self.survey.nD + * len(kys) + * 8.0 + * 1e-6 + / self.max_chunk_size + ) ) ) - ) - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), - ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - - blocks = [] - count = 0 - - for i_src, source in enumerate(self.survey.source_list): - for rx in source.receiver_list: + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path + "J.zarr", + mode="w", + shape=(self.survey.nD, m_size), + chunks=(row_chunks, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - if rx.orientation is not None: - projected_grid = f._GLoc(rx.projField) + rx.orientation - else: - projected_grid = f._GLoc(rx.projField) + blocks = [] + count = 0 - PTv = rx.getP(self.mesh, projected_grid).toarray().T + for i_src, source in enumerate(self.survey.source_list): + for rx in source.receiver_list: - for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): - start, end = dd * row_chunks, np.min( - [(dd + 1) * row_chunks, PTv.shape[1]] - ) - block = np.zeros((end - start, m_size)) - for iky, ky in enumerate(kys): - - u_ky = f[:, self._solutionType, iky] - u_source = u_ky[:, i_src] - ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] - dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) - du_dmT = -weights[iky] * dA_dmT - block += du_dmT.T.reshape((-1, m_size)) - - if len(blocks) == 0: - blocks = block + if rx.orientation is not None: + projected_grid = f._GLoc(rx.projField) + rx.orientation else: - blocks = np.vstack([blocks, block]) - - while blocks.shape[0] >= row_chunks: - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, count + row_chunks), slice(None)), - blocks[:row_chunks, :].astype(np.float32), - ) + projected_grid = f._GLoc(rx.projField) + + PTv = rx.getP(self.mesh, projected_grid).toarray().T + + for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): + start, end = dd * row_chunks, np.min( + [(dd + 1) * row_chunks, PTv.shape[1]] + ) + block = np.zeros((end - start, m_size)) + for iky, ky in enumerate(kys): + + u_ky = f[:, self._solutionType, iky] + u_source = u_ky[:, i_src] + ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] + dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) + du_dmT = -weights[iky] * dA_dmT + block += du_dmT.T.reshape((-1, m_size)) + + if len(blocks) == 0: + blocks = block else: - Jmatrix[count : count + row_chunks, :] = blocks[ - :row_chunks, : - ].astype(np.float32) - - blocks = blocks[row_chunks:, :].astype(np.float32) - count += row_chunks + blocks = np.vstack([blocks, block]) + + while blocks.shape[0] >= row_chunks: + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, count + row_chunks), slice(None)), + blocks[:row_chunks, :].astype(np.float32), + ) + else: + Jmatrix[count : count + row_chunks, :] = blocks[ + :row_chunks, : + ].astype(np.float32) + + blocks = blocks[row_chunks:, :].astype(np.float32) + count += row_chunks + + del ATinvdf_duT, dA_dmT, block + + if len(blocks) != 0: + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, self.survey.nD), slice(None)), + blocks.astype(np.float32), + ) + else: + Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - del ATinvdf_duT, dA_dmT, block + for iky, _ in enumerate(kys): + self.Ainv[iky].clean() - if len(blocks) != 0: if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, self.survey.nD), slice(None)), - blocks.astype(np.float32), - ) + del Jmatrix + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") else: - Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - - for iky, _ in enumerate(kys): - self.Ainv[iky].clean() + self._Jmatrix = Jmatrix - if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") - else: - self._Jmatrix = Jmatrix + return self._Jmatrix - return self._Jmatrix + def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). + .. math:: -Sim.compute_J = compute_J + d_\\text{pred} = P(f(m)) + Where P is a projection of the fields onto the data space. + """ + weights = self._quad_weights + if self._mini_survey is not None: + survey = self._mini_survey + else: + survey = self.survey -def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - weights = self._quad_weights - if self._mini_survey is not None: - survey = self._mini_survey - else: - survey = self.survey - - if survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - temp = np.empty(survey.nD) - count = 0 - for src in survey.source_list: - for rx in src.receiver_list: - d = rx.eval(src, self.mesh, f).dot(weights) - temp[count : count + len(d)] = d - count += len(d) - - return self._mini_survey_data(temp) - + if survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) -Sim.dpred = dask_dpred + if f is None: + if m is None: + m = self.model + f = self.fields(m) + temp = np.empty(survey.nD) + count = 0 + for src in survey.source_list: + for rx in src.receiver_list: + d = rx.eval(src, self.mesh, f).dot(weights) + temp[count : count + len(d)] = d + count += len(d) -def dask_getSourceTerm(self, _): - """ - Evaluates the sources, and puts them in matrix form - :rtype: tuple - :return: q (nC or nN, nSrc) - """ + return self._mini_survey_data(temp) - if getattr(self, "_q", None) is None: + def getSourceTerm(self, _): + """ + Evaluates the sources, and puts them in matrix form + :rtype: tuple + :return: q (nC or nN, nSrc) + """ - if self._mini_survey is not None: - Srcs = self._mini_survey.source_list - else: - Srcs = self.survey.source_list + if getattr(self, "_q", None) is None: - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError - - elif self._formulation == "HJ": - n = self.mesh.nC + if self._mini_survey is not None: + Srcs = self._mini_survey.source_list + else: + Srcs = self.survey.source_list - q = np.zeros((n, len(Srcs)), order="F") + if self._formulation == "EB": + n = self.mesh.nN + # return NotImplementedError - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + elif self._formulation == "HJ": + n = self.mesh.nC - self._q = q + q = np.zeros((n, len(Srcs)), order="F") - return self._q + for i, source in enumerate(Srcs): + q[:, i] = source.eval(self) + self._q = q -Sim.getSourceTerm = dask_getSourceTerm + return self._q diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 0800a3cac4..bc632bd632 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -2,6 +2,7 @@ import dask.array import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim +from ...simulation import BaseSimulation from ....utils import Zero from simpeg.fields import TimeFields from multiprocessing import cpu_count @@ -9,33 +10,234 @@ import scipy.sparse as sp from dask import array, delayed -from simpeg.dask.simulation import dask_getJtJdiag, dask_Jvec, dask_Jtvec +from simpeg.electromagnetics.time_domain.simulation import ( + Simulation3DMagneticFluxDensity as MagneticFlux, +) from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc from time import time from tqdm import tqdm -Sim.sensitivity_path = "./sensitivity/" -Sim.getJtJdiag = dask_getJtJdiag -Sim.Jvec = dask_Jvec -Sim.Jtvec = dask_Jtvec -Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] +class BaseTDEMSimulation(BaseSimulation, Sim): + def fields(self, m=None): + if m is not None: + self.model = m -@delayed -def field_projection(field_array, src_list, array_ind, time_ind, func): - fieldI = field_array[:, :, array_ind] - if fieldI.shape[0] == fieldI.size: - fieldI = mkvc(fieldI, 2) - new_array = func(fieldI, src_list, time_ind) - if new_array.ndim == 1: - new_array = new_array[:, np.newaxis, np.newaxis] - elif new_array.ndim == 2: - new_array = new_array[:, :, np.newaxis] + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - return new_array + f = self.fieldsPair(self) + f[:, self._fieldType + "Solution", 0] = self.getInitialFields() + Ainv = {} + + for tInd, dt in enumerate(self.time_steps): + if dt not in Ainv: + A = self.getAdiag(tInd) + Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) + + Asubdiag = self.getAsubdiag(tInd) + rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] + + if ( + np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) + > 1e-8 + ): + rhs += self.getRHS(tInd + 1) + + sol = Ainv[dt] * rhs + f[:, self._fieldType + "Solution", tInd + 1] = sol + + self.Ainv = Ainv + self._stashed_fields = f + return f + + def getSourceTerm(self, tInd): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + source_list = self.survey.source_list + source_block = np.array_split(source_list, cpu_count()) + + block_compute = [] + for block in source_block: + block_compute.append(source_evaluation(self, block, self.times[tInd])) + + blocks = dask.compute(block_compute)[0] + + s_m, s_e = [], [] + for block in blocks: + if block[0]: + s_m.append(block[0]) + s_e.append(block[1]) + + if isinstance(s_m[0][0], Zero): + return Zero(), np.vstack(s_e).T + + return np.vstack(s_m).T, np.vstack(s_e).T + + def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). + + .. math:: + + d_\\text{pred} = P(f(m)) + + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + if f is None: + if m is None: + m = self.model + f = self.fields(m) + + rows = [] + receiver_projection = self.survey.source_list[0].receiver_list[0].projField + fields_array = f[:, receiver_projection, :] + + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] + + all_receivers = [] + + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(all_receivers, cpu_count()) + + for block in receiver_blocks: + n_data = np.sum([rec.nD for _, _, rec in block]) + if n_data == 0: + continue + + rows.append( + array.from_delayed( + evaluate_receivers( + block, self.mesh, self.time_mesh, f, fields_array + ), + dtype=np.float64, + shape=(n_data,), + ) + ) + + data = array.hstack(rows).compute() + + return data + + def compute_J(self, m, f=None): + """ + Compute the rows for the sensitivity matrix. + """ + if f is None: + f = self.fields(m) + + ftype = self._fieldType + "Solution" + sens_name = self.sensitivity_path[:-5] + if self.store_sensitivities == "disk": + rows = array.zeros( + (self.survey.nD, m.size), + chunks=(self.max_chunk_size, m.size), + dtype=np.float32, + ) + Jmatrix = array.to_zarr( + rows, + os.path.join(sens_name + "_1.zarr"), + compute=True, + return_stored=True, + overwrite=True, + ) + else: + Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) + + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) + blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) + fields_array = f[:, ftype, :] + + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] + + times_field_derivs, Jmatrix = compute_field_derivs( + self, f, blocks, Jmatrix, fields_array.shape + ) + + ATinv_df_duT_v = {} + for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + AdiagTinv = self.Ainv[dt] + j_row_updates = [] + time_mask = data_times > simulation_times[tInd] + + if not np.any(time_mask): + continue + + for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): + ATinv_df_duT_v = get_field_deriv_block( + self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + ) + + if len(block) == 0: + continue + + j_row_updates.append( + array.from_delayed( + compute_rows( + self, + tInd, + block, + ATinv_df_duT_v, + fields_array, + time_mask, + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), + ) + ) + + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" + array.to_zarr( + Jmatrix + array.vstack(j_row_updates), + sens_name, + compute=True, + overwrite=True, + ) + Jmatrix = array.from_zarr(sens_name) + else: + Jmatrix += array.vstack(j_row_updates).compute() + + for A in self.Ainv.values(): + A.clean() + + if self.store_sensitivities == "ram": + self._Jmatrix = np.asarray(Jmatrix) + + self._Jmatrix = Jmatrix + + return self._Jmatrix + + +class Simulation3DMagneticFluxDensity(MagneticFlux, BaseTDEMSimulation): + """ + Overload the Simulation3DMagneticFluxDensity class to use Dask + """ def _getField(self, name, ind, src_list): @@ -93,40 +295,18 @@ def _getField(self, name, ind, src_list): TimeFields._getField = _getField -def fields(self, m=None): - if m is not None: - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - - f = self.fieldsPair(self) - f[:, self._fieldType + "Solution", 0] = self.getInitialFields() - Ainv = {} - - for tInd, dt in enumerate(self.time_steps): - if dt not in Ainv: - A = self.getAdiag(tInd) - Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) - - Asubdiag = self.getAsubdiag(tInd) - rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] - - if ( - np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) - > 1e-8 - ): - rhs += self.getRHS(tInd + 1) - - sol = Ainv[dt] * rhs - f[:, self._fieldType + "Solution", tInd + 1] = sol - - self.Ainv = Ainv - self._stashed_fields = f - return f - +@delayed +def field_projection(field_array, src_list, array_ind, time_ind, func): + fieldI = field_array[:, :, array_ind] + if fieldI.shape[0] == fieldI.size: + fieldI = mkvc(fieldI, 2) + new_array = func(fieldI, src_list, time_ind) + if new_array.ndim == 1: + new_array = new_array[:, np.newaxis, np.newaxis] + elif new_array.ndim == 2: + new_array = new_array[:, :, np.newaxis] -Sim.fields = fields + return new_array @delayed @@ -140,35 +320,6 @@ def source_evaluation(simulation, sources, time_channel): return s_m, s_e -def dask_getSourceTerm(self, tInd): - """ - Assemble the source term. This ensures that the RHS is a vector / array - of the correct size - """ - source_list = self.survey.source_list - source_block = np.array_split(source_list, cpu_count()) - - block_compute = [] - for block in source_block: - block_compute.append(source_evaluation(self, block, self.times[tInd])) - - blocks = dask.compute(block_compute)[0] - - s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m.append(block[0]) - s_e.append(block[1]) - - if isinstance(s_m[0][0], Zero): - return Zero(), np.vstack(s_e).T - - return np.vstack(s_m).T, np.vstack(s_e).T - - -Sim.getSourceTerm = dask_getSourceTerm - - @delayed def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): data = [] @@ -182,116 +333,6 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): return np.hstack(data) -def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - rows = [] - receiver_projection = self.survey.source_list[0].receiver_list[0].projField - fields_array = f[:, receiver_projection, :] - - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] - - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - - receiver_blocks = np.array_split(all_receivers, cpu_count()) - - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue - - rows.append( - array.from_delayed( - evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), - dtype=np.float64, - shape=(n_data,), - ) - ) - - data = array.hstack(rows).compute() - - return data - - -Sim.dpred = dask_dpred -Sim.field_derivs = None - - -@delayed -def delayed_block_deriv( - n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape -): - """Compute derivatives for sources and receivers in a block""" - df_duT = [] - j_updates = [] - - for indices, arrays in chunks: - j_update = 0.0 - source = source_list[indices[0]] - receiver = source.receiver_list[indices[1]] - - spatialP = receiver.getSpatialP(mesh, fields) - timeP = receiver.getTimeP(time_mesh, fields) - - derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) - time_derivs = [] - for time_index in range(n_times + 1): - if len(timeP[:, time_index].data) == 0: - time_derivs.append( - sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) - ) - j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - continue - - projection = sp.kron(timeP[:, time_index], spatialP, format="csr") - cur = derivative_fun( - time_index, - source, - None, - projection.T, - adjoint=True, - ) - - time_derivs.append(cur[0][:, arrays[0]]) - - if not isinstance(cur[1], Zero): - j_update += cur[1].T - else: - j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - - j_updates.append(j_update) - df_duT.append(time_derivs) - - return df_duT, j_updates - - def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields @@ -338,23 +379,6 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): return df_duT, Jmatrix -@delayed -def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd -): - if (s_id, r_id, b_id) not in ATinv_df_duT_v: - # last timestep (first to be solved) - stacked_block = field_derivs.toarray()[:, local_ind] - - else: - stacked_block = np.asarray( - field_derivs[:, local_ind] - - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] - ) - - return stacked_block - - def update_deriv_blocks(address, indices, derivatives, solve, shape): if address not in derivatives: deriv_array = np.zeros(shape) @@ -441,6 +465,71 @@ def get_field_deriv_block( return ATinv_df_duT_v +@delayed +def delayed_block_deriv( + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape +): + """Compute derivatives for sources and receivers in a block""" + df_duT = [] + j_updates = [] + + for indices, arrays in chunks: + j_update = 0.0 + source = source_list[indices[0]] + receiver = source.receiver_list[indices[1]] + + spatialP = receiver.getSpatialP(mesh, fields) + timeP = receiver.getTimeP(time_mesh, fields) + + derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) + time_derivs = [] + for time_index in range(n_times + 1): + if len(timeP[:, time_index].data) == 0: + time_derivs.append( + sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) + ) + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + continue + + projection = sp.kron(timeP[:, time_index], spatialP, format="csr") + cur = derivative_fun( + time_index, + source, + None, + projection.T, + adjoint=True, + ) + + time_derivs.append(cur[0][:, arrays[0]]) + + if not isinstance(cur[1], Zero): + j_update += cur[1].T + else: + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + + j_updates.append(j_update) + df_duT.append(time_derivs) + + return df_duT, j_updates + + +@delayed +def deriv_block( + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd +): + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + # last timestep (first to be solved) + stacked_block = field_derivs.toarray()[:, local_ind] + + else: + stacked_block = np.asarray( + field_derivs[:, local_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] + ) + + return stacked_block + + @delayed def compute_rows( simulation, @@ -493,102 +582,3 @@ def compute_rows( rows.append(row_block) return np.vstack(rows) - - -def compute_J(self, m, f=None): - """ - Compute the rows for the sensitivity matrix. - """ - if f is None: - f = self.fields(m) - - ftype = self._fieldType + "Solution" - sens_name = self.sensitivity_path[:-5] - if self.store_sensitivities == "disk": - rows = array.zeros( - (self.survey.nD, m.size), - chunks=(self.max_chunk_size, m.size), - dtype=np.float32, - ) - Jmatrix = array.to_zarr( - rows, - os.path.join(sens_name + "_1.zarr"), - compute=True, - return_stored=True, - overwrite=True, - ) - else: - Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) - - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) - blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) - fields_array = f[:, ftype, :] - - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] - - times_field_derivs, Jmatrix = compute_field_derivs( - self, f, blocks, Jmatrix, fields_array.shape - ) - - ATinv_df_duT_v = {} - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - AdiagTinv = self.Ainv[dt] - j_row_updates = [] - time_mask = data_times > simulation_times[tInd] - - if not np.any(time_mask): - continue - - for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): - ATinv_df_duT_v = get_field_deriv_block( - self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask - ) - - if len(block) == 0: - continue - - j_row_updates.append( - array.from_delayed( - compute_rows( - self, - tInd, - block, - ATinv_df_duT_v, - fields_array, - time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), - ) - ) - - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr( - Jmatrix + array.vstack(j_row_updates), - sens_name, - compute=True, - overwrite=True, - ) - Jmatrix = array.from_zarr(sens_name) - else: - Jmatrix += array.vstack(j_row_updates).compute() - - for A in self.Ainv.values(): - A.clean() - - if self.store_sensitivities == "ram": - self._Jmatrix = np.asarray(Jmatrix) - - self._Jmatrix = Jmatrix - - return self._Jmatrix - - -Sim.compute_J = compute_J diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index a8dd98f6e0..7272443e41 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -9,7 +9,7 @@ from ..objective_function import ComboObjectiveFunction -def get_dpred(self, m, f=None, compute_J=False): +def get_dpred(self, m, f=None): dpreds = [] for i, objfct in enumerate(self.dmisfit.objfcts): @@ -21,9 +21,6 @@ def get_dpred(self, m, f=None, compute_J=False): future = objfct.simulation.dpred(m, f=fields) - if compute_J: - objfct.simulation.compute_J(m, f=fields) - dpreds += [future] if isinstance(dpreds[0], Future): @@ -45,7 +42,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): # if isinstance(self.dmisfit, BaseDataMisfit): phi_d = self.dmisfit(m, f=fields) - self.dpred = self.get_dpred(m, f=fields, compute_J=return_H) + self.dpred = self.get_dpred(m, f=fields) phi_d = 0 for (_, objfct), pred in zip(self.dmisfit, self.dpred): diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 42ee1c1328..5237571fcf 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,135 +1,114 @@ import numpy as np from ...potential_fields.base import BasePFSimulation as Sim +from ..simulation import BaseSimulation import os from dask import delayed, array, config from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes -from simpeg.dask.simulation import dask_getJtJdiag - -Sim.getJtJdiag = dask_getJtJdiag -Sim._chunk_format = "row" - - -@property -def chunk_format(self): - "Apply memory chunks along rows of G, either 'equal', 'row', or 'auto'" - return self._chunk_format - - -@chunk_format.setter -def chunk_format(self, other): - if other not in ["equal", "row", "auto"]: - raise ValueError("Chunk format must be 'equal', 'row', or 'auto'") - self._chunk_format = other - - -Sim.chunk_format = chunk_format - - -def dask_dpred(self, m=None, f=None): - if m is not None: - self.model = m - if f is not None: - return f - return self.fields(self.model) - - -Sim.dpred = dask_dpred - - -def dask_residual(self, m, dobs, f=None): - return self.dpred(m, f=f) - dobs - - -Sim.residual = dask_residual - - -def dask_linear_operator(self): - forward_only = self.store_sensitivities == "forward_only" - row = delayed(self.evaluate_integral, pure=True) - n_cells = self.nC - if getattr(self, "model_type", None) == "vector": - n_cells *= 3 - - rows = [ - array.from_delayed( - row(receiver_location, components), - dtype=self.sensitivity_dtype, - shape=(len(components),) if forward_only else (len(components), n_cells), - ) - for receiver_location, components in self.survey._location_component_iterator() - ] - if forward_only: - stack = array.concatenate(rows) - else: - stack = array.vstack(rows) - # Chunking options - if self.chunk_format == "row": - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - # Autochunking by rows is faster and more memory efficient for - # very large problems sensitivty and forward calculations - stack = stack.rechunk({0: "auto", 1: -1}) - elif self.chunk_format == "equal": - # Manual chunks for equal number of blocks along rows and columns. - # Optimal for Jvec and Jtvec operations - row_chunk, col_chunk = compute_chunk_sizes( - *stack.shape, self.max_chunk_size +class BasePFSimulation(BaseSimulation, Sim): + + _chunk_format = "row" + + @property + def chunk_format(self): + "Apply memory chunks along rows of G, either 'equal', 'row', or 'auto'" + return self._chunk_format + + @chunk_format.setter + def chunk_format(self, other): + if other not in ["equal", "row", "auto"]: + raise ValueError("Chunk format must be 'equal', 'row', or 'auto'") + self._chunk_format = other + + def dpred(self, m=None, f=None): + if m is not None: + self.model = m + if f is not None: + return f + return self.fields(self.model) + + def residual(self, m, dobs, f=None): + return self.dpred(m, f=f) - dobs + + def linear_operator(self): + forward_only = self.store_sensitivities == "forward_only" + row = delayed(self.evaluate_integral, pure=True) + n_cells = self.nC + if getattr(self, "model_type", None) == "vector": + n_cells *= 3 + + rows = [ + array.from_delayed( + row(receiver_location, components), + dtype=self.sensitivity_dtype, + shape=( + (len(components),) if forward_only else (len(components), n_cells) + ), ) - stack = stack.rechunk((row_chunk, col_chunk)) + for receiver_location, components in self.survey._location_component_iterator() + ] + if forward_only: + stack = array.concatenate(rows) else: - # Auto chunking by columns is faster for Inversions - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - stack = stack.rechunk({0: -1, 1: "auto"}) - - if self.store_sensitivities == "disk": - sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") - if os.path.exists(sens_name): - kernel = array.from_zarr(sens_name) - if np.all( - np.r_[ - np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), - np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), - np.r_[kernel.shape] == np.r_[stack.shape], - ] - ): - # Check that loaded kernel matches supplied data and mesh - print("Zarr file detected with same shape and chunksize ... re-loading") - return kernel - - print("Writing Zarr file to disk") - with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=True, return_stored=True, overwrite=True - ) - elif forward_only: - with ProgressBar(): - print("Forward calculation: ") - kernel = stack.compute() - else: - with ProgressBar(): - print("Computing sensitivities to local ram") - kernel = stack.persist() - return kernel - - -Sim.linear_operator = dask_linear_operator - - -def compute_J(self, _): - return self.linear_operator() - - -Sim.compute_J = compute_J - - -@property -def Jmatrix(self): - if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.linear_operator() - return self._Jmatrix - - -Sim.Jmatrix = Jmatrix + stack = array.vstack(rows) + # Chunking options + if self.chunk_format == "row": + config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) + # Autochunking by rows is faster and more memory efficient for + # very large problems sensitivty and forward calculations + stack = stack.rechunk({0: "auto", 1: -1}) + elif self.chunk_format == "equal": + # Manual chunks for equal number of blocks along rows and columns. + # Optimal for Jvec and Jtvec operations + row_chunk, col_chunk = compute_chunk_sizes( + *stack.shape, self.max_chunk_size + ) + stack = stack.rechunk((row_chunk, col_chunk)) + else: + # Auto chunking by columns is faster for Inversions + config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) + stack = stack.rechunk({0: -1, 1: "auto"}) + + if self.store_sensitivities == "disk": + sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") + if os.path.exists(sens_name): + kernel = array.from_zarr(sens_name) + if np.all( + np.r_[ + np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), + np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), + np.r_[kernel.shape] == np.r_[stack.shape], + ] + ): + # Check that loaded kernel matches supplied data and mesh + print( + "Zarr file detected with same shape and chunksize ... re-loading" + ) + return kernel + + print("Writing Zarr file to disk") + with ProgressBar(): + print("Saving kernel to zarr: " + sens_name) + kernel = array.to_zarr( + stack, sens_name, compute=False, return_stored=True, overwrite=True + ) + elif forward_only: + # with ProgressBar(): + # print("Forward calculation: ") + kernel = stack # .compute() + else: + # with ProgressBar(): + # print("Computing sensitivities to local ram") + kernel = stack # .persist() + return kernel + + def compute_J(self, _): + return self.linear_operator() + + @property + def Jmatrix(self): + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.linear_operator() + return self._Jmatrix diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index b380579ce0..a5a12ce135 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,11 +1,11 @@ import numpy as np from dask import array from ....potential_fields.gravity import Simulation3DIntegral as Sim -from ...simulation import BaseSimulation +from ..base import BasePFSimulation from ....utils import sdiag, mkvc -class Simulation3DIntegral(BaseSimulation, Sim): +class Simulation3DIntegral(BasePFSimulation, Sim): """ Overload the Simulation3DIntegral class to use Dask """ diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 2dff14ae10..9ca8d203dc 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,12 +1,11 @@ import numpy as np +from dask import array from ....potential_fields.magnetics import Simulation3DIntegral as Sim +from ..base import BasePFSimulation from ....utils import sdiag, mkvc -from ..base import Jmatrix -Sim.Jmatrix = Jmatrix - -class Simulation3DIntegral(Sim): +class Simulation3DIntegral(Sim, BasePFSimulation): """ Overwrite the dask_getJtJdiag method """ @@ -22,9 +21,11 @@ def getJtJdiag(self, m, W=None, f=None): W = np.ones(self.nD) else: W = W.diagonal() - if getattr(self, "_jtj_diag", None) is None: + if getattr(self, "_gtg_diagonal", None) is None: if not self.is_amplitude_data: - diag = ((W[:, None] * self.Jmatrix) ** 2).sum(axis=0).compute() + diag = array.einsum( + "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix + ).compute() else: ampDeriv = self.ampDeriv J = ( @@ -33,8 +34,8 @@ def getJtJdiag(self, m, W=None, f=None): + ampDeriv[2, :, None] * self.Jmatrix[2::3] ) diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() - self._jtj_diag = diag + self._gtg_diagonal = diag else: - diag = self._jtj_diag + diag = self._gtg_diagonal return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 4ce775d4dc..e1677dc8f2 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,6 +1,6 @@ from ..simulation import BaseSimulation as Sim -from dask import array, delayed +from dask import array import numpy as np @@ -10,6 +10,8 @@ class BaseSimulation(Sim): Base class for SimPEG simulations """ + clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] + sensitivity_path = "./sensitivity/" _max_ram = 16 _max_chunk_size = 128 @@ -52,50 +54,6 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtj_diag - # def dpred(self, m=None, f=None): - # r""" - # dpred(m, f=None) - # Create the projected data from a model. - # The fields, f, (if provided) will be used for the predicted data - # instead of recalculating the fields (which may be expensive!). - # - # .. math:: - # - # d_\\text{pred} = P(f(m)) - # - # Where P is a projection of the fields onto the data space. - # """ - # if self.survey is None: - # raise AttributeError( - # "The survey has not yet been set and is required to compute " - # "data. Please set the survey for the simulation: " - # "simulation.survey = survey" - # ) - # - # if f is None: - # if m is None: - # m = self.model - # f = self.fields(m) - # - # def evaluate_receiver(source, receiver, mesh, fields): - # return receiver.eval(source, mesh, fields).flatten() - # - # row = delayed(evaluate_receiver, pure=True) - # rows = [] - # for src in self.survey.source_list: - # for rx in src.receiver_list: - # rows.append( - # array.from_delayed( - # row(src, rx, self.mesh, f), - # dtype=np.float32, - # shape=(rx.nD,), - # ) - # ) - # - # data = array.hstack(rows) - # - # return data - def Jvec(self, m, v, **_): """ Compute sensitivity matrix (J) and vector (v) product. @@ -128,88 +86,3 @@ def Jmatrix(self): self._Jmatrix = self.compute_J(self.model) return self._Jmatrix - - -def dask_Jvec(self, m, v, **_): - """ - Compute sensitivity matrix (J) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix @ v.astype(np.float32) - - return array.dot(self.Jmatrix, v).astype(np.float32) - - -def dask_Jtvec(self, m, v, **_): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m - - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix.T @ v.astype(np.float32) - - return array.dot(v, self.Jmatrix).astype(np.float32) - - -def dask_getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - if getattr(self, "_jtjdiag", None) is None: - self.model = m - if W is None: - W = np.ones(self.Jmatrix.shape[0]) - else: - W = W.diagonal() - - self._jtj_diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) - - return self._jtj_diag - - -def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - def evaluate_receiver(source, receiver, mesh, fields): - return receiver.eval(source, mesh, fields).flatten() - - row = delayed(evaluate_receiver, pure=True) - rows = [] - for src in self.survey.source_list: - for rx in src.receiver_list: - rows.append( - array.from_delayed( - row(src, rx, self.mesh, f), - dtype=np.float32, - shape=(rx.nD,), - ) - ) - - data = array.hstack(rows).compute() - - return data diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index c144dded7c..8041d5ef48 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -8,7 +8,7 @@ import warnings import os import scipy.sparse as sp - +from ..meta.simulation import MetaSimulation from ..typing import RandomSeed from ..data_misfit import BaseDataMisfit @@ -3501,7 +3501,11 @@ def endIter(self): self.opt.upper[indices[nC:]] = np.inf for simulation in self.simulations: - simulation.chiMap = SphericalSystem() * simulation.chiMap + if isinstance(simulation, MetaSimulation): + for sim in simulation.simulations: + sim.chiMap = SphericalSystem() * sim.chiMap + else: + simulation.chiMap = SphericalSystem() * simulation.chiMap # Add and update directives for directive in self.inversion.directiveList.dList: diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index c888d9c330..a41b1463d6 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -33,13 +33,6 @@ def _calc_dpred(mapping, sim, model, field, apply_map=False): return sim.dpred(m=sim.model, f=field) -def _compute_J(sim, m, field): - if getattr(sim, "_Jmatrix", None) is not None: - return sim.Jmatrix - - return sim.compute_J(m, field) - - def _j_vec_op(mapping, sim, model, field, v, apply_map=False): # return array.from_array(np.zeros(100)) sim_v = mapping.deriv(model) @ v @@ -63,7 +56,7 @@ def _get_jtj_diag(mapping, sim, model, field, w, apply_map=False): jtj = sim.getJtJdiag(mapping @ model, w, f=field) else: jtj = sim.getJtJdiag(sim.model, w, f=field) - sim_jtj = sp.diags(np.sqrt(jtj)) + sim_jtj = sp.diags(np.sqrt(np.asarray(jtj))) m_deriv = mapping.deriv(model) return np.asarray((sim_jtj @ m_deriv).power(2).sum(axis=0)).flatten() @@ -450,24 +443,6 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag - def compute_J(self, m, f=None): - self.model = m - if f is None: - f = self.fields(m) - J = [] - client = self.client - for sim, worker, field in zip(self.simulations, self._workers, f): - J.append( - client.submit( - _compute_J, - sim, - m, - field, - workers=worker, - ) - ) - return self.client.gather(J) - class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation): """A dask distributed version of :class:`.SumMetaSimulation`. diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 75bbd47111..5cdbb4eb9c 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -326,19 +326,6 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag - def compute_J(self, m, f=None): - self.model = m - if f is None: - f = self.fields(m) - J = [] - for sim, field in zip(self.simulations, f): - if getattr(sim, "_Jmatrix", None) is not None: - continue - J.append( - sim.compute_J(m, field), - ) - return J - @property def deleteTheseOnModelUpdate(self): return super().deleteTheseOnModelUpdate + ["_jtjdiag"] diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 0cd0a95652..6643dc1f67 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -166,7 +166,7 @@ def fields(self, m): if self.engine == "choclo": fields = self._forward(self.rho) else: - fields = mkvc(self.linear_operator()) + fields = self.linear_operator() else: fields = self.G @ (self.rho).astype(self.sensitivity_dtype, copy=False) return np.asarray(fields) From e3330353f69b625b60d74a056a2e75c227f2c0f8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 13 Dec 2024 10:31:58 -0800 Subject: [PATCH 17/84] Start moving compute outside of delay --- simpeg/dask/inverse_problem.py | 10 +++++++--- .../dask/potential_fields/gravity/simulation.py | 16 ++++++++++------ simpeg/directives/directives.py | 2 +- simpeg/meta/dask_sim.py | 6 ++++-- simpeg/potential_fields/gravity/simulation.py | 2 +- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 7272443e41..417d85aeea 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -2,7 +2,7 @@ import numpy as np -from dask.distributed import Future, get_client +from dask.distributed import get_client from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse @@ -23,11 +23,15 @@ def get_dpred(self, m, f=None): dpreds += [future] - if isinstance(dpreds[0], Future): + try: client = get_client() + except ValueError: + client = None + + if client is not None: dpreds = client.gather(dpreds) - return dpreds + return np.asarray(dpreds) BaseInvProblem.get_dpred = get_dpred diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index a5a12ce135..934245abdc 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,8 +1,8 @@ import numpy as np -from dask import array +from dask import array, delayed from ....potential_fields.gravity import Simulation3DIntegral as Sim from ..base import BasePFSimulation -from ....utils import sdiag, mkvc +from scipy.sparse import csr_matrix as csr class Simulation3DIntegral(BasePFSimulation, Sim): @@ -22,11 +22,15 @@ def getJtJdiag(self, m, W=None, f=None): W = W.diagonal() if getattr(self, "_gtg_diagonal", None) is None: - diag = array.einsum( - "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix - ).compute() + diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) self._gtg_diagonal = diag else: diag = self._gtg_diagonal - return mkvc((sdiag(np.sqrt(diag)) @ self.rhoDeriv).power(2).sum(axis=0)) + mapping_deriv = self.rhoDeriv.tocsr().T.power(2) + dmudm_jtvec = delayed(csr.dot)(mapping_deriv, diag) + jtjdiag = array.from_delayed( + dmudm_jtvec, dtype=np.float32, shape=[mapping_deriv.shape[1]] + ) + + return jtjdiag diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 8041d5ef48..077e10bebe 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -68,7 +68,7 @@ def compute_JtJdiags(data_misfit, m): for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): jtj_diag += multiplier * diag - return jtj_diag + return np.asarray(jtj_diag) class InversionDirective: diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index a41b1463d6..68d10ee6ed 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -37,9 +37,9 @@ def _j_vec_op(mapping, sim, model, field, v, apply_map=False): # return array.from_array(np.zeros(100)) sim_v = mapping.deriv(model) @ v if apply_map: - return array.compute(sim.Jvec(mapping @ model, sim_v, f=field)) + return sim.Jvec(mapping @ model, sim_v, f=field) else: - return array.compute(sim.Jvec(sim.model, sim_v, f=field)) + return sim.Jvec(sim.model, sim_v, f=field) def _jt_vec_op(mapping, sim, model, field, v, start, end, apply_map=False): @@ -47,6 +47,8 @@ def _jt_vec_op(mapping, sim, model, field, v, start, end, apply_map=False): jtv = sim.Jtvec(mapping @ model, v[start:end], f=field) else: jtv = sim.Jtvec(sim.model, v[start:end], f=field) + + # Need to delay this operation until the future is computed return mapping.deriv(model).T @ array.compute(jtv)[0] diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 6643dc1f67..a2fc162e03 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -169,7 +169,7 @@ def fields(self, m): fields = self.linear_operator() else: fields = self.G @ (self.rho).astype(self.sensitivity_dtype, copy=False) - return np.asarray(fields) + return fields def getJtJdiag(self, m, W=None, f=None): """ From 5e738d1431005558a9451c3ed927c5a9bd58d7f8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 14 Dec 2024 13:03:59 -0800 Subject: [PATCH 18/84] Remove future for fields, assume stashed on simulation --- simpeg/dask/data_misfit.py | 19 +++++++---- simpeg/dask/inverse_problem.py | 37 +++++----------------- simpeg/dask/simulation.py | 1 + simpeg/dask/utils.py | 14 -------- simpeg/meta/dask_sim.py | 58 ++++++++++++++++------------------ 5 files changed, 50 insertions(+), 79 deletions(-) diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py index f251155163..0cb8ae97f1 100644 --- a/simpeg/dask/data_misfit.py +++ b/simpeg/dask/data_misfit.py @@ -2,13 +2,11 @@ from ..data_misfit import L2DataMisfit -from ..utils import mkvc - from dask.distributed import get_client, Future def _data_residual(dpred, dobs): - return mkvc(dpred) - dobs + return dpred - dobs def _misfit(residual, W): @@ -57,7 +55,13 @@ def dask_deriv(self, m, f=None): if isinstance(residuals, Future): client = get_client() - wtw_d = client.submit(_stack_futures, residuals, self.W.diagonal() ** 2.0) + who = client.who_has(residuals) + wtw_d = client.submit( + _stack_futures, + residuals, + self.W.diagonal() ** 2.0, + workers=who[residuals.key], + ) else: wtw_d = self.W.diagonal() ** 2.0 * residuals @@ -70,7 +74,7 @@ def dask_deriv(self, m, f=None): def _stack_futures(futures, W): - return W * np.hstack(futures).flatten() + return W * futures def dask_deriv2(self, m, v, f=None): @@ -80,7 +84,10 @@ def dask_deriv2(self, m, v, f=None): jvec = self.simulation.Jvec(m, v) if isinstance(jvec, Future): client = get_client() - w_jvec = client.submit(_stack_futures, jvec, self.W.diagonal() ** 2.0) + who = client.who_has(jvec) + w_jvec = client.submit( + _stack_futures, jvec, self.W.diagonal() ** 2.0, workers=who[jvec.key] + ) else: w_jvec = self.W.diagonal() ** 2.0 * jvec diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 417d85aeea..94bdea857b 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -12,26 +12,17 @@ def get_dpred(self, m, f=None): dpreds = [] - for i, objfct in enumerate(self.dmisfit.objfcts): - - if f is not None: - fields = f[i] - else: - fields = objfct.simulation.fields(m) - - future = objfct.simulation.dpred(m, f=fields) - - dpreds += [future] + for objfct in self.dmisfit.objfcts: + dpred = objfct.simulation.dpred(m) + dpreds += [dpred] try: client = get_client() - except ValueError: - client = None - - if client is not None: dpreds = client.gather(dpreds) + except ValueError: + pass - return np.asarray(dpreds) + return dpreds BaseInvProblem.get_dpred = get_dpred @@ -41,21 +32,13 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m - # Store fields if doing a line-search - fields = self.getFields(m, store=(return_g is False and return_H is False)) - - # if isinstance(self.dmisfit, BaseDataMisfit): - phi_d = self.dmisfit(m, f=fields) - self.dpred = self.get_dpred(m, f=fields) + self.dpred = self.get_dpred(m) phi_d = 0 for (_, objfct), pred in zip(self.dmisfit, self.dpred): residual = objfct.W * (objfct.data.dobs - pred) phi_d += np.vdot(residual, residual) - phi_d = np.asarray(phi_d) - # print(self.dpred[0]) - reg2Deriv = [] if isinstance(self.reg, ComboObjectiveFunction): for constant, objfct in self.reg: @@ -107,11 +90,8 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): out = (phi,) if return_g: - phi_dDeriv = self.dmisfit.deriv(m, f=fields) - # if hasattr(self.reg.objfcts[0], "space") and self.reg.objfcts[0].space == "spherical": + phi_dDeriv = self.dmisfit.deriv(m) phi_mDeriv = self.reg.deriv(m) - # else: - # phi_mDeriv = np.sum([reg2Deriv * obj.f_m for reg2Deriv, obj in zip(self.reg2Deriv, self.reg.objfcts)], axis=0) g = np.asarray(phi_dDeriv) + self.beta * phi_mDeriv out += (g,) @@ -121,7 +101,6 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): def H_fun(v): phi_d2Deriv = self.dmisfit.deriv2(m, v) phi_m2Deriv = self.reg2Deriv * v - H = phi_d2Deriv + self.beta * phi_m2Deriv return H diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index e1677dc8f2..06145cbaa9 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -84,5 +84,6 @@ def Jmatrix(self): """ if getattr(self, "_Jmatrix", None) is None: self._Jmatrix = self.compute_J(self.model) + self._stashed_fields = None return self._Jmatrix diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index ad292dc9a2..d2bf546220 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -1,5 +1,4 @@ import numpy as np -from dask.distributed import get_client from multiprocessing import cpu_count @@ -27,19 +26,6 @@ def compute_chunk_sizes(M, N, target_chunk_size): return rowChunk, colChunk -def compute(self, job): - """ - Compute dask job for either dask array or client. - """ - if isinstance(job, np.ndarray): - return job - try: - client = get_client() - return client.compute(job, workers=self.workers) - except ValueError: - return job.compute() - - def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> list: """ Get the blocks of sources and receivers to be computed in parallel. diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 68d10ee6ed..218c9ee647 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -26,38 +26,38 @@ def _calc_fields(mapping, sim, model, apply_map=False): return sim.fields(m=sim.model) -def _calc_dpred(mapping, sim, model, field, apply_map=False): +def _calc_dpred(mapping, sim, model, fields, apply_map=False): if apply_map and model is not None: - return sim.dpred(m=mapping @ model) + return array.compute(sim.dpred(m=mapping @ model, f=fields))[0] else: - return sim.dpred(m=sim.model, f=field) + return array.compute(sim.dpred(m=sim.model, f=fields))[0] -def _j_vec_op(mapping, sim, model, field, v, apply_map=False): +def _j_vec_op(mapping, sim, model, v, apply_map=False): # return array.from_array(np.zeros(100)) sim_v = mapping.deriv(model) @ v if apply_map: - return sim.Jvec(mapping @ model, sim_v, f=field) + return array.compute(sim.Jvec(mapping @ model, sim_v))[0] else: - return sim.Jvec(sim.model, sim_v, f=field) + return array.compute(sim.Jvec(sim.model, sim_v))[0] -def _jt_vec_op(mapping, sim, model, field, v, start, end, apply_map=False): +def _jt_vec_op(mapping, sim, model, v, start, end, apply_map=False): if apply_map: - jtv = sim.Jtvec(mapping @ model, v[start:end], f=field) + jtv = sim.Jtvec(mapping @ model, v[start:end]) else: - jtv = sim.Jtvec(sim.model, v[start:end], f=field) + jtv = sim.Jtvec(sim.model, v[start:end]) # Need to delay this operation until the future is computed return mapping.deriv(model).T @ array.compute(jtv)[0] -def _get_jtj_diag(mapping, sim, model, field, w, apply_map=False): +def _get_jtj_diag(mapping, sim, model, w, apply_map=False): w = sp.diags(w) if apply_map: - jtj = sim.getJtJdiag(mapping @ model, w, f=field) + jtj = sim.getJtJdiag(mapping @ model, w) else: - jtj = sim.getJtJdiag(sim.model, w, f=field) + jtj = sim.getJtJdiag(sim.model, w) sim_jtj = sp.diags(np.sqrt(np.asarray(jtj))) m_deriv = mapping.deriv(model) return np.asarray((sim_jtj @ m_deriv).power(2).sum(axis=0)).flatten() @@ -355,42 +355,40 @@ def dpred(self, m=None, f=None): workers=worker, ) ) - return _reduce(client, array.hstack, dpred) + return _reduce(client, np.concatenate, dpred) def Jvec(self, m, v, f=None): self.model = m m_future = self._m_as_future - if f is None: - f = self.fields(m) + # if f is None: + # f = self.fields(m) client = self.client [v_future] = client.scatter([v], broadcast=True) j_vec = [] - for mapping, sim, worker, field in zip( - self.mappings, self.simulations, self._workers, f - ): + for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): j_vec.append( client.submit( _j_vec_op, mapping, sim, m_future, - field, + # field, v_future, self._repeat_sim, workers=worker, ) ) - return _reduce(client, array.hstack, j_vec) + return _reduce(client, np.concatenate, j_vec) def Jtvec(self, m, v, f=None): self.model = m m_future = self._m_as_future - if f is None: - f = self.fields(m) + # if f is None: + # f = self.fields(m) jt_vec = [] client = self.client - for i, (mapping, sim, worker, field) in enumerate( - zip(self.mappings, self.simulations, self._workers, f) + for i, (mapping, sim, worker) in enumerate( + zip(self.mappings, self.simulations, self._workers) ): jt_vec.append( client.submit( @@ -398,7 +396,7 @@ def Jtvec(self, m, v, f=None): mapping, sim, m_future, - field, + # field, v, self._data_offsets[i], self._data_offsets[i + 1], @@ -420,10 +418,10 @@ def getJtJdiag(self, m, W=None, f=None): W = W.diagonal() jtj_diag = [] client = self.client - if f is None: - f = self.fields(m) - for i, (mapping, sim, worker, field) in enumerate( - zip(self.mappings, self.simulations, self._workers, f) + # if f is None: + # f = self.fields(m) + for i, (mapping, sim, worker) in enumerate( + zip(self.mappings, self.simulations, self._workers) ): sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] # s = client.gather(sim) @@ -434,7 +432,7 @@ def getJtJdiag(self, m, W=None, f=None): mapping, sim, m_future, - field, + # field, sim_w, self._repeat_sim, workers=worker, From 2cd396e446c517a78685f34b1527a4bb1de336bf Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 14 Dec 2024 22:02:46 -0800 Subject: [PATCH 19/84] Fix double compute of sensitivities --- simpeg/potential_fields/gravity/simulation.py | 4 ++++ simpeg/potential_fields/magnetics/simulation.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index a2fc162e03..789f9405f5 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -216,6 +216,10 @@ def G(self): Gravity forward operator """ if getattr(self, "_G", None) is None: + if self._Jmatrix is not None: + self._G = self._Jmatrix + return self._G + if self.engine == "choclo": self._G = self._sensitivity_matrix() else: diff --git a/simpeg/potential_fields/magnetics/simulation.py b/simpeg/potential_fields/magnetics/simulation.py index f960e63810..9a162cc78c 100644 --- a/simpeg/potential_fields/magnetics/simulation.py +++ b/simpeg/potential_fields/magnetics/simulation.py @@ -199,11 +199,17 @@ def fields(self, model): @property def G(self): if getattr(self, "_G", None) is None: + if self._Jmatrix is not None: + self._G = self._Jmatrix + return self._G + if self.engine == "choclo": self._G = self._sensitivity_matrix() else: self._G = self.linear_operator() + self._Jmatrix = self._G + return self._G modelType = deprecate_property( From b51553875f7b45eed8030241c36a2485e008e5e4 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 14 Dec 2024 23:27:09 -0800 Subject: [PATCH 20/84] Revert back to overload dask methods --- .../frequency_domain/simulation.py | 407 +++++++++--------- .../static/induced_polarization/simulation.py | 181 ++++---- .../induced_polarization/simulation_2d.py | 91 ++-- .../static/resistivity/simulation.py | 242 +++++------ .../static/resistivity/simulation_2d.py | 326 +++++++------- .../time_domain/simulation.py | 350 ++++++++------- simpeg/dask/potential_fields/base.py | 218 +++++----- .../potential_fields/gravity/simulation.py | 52 ++- .../potential_fields/magnetics/simulation.py | 56 ++- simpeg/dask/simulation.py | 162 +++---- 10 files changed, 1051 insertions(+), 1034 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index f1f75497e1..bb666c169d 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,5 +1,5 @@ from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim -from ...simulation import BaseSimulation + from ....utils import Zero import numpy as np import scipy.sparse as sp @@ -7,13 +7,9 @@ from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks -from simpeg.electromagnetics.frequency_domain.simulation import ( - Simulation3DMagneticFluxDensity as MagFlux, -) + from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary -from simpeg.electromagnetics.natural_source.simulation import ( - Simulation3DPrimarySecondary as NSPrimarySecondary, -) + import zarr from tqdm import tqdm @@ -102,246 +98,241 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T -class BaseFDEMSimulation(BaseSimulation, Sim): - sensitivity_path = "./sensitivity/" - clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] +def getSourceTerm(self, freq, source=None): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + if source is None: + source_list = self.survey.get_sources_by_frequency(freq) + source_block = np.array_split(source_list, cpu_count()) - def getSourceTerm(self, freq, source=None): - """ - Assemble the source term. This ensures that the RHS is a vector / array - of the correct size - """ - if source is None: - source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, cpu_count()) + block_compute = [] + for block in source_block: + if len(block) == 0: + continue - block_compute = [] - for block in source_block: - if len(block) == 0: - continue + block_compute.append(source_evaluation(self, block)) - block_compute.append(source_evaluation(self, block)) + blocks = compute(block_compute)[0] + s_m, s_e = [], [] + for block in blocks: + if block[0]: + s_m += block[0] + s_e += block[1] - blocks = compute(block_compute)[0] - s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m += block[0] - s_e += block[1] + else: + sm, se = source.eval(self) + s_m, s_e = [sm], [se] - else: - sm, se = source.eval(self) - s_m, s_e = [sm], [se] + if isinstance(s_m[0][0], Zero): # Assume the rest is all Zero + s_m = Zero() + else: + s_m = np.vstack(s_m) + if s_m.shape[0] < s_m.shape[1]: + s_m = s_m.T - if isinstance(s_m[0][0], Zero): # Assume the rest is all Zero - s_m = Zero() - else: - s_m = np.vstack(s_m) - if s_m.shape[0] < s_m.shape[1]: - s_m = s_m.T + if isinstance(s_e[0][0], Zero): # Assume the rest is all Zero + s_e = Zero() + else: + s_e = np.vstack(s_e) + if s_e.shape[0] < s_e.shape[1]: + s_e = s_e.T + return s_m, s_e - if isinstance(s_e[0][0], Zero): # Assume the rest is all Zero - s_e = Zero() - else: - s_e = np.vstack(s_e) - if s_e.shape[0] < s_e.shape[1]: - s_e = s_e.T - return s_m, s_e - - def dask_dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - if f is None: - if m is None: - m = self.model - f = self.fields(m) +def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). - all_receivers = [] + .. math:: - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) + d_\\text{pred} = P(f(m)) - receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) - rows = [] - mesh = delayed(self.mesh) - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) - rows.append( - array.from_delayed( - evaluate_receivers(block, mesh, f), - dtype=np.float64, - shape=(n_data,), - ) + if f is None: + if m is None: + m = self.model + f = self.fields(m) + + all_receivers = [] + + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) + rows = [] + mesh = delayed(self.mesh) + for block in receiver_blocks: + n_data = np.sum([rec.nD for _, _, rec in block]) + if n_data == 0: + continue + + rows.append( + array.from_delayed( + evaluate_receivers(block, mesh, f), + dtype=np.float64, + shape=(n_data,), ) + ) - data = compute(array.hstack(rows))[0] + data = compute(array.hstack(rows))[0] - return data + return data - def fields(self, m=None): - if m is not None: - self.model = m - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields +def fields(self, m=None): + if m is not None: + self.model = m - f = self.fieldsPair(self) - Ainv = {} - for freq in self.survey.frequencies: - A = self.getA(freq) - rhs = self.getRHS(freq) - Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) - u = Ainv_solve * rhs - sources = self.survey.get_sources_by_frequency(freq) - f[sources, self._solutionType] = u - Ainv[freq] = Ainv_solve + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - self.Ainv = Ainv + f = self.fieldsPair(self) + Ainv = {} + for freq in self.survey.frequencies: + A = self.getA(freq) + rhs = self.getRHS(freq) + Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) + u = Ainv_solve * rhs + sources = self.survey.get_sources_by_frequency(freq) + f[sources, self._solutionType] = u + Ainv[freq] = Ainv_solve - self._stashed_fields = f + self.Ainv = Ainv - return f + self._stashed_fields = f - def compute_J(self, m, f=None): - self.model = m + return f - if f is None: - f = self.fields(m) - if len(self.Ainv) > 1: - raise NotImplementedError( - "Current implementation of parallelization assumes a single frequency per simulation. " - "Consider creating one misfit per frequency." - ) +def compute_J(self, m, f=None): + self.model = m - A_i = list(self.Ainv.values())[0] - m_size = m.size + if f is None: + f = self.fields(m) - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path, - mode="w", - shape=(self.survey.nD, m_size), - chunks=(self.max_chunk_size, m_size), - ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - - compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) - blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=False + if len(self.Ainv) > 1: + raise NotImplementedError( + "Current implementation of parallelization assumes a single frequency per simulation. " + "Consider creating one misfit per frequency." ) - fields_array = delayed(f[:, self._solutionType]) - fields = delayed(f) - survey = delayed(self.survey) - mesh = delayed(self.mesh) - blocks_receiver_derivs = [] - for block in blocks: - blocks_receiver_derivs.append( - receiver_derivs( - survey, - mesh, - fields, - block, - ) - ) + A_i = list(self.Ainv.values())[0] + m_size = m.size - # Dask process for all derivatives - blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, m_size), + chunks=(self.max_chunk_size, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - for block_derivs_chunks, addresses_chunks in tqdm( - zip(blocks_receiver_derivs, blocks), - ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", - ): - Jmatrix = self.parallel_block_compute( - m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) + blocks = get_parallel_blocks( + self.survey.source_list, compute_row_size, optimize=False + ) + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + blocks_receiver_derivs = [] + + for block in blocks: + blocks_receiver_derivs.append( + receiver_derivs( + survey, + mesh, + fields, + block, ) + ) - for A in self.Ainv.values(): - A.clean() - - if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = array.from_zarr(self.sensitivity_path) - else: - self._Jmatrix = Jmatrix - - return self._Jmatrix + # Dask process for all derivatives + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - def parallel_block_compute( - self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + for block_derivs_chunks, addresses_chunks in tqdm( + zip(blocks_receiver_derivs, blocks), + ncols=len(blocks_receiver_derivs), + desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", ): - m_size = m.size - block_stack = sp.hstack(blocks_receiver_derivs).toarray() - ATinvdf_duT = delayed(A_i * block_stack) - count = 0 - rows = [] - block_delayed = [] - - for address, dfduT in zip(addresses, blocks_receiver_derivs): - n_cols = dfduT.shape[1] - n_rows = address[1][2] - block_delayed.append( - array.from_delayed( - eval_block( - self, - ATinvdf_duT, - np.arange(count, count + n_cols), - Zero(), - fields_array, - address, - ), - dtype=np.float32, - shape=(n_rows, m_size), - ) - ) - count += n_cols - rows += address[1][1].tolist() + Jmatrix = self.parallel_block_compute( + m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + ) - indices = np.hstack(rows) + for A in self.Ainv.values(): + A.clean() - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (indices, slice(None)), - compute(array.vstack(block_delayed))[0], + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = array.from_zarr(self.sensitivity_path) + else: + self._Jmatrix = Jmatrix + + return self._Jmatrix + + +def parallel_block_compute( + self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses +): + m_size = m.size + block_stack = sp.hstack(blocks_receiver_derivs).toarray() + ATinvdf_duT = delayed(A_i * block_stack) + count = 0 + rows = [] + block_delayed = [] + + for address, dfduT in zip(addresses, blocks_receiver_derivs): + n_cols = dfduT.shape[1] + n_rows = address[1][2] + block_delayed.append( + array.from_delayed( + eval_block( + self, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), ) - else: - # Dask process to compute row and store - Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + ) + count += n_cols + rows += address[1][1].tolist() - return Jmatrix + indices = np.hstack(rows) + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (indices, slice(None)), + compute(array.vstack(block_delayed))[0], + ) + else: + # Dask process to compute row and store + Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] -class Simulation3DMagneticFluxDensity(MagFlux, BaseFDEMSimulation): - """ - Overload the Simulation3DMagneticFluxDensity class to provide the necessary functionality - """ + return Jmatrix -class Simulation3DPrimarySecondary(NSPrimarySecondary, BaseFDEMSimulation): - """ - Overload the Simulation3DPrimarySecondary class to provide the necessary functionality - """ +Sim.parallel_block_compute = parallel_block_compute +Sim.compute_J = compute_J +Sim.fields = fields +Sim.dpred = dpred +Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 2acd63a822..beee5f7338 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -1,7 +1,13 @@ from .....electromagnetics.static.induced_polarization.simulation import ( - Simulation3DNodal as Sim, + BaseIPSimulation as Sim, ) -from ....simulation import BaseSimulation + +from ..resistivity.simulation import ( + compute_J, + getSourceTerm, +) + + from .....data import Data import dask.array as da from dask.distributed import Future @@ -10,119 +16,118 @@ numcodecs.blosc.use_threads = False -from ..resistivity.simulation import Simulation3DNodal as SimulationDC3D +def fields(self, m=None): + if m is not None: + self.model = m -class Simulation3DNodal(BaseSimulation, Sim): + A = self.getA() + Ainv = self.solver(A, **self.solver_opts) + RHS = self.getRHS() - def fields(self, m=None): - if m is not None: - self.model = m + f = self.fieldsPair(self) + f[:, self._solutionType] = Ainv * RHS - A = self.getA() - Ainv = self.solver(A, **self.solver_opts) - RHS = self.getRHS() + if self._scale is None: + scale = Data(self.survey, np.ones(self.survey.nD)) + # loop through receivers to check if they need to set the _dc_voltage + for src in self.survey.source_list: + for rx in src.receiver_list: + if ( + rx.data_type == "apparent_chargeability" + or self._data_type == "apparent_chargeability" + ): + scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) + self._scale = scale.dobs - f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + self.Ainv = Ainv - if self._scale is None: - scale = Data(self.survey, np.ones(self.survey.nD)) - # loop through receivers to check if they need to set the _dc_voltage - for src in self.survey.source_list: - for rx in src.receiver_list: - if ( - rx.data_type == "apparent_chargeability" - or self._data_type == "apparent_chargeability" - ): - scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) - self._scale = scale.dobs + return f - self.Ainv = Ainv - return f +def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). - def dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). + .. math:: - .. math:: + d_\\text{pred} = P(f(m)) - d_\\text{pred} = P(f(m)) + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) + data = self.Jvec(m, m) - data = self.Jvec(m, m) + return np.asarray(data) - return np.asarray(data) - def getJtJdiag(self, m, W=None): - """ - Return the diagonal of JtJ - """ - self.model = m - if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish +def getJtJdiag(self, m, W=None): + """ + Return the diagonal of JtJ + """ + self.model = m + if getattr(self, "_jtjdiag", None) is None: + if isinstance(self.Jmatrix, Future): + self.Jmatrix # Wait to finish - if W is None: - W = self._scale * np.ones(self.nD) - else: - W = (self._scale * W.diagonal()) ** 2.0 + if W is None: + W = self._scale * np.ones(self.nD) + else: + W = (self._scale * W.diagonal()) ** 2.0 - diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) + diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - if isinstance(diag, da.Array): - diag = np.asarray(diag.compute()) + if isinstance(diag, da.Array): + diag = np.asarray(diag.compute()) - self._jtjdiag = diag + self._jtjdiag = diag - return self._jtjdiag + return self._jtjdiag - def Jvec(self, m, v, f=None): - """ - Compute sensitivity matrix (J) and vector (v) product. - """ - self.model = m - if isinstance(self.Jmatrix, np.ndarray): - return self._scale.astype(np.float32) * ( - self.Jmatrix @ v.astype(np.float32) - ) +def Jvec(self, m, v, f=None): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish + if isinstance(self.Jmatrix, np.ndarray): + return self._scale.astype(np.float32) * (self.Jmatrix @ v.astype(np.float32)) - return self._scale.astype(np.float32) * da.dot(self.Jmatrix, v).astype( - np.float32 - ) + if isinstance(self.Jmatrix, Future): + self.Jmatrix # Wait to finish - def Jtvec(self, m, v, f=None): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m + return self._scale.astype(np.float32) * da.dot(self.Jmatrix, v).astype(np.float32) - if isinstance(self.Jmatrix, np.ndarray): - return (self._scale * v.astype(np.float32)).astype( - np.float32 - ) @ self.Jmatrix - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish +def Jtvec(self, m, v, f=None): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return (self._scale * v.astype(np.float32)).astype(np.float32) @ self.Jmatrix + + if isinstance(self.Jmatrix, Future): + self.Jmatrix # Wait to finish - return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) + return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) -Simulation3DNodal.compute_J = SimulationDC3D.compute_J -Simulation3DNodal.getSourceTerm = SimulationDC3D.getSourceTerm +Sim.compute_J = compute_J +Sim.getSourceTerm = getSourceTerm +Sim.Jtvec = Jtvec +Sim.Jvec = Jvec +Sim.getJtJdiag = getJtJdiag +Sim.dpred = dpred +Sim.fields = fields diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index aa9afade78..0a535b2af6 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -1,54 +1,53 @@ from .....electromagnetics.static.induced_polarization.simulation import ( Simulation2DNodal as Sim, ) -from ....simulation import BaseSimulation from .....data import Data import numpy as np import numcodecs numcodecs.blosc.use_threads = False - -from ..resistivity.simulation_2d import Simulation2DNodal as SimulationDC2D - - -class Simulation2DNodal(BaseSimulation, Sim): - """ - Overloaded Simulation2DNodal to include the dask methods - """ - - def fields(self, m=None): - if m is not None: - self.model = m - - kys = self._quad_points - f = self.fieldsPair(self) - f._quad_weights = self._quad_weights - - Ainv = {} - for iky, ky in enumerate(kys): - A = self.getA(ky) - Ainv[iky] = self.solver(A, **self.solver_opts) - - RHS = self.getRHS(ky) - f[:, self._solutionType, iky] = Ainv[iky] * RHS - - if self._scale is None: - scale = Data(self.survey, np.ones(self.survey.nD)) - f_fwd = self.fields_to_space(f) - # loop through receievers to check if they need to set the _dc_voltage - for src in self.survey.source_list: - for rx in src.receiver_list: - if ( - rx.data_type == "apparent_chargeability" - or self._data_type == "apparent_chargeability" - ): - scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) - self._scale = scale.dobs - - self.Ainv = Ainv - - return f - - -Simulation2DNodal.compute_J = SimulationDC2D.compute_J -Simulation2DNodal.getSourceTerm = SimulationDC2D.getSourceTerm +from .simulation import getJtJdiag, Jvec, Jtvec, dpred +from ..resistivity.simulation_2d import compute_J, getSourceTerm + + +def fields(self, m=None): + if m is not None: + self.model = m + + kys = self._quad_points + f = self.fieldsPair(self) + f._quad_weights = self._quad_weights + + Ainv = {} + for iky, ky in enumerate(kys): + A = self.getA(ky) + Ainv[iky] = self.solver(A, **self.solver_opts) + + RHS = self.getRHS(ky) + f[:, self._solutionType, iky] = Ainv[iky] * RHS + + if self._scale is None: + scale = Data(self.survey, np.ones(self.survey.nD)) + f_fwd = self.fields_to_space(f) + # loop through receievers to check if they need to set the _dc_voltage + for src in self.survey.source_list: + for rx in src.receiver_list: + if ( + rx.data_type == "apparent_chargeability" + or self._data_type == "apparent_chargeability" + ): + scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) + self._scale = scale.dobs + + self.Ainv = Ainv + + return f + + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.dpred = dpred +Sim.fields = fields +Sim.compute_J = compute_J +Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 8b43f07644..bf8e11a674 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,5 +1,5 @@ from .....electromagnetics.static.resistivity.simulation import Simulation3DNodal as Sim -from ....simulation import BaseSimulation + from .....utils import Zero import dask.array as da import numpy as np @@ -14,158 +14,158 @@ numcodecs.blosc.use_threads = False -class Simulation3DNodal(BaseSimulation, Sim): - """ - Overload of the Simulation3DNodal to include the dask operations - """ +def fields(self, m=None): + if m is not None: + self.model = m - def fields(self, m=None): - if m is not None: - self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields + A = self.getA() + Ainv = self.solver(A, **self.solver_opts) + RHS = self.getRHS() - A = self.getA() - Ainv = self.solver(A, **self.solver_opts) - RHS = self.getRHS() + f = self.fieldsPair(self) + f[:, self._solutionType] = Ainv * RHS - f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + self.Ainv = Ainv - self.Ainv = Ainv + self._stashed_fields = f - self._stashed_fields = f + return f - return f - def compute_J(self, m, f=None): +def compute_J(self, m, f=None): - if f is None: - f = self.fields(m) + if f is None: + f = self.fields(m) - m_size = m.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil( - float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size - ) - ) + m_size = m.size + row_chunks = int( + np.ceil( + float(self.survey.nD) + / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) + ) + ) + + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path + "J.zarr", + mode="w", + shape=(self.survey.nD, m_size), + chunks=(row_chunks, m_size), ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), - ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + blocks = [] + count = 0 + for source in self.survey.source_list: + u_source = f[source, self._solutionType] - blocks = [] - count = 0 - for source in self.survey.source_list: - u_source = f[source, self._solutionType] + for rx in source.receiver_list: + + if rx.orientation is not None: + projected_grid = f._GLoc(rx.projField) + rx.orientation + else: + projected_grid = f._GLoc(rx.projField) - for rx in source.receiver_list: + PTv = rx.getP(self.mesh, projected_grid).toarray().T - if rx.orientation is not None: - projected_grid = f._GLoc(rx.projField) + rx.orientation + for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): + start, end = dd * row_chunks, np.min( + [(dd + 1) * row_chunks, PTv.shape[1]] + ) + df_duTFun = getattr(f, "_{0!s}Deriv".format(rx.projField), None) + df_duT, df_dmT = df_duTFun( + source, None, PTv[:, start:end], adjoint=True + ) + ATinvdf_duT = self.Ainv * df_duT + dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) + dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(df_dmT, Zero): + du_dmT += df_dmT + + # + du_dmT = du_dmT.T.reshape((-1, m_size)) + + if len(blocks) == 0: + blocks = du_dmT else: - projected_grid = f._GLoc(rx.projField) - - PTv = rx.getP(self.mesh, projected_grid).toarray().T - - for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): - start, end = dd * row_chunks, np.min( - [(dd + 1) * row_chunks, PTv.shape[1]] - ) - df_duTFun = getattr(f, "_{0!s}Deriv".format(rx.projField), None) - df_duT, df_dmT = df_duTFun( - source, None, PTv[:, start:end], adjoint=True - ) - ATinvdf_duT = self.Ainv * df_duT - dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) - dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(df_dmT, Zero): - du_dmT += df_dmT - - # - du_dmT = du_dmT.T.reshape((-1, m_size)) - - if len(blocks) == 0: - blocks = du_dmT + blocks = np.vstack([blocks, du_dmT]) + + while blocks.shape[0] >= row_chunks: + + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, count + row_chunks), slice(None)), + blocks[:row_chunks, :].astype(np.float32), + ) else: - blocks = np.vstack([blocks, du_dmT]) + Jmatrix[count : count + row_chunks, :] = blocks[ + :row_chunks, : + ].astype(np.float32) - while blocks.shape[0] >= row_chunks: + blocks = blocks[row_chunks:, :].astype(np.float32) + count += row_chunks - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, count + row_chunks), slice(None)), - blocks[:row_chunks, :].astype(np.float32), - ) - else: - Jmatrix[count : count + row_chunks, :] = blocks[ - :row_chunks, : - ].astype(np.float32) + del df_duT, ATinvdf_duT, dA_dmT, dRHS_dmT, du_dmT - blocks = blocks[row_chunks:, :].astype(np.float32) - count += row_chunks + if len(blocks) != 0: - del df_duT, ATinvdf_duT, dA_dmT, dRHS_dmT, du_dmT + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, self.survey.nD), slice(None)), + blocks.astype(np.float32), + ) + else: + Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - if len(blocks) != 0: + self.Ainv.clean() - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, self.survey.nD), slice(None)), - blocks.astype(np.float32), - ) - else: - Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + else: + self._Jmatrix = Jmatrix - self.Ainv.clean() + return self._Jmatrix - if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") - else: - self._Jmatrix = Jmatrix - return self._Jmatrix +def getSourceTerm(self): + """ + Evaluates the sources, and puts them in matrix form + :rtype: tuple + :return: q (nC or nN, nSrc) + """ + + if getattr(self, "_q", None) is None: - def getSourceTerm(self): - """ - Evaluates the sources, and puts them in matrix form - :rtype: tuple - :return: q (nC or nN, nSrc) - """ + if self._mini_survey is not None: + Srcs = self._mini_survey.source_list + else: + Srcs = self.survey.source_list - if getattr(self, "_q", None) is None: + if self._formulation == "EB": + n = self.mesh.nN + # return NotImplementedError - if self._mini_survey is not None: - Srcs = self._mini_survey.source_list - else: - Srcs = self.survey.source_list + elif self._formulation == "HJ": + n = self.mesh.nC - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError + q = np.zeros((n, len(Srcs)), order="F") - elif self._formulation == "HJ": - n = self.mesh.nC + for i, source in enumerate(Srcs): + q[:, i] = source.eval(self) - q = np.zeros((n, len(Srcs)), order="F") + self._q = q - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + return self._q - self._q = q - return self._q +Sim.getSourceTerm = getSourceTerm +Sim.fields = fields +Sim.compute_J = compute_J diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index b2d1546b74..bb5bd0e6c6 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,7 +1,7 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( Simulation2DNodal as Sim, ) -from ....simulation import BaseSimulation + import dask.array as da import numpy as np import zarr @@ -10,203 +10,205 @@ numcodecs.blosc.use_threads = False -class Simulation2DNodal(BaseSimulation, Sim): - """ - Overload of the Simulation3DNodal to include the dask operations - """ - - clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] +def fields(self, m=None): + if m is not None: + self.model = m - def fields(self, m=None): - if m is not None: - self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields + kys = self._quad_points + f = self.fieldsPair(self) + f._quad_weights = self._quad_weights - kys = self._quad_points - f = self.fieldsPair(self) - f._quad_weights = self._quad_weights + Ainv = {} + for iky, ky in enumerate(kys): + A = self.getA(ky) + Ainv[iky] = self.solver(A, **self.solver_opts) - Ainv = {} - for iky, ky in enumerate(kys): - A = self.getA(ky) - Ainv[iky] = self.solver(A, **self.solver_opts) + RHS = self.getRHS(ky) + f[:, self._solutionType, iky] = Ainv[iky] * RHS - RHS = self.getRHS(ky) - f[:, self._solutionType, iky] = Ainv[iky] * RHS + self.Ainv = Ainv - self.Ainv = Ainv + self._stashed_fields = f + return f - self._stashed_fields = f - return f - def compute_J(self, m, f=None): - kys = self._quad_points - weights = self._quad_weights +def compute_J(self, m, f=None): + kys = self._quad_points + weights = self._quad_weights - if f is None: - f = self.fields(m) + if f is None: + f = self.fields(m) - m_size = m.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil( - float(m_size) - * self.survey.nD - * len(kys) - * 8.0 - * 1e-6 - / self.max_chunk_size - ) + m_size = m.size + row_chunks = int( + np.ceil( + float(self.survey.nD) + / np.ceil( + float(m_size) + * self.survey.nD + * len(kys) + * 8.0 + * 1e-6 + / self.max_chunk_size ) ) - if self.store_sensitivities == "disk": - Jmatrix = zarr.open( - self.sensitivity_path + "J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), - ) - else: - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + ) + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path + "J.zarr", + mode="w", + shape=(self.survey.nD, m_size), + chunks=(row_chunks, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + + blocks = [] + count = 0 - blocks = [] - count = 0 + for i_src, source in enumerate(self.survey.source_list): + for rx in source.receiver_list: + + if rx.orientation is not None: + projected_grid = f._GLoc(rx.projField) + rx.orientation + else: + projected_grid = f._GLoc(rx.projField) - for i_src, source in enumerate(self.survey.source_list): - for rx in source.receiver_list: + PTv = rx.getP(self.mesh, projected_grid).toarray().T - if rx.orientation is not None: - projected_grid = f._GLoc(rx.projField) + rx.orientation + for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): + start, end = dd * row_chunks, np.min( + [(dd + 1) * row_chunks, PTv.shape[1]] + ) + block = np.zeros((end - start, m_size)) + for iky, ky in enumerate(kys): + + u_ky = f[:, self._solutionType, iky] + u_source = u_ky[:, i_src] + ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] + dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) + du_dmT = -weights[iky] * dA_dmT + block += du_dmT.T.reshape((-1, m_size)) + + if len(blocks) == 0: + blocks = block else: - projected_grid = f._GLoc(rx.projField) - - PTv = rx.getP(self.mesh, projected_grid).toarray().T - - for dd in range(int(np.ceil(PTv.shape[1] / row_chunks))): - start, end = dd * row_chunks, np.min( - [(dd + 1) * row_chunks, PTv.shape[1]] - ) - block = np.zeros((end - start, m_size)) - for iky, ky in enumerate(kys): - - u_ky = f[:, self._solutionType, iky] - u_source = u_ky[:, i_src] - ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] - dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) - du_dmT = -weights[iky] * dA_dmT - block += du_dmT.T.reshape((-1, m_size)) - - if len(blocks) == 0: - blocks = block + blocks = np.vstack([blocks, block]) + + while blocks.shape[0] >= row_chunks: + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (np.arange(count, count + row_chunks), slice(None)), + blocks[:row_chunks, :].astype(np.float32), + ) else: - blocks = np.vstack([blocks, block]) - - while blocks.shape[0] >= row_chunks: - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, count + row_chunks), slice(None)), - blocks[:row_chunks, :].astype(np.float32), - ) - else: - Jmatrix[count : count + row_chunks, :] = blocks[ - :row_chunks, : - ].astype(np.float32) - - blocks = blocks[row_chunks:, :].astype(np.float32) - count += row_chunks - - del ATinvdf_duT, dA_dmT, block - - if len(blocks) != 0: - if self.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(count, self.survey.nD), slice(None)), - blocks.astype(np.float32), - ) - else: - Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) + Jmatrix[count : count + row_chunks, :] = blocks[ + :row_chunks, : + ].astype(np.float32) + + blocks = blocks[row_chunks:, :].astype(np.float32) + count += row_chunks - for iky, _ in enumerate(kys): - self.Ainv[iky].clean() + del ATinvdf_duT, dA_dmT, block + if len(blocks) != 0: if self.store_sensitivities == "disk": - del Jmatrix - self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + Jmatrix.set_orthogonal_selection( + (np.arange(count, self.survey.nD), slice(None)), + blocks.astype(np.float32), + ) else: - self._Jmatrix = Jmatrix + Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - return self._Jmatrix + for iky, _ in enumerate(kys): + self.Ainv[iky].clean() - def dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). + if self.store_sensitivities == "disk": + del Jmatrix + self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr") + else: + self._Jmatrix = Jmatrix - .. math:: + return self._Jmatrix - d_\\text{pred} = P(f(m)) - Where P is a projection of the fields onto the data space. - """ - weights = self._quad_weights - if self._mini_survey is not None: - survey = self._mini_survey - else: - survey = self.survey +def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). - if survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) + .. math:: - if f is None: - if m is None: - m = self.model - f = self.fields(m) + d_\\text{pred} = P(f(m)) - temp = np.empty(survey.nD) - count = 0 - for src in survey.source_list: - for rx in src.receiver_list: - d = rx.eval(src, self.mesh, f).dot(weights) - temp[count : count + len(d)] = d - count += len(d) + Where P is a projection of the fields onto the data space. + """ + weights = self._quad_weights + if self._mini_survey is not None: + survey = self._mini_survey + else: + survey = self.survey + + if survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) - return self._mini_survey_data(temp) + if f is None: + if m is None: + m = self.model + f = self.fields(m) - def getSourceTerm(self, _): - """ - Evaluates the sources, and puts them in matrix form - :rtype: tuple - :return: q (nC or nN, nSrc) - """ + temp = np.empty(survey.nD) + count = 0 + for src in survey.source_list: + for rx in src.receiver_list: + d = rx.eval(src, self.mesh, f).dot(weights) + temp[count : count + len(d)] = d + count += len(d) - if getattr(self, "_q", None) is None: + return self._mini_survey_data(temp) - if self._mini_survey is not None: - Srcs = self._mini_survey.source_list - else: - Srcs = self.survey.source_list - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError +def getSourceTerm(self, _): + """ + Evaluates the sources, and puts them in matrix form + :rtype: tuple + :return: q (nC or nN, nSrc) + """ + + if getattr(self, "_q", None) is None: + + if self._mini_survey is not None: + Srcs = self._mini_survey.source_list + else: + Srcs = self.survey.source_list + + if self._formulation == "EB": + n = self.mesh.nN + # return NotImplementedError + + elif self._formulation == "HJ": + n = self.mesh.nC + + q = np.zeros((n, len(Srcs)), order="F") - elif self._formulation == "HJ": - n = self.mesh.nC + for i, source in enumerate(Srcs): + q[:, i] = source.eval(self) - q = np.zeros((n, len(Srcs)), order="F") + self._q = q - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + return self._q - self._q = q - return self._q +Sim.fields = fields +Sim.compute_J = compute_J +Sim.dpred = dpred +Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index bc632bd632..93cda31d99 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -2,7 +2,7 @@ import dask.array import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim -from ...simulation import BaseSimulation + from ....utils import Zero from simpeg.fields import TimeFields from multiprocessing import cpu_count @@ -10,9 +10,6 @@ import scipy.sparse as sp from dask import array, delayed -from simpeg.electromagnetics.time_domain.simulation import ( - Simulation3DMagneticFluxDensity as MagneticFlux, -) from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -20,224 +17,217 @@ from tqdm import tqdm -class BaseTDEMSimulation(BaseSimulation, Sim): +def fields(self, m=None): + if m is not None: + self.model = m - def fields(self, m=None): - if m is not None: - self.model = m + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields + f = self.fieldsPair(self) + f[:, self._fieldType + "Solution", 0] = self.getInitialFields() + Ainv = {} - f = self.fieldsPair(self) - f[:, self._fieldType + "Solution", 0] = self.getInitialFields() - Ainv = {} + for tInd, dt in enumerate(self.time_steps): + if dt not in Ainv: + A = self.getAdiag(tInd) + Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) - for tInd, dt in enumerate(self.time_steps): - if dt not in Ainv: - A = self.getAdiag(tInd) - Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) + Asubdiag = self.getAsubdiag(tInd) + rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] - Asubdiag = self.getAsubdiag(tInd) - rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] + if ( + np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) + > 1e-8 + ): + rhs += self.getRHS(tInd + 1) - if ( - np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) - > 1e-8 - ): - rhs += self.getRHS(tInd + 1) + sol = Ainv[dt] * rhs + f[:, self._fieldType + "Solution", tInd + 1] = sol - sol = Ainv[dt] * rhs - f[:, self._fieldType + "Solution", tInd + 1] = sol + self.Ainv = Ainv + self._stashed_fields = f + return f - self.Ainv = Ainv - self._stashed_fields = f - return f - def getSourceTerm(self, tInd): - """ - Assemble the source term. This ensures that the RHS is a vector / array - of the correct size - """ - source_list = self.survey.source_list - source_block = np.array_split(source_list, cpu_count()) +def getSourceTerm(self, tInd): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + source_list = self.survey.source_list + source_block = np.array_split(source_list, cpu_count()) - block_compute = [] - for block in source_block: - block_compute.append(source_evaluation(self, block, self.times[tInd])) + block_compute = [] + for block in source_block: + block_compute.append(source_evaluation(self, block, self.times[tInd])) - blocks = dask.compute(block_compute)[0] + blocks = dask.compute(block_compute)[0] - s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m.append(block[0]) - s_e.append(block[1]) + s_m, s_e = [], [] + for block in blocks: + if block[0]: + s_m.append(block[0]) + s_e.append(block[1]) - if isinstance(s_m[0][0], Zero): - return Zero(), np.vstack(s_e).T + if isinstance(s_m[0][0], Zero): + return Zero(), np.vstack(s_e).T - return np.vstack(s_m).T, np.vstack(s_e).T + return np.vstack(s_m).T, np.vstack(s_e).T - def dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - .. math:: +def dpred(self, m=None, f=None): + r""" + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). - d_\\text{pred} = P(f(m)) + .. math:: - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) + d_\\text{pred} = P(f(m)) - if f is None: - if m is None: - m = self.model - f = self.fields(m) + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) - rows = [] - receiver_projection = self.survey.source_list[0].receiver_list[0].projField - fields_array = f[:, receiver_projection, :] + if f is None: + if m is None: + m = self.model + f = self.fields(m) - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] + rows = [] + receiver_projection = self.survey.source_list[0].receiver_list[0].projField + fields_array = f[:, receiver_projection, :] - all_receivers = [] + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) + all_receivers = [] - receiver_blocks = np.array_split(all_receivers, cpu_count()) + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue + receiver_blocks = np.array_split(all_receivers, cpu_count()) - rows.append( - array.from_delayed( - evaluate_receivers( - block, self.mesh, self.time_mesh, f, fields_array - ), - dtype=np.float64, - shape=(n_data,), - ) - ) + for block in receiver_blocks: + n_data = np.sum([rec.nD for _, _, rec in block]) + if n_data == 0: + continue - data = array.hstack(rows).compute() + rows.append( + array.from_delayed( + evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), + dtype=np.float64, + shape=(n_data,), + ) + ) - return data + data = array.hstack(rows).compute() - def compute_J(self, m, f=None): - """ - Compute the rows for the sensitivity matrix. - """ - if f is None: - f = self.fields(m) + return data - ftype = self._fieldType + "Solution" - sens_name = self.sensitivity_path[:-5] - if self.store_sensitivities == "disk": - rows = array.zeros( - (self.survey.nD, m.size), - chunks=(self.max_chunk_size, m.size), - dtype=np.float32, - ) - Jmatrix = array.to_zarr( - rows, - os.path.join(sens_name + "_1.zarr"), - compute=True, - return_stored=True, - overwrite=True, - ) - else: - Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) - blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) - fields_array = f[:, ftype, :] +def compute_J(self, m, f=None): + """ + Compute the rows for the sensitivity matrix. + """ + if f is None: + f = self.fields(m) + + ftype = self._fieldType + "Solution" + sens_name = self.sensitivity_path[:-5] + if self.store_sensitivities == "disk": + rows = array.zeros( + (self.survey.nD, m.size), + chunks=(self.max_chunk_size, m.size), + dtype=np.float32, + ) + Jmatrix = array.to_zarr( + rows, + os.path.join(sens_name + "_1.zarr"), + compute=True, + return_stored=True, + overwrite=True, + ) + else: + Jmatrix = np.zeros((self.survey.nD, m.size), dtype=np.float64) - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) + blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) + fields_array = f[:, ftype, :] - times_field_derivs, Jmatrix = compute_field_derivs( - self, f, blocks, Jmatrix, fields_array.shape - ) + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] - ATinv_df_duT_v = {} - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - AdiagTinv = self.Ainv[dt] - j_row_updates = [] - time_mask = data_times > simulation_times[tInd] + times_field_derivs, Jmatrix = compute_field_derivs( + self, f, blocks, Jmatrix, fields_array.shape + ) - if not np.any(time_mask): - continue + ATinv_df_duT_v = {} + for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + AdiagTinv = self.Ainv[dt] + j_row_updates = [] + time_mask = data_times > simulation_times[tInd] - for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): - ATinv_df_duT_v = get_field_deriv_block( - self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask - ) + if not np.any(time_mask): + continue - if len(block) == 0: - continue + for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): + ATinv_df_duT_v = get_field_deriv_block( + self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + ) - j_row_updates.append( - array.from_delayed( - compute_rows( - self, - tInd, - block, - ATinv_df_duT_v, - fields_array, - time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), - ) - ) + if len(block) == 0: + continue - if self.store_sensitivities == "disk": - sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr( - Jmatrix + array.vstack(j_row_updates), - sens_name, - compute=True, - overwrite=True, + j_row_updates.append( + array.from_delayed( + compute_rows( + self, + tInd, + block, + ATinv_df_duT_v, + fields_array, + time_mask, + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), ) - Jmatrix = array.from_zarr(sens_name) - else: - Jmatrix += array.vstack(j_row_updates).compute() - - for A in self.Ainv.values(): - A.clean() + ) - if self.store_sensitivities == "ram": - self._Jmatrix = np.asarray(Jmatrix) + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" + array.to_zarr( + Jmatrix + array.vstack(j_row_updates), + sens_name, + compute=True, + overwrite=True, + ) + Jmatrix = array.from_zarr(sens_name) + else: + Jmatrix += array.vstack(j_row_updates).compute() - self._Jmatrix = Jmatrix + for A in self.Ainv.values(): + A.clean() - return self._Jmatrix + if self.store_sensitivities == "ram": + self._Jmatrix = np.asarray(Jmatrix) + self._Jmatrix = Jmatrix -class Simulation3DMagneticFluxDensity(MagneticFlux, BaseTDEMSimulation): - """ - Overload the Simulation3DMagneticFluxDensity class to use Dask - """ + return self._Jmatrix def _getField(self, name, ind, src_list): @@ -582,3 +572,9 @@ def compute_rows( rows.append(row_block) return np.vstack(rows) + + +Sim.fields = fields +Sim.getSourceTerm = getSourceTerm +Sim.dpred = dpred +Sim.compute_J = compute_J diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 5237571fcf..7953fe24c3 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,114 +1,124 @@ import numpy as np from ...potential_fields.base import BasePFSimulation as Sim -from ..simulation import BaseSimulation + import os from dask import delayed, array, config from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes -class BasePFSimulation(BaseSimulation, Sim): - - _chunk_format = "row" - - @property - def chunk_format(self): - "Apply memory chunks along rows of G, either 'equal', 'row', or 'auto'" - return self._chunk_format - - @chunk_format.setter - def chunk_format(self, other): - if other not in ["equal", "row", "auto"]: - raise ValueError("Chunk format must be 'equal', 'row', or 'auto'") - self._chunk_format = other - - def dpred(self, m=None, f=None): - if m is not None: - self.model = m - if f is not None: - return f - return self.fields(self.model) - - def residual(self, m, dobs, f=None): - return self.dpred(m, f=f) - dobs - - def linear_operator(self): - forward_only = self.store_sensitivities == "forward_only" - row = delayed(self.evaluate_integral, pure=True) - n_cells = self.nC - if getattr(self, "model_type", None) == "vector": - n_cells *= 3 - - rows = [ - array.from_delayed( - row(receiver_location, components), - dtype=self.sensitivity_dtype, - shape=( - (len(components),) if forward_only else (len(components), n_cells) - ), +_chunk_format = "row" + + +@property +def chunk_format(self): + "Apply memory chunks along rows of G, either 'equal', 'row', or 'auto'" + return self._chunk_format + + +@chunk_format.setter +def chunk_format(self, other): + if other not in ["equal", "row", "auto"]: + raise ValueError("Chunk format must be 'equal', 'row', or 'auto'") + self._chunk_format = other + + +def dpred(self, m=None, f=None): + if m is not None: + self.model = m + if f is not None: + return f + return self.fields(self.model) + + +def residual(self, m, dobs, f=None): + return self.dpred(m, f=f) - dobs + + +def linear_operator(self): + forward_only = self.store_sensitivities == "forward_only" + row = delayed(self.evaluate_integral, pure=True) + n_cells = self.nC + if getattr(self, "model_type", None) == "vector": + n_cells *= 3 + + rows = [ + array.from_delayed( + row(receiver_location, components), + dtype=self.sensitivity_dtype, + shape=((len(components),) if forward_only else (len(components), n_cells)), + ) + for receiver_location, components in self.survey._location_component_iterator() + ] + if forward_only: + stack = array.concatenate(rows) + else: + stack = array.vstack(rows) + # Chunking options + if self.chunk_format == "row": + config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) + # Autochunking by rows is faster and more memory efficient for + # very large problems sensitivty and forward calculations + stack = stack.rechunk({0: "auto", 1: -1}) + elif self.chunk_format == "equal": + # Manual chunks for equal number of blocks along rows and columns. + # Optimal for Jvec and Jtvec operations + row_chunk, col_chunk = compute_chunk_sizes( + *stack.shape, self.max_chunk_size ) - for receiver_location, components in self.survey._location_component_iterator() - ] - if forward_only: - stack = array.concatenate(rows) - else: - stack = array.vstack(rows) - # Chunking options - if self.chunk_format == "row": - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - # Autochunking by rows is faster and more memory efficient for - # very large problems sensitivty and forward calculations - stack = stack.rechunk({0: "auto", 1: -1}) - elif self.chunk_format == "equal": - # Manual chunks for equal number of blocks along rows and columns. - # Optimal for Jvec and Jtvec operations - row_chunk, col_chunk = compute_chunk_sizes( - *stack.shape, self.max_chunk_size - ) - stack = stack.rechunk((row_chunk, col_chunk)) - else: - # Auto chunking by columns is faster for Inversions - config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) - stack = stack.rechunk({0: -1, 1: "auto"}) - - if self.store_sensitivities == "disk": - sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") - if os.path.exists(sens_name): - kernel = array.from_zarr(sens_name) - if np.all( - np.r_[ - np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), - np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), - np.r_[kernel.shape] == np.r_[stack.shape], - ] - ): - # Check that loaded kernel matches supplied data and mesh - print( - "Zarr file detected with same shape and chunksize ... re-loading" - ) - return kernel - - print("Writing Zarr file to disk") - with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=False, return_stored=True, overwrite=True - ) - elif forward_only: - # with ProgressBar(): - # print("Forward calculation: ") - kernel = stack # .compute() + stack = stack.rechunk((row_chunk, col_chunk)) else: - # with ProgressBar(): - # print("Computing sensitivities to local ram") - kernel = stack # .persist() - return kernel - - def compute_J(self, _): - return self.linear_operator() - - @property - def Jmatrix(self): - if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.linear_operator() - return self._Jmatrix + # Auto chunking by columns is faster for Inversions + config.set({"array.chunk-size": f"{self.max_chunk_size}MiB"}) + stack = stack.rechunk({0: -1, 1: "auto"}) + + if self.store_sensitivities == "disk": + sens_name = os.path.join(self.sensitivity_path, "sensitivity.zarr") + if os.path.exists(sens_name): + kernel = array.from_zarr(sens_name) + if np.all( + np.r_[ + np.any(np.r_[kernel.chunks[0]] == stack.chunks[0]), + np.any(np.r_[kernel.chunks[1]] == stack.chunks[1]), + np.r_[kernel.shape] == np.r_[stack.shape], + ] + ): + # Check that loaded kernel matches supplied data and mesh + print("Zarr file detected with same shape and chunksize ... re-loading") + return kernel + + print("Writing Zarr file to disk") + with ProgressBar(): + print("Saving kernel to zarr: " + sens_name) + kernel = array.to_zarr( + stack, sens_name, compute=False, return_stored=True, overwrite=True + ) + elif forward_only: + # with ProgressBar(): + # print("Forward calculation: ") + kernel = stack # .compute() + else: + # with ProgressBar(): + # print("Computing sensitivities to local ram") + kernel = stack # .persist() + return kernel + + +def compute_J(self, _): + return self.linear_operator() + + +@property +def Jmatrix(self): + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.linear_operator() + return self._Jmatrix + + +Sim._chunk_format = _chunk_format +Sim.chunk_format = chunk_format +Sim.dpred = dpred +Sim.residual = residual +Sim.linear_operator = linear_operator +Sim.compute_J = compute_J +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index 934245abdc..c0fdc8706c 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,36 +1,34 @@ import numpy as np from dask import array, delayed from ....potential_fields.gravity import Simulation3DIntegral as Sim -from ..base import BasePFSimulation + from scipy.sparse import csr_matrix as csr -class Simulation3DIntegral(BasePFSimulation, Sim): +def getJtJdiag(self, m, W=None, f=None): """ - Overload the Simulation3DIntegral class to use Dask + Return the diagonal of JtJ """ - def getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - - self.model = m - if W is None: - W = np.ones(self.Jmatrix.shape[0]) - else: - W = W.diagonal() - - if getattr(self, "_gtg_diagonal", None) is None: - diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - - mapping_deriv = self.rhoDeriv.tocsr().T.power(2) - dmudm_jtvec = delayed(csr.dot)(mapping_deriv, diag) - jtjdiag = array.from_delayed( - dmudm_jtvec, dtype=np.float32, shape=[mapping_deriv.shape[1]] - ) - - return jtjdiag + self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() + + if getattr(self, "_gtg_diagonal", None) is None: + diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) + self._gtg_diagonal = diag + else: + diag = self._gtg_diagonal + + mapping_deriv = self.rhoDeriv.tocsr().T.power(2) + dmudm_jtvec = delayed(csr.dot)(mapping_deriv, diag) + jtjdiag = array.from_delayed( + dmudm_jtvec, dtype=np.float32, shape=[mapping_deriv.shape[1]] + ) + + return jtjdiag + + +Sim.getJtJdiag = getJtJdiag diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 9ca8d203dc..d6bb02694b 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,41 +1,39 @@ import numpy as np from dask import array from ....potential_fields.magnetics import Simulation3DIntegral as Sim -from ..base import BasePFSimulation + from ....utils import sdiag, mkvc -class Simulation3DIntegral(Sim, BasePFSimulation): +def getJtJdiag(self, m, W=None, f=None): """ - Overwrite the dask_getJtJdiag method + Return the diagonal of JtJ """ - def getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - - self.model = m + self.model = m - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - if not self.is_amplitude_data: - diag = array.einsum( - "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix - ).compute() - else: - ampDeriv = self.ampDeriv - J = ( - ampDeriv[0, :, None] * self.Jmatrix[::3] - + ampDeriv[1, :, None] * self.Jmatrix[1::3] - + ampDeriv[2, :, None] * self.Jmatrix[2::3] - ) - diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() - self._gtg_diagonal = diag + if W is None: + W = np.ones(self.nD) + else: + W = W.diagonal() + if getattr(self, "_gtg_diagonal", None) is None: + if not self.is_amplitude_data: + diag = array.einsum( + "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix + ).compute() else: - diag = self._gtg_diagonal + ampDeriv = self.ampDeriv + J = ( + ampDeriv[0, :, None] * self.Jmatrix[::3] + + ampDeriv[1, :, None] * self.Jmatrix[1::3] + + ampDeriv[2, :, None] * self.Jmatrix[2::3] + ) + diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() + self._gtg_diagonal = diag + else: + diag = self._gtg_diagonal + + return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) + - return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) +Sim.getJtJdiag = getJtJdiag diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 06145cbaa9..ba65009364 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -4,86 +4,104 @@ import numpy as np +Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] +Sim.sensitivity_path = "./sensitivity/" +Sim._max_ram = 16 +Sim._max_chunk_size = 128 -class BaseSimulation(Sim): + +@property +def max_ram(self): + "Maximum ram in (Gb)" + return self._max_ram + + +@max_ram.setter +def max_ram(self, other): + if other <= 0: + raise ValueError("max_ram must be greater than 0") + self._max_ram = other + + +Sim.max_ram = max_ram + + +@property +def max_chunk_size(self): + "Largest chunk size (Mb) used by Dask" + return self._max_chunk_size + + +@max_chunk_size.setter +def max_chunk_size(self, other): + if other <= 0: + raise ValueError("max_chunk_size must be greater than 0") + self._max_chunk_size = other + + +Sim.max_chunk_size = max_chunk_size + + +def getJtJdiag(self, m, W=None, f=None): """ - Base class for SimPEG simulations + Return the diagonal of JtJ """ - - clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] - sensitivity_path = "./sensitivity/" - _max_ram = 16 - _max_chunk_size = 128 - - @property - def max_ram(self): - "Maximum ram in (Gb)" - return self._max_ram - - @max_ram.setter - def max_ram(self, other): - if other <= 0: - raise ValueError("max_ram must be greater than 0") - self._max_ram = other - - @property - def max_chunk_size(self): - "Largest chunk size (Mb) used by Dask" - return self._max_chunk_size - - @max_chunk_size.setter - def max_chunk_size(self, other): - if other <= 0: - raise ValueError("max_chunk_size must be greater than 0") - self._max_chunk_size = other - - def getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - if getattr(self, "_jtjdiag", None) is None: - self.model = m - if W is None: - W = np.ones(self.Jmatrix.shape[0]) - else: - W = W.diagonal() - - self._jtj_diag = array.einsum( - "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix - ) - - return self._jtj_diag - - def Jvec(self, m, v, **_): - """ - Compute sensitivity matrix (J) and vector (v) product. - """ + if getattr(self, "_jtjdiag", None) is None: self.model = m + if W is None: + W = np.ones(self.Jmatrix.shape[0]) + else: + W = W.diagonal() - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix @ v.astype(np.float32) + self._jtj_diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) - return array.dot(self.Jmatrix, v.astype(np.float32)) + return self._jtj_diag - def Jtvec(self, m, v, **_): - """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. - """ - self.model = m - if isinstance(self.Jmatrix, np.ndarray): - return self.Jmatrix.T @ v.astype(np.float32) +Sim.getJtJdiag = getJtJdiag + + +def Jvec(self, m, v, **_): + """ + Compute sensitivity matrix (J) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix @ v.astype(np.float32) + + return array.dot(self.Jmatrix, v.astype(np.float32)) + + +Sim.Jvec = Jvec + + +def Jtvec(self, m, v, **_): + """ + Compute adjoint sensitivity matrix (J^T) and vector (v) product. + """ + self.model = m + + if isinstance(self.Jmatrix, np.ndarray): + return self.Jmatrix.T @ v.astype(np.float32) + + return array.dot(v.astype(np.float32), self.Jmatrix) + + +Sim.Jtvec = Jtvec + + +@property +def Jmatrix(self): + """ + Sensitivity matrix stored on disk + Return the diagonal of JtJ + """ + if getattr(self, "_Jmatrix", None) is None: + self._Jmatrix = self.compute_J(self.model) + self._stashed_fields = None - return array.dot(v.astype(np.float32), self.Jmatrix) + return self._Jmatrix - @property - def Jmatrix(self): - """ - Sensitivity matrix stored on disk - Return the diagonal of JtJ - """ - if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.compute_J(self.model) - self._stashed_fields = None - return self._Jmatrix +Sim.Jmatrix = Jmatrix From 00a0f0265baf6c82e314955c3d432c95eb98bfa6 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 17 Dec 2024 12:59:31 -0800 Subject: [PATCH 21/84] Simplify PF compute Jmatrix --- simpeg/dask/potential_fields/base.py | 32 +++++++++++----------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 7953fe24c3..0841796b15 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -3,7 +3,7 @@ import os from dask import delayed, array, config -from dask.diagnostics import ProgressBar + from ..utils import compute_chunk_sizes @@ -87,34 +87,28 @@ def linear_operator(self): print("Zarr file detected with same shape and chunksize ... re-loading") return kernel - print("Writing Zarr file to disk") - with ProgressBar(): - print("Saving kernel to zarr: " + sens_name) - kernel = array.to_zarr( - stack, sens_name, compute=False, return_stored=True, overwrite=True - ) - elif forward_only: - # with ProgressBar(): - # print("Forward calculation: ") - kernel = stack # .compute() - else: - # with ProgressBar(): - # print("Computing sensitivities to local ram") - kernel = stack # .persist() - return kernel + return array.to_zarr( + stack, sens_name, compute=False, return_stored=True, overwrite=True + ) + return stack -def compute_J(self, _): - return self.linear_operator() +def compute_J(self, _, f=None): + return self.linear_operator().persist() @property def Jmatrix(self): if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.linear_operator() + self._Jmatrix = self.compute_J() return self._Jmatrix +@Jmatrix.setter +def Jmatrix(self, value): + self._Jmatrix = value + + Sim._chunk_format = _chunk_format Sim.chunk_format = chunk_format Sim.dpred = dpred From 238f304585d6779551c7c7f0f0a38b207dd14146 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 17 Dec 2024 14:05:20 -0800 Subject: [PATCH 22/84] Handle G matrix on dask class --- simpeg/dask/inverse_problem.py | 10 +-- simpeg/dask/potential_fields/base.py | 14 +++- simpeg/meta/dask_sim.py | 75 +++++++++---------- simpeg/potential_fields/gravity/simulation.py | 8 +- .../potential_fields/magnetics/simulation.py | 10 +-- 5 files changed, 56 insertions(+), 61 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 94bdea857b..5baa675e2b 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,8 +1,7 @@ from ..inverse_problem import BaseInvProblem import numpy as np - -from dask.distributed import get_client +from dask.distributed import get_client, Future from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse @@ -16,11 +15,12 @@ def get_dpred(self, m, f=None): dpred = objfct.simulation.dpred(m) dpreds += [dpred] - try: + if isinstance(dpreds[0], Future): client = get_client() dpreds = client.gather(dpreds) - except ValueError: - pass + else: + for i, dpred in enumerate(dpreds): + dpreds[i] = np.asarray(dpred) return dpreds diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 0841796b15..3db0222c09 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -93,6 +93,17 @@ def linear_operator(self): return stack +@property +def G(self): + """ + Gravity forward operator + """ + if getattr(self, "_G", None) is None: + self._G = self.Jmatrix + + return self._G + + def compute_J(self, _, f=None): return self.linear_operator().persist() @@ -100,7 +111,7 @@ def compute_J(self, _, f=None): @property def Jmatrix(self): if getattr(self, "_Jmatrix", None) is None: - self._Jmatrix = self.compute_J() + self._Jmatrix = self.compute_J(self.model) return self._Jmatrix @@ -109,6 +120,7 @@ def Jmatrix(self, value): self._Jmatrix = value +Sim.G = G Sim._chunk_format = _chunk_format Sim.chunk_format = chunk_format Sim.dpred = dpred diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 218c9ee647..bddf091920 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -8,7 +8,6 @@ import itertools from dask.distributed import Client from dask.distributed import Future -from dask import array from .simulation import MetaSimulation, SumMetaSimulation import scipy.sparse as sp from operator import add @@ -26,39 +25,36 @@ def _calc_fields(mapping, sim, model, apply_map=False): return sim.fields(m=sim.model) -def _calc_dpred(mapping, sim, model, fields, apply_map=False): +def _calc_dpred(mapping, sim, model, field, apply_map=False): if apply_map and model is not None: - return array.compute(sim.dpred(m=mapping @ model, f=fields))[0] + return sim.dpred(m=mapping @ model) else: - return array.compute(sim.dpred(m=sim.model, f=fields))[0] + return sim.dpred(m=sim.model, f=field) -def _j_vec_op(mapping, sim, model, v, apply_map=False): - # return array.from_array(np.zeros(100)) +def _j_vec_op(mapping, sim, model, field, v, apply_map=False): sim_v = mapping.deriv(model) @ v if apply_map: - return array.compute(sim.Jvec(mapping @ model, sim_v))[0] + return sim.Jvec(mapping @ model, sim_v, f=field) else: - return array.compute(sim.Jvec(sim.model, sim_v))[0] + return sim.Jvec(sim.model, sim_v, f=field) -def _jt_vec_op(mapping, sim, model, v, start, end, apply_map=False): +def _jt_vec_op(mapping, sim, model, field, v, apply_map=False): if apply_map: - jtv = sim.Jtvec(mapping @ model, v[start:end]) + jtv = sim.Jtvec(mapping @ model, v, f=field) else: - jtv = sim.Jtvec(sim.model, v[start:end]) + jtv = sim.Jtvec(sim.model, v, f=field) + return mapping.deriv(model).T @ jtv - # Need to delay this operation until the future is computed - return mapping.deriv(model).T @ array.compute(jtv)[0] - -def _get_jtj_diag(mapping, sim, model, w, apply_map=False): +def _get_jtj_diag(mapping, sim, model, field, w, apply_map=False): w = sp.diags(w) if apply_map: - jtj = sim.getJtJdiag(mapping @ model, w) + jtj = sim.getJtJdiag(mapping @ model, w, f=field) else: - jtj = sim.getJtJdiag(sim.model, w) - sim_jtj = sp.diags(np.sqrt(np.asarray(jtj))) + jtj = sim.getJtJdiag(sim.model, w, f=field) + sim_jtj = sp.diags(np.sqrt(jtj)) m_deriv = mapping.deriv(model) return np.asarray((sim_jtj @ m_deriv).power(2).sum(axis=0)).flatten() @@ -69,7 +65,7 @@ def _reduce(client, operation, items): if len(items) % 2 == 1: new_reduce[-1] = client.submit(operation, new_reduce[-1], items[-1]) items = new_reduce - return items[0] + return client.gather(items[0]) def _validate_type_or_future_of_type( @@ -355,40 +351,42 @@ def dpred(self, m=None, f=None): workers=worker, ) ) - return _reduce(client, np.concatenate, dpred) + return np.concatenate(client.gather(dpred)) def Jvec(self, m, v, f=None): self.model = m m_future = self._m_as_future - # if f is None: - # f = self.fields(m) + if f is None: + f = self.fields(m) client = self.client [v_future] = client.scatter([v], broadcast=True) j_vec = [] - for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): + for mapping, sim, worker, field in zip( + self.mappings, self.simulations, self._workers, f + ): j_vec.append( client.submit( _j_vec_op, mapping, sim, m_future, - # field, + field, v_future, self._repeat_sim, workers=worker, ) ) - return _reduce(client, np.concatenate, j_vec) + return np.concatenate(self.client.gather(j_vec)) def Jtvec(self, m, v, f=None): self.model = m m_future = self._m_as_future - # if f is None: - # f = self.fields(m) + if f is None: + f = self.fields(m) jt_vec = [] client = self.client - for i, (mapping, sim, worker) in enumerate( - zip(self.mappings, self.simulations, self._workers) + for i, (mapping, sim, worker, field) in enumerate( + zip(self.mappings, self.simulations, self._workers, f) ): jt_vec.append( client.submit( @@ -396,10 +394,8 @@ def Jtvec(self, m, v, f=None): mapping, sim, m_future, - # field, - v, - self._data_offsets[i], - self._data_offsets[i + 1], + field, + v[self._data_offsets[i] : self._data_offsets[i + 1]], self._repeat_sim, workers=worker, ) @@ -418,26 +414,23 @@ def getJtJdiag(self, m, W=None, f=None): W = W.diagonal() jtj_diag = [] client = self.client - # if f is None: - # f = self.fields(m) - for i, (mapping, sim, worker) in enumerate( - zip(self.mappings, self.simulations, self._workers) + if f is None: + f = self.fields(m) + for i, (mapping, sim, worker, field) in enumerate( + zip(self.mappings, self.simulations, self._workers, f) ): sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] - # s = client.gather(sim) - # ff = client.gather(field) jtj_diag.append( client.submit( _get_jtj_diag, mapping, sim, m_future, - # field, + field, sim_w, self._repeat_sim, workers=worker, ) - # s.getJtJdiag(self.model, sim_w, f=ff) ) self._jtjdiag = _reduce(client, add, jtj_diag) diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 789f9405f5..2c8ac62701 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -201,7 +201,7 @@ def Jvec(self, m, v, f=None): Sensitivity times a vector """ dmu_dm_v = self.rhoDeriv @ v - return self.G @ dmu_dm_v.astype(self.sensitivity_dtype, copy=False) + return np.asarray(self.G @ dmu_dm_v.astype(self.sensitivity_dtype, copy=False)) def Jtvec(self, m, v, f=None): """ @@ -216,16 +216,10 @@ def G(self): Gravity forward operator """ if getattr(self, "_G", None) is None: - if self._Jmatrix is not None: - self._G = self._Jmatrix - return self._G - if self.engine == "choclo": self._G = self._sensitivity_matrix() else: self._G = self.linear_operator() - - self._Jmatrix = self._G return self._G @property diff --git a/simpeg/potential_fields/magnetics/simulation.py b/simpeg/potential_fields/magnetics/simulation.py index 9a162cc78c..610b933682 100644 --- a/simpeg/potential_fields/magnetics/simulation.py +++ b/simpeg/potential_fields/magnetics/simulation.py @@ -198,18 +198,14 @@ def fields(self, model): @property def G(self): + """ + Gravity forward operator + """ if getattr(self, "_G", None) is None: - if self._Jmatrix is not None: - self._G = self._Jmatrix - return self._G - if self.engine == "choclo": self._G = self._sensitivity_matrix() else: self._G = self.linear_operator() - - self._Jmatrix = self._G - return self._G modelType = deprecate_property( From 38a3b3c87f61651576575a062b32f437a23d9d30 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 18 Dec 2024 09:30:22 -0800 Subject: [PATCH 23/84] Explicit re-assignement of G on grav and mag dask classes --- simpeg/dask/potential_fields/gravity/simulation.py | 3 ++- simpeg/dask/potential_fields/magnetics/simulation.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index c0fdc8706c..88bc1a3404 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,7 +1,7 @@ import numpy as np from dask import array, delayed from ....potential_fields.gravity import Simulation3DIntegral as Sim - +from ..base import G from scipy.sparse import csr_matrix as csr @@ -32,3 +32,4 @@ def getJtJdiag(self, m, W=None, f=None): Sim.getJtJdiag = getJtJdiag +Sim.G = G diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index d6bb02694b..1d0e21cf9b 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,7 +1,7 @@ import numpy as np from dask import array from ....potential_fields.magnetics import Simulation3DIntegral as Sim - +from ..base import G from ....utils import sdiag, mkvc @@ -37,3 +37,4 @@ def getJtJdiag(self, m, W=None, f=None): Sim.getJtJdiag = getJtJdiag +Sim.G = G From 02507f04f3082621963601c86547db426956b854 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 22 Dec 2024 10:27:03 -0800 Subject: [PATCH 24/84] Continue work --- .../frequency_domain/simulation.py | 36 +++++--- simpeg/dask/potential_fields/base.py | 6 +- .../potential_fields/gravity/simulation.py | 31 +------ .../potential_fields/magnetics/simulation.py | 1 + simpeg/dask/simulation.py | 4 +- simpeg/meta/__init__.py | 1 + simpeg/meta/dask_sim.py | 91 +++++++++++++++++++ simpeg/meta/simulation.py | 8 +- 8 files changed, 126 insertions(+), 52 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index bb666c169d..0b105b0e0b 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -112,9 +112,11 @@ def getSourceTerm(self, freq, source=None): if len(block) == 0: continue - block_compute.append(source_evaluation(self, block)) + block_compute.append( + self.client.submit(source_evaluation, self, block, workers=self.worker) + ) - blocks = compute(block_compute)[0] + blocks = self.client.gather(block_compute) s_m, s_e = [], [] for block in blocks: if block[0]: @@ -197,8 +199,8 @@ def fields(self, m=None): if m is not None: self.model = m - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields + # if getattr(self, "_stashed_fields", None) is not None: + # return self._stashed_fields f = self.fieldsPair(self) Ainv = {} @@ -211,9 +213,9 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - self.Ainv = Ainv - - self._stashed_fields = f + # Ainv = Ainv + # + # self._stashed_fields = f return f @@ -224,13 +226,19 @@ def compute_J(self, m, f=None): if f is None: f = self.fields(m) - if len(self.Ainv) > 1: + Ainv = {} + for freq in self.survey.frequencies: + A = self.getA(freq) + Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) + Ainv[freq] = Ainv_solve + + if len(Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(self.Ainv.values())[0] + A_i = list(Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -269,22 +277,20 @@ def compute_J(self, m, f=None): for block_derivs_chunks, addresses_chunks in tqdm( zip(blocks_receiver_derivs, blocks), ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(self.Ainv)[0]} Hz", + desc=f"Sensitivities at {list(Ainv)[0]} Hz", ): Jmatrix = self.parallel_block_compute( m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in self.Ainv.values(): + for A in Ainv.values(): A.clean() if self.store_sensitivities == "disk": del Jmatrix - self._Jmatrix = array.from_zarr(self.sensitivity_path) - else: - self._Jmatrix = Jmatrix + Jmatrix = array.from_zarr(self.sensitivity_path) - return self._Jmatrix + return Jmatrix def parallel_block_compute( diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 3db0222c09..06c521c91b 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -88,9 +88,9 @@ def linear_operator(self): return kernel return array.to_zarr( - stack, sens_name, compute=False, return_stored=True, overwrite=True + stack, sens_name, compute=True, return_stored=True, overwrite=True ) - return stack + return stack.compute() @property @@ -105,7 +105,7 @@ def G(self): def compute_J(self, _, f=None): - return self.linear_operator().persist() + return self.linear_operator() @property diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index 88bc1a3404..99fbea1c7b 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,35 +1,8 @@ -import numpy as np -from dask import array, delayed from ....potential_fields.gravity import Simulation3DIntegral as Sim from ..base import G -from scipy.sparse import csr_matrix as csr - - -def getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - - self.model = m - if W is None: - W = np.ones(self.Jmatrix.shape[0]) - else: - W = W.diagonal() - - if getattr(self, "_gtg_diagonal", None) is None: - diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal - - mapping_deriv = self.rhoDeriv.tocsr().T.power(2) - dmudm_jtvec = delayed(csr.dot)(mapping_deriv, diag) - jtjdiag = array.from_delayed( - dmudm_jtvec, dtype=np.float32, shape=[mapping_deriv.shape[1]] - ) - - return jtjdiag +from ...simulation import getJtJdiag +Sim.clean_on_model_update = [] Sim.getJtJdiag = getJtJdiag Sim.G = G diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 1d0e21cf9b..770320c2e4 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -36,5 +36,6 @@ def getJtJdiag(self, m, W=None, f=None): return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) +Sim.clean_on_model_update = [] Sim.getJtJdiag = getJtJdiag Sim.G = G diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index ba65009364..56ea168a28 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -53,7 +53,9 @@ def getJtJdiag(self, m, W=None, f=None): else: W = W.diagonal() - self._jtj_diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) + self._jtj_diag = np.asarray( + np.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) + ) return self._jtj_diag diff --git a/simpeg/meta/__init__.py b/simpeg/meta/__init__.py index 3dca694298..7c58eeb2f8 100644 --- a/simpeg/meta/__init__.py +++ b/simpeg/meta/__init__.py @@ -78,6 +78,7 @@ try: from .dask_sim import ( DaskMetaSimulation, + DaskMetaSimulationExplicit, DaskSumMetaSimulation, DaskRepeatedSimulation, ) diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index bddf091920..86459b12f2 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -1,11 +1,13 @@ import numpy as np from simpeg.simulation import BaseSimulation + from simpeg.survey import BaseSurvey from simpeg.maps import IdentityMap from simpeg.utils import validate_list_of_types, validate_type from simpeg.props import HasModel import itertools + from dask.distributed import Client from dask.distributed import Future from .simulation import MetaSimulation, SumMetaSimulation @@ -152,6 +154,7 @@ class DaskMetaSimulation(MetaSimulation): def __init__(self, simulations, mappings, client): self._client = validate_type("client", client, Client, cast=False) + self._concrete_simulations = None super().__init__(simulations, mappings) def _make_survey(self): @@ -177,6 +180,7 @@ def simulations(self): @simulations.setter def simulations(self, value): client = self.client + self._concrete_simulations = client.gather(value) simulations, workers = _validate_type_or_future_of_type( "simulations", value, BaseSimulation, client, return_workers=True ) @@ -437,6 +441,93 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag +def _compute_j(sim, model): + sim.model = model + jmatrix = getattr(sim, "_Jmatrix", None) + + if jmatrix is None: + jmatrix = sim.compute_J(model) + + return jmatrix + + +def set_jmatrix(sim, jmatrix): + sim._Jmatrix = jmatrix + return sim + + +class DaskMetaSimulationExplicit(DaskMetaSimulation): + clean_on_model_update = ["_Jmatrix", "_stashed_fields"] + + def fields(self, m): + self.model = m + + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + + client = self.client + m_future = self._m_as_future + # The above should pass the model to all the internal simulations. + f = [] + simulations = [] + for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): + jmatrix = client.submit( + _compute_j, + sim, + m_future, + workers=worker, + ) + sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) + f.append( + client.submit( + _calc_fields, + mapping, + sim, + m_future, + self._repeat_sim, + workers=worker, + ) + ) + simulations.append(sim) + + self._stashed_fields = f + self.simulations = simulations + return f + + def getJtJdiag(self, m, W=None, f=None): + self.model = m + m_future = self._m_as_future + if getattr(self, "_jtjdiag", None) is None: + if W is None: + W = np.ones(self.survey.nD) + else: + W = W.diagonal() + jtj_diag = [] + client = self.client + if f is None: + f = self.fields(m) + for i, (mapping, sim, worker, field) in enumerate( + zip(self.mappings, self.simulations, self._workers, f) + ): + sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] + + jtj_diag.append( + client.submit( + _get_jtj_diag, + mapping, + sim, + m_future, + field, + sim_w, + self._repeat_sim, + workers=worker, + ) + ) + self._jtjdiag = _reduce(client, add, jtj_diag) + + return self._jtjdiag + + class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation): """A dask distributed version of :class:`.SumMetaSimulation`. diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 5cdbb4eb9c..7c52e5fa8a 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -89,10 +89,10 @@ class MetaSimulation(BaseSimulation): _repeat_sim = False def __init__(self, simulations, mappings): - warnings.warn( - "The MetaSimulation class is a work in progress and might change in the future", - stacklevel=2, - ) + # warnings.warn( + # "The MetaSimulation class is a work in progress and might change in the future", + # stacklevel=2, + # ) self.simulations = simulations self.mappings = mappings self.model = None From c469bf98ea25a8df30a0c55e466926322ad7d2e8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 08:37:01 -0800 Subject: [PATCH 25/84] Block compute rows for PF --- simpeg/dask/potential_fields/base.py | 30 +++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 06c521c91b..031e86b24f 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -3,7 +3,6 @@ import os from dask import delayed, array, config - from ..utils import compute_chunk_sizes @@ -35,20 +34,41 @@ def residual(self, m, dobs, f=None): return self.dpred(m, f=f) - dobs +@delayed +def block_compute(sim, rows, components): + block = [] + for row in rows: + block.append(sim.evaluate_integral(row, components)) + + if sim.store_sensitivities == "forward_only": + return np.hstack(block) + + return np.vstack(block) + + def linear_operator(self): forward_only = self.store_sensitivities == "forward_only" - row = delayed(self.evaluate_integral, pure=True) n_cells = self.nC if getattr(self, "model_type", None) == "vector": n_cells *= 3 + n_components = len(self.survey.components) + n_blocks = int( + (n_cells * n_components * self.survey.receiver_locations.shape[0] * 8.0 * 1e-6) + / self.max_chunk_size + ) + block_split = np.array_split(self.survey.receiver_locations, n_blocks) rows = [ array.from_delayed( - row(receiver_location, components), + block_compute(self, block, self.survey.components), dtype=self.sensitivity_dtype, - shape=((len(components),) if forward_only else (len(components), n_cells)), + shape=( + (len(block) * n_components,) + if forward_only + else (len(block) * n_components, n_cells) + ), ) - for receiver_location, components in self.survey._location_component_iterator() + for block in block_split ] if forward_only: stack = array.concatenate(rows) From ceda37878b80ee2991de659a5eba2e4090ad18cd Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 08:38:09 -0800 Subject: [PATCH 26/84] Add back client property on simulation --- simpeg/dask/simulation.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 56ea168a28..535bb8d38c 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,7 +1,7 @@ from ..simulation import BaseSimulation as Sim from dask import array - +from dask.distributed import get_client import numpy as np Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] @@ -42,6 +42,20 @@ def max_chunk_size(self, other): Sim.max_chunk_size = max_chunk_size +@property +def client(self): + if getattr(self, "_client", None) is None: + try: + self._client = get_client() + except ValueError: + self._client = False + + return self._client + + +Sim.client = client + + def getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ From a4728288e1c01bf865c170fa0df44e7db0fd5f1b Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 08:59:30 -0800 Subject: [PATCH 27/84] Comment explicit comp of J --- simpeg/dask/potential_fields/base.py | 3 ++- .../potential_fields/magnetics/simulation.py | 2 +- simpeg/meta/dask_sim.py | 20 ++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 031e86b24f..9717415615 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -53,7 +53,7 @@ def linear_operator(self): n_cells *= 3 n_components = len(self.survey.components) - n_blocks = int( + n_blocks = np.ceil( (n_cells * n_components * self.survey.receiver_locations.shape[0] * 8.0 * 1e-6) / self.max_chunk_size ) @@ -140,6 +140,7 @@ def Jmatrix(self, value): self._Jmatrix = value +Sim.clean_on_model_update = [] Sim.G = G Sim._chunk_format = _chunk_format Sim.chunk_format = chunk_format diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 770320c2e4..e19ca3f4d5 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -28,7 +28,7 @@ def getJtJdiag(self, m, W=None, f=None): + ampDeriv[1, :, None] * self.Jmatrix[1::3] + ampDeriv[2, :, None] * self.Jmatrix[2::3] ) - diag = ((W[:, None] * J) ** 2).sum(axis=0).compute() + diag = array.einsum("i,ij,ij->j", W**2, J, J).compute() self._gtg_diagonal = diag else: diag = self._gtg_diagonal diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 86459b12f2..cbb1582d5e 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -70,6 +70,11 @@ def _reduce(client, operation, items): return client.gather(items[0]) +def _set_worker(obj, worker): + obj.worker = worker + return obj + + def _validate_type_or_future_of_type( property_name, objects, @@ -112,6 +117,7 @@ def _validate_type_or_future_of_type( warnings.warn( f"{property_name} {i} is not on the expected worker.", stacklevel=2 ) + obj = client.submit(_set_worker, obj, worker) # Ensure this runs on the expected worker futures = [] @@ -471,13 +477,13 @@ def fields(self, m): f = [] simulations = [] for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): - jmatrix = client.submit( - _compute_j, - sim, - m_future, - workers=worker, - ) - sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) + # jmatrix = client.submit( + # _compute_j, + # sim, + # m_future, + # workers=worker, + # ) + # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) f.append( client.submit( _calc_fields, From de9ee29f5470dad5c9041431a92c8978ad99c057 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 10:16:01 -0800 Subject: [PATCH 28/84] Fix naming in get fields --- simpeg/potential_fields/gravity/simulation.py | 6 ++++-- simpeg/potential_fields/magnetics/simulation.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/simpeg/potential_fields/gravity/simulation.py b/simpeg/potential_fields/gravity/simulation.py index 2c8ac62701..7fcf032cd1 100644 --- a/simpeg/potential_fields/gravity/simulation.py +++ b/simpeg/potential_fields/gravity/simulation.py @@ -145,7 +145,7 @@ def __init__( self._sensitivity_gravity = _sensitivity_gravity_serial self._forward_gravity = _forward_gravity_serial - def fields(self, m): + def fields(self, m=None): """ Forward model the gravity field of the mesh on the receivers in the survey @@ -160,7 +160,9 @@ def fields(self, m): Gravity fields generated by the given model on every receiver location. """ - self.model = m + if m is not None: + self.model = m + if self.store_sensitivities == "forward_only": # Compute the linear operation without forming the full dense G if self.engine == "choclo": diff --git a/simpeg/potential_fields/magnetics/simulation.py b/simpeg/potential_fields/magnetics/simulation.py index 610b933682..05449aded5 100644 --- a/simpeg/potential_fields/magnetics/simulation.py +++ b/simpeg/potential_fields/magnetics/simulation.py @@ -178,9 +178,10 @@ def M(self, M): M = np.asarray(M) self._M = M.reshape((self.nC, 3)) - def fields(self, model): - self.model = model - # model = self.chiMap * model + def fields(self, m=None): + if m is not None: + self.model = m + if self.store_sensitivities == "forward_only": if self.engine == "choclo": fields = self._forward(self.chi) From 14e7d2b22818241ab6a6840123061738678ef706 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 21:37:11 -0800 Subject: [PATCH 29/84] Comment out Explicit version. Clean ups --- simpeg/dask/simulation.py | 1 - simpeg/data_misfit.py | 12 +-- simpeg/meta/dask_sim.py | 176 ++++++++++++++++++++------------------ 3 files changed, 92 insertions(+), 97 deletions(-) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 535bb8d38c..ca915ed62e 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -115,7 +115,6 @@ def Jmatrix(self): """ if getattr(self, "_Jmatrix", None) is None: self._Jmatrix = self.compute_J(self.model) - self._stashed_fields = None return self._Jmatrix diff --git a/simpeg/data_misfit.py b/simpeg/data_misfit.py index 05dabcbb41..ef8273b36f 100644 --- a/simpeg/data_misfit.py +++ b/simpeg/data_misfit.py @@ -1,5 +1,5 @@ import numpy as np -from .utils import Counter, mkvc, sdiag, timeIt, Identity, validate_type +from .utils import Counter, sdiag, timeIt, Identity, validate_type from .data import Data from .simulation import BaseSimulation from .objective_function import L2ObjectiveFunction @@ -359,16 +359,6 @@ def getJtJdiag(self, m): + "Cannot form the sensitivity explicitly" ) - # mapping_deriv = self.model_map.deriv(m) - # - # if self.model_map is not None: - # m = mapping_deriv @ m - jtjdiag = self.simulation.getJtJdiag(m, W=self.W) - # if self.model_map is not None: - # jtjdiag = mkvc( - # (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0) - # ) - return jtjdiag diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index cbb1582d5e..a30bdc4d71 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -158,6 +158,8 @@ class DaskMetaSimulation(MetaSimulation): The dask client to use for communication. """ + clean_on_model_update = ["_jtjdiag", "_stashed_fields"] + def __init__(self, simulations, mappings, client): self._client = validate_type("client", client, Client, cast=False) self._concrete_simulations = None @@ -324,6 +326,8 @@ def fields(self, m): self.model = m client = self.client m_future = self._m_as_future + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): @@ -337,6 +341,7 @@ def fields(self, m): workers=worker, ) ) + self._stashed_fields = f return f def dpred(self, m=None, f=None): @@ -447,91 +452,92 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag -def _compute_j(sim, model): - sim.model = model - jmatrix = getattr(sim, "_Jmatrix", None) - - if jmatrix is None: - jmatrix = sim.compute_J(model) - - return jmatrix - - -def set_jmatrix(sim, jmatrix): - sim._Jmatrix = jmatrix - return sim - - -class DaskMetaSimulationExplicit(DaskMetaSimulation): - clean_on_model_update = ["_Jmatrix", "_stashed_fields"] - - def fields(self, m): - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - - client = self.client - m_future = self._m_as_future - # The above should pass the model to all the internal simulations. - f = [] - simulations = [] - for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): - # jmatrix = client.submit( - # _compute_j, - # sim, - # m_future, - # workers=worker, - # ) - # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) - f.append( - client.submit( - _calc_fields, - mapping, - sim, - m_future, - self._repeat_sim, - workers=worker, - ) - ) - simulations.append(sim) - - self._stashed_fields = f - self.simulations = simulations - return f - - def getJtJdiag(self, m, W=None, f=None): - self.model = m - m_future = self._m_as_future - if getattr(self, "_jtjdiag", None) is None: - if W is None: - W = np.ones(self.survey.nD) - else: - W = W.diagonal() - jtj_diag = [] - client = self.client - if f is None: - f = self.fields(m) - for i, (mapping, sim, worker, field) in enumerate( - zip(self.mappings, self.simulations, self._workers, f) - ): - sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] - - jtj_diag.append( - client.submit( - _get_jtj_diag, - mapping, - sim, - m_future, - field, - sim_w, - self._repeat_sim, - workers=worker, - ) - ) - self._jtjdiag = _reduce(client, add, jtj_diag) - - return self._jtjdiag +# +# def _compute_j(sim, model): +# sim.model = model +# jmatrix = getattr(sim, "_Jmatrix", None) +# +# if jmatrix is None: +# jmatrix = sim.compute_J(model) +# +# return jmatrix +# +# +# def set_jmatrix(sim, jmatrix): +# sim._Jmatrix = jmatrix +# return sim + + +# class DaskMetaSimulationExplicit(DaskMetaSimulation): +# clean_on_model_update = ["_Jmatrix", "_stashed_fields"] +# +# def fields(self, m): +# self.model = m +# +# if getattr(self, "_stashed_fields", None) is not None: +# return self._stashed_fields +# +# client = self.client +# m_future = self._m_as_future +# # The above should pass the model to all the internal simulations. +# f = [] +# simulations = [] +# for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): +# # jmatrix = client.submit( +# # _compute_j, +# # sim, +# # m_future, +# # workers=worker, +# # ) +# # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) +# f.append( +# client.submit( +# _calc_fields, +# mapping, +# sim, +# m_future, +# self._repeat_sim, +# workers=worker, +# ) +# ) +# simulations.append(sim) +# +# self._stashed_fields = f +# # self.simulations = simulations +# return f +# +# def getJtJdiag(self, m, W=None, f=None): +# self.model = m +# m_future = self._m_as_future +# if getattr(self, "_jtjdiag", None) is None: +# if W is None: +# W = np.ones(self.survey.nD) +# else: +# W = W.diagonal() +# jtj_diag = [] +# client = self.client +# if f is None: +# f = self.fields(m) +# for i, (mapping, sim, worker, field) in enumerate( +# zip(self.mappings, self.simulations, self._workers, f) +# ): +# sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] +# +# jtj_diag.append( +# client.submit( +# _get_jtj_diag, +# mapping, +# sim, +# m_future, +# field, +# sim_w, +# self._repeat_sim, +# workers=worker, +# ) +# ) +# self._jtjdiag = _reduce(client, add, jtj_diag) +# +# return self._jtjdiag class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation): From 374a1de21735673fd207f3a33bf8a676858b511d Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 6 Jan 2025 13:06:20 -0800 Subject: [PATCH 30/84] Implement DaskComboMisfit --- simpeg/dask/inverse_problem.py | 29 +- simpeg/dask/objective_function.py | 511 ++++++++++++++++++++++-------- simpeg/directives/directives.py | 22 +- simpeg/meta/dask_sim.py | 4 +- 4 files changed, 409 insertions(+), 157 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 5baa675e2b..543d409e8c 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,26 +1,21 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from dask.distributed import get_client, Future +from .objective_function import DaskComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse - from ..objective_function import ComboObjectiveFunction def get_dpred(self, m, f=None): dpreds = [] - for objfct in self.dmisfit.objfcts: - dpred = objfct.simulation.dpred(m) - dpreds += [dpred] + if isinstance(self.dmisfit, DaskComboMisfits): + return self.dmisfit.get_dpred(m, f=f) - if isinstance(dpreds[0], Future): - client = get_client() - dpreds = client.gather(dpreds) - else: - for i, dpred in enumerate(dpreds): - dpreds[i] = np.asarray(dpred) + for objfct in self.dmisfit.objfcts: + dpred = objfct.simulation.dpred(m, f=f) + dpreds += [np.asarray(dpred)] return dpreds @@ -34,9 +29,15 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): self.dpred = self.get_dpred(m) - phi_d = 0 - for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residual = objfct.W * (objfct.data.dobs - pred) + residuals = [] + if isinstance(self.dmisfit, DaskComboMisfits): + residuals = self.dmisfit.residuals(m) + else: + for (_, objfct), pred in zip(self.dmisfit, self.dpred): + residuals.append(objfct.W * (objfct.data.dobs - pred)) + + phi_d = 0.0 + for residual in residuals: phi_d += np.vdot(residual, residual) reg2Deriv = [] diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 2ad57299ff..89e137fc63 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,152 +1,405 @@ from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction -import dask.array as da - import numpy as np -from dask.distributed import Future, get_client, Client +import scipy.sparse as sp from ..data_misfit import L2DataMisfit +from simpeg.maps import IdentityMap +from simpeg.meta.dask_sim import _validate_type_or_future_of_type, _reduce - -@property -def client(self): - if getattr(self, "_client", None) is None: - self._client = get_client() - - return self._client +from operator import add -@client.setter -def client(self, client): - assert isinstance(client, Client) - self._client = client +def _calc_fields(objfct, model): + return objfct.simulation.fields(m=objfct.simulation.model) -BaseObjectiveFunction.client = client +def _calc_dpred(objfct, model, field): + return objfct.simulation.dpred(m=objfct.simulation.model, f=field) -def dask_call(self, m, f=None): - fcts = [] - multipliers = [] - for i, phi in enumerate(self): - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: +def _calc_residual(objfct, model, field): + return objfct.W * ( + objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model, f=field) + ) - if f is not None and objfct._has_fields: - fct = objfct(m, f=f[i]) - else: - fct = objfct(m) - - if isinstance(fct, Future): - future = self.client.compute( - self.client.submit(da.multiply, multiplier, fct).result() - ) - fcts += [future] - else: - fcts += [fct] - multipliers += [multiplier] - - if isinstance(fcts[0], Future): - phi = self.client.submit( - da.sum, self.client.submit(da.vstack, fcts), axis=0 - ).result() - return phi +def _deriv(objfct, multiplier, mapping, model, fields): + if fields is not None and objfct.has_fields: + return ( + 2 + * multiplier + * mapping.deriv(model).T + @ objfct.deriv(objfct.simulation.model, f=fields) + ) else: - return np.sum(np.r_[multipliers][:, None] * np.vstack(fcts), axis=0).squeeze() - - -ComboObjectiveFunction.__call__ = dask_call - - -def dask_deriv(self, m, f=None): - """ - First derivative of the composite objective function is the sum of the - derivatives of each objective function in the list, weighted by their - respective multplier. - - :param numpy.ndarray m: model - :param SimPEG.Fields f: Fields object (if applicable) - """ - - g = [] - multipliers = [] - for i, phi in enumerate(self): - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: - - if f is not None and isinstance(objfct, L2DataMisfit): - fct = objfct.deriv(m, f=f[i]) - else: - fct = objfct.deriv(m) - - if isinstance(fct, Future): - future = self.client.compute( - self.client.submit(da.multiply, multiplier, fct) - ) - g += [future] - else: - g += [fct] - - multipliers += [multiplier] + return ( + 2 + * multiplier + * mapping.deriv(model).T + @ objfct.deriv(objfct.simulation.model) + ) + + +def _deriv2(objfct, multiplier, mapping, model, v, fields): + sim_v = mapping.deriv(model) @ v + if fields is not None and objfct.has_fields: + return ( + 2 + * multiplier + * mapping.deriv(model).T + @ objfct.deriv2(objfct.simulation.model, sim_v, f=fields) + ) + else: + return ( + 2 + * multiplier + * mapping.deriv(model).T + @ objfct.deriv2(objfct.simulation.model, sim_v) + ) - if isinstance(g[0], Future): - big_future = self.client.submit( - da.sum, self.client.submit(da.vstack, g), axis=0 - ).result() - return self.client.compute(big_future).result() - else: - return np.sum(np.r_[multipliers][:, None] * np.vstack(g), axis=0).squeeze() +def _store_model(mapping, objfct, model): + objfct.simulation.model = mapping * model -ComboObjectiveFunction.deriv = dask_deriv +def _get_jtj_diag(mapping, objfct, model, field): + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W, f=field) + sim_jtj = sp.diags(np.sqrt(jtj)) + m_deriv = mapping.deriv(model) + return np.asarray((sim_jtj @ m_deriv).power(2).sum(axis=0)).flatten() -def dask_deriv2(self, m, v=None, f=None): +class DaskComboMisfits(ComboObjectiveFunction): """ - Second derivative of the composite objective function is the sum of the - second derivatives of each objective function in the list, weighted by - their respective multplier. - - :param numpy.ndarray m: model - :param numpy.ndarray v: vector we are multiplying by - :param SimPEG.Fields f: Fields object (if applicable) + A composite objective function for distributed computing. """ - H = [] - multipliers = [] - for phi in self: - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue - else: - fct = objfct.deriv2(m, v) - - if isinstance(fct, Future): - future = self.client.submit(da.multiply, multiplier, fct) - H += [future] - else: - H += [fct] - - multipliers += [multiplier] - - if isinstance(H[0], Future): - big_future = self.client.submit( - da.sum, self.client.submit(da.vstack, H), axis=0 - ).result() - - return np.asarray(big_future) - - else: - phi_deriv2 = 0 - for multiplier, h in zip(multipliers, H): - phi_deriv2 += multiplier * h - - return phi_deriv2 - - -ComboObjectiveFunction.deriv2 = dask_deriv2 + def __init__( + self, + objfcts: list[BaseObjectiveFunction], + mappings: list[IdentityMap], + multipliers=None, + client: Client | None = None, + **kwargs, + ): + self._model: np.ndarray | None = None + self.client = client + + super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) + + self.mappings = mappings + self._repeat_sim = False # Flag to indicate if the simulation is repeated + + def __call__(self, m, f=None): + self.model = m + client = self.client + m_future = self._m_as_future + + if f is None: + f = self.fields(m) + + values = [] + for phi, field, worker in zip(self, f, self._workers): + multiplier, objfct = phi + if multiplier == 0.0: # don't evaluate the fct + continue + + values.append( + client.submit( + _calc_objective, objfct, multiplier, m_future, field, workers=worker + ) + ) + + return _reduce(client, add, values) + + @property + def client(self): + """ + Get the dask.distributed.Client instance. + """ + return self._client + + @client.setter + def client(self, client): + if not isinstance(client, Client): + raise TypeError("client must be a dask.distributed.Client") + + self._client = client + + def deriv(self, m, f=None): + """ + First derivative of the composite objective function is the sum of the + derivatives of each objective function in the list, weighted by their + respective multplier. + + :param numpy.ndarray m: model + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + client = self.client + m_future = self._m_as_future + + if f is None: + f = self.fields(m) + + derivs = [] + for multiplier, objfct, mapping, field, worker in zip( + self.multipliers, self._futures, self.mappings, f, self._workers + ): + if multiplier == 0.0: # don't evaluate the fct + continue + + derivs.append( + client.submit( + _deriv, objfct, multiplier, mapping, m_future, field, workers=worker + ) + ) + + return _reduce(client, add, derivs) + + def deriv2(self, m, v=None, f=None): + """ + Second derivative of the composite objective function is the sum of the + second derivatives of each objective function in the list, weighted by + their respective multplier. + + :param numpy.ndarray m: model + :param numpy.ndarray v: vector we are multiplying by + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + client = self.client + m_future = self._m_as_future + [v_future] = client.scatter([v], broadcast=True) + + if f is None: + f = self.fields(m) + + derivs = [] + for multiplier, objfct, mapping, field, worker in zip( + self.multipliers, self._futures, self.mappings, f, self._workers + ): + if multiplier == 0.0: # don't evaluate the fct + continue + + derivs.append( + client.submit( + _deriv2, + objfct, + multiplier, + mapping, + m_future, + v_future, + field, + workers=worker, + ) + ) + + return _reduce(client, add, derivs) + + def get_dpred(self, m, f=None): + self.model = m + + if f is None: + f = self.fields(m) + + client = self.client + m_future = self._m_as_future + dpred = [] + for objfct, worker, field in zip(self._futures, self._workers, f): + dpred.append( + client.submit( + _calc_dpred, + objfct, + m_future, + field, + workers=worker, + ) + ) + return client.gather(dpred) + + def getJtJdiag(self, m, f=None): + self.model = m + m_future = self._m_as_future + if getattr(self, "_jtjdiag", None) is None: + + jtj_diag = [] + client = self.client + if f is None: + f = self.fields(m) + for mapping, objfct, worker, field in zip( + self.mappings, self._futures, self._workers, f + ): + jtj_diag.append( + client.submit( + _get_jtj_diag, + mapping, + objfct, + m_future, + field, + workers=worker, + ) + ) + self._jtjdiag = _reduce(client, add, jtj_diag) + + return self._jtjdiag + + def fields(self, m): + self.model = m + client = self.client + m_future = self._m_as_future + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + # The above should pass the model to all the internal simulations. + f = [] + for objfct, worker in zip(self._futures, self._workers): + f.append( + client.submit( + _calc_fields, + objfct, + m_future, + workers=worker, + ) + ) + self._stashed_fields = f + return f + + @property + def mappings(self): + """The future mappings paired to each data misfit. + + Every mapping should accept the same length model, and output + a model that is consistent with the simulation. + + Returns + ------- + (n_sim) list of distributed.Future simpeg.maps.IdentityMap + """ + return self._mappings + + @mappings.setter + def mappings(self, value): + client = self.client + + workers = self._workers + if len(value) != len(self.objfcts): + raise ValueError( + "Must provide the same number of mappings and simulations." + ) + mappings = _validate_type_or_future_of_type( + "mappings", value, IdentityMap, client, workers=workers + ) + + # validate mapping shapes and simulation shapes + model_len = client.submit(lambda v: v.shape[1], mappings[0]).result() + + def check_mapping(mapping, objfct, model_len): + if mapping.shape[1] != model_len: + # Bad mapping model length + return 1 + map_out_shape = mapping.shape[0] + for name in objfct.simulation._act_map_names: + sim_mapping = getattr(objfct.simulation, name) + sim_in_shape = sim_mapping.shape[1] + if ( + map_out_shape != "*" + and sim_in_shape != "*" + and sim_in_shape != map_out_shape + ): + # Inconsistent simulation input and mapping output + return 2 + # All good + return 0 + + error_checks = [] + for mapping, objfct, worker in zip(mappings, self._futures, workers): + # if it was a repeat objfct, this should cause the simulation to be transfered + # to each worker. + error_checks.append( + client.submit(check_mapping, mapping, objfct, model_len, workers=worker) + ) + error_checks = np.asarray(client.gather(error_checks)) + + if np.any(error_checks == 1): + raise ValueError("All mappings must have the same input length") + if np.any(error_checks == 2): + raise ValueError( + f"Simulations and mappings at indices {np.where(error_checks == 2)}" + f" are inconsistent." + ) + + self._mappings = mappings + + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + # Only send the model to the internal simulations if it was updated. + if value is self.model: + return + + client = self.client + [self._m_as_future] = client.scatter([value], broadcast=True) + + futures = [] + for mapping, objfct, worker in zip(self.mappings, self._futures, self._workers): + futures.append( + client.submit( + _store_model, + mapping, + objfct, + self._m_as_future, + workers=worker, + ) + ) + self.client.gather(futures) # blocking call to ensure all models were stored + + @property + def objfcts(self): + return self._objfcts + + @objfcts.setter + def objfcts(self, objfcts): + client = self.client + + futures, workers = _validate_type_or_future_of_type( + "objfcts", objfcts, L2DataMisfit, client, return_workers=True + ) + for objfct, future in zip(objfcts, futures): + if hasattr(objfct, "name"): + future.name = objfct.name + + self._objfcts = objfcts + self._futures = futures + self._workers = workers + + def residuals(self, m, f=None): + """ + Compute the residual for the data misfit. + """ + self.model = m + if f is None: + f = self.fields(m) + client = self.client + m_future = self._m_as_future + residuals = [] + for objfct, worker, field in zip(self._futures, self._workers, f): + residuals.append( + client.submit( + _calc_residual, + objfct, + m_future, + field, + workers=worker, + ) + ) + return client.gather(residuals) + + @property + def workers(self): + """ + Get the list of dask.distributed.workers associated with the objective functions. + """ + return self._workers diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 077e10bebe..1873eaea88 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -48,7 +48,6 @@ validate_float, validate_ndarray_with_shape, ) -from dask.distributed import get_client, Future from geoh5py.groups.property_group import GroupTypeEnum from geoh5py.groups import PropertyGroup, UIJsonGroup from geoh5py.objects import ObjectBase @@ -56,17 +55,16 @@ def compute_JtJdiags(data_misfit, m): - jtj_diags = [] - for dmisfit in data_misfit.objfcts: - jtj_diags.append(dmisfit.getJtJdiag(m)) - - if isinstance(jtj_diags[0], Future): - client = get_client() - jtj_diags = client.gather(jtj_diags) - - jtj_diag = np.zeros_like(jtj_diags[0]) - for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): - jtj_diag += multiplier * diag + if hasattr(data_misfit, "getJtJdiag"): + return data_misfit.getJtJdiag(m) + else: + jtj_diags = [] + for dmisfit in data_misfit.objfcts: + jtj_diags.append(dmisfit.getJtJdiag(m)) + + jtj_diag = np.zeros_like(jtj_diags[0]) + for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + jtj_diag += multiplier * diag return np.asarray(jtj_diag) diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index a30bdc4d71..79a494f158 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -162,7 +162,7 @@ class DaskMetaSimulation(MetaSimulation): def __init__(self, simulations, mappings, client): self._client = validate_type("client", client, Client, cast=False) - self._concrete_simulations = None + super().__init__(simulations, mappings) def _make_survey(self): @@ -188,7 +188,7 @@ def simulations(self): @simulations.setter def simulations(self, value): client = self.client - self._concrete_simulations = client.gather(value) + simulations, workers = _validate_type_or_future_of_type( "simulations", value, BaseSimulation, client, return_workers=True ) From 7bf43e2db520553fc00341a1f388445030fc007f Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 6 Jan 2025 14:32:23 -0800 Subject: [PATCH 31/84] Rely on MetaSim for handling of mapping --- simpeg/dask/objective_function.py | 138 +++++------------------------- 1 file changed, 21 insertions(+), 117 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 89e137fc63..c501de42df 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,9 +1,8 @@ from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction import numpy as np -import scipy.sparse as sp +from dask.distributed import Client from ..data_misfit import L2DataMisfit -from simpeg.maps import IdentityMap from simpeg.meta.dask_sim import _validate_type_or_future_of_type, _reduce from operator import add @@ -23,50 +22,27 @@ def _calc_residual(objfct, model, field): ) -def _deriv(objfct, multiplier, mapping, model, fields): +def _deriv(objfct, multiplier, model, fields): if fields is not None and objfct.has_fields: - return ( - 2 - * multiplier - * mapping.deriv(model).T - @ objfct.deriv(objfct.simulation.model, f=fields) - ) + return 2 * multiplier * objfct.deriv(objfct.simulation.model, f=fields) else: - return ( - 2 - * multiplier - * mapping.deriv(model).T - @ objfct.deriv(objfct.simulation.model) - ) + return 2 * multiplier * objfct.deriv(objfct.simulation.model) -def _deriv2(objfct, multiplier, mapping, model, v, fields): - sim_v = mapping.deriv(model) @ v +def _deriv2(objfct, multiplier, model, v, fields): if fields is not None and objfct.has_fields: - return ( - 2 - * multiplier - * mapping.deriv(model).T - @ objfct.deriv2(objfct.simulation.model, sim_v, f=fields) - ) + return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) else: - return ( - 2 - * multiplier - * mapping.deriv(model).T - @ objfct.deriv2(objfct.simulation.model, sim_v) - ) + return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v) -def _store_model(mapping, objfct, model): - objfct.simulation.model = mapping * model +def _store_model(objfct, model): + objfct.simulation.model = model -def _get_jtj_diag(mapping, objfct, model, field): +def _get_jtj_diag(objfct, model, field): jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W, f=field) - sim_jtj = sp.diags(np.sqrt(jtj)) - m_deriv = mapping.deriv(model) - return np.asarray((sim_jtj @ m_deriv).power(2).sum(axis=0)).flatten() + return jtj.flatten() class DaskComboMisfits(ComboObjectiveFunction): @@ -77,7 +53,6 @@ class DaskComboMisfits(ComboObjectiveFunction): def __init__( self, objfcts: list[BaseObjectiveFunction], - mappings: list[IdentityMap], multipliers=None, client: Client | None = None, **kwargs, @@ -87,9 +62,6 @@ def __init__( super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) - self.mappings = mappings - self._repeat_sim = False # Flag to indicate if the simulation is repeated - def __call__(self, m, f=None): self.model = m client = self.client @@ -143,15 +115,15 @@ def deriv(self, m, f=None): f = self.fields(m) derivs = [] - for multiplier, objfct, mapping, field, worker in zip( - self.multipliers, self._futures, self.mappings, f, self._workers + for multiplier, objfct, field, worker in zip( + self.multipliers, self._futures, f, self._workers ): if multiplier == 0.0: # don't evaluate the fct continue derivs.append( client.submit( - _deriv, objfct, multiplier, mapping, m_future, field, workers=worker + _deriv, objfct, multiplier, m_future, field, workers=worker ) ) @@ -176,8 +148,8 @@ def deriv2(self, m, v=None, f=None): f = self.fields(m) derivs = [] - for multiplier, objfct, mapping, field, worker in zip( - self.multipliers, self._futures, self.mappings, f, self._workers + for multiplier, objfct, field, worker in zip( + self.multipliers, self._futures, f, self._workers ): if multiplier == 0.0: # don't evaluate the fct continue @@ -187,7 +159,6 @@ def deriv2(self, m, v=None, f=None): _deriv2, objfct, multiplier, - mapping, m_future, v_future, field, @@ -227,13 +198,10 @@ def getJtJdiag(self, m, f=None): client = self.client if f is None: f = self.fields(m) - for mapping, objfct, worker, field in zip( - self.mappings, self._futures, self._workers, f - ): + for objfct, worker, field in zip(self._futures, self._workers, f): jtj_diag.append( client.submit( _get_jtj_diag, - mapping, objfct, m_future, field, @@ -264,72 +232,6 @@ def fields(self, m): self._stashed_fields = f return f - @property - def mappings(self): - """The future mappings paired to each data misfit. - - Every mapping should accept the same length model, and output - a model that is consistent with the simulation. - - Returns - ------- - (n_sim) list of distributed.Future simpeg.maps.IdentityMap - """ - return self._mappings - - @mappings.setter - def mappings(self, value): - client = self.client - - workers = self._workers - if len(value) != len(self.objfcts): - raise ValueError( - "Must provide the same number of mappings and simulations." - ) - mappings = _validate_type_or_future_of_type( - "mappings", value, IdentityMap, client, workers=workers - ) - - # validate mapping shapes and simulation shapes - model_len = client.submit(lambda v: v.shape[1], mappings[0]).result() - - def check_mapping(mapping, objfct, model_len): - if mapping.shape[1] != model_len: - # Bad mapping model length - return 1 - map_out_shape = mapping.shape[0] - for name in objfct.simulation._act_map_names: - sim_mapping = getattr(objfct.simulation, name) - sim_in_shape = sim_mapping.shape[1] - if ( - map_out_shape != "*" - and sim_in_shape != "*" - and sim_in_shape != map_out_shape - ): - # Inconsistent simulation input and mapping output - return 2 - # All good - return 0 - - error_checks = [] - for mapping, objfct, worker in zip(mappings, self._futures, workers): - # if it was a repeat objfct, this should cause the simulation to be transfered - # to each worker. - error_checks.append( - client.submit(check_mapping, mapping, objfct, model_len, workers=worker) - ) - error_checks = np.asarray(client.gather(error_checks)) - - if np.any(error_checks == 1): - raise ValueError("All mappings must have the same input length") - if np.any(error_checks == 2): - raise ValueError( - f"Simulations and mappings at indices {np.where(error_checks == 2)}" - f" are inconsistent." - ) - - self._mappings = mappings - @property def model(self): return self._model @@ -340,15 +242,17 @@ def model(self, value): if value is self.model: return + self._stashed_fields = None + self._jtjdiag = None + client = self.client [self._m_as_future] = client.scatter([value], broadcast=True) futures = [] - for mapping, objfct, worker in zip(self.mappings, self._futures, self._workers): + for objfct, worker in zip(self._futures, self._workers): futures.append( client.submit( _store_model, - mapping, objfct, self._m_as_future, workers=worker, From b2ad007670d23e063f04dbb76821ec24d0292b36 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 6 Jan 2025 14:39:25 -0800 Subject: [PATCH 32/84] Remove unused dask.data_misfit module --- simpeg/dask/__init__.py | 1 - simpeg/dask/data_misfit.py | 100 ------------------------------------- 2 files changed, 101 deletions(-) delete mode 100644 simpeg/dask/data_misfit.py diff --git a/simpeg/dask/__init__.py b/simpeg/dask/__init__.py index 960be9da43..89fee4fcd9 100644 --- a/simpeg/dask/__init__.py +++ b/simpeg/dask/__init__.py @@ -1,6 +1,5 @@ try: import simpeg.dask.simulation - import simpeg.dask.data_misfit import simpeg.dask.electromagnetics.frequency_domain.simulation import simpeg.dask.electromagnetics.static.resistivity.simulation import simpeg.dask.electromagnetics.static.resistivity.simulation_2d diff --git a/simpeg/dask/data_misfit.py b/simpeg/dask/data_misfit.py deleted file mode 100644 index 0cb8ae97f1..0000000000 --- a/simpeg/dask/data_misfit.py +++ /dev/null @@ -1,100 +0,0 @@ -import numpy as np - -from ..data_misfit import L2DataMisfit - -from dask.distributed import get_client, Future - - -def _data_residual(dpred, dobs): - return dpred - dobs - - -def _misfit(residual, W): - vec = W * residual - return np.dot(vec, vec) - - -def dask_call(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__` - """ - residuals = self.residual(m, f=f) - - if isinstance(residuals, Future): - client = get_client() - phi_d = client.submit(_misfit, residuals, self.W) - else: - phi_d = _misfit(residuals, self.W) - - return phi_d - - -L2DataMisfit.__call__ = dask_call - - -def dask_residual(self, m, f=None): - dpred = self.simulation.dpred(m, f=f) - - if isinstance(dpred, Future): - client = get_client() - residuals = client.submit(_data_residual, dpred, self.data.dobs) - else: - residuals = _data_residual(dpred, self.data.dobs) - - return residuals - - -L2DataMisfit.residual = dask_residual - - -def dask_deriv(self, m, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv` - """ - residuals = self.residual(m, f=f) - - if isinstance(residuals, Future): - client = get_client() - who = client.who_has(residuals) - wtw_d = client.submit( - _stack_futures, - residuals, - self.W.diagonal() ** 2.0, - workers=who[residuals.key], - ) - else: - wtw_d = self.W.diagonal() ** 2.0 * residuals - - Jtvec = self.simulation.Jtvec(m, wtw_d) - - return Jtvec - - -L2DataMisfit.deriv = dask_deriv - - -def _stack_futures(futures, W): - return W * futures - - -def dask_deriv2(self, m, v, f=None): - """ - Distributed :obj:`simpeg.data_misfit.L2DataMisfit.deriv2` - """ - jvec = self.simulation.Jvec(m, v) - if isinstance(jvec, Future): - client = get_client() - who = client.who_has(jvec) - w_jvec = client.submit( - _stack_futures, jvec, self.W.diagonal() ** 2.0, workers=who[jvec.key] - ) - - else: - w_jvec = self.W.diagonal() ** 2.0 * jvec - - jtwjvec = self.simulation.Jtvec(m, w_jvec) - - return jtwjvec - - -L2DataMisfit.deriv2 = dask_deriv2 From 704289b6b43349bbe629a86dc45b831638bcd122 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 7 Jan 2025 09:17:16 -0800 Subject: [PATCH 33/84] Remove 2 multiplier --- simpeg/dask/objective_function.py | 10 +++++----- .../dask/potential_fields/magnetics/simulation.py | 14 ++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c501de42df..47ecf0326f 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -24,16 +24,16 @@ def _calc_residual(objfct, model, field): def _deriv(objfct, multiplier, model, fields): if fields is not None and objfct.has_fields: - return 2 * multiplier * objfct.deriv(objfct.simulation.model, f=fields) + return multiplier * objfct.deriv(objfct.simulation.model, f=fields) else: - return 2 * multiplier * objfct.deriv(objfct.simulation.model) + return multiplier * objfct.deriv(objfct.simulation.model) def _deriv2(objfct, multiplier, model, v, fields): if fields is not None and objfct.has_fields: - return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) + return multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) else: - return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v) + return multiplier * objfct.deriv2(objfct.simulation.model, v) def _store_model(objfct, model): @@ -229,7 +229,7 @@ def fields(self, m): workers=worker, ) ) - self._stashed_fields = f + self._stashed_fields = client.compute(f) return f @property diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index e19ca3f4d5..09c558f38d 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -18,9 +18,7 @@ def getJtJdiag(self, m, W=None, f=None): W = W.diagonal() if getattr(self, "_gtg_diagonal", None) is None: if not self.is_amplitude_data: - diag = array.einsum( - "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix - ).compute() + diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) else: ampDeriv = self.ampDeriv J = ( @@ -28,12 +26,12 @@ def getJtJdiag(self, m, W=None, f=None): + ampDeriv[1, :, None] * self.Jmatrix[1::3] + ampDeriv[2, :, None] * self.Jmatrix[2::3] ) - diag = array.einsum("i,ij,ij->j", W**2, J, J).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal + diag = array.einsum("i,ij,ij->j", W**2, J, J) + self._gtg_diagonal = np.asarray(diag) - return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) + return mkvc( + (sdiag(np.sqrt(self._gtg_diagonal)) @ self.chiDeriv).power(2).sum(axis=0) + ) Sim.clean_on_model_update = [] From 3cf2f4d5f78404d69eb1dcbfd9fdbb09b5818d66 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 8 Jan 2025 14:24:29 -0800 Subject: [PATCH 34/84] Fix for mag. Best so far --- simpeg/dask/objective_function.py | 11 ++++--- .../potential_fields/magnetics/simulation.py | 33 +------------------ .../potential_fields/magnetics/simulation.py | 3 +- 3 files changed, 9 insertions(+), 38 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 47ecf0326f..caa1b16f76 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -126,8 +126,8 @@ def deriv(self, m, f=None): _deriv, objfct, multiplier, m_future, field, workers=worker ) ) - - return _reduce(client, add, derivs) + derivs = _reduce(client, add, derivs) + return derivs def deriv2(self, m, v=None, f=None): """ @@ -166,7 +166,9 @@ def deriv2(self, m, v=None, f=None): ) ) - return _reduce(client, add, derivs) + derivs = _reduce(client, add, derivs) + + return derivs def get_dpred(self, m, f=None): self.model = m @@ -229,7 +231,7 @@ def fields(self, m): workers=worker, ) ) - self._stashed_fields = client.compute(f) + self._stashed_fields = f return f @property @@ -259,6 +261,7 @@ def model(self, value): ) ) self.client.gather(futures) # blocking call to ensure all models were stored + self._model = value @property def objfcts(self): diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index 09c558f38d..a39ef9ef9d 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,37 +1,6 @@ -import numpy as np -from dask import array from ....potential_fields.magnetics import Simulation3DIntegral as Sim from ..base import G -from ....utils import sdiag, mkvc - - -def getJtJdiag(self, m, W=None, f=None): - """ - Return the diagonal of JtJ - """ - - self.model = m - - if W is None: - W = np.ones(self.nD) - else: - W = W.diagonal() - if getattr(self, "_gtg_diagonal", None) is None: - if not self.is_amplitude_data: - diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) - else: - ampDeriv = self.ampDeriv - J = ( - ampDeriv[0, :, None] * self.Jmatrix[::3] - + ampDeriv[1, :, None] * self.Jmatrix[1::3] - + ampDeriv[2, :, None] * self.Jmatrix[2::3] - ) - diag = array.einsum("i,ij,ij->j", W**2, J, J) - self._gtg_diagonal = np.asarray(diag) - - return mkvc( - (sdiag(np.sqrt(self._gtg_diagonal)) @ self.chiDeriv).power(2).sum(axis=0) - ) +from ...simulation import getJtJdiag Sim.clean_on_model_update = [] diff --git a/simpeg/potential_fields/magnetics/simulation.py b/simpeg/potential_fields/magnetics/simulation.py index 05449aded5..228fc1de2e 100644 --- a/simpeg/potential_fields/magnetics/simulation.py +++ b/simpeg/potential_fields/magnetics/simulation.py @@ -246,8 +246,7 @@ def getJtJdiag(self, m, W=None, f=None): if getattr(self, "_gtg_diagonal", None) is None: diag = np.zeros(self.Jmatrix.shape[1]) if not self.is_amplitude_data: - for i in range(len(W)): - diag += W[i] * (self.Jmatrix[i] * self.Jmatrix[i]) + diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) else: ampDeriv = self.ampDeriv Gx = self.Jmatrix[::3] From e6470794de95a6995d4e9ae66f284fa2a1959040 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 9 Jan 2025 14:11:58 -0800 Subject: [PATCH 35/84] Outfit frequency simulations with client option --- .../frequency_domain/simulation.py | 179 +++++++++++------- 1 file changed, 113 insertions(+), 66 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 0b105b0e0b..d9d43297e7 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,20 +1,15 @@ from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim - from ....utils import Zero +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp from multiprocessing import cpu_count from dask import array, compute, delayed - from simpeg.dask.utils import get_parallel_blocks - from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary - import zarr -from tqdm import tqdm -@delayed def evaluate_receivers(block, mesh, fields): data = [] for source, _, receiver in block: @@ -23,7 +18,6 @@ def evaluate_receivers(block, mesh, fields): return np.hstack(data) -@delayed def source_evaluation(simulation, sources): s_m, s_e = [], [] for source in sources: @@ -34,7 +28,6 @@ def source_evaluation(simulation, sources): return s_m, s_e -@delayed def receiver_derivs(survey, mesh, fields, blocks): field_derivatives = [] for address in blocks: @@ -55,7 +48,6 @@ def receiver_derivs(survey, mesh, fields, blocks): return field_derivatives -@delayed def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): """ Evaluate the sensitivities for the block or data @@ -108,15 +100,23 @@ def getSourceTerm(self, freq, source=None): source_block = np.array_split(source_list, cpu_count()) block_compute = [] + + if self.client: + sim = self.client.scatter(self) + for block in source_block: if len(block) == 0: continue - block_compute.append( - self.client.submit(source_evaluation, self, block, workers=self.worker) - ) + if self.client: + block_compute.append(self.client.submit(source_evaluation, sim, block)) + else: + block_compute.append(delayed(source_evaluation)(self, block)) - blocks = self.client.gather(block_compute) + if self.client: + blocks = self.client.gather(block_compute) + else: + blocks = compute(block_compute)[0] s_m, s_e = [], [] for block in blocks: if block[0]: @@ -174,23 +174,32 @@ def dpred(self, m=None, f=None): for rx in src.receiver_list: all_receivers.append((src, ind, rx)) + if self.client: + f = self.client.scatter(f) + mesh = self.client.scatter(self.mesh) + receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) rows = [] - mesh = delayed(self.mesh) for block in receiver_blocks: n_data = np.sum([rec.nD for _, _, rec in block]) if n_data == 0: continue - rows.append( - array.from_delayed( - evaluate_receivers(block, mesh, f), - dtype=np.float64, - shape=(n_data,), + if self.client: + rows.append(self.client.submit(evaluate_receivers, block, mesh, f)) + else: + rows.append( + array.from_delayed( + delayed(evaluate_receivers, block, mesh, f), + dtype=np.float64, + shape=(n_data,), + ) ) - ) - data = compute(array.hstack(rows))[0] + if self.client: + data = np.hstack(self.client.gather(rows)) + else: + data = compute(array.hstack(rows))[0] return data @@ -199,8 +208,8 @@ def fields(self, m=None): if m is not None: self.model = m - # if getattr(self, "_stashed_fields", None) is not None: - # return self._stashed_fields + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields f = self.fieldsPair(self) Ainv = {} @@ -213,9 +222,9 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - # Ainv = Ainv - # - # self._stashed_fields = f + self.Ainv = Ainv + + self._stashed_fields = f return f @@ -226,19 +235,13 @@ def compute_J(self, m, f=None): if f is None: f = self.fields(m) - Ainv = {} - for freq in self.survey.frequencies: - A = self.getA(freq) - Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) - Ainv[freq] = Ainv_solve - - if len(Ainv) > 1: + if len(self.Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(Ainv.values())[0] + A_i = list(self.Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -255,37 +258,51 @@ def compute_J(self, m, f=None): blocks = get_parallel_blocks( self.survey.source_list, compute_row_size, optimize=False ) - fields_array = delayed(f[:, self._solutionType]) - fields = delayed(f) - survey = delayed(self.survey) - mesh = delayed(self.mesh) - blocks_receiver_derivs = [] - - for block in blocks: - blocks_receiver_derivs.append( - receiver_derivs( - survey, - mesh, - fields, - block, - ) + if self.client: + fields_array = self.client.scatter(f[:, self._solutionType]) + fields = self.client.scatter(f) + survey = self.client.scatter(self.survey) + mesh = self.client.scatter(self.mesh) + blocks_receiver_derivs = self.client.map( + receiver_derivs, + [survey] * len(blocks), + [mesh] * len(blocks), + [fields] * len(blocks), + blocks, ) + else: + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + blocks_receiver_derivs = [] + delayed_derivs = delayed(receiver_derivs) + for block in blocks: + blocks_receiver_derivs.append( + delayed_derivs( + survey, + mesh, + fields, + block, + ) + ) # Dask process for all derivatives - blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + if self.client: + blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) + else: + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - for block_derivs_chunks, addresses_chunks in tqdm( - zip(blocks_receiver_derivs, blocks), - ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(Ainv)[0]} Hz", - ): + for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): Jmatrix = self.parallel_block_compute( m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in Ainv.values(): + for A in self.Ainv.values(): A.clean() + del self.Ainv + if self.store_sensitivities == "disk": del Jmatrix Jmatrix = array.from_zarr(self.sensitivity_path) @@ -298,7 +315,13 @@ def parallel_block_compute( ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() - ATinvdf_duT = delayed(A_i * block_stack) + + ATinvdf_duT = A_i * block_stack + if self.client: + ATinvdf_duT = self.client.scatter(ATinvdf_duT) + sim = self.client.scatter(self) + else: + ATinvdf_duT = delayed(ATinvdf_duT) count = 0 rows = [] block_delayed = [] @@ -306,39 +329,63 @@ def parallel_block_compute( for address, dfduT in zip(addresses, blocks_receiver_derivs): n_cols = dfduT.shape[1] n_rows = address[1][2] - block_delayed.append( - array.from_delayed( - eval_block( - self, + + if self.client: + block_delayed.append( + self.client.submit( + eval_block, + sim, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, - ), - dtype=np.float32, - shape=(n_rows, m_size), + ) + ) + else: + delayed_eval = delayed(eval_block) + block_delayed.append( + array.from_delayed( + delayed_eval( + self, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), + ) ) - ) count += n_cols rows += address[1][1].tolist() indices = np.hstack(rows) + if self.client: + block = np.vstack(self.client.gather(block_delayed)) + else: + block = compute(array.vstack(block_delayed))[0] + if self.store_sensitivities == "disk": Jmatrix.set_orthogonal_selection( (indices, slice(None)), - compute(array.vstack(block_delayed))[0], + block, ) else: # Dask process to compute row and store - Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + Jmatrix[indices, :] = block return Jmatrix Sim.parallel_block_compute = parallel_block_compute Sim.compute_J = compute_J +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix Sim.fields = fields Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm From 935d62d1f59eae78f4dba0ce6df414d7eaa859f2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 10 Jan 2025 08:34:57 -0800 Subject: [PATCH 36/84] Temp remove of overloaded dask dpred for FEM --- .../frequency_domain/simulation.py | 44 ++++++++++--------- .../time_domain/simulation.py | 6 +++ simpeg/dask/objective_function.py | 6 ++- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index d9d43297e7..611fe9e15b 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -168,33 +168,35 @@ def dpred(self, m=None, f=None): m = self.model f = self.fields(m) - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - if self.client: f = self.client.scatter(f) mesh = self.client.scatter(self.mesh) + else: + mesh = delayed(self.mesh) + delayed_block_eval = delayed(evaluate_receivers) - receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) rows = [] - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + block = [(src, ind, rx)] - if self.client: - rows.append(self.client.submit(evaluate_receivers, block, mesh, f)) - else: - rows.append( - array.from_delayed( - delayed(evaluate_receivers, block, mesh, f), - dtype=np.float64, - shape=(n_data,), + # receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) + # rows = [] + # for block in receiver_blocks: + # n_data = np.sum([rec.nD for _, _, rec in block]) + if rx.nD == 0: + continue + + if self.client: + rows.append(self.client.submit(evaluate_receivers, block, mesh, f)) + else: + rows.append( + array.from_delayed( + delayed_block_eval(block, mesh, f), + dtype=np.float64, + shape=(rx.nD,), + ) ) - ) if self.client: data = np.hstack(self.client.gather(rows)) @@ -387,5 +389,5 @@ def parallel_block_compute( Sim.Jtvec = Jtvec Sim.Jmatrix = Jmatrix Sim.fields = fields -Sim.dpred = dpred +# Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 93cda31d99..e21f80bddd 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -4,6 +4,7 @@ from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix from simpeg.fields import TimeFields from multiprocessing import cpu_count import numpy as np @@ -575,6 +576,11 @@ def compute_rows( Sim.fields = fields +Sim.getJtJdiag = getJtJdiag Sim.getSourceTerm = getSourceTerm Sim.dpred = dpred Sim.compute_J = compute_J +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index caa1b16f76..c07c4fa046 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -241,7 +241,11 @@ def model(self): @model.setter def model(self, value): # Only send the model to the internal simulations if it was updated. - if value is self.model: + if ( + isinstance(value, np.ndarray) + and isinstance(self.model, np.ndarray) + and np.allclose(value, self.model) + ): return self._stashed_fields = None From f9532976883449d213918eab05373ff08bb5b417 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 10 Jan 2025 15:32:09 -0800 Subject: [PATCH 37/84] Outfit time domain simulation with client run --- .../frequency_domain/simulation.py | 4 +- .../time_domain/simulation.py | 376 ++++++++++++------ simpeg/dask/utils.py | 8 +- 3 files changed, 255 insertions(+), 133 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 611fe9e15b..3e3f5c4393 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,3 +1,5 @@ +import gc + from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix @@ -304,7 +306,7 @@ def compute_J(self, m, f=None): A.clean() del self.Ainv - + gc.collect() if self.store_sensitivities == "disk": del Jmatrix Jmatrix = array.from_zarr(self.sensitivity_path) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index e21f80bddd..8a50032ee8 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -4,7 +4,7 @@ from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero -from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix +from ...simulation import client, getJtJdiag, Jvec, Jtvec, Jmatrix from simpeg.fields import TimeFields from multiprocessing import cpu_count import numpy as np @@ -15,7 +15,6 @@ from simpeg.utils import mkvc from time import time -from tqdm import tqdm def fields(self, m=None): @@ -60,10 +59,29 @@ def getSourceTerm(self, tInd): source_block = np.array_split(source_list, cpu_count()) block_compute = [] + + if self.client: + sim = self.client.scatter(self) + else: + delayed_source_eval = delayed(source_evaluation) + for block in source_block: - block_compute.append(source_evaluation(self, block, self.times[tInd])) + if self.client: + block_compute.append( + self.client.submit( + source_evaluation, + sim, + block, + self.times[tInd], + ) + ) + else: + block_compute.append(delayed_source_eval(self, block, self.times[tInd])) - blocks = dask.compute(block_compute)[0] + if self.client: + blocks = self.client.gather(block_compute) + else: + blocks = dask.compute(block_compute)[0] s_m, s_e = [], [] for block in blocks: @@ -77,62 +95,62 @@ def getSourceTerm(self, tInd): return np.vstack(s_m).T, np.vstack(s_e).T -def dpred(self, m=None, f=None): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - rows = [] - receiver_projection = self.survey.source_list[0].receiver_list[0].projField - fields_array = f[:, receiver_projection, :] - - if len(self.survey.source_list) == 1: - fields_array = fields_array[:, np.newaxis, :] - - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - - receiver_blocks = np.array_split(all_receivers, cpu_count()) - - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue - - rows.append( - array.from_delayed( - evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), - dtype=np.float64, - shape=(n_data,), - ) - ) - - data = array.hstack(rows).compute() - - return data +# def dpred(self, m=None, f=None): +# r""" +# dpred(m, f=None) +# Create the projected data from a model. +# The fields, f, (if provided) will be used for the predicted data +# instead of recalculating the fields (which may be expensive!). +# +# .. math:: +# +# d_\\text{pred} = P(f(m)) +# +# Where P is a projection of the fields onto the data space. +# """ +# if self.survey is None: +# raise AttributeError( +# "The survey has not yet been set and is required to compute " +# "data. Please set the survey for the simulation: " +# "simulation.survey = survey" +# ) +# +# if f is None: +# if m is None: +# m = self.model +# f = self.fields(m) +# +# rows = [] +# receiver_projection = self.survey.source_list[0].receiver_list[0].projField +# fields_array = f[:, receiver_projection, :] +# +# if len(self.survey.source_list) == 1: +# fields_array = fields_array[:, np.newaxis, :] +# +# all_receivers = [] +# +# for ind, src in enumerate(self.survey.source_list): +# for rx in src.receiver_list: +# all_receivers.append((src, ind, rx)) +# +# receiver_blocks = np.array_split(all_receivers, cpu_count()) +# +# for block in receiver_blocks: +# n_data = np.sum([rec.nD for _, _, rec in block]) +# if n_data == 0: +# continue +# +# rows.append( +# array.from_delayed( +# evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), +# dtype=np.float64, +# shape=(n_data,), +# ) +# ) +# +# data = array.hstack(rows).compute() +# +# return data def compute_J(self, m, f=None): @@ -174,7 +192,14 @@ def compute_J(self, m, f=None): ) ATinv_df_duT_v = {} - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + + if self.client: + fields_array = self.client.scatter(fields_array) + sim = self.client.scatter(self) + else: + delayed_compute_rows = delayed(compute_rows) + + for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): AdiagTinv = self.Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] @@ -190,35 +215,54 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - j_row_updates.append( - array.from_delayed( - compute_rows( - self, + if self.client: + j_row_updates.append( + self.client.submit( + compute_rows, + sim, tInd, block, ATinv_df_duT_v, fields_array, time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), + ) ) - ) + else: + j_row_updates.append( + array.from_delayed( + delayed_compute_rows( + self, + tInd, + block, + ATinv_df_duT_v, + fields_array, + time_mask, + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), + ) + ) + + if self.client: + j_row_updates = np.vstack(self.client.gather(j_row_updates)) + + else: + j_row_updates = array.vstack(j_row_updates).compute() if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" array.to_zarr( - Jmatrix + array.vstack(j_row_updates), + Jmatrix + j_row_updates, sens_name, compute=True, overwrite=True, ) Jmatrix = array.from_zarr(sens_name) else: - Jmatrix += array.vstack(j_row_updates).compute() + Jmatrix += j_row_updates for A in self.Ainv.values(): A.clean() @@ -268,25 +312,51 @@ def _getField(self, name, ind, src_list): else: # loop over the time steps arrays = [] + if self.client: + pointerFields = self.client.scatter(pointerFields) + src_list = self.client.scatter(src_list) + func = self.client.scatter(func) + else: + delayed_field_comp = delayed(field_projection) + for i, TIND_i in enumerate(timeII): # Need to parallelize this - arrays.append( - array.from_delayed( - field_projection(pointerFields, src_list, i, TIND_i, func), - dtype=np.float32, - shape=(pointerShape[0], pointerShape[1], 1), + + if self.client: + arrays.append( + self.client.submit( + field_projection, + pointerFields, + src_list, + i, + TIND_i, + func, + ) + ) + else: + arrays.append( + array.from_delayed( + delayed_field_comp( + pointerFields, src_list, i, TIND_i, func + ), + dtype=np.float32, + shape=(pointerShape[0], pointerShape[1], 1), + ) ) - ) - out = array.dstack(arrays).compute() + if self.client: + arrays = self.client.gather(arrays) + out = np.dstack(arrays) + else: + out = array.dstack(arrays).compute() shape = self._correctShape(name, ind, deflate=True) return out.reshape(shape, order="F") -TimeFields._getField = _getField +# TimeFields._getField = _getField +TimeFields.client = client -@delayed def field_projection(field_array, src_list, array_ind, time_ind, func): fieldI = field_array[:, :, array_ind] if fieldI.shape[0] == fieldI.size: @@ -300,7 +370,6 @@ def field_projection(field_array, src_list, array_ind, time_ind, func): return new_array -@delayed def source_evaluation(simulation, sources, time_channel): s_m, s_e = [], [] for source in sources: @@ -311,7 +380,6 @@ def source_evaluation(simulation, sources, time_channel): return s_m, s_e -@delayed def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): data = [] for _, ind, receiver in block: @@ -324,31 +392,61 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): return np.hstack(data) -def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): +def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ delayed_chunks = [] + + if self.client: + mesh = self.client.scatter(self.mesh) + time_mesh = self.client.scatter(self.time_mesh) + fields = self.client.scatter(fields) + else: + mesh = self.mesh + time_mesh = self.time_mesh + delayed_block_deriv = delayed(block_deriv) + for chunks in blocks: if len(chunks) == 0: continue - delayed_block = delayed_block_deriv( - simulation.nT, - chunks, - fields_shape[0], - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - simulation.model.size, - ) - delayed_chunks.append(delayed_block) + if self.client: + delayed_chunks.append( + self.client.submit( + block_deriv, + self.nT, + chunks, + fields_shape[0], + self.survey.source_list, + mesh, + time_mesh, + fields, + self.model.size, + ) + ) + else: + delayed_chunks.append( + delayed_block_deriv( + self.nT, + chunks, + fields_shape[0], + self.survey.source_list, + self.mesh, + self.time_mesh, + fields, + self.model.size, + ) + ) + + if self.client: + result = self.client.gather(delayed_chunks) + else: + result = dask.compute(delayed_chunks)[0] - result = dask.compute(delayed_chunks)[0] df_duT = [ [[[] for _ in block] for block in blocks if len(block) > 0] - for _ in range(simulation.nT + 1) + for _ in range(self.nT + 1) ] j_updates = [] @@ -362,8 +460,8 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): if len(j_updates.data) > 0: Jmatrix += j_updates - if simulation.store_sensitivities == "disk": - sens_name = simulation.sensitivity_path[:-5] + f"_{time() % 2}.zarr" + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{time() % 2}.zarr" array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) Jmatrix = array.from_zarr(sens_name) @@ -385,7 +483,7 @@ def update_deriv_blocks(address, indices, derivatives, solve, shape): def get_field_deriv_block( - simulation, + self, block: list, field_derivs: list, tInd: int, @@ -401,9 +499,10 @@ def get_field_deriv_block( count = 0 Asubdiag = None - if tInd < simulation.nT - 1: - Asubdiag = simulation.getAsubdiag(tInd + 1) + if tInd < self.nT - 1: + Asubdiag = self.getAsubdiag(tInd + 1) + delayed_deriv = delayed(deriv_block) for ((s_id, r_id, b_id), (rx_ind, _, shape)), field_deriv in zip( block, field_derivs ): @@ -419,28 +518,49 @@ def get_field_deriv_block( local_ind, ) count += len(local_ind) - deriv_comp = deriv_block( - s_id, - r_id, - b_id, - ATinv_df_duT_v, - Asubdiag, - local_ind, - field_deriv, - tInd, - ) - stacked_blocks.append( - array.from_delayed( - deriv_comp, - dtype=float, - shape=( - field_deriv.shape[0], - len(local_ind), - ), + + if self.client: + stacked_blocks.append( + self.client.submit( + deriv_block, + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + field_deriv, + tInd, + ) + ) + else: + deriv_comp = delayed_deriv( + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + field_deriv, + tInd, + ) + stacked_blocks.append( + array.from_delayed( + deriv_comp, + dtype=float, + shape=( + field_deriv.shape[0], + len(local_ind), + ), + ) ) - ) if len(stacked_blocks) > 0: - blocks = array.hstack(stacked_blocks).compute() + + if self.client: + blocks = np.hstack(self.client.gather(stacked_blocks)) + else: + blocks = array.hstack(stacked_blocks).compute() + solve = (AdiagTinv * blocks).reshape(blocks.shape) else: solve = None @@ -456,8 +576,7 @@ def get_field_deriv_block( return ATinv_df_duT_v -@delayed -def delayed_block_deriv( +def block_deriv( n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape ): """Compute derivatives for sources and receivers in a block""" @@ -504,7 +623,6 @@ def delayed_block_deriv( return df_duT, j_updates -@delayed def deriv_block( s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd ): @@ -521,7 +639,6 @@ def deriv_block( return stacked_block -@delayed def compute_rows( simulation, tInd, @@ -575,10 +692,11 @@ def compute_rows( return np.vstack(rows) +Sim.client = client Sim.fields = fields Sim.getJtJdiag = getJtJdiag Sim.getSourceTerm = getSourceTerm -Sim.dpred = dpred +# Sim.dpred = dpred Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag Sim.Jvec = Jvec diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index d2bf546220..10188b7bc0 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -26,7 +26,9 @@ def compute_chunk_sizes(M, N, target_chunk_size): return rowChunk, colChunk -def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> list: +def get_parallel_blocks( + source_list: list, data_block_size, optimize=True, thread_count=64 +) -> list: """ Get the blocks of sources and receivers to be computed in parallel. @@ -50,7 +52,7 @@ def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> li chunk_size = len(chunk) # Condition to start a new block - if (row_count + chunk_size) > (data_block_size * cpu_count()): + if (row_count + chunk_size) > (data_block_size * thread_count): row_count = 0 block_count += 1 blocks.append([]) @@ -69,7 +71,7 @@ def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> li row_count += chunk_size # Re-split over cpu_count if too few blocks - if len(blocks) < cpu_count() and optimize: + if len(blocks) < thread_count and optimize: flatten_blocks = [] for block in blocks: flatten_blocks += block From 96c3a5fef149e26076151f139e88ac9a594bdf3f Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 12 Jan 2025 14:56:46 -0800 Subject: [PATCH 38/84] Update ip --- .../static/induced_polarization/simulation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index beee5f7338..4fbf0d46a5 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -76,18 +76,13 @@ def getJtJdiag(self, m, W=None): """ self.model = m if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish if W is None: W = self._scale * np.ones(self.nD) else: W = (self._scale * W.diagonal()) ** 2.0 - diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, da.Array): - diag = np.asarray(diag.compute()) + diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) self._jtjdiag = diag @@ -121,7 +116,7 @@ def Jtvec(self, m, v, f=None): if isinstance(self.Jmatrix, Future): self.Jmatrix # Wait to finish - return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) + return da.dot((v * self._scale).astype(np.float32), self.Jmatrix).astype(np.float32) Sim.compute_J = compute_J From e882cf11c45b1afe2c7d0077ba595eb2cf049e74 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 13 Jan 2025 15:08:43 -0800 Subject: [PATCH 39/84] Remove parallel compute of RHS --- .../frequency_domain/simulation.py | 60 +++++++++++-------- .../static/induced_polarization/simulation.py | 2 +- simpeg/dask/objective_function.py | 30 +++++++--- 3 files changed, 57 insertions(+), 35 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 3e3f5c4393..8243560d25 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -5,7 +5,8 @@ from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp -from multiprocessing import cpu_count + +# from multiprocessing import cpu_count from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary @@ -98,32 +99,39 @@ def getSourceTerm(self, freq, source=None): of the correct size """ if source is None: - source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, cpu_count()) - - block_compute = [] - - if self.client: - sim = self.client.scatter(self) - - for block in source_block: - if len(block) == 0: - continue - - if self.client: - block_compute.append(self.client.submit(source_evaluation, sim, block)) - else: - block_compute.append(delayed(source_evaluation)(self, block)) - - if self.client: - blocks = self.client.gather(block_compute) - else: - blocks = compute(block_compute)[0] + # if self.client: + # n_splits = int(self.client.cluster.scheduler.total_nthreads / len(self.client.cluster.scheduler.workers)) + # else: + # n_splits = cpu_count() + # + # source_list = self.survey.get_sources_by_frequency(freq) + # source_block = np.array_split(source_list, n_splits) + # + # block_compute = [] + # + # if self.client: + # sim = self.client.scatter(self) + # source_block = self.client.scatter(source_block) + # + # for block in source_block: + # if self.client: + # block_compute.append(self.client.submit(source_evaluation, sim, block)) + # else: + # block_compute.append(delayed(source_evaluation)(self, block)) + # + # if self.client: + # blocks = self.client.gather(block_compute) + # else: + # blocks = compute(block_compute)[0] s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m += block[0] - s_e += block[1] + # for block in blocks: + # if block[0]: + for source in self.survey.get_sources_by_frequency(freq): + sm, se = source.eval(self) + s_m.append(sm) + s_e.append(se) + # s_m += block[0] + # s_e += block[1] else: sm, se = source.eval(self) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 4fbf0d46a5..8768265021 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -70,7 +70,7 @@ def dpred(self, m=None, f=None): return np.asarray(data) -def getJtJdiag(self, m, W=None): +def getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ """ diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c07c4fa046..a52608a8f3 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -55,10 +55,12 @@ def __init__( objfcts: list[BaseObjectiveFunction], multipliers=None, client: Client | None = None, + workers: list[str] | None = None, **kwargs, ): self._model: np.ndarray | None = None self.client = client + self.workers = workers super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) @@ -98,6 +100,20 @@ def client(self, client): self._client = client + @property + def workers(self): + """ + List of worker addresses + """ + return self._workers + + @workers.setter + def workers(self, workers): + if not isinstance(workers, list | type(None)): + raise TypeError("workers must be a list of strings") + + self._workers = workers + def deriv(self, m, f=None): """ First derivative of the composite objective function is the sum of the @@ -276,7 +292,12 @@ def objfcts(self, objfcts): client = self.client futures, workers = _validate_type_or_future_of_type( - "objfcts", objfcts, L2DataMisfit, client, return_workers=True + "objfcts", + objfcts, + L2DataMisfit, + client, + workers=self.workers, + return_workers=True, ) for objfct, future in zip(objfcts, futures): if hasattr(objfct, "name"): @@ -307,10 +328,3 @@ def residuals(self, m, f=None): ) ) return client.gather(residuals) - - @property - def workers(self): - """ - Get the list of dask.distributed.workers associated with the objective functions. - """ - return self._workers From e85eb830cc4419a215b47cc6c4b39bcffbcb99e5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 13 Jan 2025 15:44:17 -0800 Subject: [PATCH 40/84] Remove inner parallelization if client --- .../frequency_domain/simulation.py | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 8243560d25..c0f8a9043f 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -270,24 +270,19 @@ def compute_J(self, m, f=None): blocks = get_parallel_blocks( self.survey.source_list, compute_row_size, optimize=False ) + fields_array = f[:, self._solutionType] + blocks_receiver_derivs = [] if self.client: - fields_array = self.client.scatter(f[:, self._solutionType]) - fields = self.client.scatter(f) - survey = self.client.scatter(self.survey) - mesh = self.client.scatter(self.mesh) - blocks_receiver_derivs = self.client.map( - receiver_derivs, - [survey] * len(blocks), - [mesh] * len(blocks), - [fields] * len(blocks), - blocks, - ) + for block in blocks: + blocks_receiver_derivs.append( + receiver_derivs(self.survey, self.mesh, f, block) + ) else: fields_array = delayed(f[:, self._solutionType]) fields = delayed(f) survey = delayed(self.survey) mesh = delayed(self.mesh) - blocks_receiver_derivs = [] + delayed_derivs = delayed(receiver_derivs) for block in blocks: blocks_receiver_derivs.append( @@ -300,9 +295,7 @@ def compute_J(self, m, f=None): ) # Dask process for all derivatives - if self.client: - blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) - else: + if not self.client: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): @@ -329,10 +322,7 @@ def parallel_block_compute( block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = A_i * block_stack - if self.client: - ATinvdf_duT = self.client.scatter(ATinvdf_duT) - sim = self.client.scatter(self) - else: + if not self.client: ATinvdf_duT = delayed(ATinvdf_duT) count = 0 rows = [] @@ -344,9 +334,8 @@ def parallel_block_compute( if self.client: block_delayed.append( - self.client.submit( - eval_block, - sim, + eval_block( + self, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), @@ -376,7 +365,7 @@ def parallel_block_compute( indices = np.hstack(rows) if self.client: - block = np.vstack(self.client.gather(block_delayed)) + block = np.vstack(block_delayed) else: block = compute(array.vstack(block_delayed))[0] From 6adc81f771cff1731078f00e5d068d5d7929384d Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 14 Jan 2025 09:53:09 -0800 Subject: [PATCH 41/84] Never store Ainv on FEM --- .../frequency_domain/simulation.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index c0f8a9043f..df53f104e8 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -234,7 +234,7 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - self.Ainv = Ainv + # self.Ainv = Ainv self._stashed_fields = f @@ -247,13 +247,18 @@ def compute_J(self, m, f=None): if f is None: f = self.fields(m) - if len(self.Ainv) > 1: + Ainv = {} + for freq in self.survey.frequencies: + A = self.getA(freq) + Ainv[freq] = self.solver(sp.csr_matrix(A), **self.solver_opts) + + if len(Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(self.Ainv.values())[0] + A_i = list(Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -303,10 +308,10 @@ def compute_J(self, m, f=None): m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in self.Ainv.values(): + for A in Ainv.values(): A.clean() - del self.Ainv + del Ainv gc.collect() if self.store_sensitivities == "disk": del Jmatrix From 3e3e8ad7018380fe65124cccec79c3923a706821 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 07:46:54 -0800 Subject: [PATCH 42/84] Don't pass fields and use internal --- .../frequency_domain/simulation.py | 15 ++---- simpeg/dask/inverse_problem.py | 31 +++++++++++ simpeg/dask/objective_function.py | 54 +++++++++---------- .../potential_fields/gravity/simulation.py | 12 ++++- .../potential_fields/magnetics/simulation.py | 12 ++++- simpeg/meta/dask_sim.py | 10 ++-- 6 files changed, 90 insertions(+), 44 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index df53f104e8..c0f8a9043f 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -234,7 +234,7 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - # self.Ainv = Ainv + self.Ainv = Ainv self._stashed_fields = f @@ -247,18 +247,13 @@ def compute_J(self, m, f=None): if f is None: f = self.fields(m) - Ainv = {} - for freq in self.survey.frequencies: - A = self.getA(freq) - Ainv[freq] = self.solver(sp.csr_matrix(A), **self.solver_opts) - - if len(Ainv) > 1: + if len(self.Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(Ainv.values())[0] + A_i = list(self.Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -308,10 +303,10 @@ def compute_J(self, m, f=None): m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in Ainv.values(): + for A in self.Ainv.values(): A.clean() - del Ainv + del self.Ainv gc.collect() if self.store_sensitivities == "disk": del Jmatrix diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 543d409e8c..80013e3b07 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -5,6 +5,8 @@ from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse from ..objective_function import ComboObjectiveFunction +from simpeg.utils import call_hooks +from simpeg.version import __version__ as simpeg_version def get_dpred(self, m, f=None): @@ -112,3 +114,32 @@ def H_fun(v): BaseInvProblem.evalFunction = dask_evalFunction + + +@call_hooks("startup") +def startup(self, m0): + """startup(m0) + + Called when inversion is first starting. + """ + if self.debug: + print("Calling InvProblem.startup") + + if self.print_version: + print(f"\nRunning inversion with SimPEG v{simpeg_version}") + + for fct in self.reg.objfcts: + if ( + hasattr(fct, "reference_model") + and getattr(fct, "reference_model", None) is None + ): + print("simpeg.InvProblem will set Regularization.reference_model to m0.") + fct.reference_model = m0 + + self.phi_d = np.nan + self.phi_m = np.nan + + self.model = m0 + + +BaseInvProblem.startup = startup diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index a52608a8f3..6fb263b299 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -22,26 +22,26 @@ def _calc_residual(objfct, model, field): ) -def _deriv(objfct, multiplier, model, fields): - if fields is not None and objfct.has_fields: - return multiplier * objfct.deriv(objfct.simulation.model, f=fields) - else: - return multiplier * objfct.deriv(objfct.simulation.model) +def _deriv(objfct, multiplier, model): + # if fields is not None and objfct.has_fields: + # return multiplier * objfct.deriv(objfct.simulation.model) + # else: + return multiplier * objfct.deriv(objfct.simulation.model) -def _deriv2(objfct, multiplier, model, v, fields): - if fields is not None and objfct.has_fields: - return multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) - else: - return multiplier * objfct.deriv2(objfct.simulation.model, v) +def _deriv2(objfct, multiplier, model, v): + # if fields is not None and objfct.has_fields: + # return multiplier * objfct.deriv2(objfct.simulation.model, v) + # else: + return multiplier * objfct.deriv2(objfct.simulation.model, v) def _store_model(objfct, model): objfct.simulation.model = model -def _get_jtj_diag(objfct, model, field): - jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W, f=field) +def _get_jtj_diag(objfct, model): + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) return jtj.flatten() @@ -127,20 +127,18 @@ def deriv(self, m, f=None): client = self.client m_future = self._m_as_future - if f is None: - f = self.fields(m) + # if f is None: + # f = self.fields(m) derivs = [] - for multiplier, objfct, field, worker in zip( - self.multipliers, self._futures, f, self._workers + for multiplier, objfct, worker in zip( + self.multipliers, self._futures, self._workers ): if multiplier == 0.0: # don't evaluate the fct continue derivs.append( - client.submit( - _deriv, objfct, multiplier, m_future, field, workers=worker - ) + client.submit(_deriv, objfct, multiplier, m_future, workers=worker) ) derivs = _reduce(client, add, derivs) return derivs @@ -160,12 +158,12 @@ def deriv2(self, m, v=None, f=None): m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - if f is None: - f = self.fields(m) + # if f is None: + # f = self.fields(m) derivs = [] - for multiplier, objfct, field, worker in zip( - self.multipliers, self._futures, f, self._workers + for multiplier, objfct, worker in zip( + self.multipliers, self._futures, self._workers ): if multiplier == 0.0: # don't evaluate the fct continue @@ -177,7 +175,7 @@ def deriv2(self, m, v=None, f=None): multiplier, m_future, v_future, - field, + # field, workers=worker, ) ) @@ -214,15 +212,15 @@ def getJtJdiag(self, m, f=None): jtj_diag = [] client = self.client - if f is None: - f = self.fields(m) - for objfct, worker, field in zip(self._futures, self._workers, f): + # if f is None: + # f = self.fields(m) + for objfct, worker in zip(self._futures, self._workers): jtj_diag.append( client.submit( _get_jtj_diag, objfct, m_future, - field, + # field, workers=worker, ) ) diff --git a/simpeg/dask/potential_fields/gravity/simulation.py b/simpeg/dask/potential_fields/gravity/simulation.py index 99fbea1c7b..4c8d39271c 100644 --- a/simpeg/dask/potential_fields/gravity/simulation.py +++ b/simpeg/dask/potential_fields/gravity/simulation.py @@ -1,8 +1,18 @@ from ....potential_fields.gravity import Simulation3DIntegral as Sim -from ..base import G from ...simulation import getJtJdiag +@property +def G(self): + """ + Gravity forward operator + """ + if getattr(self, "_G", None) is None: + self._G = self.Jmatrix + + return self._G + + Sim.clean_on_model_update = [] Sim.getJtJdiag = getJtJdiag Sim.G = G diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index a39ef9ef9d..cf3303215b 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -1,8 +1,18 @@ from ....potential_fields.magnetics import Simulation3DIntegral as Sim -from ..base import G from ...simulation import getJtJdiag +@property +def G(self): + """ + Gravity forward operator + """ + if getattr(self, "_G", None) is None: + self._G = self.Jmatrix + + return self._G + + Sim.clean_on_model_update = [] Sim.getJtJdiag = getJtJdiag Sim.G = G diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 79a494f158..53389e43a2 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -70,9 +70,9 @@ def _reduce(client, operation, items): return client.gather(items[0]) -def _set_worker(obj, worker): - obj.worker = worker - return obj +# def _set_worker(obj, worker): +# obj.worker = worker +# return obj def _validate_type_or_future_of_type( @@ -85,9 +85,11 @@ def _validate_type_or_future_of_type( ): try: # validate as a list of things that need to be sent. + # workers = [(worker.worker_address,) for worker in client.cluster.workers.values()] objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) + # objects[0].simulation.simulations[0].worker = workers[0] if workers is None: objects = client.scatter(objects) else: @@ -117,7 +119,7 @@ def _validate_type_or_future_of_type( warnings.warn( f"{property_name} {i} is not on the expected worker.", stacklevel=2 ) - obj = client.submit(_set_worker, obj, worker) + # obj = client.submit(_set_worker, obj, worker) # Ensure this runs on the expected worker futures = [] From 8a8c968acd76d3f9dd6a5d8807c9b409f36bd8d2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 10:38:11 -0800 Subject: [PATCH 43/84] Use client inside base potential fields integral --- simpeg/dask/potential_fields/base.py | 63 +++++++++++++++++----------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 9717415615..e2767ce81a 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,4 +1,5 @@ import numpy as np + from ...potential_fields.base import BasePFSimulation as Sim import os @@ -34,7 +35,6 @@ def residual(self, m, dobs, f=None): return self.dpred(m, f=f) - dobs -@delayed def block_compute(sim, rows, components): block = [] for row in rows: @@ -58,18 +58,43 @@ def linear_operator(self): / self.max_chunk_size ) block_split = np.array_split(self.survey.receiver_locations, n_blocks) - rows = [ - array.from_delayed( - block_compute(self, block, self.survey.components), - dtype=self.sensitivity_dtype, - shape=( - (len(block) * n_components,) - if forward_only - else (len(block) * n_components, n_cells) - ), - ) - for block in block_split - ] + + if self.client: + sim = self.client.scatter(self, workers=self.worker) + else: + delayed_compute = delayed(block_compute) + + rows = [] + for block in block_split: + if self.client: + rows.append( + self.client.submit( + block_compute, + sim, + block, + self.survey.components, + workers=self.worker, + ) + ) + else: + chunk = delayed_compute(self, block, self.survey.components) + rows.append( + array.from_delayed( + chunk, + dtype=self.sensitivity_dtype, + shape=( + (len(block) * n_components,) + if forward_only + else (len(block) * n_components, n_cells) + ), + ) + ) + + if self.client: + if forward_only: + return np.hstack(self.client.gather(rows)) + return np.vstack(self.client.gather(rows)) + if forward_only: stack = array.concatenate(rows) else: @@ -113,17 +138,6 @@ def linear_operator(self): return stack.compute() -@property -def G(self): - """ - Gravity forward operator - """ - if getattr(self, "_G", None) is None: - self._G = self.Jmatrix - - return self._G - - def compute_J(self, _, f=None): return self.linear_operator() @@ -141,7 +155,6 @@ def Jmatrix(self, value): Sim.clean_on_model_update = [] -Sim.G = G Sim._chunk_format = _chunk_format Sim.chunk_format = chunk_format Sim.dpred = dpred From 3e1ec4a5e81a7be2ae0ed7e4bcfd0abf91305314 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 10:38:48 -0800 Subject: [PATCH 44/84] Run futures in blocks --- simpeg/dask/objective_function.py | 245 +++++++++++++++++++++--------- 1 file changed, 170 insertions(+), 75 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 6fb263b299..581ec48c6c 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -3,8 +3,8 @@ import numpy as np from dask.distributed import Client from ..data_misfit import L2DataMisfit -from simpeg.meta.dask_sim import _validate_type_or_future_of_type, _reduce - +from simpeg.meta.dask_sim import _reduce +from simpeg.utils import validate_list_of_types from operator import add @@ -45,6 +45,83 @@ def _get_jtj_diag(objfct, model): return jtj.flatten() +def _validate_type_or_future_of_type( + property_name, + objects, + obj_type, + client, + workers=None, + return_workers=False, +): + # try: + # # validate as a list of things that need to be sent. + workers = [(worker.worker_address,) for worker in client.cluster.workers.values()] + objects = validate_list_of_types( + property_name, objects, obj_type, ensure_unique=True + ) + workload = [[]] + count = 0 + for obj in objects: + if count == len(workers): + count = 0 + workload.append([]) + obj.simulation.simulations[0].worker = workers[count] + future = client.scatter([obj], workers=workers[count])[0] + + if hasattr(obj, "name"): + future.name = obj.name + + workload[-1].append(future) + count += 1 + + # objects[0].simulation.simulations[0].worker = workers[0] + # if workers is None: + # objects = client.scatter(objects) + # else: + # tmp = [] + # for obj, worker in zip(objects, workers): + # tmp.append(client.scatter([obj], workers=worker)[0]) + # objects = tmp + # except TypeError: + # pass + # ensure list of futures + # objects = validate_list_of_types( + # property_name, + # objects, + # Future, + # ) + # Figure out where everything lives + + # who = client.who_has(workload) + # # if workers is None: + # # workers = [] + # for ii, worker in enumerate(who.values()): + # if worker != workers[ii % len(workers)]: + # warnings.warn( + # f"{property_name} {i} is not on the expected worker.", stacklevel=2 + # ) + # # obj = client.submit(_set_worker, obj, worker) + + # Ensure this runs on the expected worker + futures = [] + for work in workload: + + for obj, worker in zip(work, workers): + futures.append( + client.submit( + lambda v: not isinstance(v, obj_type), obj, workers=worker + ) + ) + is_not_obj = np.array(client.gather(futures)) + if np.any(is_not_obj): + raise TypeError(f"{property_name} futures must be an instance of {obj_type}") + + if return_workers: + return workload, workers + else: + return workload + + class DaskComboMisfits(ComboObjectiveFunction): """ A composite objective function for distributed computing. @@ -131,15 +208,23 @@ def deriv(self, m, f=None): # f = self.fields(m) derivs = [] - for multiplier, objfct, worker in zip( - self.multipliers, self._futures, self._workers - ): - if multiplier == 0.0: # don't evaluate the fct - continue + count = 0 + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + derivs.append( + client.submit( + _deriv, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) + ) + count += 1 - derivs.append( - client.submit(_deriv, objfct, multiplier, m_future, workers=worker) - ) derivs = _reduce(client, add, derivs) return derivs @@ -162,23 +247,24 @@ def deriv2(self, m, v=None, f=None): # f = self.fields(m) derivs = [] - for multiplier, objfct, worker in zip( - self.multipliers, self._futures, self._workers - ): - if multiplier == 0.0: # don't evaluate the fct - continue + count = 0 + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue - derivs.append( - client.submit( - _deriv2, - objfct, - multiplier, - m_future, - v_future, - # field, - workers=worker, + derivs.append( + client.submit( + _deriv2, + objfct, + self.multipliers[count], + m_future, + v_future, + # field, + workers=worker, + ) ) - ) + count += 1 derivs = _reduce(client, add, derivs) @@ -193,16 +279,17 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] - for objfct, worker, field in zip(self._futures, self._workers, f): - dpred.append( - client.submit( - _calc_dpred, - objfct, - m_future, - field, - workers=worker, + for futures, fields in zip(self._futures, f): + for objfct, worker, field in zip(futures, self._workers, fields): + dpred.append( + client.submit( + _calc_dpred, + objfct, + m_future, + field, + workers=worker, + ) ) - ) return client.gather(dpred) def getJtJdiag(self, m, f=None): @@ -210,21 +297,25 @@ def getJtJdiag(self, m, f=None): m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: - jtj_diag = [] + jtj_diag = 0.0 client = self.client # if f is None: # f = self.fields(m) - for objfct, worker in zip(self._futures, self._workers): - jtj_diag.append( - client.submit( - _get_jtj_diag, - objfct, - m_future, - # field, - workers=worker, + for futures in self._futures: + work = [] + for objfct, worker in zip(futures, self._workers): + work.append( + client.submit( + _get_jtj_diag, + objfct, + m_future, + # field, + workers=worker, + ) ) - ) - self._jtjdiag = _reduce(client, add, jtj_diag) + jtj_diag += _reduce(client, add, work) + + self._jtjdiag = jtj_diag return self._jtjdiag @@ -236,15 +327,17 @@ def fields(self, m): return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] - for objfct, worker in zip(self._futures, self._workers): - f.append( - client.submit( - _calc_fields, - objfct, - m_future, - workers=worker, + for futures in self._futures: + f.append([]) + for objfct, worker in zip(futures, self._workers): + f[-1].append( + client.submit( + _calc_fields, + objfct, + m_future, + workers=worker, + ) ) - ) self._stashed_fields = f return f @@ -268,17 +361,18 @@ def model(self, value): client = self.client [self._m_as_future] = client.scatter([value], broadcast=True) - futures = [] - for objfct, worker in zip(self._futures, self._workers): - futures.append( - client.submit( - _store_model, - objfct, - self._m_as_future, - workers=worker, + stores = [] + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): + stores.append( + client.submit( + _store_model, + objfct, + self._m_as_future, + workers=worker, + ) ) - ) - self.client.gather(futures) # blocking call to ensure all models were stored + self.client.gather(stores) # blocking call to ensure all models were stored self._model = value @property @@ -297,9 +391,9 @@ def objfcts(self, objfcts): workers=self.workers, return_workers=True, ) - for objfct, future in zip(objfcts, futures): - if hasattr(objfct, "name"): - future.name = objfct.name + # for objfct, future in zip(objfcts, futures): + # if hasattr(objfct, "name"): + # future.name = objfct.name self._objfcts = objfcts self._futures = futures @@ -315,14 +409,15 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future residuals = [] - for objfct, worker, field in zip(self._futures, self._workers, f): - residuals.append( - client.submit( - _calc_residual, - objfct, - m_future, - field, - workers=worker, + for futures, fields in zip(self._futures, f): + for objfct, worker, field in zip(futures, self._workers, fields): + residuals.append( + client.submit( + _calc_residual, + objfct, + m_future, + field, + workers=worker, + ) ) - ) return client.gather(residuals) From bab728f84547f6749a077fefb1ad63cf673b8b27 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 13:45:51 -0800 Subject: [PATCH 45/84] Bring back internal parallelization of FEM --- .../frequency_domain/simulation.py | 160 ++++++++++-------- simpeg/dask/simulation.py | 17 ++ 2 files changed, 109 insertions(+), 68 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index c0f8a9043f..42364794d5 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -6,14 +6,13 @@ import numpy as np import scipy.sparse as sp -# from multiprocessing import cpu_count from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr -def evaluate_receivers(block, mesh, fields): +def receivers_eval(block, mesh, fields): data = [] for source, _, receiver in block: data.append(receiver.eval(source, mesh, fields).flatten()) @@ -21,7 +20,7 @@ def evaluate_receivers(block, mesh, fields): return np.hstack(data) -def source_evaluation(simulation, sources): +def source_eval(simulation, sources): s_m, s_e = [], [] for source in sources: sm, se = source.eval(simulation) @@ -99,39 +98,39 @@ def getSourceTerm(self, freq, source=None): of the correct size """ if source is None: - # if self.client: - # n_splits = int(self.client.cluster.scheduler.total_nthreads / len(self.client.cluster.scheduler.workers)) - # else: - # n_splits = cpu_count() - # - # source_list = self.survey.get_sources_by_frequency(freq) - # source_block = np.array_split(source_list, n_splits) - # - # block_compute = [] - # - # if self.client: - # sim = self.client.scatter(self) - # source_block = self.client.scatter(source_block) - # - # for block in source_block: - # if self.client: - # block_compute.append(self.client.submit(source_evaluation, sim, block)) - # else: - # block_compute.append(delayed(source_evaluation)(self, block)) - # - # if self.client: - # blocks = self.client.gather(block_compute) - # else: - # blocks = compute(block_compute)[0] + + source_list = self.survey.get_sources_by_frequency(freq) + source_block = np.array_split(source_list, self.n_threads) + + block_compute = [] + + if self.client: + sim = self.client.scatter(self, workers=self.worker) + source_block = self.client.scatter(source_block, workers=self.worker) + else: + delayed_source_eval = delayed(source_eval) + + for block in source_block: + if self.client: + block_compute.append( + self.client.submit(source_eval, sim, block, workers=self.worker) + ) + else: + block_compute.append(delayed_source_eval(self, block)) + + if self.client: + blocks = self.client.gather(block_compute) + else: + blocks = compute(block_compute)[0] s_m, s_e = [], [] - # for block in blocks: - # if block[0]: - for source in self.survey.get_sources_by_frequency(freq): - sm, se = source.eval(self) - s_m.append(sm) - s_e.append(se) - # s_m += block[0] - # s_e += block[1] + for block in blocks: + if block[0]: + # for source in self.survey.get_sources_by_frequency(freq): + # sm, se = source.eval(self) + # s_m.append(sm) + # s_e.append(se) + s_m += block[0] + s_e += block[1] else: sm, se = source.eval(self) @@ -153,7 +152,7 @@ def getSourceTerm(self, freq, source=None): return s_m, s_e -def dpred(self, m=None, f=None): +def dpred(self, m=None, f=None, compute_J=False): r""" dpred(m, f=None) Create the projected data from a model. @@ -178,42 +177,46 @@ def dpred(self, m=None, f=None): m = self.model f = self.fields(m) - if self.client: - f = self.client.scatter(f) - mesh = self.client.scatter(self.mesh) - else: - mesh = delayed(self.mesh) - delayed_block_eval = delayed(evaluate_receivers) + all_receivers = [] - rows = [] for ind, src in enumerate(self.survey.source_list): for rx in src.receiver_list: - block = [(src, ind, rx)] + all_receivers.append((src, ind, rx)) - # receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) - # rows = [] - # for block in receiver_blocks: - # n_data = np.sum([rec.nD for _, _, rec in block]) - if rx.nD == 0: - continue + receiver_blocks = np.array_split(np.asarray(all_receivers), self.n_threads) + rows = [] - if self.client: - rows.append(self.client.submit(evaluate_receivers, block, mesh, f)) - else: - rows.append( - array.from_delayed( - delayed_block_eval(block, mesh, f), - dtype=np.float64, - shape=(rx.nD,), - ) + if self.client: + f = self.client.scatter(f, workers=self.worker) + mesh = self.client.scatter(self.mesh, workers=self.worker) + else: + delayed_receivers_eval = delayed(receivers_eval) + mesh = delayed(self.mesh) + + for block in receiver_blocks: + n_data = np.sum([rec.nD for _, _, rec in block]) + if n_data == 0: + continue + + if self.client: + rows.append( + self.client.submit(receivers_eval, block, mesh, f, workers=self.worker) + ) + else: + rows.append( + array.from_delayed( + delayed_receivers_eval(block, mesh, f), + dtype=np.float64, + shape=(n_data,), ) + ) if self.client: - data = np.hstack(self.client.gather(rows)) + rows = np.hstack(self.client.gather(rows)) else: - data = compute(array.hstack(rows))[0] + rows = compute(array.hstack(rows))[0] - return data + return rows def fields(self, m=None): @@ -273,9 +276,22 @@ def compute_J(self, m, f=None): fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] if self.client: + fields_array = self.client.scatter( + f[:, self._solutionType], workers=self.worker + ) + fields = self.client.scatter(f, workers=self.worker) + survey = self.client.scatter(self.survey, workers=self.worker) + mesh = self.client.scatter(self.mesh, workers=self.worker) for block in blocks: blocks_receiver_derivs.append( - receiver_derivs(self.survey, self.mesh, f, block) + self.client.submit( + receiver_derivs, + survey, + mesh, + fields, + block, + workers=self.worker, + ) ) else: fields_array = delayed(f[:, self._solutionType]) @@ -295,7 +311,9 @@ def compute_J(self, m, f=None): ) # Dask process for all derivatives - if not self.client: + if self.client: + blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) + else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): @@ -322,7 +340,9 @@ def parallel_block_compute( block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = A_i * block_stack - if not self.client: + if self.client: + ATinvdf_duT = self.client.scatter(ATinvdf_duT, workers=self.worker) + else: ATinvdf_duT = delayed(ATinvdf_duT) count = 0 rows = [] @@ -333,14 +353,17 @@ def parallel_block_compute( n_rows = address[1][2] if self.client: + sim = self.client.scatter(self, workers=self.worker) block_delayed.append( - eval_block( - self, + self.client.submit( + eval_block, + sim, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, + workers=self.worker, ) ) else: @@ -365,6 +388,7 @@ def parallel_block_compute( indices = np.hstack(rows) if self.client: + block_delayed = self.client.gather(block_delayed) block = np.vstack(block_delayed) else: block = compute(array.vstack(block_delayed))[0] @@ -388,5 +412,5 @@ def parallel_block_compute( Sim.Jtvec = Jtvec Sim.Jmatrix = Jmatrix Sim.fields = fields -# Sim.dpred = dpred +Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index ca915ed62e..7c4b2dbaab 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -120,3 +120,20 @@ def Jmatrix(self): Sim.Jmatrix = Jmatrix + + +@property +def n_threads(self): + """ + Number of threads used by Dask + """ + if getattr(self, "_n_threads", None) is None: + if self.client: + self._n_threads = self.client.nthreads()[self.worker[0]] + else: + self._n_threads = cpu_count() + + return self._n_threads + + +Sim.n_threads = n_threads From 0b072beb114bca72dcf69a5646253fb82451ec2a Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 14:28:09 -0800 Subject: [PATCH 46/84] Change reduce to plain gather and sum --- simpeg/dask/objective_function.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 581ec48c6c..3a7a3f2717 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -3,9 +3,8 @@ import numpy as np from dask.distributed import Client from ..data_misfit import L2DataMisfit -from simpeg.meta.dask_sim import _reduce + from simpeg.utils import validate_list_of_types -from operator import add def _calc_fields(objfct, model): @@ -161,7 +160,8 @@ def __call__(self, m, f=None): ) ) - return _reduce(client, add, values) + values = self.client.gather(values) + return np.sum(values) @property def client(self): @@ -225,8 +225,8 @@ def deriv(self, m, f=None): ) count += 1 - derivs = _reduce(client, add, derivs) - return derivs + derivs = self.client.gather(derivs) + return np.sum(derivs, axis=0) def deriv2(self, m, v=None, f=None): """ @@ -266,7 +266,8 @@ def deriv2(self, m, v=None, f=None): ) count += 1 - derivs = _reduce(client, add, derivs) + derivs = self.client.gather(derivs) + derivs = np.sum(derivs, axis=0) return derivs @@ -313,7 +314,8 @@ def getJtJdiag(self, m, f=None): workers=worker, ) ) - jtj_diag += _reduce(client, add, work) + work = client.gather(work) + jtj_diag += np.sum(work, axis=0) self._jtjdiag = jtj_diag From 936f9af13d641abc31a0d7771065f49705dd9ee5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 14:32:03 -0800 Subject: [PATCH 47/84] Add temporary debug prints (to be reverted) --- simpeg/dask/inverse_problem.py | 3 +++ simpeg/directives/directives.py | 1 + 2 files changed, 4 insertions(+) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 80013e3b07..a43ee6127a 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -29,9 +29,11 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m + print("Computing dpred") self.dpred = self.get_dpred(m) residuals = [] + print("Computing residuals") if isinstance(self.dmisfit, DaskComboMisfits): residuals = self.dmisfit.residuals(m) else: @@ -93,6 +95,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): out = (phi,) if return_g: + print("Computing gradient") phi_dDeriv = self.dmisfit.deriv(m) phi_mDeriv = self.reg.deriv(m) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 1873eaea88..358597f135 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -360,6 +360,7 @@ def call(self, ruleType): '", "'.join(directives) ) for r in self.dList: + print(f"Running directive {r}") getattr(r, ruleType)() def validate(self): From b501e16a28371700fd1e035859d705f61f42b675 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 15:13:42 -0800 Subject: [PATCH 48/84] Remove fields calls --- simpeg/dask/objective_function.py | 52 ++++++++++++++----------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 3a7a3f2717..b3fcca7df9 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -11,13 +11,13 @@ def _calc_fields(objfct, model): return objfct.simulation.fields(m=objfct.simulation.model) -def _calc_dpred(objfct, model, field): - return objfct.simulation.dpred(m=objfct.simulation.model, f=field) +def _calc_dpred(objfct, model): + return objfct.simulation.dpred(m=objfct.simulation.model) -def _calc_residual(objfct, model, field): +def _calc_residual(objfct, model): return objfct.W * ( - objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model, f=field) + objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) ) @@ -145,20 +145,23 @@ def __call__(self, m, f=None): client = self.client m_future = self._m_as_future - if f is None: - f = self.fields(m) - values = [] - for phi, field, worker in zip(self, f, self._workers): - multiplier, objfct = phi - if multiplier == 0.0: # don't evaluate the fct - continue + count = 0 + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): - values.append( - client.submit( - _calc_objective, objfct, multiplier, m_future, field, workers=worker + if self.multipliers[count] == 0.0: + continue + + values.append( + client.submit( + _calc_objective, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) ) - ) values = self.client.gather(values) return np.sum(values) @@ -274,20 +277,16 @@ def deriv2(self, m, v=None, f=None): def get_dpred(self, m, f=None): self.model = m - if f is None: - f = self.fields(m) - client = self.client m_future = self._m_as_future dpred = [] - for futures, fields in zip(self._futures, f): - for objfct, worker, field in zip(futures, self._workers, fields): + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): dpred.append( client.submit( _calc_dpred, objfct, m_future, - field, workers=worker, ) ) @@ -393,9 +392,6 @@ def objfcts(self, objfcts): workers=self.workers, return_workers=True, ) - # for objfct, future in zip(objfcts, futures): - # if hasattr(objfct, "name"): - # future.name = objfct.name self._objfcts = objfcts self._futures = futures @@ -406,19 +402,17 @@ def residuals(self, m, f=None): Compute the residual for the data misfit. """ self.model = m - if f is None: - f = self.fields(m) + client = self.client m_future = self._m_as_future residuals = [] - for futures, fields in zip(self._futures, f): - for objfct, worker, field in zip(futures, self._workers, fields): + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): residuals.append( client.submit( _calc_residual, objfct, m_future, - field, workers=worker, ) ) From 16d0f4ceae53d6ce6b39c83144211bc2941326a9 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 16:15:13 -0800 Subject: [PATCH 49/84] More prints around save geoh5 --- simpeg/directives/directives.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 358597f135..cdc3c55d34 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3176,6 +3176,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq # Save results with fetch_active_workspace(self._geoh5, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] + print("Saving to file") for cc, component in enumerate(self.components): if component not in self.data_type: self.data_type[component] = {} @@ -3267,6 +3268,7 @@ def get_values(self, values: list[np.ndarray] | None): else: dpred = getattr(self.invProb, "dpred", None) if dpred is None: + print("Computing dpred") dpred = self.invProb.get_dpred(self.invProb.model) self.invProb.dpred = dpred From a2727e14fa7f7611f1445167c4c380521e4664c1 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 16:49:19 -0800 Subject: [PATCH 50/84] Skip dpred for residuals --- simpeg/dask/inverse_problem.py | 2 +- simpeg/dask/objective_function.py | 15 ++++++--------- simpeg/directives/directives.py | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index a43ee6127a..96b511c335 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -35,7 +35,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): residuals = [] print("Computing residuals") if isinstance(self.dmisfit, DaskComboMisfits): - residuals = self.dmisfit.residuals(m) + residuals = self.dmisfit.residuals(m, self.dpred) else: for (_, objfct), pred in zip(self.dmisfit, self.dpred): residuals.append(objfct.W * (objfct.data.dobs - pred)) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index b3fcca7df9..ade93538d9 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -15,10 +15,8 @@ def _calc_dpred(objfct, model): return objfct.simulation.dpred(m=objfct.simulation.model) -def _calc_residual(objfct, model): - return objfct.W * ( - objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) - ) +def _calc_residual(objfct, dpred): + return objfct.W * (objfct.data.dobs - dpred) def _deriv(objfct, multiplier, model): @@ -280,6 +278,7 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] + print("in dpred") for futures in self._futures: for objfct, worker in zip(futures, self._workers): dpred.append( @@ -397,22 +396,20 @@ def objfcts(self, objfcts): self._futures = futures self._workers = workers - def residuals(self, m, f=None): + def residuals(self, m, dpreds, f=None): """ Compute the residual for the data misfit. """ self.model = m - client = self.client - m_future = self._m_as_future residuals = [] for futures in self._futures: - for objfct, worker in zip(futures, self._workers): + for objfct, worker, dpred in zip(futures, self._workers, dpreds): residuals.append( client.submit( _calc_residual, objfct, - m_future, + dpred, workers=worker, ) ) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index cdc3c55d34..2dbfb747a1 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3271,7 +3271,7 @@ def get_values(self, values: list[np.ndarray] | None): print("Computing dpred") dpred = self.invProb.get_dpred(self.invProb.model) self.invProb.dpred = dpred - + print("Done") if self.joint_index is not None: dpred = [dpred[ind] for ind in self.joint_index] From 580bda9021f69654026ff3301578857f38b571bf Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 17:15:01 -0800 Subject: [PATCH 51/84] Unhook special dpred --- simpeg/dask/electromagnetics/frequency_domain/simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 42364794d5..55bb4de8f4 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -412,5 +412,5 @@ def parallel_block_compute( Sim.Jtvec = Jtvec Sim.Jmatrix = Jmatrix Sim.fields = fields -Sim.dpred = dpred +# Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm From 74777e569834475ae523a92a637c37fb8a866373 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 18:35:40 -0800 Subject: [PATCH 52/84] Revert "Skip dpred for residuals" This reverts commit a2727e14fa7f7611f1445167c4c380521e4664c1. --- simpeg/dask/inverse_problem.py | 2 +- simpeg/dask/objective_function.py | 15 +++++++++------ simpeg/directives/directives.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 96b511c335..a43ee6127a 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -35,7 +35,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): residuals = [] print("Computing residuals") if isinstance(self.dmisfit, DaskComboMisfits): - residuals = self.dmisfit.residuals(m, self.dpred) + residuals = self.dmisfit.residuals(m) else: for (_, objfct), pred in zip(self.dmisfit, self.dpred): residuals.append(objfct.W * (objfct.data.dobs - pred)) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index ade93538d9..b3fcca7df9 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -15,8 +15,10 @@ def _calc_dpred(objfct, model): return objfct.simulation.dpred(m=objfct.simulation.model) -def _calc_residual(objfct, dpred): - return objfct.W * (objfct.data.dobs - dpred) +def _calc_residual(objfct, model): + return objfct.W * ( + objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) + ) def _deriv(objfct, multiplier, model): @@ -278,7 +280,6 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] - print("in dpred") for futures in self._futures: for objfct, worker in zip(futures, self._workers): dpred.append( @@ -396,20 +397,22 @@ def objfcts(self, objfcts): self._futures = futures self._workers = workers - def residuals(self, m, dpreds, f=None): + def residuals(self, m, f=None): """ Compute the residual for the data misfit. """ self.model = m + client = self.client + m_future = self._m_as_future residuals = [] for futures in self._futures: - for objfct, worker, dpred in zip(futures, self._workers, dpreds): + for objfct, worker in zip(futures, self._workers): residuals.append( client.submit( _calc_residual, objfct, - dpred, + m_future, workers=worker, ) ) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 2dbfb747a1..cdc3c55d34 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3271,7 +3271,7 @@ def get_values(self, values: list[np.ndarray] | None): print("Computing dpred") dpred = self.invProb.get_dpred(self.invProb.model) self.invProb.dpred = dpred - print("Done") + if self.joint_index is not None: dpred = [dpred[ind] for ind in self.joint_index] From 1f6256588e928a012980a20c2e6356831814374c Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 21:36:32 -0800 Subject: [PATCH 53/84] Batch compute dpreds --- simpeg/dask/objective_function.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index b3fcca7df9..8619ac75ab 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -281,8 +281,9 @@ def get_dpred(self, m, f=None): m_future = self._m_as_future dpred = [] for futures in self._futures: + future_preds = [] for objfct, worker in zip(futures, self._workers): - dpred.append( + future_preds.append( client.submit( _calc_dpred, objfct, @@ -290,7 +291,9 @@ def get_dpred(self, m, f=None): workers=worker, ) ) - return client.gather(dpred) + dpred += client.gather(future_preds) + + return dpred def getJtJdiag(self, m, f=None): self.model = m From 92e531d907f8c61c4d7c73808b0c2bf946ac4802 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 16 Jan 2025 23:00:28 -0800 Subject: [PATCH 54/84] Remove test prints --- simpeg/dask/inverse_problem.py | 6 +----- simpeg/directives/directives.py | 3 --- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index a43ee6127a..aeb2da9878 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -28,12 +28,9 @@ def get_dpred(self, m, f=None): def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m - - print("Computing dpred") self.dpred = self.get_dpred(m) - residuals = [] - print("Computing residuals") + if isinstance(self.dmisfit, DaskComboMisfits): residuals = self.dmisfit.residuals(m) else: @@ -95,7 +92,6 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): out = (phi,) if return_g: - print("Computing gradient") phi_dDeriv = self.dmisfit.deriv(m) phi_mDeriv = self.reg.deriv(m) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index cdc3c55d34..1873eaea88 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -360,7 +360,6 @@ def call(self, ruleType): '", "'.join(directives) ) for r in self.dList: - print(f"Running directive {r}") getattr(r, ruleType)() def validate(self): @@ -3176,7 +3175,6 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq # Save results with fetch_active_workspace(self._geoh5, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] - print("Saving to file") for cc, component in enumerate(self.components): if component not in self.data_type: self.data_type[component] = {} @@ -3268,7 +3266,6 @@ def get_values(self, values: list[np.ndarray] | None): else: dpred = getattr(self.invProb, "dpred", None) if dpred is None: - print("Computing dpred") dpred = self.invProb.get_dpred(self.invProb.model) self.invProb.dpred = dpred From 4130256cdb649e503f35ba64a08d12e618b43bb9 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 17 Jan 2025 07:31:12 -0800 Subject: [PATCH 55/84] Send residuals in batch too --- simpeg/dask/objective_function.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 8619ac75ab..8a8c0b7638 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -246,9 +246,6 @@ def deriv2(self, m, v=None, f=None): m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - # if f is None: - # f = self.fields(m) - derivs = [] count = 0 for futures in self._futures: @@ -410,8 +407,9 @@ def residuals(self, m, f=None): m_future = self._m_as_future residuals = [] for futures in self._futures: + future_residuals = [] for objfct, worker in zip(futures, self._workers): - residuals.append( + future_residuals.append( client.submit( _calc_residual, objfct, @@ -419,4 +417,6 @@ def residuals(self, m, f=None): workers=worker, ) ) - return client.gather(residuals) + residuals += client.gather(future_residuals) + + return residuals From e7275750823cd5f01f54c1290854d09b6823e43e Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 19 Jan 2025 09:41:13 -0800 Subject: [PATCH 56/84] Only store factorization on compute_J call --- .../frequency_domain/simulation.py | 16 +++++++++------- simpeg/meta/simulation.py | 10 ++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 55bb4de8f4..01e5136a99 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -219,7 +219,7 @@ def dpred(self, m=None, f=None, compute_J=False): return rows -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -237,10 +237,12 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - self.Ainv = Ainv + # self.Ainv = Ainv self._stashed_fields = f + if return_Ainv: + return f, Ainv return f @@ -248,15 +250,15 @@ def compute_J(self, m, f=None): self.model = m if f is None: - f = self.fields(m) + f, Ainv = self.fields(m=m, return_Ainv=True) - if len(self.Ainv) > 1: + if len(Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(self.Ainv.values())[0] + A_i = list(Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -321,10 +323,10 @@ def compute_J(self, m, f=None): m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in self.Ainv.values(): + for A in Ainv.values(): A.clean() - del self.Ainv + del Ainv gc.collect() if self.store_sensitivities == "disk": del Jmatrix diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index 7c52e5fa8a..f0ddd27c96 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -307,16 +307,14 @@ def getJtJdiag(self, m, W=None, f=None): # (i.e. projections, multipliers, etc.). # It is usually close within a scaling factor for others, whose accuracy is controlled # by how diagonally dominant JtJ is. - if f is None: - f = self.fields(m) - for i, (mapping, sim, field) in enumerate( - zip(self.mappings, self.simulations, f) - ): + # if f is None: + # f = self.fields(m) + for i, (mapping, sim) in enumerate(zip(self.mappings, self.simulations)): if self._repeat_sim: sim.model = mapping * self.model sim_w = sp.diags(W[self._data_offsets[i] : self._data_offsets[i + 1]]) sim_jtj = sp.diags( - np.sqrt(np.asarray(sim.getJtJdiag(sim.model, sim_w, f=field))) + np.sqrt(np.asarray(sim.getJtJdiag(sim.model, sim_w))) ) m_deriv = mapping.deriv(self.model) jtj_diag += np.asarray( From 91d40c5429aa34b33fe3ba22dbf45ac97ad1d3e7 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 20 Jan 2025 10:27:21 -0800 Subject: [PATCH 57/84] Add prints for debug --- .../frequency_domain/simulation.py | 2 +- simpeg/dask/objective_function.py | 43 ++++++++++++++----- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 01e5136a99..e6a10aef4e 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -223,7 +223,7 @@ def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m - if getattr(self, "_stashed_fields", None) is not None: + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: return self._stashed_fields f = self.fieldsPair(self) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 8a8c0b7638..a375e35647 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -3,9 +3,19 @@ import numpy as np from dask.distributed import Client from ..data_misfit import L2DataMisfit - +import os from simpeg.utils import validate_list_of_types +OUTFILE = os.getcwd() + "/update.txt" + + +def write_message(message, mode="a"): + with open(OUTFILE, mode) as f: + f.write(message + "\n") + + +write_message("Starting", mode="w+") + def _calc_fields(objfct, model): return objfct.simulation.fields(m=objfct.simulation.model) @@ -210,14 +220,16 @@ def deriv(self, m, f=None): # if f is None: # f = self.fields(m) - derivs = [] + derivs = 0.0 count = 0 + write_message("Calculating deriv") for futures in self._futures: + future_deriv = [] for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue - derivs.append( + future_deriv.append( client.submit( _deriv, objfct, @@ -226,10 +238,12 @@ def deriv(self, m, f=None): workers=worker, ) ) + count += 1 + future_deriv = client.gather(future_deriv) + derivs += np.sum(future_deriv, axis=0) - derivs = self.client.gather(derivs) - return np.sum(derivs, axis=0) + return derivs def deriv2(self, m, v=None, f=None): """ @@ -246,14 +260,17 @@ def deriv2(self, m, v=None, f=None): m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - derivs = [] + derivs = 0.0 count = 0 + write_message("Calculating deriv2") for futures in self._futures: + + future_derivs = [] for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue - derivs.append( + future_derivs.append( client.submit( _deriv2, objfct, @@ -266,8 +283,8 @@ def deriv2(self, m, v=None, f=None): ) count += 1 - derivs = self.client.gather(derivs) - derivs = np.sum(derivs, axis=0) + future_derivs = self.client.gather(future_derivs) + derivs += np.sum(future_derivs, axis=0) return derivs @@ -277,6 +294,7 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] + write_message("Calculating dpred") for futures in self._futures: future_preds = [] for objfct, worker in zip(futures, self._workers): @@ -299,10 +317,13 @@ def getJtJdiag(self, m, f=None): jtj_diag = 0.0 client = self.client + + write_message("Calculating JtJdiag") # if f is None: # f = self.fields(m) - for futures in self._futures: + for ii, futures in enumerate(self._futures): work = [] + write_message(f"Future {ii} of {len(self._futures)}") for objfct, worker in zip(futures, self._workers): work.append( client.submit( @@ -328,6 +349,7 @@ def fields(self, m): return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] + write_message("Calculating fields") for futures in self._futures: f.append([]) for objfct, worker in zip(futures, self._workers): @@ -406,6 +428,7 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future residuals = [] + write_message("Calculating residuals") for futures in self._futures: future_residuals = [] for objfct, worker in zip(futures, self._workers): From d08e90e60fb197bf8b599168fb15a5aca3b49197 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 22 Jan 2025 12:08:27 -0800 Subject: [PATCH 58/84] Simplify getSource --- .../frequency_domain/simulation.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index e6a10aef4e..5ea14ced6f 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -100,35 +100,14 @@ def getSourceTerm(self, freq, source=None): if source is None: source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, self.n_threads) - block_compute = [] - if self.client: - sim = self.client.scatter(self, workers=self.worker) - source_block = self.client.scatter(source_block, workers=self.worker) - else: - delayed_source_eval = delayed(source_eval) + for block in [source_list]: + block_compute.append(source_eval(self, block)) - for block in source_block: - if self.client: - block_compute.append( - self.client.submit(source_eval, sim, block, workers=self.worker) - ) - else: - block_compute.append(delayed_source_eval(self, block)) - - if self.client: - blocks = self.client.gather(block_compute) - else: - blocks = compute(block_compute)[0] s_m, s_e = [], [] - for block in blocks: + for block in block_compute: if block[0]: - # for source in self.survey.get_sources_by_frequency(freq): - # sm, se = source.eval(self) - # s_m.append(sm) - # s_e.append(se) s_m += block[0] s_e += block[1] From c69531ea1fd52c998d815664dfc6256fe52ca2cd Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 22 Jan 2025 12:39:38 -0800 Subject: [PATCH 59/84] Try again using local get_client --- .../frequency_domain/simulation.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 5ea14ced6f..9d12944294 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -10,6 +10,7 @@ from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr +from dask.distributed import get_client def receivers_eval(block, mesh, fields): @@ -228,6 +229,8 @@ def fields(self, m=None, return_Ainv=False): def compute_J(self, m, f=None): self.model = m + client = get_client() + if f is None: f, Ainv = self.fields(m=m, return_Ainv=True) @@ -256,16 +259,14 @@ def compute_J(self, m, f=None): ) fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - if self.client: - fields_array = self.client.scatter( - f[:, self._solutionType], workers=self.worker - ) - fields = self.client.scatter(f, workers=self.worker) - survey = self.client.scatter(self.survey, workers=self.worker) - mesh = self.client.scatter(self.mesh, workers=self.worker) + if client: + fields_array = client.scatter(f[:, self._solutionType], workers=self.worker) + fields = client.scatter(f, workers=self.worker) + survey = client.scatter(self.survey, workers=self.worker) + mesh = client.scatter(self.mesh, workers=self.worker) for block in blocks: blocks_receiver_derivs.append( - self.client.submit( + client.submit( receiver_derivs, survey, mesh, @@ -292,14 +293,14 @@ def compute_J(self, m, f=None): ) # Dask process for all derivatives - if self.client: - blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) + if client: + blocks_receiver_derivs = client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): Jmatrix = self.parallel_block_compute( - m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks, client ) for A in Ainv.values(): @@ -315,14 +316,14 @@ def compute_J(self, m, f=None): def parallel_block_compute( - self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses, client ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = A_i * block_stack - if self.client: - ATinvdf_duT = self.client.scatter(ATinvdf_duT, workers=self.worker) + if client: + ATinvdf_duT = client.scatter(ATinvdf_duT, workers=self.worker) else: ATinvdf_duT = delayed(ATinvdf_duT) count = 0 @@ -333,10 +334,10 @@ def parallel_block_compute( n_cols = dfduT.shape[1] n_rows = address[1][2] - if self.client: - sim = self.client.scatter(self, workers=self.worker) + if client: + sim = client.scatter(self, workers=self.worker) block_delayed.append( - self.client.submit( + client.submit( eval_block, sim, ATinvdf_duT, @@ -368,8 +369,8 @@ def parallel_block_compute( indices = np.hstack(rows) - if self.client: - block_delayed = self.client.gather(block_delayed) + if client: + block_delayed = client.gather(block_delayed) block = np.vstack(block_delayed) else: block = compute(array.vstack(block_delayed))[0] From 97fbfd683dffb842ae511802dd9f24b196067a01 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 22 Jan 2025 12:56:35 -0800 Subject: [PATCH 60/84] Revert "Try again using local get_client" This reverts commit c69531ea1fd52c998d815664dfc6256fe52ca2cd. --- .../frequency_domain/simulation.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 9d12944294..5ea14ced6f 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -10,7 +10,6 @@ from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr -from dask.distributed import get_client def receivers_eval(block, mesh, fields): @@ -229,8 +228,6 @@ def fields(self, m=None, return_Ainv=False): def compute_J(self, m, f=None): self.model = m - client = get_client() - if f is None: f, Ainv = self.fields(m=m, return_Ainv=True) @@ -259,14 +256,16 @@ def compute_J(self, m, f=None): ) fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - if client: - fields_array = client.scatter(f[:, self._solutionType], workers=self.worker) - fields = client.scatter(f, workers=self.worker) - survey = client.scatter(self.survey, workers=self.worker) - mesh = client.scatter(self.mesh, workers=self.worker) + if self.client: + fields_array = self.client.scatter( + f[:, self._solutionType], workers=self.worker + ) + fields = self.client.scatter(f, workers=self.worker) + survey = self.client.scatter(self.survey, workers=self.worker) + mesh = self.client.scatter(self.mesh, workers=self.worker) for block in blocks: blocks_receiver_derivs.append( - client.submit( + self.client.submit( receiver_derivs, survey, mesh, @@ -293,14 +292,14 @@ def compute_J(self, m, f=None): ) # Dask process for all derivatives - if client: - blocks_receiver_derivs = client.gather(blocks_receiver_derivs) + if self.client: + blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): Jmatrix = self.parallel_block_compute( - m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks, client + m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) for A in Ainv.values(): @@ -316,14 +315,14 @@ def compute_J(self, m, f=None): def parallel_block_compute( - self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses, client + self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = A_i * block_stack - if client: - ATinvdf_duT = client.scatter(ATinvdf_duT, workers=self.worker) + if self.client: + ATinvdf_duT = self.client.scatter(ATinvdf_duT, workers=self.worker) else: ATinvdf_duT = delayed(ATinvdf_duT) count = 0 @@ -334,10 +333,10 @@ def parallel_block_compute( n_cols = dfduT.shape[1] n_rows = address[1][2] - if client: - sim = client.scatter(self, workers=self.worker) + if self.client: + sim = self.client.scatter(self, workers=self.worker) block_delayed.append( - client.submit( + self.client.submit( eval_block, sim, ATinvdf_duT, @@ -369,8 +368,8 @@ def parallel_block_compute( indices = np.hstack(rows) - if client: - block_delayed = client.gather(block_delayed) + if self.client: + block_delayed = self.client.gather(block_delayed) block = np.vstack(block_delayed) else: block = compute(array.vstack(block_delayed))[0] From 8140c41a1cf78a5dd679a9c403a5bac1df2e7852 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 23 Jan 2025 16:40:37 -0800 Subject: [PATCH 61/84] Fix issue with client not stored on FEM class --- .../frequency_domain/simulation.py | 71 ++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 5ea14ced6f..fdd0ed1e19 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -7,6 +7,7 @@ import scipy.sparse as sp from dask import array, compute, delayed +from dask.distributed import get_client from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr @@ -256,16 +257,18 @@ def compute_J(self, m, f=None): ) fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - if self.client: - fields_array = self.client.scatter( - f[:, self._solutionType], workers=self.worker - ) - fields = self.client.scatter(f, workers=self.worker) - survey = self.client.scatter(self.survey, workers=self.worker) - mesh = self.client.scatter(self.mesh, workers=self.worker) + + client = get_client() + + if client: + fields_array = client.scatter(f[:, self._solutionType], workers=self.worker) + fields = client.scatter(f, workers=self.worker) + survey = client.scatter(self.survey, workers=self.worker) + mesh = client.scatter(self.mesh, workers=self.worker) + simulation = client.scatter(self, workers=self.worker) for block in blocks: blocks_receiver_derivs.append( - self.client.submit( + client.submit( receiver_derivs, survey, mesh, @@ -279,7 +282,7 @@ def compute_J(self, m, f=None): fields = delayed(f) survey = delayed(self.survey) mesh = delayed(self.mesh) - + simulation = delayed(self) delayed_derivs = delayed(receiver_derivs) for block in blocks: blocks_receiver_derivs.append( @@ -292,14 +295,23 @@ def compute_J(self, m, f=None): ) # Dask process for all derivatives - if self.client: - blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) + if client: + blocks_receiver_derivs = client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): - Jmatrix = self.parallel_block_compute( - m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + Jmatrix = parallel_block_compute( + simulation, + m, + Jmatrix, + block_derivs_chunks, + A_i, + fields_array, + addresses_chunks, + client, + self.worker, + store_sensitivities=self.store_sensitivities, ) for A in Ainv.values(): @@ -315,14 +327,24 @@ def compute_J(self, m, f=None): def parallel_block_compute( - self, m, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses + simulation, + m, + Jmatrix, + blocks_receiver_derivs, + A_i, + fields_array, + addresses, + client, + worker, + store_sensitivities="disk", ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() ATinvdf_duT = A_i * block_stack - if self.client: - ATinvdf_duT = self.client.scatter(ATinvdf_duT, workers=self.worker) + + if client: + ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker) else: ATinvdf_duT = delayed(ATinvdf_duT) count = 0 @@ -333,18 +355,17 @@ def parallel_block_compute( n_cols = dfduT.shape[1] n_rows = address[1][2] - if self.client: - sim = self.client.scatter(self, workers=self.worker) + if client: block_delayed.append( - self.client.submit( + client.submit( eval_block, - sim, + simulation, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, - workers=self.worker, + workers=worker, ) ) else: @@ -352,7 +373,7 @@ def parallel_block_compute( block_delayed.append( array.from_delayed( delayed_eval( - self, + simulation, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), @@ -368,13 +389,13 @@ def parallel_block_compute( indices = np.hstack(rows) - if self.client: - block_delayed = self.client.gather(block_delayed) + if client: + block_delayed = client.gather(block_delayed) block = np.vstack(block_delayed) else: block = compute(array.vstack(block_delayed))[0] - if self.store_sensitivities == "disk": + if store_sensitivities == "disk": Jmatrix.set_orthogonal_selection( (indices, slice(None)), block, From 39aa97a65a10df91627934d0fb0e16ac191bd815 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 25 Jan 2025 08:21:48 -0800 Subject: [PATCH 62/84] Remove client from property of simulation --- .../frequency_domain/simulation.py | 130 +++++++------- .../time_domain/simulation.py | 162 +++++++++--------- simpeg/dask/objective_function.py | 7 +- simpeg/dask/simulation.py | 22 +-- 4 files changed, 157 insertions(+), 164 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index fdd0ed1e19..5cbce13297 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -132,71 +132,71 @@ def getSourceTerm(self, freq, source=None): return s_m, s_e -def dpred(self, m=None, f=None, compute_J=False): - r""" - dpred(m, f=None) - Create the projected data from a model. - The fields, f, (if provided) will be used for the predicted data - instead of recalculating the fields (which may be expensive!). - - .. math:: - - d_\\text{pred} = P(f(m)) - - Where P is a projection of the fields onto the data space. - """ - if self.survey is None: - raise AttributeError( - "The survey has not yet been set and is required to compute " - "data. Please set the survey for the simulation: " - "simulation.survey = survey" - ) - - if f is None: - if m is None: - m = self.model - f = self.fields(m) - - all_receivers = [] - - for ind, src in enumerate(self.survey.source_list): - for rx in src.receiver_list: - all_receivers.append((src, ind, rx)) - - receiver_blocks = np.array_split(np.asarray(all_receivers), self.n_threads) - rows = [] - - if self.client: - f = self.client.scatter(f, workers=self.worker) - mesh = self.client.scatter(self.mesh, workers=self.worker) - else: - delayed_receivers_eval = delayed(receivers_eval) - mesh = delayed(self.mesh) - - for block in receiver_blocks: - n_data = np.sum([rec.nD for _, _, rec in block]) - if n_data == 0: - continue - - if self.client: - rows.append( - self.client.submit(receivers_eval, block, mesh, f, workers=self.worker) - ) - else: - rows.append( - array.from_delayed( - delayed_receivers_eval(block, mesh, f), - dtype=np.float64, - shape=(n_data,), - ) - ) - - if self.client: - rows = np.hstack(self.client.gather(rows)) - else: - rows = compute(array.hstack(rows))[0] - - return rows +# def dpred(self, m=None, f=None, compute_J=False): +# r""" +# dpred(m, f=None) +# Create the projected data from a model. +# The fields, f, (if provided) will be used for the predicted data +# instead of recalculating the fields (which may be expensive!). +# +# .. math:: +# +# d_\\text{pred} = P(f(m)) +# +# Where P is a projection of the fields onto the data space. +# """ +# if self.survey is None: +# raise AttributeError( +# "The survey has not yet been set and is required to compute " +# "data. Please set the survey for the simulation: " +# "simulation.survey = survey" +# ) +# +# if f is None: +# if m is None: +# m = self.model +# f = self.fields(m) +# +# all_receivers = [] +# +# for ind, src in enumerate(self.survey.source_list): +# for rx in src.receiver_list: +# all_receivers.append((src, ind, rx)) +# +# receiver_blocks = np.array_split(np.asarray(all_receivers), self.n_threads) +# rows = [] +# +# if self.client: +# f = self.client.scatter(f, workers=self.worker) +# mesh = self.client.scatter(self.mesh, workers=self.worker) +# else: +# delayed_receivers_eval = delayed(receivers_eval) +# mesh = delayed(self.mesh) +# +# for block in receiver_blocks: +# n_data = np.sum([rec.nD for _, _, rec in block]) +# if n_data == 0: +# continue +# +# if self.client: +# rows.append( +# self.client.submit(receivers_eval, block, mesh, f, workers=self.worker) +# ) +# else: +# rows.append( +# array.from_delayed( +# delayed_receivers_eval(block, mesh, f), +# dtype=np.float64, +# shape=(n_data,), +# ) +# ) +# +# if self.client: +# rows = np.hstack(self.client.gather(rows)) +# else: +# rows = compute(array.hstack(rows))[0] +# +# return rows def fields(self, m=None, return_Ainv=False): diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 8a50032ee8..c9ce885ca2 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -4,12 +4,14 @@ from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero -from ...simulation import client, getJtJdiag, Jvec, Jtvec, Jmatrix -from simpeg.fields import TimeFields -from multiprocessing import cpu_count +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix + +# from simpeg.fields import TimeFields +# from multiprocessing import cpu_count import numpy as np import scipy.sparse as sp from dask import array, delayed +from dask.distributed import get_client from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -50,49 +52,49 @@ def fields(self, m=None): return f -def getSourceTerm(self, tInd): - """ - Assemble the source term. This ensures that the RHS is a vector / array - of the correct size - """ - source_list = self.survey.source_list - source_block = np.array_split(source_list, cpu_count()) - - block_compute = [] - - if self.client: - sim = self.client.scatter(self) - else: - delayed_source_eval = delayed(source_evaluation) - - for block in source_block: - if self.client: - block_compute.append( - self.client.submit( - source_evaluation, - sim, - block, - self.times[tInd], - ) - ) - else: - block_compute.append(delayed_source_eval(self, block, self.times[tInd])) - - if self.client: - blocks = self.client.gather(block_compute) - else: - blocks = dask.compute(block_compute)[0] - - s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m.append(block[0]) - s_e.append(block[1]) - - if isinstance(s_m[0][0], Zero): - return Zero(), np.vstack(s_e).T - - return np.vstack(s_m).T, np.vstack(s_e).T +# # def getSourceTerm(self, tInd): +# """ +# Assemble the source term. This ensures that the RHS is a vector / array +# of the correct size +# """ +# source_list = self.survey.source_list +# source_block = np.array_split(source_list, cpu_count()) +# +# block_compute = [] +# +# if client: +# sim = client.scatter(self, workers=self.worker) +# else: +# delayed_source_eval = delayed(source_evaluation) +# +# for block in source_block: +# if client: +# block_compute.append( +# client.submit( +# source_evaluation, +# sim, +# block, +# self.times[tInd], +# ) +# ) +# else: +# block_compute.append(delayed_source_eval(self, block, self.times[tInd])) +# +# if client: +# blocks = client.gather(block_compute) +# else: +# blocks = dask.compute(block_compute)[0] +# +# s_m, s_e = [], [] +# for block in blocks: +# if block[0]: +# s_m.append(block[0]) +# s_e.append(block[1]) +# +# if isinstance(s_m[0][0], Zero): +# return Zero(), np.vstack(s_e).T +# +# return np.vstack(s_m).T, np.vstack(s_e).T # def dpred(self, m=None, f=None): @@ -183,19 +185,19 @@ def compute_J(self, m, f=None): compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) fields_array = f[:, ftype, :] - + client = get_client() if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] times_field_derivs, Jmatrix = compute_field_derivs( - self, f, blocks, Jmatrix, fields_array.shape + self, f, blocks, Jmatrix, fields_array.shape, client ) ATinv_df_duT_v = {} - if self.client: - fields_array = self.client.scatter(fields_array) - sim = self.client.scatter(self) + if client: + fields_array = client.scatter(fields_array, workers=self.worker) + sim = client.scatter(self, workers=self.worker) else: delayed_compute_rows = delayed(compute_rows) @@ -215,9 +217,9 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - if self.client: + if client: j_row_updates.append( - self.client.submit( + client.submit( compute_rows, sim, tInd, @@ -225,6 +227,7 @@ def compute_J(self, m, f=None): ATinv_df_duT_v, fields_array, time_mask, + workers=self.worker, ) ) else: @@ -246,8 +249,8 @@ def compute_J(self, m, f=None): ) ) - if self.client: - j_row_updates = np.vstack(self.client.gather(j_row_updates)) + if client: + j_row_updates = np.vstack(client.gather(j_row_updates)) else: j_row_updates = array.vstack(j_row_updates).compute() @@ -312,24 +315,25 @@ def _getField(self, name, ind, src_list): else: # loop over the time steps arrays = [] - if self.client: - pointerFields = self.client.scatter(pointerFields) - src_list = self.client.scatter(src_list) - func = self.client.scatter(func) + if client: + pointerFields = client.scatter(pointerFields, workers=self.worker) + src_list = client.scatter(src_list, workers=self.worker) + func = client.scatter(func, workers=self.worker) else: delayed_field_comp = delayed(field_projection) for i, TIND_i in enumerate(timeII): # Need to parallelize this - if self.client: + if client: arrays.append( - self.client.submit( + client.submit( field_projection, pointerFields, src_list, i, TIND_i, func, + workers=self.worker, ) ) else: @@ -343,8 +347,8 @@ def _getField(self, name, ind, src_list): ) ) - if self.client: - arrays = self.client.gather(arrays) + if client: + arrays = client.gather(arrays) out = np.dstack(arrays) else: out = array.dstack(arrays).compute() @@ -354,7 +358,6 @@ def _getField(self, name, ind, src_list): # TimeFields._getField = _getField -TimeFields.client = client def field_projection(field_array, src_list, array_ind, time_ind, func): @@ -392,16 +395,16 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): return np.hstack(data) -def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): +def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): """ Compute the derivative of the fields """ delayed_chunks = [] - if self.client: - mesh = self.client.scatter(self.mesh) - time_mesh = self.client.scatter(self.time_mesh) - fields = self.client.scatter(fields) + if client: + mesh = client.scatter(self.mesh, workers=self.worker) + time_mesh = client.scatter(self.time_mesh, workers=self.worker) + fields = client.scatter(fields, workers=self.worker) else: mesh = self.mesh time_mesh = self.time_mesh @@ -411,9 +414,9 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): if len(chunks) == 0: continue - if self.client: + if client: delayed_chunks.append( - self.client.submit( + client.submit( block_deriv, self.nT, chunks, @@ -423,6 +426,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): time_mesh, fields, self.model.size, + workers=self.worker, ) ) else: @@ -439,8 +443,8 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape): ) ) - if self.client: - result = self.client.gather(delayed_chunks) + if client: + result = client.gather(delayed_chunks) else: result = dask.compute(delayed_chunks)[0] @@ -519,9 +523,9 @@ def get_field_deriv_block( ) count += len(local_ind) - if self.client: + if client: stacked_blocks.append( - self.client.submit( + client.submit( deriv_block, s_id, r_id, @@ -531,6 +535,7 @@ def get_field_deriv_block( local_ind, field_deriv, tInd, + workers=self.worker, ) ) else: @@ -556,8 +561,8 @@ def get_field_deriv_block( ) if len(stacked_blocks) > 0: - if self.client: - blocks = np.hstack(self.client.gather(stacked_blocks)) + if client: + blocks = np.hstack(client.gather(stacked_blocks)) else: blocks = array.hstack(stacked_blocks).compute() @@ -692,10 +697,9 @@ def compute_rows( return np.vstack(rows) -Sim.client = client Sim.fields = fields Sim.getJtJdiag = getJtJdiag -Sim.getSourceTerm = getSourceTerm +# Sim.getSourceTerm = getSourceTerm # Sim.dpred = dpred Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index a375e35647..c8ef35d06b 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -5,6 +5,7 @@ from ..data_misfit import L2DataMisfit import os from simpeg.utils import validate_list_of_types +from time import time OUTFILE = os.getcwd() + "/update.txt" @@ -323,7 +324,7 @@ def getJtJdiag(self, m, f=None): # f = self.fields(m) for ii, futures in enumerate(self._futures): work = [] - write_message(f"Future {ii} of {len(self._futures)}") + ct = time() for objfct, worker in zip(futures, self._workers): work.append( client.submit( @@ -334,7 +335,11 @@ def getJtJdiag(self, m, f=None): workers=worker, ) ) + work = client.gather(work) + write_message( + f"Future {ii} of {len(self._futures)} in {time() - ct:.3f} sec" + ) jtj_diag += np.sum(work, axis=0) self._jtjdiag = jtj_diag diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 7c4b2dbaab..ef2e64fdb1 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -1,7 +1,6 @@ from ..simulation import BaseSimulation as Sim from dask import array -from dask.distributed import get_client import numpy as np Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] @@ -42,20 +41,6 @@ def max_chunk_size(self, other): Sim.max_chunk_size = max_chunk_size -@property -def client(self): - if getattr(self, "_client", None) is None: - try: - self._client = get_client() - except ValueError: - self._client = False - - return self._client - - -Sim.client = client - - def getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ @@ -122,14 +107,13 @@ def Jmatrix(self): Sim.Jmatrix = Jmatrix -@property -def n_threads(self): +def n_threads(self, client=None): """ Number of threads used by Dask """ if getattr(self, "_n_threads", None) is None: - if self.client: - self._n_threads = self.client.nthreads()[self.worker[0]] + if client: + self._n_threads = client.nthreads()[self.worker[0]] else: self._n_threads = cpu_count() From f92e750745b618a6ba4378216f8a3d2e2c229b42 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 25 Jan 2025 08:27:27 -0800 Subject: [PATCH 63/84] Fix for non-client run --- .../frequency_domain/simulation.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 5cbce13297..37e1e4e88b 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -258,14 +258,19 @@ def compute_J(self, m, f=None): fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - client = get_client() + try: + client = get_client() + worker = self.worker + except ValueError: + client = None + worker = None if client: - fields_array = client.scatter(f[:, self._solutionType], workers=self.worker) - fields = client.scatter(f, workers=self.worker) - survey = client.scatter(self.survey, workers=self.worker) - mesh = client.scatter(self.mesh, workers=self.worker) - simulation = client.scatter(self, workers=self.worker) + fields_array = client.scatter(f[:, self._solutionType], workers=worker) + fields = client.scatter(f, workers=worker) + survey = client.scatter(self.survey, workers=worker) + mesh = client.scatter(self.mesh, workers=worker) + simulation = client.scatter(self, workers=worker) for block in blocks: blocks_receiver_derivs.append( client.submit( @@ -274,7 +279,7 @@ def compute_J(self, m, f=None): mesh, fields, block, - workers=self.worker, + workers=worker, ) ) else: @@ -310,7 +315,7 @@ def compute_J(self, m, f=None): fields_array, addresses_chunks, client, - self.worker, + worker, store_sensitivities=self.store_sensitivities, ) @@ -335,7 +340,7 @@ def parallel_block_compute( fields_array, addresses, client, - worker, + worker=None, store_sensitivities="disk", ): m_size = m.size From 61403e45ec78046594f555f439399dd9fd1b69ee Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 26 Jan 2025 09:14:42 -0800 Subject: [PATCH 64/84] Fix potential fields for client not on simulaiton --- simpeg/dask/potential_fields/base.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index e2767ce81a..03f9007363 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -1,7 +1,7 @@ import numpy as np from ...potential_fields.base import BasePFSimulation as Sim - +from dask.distributed import get_client import os from dask import delayed, array, config from ..utils import compute_chunk_sizes @@ -59,16 +59,21 @@ def linear_operator(self): ) block_split = np.array_split(self.survey.receiver_locations, n_blocks) - if self.client: - sim = self.client.scatter(self, workers=self.worker) + try: + client = get_client() + except ValueError: + client = None + + if client: + sim = client.scatter(self, workers=self.worker) else: delayed_compute = delayed(block_compute) rows = [] for block in block_split: - if self.client: + if client: rows.append( - self.client.submit( + client.submit( block_compute, sim, block, @@ -90,10 +95,10 @@ def linear_operator(self): ) ) - if self.client: + if client: if forward_only: - return np.hstack(self.client.gather(rows)) - return np.vstack(self.client.gather(rows)) + return np.hstack(client.gather(rows)) + return np.vstack(client.gather(rows)) if forward_only: stack = array.concatenate(rows) From 7d0ceba0db9492fe6c552e676ae4ee4bb5c54117 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 26 Jan 2025 09:15:22 -0800 Subject: [PATCH 65/84] Don't use mesh object to index Projection --- simpeg/electromagnetics/frequency_domain/receivers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simpeg/electromagnetics/frequency_domain/receivers.py b/simpeg/electromagnetics/frequency_domain/receivers.py index b424135a20..ef1510446a 100644 --- a/simpeg/electromagnetics/frequency_domain/receivers.py +++ b/simpeg/electromagnetics/frequency_domain/receivers.py @@ -153,8 +153,8 @@ def getP(self, mesh, projected_grid): scipy.sparse.csr_matrix P, the interpolation matrix """ - if (mesh, projected_grid) in self._Ps: - return self._Ps[(mesh, projected_grid)] + if projected_grid in self._Ps: + return self._Ps[projected_grid] P = Zero() for strength, comp in zip(self.orientation, ["x", "y", "z"]): @@ -164,7 +164,7 @@ def getP(self, mesh, projected_grid): ) if self.storeProjections: - self._Ps[(mesh, projected_grid)] = P + self._Ps[projected_grid] = P return P def eval(self, src, mesh, f): # noqa: A003 From 0b1709a7dbc420e0c6abe7c7c6d9e95550effbda Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 26 Jan 2025 17:32:52 -0800 Subject: [PATCH 66/84] Fix handling of spatialP differently on FEM and TEM receivers --- .../frequency_domain/simulation.py | 91 +++++-------------- .../time_domain/simulation.py | 17 +++- simpeg/dask/simulation.py | 38 ++++++++ .../frequency_domain/receivers.py | 6 +- 4 files changed, 77 insertions(+), 75 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 37e1e4e88b..2ab6c3322f 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -98,13 +98,32 @@ def getSourceTerm(self, freq, source=None): Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ + try: + client = get_client() + sim = client.scatter(self, workers=self.worker) + except ValueError: + client = None + sim = self + if source is None: source_list = self.survey.get_sources_by_frequency(freq) + source_blocks = np.array_split(source_list, self.n_threads(client=client)) block_compute = [] - for block in [source_list]: - block_compute.append(source_eval(self, block)) + for block in source_blocks: + if len(block) == 0: + continue + + if client: + block_compute.append( + client.submit(source_eval, sim, block, workers=self.worker) + ) + else: + block_compute.append(source_eval(sim, block)) + + if client: + block_compute = client.gather(block_compute) s_m, s_e = [], [] for block in block_compute: @@ -132,73 +151,6 @@ def getSourceTerm(self, freq, source=None): return s_m, s_e -# def dpred(self, m=None, f=None, compute_J=False): -# r""" -# dpred(m, f=None) -# Create the projected data from a model. -# The fields, f, (if provided) will be used for the predicted data -# instead of recalculating the fields (which may be expensive!). -# -# .. math:: -# -# d_\\text{pred} = P(f(m)) -# -# Where P is a projection of the fields onto the data space. -# """ -# if self.survey is None: -# raise AttributeError( -# "The survey has not yet been set and is required to compute " -# "data. Please set the survey for the simulation: " -# "simulation.survey = survey" -# ) -# -# if f is None: -# if m is None: -# m = self.model -# f = self.fields(m) -# -# all_receivers = [] -# -# for ind, src in enumerate(self.survey.source_list): -# for rx in src.receiver_list: -# all_receivers.append((src, ind, rx)) -# -# receiver_blocks = np.array_split(np.asarray(all_receivers), self.n_threads) -# rows = [] -# -# if self.client: -# f = self.client.scatter(f, workers=self.worker) -# mesh = self.client.scatter(self.mesh, workers=self.worker) -# else: -# delayed_receivers_eval = delayed(receivers_eval) -# mesh = delayed(self.mesh) -# -# for block in receiver_blocks: -# n_data = np.sum([rec.nD for _, _, rec in block]) -# if n_data == 0: -# continue -# -# if self.client: -# rows.append( -# self.client.submit(receivers_eval, block, mesh, f, workers=self.worker) -# ) -# else: -# rows.append( -# array.from_delayed( -# delayed_receivers_eval(block, mesh, f), -# dtype=np.float64, -# shape=(n_data,), -# ) -# ) -# -# if self.client: -# rows = np.hstack(self.client.gather(rows)) -# else: -# rows = compute(array.hstack(rows))[0] -# -# return rows - - def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -419,5 +371,4 @@ def parallel_block_compute( Sim.Jtvec = Jtvec Sim.Jmatrix = Jmatrix Sim.fields = fields -# Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index c9ce885ca2..269838c102 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -185,7 +185,12 @@ def compute_J(self, m, f=None): compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) fields_array = f[:, ftype, :] - client = get_client() + + try: + client = get_client() + except ValueError: + client = None + if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] @@ -211,7 +216,14 @@ def compute_J(self, m, f=None): for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( - self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + self, + block, + field_deriv, + tInd, + AdiagTinv, + ATinv_df_duT_v, + time_mask, + client, ) if len(block) == 0: @@ -494,6 +506,7 @@ def get_field_deriv_block( AdiagTinv, ATinv_df_duT_v: dict, time_mask, + client, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index ef2e64fdb1..d04cfb0a32 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -2,6 +2,7 @@ from dask import array import numpy as np +from multiprocessing import cpu_count Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag", "_stashed_fields"] Sim.sensitivity_path = "./sensitivity/" @@ -121,3 +122,40 @@ def n_threads(self, client=None): Sim.n_threads = n_threads + + +# TODO: Make dpred parallel +def dpred(self, m=None, f=None): + r"""Predicted data for the model provided. + + Parameters + ---------- + m : (n_param,) numpy.ndarray + The model parameters. + f : simpeg.fields.Fields, optional + If provided, will be used to compute the predicted data + without recalculating the fields. + + Returns + ------- + (n_data, ) numpy.ndarray + The predicted data vector. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + if f is None: + if m is None: + m = self.model + + f = self.fields(m) + + data = Data(self.survey) + for src in self.survey.source_list: + for rx in src.receiver_list: + data[src, rx] = rx.eval(src, self.mesh, f) + return mkvc(data) diff --git a/simpeg/electromagnetics/frequency_domain/receivers.py b/simpeg/electromagnetics/frequency_domain/receivers.py index ef1510446a..a9eab28f3e 100644 --- a/simpeg/electromagnetics/frequency_domain/receivers.py +++ b/simpeg/electromagnetics/frequency_domain/receivers.py @@ -153,8 +153,8 @@ def getP(self, mesh, projected_grid): scipy.sparse.csr_matrix P, the interpolation matrix """ - if projected_grid in self._Ps: - return self._Ps[projected_grid] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P = Zero() for strength, comp in zip(self.orientation, ["x", "y", "z"]): @@ -164,7 +164,7 @@ def getP(self, mesh, projected_grid): ) if self.storeProjections: - self._Ps[projected_grid] = P + self.spatialP = P return P def eval(self, src, mesh, f): # noqa: A003 From 3a29a23d1b844de11ae5c2ca9943031dc9e98175 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 27 Jan 2025 08:17:57 -0800 Subject: [PATCH 67/84] Use n_cell instead of mesh --- simpeg/electromagnetics/time_domain/receivers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simpeg/electromagnetics/time_domain/receivers.py b/simpeg/electromagnetics/time_domain/receivers.py index 98e4a5053e..dc2cbb0255 100644 --- a/simpeg/electromagnetics/time_domain/receivers.py +++ b/simpeg/electromagnetics/time_domain/receivers.py @@ -148,15 +148,15 @@ def getP(self, mesh, time_mesh, f): ----- Projection matrices are stored as a dictionary (mesh, time_mesh) if storeProjections is True """ - if (mesh, time_mesh) in self._Ps: - return self._Ps[(mesh, time_mesh)] + if (mesh.n_cells, time_mesh.n_cells) in self._Ps: + return self._Ps[(mesh.n_cells, time_mesh.n_cells)] Ps = self.getSpatialP(mesh, f) Pt = self.getTimeP(time_mesh, f) P = sp.kron(Pt, Ps) if self.storeProjections: - self._Ps[(mesh, time_mesh)] = P + self._Ps[(mesh.n_cells, time_mesh.n_cells)] = P return P From 406b25d71ce678b5c947f6323bb721ad9dc07b69 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 28 Jan 2025 08:24:59 -0800 Subject: [PATCH 68/84] Re-implement parallel getSourceTerm --- .../time_domain/simulation.py | 154 ++++++------------ 1 file changed, 51 insertions(+), 103 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 269838c102..ccd4384bb4 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -52,107 +52,56 @@ def fields(self, m=None): return f -# # def getSourceTerm(self, tInd): -# """ -# Assemble the source term. This ensures that the RHS is a vector / array -# of the correct size -# """ -# source_list = self.survey.source_list -# source_block = np.array_split(source_list, cpu_count()) -# -# block_compute = [] -# -# if client: -# sim = client.scatter(self, workers=self.worker) -# else: -# delayed_source_eval = delayed(source_evaluation) -# -# for block in source_block: -# if client: -# block_compute.append( -# client.submit( -# source_evaluation, -# sim, -# block, -# self.times[tInd], -# ) -# ) -# else: -# block_compute.append(delayed_source_eval(self, block, self.times[tInd])) -# -# if client: -# blocks = client.gather(block_compute) -# else: -# blocks = dask.compute(block_compute)[0] -# -# s_m, s_e = [], [] -# for block in blocks: -# if block[0]: -# s_m.append(block[0]) -# s_e.append(block[1]) -# -# if isinstance(s_m[0][0], Zero): -# return Zero(), np.vstack(s_e).T -# -# return np.vstack(s_m).T, np.vstack(s_e).T - - -# def dpred(self, m=None, f=None): -# r""" -# dpred(m, f=None) -# Create the projected data from a model. -# The fields, f, (if provided) will be used for the predicted data -# instead of recalculating the fields (which may be expensive!). -# -# .. math:: -# -# d_\\text{pred} = P(f(m)) -# -# Where P is a projection of the fields onto the data space. -# """ -# if self.survey is None: -# raise AttributeError( -# "The survey has not yet been set and is required to compute " -# "data. Please set the survey for the simulation: " -# "simulation.survey = survey" -# ) -# -# if f is None: -# if m is None: -# m = self.model -# f = self.fields(m) -# -# rows = [] -# receiver_projection = self.survey.source_list[0].receiver_list[0].projField -# fields_array = f[:, receiver_projection, :] -# -# if len(self.survey.source_list) == 1: -# fields_array = fields_array[:, np.newaxis, :] -# -# all_receivers = [] -# -# for ind, src in enumerate(self.survey.source_list): -# for rx in src.receiver_list: -# all_receivers.append((src, ind, rx)) -# -# receiver_blocks = np.array_split(all_receivers, cpu_count()) -# -# for block in receiver_blocks: -# n_data = np.sum([rec.nD for _, _, rec in block]) -# if n_data == 0: -# continue -# -# rows.append( -# array.from_delayed( -# evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), -# dtype=np.float64, -# shape=(n_data,), -# ) -# ) -# -# data = array.hstack(rows).compute() -# -# return data +def getSourceTerm(self, tInd): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + try: + client = get_client() + sim = client.scatter(self, workers=self.worker) + except ValueError: + client = None + sim = self + + source_list = self.survey.source_list + source_block = np.array_split(source_list, self.n_threads(client=client)) + + if client: + sim = client.scatter(self, workers=self.worker) + else: + delayed_source_eval = delayed(source_evaluation) + sim = self + + block_compute = [] + for block in source_block: + if client: + block_compute.append( + client.submit( + source_evaluation, + sim, + block, + self.times[tInd], + ) + ) + else: + block_compute.append(delayed_source_eval(self, block, self.times[tInd])) + + if client: + blocks = client.gather(block_compute) + else: + blocks = dask.compute(block_compute)[0] + + s_m, s_e = [], [] + for block in blocks: + if block[0]: + s_m.append(block[0]) + s_e.append(block[1]) + + if isinstance(s_m[0][0], Zero): + return Zero(), np.vstack(s_e).T + + return np.vstack(s_m).T, np.vstack(s_e).T def compute_J(self, m, f=None): @@ -712,8 +661,7 @@ def compute_rows( Sim.fields = fields Sim.getJtJdiag = getJtJdiag -# Sim.getSourceTerm = getSourceTerm -# Sim.dpred = dpred +Sim.getSourceTerm = getSourceTerm Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag Sim.Jvec = Jvec From d4b78ef737a569b99ed74b72570235b2fb056b02 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 29 Jan 2025 13:36:34 -0800 Subject: [PATCH 69/84] Don't use the mesh itself to index projections on NS receivers. Scatter sources to avoid large data on graph --- .../frequency_domain/simulation.py | 54 +++++++++++++------ .../natural_source/receivers.py | 6 +-- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 2ab6c3322f..2d78ea7c7a 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,5 +1,5 @@ import gc - +import os from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix @@ -11,6 +11,14 @@ from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr +from time import time + +OUTFILE = os.getcwd() + "/update.txt" + + +def write_message(message, mode="a"): + with open(OUTFILE, mode) as f: + f.write(message + "\n") def receivers_eval(block, mesh, fields): @@ -21,10 +29,10 @@ def receivers_eval(block, mesh, fields): return np.hstack(data) -def source_eval(simulation, sources): +def source_eval(simulation, sources, indices): s_m, s_e = [], [] - for source in sources: - sm, se = source.eval(simulation) + for ind in indices: + sm, se = sources[ind].eval(simulation) s_m.append(sm) s_e.append(se) @@ -55,6 +63,7 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address """ Evaluate the sensitivities for the block or data """ + if Ainv_deriv_u.ndim == 1: deriv_columns = Ainv_deriv_u[:, np.newaxis] else: @@ -78,12 +87,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address deriv_columns, adjoint=True, ) + dRHS_dmT = simulation.getRHSDeriv( source.frequency, source, deriv_columns, adjoint=True, ) + du_dmT = -dA_dmT if not isinstance(dRHS_dmT, Zero): du_dmT += dRHS_dmT @@ -98,30 +109,42 @@ def getSourceTerm(self, freq, source=None): Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ - try: - client = get_client() - sim = client.scatter(self, workers=self.worker) - except ValueError: - client = None - sim = self + ct = time() if source is None: + ct = time() + try: + client = get_client() + sim = client.scatter(self, workers=self.worker) + except ValueError: + client = None + sim = self + + write_message("Time to scatter simulation: {}".format(time() - ct)) source_list = self.survey.get_sources_by_frequency(freq) - source_blocks = np.array_split(source_list, self.n_threads(client=client)) - block_compute = [] + source_blocks = np.array_split( + np.arange(len(source_list)), self.n_threads(client=client) + ) + if client: + source_list = client.scatter(source_list, workers=self.worker) + + block_compute = [] + ct = time() for block in source_blocks: if len(block) == 0: continue if client: block_compute.append( - client.submit(source_eval, sim, block, workers=self.worker) + client.submit( + source_eval, sim, source_list, block, workers=self.worker + ) ) else: - block_compute.append(source_eval(sim, block)) - + block_compute.append(source_eval(sim, source_list, block)) + write_message("Time to submit source terms: {}".format(time() - ct)) if client: block_compute = client.gather(block_compute) @@ -148,6 +171,7 @@ def getSourceTerm(self, freq, source=None): s_e = np.vstack(s_e) if s_e.shape[0] < s_e.shape[1]: s_e = s_e.T + return s_m, s_e diff --git a/simpeg/electromagnetics/natural_source/receivers.py b/simpeg/electromagnetics/natural_source/receivers.py index 930fff879e..75498271f5 100644 --- a/simpeg/electromagnetics/natural_source/receivers.py +++ b/simpeg/electromagnetics/natural_source/receivers.py @@ -175,8 +175,8 @@ def getP(self, mesh, projected_grid, field="e"): if mesh.dim < 3: return super().getP(mesh, projected_grid) - if (mesh, projected_grid, field) in self._Ps: - return self._Ps[(mesh, projected_grid, field)] + if (mesh.n_cells, projected_grid, field) in self._Ps: + return self._Ps[(mesh.n_cells, projected_grid, field)] if field == "e": locs = self.locations_e @@ -184,7 +184,7 @@ def getP(self, mesh, projected_grid, field="e"): locs = self.locations_h P = mesh.get_interpolation_matrix(locs, projected_grid) if self.storeProjections: - self._Ps[(mesh, projected_grid, field)] = P + self._Ps[(mesh.n_cells, projected_grid, field)] = P return P def _eval_impedance(self, src, mesh, f): From 1f4b2232763e4113a13f5f8a988db035473d0232 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Jan 2025 09:02:54 -0800 Subject: [PATCH 70/84] Clean out prints. Update time domain --- .../frequency_domain/simulation.py | 21 +- .../time_domain/simulation.py | 203 +++++++++--------- 2 files changed, 107 insertions(+), 117 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 2d78ea7c7a..692eabd63a 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,5 +1,5 @@ import gc -import os + from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix @@ -11,14 +11,6 @@ from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr -from time import time - -OUTFILE = os.getcwd() + "/update.txt" - - -def write_message(message, mode="a"): - with open(OUTFILE, mode) as f: - f.write(message + "\n") def receivers_eval(block, mesh, fields): @@ -109,10 +101,9 @@ def getSourceTerm(self, freq, source=None): Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ - ct = time() if source is None: - ct = time() + try: client = get_client() sim = client.scatter(self, workers=self.worker) @@ -120,8 +111,6 @@ def getSourceTerm(self, freq, source=None): client = None sim = self - write_message("Time to scatter simulation: {}".format(time() - ct)) - source_list = self.survey.get_sources_by_frequency(freq) source_blocks = np.array_split( np.arange(len(source_list)), self.n_threads(client=client) @@ -131,7 +120,7 @@ def getSourceTerm(self, freq, source=None): source_list = client.scatter(source_list, workers=self.worker) block_compute = [] - ct = time() + for block in source_blocks: if len(block) == 0: continue @@ -144,7 +133,7 @@ def getSourceTerm(self, freq, source=None): ) else: block_compute.append(source_eval(sim, source_list, block)) - write_message("Time to submit source terms: {}".format(time() - ct)) + if client: block_compute = client.gather(block_compute) @@ -193,8 +182,6 @@ def fields(self, m=None, return_Ainv=False): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - # self.Ainv = Ainv - self._stashed_fields = f if return_Ainv: diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index ccd4384bb4..0a86d55436 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -7,7 +7,7 @@ from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix # from simpeg.fields import TimeFields -# from multiprocessing import cpu_count + import numpy as np import scipy.sparse as sp from dask import array, delayed @@ -19,11 +19,11 @@ from time import time -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m - if getattr(self, "_stashed_fields", None) is not None: + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: return self._stashed_fields f = self.fieldsPair(self) @@ -47,8 +47,9 @@ def fields(self, m=None): sol = Ainv[dt] * rhs f[:, self._fieldType + "Solution", tInd + 1] = sol - self.Ainv = Ainv self._stashed_fields = f + if return_Ainv: + return f, Ainv return f @@ -65,10 +66,13 @@ def getSourceTerm(self, tInd): sim = self source_list = self.survey.source_list - source_block = np.array_split(source_list, self.n_threads(client=client)) + source_block = np.array_split( + np.arange(len(source_list)), self.n_threads(client=client) + ) if client: sim = client.scatter(self, workers=self.worker) + source_list = client.scatter(source_list, workers=self.worker) else: delayed_source_eval = delayed(source_evaluation) sim = self @@ -78,14 +82,13 @@ def getSourceTerm(self, tInd): if client: block_compute.append( client.submit( - source_evaluation, - sim, - block, - self.times[tInd], + source_evaluation, sim, block, self.times[tInd], source_list ) ) else: - block_compute.append(delayed_source_eval(self, block, self.times[tInd])) + block_compute.append( + delayed_source_eval(self, block, self.times[tInd], source_list) + ) if client: blocks = client.gather(block_compute) @@ -109,7 +112,12 @@ def compute_J(self, m, f=None): Compute the rows for the sensitivity matrix. """ if f is None: - f = self.fields(m) + f, Ainv = self.fields(m=m, return_Ainv=True) + + try: + client = get_client() + except ValueError: + client = None ftype = self._fieldType + "Solution" sens_name = self.sensitivity_path[:-5] @@ -135,11 +143,6 @@ def compute_J(self, m, f=None): blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) fields_array = f[:, ftype, :] - try: - client = get_client() - except ValueError: - client = None - if len(self.survey.source_list) == 1: fields_array = fields_array[:, np.newaxis, :] @@ -154,9 +157,10 @@ def compute_J(self, m, f=None): sim = client.scatter(self, workers=self.worker) else: delayed_compute_rows = delayed(compute_rows) + sim = self for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): - AdiagTinv = self.Ainv[dt] + AdiagTinv = Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] @@ -195,7 +199,7 @@ def compute_J(self, m, f=None): j_row_updates.append( array.from_delayed( delayed_compute_rows( - self, + sim, tInd, block, ATinv_df_duT_v, @@ -212,7 +216,6 @@ def compute_J(self, m, f=None): if client: j_row_updates = np.vstack(client.gather(j_row_updates)) - else: j_row_updates = array.vstack(j_row_updates).compute() @@ -228,7 +231,7 @@ def compute_J(self, m, f=None): else: Jmatrix += j_row_updates - for A in self.Ainv.values(): + for A in Ainv.values(): A.clean() if self.store_sensitivities == "ram": @@ -239,83 +242,83 @@ def compute_J(self, m, f=None): return self._Jmatrix -def _getField(self, name, ind, src_list): - srcInd, timeInd = ind - - if name in self._fields: - out = self._fields[name][:, srcInd, timeInd] - else: - # Aliased fields - alias, loc, func = self.aliasFields[name] - if isinstance(func, str): - assert hasattr(self, func), ( - "The alias field function is a string, but it does " - "not exist in the Fields class." - ) - func = getattr(self, func) - pointerFields = self._fields[alias][:, srcInd, timeInd] - pointerShape = self._correctShape(alias, ind) - pointerFields = pointerFields.reshape(pointerShape, order="F") - - # First try to return the function as three arguments (without timeInd) - if timeInd == slice(None, None, None): - try: - # assume it will take care of integrating over all times - return func(pointerFields, srcInd) - except TypeError: - pass - - timeII = np.arange(self.simulation.nT + 1)[timeInd] - if not isinstance(src_list, list): - src_list = [src_list] - - if timeII.size == 1: - pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) - pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") - out = func(pointerFields, src_list, timeII) - else: # loop over the time steps - arrays = [] - - if client: - pointerFields = client.scatter(pointerFields, workers=self.worker) - src_list = client.scatter(src_list, workers=self.worker) - func = client.scatter(func, workers=self.worker) - else: - delayed_field_comp = delayed(field_projection) - - for i, TIND_i in enumerate(timeII): # Need to parallelize this - - if client: - arrays.append( - client.submit( - field_projection, - pointerFields, - src_list, - i, - TIND_i, - func, - workers=self.worker, - ) - ) - else: - arrays.append( - array.from_delayed( - delayed_field_comp( - pointerFields, src_list, i, TIND_i, func - ), - dtype=np.float32, - shape=(pointerShape[0], pointerShape[1], 1), - ) - ) - - if client: - arrays = client.gather(arrays) - out = np.dstack(arrays) - else: - out = array.dstack(arrays).compute() - - shape = self._correctShape(name, ind, deflate=True) - return out.reshape(shape, order="F") +# def _getField(self, name, ind, src_list): +# srcInd, timeInd = ind +# +# if name in self._fields: +# out = self._fields[name][:, srcInd, timeInd] +# else: +# # Aliased fields +# alias, loc, func = self.aliasFields[name] +# if isinstance(func, str): +# assert hasattr(self, func), ( +# "The alias field function is a string, but it does " +# "not exist in the Fields class." +# ) +# func = getattr(self, func) +# pointerFields = self._fields[alias][:, srcInd, timeInd] +# pointerShape = self._correctShape(alias, ind) +# pointerFields = pointerFields.reshape(pointerShape, order="F") +# +# # First try to return the function as three arguments (without timeInd) +# if timeInd == slice(None, None, None): +# try: +# # assume it will take care of integrating over all times +# return func(pointerFields, srcInd) +# except TypeError: +# pass +# +# timeII = np.arange(self.simulation.nT + 1)[timeInd] +# if not isinstance(src_list, list): +# src_list = [src_list] +# +# if timeII.size == 1: +# pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) +# pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") +# out = func(pointerFields, src_list, timeII) +# else: # loop over the time steps +# arrays = [] +# +# if client: +# pointerFields = client.scatter(pointerFields, workers=self.worker) +# src_list = client.scatter(src_list, workers=self.worker) +# func = client.scatter(func, workers=self.worker) +# else: +# delayed_field_comp = delayed(field_projection) +# +# for i, TIND_i in enumerate(timeII): # Need to parallelize this +# +# if client: +# arrays.append( +# client.submit( +# field_projection, +# pointerFields, +# src_list, +# i, +# TIND_i, +# func, +# workers=self.worker, +# ) +# ) +# else: +# arrays.append( +# array.from_delayed( +# delayed_field_comp( +# pointerFields, src_list, i, TIND_i, func +# ), +# dtype=np.float32, +# shape=(pointerShape[0], pointerShape[1], 1), +# ) +# ) +# +# if client: +# arrays = client.gather(arrays) +# out = np.dstack(arrays) +# else: +# out = array.dstack(arrays).compute() +# +# shape = self._correctShape(name, ind, deflate=True) +# return out.reshape(shape, order="F") # TimeFields._getField = _getField @@ -334,10 +337,10 @@ def field_projection(field_array, src_list, array_ind, time_ind, func): return new_array -def source_evaluation(simulation, sources, time_channel): +def source_evaluation(simulation, indices, time_channel, sources): s_m, s_e = [], [] - for source in sources: - sm, se = source.eval(simulation, time_channel) + for ind in indices: + sm, se = sources[ind].eval(simulation, time_channel) s_m.append(sm) s_e.append(se) From 8f58f51dda8cf07cf90070b03be64a1e2e9dae4c Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Jan 2025 15:53:03 -0800 Subject: [PATCH 71/84] Clean up tem simulation for distributed process --- .../time_domain/simulation.py | 203 ++++++------------ 1 file changed, 65 insertions(+), 138 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 0a86d55436..42c1580c47 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -6,8 +6,6 @@ from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix -# from simpeg.fields import TimeFields - import numpy as np import scipy.sparse as sp from dask import array, delayed @@ -18,6 +16,13 @@ from time import time +OUTFILE = os.getcwd() + "/update.txt" + + +def write_message(message, mode="a"): + with open(OUTFILE, mode) as f: + f.write(message + "\n") + def fields(self, m=None, return_Ainv=False): if m is not None: @@ -77,12 +82,18 @@ def getSourceTerm(self, tInd): delayed_source_eval = delayed(source_evaluation) sim = self + ct = time() block_compute = [] for block in source_block: if client: block_compute.append( client.submit( - source_evaluation, sim, block, self.times[tInd], source_list + source_evaluation, + sim, + block, + self.times[tInd], + source_list, + workers=self.worker, ) ) else: @@ -95,6 +106,7 @@ def getSourceTerm(self, tInd): else: blocks = dask.compute(block_compute)[0] + write_message(f"Source term computation: {time() - ct:.3e} sec") s_m, s_e = [], [] for block in blocks: if block[0]: @@ -140,7 +152,11 @@ def compute_J(self, m, f=None): simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times compute_row_size = np.ceil(self.max_chunk_size / (m.shape[0] * 8.0 * 1e-6)) - blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) + blocks = get_parallel_blocks( + self.survey.source_list, + compute_row_size, + thread_count=self.n_threads(client=client), + ) fields_array = f[:, ftype, :] if len(self.survey.source_list) == 1: @@ -150,7 +166,7 @@ def compute_J(self, m, f=None): self, f, blocks, Jmatrix, fields_array.shape, client ) - ATinv_df_duT_v = {} + ATinv_df_duT_v = [[] for _ in blocks] if client: fields_array = client.scatter(fields_array, workers=self.worker) @@ -167,29 +183,37 @@ def compute_J(self, m, f=None): if not np.any(time_mask): continue - for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): - ATinv_df_duT_v = get_field_deriv_block( + for ind, (block, field_deriv) in enumerate( + zip(blocks, times_field_derivs[tInd + 1]) + ): + ct = time() + ATinv_df_duT_v[ind] = get_field_deriv_block( self, block, field_deriv, tInd, AdiagTinv, - ATinv_df_duT_v, + ATinv_df_duT_v[ind], time_mask, client, ) + write_message(f"Field deriv block computation: {time() - ct:.3e} sec") if len(block) == 0: continue + ct = time() if client: + field_derivatives = client.scatter( + ATinv_df_duT_v[ind], workers=self.worker + ) j_row_updates.append( client.submit( compute_rows, sim, tInd, block, - ATinv_df_duT_v, + field_derivatives, fields_array, time_mask, workers=self.worker, @@ -202,7 +226,7 @@ def compute_J(self, m, f=None): sim, tInd, block, - ATinv_df_duT_v, + ATinv_df_duT_v[ind], fields_array, time_mask, ), @@ -242,88 +266,6 @@ def compute_J(self, m, f=None): return self._Jmatrix -# def _getField(self, name, ind, src_list): -# srcInd, timeInd = ind -# -# if name in self._fields: -# out = self._fields[name][:, srcInd, timeInd] -# else: -# # Aliased fields -# alias, loc, func = self.aliasFields[name] -# if isinstance(func, str): -# assert hasattr(self, func), ( -# "The alias field function is a string, but it does " -# "not exist in the Fields class." -# ) -# func = getattr(self, func) -# pointerFields = self._fields[alias][:, srcInd, timeInd] -# pointerShape = self._correctShape(alias, ind) -# pointerFields = pointerFields.reshape(pointerShape, order="F") -# -# # First try to return the function as three arguments (without timeInd) -# if timeInd == slice(None, None, None): -# try: -# # assume it will take care of integrating over all times -# return func(pointerFields, srcInd) -# except TypeError: -# pass -# -# timeII = np.arange(self.simulation.nT + 1)[timeInd] -# if not isinstance(src_list, list): -# src_list = [src_list] -# -# if timeII.size == 1: -# pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) -# pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") -# out = func(pointerFields, src_list, timeII) -# else: # loop over the time steps -# arrays = [] -# -# if client: -# pointerFields = client.scatter(pointerFields, workers=self.worker) -# src_list = client.scatter(src_list, workers=self.worker) -# func = client.scatter(func, workers=self.worker) -# else: -# delayed_field_comp = delayed(field_projection) -# -# for i, TIND_i in enumerate(timeII): # Need to parallelize this -# -# if client: -# arrays.append( -# client.submit( -# field_projection, -# pointerFields, -# src_list, -# i, -# TIND_i, -# func, -# workers=self.worker, -# ) -# ) -# else: -# arrays.append( -# array.from_delayed( -# delayed_field_comp( -# pointerFields, src_list, i, TIND_i, func -# ), -# dtype=np.float32, -# shape=(pointerShape[0], pointerShape[1], 1), -# ) -# ) -# -# if client: -# arrays = client.gather(arrays) -# out = np.dstack(arrays) -# else: -# out = array.dstack(arrays).compute() -# -# shape = self._correctShape(name, ind, deflate=True) -# return out.reshape(shape, order="F") - - -# TimeFields._getField = _getField - - def field_projection(field_array, src_list, array_ind, time_ind, func): fieldI = field_array[:, :, array_ind] if fieldI.shape[0] == fieldI.size: @@ -436,27 +378,13 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): return df_duT, Jmatrix -def update_deriv_blocks(address, indices, derivatives, solve, shape): - if address not in derivatives: - deriv_array = np.zeros(shape) - else: - deriv_array = derivatives[address] - - if address in indices: - columns, local_ind = indices[address] - if solve is not None: - deriv_array[:, local_ind] = solve[:, columns] - - derivatives[address] = deriv_array - - def get_field_deriv_block( self, block: list, field_derivs: list, tInd: int, AdiagTinv, - ATinv_df_duT_v: dict, + ATinv_df_duT_v, time_mask, client, ): @@ -464,7 +392,9 @@ def get_field_deriv_block( Stack the blocks of field derivatives for a given timestep and call the direct solver. """ stacked_blocks = [] - indices = {} + if len(ATinv_df_duT_v) == 0: + ATinv_df_duT_v = [[] for _ in block] + indices = [] count = 0 Asubdiag = None @@ -472,8 +402,8 @@ def get_field_deriv_block( Asubdiag = self.getAsubdiag(tInd + 1) delayed_deriv = delayed(deriv_block) - for ((s_id, r_id, b_id), (rx_ind, _, shape)), field_deriv in zip( - block, field_derivs + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( + block, field_derivs, ATinv_df_duT_v ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] @@ -482,9 +412,8 @@ def get_field_deriv_block( if len(local_ind) < 1: continue - indices[(s_id, r_id, b_id)] = ( - np.arange(count, count + len(local_ind)), - local_ind, + indices.append( + (np.arange(count, count + len(local_ind)), local_ind), ) count += len(local_ind) @@ -492,27 +421,19 @@ def get_field_deriv_block( stacked_blocks.append( client.submit( deriv_block, - s_id, - r_id, - b_id, - ATinv_df_duT_v, + ATinv_chunk, Asubdiag, local_ind, field_deriv, - tInd, workers=self.worker, ) ) else: deriv_comp = delayed_deriv( - s_id, - r_id, - b_id, - ATinv_df_duT_v, + ATinv_chunk, Asubdiag, local_ind, field_deriv, - tInd, ) stacked_blocks.append( array.from_delayed( @@ -535,15 +456,25 @@ def get_field_deriv_block( else: solve = None - for (address, arrays), field_deriv in zip(block, field_derivs): - shape = ( - field_deriv.shape[0], - len(arrays[0]), - ) + updated_ATinv_df_duT_v = [] + for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip( + block, field_derivs, ATinv_df_duT_v, indices + ): + + if len(ATinv_chunk) == 0: + shape = ( + field_deriv.shape[0], + len(arrays[0]), + ) + ATinv_chunk = np.zeros(shape, dtype=np.float32) - update_deriv_blocks(address, indices, ATinv_df_duT_v, solve, shape) + if solve is None: + continue + + ATinv_chunk[:, local_ind] = solve[:, columns] + updated_ATinv_df_duT_v.append(ATinv_chunk) - return ATinv_df_duT_v + return updated_ATinv_df_duT_v def block_deriv( @@ -593,17 +524,14 @@ def block_deriv( return df_duT, j_updates -def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd -): - if (s_id, r_id, b_id) not in ATinv_df_duT_v: +def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): + if len(ATinv_df_duT_v) == 0: # last timestep (first to be solved) stacked_block = field_derivs.toarray()[:, local_ind] else: stacked_block = np.asarray( - field_derivs[:, local_ind] - - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] + field_derivs[:, local_ind] - Asubdiag.T * ATinv_df_duT_v[:, local_ind] ) return stacked_block @@ -622,7 +550,7 @@ def compute_rows( """ rows = [] - for address, ind_array in chunks: + for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] local_ind = np.arange(len(ind_array[0]))[time_check] @@ -634,7 +562,6 @@ def compute_rows( rows.append(row_block) continue - field_derivs = ATinv_df_duT_v[address] dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], From 925f3e4735060af4e992b1e3711dbad3d20a7f70 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 10:03:16 -0800 Subject: [PATCH 72/84] Add small quantity to avoid zero division --- simpeg/electromagnetics/static/resistivity/simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simpeg/electromagnetics/static/resistivity/simulation.py b/simpeg/electromagnetics/static/resistivity/simulation.py index 4e65a8d07a..bdd09e1666 100644 --- a/simpeg/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/electromagnetics/static/resistivity/simulation.py @@ -606,7 +606,7 @@ def setBC(self): # TODO: Implement Zhang et al. (1995) r_vec = boundary_faces - source_point - r = np.linalg.norm(r_vec, axis=-1) + r = np.linalg.norm(r_vec, axis=-1) + 1e-12 r_hat = r_vec / r[:, None] r_dot_n = np.einsum("ij,ij->i", r_hat, boundary_normals) From 72d0fcfea6e30a0fdc777b951f64aa86485d6cc3 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 10:55:26 -0800 Subject: [PATCH 73/84] Stash RHS as sparse arrays. Index receiver projection without mesh --- .../static/induced_polarization/simulation.py | 2 +- .../static/resistivity/simulation.py | 50 +++++++++++++------ .../static/resistivity/receivers.py | 12 ++--- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 8768265021..1a1171cffe 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -26,7 +26,7 @@ def fields(self, m=None): RHS = self.getRHS() f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + f[:, self._solutionType] = Ainv * np.asarray(RHS.todense()) if self._scale is None: scale = Data(self.survey, np.ones(self.survey.nD)) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index bf8e11a674..6b48520d6f 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,8 +1,10 @@ from .....electromagnetics.static.resistivity.simulation import Simulation3DNodal as Sim from .....utils import Zero +from dask.distributed import get_client import dask.array as da import numpy as np +from scipy import sparse as sp import zarr import numcodecs @@ -26,7 +28,7 @@ def fields(self, m=None): RHS = self.getRHS() f = self.fieldsPair(self) - f[:, self._solutionType] = Ainv * RHS + f[:, self._solutionType] = Ainv * np.asarray(RHS.todense()) self.Ainv = Ainv @@ -135,6 +137,17 @@ def compute_J(self, m, f=None): return self._Jmatrix +def source_eval(simulation, sources, indices): + """ + Evaluate the source term for the given source and index + """ + blocks = [] + for ind in indices: + blocks.append(sources[ind].eval(simulation)) + + return sp.csr_matrix(np.vstack(blocks).T) + + def getSourceTerm(self): """ Evaluates the sources, and puts them in matrix form @@ -145,23 +158,30 @@ def getSourceTerm(self): if getattr(self, "_q", None) is None: if self._mini_survey is not None: - Srcs = self._mini_survey.source_list + source_list = self._mini_survey.source_list else: - Srcs = self.survey.source_list - - if self._formulation == "EB": - n = self.mesh.nN - # return NotImplementedError - - elif self._formulation == "HJ": - n = self.mesh.nC - - q = np.zeros((n, len(Srcs)), order="F") + source_list = self.survey.source_list + + indices = np.arange(len(source_list)) + try: + + client = get_client() + sim = client.scatter(self, workers=self.worker) + future_list = client.scatter(source_list, workers=self.worker) + indices = np.array_split(indices, self.n_threads(client=client)) + blocks = [] + for ind in indices: + blocks.append( + client.submit( + source_eval, sim, future_list, ind, workers=self.worker + ) + ) - for i, source in enumerate(Srcs): - q[:, i] = source.eval(self) + blocks = sp.hstack(client.gather(blocks)) + except ValueError: + blocks = source_eval(self, source_list, indices) - self._q = q + self._q = blocks return self._q diff --git a/simpeg/electromagnetics/static/resistivity/receivers.py b/simpeg/electromagnetics/static/resistivity/receivers.py index 53c2614bb7..ab1324d8ed 100644 --- a/simpeg/electromagnetics/static/resistivity/receivers.py +++ b/simpeg/electromagnetics/static/resistivity/receivers.py @@ -410,15 +410,15 @@ def getP(self, mesh, projected_grid, transpose=False): P, the interpolation matrix """ - if mesh in self._Ps: - return self._Ps[mesh] + if mesh.n_cells in self._Ps: + return self._Ps[mesh.n_cells] P0 = mesh.get_interpolation_matrix(self.locations[0], projected_grid) P1 = mesh.get_interpolation_matrix(self.locations[1], projected_grid) P = P0 - P1 if self.storeProjections: - self._Ps[mesh] = P + self._Ps[mesh.n_cells] = P if transpose: P = P.toarray().T @@ -489,12 +489,12 @@ def getP(self, mesh, projected_grid): P, the interpolation matrix """ - if mesh in self._Ps: - return self._Ps[mesh] + if mesh.n_cells in self._Ps: + return self._Ps[mesh.n_cells] P = mesh.get_interpolation_matrix(self.locations, projected_grid) if self.storeProjections: - self._Ps[mesh] = P + self._Ps[mesh.n_cells] = P return P From 76efdfd09b97725bd56292ad05e0f12add7c16d4 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 10:58:38 -0800 Subject: [PATCH 74/84] Stash sources on TEM simulation as sparse array --- .../time_domain/simulation.py | 82 +++++++------------ 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 42c1580c47..8870622719 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -14,15 +14,6 @@ from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc -from time import time - -OUTFILE = os.getcwd() + "/update.txt" - - -def write_message(message, mode="a"): - with open(OUTFILE, mode) as f: - f.write(message + "\n") - def fields(self, m=None, return_Ainv=False): if m is not None: @@ -42,14 +33,13 @@ def fields(self, m=None, return_Ainv=False): Asubdiag = self.getAsubdiag(tInd) rhs = -Asubdiag * f[:, (self._fieldType + "Solution"), tInd] - if ( np.abs(self.survey.source_list[0].waveform.eval(self.times[tInd + 1])) > 1e-8 ): rhs += self.getRHS(tInd + 1) - sol = Ainv[dt] * rhs + sol = Ainv[dt] * np.asarray(rhs) f[:, self._fieldType + "Solution", tInd + 1] = sol self._stashed_fields = f @@ -63,6 +53,14 @@ def getSourceTerm(self, tInd): Assemble the source term. This ensures that the RHS is a vector / array of the correct size """ + if ( + getattr(self, "_stashed_sources", None) is not None + and tInd in self._stashed_sources + ): + return self._stashed_sources[tInd] + elif getattr(self, "_stashed_sources", None) is None: + self._stashed_sources = {} + try: client = get_client() sim = client.scatter(self, workers=self.worker) @@ -82,7 +80,6 @@ def getSourceTerm(self, tInd): delayed_source_eval = delayed(source_evaluation) sim = self - ct = time() block_compute = [] for block in source_block: if client: @@ -106,7 +103,6 @@ def getSourceTerm(self, tInd): else: blocks = dask.compute(block_compute)[0] - write_message(f"Source term computation: {time() - ct:.3e} sec") s_m, s_e = [], [] for block in blocks: if block[0]: @@ -114,9 +110,13 @@ def getSourceTerm(self, tInd): s_e.append(block[1]) if isinstance(s_m[0][0], Zero): - return Zero(), np.vstack(s_e).T + self._stashed_sources[tInd] = Zero(), sp.csr_matrix(np.vstack(s_e).T) + else: + self._stashed_sources[tInd] = sp.csr_matrix(np.vstack(s_m).T), sp.csr_matrix( + np.vstack(s_e).T + ) - return np.vstack(s_m).T, np.vstack(s_e).T + return self._stashed_sources[tInd] def compute_J(self, m, f=None): @@ -174,7 +174,6 @@ def compute_J(self, m, f=None): else: delayed_compute_rows = delayed(compute_rows) sim = self - for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): AdiagTinv = Ainv[dt] j_row_updates = [] @@ -186,7 +185,6 @@ def compute_J(self, m, f=None): for ind, (block, field_deriv) in enumerate( zip(blocks, times_field_derivs[tInd + 1]) ): - ct = time() ATinv_df_duT_v[ind] = get_field_deriv_block( self, block, @@ -197,12 +195,10 @@ def compute_J(self, m, f=None): time_mask, client, ) - write_message(f"Field deriv block computation: {time() - ct:.3e} sec") if len(block) == 0: continue - ct = time() if client: field_derivatives = client.scatter( ATinv_df_duT_v[ind], workers=self.worker @@ -311,10 +307,12 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): mesh = client.scatter(self.mesh, workers=self.worker) time_mesh = client.scatter(self.time_mesh, workers=self.worker) fields = client.scatter(fields, workers=self.worker) + source_list = client.scatter(self.survey.source_list, workers=self.worker) else: mesh = self.mesh time_mesh = self.time_mesh delayed_block_deriv = delayed(block_deriv) + source_list = self.survey.source_list for chunks in blocks: if len(chunks) == 0: @@ -327,7 +325,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): self.nT, chunks, fields_shape[0], - self.survey.source_list, + source_list, mesh, time_mesh, fields, @@ -341,7 +339,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client): self.nT, chunks, fields_shape[0], - self.survey.source_list, + source_list, self.mesh, self.time_mesh, fields, @@ -401,7 +399,6 @@ def get_field_deriv_block( if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) - delayed_deriv = delayed(deriv_block) for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( block, field_derivs, ATinv_df_duT_v ): @@ -417,40 +414,19 @@ def get_field_deriv_block( ) count += len(local_ind) - if client: - stacked_blocks.append( - client.submit( - deriv_block, - ATinv_chunk, - Asubdiag, - local_ind, - field_deriv, - workers=self.worker, - ) - ) + if len(ATinv_chunk) == 0: + # last timestep (first to be solved) + stacked_block = field_deriv.toarray()[:, local_ind] + else: - deriv_comp = delayed_deriv( - ATinv_chunk, - Asubdiag, - local_ind, - field_deriv, + stacked_block = np.asarray( + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] ) - stacked_blocks.append( - array.from_delayed( - deriv_comp, - dtype=float, - shape=( - field_deriv.shape[0], - len(local_ind), - ), - ) - ) - if len(stacked_blocks) > 0: - if client: - blocks = np.hstack(client.gather(stacked_blocks)) - else: - blocks = array.hstack(stacked_blocks).compute() + stacked_blocks.append(stacked_block) + + if len(stacked_blocks) > 0: + blocks = np.hstack(stacked_blocks) solve = (AdiagTinv * blocks).reshape(blocks.shape) else: From 50aa18e8ecb67766cf63281c421c14dff54edb7c Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 12:45:50 -0800 Subject: [PATCH 75/84] Don't store Ainv on simulation. Fix imports of Jvec, Jtvec --- .../static/resistivity/simulation.py | 25 +++++++++++-------- .../time_domain/simulation.py | 1 - simpeg/dask/objective_function.py | 1 + 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index 6b48520d6f..3605218046 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -1,5 +1,7 @@ from .....electromagnetics.static.resistivity.simulation import Simulation3DNodal as Sim +from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix + from .....utils import Zero from dask.distributed import get_client import dask.array as da @@ -16,11 +18,11 @@ numcodecs.blosc.use_threads = False -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m - if getattr(self, "_stashed_fields", None) is not None: + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: return self._stashed_fields A = self.getA() @@ -30,17 +32,15 @@ def fields(self, m=None): f = self.fieldsPair(self) f[:, self._solutionType] = Ainv * np.asarray(RHS.todense()) - self.Ainv = Ainv - self._stashed_fields = f - + if return_Ainv: + return f, Ainv return f def compute_J(self, m, f=None): - if f is None: - f = self.fields(m) + f, Ainv = self.fields(m=m, return_Ainv=True) m_size = m.size row_chunks = int( @@ -67,7 +67,7 @@ def compute_J(self, m, f=None): for rx in source.receiver_list: - if rx.orientation is not None: + if getattr(rx, "orientation", None) is not None: projected_grid = f._GLoc(rx.projField) + rx.orientation else: projected_grid = f._GLoc(rx.projField) @@ -82,7 +82,7 @@ def compute_J(self, m, f=None): df_duT, df_dmT = df_duTFun( source, None, PTv[:, start:end], adjoint=True ) - ATinvdf_duT = self.Ainv * df_duT + ATinvdf_duT = Ainv * df_duT dA_dmT = self.getADeriv(u_source, ATinvdf_duT, adjoint=True) dRHS_dmT = self.getRHSDeriv(source, ATinvdf_duT, adjoint=True) du_dmT = -dA_dmT @@ -126,7 +126,7 @@ def compute_J(self, m, f=None): else: Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) - self.Ainv.clean() + Ainv.clean() if self.store_sensitivities == "disk": del Jmatrix @@ -189,3 +189,8 @@ def getSourceTerm(self): Sim.getSourceTerm = getSourceTerm Sim.fields = fields Sim.compute_J = compute_J + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 8870622719..5509417eb4 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -566,7 +566,6 @@ def compute_rows( Sim.fields = fields -Sim.getJtJdiag = getJtJdiag Sim.getSourceTerm = getSourceTerm Sim.compute_J = compute_J Sim.getJtJdiag = getJtJdiag diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c8ef35d06b..7c1a7a3d2e 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -43,6 +43,7 @@ def _deriv2(objfct, multiplier, model, v): # if fields is not None and objfct.has_fields: # return multiplier * objfct.deriv2(objfct.simulation.model, v) # else: + print("Calculating deriv2") return multiplier * objfct.deriv2(objfct.simulation.model, v) From 6811443057e23672c3a940b70891b76b2e395239 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 12:47:57 -0800 Subject: [PATCH 76/84] Remove test print --- simpeg/dask/objective_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 7c1a7a3d2e..c8ef35d06b 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -43,7 +43,6 @@ def _deriv2(objfct, multiplier, model, v): # if fields is not None and objfct.has_fields: # return multiplier * objfct.deriv2(objfct.simulation.model, v) # else: - print("Calculating deriv2") return multiplier * objfct.deriv2(objfct.simulation.model, v) From eea390cc0187e8a3ff8e0d5417c15139cb605d1f Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 14:29:59 -0800 Subject: [PATCH 77/84] Fix 2D simulation --- .../static/resistivity/simulation_2d.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py index bb5bd0e6c6..bb430ec2c1 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation_2d.py @@ -1,6 +1,7 @@ from .....electromagnetics.static.resistivity.simulation_2d import ( Simulation2DNodal as Sim, ) +from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import dask.array as da import numpy as np @@ -10,11 +11,11 @@ numcodecs.blosc.use_threads = False -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m - if getattr(self, "_stashed_fields", None) is not None: + if getattr(self, "_stashed_fields", None) is not None and not return_Ainv: return self._stashed_fields kys = self._quad_points @@ -29,9 +30,9 @@ def fields(self, m=None): RHS = self.getRHS(ky) f[:, self._solutionType, iky] = Ainv[iky] * RHS - self.Ainv = Ainv - self._stashed_fields = f + if return_Ainv: + return f, Ainv return f @@ -39,8 +40,7 @@ def compute_J(self, m, f=None): kys = self._quad_points weights = self._quad_weights - if f is None: - f = self.fields(m) + f, Ainv = self.fields(m, return_Ainv=True) m_size = m.size row_chunks = int( @@ -72,7 +72,7 @@ def compute_J(self, m, f=None): for i_src, source in enumerate(self.survey.source_list): for rx in source.receiver_list: - if rx.orientation is not None: + if getattr(rx, "orientation", None) is not None: projected_grid = f._GLoc(rx.projField) + rx.orientation else: projected_grid = f._GLoc(rx.projField) @@ -88,7 +88,7 @@ def compute_J(self, m, f=None): u_ky = f[:, self._solutionType, iky] u_source = u_ky[:, i_src] - ATinvdf_duT = self.Ainv[iky] * PTv[:, start:end] + ATinvdf_duT = Ainv[iky] * PTv[:, start:end] dA_dmT = self.getADeriv(ky, u_source, ATinvdf_duT, adjoint=True) du_dmT = -weights[iky] * dA_dmT block += du_dmT.T.reshape((-1, m_size)) @@ -124,7 +124,7 @@ def compute_J(self, m, f=None): Jmatrix[count : self.survey.nD, :] = blocks.astype(np.float32) for iky, _ in enumerate(kys): - self.Ainv[iky].clean() + Ainv[iky].clean() if self.store_sensitivities == "disk": del Jmatrix @@ -212,3 +212,8 @@ def getSourceTerm(self, _): Sim.compute_J = compute_J Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm + +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix From ad5b5f1345b3f2122330ceaf9d35ee1e2903da34 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 14:56:16 -0800 Subject: [PATCH 78/84] Clean out prints --- simpeg/dask/objective_function.py | 94 +++++++++++-------------------- 1 file changed, 34 insertions(+), 60 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c8ef35d06b..375f5d71a1 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -3,19 +3,21 @@ import numpy as np from dask.distributed import Client from ..data_misfit import L2DataMisfit -import os -from simpeg.utils import validate_list_of_types -from time import time - -OUTFILE = os.getcwd() + "/update.txt" - - -def write_message(message, mode="a"): - with open(OUTFILE, mode) as f: - f.write(message + "\n") +from simpeg.utils import validate_list_of_types -write_message("Starting", mode="w+") +# import os +# from time import time +# +# OUTFILE = os.getcwd() + "/update.txt" +# +# +# def write_message(message, mode="a"): +# with open(OUTFILE, mode) as f: +# f.write(message + "\n") +# +# +# write_message("Starting", mode="w+") def _calc_fields(objfct, model): @@ -33,16 +35,10 @@ def _calc_residual(objfct, model): def _deriv(objfct, multiplier, model): - # if fields is not None and objfct.has_fields: - # return multiplier * objfct.deriv(objfct.simulation.model) - # else: return multiplier * objfct.deriv(objfct.simulation.model) def _deriv2(objfct, multiplier, model, v): - # if fields is not None and objfct.has_fields: - # return multiplier * objfct.deriv2(objfct.simulation.model, v) - # else: return multiplier * objfct.deriv2(objfct.simulation.model, v) @@ -84,35 +80,6 @@ def _validate_type_or_future_of_type( workload[-1].append(future) count += 1 - # objects[0].simulation.simulations[0].worker = workers[0] - # if workers is None: - # objects = client.scatter(objects) - # else: - # tmp = [] - # for obj, worker in zip(objects, workers): - # tmp.append(client.scatter([obj], workers=worker)[0]) - # objects = tmp - # except TypeError: - # pass - # ensure list of futures - # objects = validate_list_of_types( - # property_name, - # objects, - # Future, - # ) - # Figure out where everything lives - - # who = client.who_has(workload) - # # if workers is None: - # # workers = [] - # for ii, worker in enumerate(who.values()): - # if worker != workers[ii % len(workers)]: - # warnings.warn( - # f"{property_name} {i} is not on the expected worker.", stacklevel=2 - # ) - # # obj = client.submit(_set_worker, obj, worker) - - # Ensure this runs on the expected worker futures = [] for work in workload: @@ -223,7 +190,7 @@ def deriv(self, m, f=None): derivs = 0.0 count = 0 - write_message("Calculating deriv") + for futures in self._futures: future_deriv = [] for objfct, worker in zip(futures, self._workers): @@ -263,7 +230,7 @@ def deriv2(self, m, v=None, f=None): derivs = 0.0 count = 0 - write_message("Calculating deriv2") + for futures in self._futures: future_derivs = [] @@ -290,12 +257,15 @@ def deriv2(self, m, v=None, f=None): return derivs def get_dpred(self, m, f=None): + """ + Request calculation of predicted data from all simulations. + """ self.model = m client = self.client m_future = self._m_as_future dpred = [] - write_message("Calculating dpred") + for futures in self._futures: future_preds = [] for objfct, worker in zip(futures, self._workers): @@ -312,6 +282,9 @@ def get_dpred(self, m, f=None): return dpred def getJtJdiag(self, m, f=None): + """ + Request calculation of the diagonal of JtJ from all simulations. + """ self.model = m m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: @@ -319,27 +292,20 @@ def getJtJdiag(self, m, f=None): jtj_diag = 0.0 client = self.client - write_message("Calculating JtJdiag") - # if f is None: - # f = self.fields(m) - for ii, futures in enumerate(self._futures): + for futures in self._futures: work = [] - ct = time() + for objfct, worker in zip(futures, self._workers): work.append( client.submit( _get_jtj_diag, objfct, m_future, - # field, workers=worker, ) ) work = client.gather(work) - write_message( - f"Future {ii} of {len(self._futures)} in {time() - ct:.3f} sec" - ) jtj_diag += np.sum(work, axis=0) self._jtjdiag = jtj_diag @@ -347,6 +313,11 @@ def getJtJdiag(self, m, f=None): return self._jtjdiag def fields(self, m): + """ + Request calculation of fields from all simulations. + + Store list of futures for fields in self._stashed_fields. + """ self.model = m client = self.client m_future = self._m_as_future @@ -354,7 +325,7 @@ def fields(self, m): return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] - write_message("Calculating fields") + for futures in self._futures: f.append([]) for objfct, worker in zip(futures, self._workers): @@ -405,6 +376,9 @@ def model(self, value): @property def objfcts(self): + """ + List of objective functions associated with the data misfit. + """ return self._objfcts @objfcts.setter @@ -433,7 +407,7 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future residuals = [] - write_message("Calculating residuals") + for futures in self._futures: future_residuals = [] for objfct, worker in zip(futures, self._workers): From 39317c8eae5198ac3f131df26a28231fa48bdaae Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 22:22:37 -0800 Subject: [PATCH 79/84] Fix indentation error --- simpeg/dask/objective_function.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 375f5d71a1..7d46b966d7 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -207,8 +207,9 @@ def deriv(self, m, f=None): ) ) - count += 1 + count += 1 future_deriv = client.gather(future_deriv) + derivs += np.sum(future_deriv, axis=0) return derivs From f807df00340719c4325b9b644e91f0a38ca81b17 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 3 Feb 2025 22:23:46 -0800 Subject: [PATCH 80/84] Missing count increment --- simpeg/dask/objective_function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 7d46b966d7..24e2c769fa 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -140,6 +140,7 @@ def __call__(self, m, f=None): workers=worker, ) ) + count += 1 values = self.client.gather(values) return np.sum(values) From 997bbe041edf665fcf76622c54c609e70d35b5de Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 4 Feb 2025 08:48:32 -0800 Subject: [PATCH 81/84] Remove self.Ainv in IP inversions --- .../static/induced_polarization/simulation.py | 7 ++++--- .../static/induced_polarization/simulation_2d.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 1a1171cffe..28f3c12205 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -17,7 +17,7 @@ numcodecs.blosc.use_threads = False -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -40,8 +40,9 @@ def fields(self, m=None): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f) self._scale = scale.dobs - self.Ainv = Ainv - + self._stashed_fields = f + if return_Ainv: + return f, Ainv return f diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py index 0a535b2af6..9d9b4b2657 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation_2d.py @@ -10,7 +10,7 @@ from ..resistivity.simulation_2d import compute_J, getSourceTerm -def fields(self, m=None): +def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -39,8 +39,9 @@ def fields(self, m=None): scale[src, rx] = 1.0 / rx.eval(src, self.mesh, f_fwd) self._scale = scale.dobs - self.Ainv = Ainv - + self._stashed_fields = f + if return_Ainv: + return f, Ainv return f From 1035c9339c368100ad3227106d3010ff8b56f073 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 4 Feb 2025 16:00:32 -0800 Subject: [PATCH 82/84] Add ProgressBar to imports --- simpeg/dask/potential_fields/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/simpeg/dask/potential_fields/base.py b/simpeg/dask/potential_fields/base.py index 0455d90f0d..87c6cd4219 100644 --- a/simpeg/dask/potential_fields/base.py +++ b/simpeg/dask/potential_fields/base.py @@ -4,6 +4,7 @@ from dask.distributed import get_client import os from dask import delayed, array, config +from dask.diagnostics import ProgressBar from ..utils import compute_chunk_sizes From 4166931b3224456d2e33513065f98d8530b4793c Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 5 Feb 2025 09:48:53 -0800 Subject: [PATCH 83/84] Cleanups from review --- .../frequency_domain/simulation.py | 4 +- .../time_domain/simulation.py | 4 +- simpeg/dask/objective_function.py | 43 +++----- simpeg/meta/dask_sim.py | 98 +------------------ simpeg/meta/simulation.py | 2 - 5 files changed, 22 insertions(+), 129 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 692eabd63a..09a50e61a2 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -268,7 +268,9 @@ def compute_J(self, m, f=None): else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): + for block_derivs_chunks, addresses_chunks in zip( + blocks_receiver_derivs, blocks, strict=True + ): Jmatrix = parallel_block_compute( simulation, m, diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 5509417eb4..fe68e234f1 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -183,7 +183,7 @@ def compute_J(self, m, f=None): continue for ind, (block, field_deriv) in enumerate( - zip(blocks, times_field_derivs[tInd + 1]) + zip(blocks, times_field_derivs[tInd + 1], strict=True) ): ATinv_df_duT_v[ind] = get_field_deriv_block( self, @@ -434,7 +434,7 @@ def get_field_deriv_block( updated_ATinv_df_duT_v = [] for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip( - block, field_derivs, ATinv_df_duT_v, indices + block, field_derivs, ATinv_df_duT_v, indices, strict=True ): if len(ATinv_chunk) == 0: diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 24e2c769fa..cf5f70fa0f 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -6,39 +6,26 @@ from simpeg.utils import validate_list_of_types -# import os -# from time import time -# -# OUTFILE = os.getcwd() + "/update.txt" -# -# -# def write_message(message, mode="a"): -# with open(OUTFILE, mode) as f: -# f.write(message + "\n") -# -# -# write_message("Starting", mode="w+") - - -def _calc_fields(objfct, model): + +def _calc_fields(objfct, _): return objfct.simulation.fields(m=objfct.simulation.model) -def _calc_dpred(objfct, model): +def _calc_dpred(objfct, _): return objfct.simulation.dpred(m=objfct.simulation.model) -def _calc_residual(objfct, model): +def _calc_residual(objfct, _): return objfct.W * ( objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) ) -def _deriv(objfct, multiplier, model): +def _deriv(objfct, multiplier, _): return multiplier * objfct.deriv(objfct.simulation.model) -def _deriv2(objfct, multiplier, model, v): +def _deriv2(objfct, multiplier, _, v): return multiplier * objfct.deriv2(objfct.simulation.model, v) @@ -46,7 +33,7 @@ def _store_model(objfct, model): objfct.simulation.model = model -def _get_jtj_diag(objfct, model): +def _get_jtj_diag(objfct, _): jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) return jtj.flatten() @@ -56,12 +43,15 @@ def _validate_type_or_future_of_type( objects, obj_type, client, - workers=None, + workers: list[str] | None = None, return_workers=False, ): - # try: - # # validate as a list of things that need to be sent. - workers = [(worker.worker_address,) for worker in client.cluster.workers.values()] + + if workers is None: + workers = [ + (worker.worker_address,) for worker in client.cluster.workers.values() + ] + objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) @@ -126,7 +116,7 @@ def __call__(self, m, f=None): values = [] count = 0 for futures in self._futures: - for objfct, worker in zip(futures, self._workers): + for objfct, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: continue @@ -186,9 +176,6 @@ def deriv(self, m, f=None): client = self.client m_future = self._m_as_future - # if f is None: - # f = self.fields(m) - derivs = 0.0 count = 0 diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index 53389e43a2..df3a714db1 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -70,11 +70,6 @@ def _reduce(client, operation, items): return client.gather(items[0]) -# def _set_worker(obj, worker): -# obj.worker = worker -# return obj - - def _validate_type_or_future_of_type( property_name, objects, @@ -85,11 +80,10 @@ def _validate_type_or_future_of_type( ): try: # validate as a list of things that need to be sent. - # workers = [(worker.worker_address,) for worker in client.cluster.workers.values()] objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) - # objects[0].simulation.simulations[0].worker = workers[0] + if workers is None: objects = client.scatter(objects) else: @@ -261,7 +255,7 @@ def check_mapping(mapping, sim, model_len): raise ValueError("All mappings must have the same input length") if np.any(error_checks == 2): raise ValueError( - f"Simulations and mappings at indices {np.where(error_checks==2)}" + f"Simulations and mappings at indices {np.where(error_checks == 2)}" f" are inconsistent." ) @@ -454,94 +448,6 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag -# -# def _compute_j(sim, model): -# sim.model = model -# jmatrix = getattr(sim, "_Jmatrix", None) -# -# if jmatrix is None: -# jmatrix = sim.compute_J(model) -# -# return jmatrix -# -# -# def set_jmatrix(sim, jmatrix): -# sim._Jmatrix = jmatrix -# return sim - - -# class DaskMetaSimulationExplicit(DaskMetaSimulation): -# clean_on_model_update = ["_Jmatrix", "_stashed_fields"] -# -# def fields(self, m): -# self.model = m -# -# if getattr(self, "_stashed_fields", None) is not None: -# return self._stashed_fields -# -# client = self.client -# m_future = self._m_as_future -# # The above should pass the model to all the internal simulations. -# f = [] -# simulations = [] -# for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): -# # jmatrix = client.submit( -# # _compute_j, -# # sim, -# # m_future, -# # workers=worker, -# # ) -# # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) -# f.append( -# client.submit( -# _calc_fields, -# mapping, -# sim, -# m_future, -# self._repeat_sim, -# workers=worker, -# ) -# ) -# simulations.append(sim) -# -# self._stashed_fields = f -# # self.simulations = simulations -# return f -# -# def getJtJdiag(self, m, W=None, f=None): -# self.model = m -# m_future = self._m_as_future -# if getattr(self, "_jtjdiag", None) is None: -# if W is None: -# W = np.ones(self.survey.nD) -# else: -# W = W.diagonal() -# jtj_diag = [] -# client = self.client -# if f is None: -# f = self.fields(m) -# for i, (mapping, sim, worker, field) in enumerate( -# zip(self.mappings, self.simulations, self._workers, f) -# ): -# sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] -# -# jtj_diag.append( -# client.submit( -# _get_jtj_diag, -# mapping, -# sim, -# m_future, -# field, -# sim_w, -# self._repeat_sim, -# workers=worker, -# ) -# ) -# self._jtjdiag = _reduce(client, add, jtj_diag) -# -# return self._jtjdiag - - class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation): """A dask distributed version of :class:`.SumMetaSimulation`. diff --git a/simpeg/meta/simulation.py b/simpeg/meta/simulation.py index f0ddd27c96..5ede0799c5 100644 --- a/simpeg/meta/simulation.py +++ b/simpeg/meta/simulation.py @@ -307,8 +307,6 @@ def getJtJdiag(self, m, W=None, f=None): # (i.e. projections, multipliers, etc.). # It is usually close within a scaling factor for others, whose accuracy is controlled # by how diagonally dominant JtJ is. - # if f is None: - # f = self.fields(m) for i, (mapping, sim) in enumerate(zip(self.mappings, self.simulations)): if self._repeat_sim: sim.model = mapping * self.model From b13da0fb95dadfa94a4191b6e9ccafbb6ea06654 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 5 Feb 2025 09:59:03 -0800 Subject: [PATCH 84/84] Expose public attribute spatialP on dc receivers --- .../static/resistivity/receivers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/simpeg/electromagnetics/static/resistivity/receivers.py b/simpeg/electromagnetics/static/resistivity/receivers.py index ab1324d8ed..3607a4049b 100644 --- a/simpeg/electromagnetics/static/resistivity/receivers.py +++ b/simpeg/electromagnetics/static/resistivity/receivers.py @@ -30,6 +30,7 @@ def __init__( projField="phi", **kwargs, ): + self.spatialP = None super(BaseRx, self).__init__(locations=locations, **kwargs) self.orientation = orientation @@ -410,15 +411,15 @@ def getP(self, mesh, projected_grid, transpose=False): P, the interpolation matrix """ - if mesh.n_cells in self._Ps: - return self._Ps[mesh.n_cells] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P0 = mesh.get_interpolation_matrix(self.locations[0], projected_grid) P1 = mesh.get_interpolation_matrix(self.locations[1], projected_grid) P = P0 - P1 if self.storeProjections: - self._Ps[mesh.n_cells] = P + self.spatialP = P if transpose: P = P.toarray().T @@ -489,12 +490,12 @@ def getP(self, mesh, projected_grid): P, the interpolation matrix """ - if mesh.n_cells in self._Ps: - return self._Ps[mesh.n_cells] + if getattr(self, "spatialP", None) is not None: + return self.spatialP P = mesh.get_interpolation_matrix(self.locations, projected_grid) if self.storeProjections: - self._Ps[mesh.n_cells] = P + self.spatialP = P return P