Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def _calc_dpred(objfct, _):
return objfct.simulation.dpred(m=objfct.simulation.model)


def _calc_objective(objfct, multiplier, model):
return multiplier * objfct(model)


def _calc_residual(objfct, _):
return objfct.W * (
objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model)
Expand All @@ -33,6 +37,18 @@ def _store_model(objfct, model):
objfct.simulation.model = model


def _setter_broadcast(objfct, key, value):
"""
Broadcast a value to all workers.
"""
if hasattr(objfct, key):
setattr(objfct, key, value)

for sim in objfct.simulation.simulations:
if hasattr(sim, key):
setattr(sim, key, value)


def _get_jtj_diag(objfct, _):
jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W)
return jtj.flatten()
Expand Down Expand Up @@ -387,6 +403,11 @@ def objfcts(self, objfcts):
self._futures = futures
self._workers = workers

self._lookup = {
obj.simulation: (future, worker)
for future, worker, obj in zip(futures[0], workers, objfcts)
}

def residuals(self, m, f=None):
"""
Compute the residual for the data misfit.
Expand All @@ -411,3 +432,26 @@ def residuals(self, m, f=None):
residuals += client.gather(future_residuals)

return residuals

def broadcast_updates(self, updates: dict):
"""
Set the attributes of the objective functions and simulations
"""
stores = []
client = self.client
for fun, (key, value) in updates.items():
if fun not in self._lookup:
continue

objfct, worker = self._lookup[fun]

stores.append(
client.submit(
_setter_broadcast,
objfct,
key,
value,
workers=worker,
)
)
self.client.gather(stores) # blocking call to ensure all models were stored
16 changes: 14 additions & 2 deletions simpeg/directives/_vector_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,25 @@ def endIter(self):
)
self.opt.upper[indices[nC:]] = np.inf

updates = {}
for simulation in self.simulations:
if isinstance(simulation, MetaSimulation):
for sim in simulation.simulations:
sim.chiMap = SphericalSystem() * sim.chiMap

if hasattr(self.dmisfit, "client"):
updates[simulation] = (
"chiMap",
SphericalSystem() * simulation.simulations[0].chiMap,
)
else:
simulation.simulations[0].chiMap = (
SphericalSystem() * simulation.simulations[0].chiMap
)
else:
simulation.chiMap = SphericalSystem() * simulation.chiMap

if hasattr(self.dmisfit, "client"):
self.dmisfit.broadcast_updates(updates)

# Add and update directives
for directive in self.inversion.directiveList.dList:
if (
Expand Down
9 changes: 4 additions & 5 deletions simpeg/maps/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,11 +925,10 @@ def __init__(self, mesh=None, nP=None, **kwargs):
self.model = None

def sphericalDeriv(self, model):
if getattr(self, "model", None) is None:
self.model = model

if getattr(self, "_sphericalDeriv", None) is None or not all(
self.model == model
if (
getattr(self, "_sphericalDeriv", None) is None
or getattr(self, "model", None) is None
or not all(self.model == model)
):
self.model = model

Expand Down
Loading