diff --git a/simpeg/dask/electromagnetics/time_domain/simulation_1d.py b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py index 8eeb6058d2..4729756ea0 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation_1d.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py @@ -13,7 +13,7 @@ def Jmatrix(self): """ if getattr(self, "_Jmatrix", None) is None: Jmat = self.getJ(self.model) - self._Jmatrix = Jmat["ds"] + self._Jmatrix = Jmat["ds"] * self.sigmaDeriv return self._Jmatrix diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index aeb2da9878..e5e3c2d092 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,7 +1,7 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from .objective_function import DaskComboMisfits +from .objective_function import DistributedComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse from ..objective_function import ComboObjectiveFunction @@ -9,37 +9,59 @@ from simpeg.version import __version__ as simpeg_version -def get_dpred(self, m, f=None): +def get_nested_predicted(objfcts, m, f=None, return_residuals=False): dpreds = [] + residuals = [] + for objfct in objfcts: + + if isinstance(objfct, ComboObjectiveFunction): + nesting = get_nested_predicted( + objfct.objfcts, m, f=f, return_residuals=return_residuals + ) - if isinstance(self.dmisfit, DaskComboMisfits): - return self.dmisfit.get_dpred(m, f=f) + if return_residuals: + dpreds += nesting[0] + residuals += nesting[1] + else: + dpreds += nesting + else: + dpred = objfct.simulation.dpred(m, f=f) + dpreds += [np.asarray(dpred)] - for objfct in self.dmisfit.objfcts: - dpred = objfct.simulation.dpred(m, f=f) - dpreds += [np.asarray(dpred)] + if return_residuals: + residual = objfct.W * (objfct.data.dobs - dpred) + residuals += [np.asarray(residual)] + if return_residuals: + return dpreds, residuals return dpreds +def get_dpred(self, m, f=None, return_residuals=False): + if isinstance(self.dmisfit, DistributedComboMisfits): + results = self.dmisfit.get_dpred(m, f=f, return_residuals=return_residuals) + else: + results = get_nested_predicted( + self.dmisfit.objfcts, m, f=f, return_residuals=return_residuals + ) + + if return_residuals: + return np.hstack(results[0]), np.hstack(results[1]) + + return np.hstack(results) + + BaseInvProblem.get_dpred = get_dpred def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" - self.model = m - self.dpred = self.get_dpred(m) - residuals = [] - if isinstance(self.dmisfit, DaskComboMisfits): - residuals = self.dmisfit.residuals(m) - else: - for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residuals.append(objfct.W * (objfct.data.dobs - pred)) + if not np.allclose(self.model, m): + self.model = m + self.dpred, self.residuals = self.get_dpred(m, return_residuals=True) - phi_d = 0.0 - for residual in residuals: - phi_d += np.vdot(residual, residual) + phi_d = np.vdot(self.residuals, self.residuals) reg2Deriv = [] if isinstance(self.reg, ComboObjectiveFunction): diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 4b9addd8d7..a52040ce70 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,18 +1,52 @@ -from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction +from ..objective_function import ( + ComboObjectiveFunction, + BaseObjectiveFunction, + _validate_multiplier, + _check_length_objective_funcs_multipliers, +) import numpy as np -from dask.distributed import Client + +from dask.distributed import Client, Future from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types def _calc_fields(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + fields = [] + for objfct_ in objfct.objfcts: + fields.append(_calc_fields(objfct_, _)) + + return fields + return objfct.simulation.fields(m=objfct.simulation.model) -def _calc_dpred(objfct, _): - return objfct.simulation.dpred(m=objfct.simulation.model) +def _calc_dpred(objfct, _, return_residuals=False): + if isinstance(objfct, ComboObjectiveFunction): + dpreds = [] + residuals = [] + for objfct_ in objfct.objfcts: + + if return_residuals: + dpred_, residual_ = _calc_dpred(objfct_, _, return_residuals) + dpreds.append(dpred_) + residuals.append(residual_) + else: + dpreds.append(_calc_dpred(objfct_, _)) + + if return_residuals: + return np.hstack(dpreds), np.hstack(residuals) + return np.hstack(dpreds) + + dpred = objfct.simulation.dpred(m=objfct.simulation.model) + if return_residuals: + residual = objfct.W * (objfct.data.dobs - dpred) + return dpred, residual + + return dpred def _calc_objective(objfct, multiplier, model): @@ -20,21 +54,46 @@ def _calc_objective(objfct, multiplier, model): def _calc_residual(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + residuals = 0.0 + for objfct_ in objfct.objfcts: + residuals += _calc_residual(objfct_, _) + + return np.hstack(residuals) + return objfct.W * ( objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) ) def _deriv(objfct, multiplier, _): - return multiplier * objfct.deriv(objfct.simulation.model) + if isinstance(objfct, ComboObjectiveFunction): + deriv = 0.0 + for multiplier_, objfct_ in objfct: + deriv += _deriv(objfct_, multiplier_, _) + else: + deriv = objfct.deriv(objfct.simulation.model) + return multiplier * deriv def _deriv2(objfct, multiplier, _, v): - return multiplier * objfct.deriv2(objfct.simulation.model, v) + + if isinstance(objfct, ComboObjectiveFunction): + deriv2 = 0.0 + for multiplier_, objfct_ in objfct: + deriv2 += _deriv2(objfct_, multiplier_, _, v) + else: + deriv2 = objfct.deriv2(objfct.simulation.model, v) + return multiplier * deriv2 def _store_model(objfct, model): - objfct.simulation.model = model + + if isinstance(objfct, ComboObjectiveFunction): + for objfct_ in objfct.objfcts: + _store_model(objfct_, model) + else: + objfct.simulation.model = model def _setter_broadcast(objfct, key, value): @@ -45,15 +104,34 @@ def _setter_broadcast(objfct, key, value): setattr(objfct, key, value) for sim in objfct.simulation.simulations: - if hasattr(sim, key): - setattr(sim, key, value) + setattr(sim, key, value) def _get_jtj_diag(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + jtj = 0.0 + for objfct_ in objfct.objfcts: + jtj += _get_jtj_diag(objfct_, _) + + return jtj + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) return jtj.flatten() +def _set_worker(objfct, worker): + """ + Set the worker for the objective function. + """ + if isinstance(objfct, ComboObjectiveFunction): + for objfct_ in objfct.objfcts: + _set_worker(objfct_, worker) + + else: + for sim in objfct.simulation.simulations: + sim.worker = worker + + def _validate_type_or_future_of_type( property_name, objects, @@ -72,48 +150,52 @@ def _validate_type_or_future_of_type( property_name, objects, obj_type, ensure_unique=True ) workload = [[]] - lookup = {} + count = 0 for obj in objects: if count == len(workers): count = 0 workload.append([]) - obj.simulation.simulations[0].worker = workers[count] - future = client.scatter([obj], workers=workers[count])[0] - lookup[obj] = (future, workers[count]) - if hasattr(obj, "name"): - future.name = obj.name + + if isinstance(obj, Future): + future = obj + else: + future = client.scatter([obj], workers=workers[count])[0] workload[-1].append(future) count += 1 futures = [] + assignments = [] for work in workload: - for obj, worker in zip(work, workers): futures.append( client.submit( lambda v: not isinstance(v, obj_type), obj, workers=worker ) ) + assignments.append(client.submit(_set_worker, obj, worker, workers=worker)) + + client.gather(assignments) + is_not_obj = np.array(client.gather(futures)) if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") if return_workers: - return workload, workers, lookup + return workload, workers else: return workload -class DaskComboMisfits(ComboObjectiveFunction): +class DistributedComboMisfits(ComboObjectiveFunction): """ A composite objective function for distributed computing. """ def __init__( self, - objfcts: list[BaseObjectiveFunction], + objfcts: list[BaseObjectiveFunction] | list[Future], multipliers=None, client: Client | None = None, workers: list[str] | None = None, @@ -123,7 +205,13 @@ def __init__( self.client = client self.workers = workers - super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) + if multipliers is None: + multipliers = len(objfcts) * [1] + + super().__init__(**kwargs) + self._nP = None + self.objfcts = objfcts + self.multipliers = np.array(multipliers, dtype=float) def __call__(self, m, f=None): self.model = m @@ -132,8 +220,8 @@ def __call__(self, m, f=None): values = [] count = 0 - for futures in self._futures: - for objfct, worker in zip(futures, self._workers, strict=True): + for futures in self._workloads: + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: continue @@ -141,7 +229,7 @@ def __call__(self, m, f=None): values.append( client.submit( _calc_objective, - objfct, + future, self.multipliers[count], m_future, workers=worker, @@ -196,16 +284,16 @@ def deriv(self, m, f=None): derivs = 0.0 count = 0 - for futures in self._futures: + for futures in self._workloads: future_deriv = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: # don't evaluate the fct continue future_deriv.append( client.submit( _deriv, - objfct, + future, self.multipliers[count], m_future, workers=worker, @@ -237,17 +325,17 @@ def deriv2(self, m, v=None, f=None): derivs = 0.0 count = 0 - for futures in self._futures: + for futures in self._workloads: future_derivs = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: # don't evaluate the fct continue future_derivs.append( client.submit( _deriv2, - objfct, + future, self.multipliers[count], m_future, v_future, @@ -262,7 +350,7 @@ def deriv2(self, m, v=None, f=None): return derivs - def get_dpred(self, m, f=None): + def get_dpred(self, m, f=None, return_residuals=False): """ Request calculation of predicted data from all simulations. """ @@ -271,21 +359,32 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] - - for futures in self._futures: + residuals = [] + for futures in self._workloads: future_preds = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): future_preds.append( client.submit( _calc_dpred, - objfct, + future, m_future, + return_residuals, workers=worker, ) ) - dpred += client.gather(future_preds) + results = client.gather(future_preds) + + for result in results: + if return_residuals: + dpred += [result[0]] + residuals += [result[1]] + else: + dpred += [result] + + if return_residuals: + return np.hstack(dpred), np.hstack(residuals) - return dpred + return np.hstack(dpred) def getJtJdiag(self, m, f=None): """ @@ -298,14 +397,14 @@ def getJtJdiag(self, m, f=None): jtj_diag = 0.0 client = self.client - for futures in self._futures: + for futures in self._workloads: work = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): work.append( client.submit( _get_jtj_diag, - objfct, + future, m_future, workers=worker, ) @@ -332,13 +431,13 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for futures in self._futures: + for futures in self._workloads: f.append([]) - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): f[-1].append( client.submit( _calc_fields, - objfct, + future, m_future, workers=worker, ) @@ -367,12 +466,12 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for futures in self._futures: - for objfct, worker in zip(futures, self._workers): + for futures in self._workloads: + for future, worker in zip(futures, self._workers, strict=True): stores.append( client.submit( _store_model, - objfct, + future, self._m_as_future, workers=worker, ) @@ -391,23 +490,46 @@ def objfcts(self): def objfcts(self, objfcts): client = self.client - futures, workers, lookup = _validate_type_or_future_of_type( + futures, workers = _validate_type_or_future_of_type( "objfcts", objfcts, - L2DataMisfit, + (L2DataMisfit, Future, ComboObjectiveFunction), client, workers=self.workers, return_workers=True, ) self._objfcts = objfcts - self._futures = futures + self._workloads = futures self._workers = workers - self._lookup = { - misfit.simulation: (future, worker) - for misfit, (future, worker) in lookup.items() - } + @property + def multipliers(self): + r"""Multipliers for the objective functions. + + For a composite objective function :math:`\phi`, that is, a weighted sum of + objective functions :math:`\phi_i` with multipliers :math:`c_i` such that + + .. math:: + \phi = \sum_{i = 1}^N c_i \phi_i, + + this method returns the multipliers :math:`c_i` in + the same order of the ``objfcts``. + + Returns + ------- + list of int + Multipliers for the objective functions. + """ + return self._multipliers + + @multipliers.setter + def multipliers(self, value): + """Set multipliers attribute after checking if they are valid.""" + for multiplier in value: + _validate_multiplier(multiplier) + _check_length_objective_funcs_multipliers(self.objfcts, value) + self._multipliers = value def residuals(self, m, f=None): """ @@ -419,13 +541,13 @@ def residuals(self, m, f=None): m_future = self._m_as_future residuals = [] - for futures in self._futures: + for futures in self._workloads: future_residuals = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): future_residuals.append( client.submit( _calc_residual, - objfct, + future, m_future, workers=worker, ) @@ -433,26 +555,3 @@ def residuals(self, m, f=None): residuals += client.gather(future_residuals) return residuals - - def broadcast_updates(self, updates: dict): - """ - Set the attributes of the objective functions and simulations - """ - stores = [] - client = self.client - for fun, (key, value) in updates.items(): - if fun not in self._lookup: - continue - - future, worker = self._lookup[fun] - - stores.append( - client.submit( - _setter_broadcast, - future, - key, - value, - workers=worker, - ) - ) - self.client.gather(stores) # blocking call to ensure all models were stored diff --git a/simpeg/directives/_regularization.py b/simpeg/directives/_regularization.py index ec14c5b6d6..16dca5a93e 100644 --- a/simpeg/directives/_regularization.py +++ b/simpeg/directives/_regularization.py @@ -218,12 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float: chi_factor : float Chi factor to compute the target misfit from. """ - value = 0 - - for survey in self.survey: - value += survey.nD * chi_factor - - return value + return self.invProb.dpred.shape[0] * chi_factor def adjust_cooling_schedule(self): """ diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index 15329d0f95..9ad798ab67 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -16,21 +16,7 @@ from geoh5py.groups import UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace - - -def compute_JtJdiags(data_misfit, m): - if hasattr(data_misfit, "getJtJdiag"): - return data_misfit.getJtJdiag(m) - else: - jtj_diags = [] - for dmisfit in data_misfit.objfcts: - jtj_diags.append(dmisfit.getJtJdiag(m)) - - jtj_diag = np.zeros_like(jtj_diags[0]) - for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): - jtj_diag += multiplier * diag - - return np.asarray(jtj_diag) +from simpeg.directives.directives import compute_JtJdiags class BaseSaveGeoH5(InversionDirective, ABC): @@ -369,8 +355,11 @@ def get_values(self, values: list[np.ndarray] | None): else: dpred = getattr(self.invProb, "dpred", None) if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) + dpred, residuals = self.invProb.get_dpred( + self.invProb.model, return_residuals=True + ) self.invProb.dpred = dpred + self.invProb.residuals = residuals if self.joint_index is not None: dpred = [dpred[ind] for ind in self.joint_index] diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 21fade5e26..44a183184d 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -54,12 +54,17 @@ def compute_JtJdiags(data_misfit, m): if hasattr(data_misfit, "getJtJdiag"): return data_misfit.getJtJdiag(m) else: - jtj_diags = [] + jtj_diag_list = [] + jtj_diag = np.zeros_like(m) + for dmisfit in data_misfit.objfcts: - jtj_diags.append(dmisfit.getJtJdiag(m)) + if isinstance(dmisfit, ComboObjectiveFunction): + jtj_diag += compute_JtJdiags(dmisfit, m) + + else: + jtj_diag_list.append(dmisfit.getJtJdiag(m)) - jtj_diag = np.zeros_like(jtj_diags[0]) - for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + for multiplier, diag in zip(data_misfit.multipliers, jtj_diag_list): jtj_diag += multiplier * diag return np.asarray(jtj_diag) @@ -203,7 +208,7 @@ def dmisfit(self) -> BaseObjectiveFunction: The data misfit associated with the directive. """ if getattr(self, "_dmisfit", None) is None: - self.dmisfit = self.invProb.dmisfit # go through the setter + self._dmisfit = self.invProb.dmisfit # go through the setter return self._dmisfit @dmisfit.setter