From 60bcbe68b4dfb7a7316fa44173d8af292b8ce9c2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 18 Apr 2025 13:25:43 -0700 Subject: [PATCH 1/2] Add broadcast update to future simulations --- simpeg/dask/objective_function.py | 44 +++++++++++++++++++++++++++++ simpeg/directives/_vector_models.py | 10 +++++-- simpeg/maps/_base.py | 9 +++--- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index cf5f70fa0f..f10bc965cd 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -15,6 +15,10 @@ def _calc_dpred(objfct, _): return objfct.simulation.dpred(m=objfct.simulation.model) +def _calc_objective(objfct, multiplier, model): + return multiplier * objfct(model) + + def _calc_residual(objfct, _): return objfct.W * ( objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) @@ -33,6 +37,18 @@ def _store_model(objfct, model): objfct.simulation.model = model +def _setter_broadcast(objfct, key, value): + """ + Broadcast a value to all workers. + """ + if hasattr(objfct, key): + setattr(objfct, key, value) + + for sim in objfct.simulation.simulations: + if hasattr(sim, key): + setattr(sim, key, value) + + def _get_jtj_diag(objfct, _): jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) return jtj.flatten() @@ -387,6 +403,11 @@ def objfcts(self, objfcts): self._futures = futures self._workers = workers + self._lookup = { + obj.simulation: (future, worker) + for future, worker, obj in zip(futures[0], workers, objfcts) + } + def residuals(self, m, f=None): """ Compute the residual for the data misfit. @@ -411,3 +432,26 @@ def residuals(self, m, f=None): residuals += client.gather(future_residuals) return residuals + + def broadcast_updates(self, updates: dict): + """ + Set the attributes of the objective functions and simulations + """ + stores = [] + client = self.client + for fun, (key, value) in updates.items(): + if fun not in self._lookup: + continue + + objfct, worker = self._lookup[fun] + + stores.append( + client.submit( + _setter_broadcast, + objfct, + key, + value, + workers=worker, + ) + ) + self.client.gather(stores) # blocking call to ensure all models were stored diff --git a/simpeg/directives/_vector_models.py b/simpeg/directives/_vector_models.py index 66678ae127..9f219d4be8 100644 --- a/simpeg/directives/_vector_models.py +++ b/simpeg/directives/_vector_models.py @@ -214,13 +214,19 @@ def endIter(self): ) self.opt.upper[indices[nC:]] = np.inf + updates = {} for simulation in self.simulations: if isinstance(simulation, MetaSimulation): - for sim in simulation.simulations: - sim.chiMap = SphericalSystem() * sim.chiMap + updates[simulation] = ( + "chiMap", + SphericalSystem() * simulation.simulations[0].chiMap, + ) else: simulation.chiMap = SphericalSystem() * simulation.chiMap + if hasattr(self.dmisfit, "client"): + self.dmisfit.broadcast_updates(updates) + # Add and update directives for directive in self.inversion.directiveList.dList: if ( diff --git a/simpeg/maps/_base.py b/simpeg/maps/_base.py index 2671b2a87f..40de5d503c 100644 --- a/simpeg/maps/_base.py +++ b/simpeg/maps/_base.py @@ -925,11 +925,10 @@ def __init__(self, mesh=None, nP=None, **kwargs): self.model = None def sphericalDeriv(self, model): - if getattr(self, "model", None) is None: - self.model = model - - if getattr(self, "_sphericalDeriv", None) is None or not all( - self.model == model + if ( + getattr(self, "_sphericalDeriv", None) is None + or getattr(self, "model", None) is None + or not all(self.model == model) ): self.model = model From 730fa00953165f6fd5a87c3942d05cf8e0a7e164 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 20 Apr 2025 20:29:48 -0700 Subject: [PATCH 2/2] Fix for non distributed mvi --- simpeg/directives/_vector_models.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/simpeg/directives/_vector_models.py b/simpeg/directives/_vector_models.py index 9f219d4be8..603777954e 100644 --- a/simpeg/directives/_vector_models.py +++ b/simpeg/directives/_vector_models.py @@ -217,10 +217,16 @@ def endIter(self): updates = {} for simulation in self.simulations: if isinstance(simulation, MetaSimulation): - updates[simulation] = ( - "chiMap", - SphericalSystem() * simulation.simulations[0].chiMap, - ) + + if hasattr(self.dmisfit, "client"): + updates[simulation] = ( + "chiMap", + SphericalSystem() * simulation.simulations[0].chiMap, + ) + else: + simulation.simulations[0].chiMap = ( + SphericalSystem() * simulation.simulations[0].chiMap + ) else: simulation.chiMap = SphericalSystem() * simulation.chiMap