From 9153281400f37a82df0b51b00feaab774333aa9e Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 8 Aug 2025 12:20:51 -0700 Subject: [PATCH 01/14] Create ConcurrentComboMisfit for parallel computes --- simpeg/dask/objective_function.py | 210 ++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 4b9addd8d7..9b0d70b76f 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,6 +1,8 @@ from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction import numpy as np +from concurrent.futures import ProcessPoolExecutor, as_completed + from dask.distributed import Client from ..data_misfit import L2DataMisfit @@ -456,3 +458,211 @@ def broadcast_updates(self, updates: dict): ) ) self.client.gather(stores) # blocking call to ensure all models were stored + + +class ConcurrentComboMisfits(ComboObjectiveFunction): + """ + A composite objective function for distributed computing. + """ + + def __init__( + self, + objfcts: list[BaseObjectiveFunction], + multipliers=None, + **kwargs, + ): + self._model: np.ndarray | None = None + + super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) + + def __call__(self, m, f=None): + self.model = m + + futures = [] + values = [] + count = 0 + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: + continue + + values.append( + executor.submit(_calc_objective, objfct, self.multipliers[count], m) + ) + count += 1 + + for future in as_completed(futures): + values.append(future.result()) + + return np.sum(values) + + def deriv(self, m, f=None): + """ + First derivative of the composite objective function is the sum of the + derivatives of each objective function in the list, weighted by their + respective multplier. + + :param numpy.ndarray m: model + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + + futures = [] + derivs = 0.0 + count = 0 + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + futures.append( + executor.submit( + _deriv, + objfct, + self.multipliers[count], + m, + ) + ) + + count += 1 + + for future in as_completed(futures): + derivs += future.result() + + return derivs + + def deriv2(self, m, v=None, f=None): + """ + Second derivative of the composite objective function is the sum of the + second derivatives of each objective function in the list, weighted by + their respective multplier. + + :param numpy.ndarray m: model + :param numpy.ndarray v: vector we are multiplying by + :param SimPEG.Fields f: Fields object (if applicable) + """ + self.model = m + + futures = [] + derivs = 0.0 + count = 0 + + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + futures.append( + executor.submit(_deriv2, objfct, self.multipliers[count], m, v) + ) + count += 1 + + for future in as_completed(futures): + derivs += future.result() + + return derivs + + def get_dpred(self, m, f=None): + """ + Request calculation of predicted data from all simulations. + """ + self.model = m + + dpred = [] + futures = [] + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + futures.append(executor.submit(_calc_dpred, objfct, m)) + + for future in as_completed(futures): + dpred.append(future.result()) + + return dpred + + def getJtJdiag(self, m, f=None): + """ + Request calculation of the diagonal of JtJ from all simulations. + """ + self.model = m + + if getattr(self, "_jtjdiag", None) is None: + + jtj_diag = 0.0 + futures = [] + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + futures.append(executor.submit(_get_jtj_diag, objfct, m)) + + for future in as_completed(futures): + jtj_diag += future.result() + + self._jtjdiag = jtj_diag + + return self._jtjdiag + + def fields(self, m): + """ + Request calculation of fields from all simulations. + + Store list of futures for fields in self._stashed_fields. + """ + self.model = m + + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields + # The above should pass the model to all the internal simulations. + futures = [] + fields = [] + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + futures.append(executor.submit(_calc_fields, objfct, m)) + for future in as_completed(futures): + fields.append(future.result()) + + self._stashed_fields = fields + return fields + + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + # Only send the model to the internal simulations if it was updated. + if ( + isinstance(value, np.ndarray) + and isinstance(self.model, np.ndarray) + and np.allclose(value, self.model) + ): + return + + self._stashed_fields = None + self._jtjdiag = None + + stores = [] + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + stores.append(executor.submit(_store_model, objfct, value)) + + # blocking call to ensure all models were stored + for store in as_completed(stores): + store.result() + + self._model = value + + def residuals(self, m, f=None): + """ + Compute the residual for the data misfit. + """ + self.model = m + + futures = [] + residuals = [] + with ProcessPoolExecutor() as executor: + for objfct in self.objfcts: + futures.append(executor.submit(_calc_residual, objfct, m)) + + for future in as_completed(futures): + residuals.append(future.result()) + + return residuals From 245f620eae3bca3434f07608dd5e051e605e9055 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 13 Aug 2025 08:22:37 -0700 Subject: [PATCH 02/14] Flatten the values --- simpeg/directives/_save_geoh5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index bc289b3f46..15329d0f95 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -287,7 +287,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq values = prop[ii, cc, :] if self.sorting is not None: - values = values[self.sorting] + values = values[self.sorting].flatten() label = self._channel_label(ii, channel) channel_name, base_name = self.get_names( From 42bcb54a3f18dd824401dd9292b204ae0636eb35 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 21 Aug 2025 13:44:53 -0700 Subject: [PATCH 03/14] Switch to dask.delayed operations. Rename classes --- simpeg/dask/objective_function.py | 167 ++++++++++-------------------- 1 file changed, 54 insertions(+), 113 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 9b0d70b76f..806ff01039 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,9 +1,9 @@ from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction import numpy as np -from concurrent.futures import ProcessPoolExecutor, as_completed from dask.distributed import Client +from dask import array, delayed, compute from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types @@ -91,7 +91,6 @@ def _validate_type_or_future_of_type( futures = [] for work in workload: - for obj, worker in zip(work, workers): futures.append( client.submit( @@ -108,7 +107,7 @@ def _validate_type_or_future_of_type( return workload -class DaskComboMisfits(ComboObjectiveFunction): +class DistributedComboMisfits(ComboObjectiveFunction): """ A composite objective function for distributed computing. """ @@ -460,7 +459,7 @@ def broadcast_updates(self, updates: dict): self.client.gather(stores) # blocking call to ensure all models were stored -class ConcurrentComboMisfits(ComboObjectiveFunction): +class DaskComboMisfits(ComboObjectiveFunction): """ A composite objective function for distributed computing. """ @@ -469,6 +468,7 @@ def __init__( self, objfcts: list[BaseObjectiveFunction], multipliers=None, + worker: str | None = None, **kwargs, ): self._model: np.ndarray | None = None @@ -479,22 +479,17 @@ def __call__(self, m, f=None): self.model = m futures = [] - values = [] count = 0 - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: - continue - values.append( - executor.submit(_calc_objective, objfct, self.multipliers[count], m) - ) - count += 1 + delayed_call = delayed(_calc_objective) + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: + continue - for future in as_completed(futures): - values.append(future.result()) + futures.append(delayed_call(objfct, self.multipliers[count], m)) + count += 1 - return np.sum(values) + return np.sum(compute(futures)[0]) def deriv(self, m, f=None): """ @@ -508,28 +503,29 @@ def deriv(self, m, f=None): self.model = m futures = [] - derivs = 0.0 + count = 0 - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue - futures.append( - executor.submit( - _deriv, + delayed_call = delayed(_deriv) + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + futures.append( + array.from_delayed( + delayed_call( objfct, self.multipliers[count], m, - ) + ), + shape=m.shape, + dtype=float, ) + ) - count += 1 - - for future in as_completed(futures): - derivs += future.result() + count += 1 - return derivs + return array.sum(futures, axis=0).compute() def deriv2(self, m, v=None, f=None): """ @@ -544,23 +540,23 @@ def deriv2(self, m, v=None, f=None): self.model = m futures = [] - derivs = 0.0 count = 0 - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue + delayed_call = delayed(_deriv2) + for objfct in self.objfcts: + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue - futures.append( - executor.submit(_deriv2, objfct, self.multipliers[count], m, v) + futures.append( + array.from_delayed( + delayed_call(objfct, self.multipliers[count], m, v), + shape=m.shape, + dtype=float, ) - count += 1 - - for future in as_completed(futures): - derivs += future.result() + ) + count += 1 - return derivs + return array.sum(futures, axis=0).compute() def get_dpred(self, m, f=None): """ @@ -568,16 +564,13 @@ def get_dpred(self, m, f=None): """ self.model = m - dpred = [] futures = [] - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - futures.append(executor.submit(_calc_dpred, objfct, m)) + delayed_call = delayed(_calc_dpred) - for future in as_completed(futures): - dpred.append(future.result()) + for objfct in self.objfcts: + futures.append(delayed_call(objfct, m)) - return dpred + return compute(futures)[0] def getJtJdiag(self, m, f=None): """ @@ -587,68 +580,19 @@ def getJtJdiag(self, m, f=None): if getattr(self, "_jtjdiag", None) is None: - jtj_diag = 0.0 futures = [] - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - futures.append(executor.submit(_get_jtj_diag, objfct, m)) - - for future in as_completed(futures): - jtj_diag += future.result() - - self._jtjdiag = jtj_diag - - return self._jtjdiag - - def fields(self, m): - """ - Request calculation of fields from all simulations. - - Store list of futures for fields in self._stashed_fields. - """ - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - # The above should pass the model to all the internal simulations. - futures = [] - fields = [] - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - futures.append(executor.submit(_calc_fields, objfct, m)) - for future in as_completed(futures): - fields.append(future.result()) - - self._stashed_fields = fields - return fields - - @property - def model(self): - return self._model + delayed_call = delayed(_get_jtj_diag) - @model.setter - def model(self, value): - # Only send the model to the internal simulations if it was updated. - if ( - isinstance(value, np.ndarray) - and isinstance(self.model, np.ndarray) - and np.allclose(value, self.model) - ): - return - - self._stashed_fields = None - self._jtjdiag = None - - stores = [] - with ProcessPoolExecutor() as executor: for objfct in self.objfcts: - stores.append(executor.submit(_store_model, objfct, value)) + futures.append( + array.from_delayed( + delayed_call(objfct, m), shape=m.shape, dtype=float + ) + ) - # blocking call to ensure all models were stored - for store in as_completed(stores): - store.result() + self._jtjdiag = array.sum(futures, axis=0).compute() - self._model = value + return self._jtjdiag def residuals(self, m, f=None): """ @@ -657,12 +601,9 @@ def residuals(self, m, f=None): self.model = m futures = [] - residuals = [] - with ProcessPoolExecutor() as executor: - for objfct in self.objfcts: - futures.append(executor.submit(_calc_residual, objfct, m)) - for future in as_completed(futures): - residuals.append(future.result()) + delayed_call = delayed(_calc_residual) + for objfct in self.objfcts: + futures.append(delayed_call(objfct, m)) - return residuals + return compute(futures)[0] From bb705cd2cf29f3a9b7547e2db92cb3f2bdaaef64 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 21 Aug 2025 13:51:19 -0700 Subject: [PATCH 04/14] Bring back broadcasting of model --- simpeg/dask/objective_function.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 806ff01039..baeb7da280 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -607,3 +607,26 @@ def residuals(self, m, f=None): futures.append(delayed_call(objfct, m)) return compute(futures)[0] + + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + # Only send the model to the internal simulations if it was updated. + if ( + isinstance(value, np.ndarray) + and isinstance(self.model, np.ndarray) + and np.allclose(value, self.model) + ): + return + + self._jtjdiag = None + + stores = [] + delayed_call = delayed(_store_model) + for objfct in self.objfcts: + stores.append(delayed_call(objfct, value)) + compute(stores) + self._model = value From 3ac48e22eb2c2d3736987fd638fae972564f9868 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 21 Aug 2025 15:43:34 -0700 Subject: [PATCH 05/14] Fix DistributedCombo --- simpeg/dask/inverse_problem.py | 6 +- simpeg/dask/objective_function.py | 134 ++++++++++++++++++++---------- 2 files changed, 93 insertions(+), 47 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index aeb2da9878..7d015f8b2f 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,7 +1,7 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from .objective_function import DaskComboMisfits +from .objective_function import DaskComboMisfits, DistributedComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse from ..objective_function import ComboObjectiveFunction @@ -12,7 +12,7 @@ def get_dpred(self, m, f=None): dpreds = [] - if isinstance(self.dmisfit, DaskComboMisfits): + if isinstance(self.dmisfit, DaskComboMisfits | DistributedComboMisfits): return self.dmisfit.get_dpred(m, f=f) for objfct in self.dmisfit.objfcts: @@ -31,7 +31,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): self.dpred = self.get_dpred(m) residuals = [] - if isinstance(self.dmisfit, DaskComboMisfits): + if isinstance(self.dmisfit, DaskComboMisfits | DistributedComboMisfits): residuals = self.dmisfit.residuals(m) else: for (_, objfct), pred in zip(self.dmisfit, self.dpred): diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index baeb7da280..6c40609289 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -1,8 +1,13 @@ -from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction +from ..objective_function import ( + ComboObjectiveFunction, + BaseObjectiveFunction, + _validate_multiplier, + _check_length_objective_funcs_multipliers, +) import numpy as np -from dask.distributed import Client +from dask.distributed import Client, Future from dask import array, delayed, compute from ..data_misfit import L2DataMisfit @@ -47,8 +52,7 @@ def _setter_broadcast(objfct, key, value): setattr(objfct, key, value) for sim in objfct.simulation.simulations: - if hasattr(sim, key): - setattr(sim, key, value) + setattr(sim, key, value) def _get_jtj_diag(objfct, _): @@ -56,6 +60,14 @@ def _get_jtj_diag(objfct, _): return jtj.flatten() +def _set_worker(objfct, worker): + """ + Set the worker for the objective function. + """ + for sim in objfct.simulation.simulations: + sim.worker = worker + + def _validate_type_or_future_of_type( property_name, objects, @@ -74,22 +86,23 @@ def _validate_type_or_future_of_type( property_name, objects, obj_type, ensure_unique=True ) workload = [[]] - lookup = {} + count = 0 for obj in objects: if count == len(workers): count = 0 workload.append([]) - obj.simulation.simulations[0].worker = workers[count] - future = client.scatter([obj], workers=workers[count])[0] - lookup[obj] = (future, workers[count]) - if hasattr(obj, "name"): - future.name = obj.name + + if isinstance(obj, Future): + future = obj + else: + future = client.scatter([obj], workers=workers[count])[0] workload[-1].append(future) count += 1 futures = [] + assignments = [] for work in workload: for obj, worker in zip(work, workers): futures.append( @@ -97,12 +110,15 @@ def _validate_type_or_future_of_type( lambda v: not isinstance(v, obj_type), obj, workers=worker ) ) + assignments.append(client.submit(_set_worker, obj, worker)) + + client.gather(assignments) is_not_obj = np.array(client.gather(futures)) if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") if return_workers: - return workload, workers, lookup + return workload, workers else: return workload @@ -114,7 +130,7 @@ class DistributedComboMisfits(ComboObjectiveFunction): def __init__( self, - objfcts: list[BaseObjectiveFunction], + objfcts: list[BaseObjectiveFunction] | list[Future], multipliers=None, client: Client | None = None, workers: list[str] | None = None, @@ -124,7 +140,13 @@ def __init__( self.client = client self.workers = workers - super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) + if multipliers is None: + multipliers = len(objfcts) * [1] + + super().__init__(**kwargs) + + self.objfcts = objfcts + self.multipliers = np.array(multipliers, dtype=float) def __call__(self, m, f=None): self.model = m @@ -392,10 +414,10 @@ def objfcts(self): def objfcts(self, objfcts): client = self.client - futures, workers, lookup = _validate_type_or_future_of_type( + futures, workers = _validate_type_or_future_of_type( "objfcts", objfcts, - L2DataMisfit, + (L2DataMisfit, Future), client, workers=self.workers, return_workers=True, @@ -405,10 +427,33 @@ def objfcts(self, objfcts): self._futures = futures self._workers = workers - self._lookup = { - misfit.simulation: (future, worker) - for misfit, (future, worker) in lookup.items() - } + @property + def multipliers(self): + r"""Multipliers for the objective functions. + + For a composite objective function :math:`\phi`, that is, a weighted sum of + objective functions :math:`\phi_i` with multipliers :math:`c_i` such that + + .. math:: + \phi = \sum_{i = 1}^N c_i \phi_i, + + this method returns the multipliers :math:`c_i` in + the same order of the ``objfcts``. + + Returns + ------- + list of int + Multipliers for the objective functions. + """ + return self._multipliers + + @multipliers.setter + def multipliers(self, value): + """Set multipliers attribute after checking if they are valid.""" + for multiplier in value: + _validate_multiplier(multiplier) + _check_length_objective_funcs_multipliers(self.objfcts, value) + self._multipliers = value def residuals(self, m, f=None): """ @@ -435,28 +480,29 @@ def residuals(self, m, f=None): return residuals - def broadcast_updates(self, updates: dict): - """ - Set the attributes of the objective functions and simulations - """ - stores = [] - client = self.client - for fun, (key, value) in updates.items(): - if fun not in self._lookup: - continue - - future, worker = self._lookup[fun] - - stores.append( - client.submit( - _setter_broadcast, - future, - key, - value, - workers=worker, - ) - ) - self.client.gather(stores) # blocking call to ensure all models were stored + # + # def broadcast_updates(self, updates: dict): + # """ + # Set the attributes of the objective functions and simulations + # """ + # stores = [] + # client = self.client + # for fun, (key, value) in updates.items(): + # if fun not in self._lookup: + # continue + # + # future, worker = self._lookup[fun] + # + # stores.append( + # client.submit( + # _setter_broadcast, + # future, + # key, + # value, + # workers=worker, + # ) + # ) + # self.client.gather(stores) # blocking call to ensure all models were stored class DaskComboMisfits(ComboObjectiveFunction): @@ -525,7 +571,7 @@ def deriv(self, m, f=None): count += 1 - return array.sum(futures, axis=0).compute() + return array.vstack(futures).sum(axis=0).compute() def deriv2(self, m, v=None, f=None): """ @@ -556,7 +602,7 @@ def deriv2(self, m, v=None, f=None): ) count += 1 - return array.sum(futures, axis=0).compute() + return array.vstack(futures).sum(axis=0).compute() def get_dpred(self, m, f=None): """ @@ -590,7 +636,7 @@ def getJtJdiag(self, m, f=None): ) ) - self._jtjdiag = array.sum(futures, axis=0).compute() + self._jtjdiag = array.vstack(futures).sum(axis=0).compute() return self._jtjdiag From 2f17ab883182e1d88bd4cc6628695b0b80673e8b Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 22 Aug 2025 15:12:34 -0700 Subject: [PATCH 06/14] Generalize functions for combos. --- simpeg/dask/objective_function.py | 63 ++++++++++++++++++++++++++++--- simpeg/directives/directives.py | 11 +++++- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 6c40609289..13792768f5 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -15,10 +15,24 @@ def _calc_fields(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + fields = [] + for objfct_ in objfct.objfcts: + fields.append(_calc_fields(objfct_, _)) + + return fields + return objfct.simulation.fields(m=objfct.simulation.model) def _calc_dpred(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + dpreds = [] + for objfct_ in objfct.objfcts: + dpreds.append(_calc_dpred(objfct_, _)) + + return np.hstack(dpreds) + return objfct.simulation.dpred(m=objfct.simulation.model) @@ -27,21 +41,46 @@ def _calc_objective(objfct, multiplier, model): def _calc_residual(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + residuals = 0.0 + for objfct_ in objfct.objfcts: + residuals += _calc_residual(objfct_, _) + + return np.hstack(residuals) + return objfct.W * ( objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) ) def _deriv(objfct, multiplier, _): - return multiplier * objfct.deriv(objfct.simulation.model) + if isinstance(objfct, ComboObjectiveFunction): + deriv = 0.0 + for multiplier_, objfct_ in objfct: + deriv += _deriv(objfct_, multiplier_, _) + else: + deriv = objfct.deriv(objfct.simulation.model) + return multiplier * deriv def _deriv2(objfct, multiplier, _, v): - return multiplier * objfct.deriv2(objfct.simulation.model, v) + + if isinstance(objfct, ComboObjectiveFunction): + deriv2 = 0.0 + for multiplier_, objfct_ in objfct: + deriv2 += _deriv2(objfct_, multiplier_, _, v) + else: + deriv2 = objfct.deriv2(objfct.simulation.model, v) + return multiplier * deriv2 def _store_model(objfct, model): - objfct.simulation.model = model + + if isinstance(objfct, ComboObjectiveFunction): + for objfct_ in objfct.objfcts: + _store_model(objfct_, model) + else: + objfct.simulation.model = model def _setter_broadcast(objfct, key, value): @@ -56,6 +95,13 @@ def _setter_broadcast(objfct, key, value): def _get_jtj_diag(objfct, _): + if isinstance(objfct, ComboObjectiveFunction): + jtj = 0.0 + for objfct_ in objfct.objfcts: + jtj += _get_jtj_diag(objfct_, _) + + return jtj + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) return jtj.flatten() @@ -64,8 +110,13 @@ def _set_worker(objfct, worker): """ Set the worker for the objective function. """ - for sim in objfct.simulation.simulations: - sim.worker = worker + if isinstance(objfct, ComboObjectiveFunction): + for objfct_ in objfct.objfcts: + _set_worker(objfct_, worker) + + else: + for sim in objfct.simulation.simulations: + sim.worker = worker def _validate_type_or_future_of_type( @@ -417,7 +468,7 @@ def objfcts(self, objfcts): futures, workers = _validate_type_or_future_of_type( "objfcts", objfcts, - (L2DataMisfit, Future), + (L2DataMisfit, Future, ComboObjectiveFunction), client, workers=self.workers, return_workers=True, diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 21fade5e26..cd4a0a7564 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -230,7 +230,7 @@ def survey(self) -> list["BaseSurvey"]: list of simpeg.survey.Survey Survey for all data misfits. """ - return [objfcts.simulation.survey for objfcts in self.dmisfit.objfcts] + return [simulation.survey for simulation in self.simulation] @property def simulation(self) -> list["BaseSimulation"]: @@ -245,7 +245,14 @@ def simulation(self) -> list["BaseSimulation"]: list of simpeg.simulation.BaseSimulation Simulation for all data misfits. """ - return [objfcts.simulation for objfcts in self.dmisfit.objfcts] + simulations = [] + for objfct in self.dmisfit.objfcts: + if isinstance(objfct, ComboObjectiveFunction): + simulations += [o.simulation for o in objfct.objfcts] + + else: + simulations.append(objfct.simulation) + return simulations def initialize(self): """Initialize inversion parameter(s) according to directive.""" From 6fee94d8fda811450b34c525b534eccfaa100bc8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 22 Aug 2025 15:41:23 -0700 Subject: [PATCH 07/14] Revert changes --- simpeg/directives/directives.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index cd4a0a7564..21fade5e26 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -230,7 +230,7 @@ def survey(self) -> list["BaseSurvey"]: list of simpeg.survey.Survey Survey for all data misfits. """ - return [simulation.survey for simulation in self.simulation] + return [objfcts.simulation.survey for objfcts in self.dmisfit.objfcts] @property def simulation(self) -> list["BaseSimulation"]: @@ -245,14 +245,7 @@ def simulation(self) -> list["BaseSimulation"]: list of simpeg.simulation.BaseSimulation Simulation for all data misfits. """ - simulations = [] - for objfct in self.dmisfit.objfcts: - if isinstance(objfct, ComboObjectiveFunction): - simulations += [o.simulation for o in objfct.objfcts] - - else: - simulations.append(objfct.simulation) - return simulations + return [objfcts.simulation for objfcts in self.dmisfit.objfcts] def initialize(self): """Initialize inversion parameter(s) according to directive.""" From bb2bd8f2740563c78e0b9ec2dfa92473559a0754 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 25 Aug 2025 08:24:17 -0700 Subject: [PATCH 08/14] Adjust directives for Combo of Combos --- simpeg/dask/inverse_problem.py | 21 +++++-- simpeg/dask/objective_function.py | 89 +++++++++++++++++++--------- simpeg/directives/_regularization.py | 7 +-- simpeg/directives/_save_geoh5.py | 16 +---- simpeg/directives/directives.py | 13 ++-- 5 files changed, 90 insertions(+), 56 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 7d015f8b2f..2bfb5f9a7a 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -9,15 +9,28 @@ from simpeg.version import __version__ as simpeg_version +def get_nested_predicted(objfcts, m, f=None): + dpreds = [] + + for objfct in objfcts: + + if isinstance(objfct, ComboObjectiveFunction): + dpreds += get_nested_predicted(objfct.objfcts, m, f=f) + + else: + dpred = objfct.simulation.dpred(m, f=f) + dpreds += [np.asarray(dpred)] + + return dpreds + + def get_dpred(self, m, f=None): dpreds = [] - if isinstance(self.dmisfit, DaskComboMisfits | DistributedComboMisfits): + if isinstance(self.dmisfit, DistributedComboMisfits): return self.dmisfit.get_dpred(m, f=f) - for objfct in self.dmisfit.objfcts: - dpred = objfct.simulation.dpred(m, f=f) - dpreds += [np.asarray(dpred)] + dpreds = get_nested_predicted(self.dmisfit.objfcts, m, f=f) return dpreds diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 13792768f5..755e975f5b 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -74,6 +74,16 @@ def _deriv2(objfct, multiplier, _, v): return multiplier * deriv2 +def _get_attr(objfct, key): + if isinstance(objfct, ComboObjectiveFunction): + attr = [] + for objfct_ in objfct.objfcts: + attr.append(_get_attr(objfct_, key)) + return attr + + return objfct.nP + + def _store_model(objfct, model): if isinstance(objfct, ComboObjectiveFunction): @@ -161,9 +171,10 @@ def _validate_type_or_future_of_type( lambda v: not isinstance(v, obj_type), obj, workers=worker ) ) - assignments.append(client.submit(_set_worker, obj, worker)) + assignments.append(client.submit(_set_worker, obj, worker, workers=worker)) client.gather(assignments) + is_not_obj = np.array(client.gather(futures)) if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") @@ -195,7 +206,7 @@ def __init__( multipliers = len(objfcts) * [1] super().__init__(**kwargs) - + self._nP = None self.objfcts = objfcts self.multipliers = np.array(multipliers, dtype=float) @@ -206,8 +217,8 @@ def __call__(self, m, f=None): values = [] count = 0 - for futures in self._futures: - for objfct, worker in zip(futures, self._workers, strict=True): + for futures in self._workloads: + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: continue @@ -215,7 +226,7 @@ def __call__(self, m, f=None): values.append( client.submit( _calc_objective, - objfct, + future, self.multipliers[count], m_future, workers=worker, @@ -226,6 +237,30 @@ def __call__(self, m, f=None): values = self.client.gather(values) return np.sum(values) + @property + def nP(self): + """Number of model parameters. + + Returns + ------- + int + Number of model parameters. + """ + if self._nP is None: + nP = [] + for futures in self._workloads: + for future, worker in zip(futures, self._workers, strict=True): + nP.append( + self.client.submit( + _get_attr, + future, + "nP", + workers=worker, + ) + ) + self._nP = np.sum(self.client.gather(nP)) + return self._nP + @property def client(self): """ @@ -270,16 +305,16 @@ def deriv(self, m, f=None): derivs = 0.0 count = 0 - for futures in self._futures: + for futures in self._workloads: future_deriv = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: # don't evaluate the fct continue future_deriv.append( client.submit( _deriv, - objfct, + future, self.multipliers[count], m_future, workers=worker, @@ -311,17 +346,17 @@ def deriv2(self, m, v=None, f=None): derivs = 0.0 count = 0 - for futures in self._futures: + for futures in self._workloads: future_derivs = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: # don't evaluate the fct continue future_derivs.append( client.submit( _deriv2, - objfct, + future, self.multipliers[count], m_future, v_future, @@ -346,13 +381,13 @@ def get_dpred(self, m, f=None): m_future = self._m_as_future dpred = [] - for futures in self._futures: + for futures in self._workloads: future_preds = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): future_preds.append( client.submit( _calc_dpred, - objfct, + future, m_future, workers=worker, ) @@ -372,14 +407,14 @@ def getJtJdiag(self, m, f=None): jtj_diag = 0.0 client = self.client - for futures in self._futures: + for futures in self._workloads: work = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): work.append( client.submit( _get_jtj_diag, - objfct, + future, m_future, workers=worker, ) @@ -406,13 +441,13 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for futures in self._futures: + for futures in self._workloads: f.append([]) - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): f[-1].append( client.submit( _calc_fields, - objfct, + future, m_future, workers=worker, ) @@ -441,12 +476,12 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for futures in self._futures: - for objfct, worker in zip(futures, self._workers): + for futures in self._workloads: + for future, worker in zip(futures, self._workers, strict=True): stores.append( client.submit( _store_model, - objfct, + future, self._m_as_future, workers=worker, ) @@ -475,7 +510,7 @@ def objfcts(self, objfcts): ) self._objfcts = objfcts - self._futures = futures + self._workloads = futures self._workers = workers @property @@ -516,13 +551,13 @@ def residuals(self, m, f=None): m_future = self._m_as_future residuals = [] - for futures in self._futures: + for futures in self._workloads: future_residuals = [] - for objfct, worker in zip(futures, self._workers): + for future, worker in zip(futures, self._workers, strict=True): future_residuals.append( client.submit( _calc_residual, - objfct, + future, m_future, workers=worker, ) diff --git a/simpeg/directives/_regularization.py b/simpeg/directives/_regularization.py index ec14c5b6d6..53a561fa09 100644 --- a/simpeg/directives/_regularization.py +++ b/simpeg/directives/_regularization.py @@ -218,12 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float: chi_factor : float Chi factor to compute the target misfit from. """ - value = 0 - - for survey in self.survey: - value += survey.nD * chi_factor - - return value + return self.dmisfit.nP def adjust_cooling_schedule(self): """ diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index 15329d0f95..ace59e3d34 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -16,21 +16,7 @@ from geoh5py.groups import UIJsonGroup from geoh5py.objects import ObjectBase from geoh5py.ui_json.utils import fetch_active_workspace - - -def compute_JtJdiags(data_misfit, m): - if hasattr(data_misfit, "getJtJdiag"): - return data_misfit.getJtJdiag(m) - else: - jtj_diags = [] - for dmisfit in data_misfit.objfcts: - jtj_diags.append(dmisfit.getJtJdiag(m)) - - jtj_diag = np.zeros_like(jtj_diags[0]) - for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): - jtj_diag += multiplier * diag - - return np.asarray(jtj_diag) +from simpeg.directives.directives import compute_JtJdiags class BaseSaveGeoH5(InversionDirective, ABC): diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 21fade5e26..a0626f4aed 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -54,12 +54,17 @@ def compute_JtJdiags(data_misfit, m): if hasattr(data_misfit, "getJtJdiag"): return data_misfit.getJtJdiag(m) else: - jtj_diags = [] + jtj_diag_list = [] + jtj_diag = np.zeros_like(m) + for dmisfit in data_misfit.objfcts: - jtj_diags.append(dmisfit.getJtJdiag(m)) + if isinstance(dmisfit, ComboObjectiveFunction): + jtj_diag += compute_JtJdiags(dmisfit, m) + + else: + jtj_diag_list.append(dmisfit.getJtJdiag(m)) - jtj_diag = np.zeros_like(jtj_diags[0]) - for multiplier, diag in zip(data_misfit.multipliers, jtj_diags): + for multiplier, diag in zip(data_misfit.multipliers, jtj_diag_list): jtj_diag += multiplier * diag return np.asarray(jtj_diag) From 87ba0792cb8be4d7d9c82a7b584e829ffb5d00c3 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 25 Aug 2025 09:12:01 -0700 Subject: [PATCH 09/14] Adapt inverse_problem with Combo of Combos --- simpeg/dask/inverse_problem.py | 32 ++++++++++++++++++++++------ simpeg/directives/_regularization.py | 2 +- simpeg/directives/directives.py | 2 +- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index 2bfb5f9a7a..c09490faab 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -1,10 +1,10 @@ from ..inverse_problem import BaseInvProblem import numpy as np -from .objective_function import DaskComboMisfits, DistributedComboMisfits +from .objective_function import DistributedComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse -from ..objective_function import ComboObjectiveFunction +from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction from simpeg.utils import call_hooks from simpeg.version import __version__ as simpeg_version @@ -32,23 +32,43 @@ def get_dpred(self, m, f=None): dpreds = get_nested_predicted(self.dmisfit.objfcts, m, f=f) - return dpreds + return np.hstack(dpreds) BaseInvProblem.get_dpred = get_dpred +def get_nested_residuals( + objfcts: list[BaseObjectiveFunction, ComboObjectiveFunction], dpreds, start=0 +): + residuals = [] + + for objfct in objfcts: + + if isinstance(objfct, ComboObjectiveFunction): + res = get_nested_residuals(objfct.objfcts, dpreds, start=start) + residuals.append(res) + start += res.shape[0] + + else: + residuals.append( + objfct.W * (objfct.data.dobs - dpreds[start : start + objfct.data.nD]) + ) + start += objfct.data.nD + + return np.hstack(residuals) + + def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" self.model = m self.dpred = self.get_dpred(m) residuals = [] - if isinstance(self.dmisfit, DaskComboMisfits | DistributedComboMisfits): + if isinstance(self.dmisfit, DistributedComboMisfits): residuals = self.dmisfit.residuals(m) else: - for (_, objfct), pred in zip(self.dmisfit, self.dpred): - residuals.append(objfct.W * (objfct.data.dobs - pred)) + residuals = get_nested_residuals(self.dmisfit.objfcts, self.dpred) phi_d = 0.0 for residual in residuals: diff --git a/simpeg/directives/_regularization.py b/simpeg/directives/_regularization.py index 53a561fa09..16dca5a93e 100644 --- a/simpeg/directives/_regularization.py +++ b/simpeg/directives/_regularization.py @@ -218,7 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float: chi_factor : float Chi factor to compute the target misfit from. """ - return self.dmisfit.nP + return self.invProb.dpred.shape[0] * chi_factor def adjust_cooling_schedule(self): """ diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index a0626f4aed..44a183184d 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -208,7 +208,7 @@ def dmisfit(self) -> BaseObjectiveFunction: The data misfit associated with the directive. """ if getattr(self, "_dmisfit", None) is None: - self.dmisfit = self.invProb.dmisfit # go through the setter + self._dmisfit = self.invProb.dmisfit # go through the setter return self._dmisfit @dmisfit.setter From b5a31e6c09c6197b35203529c6561317d697bd96 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 25 Aug 2025 13:34:21 -0700 Subject: [PATCH 10/14] Make dpred return residuals --- simpeg/dask/inverse_problem.py | 71 +++++++++++++------------------ simpeg/dask/objective_function.py | 39 ++++++++++++++--- simpeg/directives/_save_geoh5.py | 5 ++- 3 files changed, 66 insertions(+), 49 deletions(-) diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index c09490faab..e5e3c2d092 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -4,75 +4,64 @@ from .objective_function import DistributedComboMisfits from scipy.sparse.linalg import LinearOperator from ..regularization import WeightedLeastSquares, Sparse -from ..objective_function import ComboObjectiveFunction, BaseObjectiveFunction +from ..objective_function import ComboObjectiveFunction from simpeg.utils import call_hooks from simpeg.version import __version__ as simpeg_version -def get_nested_predicted(objfcts, m, f=None): +def get_nested_predicted(objfcts, m, f=None, return_residuals=False): dpreds = [] - + residuals = [] for objfct in objfcts: if isinstance(objfct, ComboObjectiveFunction): - dpreds += get_nested_predicted(objfct.objfcts, m, f=f) + nesting = get_nested_predicted( + objfct.objfcts, m, f=f, return_residuals=return_residuals + ) + if return_residuals: + dpreds += nesting[0] + residuals += nesting[1] + else: + dpreds += nesting else: dpred = objfct.simulation.dpred(m, f=f) dpreds += [np.asarray(dpred)] + if return_residuals: + residual = objfct.W * (objfct.data.dobs - dpred) + residuals += [np.asarray(residual)] + + if return_residuals: + return dpreds, residuals return dpreds -def get_dpred(self, m, f=None): - dpreds = [] - +def get_dpred(self, m, f=None, return_residuals=False): if isinstance(self.dmisfit, DistributedComboMisfits): - return self.dmisfit.get_dpred(m, f=f) + results = self.dmisfit.get_dpred(m, f=f, return_residuals=return_residuals) + else: + results = get_nested_predicted( + self.dmisfit.objfcts, m, f=f, return_residuals=return_residuals + ) - dpreds = get_nested_predicted(self.dmisfit.objfcts, m, f=f) + if return_residuals: + return np.hstack(results[0]), np.hstack(results[1]) - return np.hstack(dpreds) + return np.hstack(results) BaseInvProblem.get_dpred = get_dpred -def get_nested_residuals( - objfcts: list[BaseObjectiveFunction, ComboObjectiveFunction], dpreds, start=0 -): - residuals = [] - - for objfct in objfcts: - - if isinstance(objfct, ComboObjectiveFunction): - res = get_nested_residuals(objfct.objfcts, dpreds, start=start) - residuals.append(res) - start += res.shape[0] - - else: - residuals.append( - objfct.W * (objfct.data.dobs - dpreds[start : start + objfct.data.nD]) - ) - start += objfct.data.nD - - return np.hstack(residuals) - - def dask_evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True)""" - self.model = m - self.dpred = self.get_dpred(m) - residuals = [] - if isinstance(self.dmisfit, DistributedComboMisfits): - residuals = self.dmisfit.residuals(m) - else: - residuals = get_nested_residuals(self.dmisfit.objfcts, self.dpred) + if not np.allclose(self.model, m): + self.model = m + self.dpred, self.residuals = self.get_dpred(m, return_residuals=True) - phi_d = 0.0 - for residual in residuals: - phi_d += np.vdot(residual, residual) + phi_d = np.vdot(self.residuals, self.residuals) reg2Deriv = [] if isinstance(self.reg, ComboObjectiveFunction): diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 755e975f5b..6465a292eb 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -25,15 +25,29 @@ def _calc_fields(objfct, _): return objfct.simulation.fields(m=objfct.simulation.model) -def _calc_dpred(objfct, _): +def _calc_dpred(objfct, _, return_residuals=False): if isinstance(objfct, ComboObjectiveFunction): dpreds = [] + residuals = [] for objfct_ in objfct.objfcts: - dpreds.append(_calc_dpred(objfct_, _)) + if return_residuals: + dpred_, residual_ = _calc_dpred(objfct_, _, return_residuals) + dpreds.append(dpred_) + residuals.append(residual_) + else: + dpreds.append(_calc_dpred(objfct_, _)) + + if return_residuals: + return np.hstack(dpreds), np.hstack(residuals) return np.hstack(dpreds) - return objfct.simulation.dpred(m=objfct.simulation.model) + dpred = objfct.simulation.dpred(m=objfct.simulation.model) + if return_residuals: + residual = objfct.W * (objfct.data.dobs - dpred) + return dpred, residual + + return dpred def _calc_objective(objfct, multiplier, model): @@ -371,7 +385,7 @@ def deriv2(self, m, v=None, f=None): return derivs - def get_dpred(self, m, f=None): + def get_dpred(self, m, f=None, return_residuals=False): """ Request calculation of predicted data from all simulations. """ @@ -380,7 +394,7 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future dpred = [] - + residuals = [] for futures in self._workloads: future_preds = [] for future, worker in zip(futures, self._workers, strict=True): @@ -389,12 +403,23 @@ def get_dpred(self, m, f=None): _calc_dpred, future, m_future, + return_residuals, workers=worker, ) ) - dpred += client.gather(future_preds) + results = client.gather(future_preds) + + for result in results: + if return_residuals: + dpred += [result[0]] + residuals += [result[1]] + else: + dpred += [result] + + if return_residuals: + return np.hstack(dpred), np.hstack(residuals) - return dpred + return np.hstack(dpred) def getJtJdiag(self, m, f=None): """ diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index ace59e3d34..9ad798ab67 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -355,8 +355,11 @@ def get_values(self, values: list[np.ndarray] | None): else: dpred = getattr(self.invProb, "dpred", None) if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) + dpred, residuals = self.invProb.get_dpred( + self.invProb.model, return_residuals=True + ) self.invProb.dpred = dpred + self.invProb.residuals = residuals if self.joint_index is not None: dpred = [dpred[ind] for ind in self.joint_index] From 5b3a8ca0d1a87cb716f39169e56778edb2c2fe36 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 25 Aug 2025 13:57:06 -0700 Subject: [PATCH 11/14] Remove unused methods --- simpeg/dask/objective_function.py | 232 ------------------------------ 1 file changed, 232 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 6465a292eb..a52040ce70 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -8,7 +8,6 @@ import numpy as np from dask.distributed import Client, Future -from dask import array, delayed, compute from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types @@ -88,16 +87,6 @@ def _deriv2(objfct, multiplier, _, v): return multiplier * deriv2 -def _get_attr(objfct, key): - if isinstance(objfct, ComboObjectiveFunction): - attr = [] - for objfct_ in objfct.objfcts: - attr.append(_get_attr(objfct_, key)) - return attr - - return objfct.nP - - def _store_model(objfct, model): if isinstance(objfct, ComboObjectiveFunction): @@ -251,30 +240,6 @@ def __call__(self, m, f=None): values = self.client.gather(values) return np.sum(values) - @property - def nP(self): - """Number of model parameters. - - Returns - ------- - int - Number of model parameters. - """ - if self._nP is None: - nP = [] - for futures in self._workloads: - for future, worker in zip(futures, self._workers, strict=True): - nP.append( - self.client.submit( - _get_attr, - future, - "nP", - workers=worker, - ) - ) - self._nP = np.sum(self.client.gather(nP)) - return self._nP - @property def client(self): """ @@ -590,200 +555,3 @@ def residuals(self, m, f=None): residuals += client.gather(future_residuals) return residuals - - # - # def broadcast_updates(self, updates: dict): - # """ - # Set the attributes of the objective functions and simulations - # """ - # stores = [] - # client = self.client - # for fun, (key, value) in updates.items(): - # if fun not in self._lookup: - # continue - # - # future, worker = self._lookup[fun] - # - # stores.append( - # client.submit( - # _setter_broadcast, - # future, - # key, - # value, - # workers=worker, - # ) - # ) - # self.client.gather(stores) # blocking call to ensure all models were stored - - -class DaskComboMisfits(ComboObjectiveFunction): - """ - A composite objective function for distributed computing. - """ - - def __init__( - self, - objfcts: list[BaseObjectiveFunction], - multipliers=None, - worker: str | None = None, - **kwargs, - ): - self._model: np.ndarray | None = None - - super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) - - def __call__(self, m, f=None): - self.model = m - - futures = [] - count = 0 - - delayed_call = delayed(_calc_objective) - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: - continue - - futures.append(delayed_call(objfct, self.multipliers[count], m)) - count += 1 - - return np.sum(compute(futures)[0]) - - def deriv(self, m, f=None): - """ - First derivative of the composite objective function is the sum of the - derivatives of each objective function in the list, weighted by their - respective multplier. - - :param numpy.ndarray m: model - :param SimPEG.Fields f: Fields object (if applicable) - """ - self.model = m - - futures = [] - - count = 0 - - delayed_call = delayed(_deriv) - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue - - futures.append( - array.from_delayed( - delayed_call( - objfct, - self.multipliers[count], - m, - ), - shape=m.shape, - dtype=float, - ) - ) - - count += 1 - - return array.vstack(futures).sum(axis=0).compute() - - def deriv2(self, m, v=None, f=None): - """ - Second derivative of the composite objective function is the sum of the - second derivatives of each objective function in the list, weighted by - their respective multplier. - - :param numpy.ndarray m: model - :param numpy.ndarray v: vector we are multiplying by - :param SimPEG.Fields f: Fields object (if applicable) - """ - self.model = m - - futures = [] - count = 0 - - delayed_call = delayed(_deriv2) - for objfct in self.objfcts: - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue - - futures.append( - array.from_delayed( - delayed_call(objfct, self.multipliers[count], m, v), - shape=m.shape, - dtype=float, - ) - ) - count += 1 - - return array.vstack(futures).sum(axis=0).compute() - - def get_dpred(self, m, f=None): - """ - Request calculation of predicted data from all simulations. - """ - self.model = m - - futures = [] - delayed_call = delayed(_calc_dpred) - - for objfct in self.objfcts: - futures.append(delayed_call(objfct, m)) - - return compute(futures)[0] - - def getJtJdiag(self, m, f=None): - """ - Request calculation of the diagonal of JtJ from all simulations. - """ - self.model = m - - if getattr(self, "_jtjdiag", None) is None: - - futures = [] - delayed_call = delayed(_get_jtj_diag) - - for objfct in self.objfcts: - futures.append( - array.from_delayed( - delayed_call(objfct, m), shape=m.shape, dtype=float - ) - ) - - self._jtjdiag = array.vstack(futures).sum(axis=0).compute() - - return self._jtjdiag - - def residuals(self, m, f=None): - """ - Compute the residual for the data misfit. - """ - self.model = m - - futures = [] - - delayed_call = delayed(_calc_residual) - for objfct in self.objfcts: - futures.append(delayed_call(objfct, m)) - - return compute(futures)[0] - - @property - def model(self): - return self._model - - @model.setter - def model(self, value): - # Only send the model to the internal simulations if it was updated. - if ( - isinstance(value, np.ndarray) - and isinstance(self.model, np.ndarray) - and np.allclose(value, self.model) - ): - return - - self._jtjdiag = None - - stores = [] - delayed_call = delayed(_store_model) - for objfct in self.objfcts: - stores.append(delayed_call(objfct, value)) - compute(stores) - self._model = value From 83ca7e66cec5fbf05b627b7a139b3fb3129160d7 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 26 Aug 2025 08:32:27 -0700 Subject: [PATCH 12/14] Bump geoh5py --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 90f0cafecf..6db8372fe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ keywords = ["geophysics", "inverse problem"] dependencies = [ "discretize>=0.11", "geoana>=0.7.0", - "geoh5py>=0.12.0a1, <0.13.dev", + "geoh5py>=0.13.0a1, <0.14.dev", "libdlf", "matplotlib", "numpy>=1.22", From 85488c25ff5e3933afbd3faac399da5075ac2bb8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 26 Aug 2025 16:06:02 -0700 Subject: [PATCH 13/14] Revert geoh5py version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6db8372fe4..90f0cafecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ keywords = ["geophysics", "inverse problem"] dependencies = [ "discretize>=0.11", "geoana>=0.7.0", - "geoh5py>=0.13.0a1, <0.14.dev", + "geoh5py>=0.12.0a1, <0.13.dev", "libdlf", "matplotlib", "numpy>=1.22", From 3b31bb140a6015fa9463099ab7dc1126df80adae Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 28 Aug 2025 17:34:39 -0700 Subject: [PATCH 14/14] Add mapping derivatives back --- simpeg/dask/electromagnetics/time_domain/simulation_1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation_1d.py b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py index 8eeb6058d2..4729756ea0 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation_1d.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation_1d.py @@ -13,7 +13,7 @@ def Jmatrix(self): """ if getattr(self, "_Jmatrix", None) is None: Jmat = self.getJ(self.model) - self._Jmatrix = Jmat["ds"] + self._Jmatrix = Jmat["ds"] * self.sigmaDeriv return self._Jmatrix