From 1dfafa5b013d610fdc1d67589400b8fcc5ad8c5f Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 23 Jul 2025 10:26:48 -0700 Subject: [PATCH 1/7] Allow for list of int. Only project active times --- .../electromagnetics/time_domain/receivers.py | 21 ++++++++++++++++--- simpeg/fields.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/simpeg/electromagnetics/time_domain/receivers.py b/simpeg/electromagnetics/time_domain/receivers.py index dc2cbb0255..0a0fa0966c 100644 --- a/simpeg/electromagnetics/time_domain/receivers.py +++ b/simpeg/electromagnetics/time_domain/receivers.py @@ -1,5 +1,5 @@ import scipy.sparse as sp - +import numpy as np from ...utils import mkvc, validate_type, validate_direction from discretize.utils import Zero from ...survey import BaseTimeRx @@ -128,6 +128,20 @@ def getTimeP(self, time_mesh, f): return self.timeP + def active_times(self, projection): + """Get active times for the receiver. + + Parameters + ---------- + projection : Sparse matrix + + Returns + ------- + numpy.ndarray + Active times for the receiver. + """ + return np.unique(sp.find(projection)[1]) + def getP(self, mesh, time_mesh, f): """Returns projection matrices as a list for all components collected by the receivers. @@ -153,6 +167,7 @@ def getP(self, mesh, time_mesh, f): Ps = self.getSpatialP(mesh, f) Pt = self.getTimeP(time_mesh, f) + Pt = Pt[:, self.active_times(Pt)] P = sp.kron(Pt, Ps) if self.storeProjections: @@ -180,7 +195,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003 Fields projected to the receiver(s) """ P = self.getP(mesh, time_mesh, f) - f_part = mkvc(f[src, self.projField, :]) + f_part = mkvc(f[src, self.projField, self.active_times(self.timeP)]) return P * f_part def evalDeriv(self, src, mesh, time_mesh, f, v, adjoint=False): @@ -301,7 +316,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003 ) P = self.getP(mesh, time_mesh, f) - f_part = mkvc(f[src, "b", :]) + f_part = mkvc(f[src, "b", self.active_times(self.timeP)]) return P * f_part def getTimeP(self, time_mesh, f): diff --git a/simpeg/fields.py b/simpeg/fields.py index f78c994219..009d076c6e 100644 --- a/simpeg/fields.py +++ b/simpeg/fields.py @@ -501,7 +501,7 @@ def _getField(self, name, ind, src_list): pointerFields = pointerFields.reshape(pointerShape, order="F") # First try to return the function as three arguments (without timeInd) - if timeInd == slice(None, None, None): + if isinstance(timeInd, slice) and timeInd == slice(None, None, None): try: # assume it will take care of integrating over all times return func(pointerFields, srcInd) From de4f98b55bd9047e4ac6d448a06c9949218bd1ca Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 23 Jul 2025 10:58:27 -0700 Subject: [PATCH 2/7] Add parallel compute for dpred on tem simulation --- .../time_domain/simulation.py | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 823840564b..3fd1a42697 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -561,6 +561,87 @@ def compute_rows( return np.vstack(rows) +def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): + """ + Evaluate the data prediction for a block of sources. + """ + data = [] + for ind in indices: + + receiver_list = sources[ind].receiver_list + if len(receiver_list) == 0: + continue + + for receiver in receiver_list: + data.append(receiver.eval(sources[ind], mesh, time_mesh, fields)) + + return np.hstack(data) + + +def dpred(self, m=None, f=None): + # Docstring inherited from BaseSimulation. + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + try: + client = get_client() + except ValueError: + client = None + + if f is None: + f = self.fields(m) + + delayed_chunks = [] + + source_block = np.array_split( + np.arange(len(self.survey.source_list)), self.n_threads(client=client) + ) + if client: + mesh = client.scatter(self.mesh, workers=self.worker) + time_mesh = client.scatter(self.time_mesh, workers=self.worker) + fields = client.scatter(f, workers=self.worker) + source_list = client.scatter(self.survey.source_list, workers=self.worker) + else: + mesh = self.mesh + time_mesh = self.time_mesh + delayed_eval = delayed(evaluate_dpred_block) + source_list = self.survey.source_list + fields = f + + for block in source_block: + if len(block) == 0: + continue + + if client: + delayed_chunks.append( + client.submit( + evaluate_dpred_block, + block, + source_list, + mesh, + time_mesh, + fields, + workers=self.worker, + ) + ) + else: + delayed_chunks.append( + delayed_eval(block, source_list, mesh, time_mesh, fields) + ) + + if client: + result = client.gather(delayed_chunks) + else: + result = dask.compute(delayed_chunks)[0] + + return np.hstack(result) + + +Sim.dpred = dpred Sim.fields = fields Sim.getSourceTerm = getSourceTerm Sim.compute_J = compute_J From 599cf2489858631026e494883dd949ed4525e1e0 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 23 Jul 2025 20:16:54 -0700 Subject: [PATCH 3/7] Temp remove of block optimize --- simpeg/dask/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index 19ba96f1a6..e17c130924 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -75,15 +75,15 @@ def get_parallel_blocks( row_index += chunk_size row_count += chunk_size - # Re-split over cpu_count if too few blocks - if len(blocks) < thread_count and optimize: - flatten_blocks = [] - for block in blocks: - flatten_blocks += block - - chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - return [ - [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 - ] + # # Re-split over cpu_count if too few blocks + # if len(blocks) < thread_count and optimize: + # flatten_blocks = [] + # for block in blocks: + # flatten_blocks += block + # + # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + # return [ + # [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 + # ] return blocks From af0d5690a1a16b14983039c0ffdd6c930f9dfa77 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 24 Jul 2025 14:53:06 -0700 Subject: [PATCH 4/7] Move workers load to internal serial process --- simpeg/dask/objective_function.py | 374 ++++++++++++++++-------------- 1 file changed, 202 insertions(+), 172 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 4b9addd8d7..8a178995ec 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -5,36 +5,76 @@ from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types +from simpeg.objective_function import ( + _validate_multiplier, + _check_length_objective_funcs_multipliers, +) -def _calc_fields(objfct, _): - return objfct.simulation.fields(m=objfct.simulation.model) +def _calc_fields(objfcts, _): + blocks = [] + for objfct in objfcts: + blocks.append(objfct.simulation.fields(m=objfct.simulation.model)) + return blocks -def _calc_dpred(objfct, _): - return objfct.simulation.dpred(m=objfct.simulation.model) +def _calc_dpred(objfcts, _): + blocks = [] + for objfct in objfcts: + blocks.append(objfct.simulation.dpred(m=objfct.simulation.model)) -def _calc_objective(objfct, multiplier, model): - return multiplier * objfct(model) + return np.hstack(blocks) -def _calc_residual(objfct, _): - return objfct.W * ( - objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) - ) +def _calc_objective(objfcts, multipliers, model): + blocks = [] + for multiplier, objfct in zip(multipliers, objfcts): + if multiplier == 0.0: + continue + + blocks.append(multiplier * objfct(model)) + + return np.sum(blocks) + + +def _calc_residual(objfcts, _): + blocks = [] + for objfct in objfcts: + blocks.append( + objfct.W + * (objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model)) + ) + + return np.hstack(blocks) + + +def _deriv(objfcts, multipliers, _): + blocks = [] + for multiplier, objfct in zip(multipliers, objfcts): + if multiplier == 0.0: + continue + + blocks.append(multiplier * objfct.deriv(objfct.simulation.model)) + + return np.sum(blocks, axis=0) + +def _deriv2(objfcts, multipliers, _, v): + blocks = [] + for multiplier, objfct in zip(multipliers, objfcts): -def _deriv(objfct, multiplier, _): - return multiplier * objfct.deriv(objfct.simulation.model) + if multiplier == 0.0: + continue + blocks.append(multiplier * objfct.deriv2(objfct.simulation.model, v)) -def _deriv2(objfct, multiplier, _, v): - return multiplier * objfct.deriv2(objfct.simulation.model, v) + return np.sum(blocks, axis=0) -def _store_model(objfct, model): - objfct.simulation.model = model +def _store_model(objfcts, model): + for objfct in objfcts: + objfct.simulation.model = model def _setter_broadcast(objfct, key, value): @@ -49,9 +89,11 @@ def _setter_broadcast(objfct, key, value): setattr(sim, key, value) -def _get_jtj_diag(objfct, _): - jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) - return jtj.flatten() +def _get_jtj_diag(objfcts, _): + arrays = [] + for objfct in objfcts: + arrays.append(objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W)) + return np.sum(arrays, axis=0) def _validate_type_or_future_of_type( @@ -60,7 +102,7 @@ def _validate_type_or_future_of_type( obj_type, client, workers: list[str] | None = None, - return_workers=False, + return_lookup=False, ): if workers is None: @@ -71,26 +113,24 @@ def _validate_type_or_future_of_type( objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) - workload = [[]] + funs_split = np.array_split(objects, len(workers)) + workload = {worker: [] for worker in workers} lookup = {} - count = 0 - for obj in objects: - if count == len(workers): - count = 0 - workload.append([]) - obj.simulation.simulations[0].worker = workers[count] - future = client.scatter([obj], workers=workers[count])[0] - lookup[obj] = (future, workers[count]) - if hasattr(obj, "name"): - future.name = obj.name - - workload[-1].append(future) - count += 1 + for objfcts, worker in zip(funs_split, workers): + for obj in objfcts: + obj.simulation.simulations[0].worker = worker + future = client.scatter([obj], workers=worker)[0] + + if hasattr(obj, "name"): + future.name = obj.name + + workload[worker].append(future) + lookup[obj] = (future, worker) futures = [] - for work in workload: + for worker, future_list in workload.items(): - for obj, worker in zip(work, workers): + for obj in future_list: futures.append( client.submit( lambda v: not isinstance(v, obj_type), obj, workers=worker @@ -100,8 +140,8 @@ def _validate_type_or_future_of_type( if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") - if return_workers: - return workload, workers, lookup + if return_lookup: + return workload, lookup else: return workload @@ -130,26 +170,18 @@ def __call__(self, m, f=None): client = self.client m_future = self._m_as_future - values = [] - count = 0 - for futures in self._futures: - for objfct, worker in zip(futures, self._workers, strict=True): - - if self.multipliers[count] == 0.0: - continue - - values.append( - client.submit( - _calc_objective, - objfct, - self.multipliers[count], - m_future, - workers=worker, - ) + future_values = [] + for worker, futures in self._futures.items(): + future_values.append( + client.submit( + _calc_objective, + futures, + self.multipliers[worker], + m_future, + workers=worker, ) - count += 1 - - values = self.client.gather(values) + ) + values = self.client.gather(future_values) return np.sum(values) @property @@ -192,32 +224,21 @@ def deriv(self, m, f=None): self.model = m client = self.client m_future = self._m_as_future - - derivs = 0.0 - count = 0 - - for futures in self._futures: - future_deriv = [] - for objfct, worker in zip(futures, self._workers): - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue - - future_deriv.append( - client.submit( - _deriv, - objfct, - self.multipliers[count], - m_future, - workers=worker, - ) + future_derivs = [] + for worker, futures in self._futures.items(): + future_derivs.append( + client.submit( + _deriv, + futures, + self.multipliers[worker], + m_future, + workers=worker, ) + ) - count += 1 - future_deriv = client.gather(future_deriv) - - derivs += np.sum(future_deriv, axis=0) + derivs = self.client.gather(future_derivs) - return derivs + return np.sum(derivs, axis=0) def deriv2(self, m, v=None, f=None): """ @@ -233,34 +254,22 @@ def deriv2(self, m, v=None, f=None): client = self.client m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - - derivs = 0.0 - count = 0 - - for futures in self._futures: - - future_derivs = [] - for objfct, worker in zip(futures, self._workers): - if self.multipliers[count] == 0.0: # don't evaluate the fct - continue - - future_derivs.append( - client.submit( - _deriv2, - objfct, - self.multipliers[count], - m_future, - v_future, - # field, - workers=worker, - ) + future_derivs = [] + for worker, futures in self._futures.items(): + future_derivs.append( + client.submit( + _deriv2, + futures, + self.multipliers[worker], + m_future, + v_future, + # field, + workers=worker, ) - count += 1 - - future_derivs = self.client.gather(future_derivs) - derivs += np.sum(future_derivs, axis=0) + ) - return derivs + derivs = self.client.gather(future_derivs) + return np.sum(derivs, axis=0) def get_dpred(self, m, f=None): """ @@ -270,22 +279,20 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future - dpred = [] + future_preds = [] - for futures in self._futures: - future_preds = [] - for objfct, worker in zip(futures, self._workers): - future_preds.append( - client.submit( - _calc_dpred, - objfct, - m_future, - workers=worker, - ) + for worker, futures in self._futures.items(): + future_preds.append( + client.submit( + _calc_dpred, + futures, + m_future, + workers=worker, ) - dpred += client.gather(future_preds) + ) + dpreds = client.gather(future_preds) - return dpred + return np.hstack(dpreds) def getJtJdiag(self, m, f=None): """ @@ -295,26 +302,20 @@ def getJtJdiag(self, m, f=None): m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: - jtj_diag = 0.0 client = self.client - - for futures in self._futures: - work = [] - - for objfct, worker in zip(futures, self._workers): - work.append( - client.submit( - _get_jtj_diag, - objfct, - m_future, - workers=worker, - ) + work = [] + for worker, futures in self._futures.items(): + work.append( + client.submit( + _get_jtj_diag, + futures, + m_future, + workers=worker, ) + ) - work = client.gather(work) - jtj_diag += np.sum(work, axis=0) - - self._jtjdiag = jtj_diag + jtj_diag = client.gather(work) + self._jtjdiag = np.sum(jtj_diag, axis=0) return self._jtjdiag @@ -332,17 +333,15 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for futures in self._futures: - f.append([]) - for objfct, worker in zip(futures, self._workers): - f[-1].append( - client.submit( - _calc_fields, - objfct, - m_future, - workers=worker, - ) + for worker, futures in self._futures.items(): + f.append( + client.submit( + _calc_fields, + futures, + m_future, + workers=worker, ) + ) self._stashed_fields = f return f @@ -367,16 +366,15 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for futures in self._futures: - for objfct, worker in zip(futures, self._workers): - stores.append( - client.submit( - _store_model, - objfct, - self._m_as_future, - workers=worker, - ) + for worker, futures in self._futures.items(): + stores.append( + client.submit( + _store_model, + futures, + self._m_as_future, + workers=worker, ) + ) self.client.gather(stores) # blocking call to ensure all models were stored self._model = value @@ -391,18 +389,18 @@ def objfcts(self): def objfcts(self, objfcts): client = self.client - futures, workers, lookup = _validate_type_or_future_of_type( + workload, lookup = _validate_type_or_future_of_type( "objfcts", objfcts, L2DataMisfit, client, workers=self.workers, - return_workers=True, + return_lookup=True, ) self._objfcts = objfcts - self._futures = futures - self._workers = workers + self._futures = workload + self._workers = list(workload) self._lookup = { misfit.simulation: (future, worker) @@ -417,22 +415,20 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future - residuals = [] - for futures in self._futures: - future_residuals = [] - for objfct, worker in zip(futures, self._workers): - future_residuals.append( - client.submit( - _calc_residual, - objfct, - m_future, - workers=worker, - ) + future_residuals = [] + for worker, futures in self._futures.items(): + future_residuals.append( + client.submit( + _calc_residual, + futures, + m_future, + workers=worker, ) - residuals += client.gather(future_residuals) + ) + residuals = client.gather(future_residuals) - return residuals + return np.hstack(residuals) def broadcast_updates(self, updates: dict): """ @@ -456,3 +452,37 @@ def broadcast_updates(self, updates: dict): ) ) self.client.gather(stores) # blocking call to ensure all models were stored + + @property + def multipliers(self): + r"""Multipliers for the objective functions. + + For a composite objective function :math:`\phi`, that is, a weighted sum of + objective functions :math:`\phi_i` with multipliers :math:`c_i` such that + + .. math:: + \phi = \sum_{i = 1}^N c_i \phi_i, + + this method returns the multipliers :math:`c_i` in + the same order of the ``objfcts``. + + Returns + ------- + list of int + Multipliers for the objective functions. + """ + + return { + worker: multipliers + for worker, multipliers in zip( + self._workers, np.array_split(self._multipliers, len(self._workers)) + ) + } + + @multipliers.setter + def multipliers(self, value): + """Set multipliers attribute after checking if they are valid.""" + for multiplier in value: + _validate_multiplier(multiplier) + _check_length_objective_funcs_multipliers(self.objfcts, value) + self._multipliers = value From 74c3cb86066ab6ee59d570b8f5ab410554315091 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 25 Jul 2025 07:39:45 -0700 Subject: [PATCH 5/7] Revert "Move workers load to internal serial process" This reverts commit af0d5690a1a16b14983039c0ffdd6c930f9dfa77. --- simpeg/dask/objective_function.py | 374 ++++++++++++++---------------- 1 file changed, 172 insertions(+), 202 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 8a178995ec..4b9addd8d7 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -5,76 +5,36 @@ from ..data_misfit import L2DataMisfit from simpeg.utils import validate_list_of_types -from simpeg.objective_function import ( - _validate_multiplier, - _check_length_objective_funcs_multipliers, -) -def _calc_fields(objfcts, _): - blocks = [] - for objfct in objfcts: - blocks.append(objfct.simulation.fields(m=objfct.simulation.model)) +def _calc_fields(objfct, _): + return objfct.simulation.fields(m=objfct.simulation.model) - return blocks +def _calc_dpred(objfct, _): + return objfct.simulation.dpred(m=objfct.simulation.model) -def _calc_dpred(objfcts, _): - blocks = [] - for objfct in objfcts: - blocks.append(objfct.simulation.dpred(m=objfct.simulation.model)) - return np.hstack(blocks) +def _calc_objective(objfct, multiplier, model): + return multiplier * objfct(model) -def _calc_objective(objfcts, multipliers, model): - blocks = [] - for multiplier, objfct in zip(multipliers, objfcts): - if multiplier == 0.0: - continue - - blocks.append(multiplier * objfct(model)) - - return np.sum(blocks) - - -def _calc_residual(objfcts, _): - blocks = [] - for objfct in objfcts: - blocks.append( - objfct.W - * (objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model)) - ) - - return np.hstack(blocks) - - -def _deriv(objfcts, multipliers, _): - blocks = [] - for multiplier, objfct in zip(multipliers, objfcts): - if multiplier == 0.0: - continue - - blocks.append(multiplier * objfct.deriv(objfct.simulation.model)) - - return np.sum(blocks, axis=0) - +def _calc_residual(objfct, _): + return objfct.W * ( + objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model) + ) -def _deriv2(objfcts, multipliers, _, v): - blocks = [] - for multiplier, objfct in zip(multipliers, objfcts): - if multiplier == 0.0: - continue +def _deriv(objfct, multiplier, _): + return multiplier * objfct.deriv(objfct.simulation.model) - blocks.append(multiplier * objfct.deriv2(objfct.simulation.model, v)) - return np.sum(blocks, axis=0) +def _deriv2(objfct, multiplier, _, v): + return multiplier * objfct.deriv2(objfct.simulation.model, v) -def _store_model(objfcts, model): - for objfct in objfcts: - objfct.simulation.model = model +def _store_model(objfct, model): + objfct.simulation.model = model def _setter_broadcast(objfct, key, value): @@ -89,11 +49,9 @@ def _setter_broadcast(objfct, key, value): setattr(sim, key, value) -def _get_jtj_diag(objfcts, _): - arrays = [] - for objfct in objfcts: - arrays.append(objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W)) - return np.sum(arrays, axis=0) +def _get_jtj_diag(objfct, _): + jtj = objfct.simulation.getJtJdiag(objfct.simulation.model, objfct.W) + return jtj.flatten() def _validate_type_or_future_of_type( @@ -102,7 +60,7 @@ def _validate_type_or_future_of_type( obj_type, client, workers: list[str] | None = None, - return_lookup=False, + return_workers=False, ): if workers is None: @@ -113,24 +71,26 @@ def _validate_type_or_future_of_type( objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) - funs_split = np.array_split(objects, len(workers)) - workload = {worker: [] for worker in workers} + workload = [[]] lookup = {} - for objfcts, worker in zip(funs_split, workers): - for obj in objfcts: - obj.simulation.simulations[0].worker = worker - future = client.scatter([obj], workers=worker)[0] - - if hasattr(obj, "name"): - future.name = obj.name - - workload[worker].append(future) - lookup[obj] = (future, worker) + count = 0 + for obj in objects: + if count == len(workers): + count = 0 + workload.append([]) + obj.simulation.simulations[0].worker = workers[count] + future = client.scatter([obj], workers=workers[count])[0] + lookup[obj] = (future, workers[count]) + if hasattr(obj, "name"): + future.name = obj.name + + workload[-1].append(future) + count += 1 futures = [] - for worker, future_list in workload.items(): + for work in workload: - for obj in future_list: + for obj, worker in zip(work, workers): futures.append( client.submit( lambda v: not isinstance(v, obj_type), obj, workers=worker @@ -140,8 +100,8 @@ def _validate_type_or_future_of_type( if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") - if return_lookup: - return workload, lookup + if return_workers: + return workload, workers, lookup else: return workload @@ -170,18 +130,26 @@ def __call__(self, m, f=None): client = self.client m_future = self._m_as_future - future_values = [] - for worker, futures in self._futures.items(): - future_values.append( - client.submit( - _calc_objective, - futures, - self.multipliers[worker], - m_future, - workers=worker, + values = [] + count = 0 + for futures in self._futures: + for objfct, worker in zip(futures, self._workers, strict=True): + + if self.multipliers[count] == 0.0: + continue + + values.append( + client.submit( + _calc_objective, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) ) - ) - values = self.client.gather(future_values) + count += 1 + + values = self.client.gather(values) return np.sum(values) @property @@ -224,21 +192,32 @@ def deriv(self, m, f=None): self.model = m client = self.client m_future = self._m_as_future - future_derivs = [] - for worker, futures in self._futures.items(): - future_derivs.append( - client.submit( - _deriv, - futures, - self.multipliers[worker], - m_future, - workers=worker, + + derivs = 0.0 + count = 0 + + for futures in self._futures: + future_deriv = [] + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + future_deriv.append( + client.submit( + _deriv, + objfct, + self.multipliers[count], + m_future, + workers=worker, + ) ) - ) - derivs = self.client.gather(future_derivs) + count += 1 + future_deriv = client.gather(future_deriv) + + derivs += np.sum(future_deriv, axis=0) - return np.sum(derivs, axis=0) + return derivs def deriv2(self, m, v=None, f=None): """ @@ -254,22 +233,34 @@ def deriv2(self, m, v=None, f=None): client = self.client m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - future_derivs = [] - for worker, futures in self._futures.items(): - future_derivs.append( - client.submit( - _deriv2, - futures, - self.multipliers[worker], - m_future, - v_future, - # field, - workers=worker, + + derivs = 0.0 + count = 0 + + for futures in self._futures: + + future_derivs = [] + for objfct, worker in zip(futures, self._workers): + if self.multipliers[count] == 0.0: # don't evaluate the fct + continue + + future_derivs.append( + client.submit( + _deriv2, + objfct, + self.multipliers[count], + m_future, + v_future, + # field, + workers=worker, + ) ) - ) + count += 1 + + future_derivs = self.client.gather(future_derivs) + derivs += np.sum(future_derivs, axis=0) - derivs = self.client.gather(future_derivs) - return np.sum(derivs, axis=0) + return derivs def get_dpred(self, m, f=None): """ @@ -279,20 +270,22 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future - future_preds = [] + dpred = [] - for worker, futures in self._futures.items(): - future_preds.append( - client.submit( - _calc_dpred, - futures, - m_future, - workers=worker, + for futures in self._futures: + future_preds = [] + for objfct, worker in zip(futures, self._workers): + future_preds.append( + client.submit( + _calc_dpred, + objfct, + m_future, + workers=worker, + ) ) - ) - dpreds = client.gather(future_preds) + dpred += client.gather(future_preds) - return np.hstack(dpreds) + return dpred def getJtJdiag(self, m, f=None): """ @@ -302,20 +295,26 @@ def getJtJdiag(self, m, f=None): m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: + jtj_diag = 0.0 client = self.client - work = [] - for worker, futures in self._futures.items(): - work.append( - client.submit( - _get_jtj_diag, - futures, - m_future, - workers=worker, + + for futures in self._futures: + work = [] + + for objfct, worker in zip(futures, self._workers): + work.append( + client.submit( + _get_jtj_diag, + objfct, + m_future, + workers=worker, + ) ) - ) - jtj_diag = client.gather(work) - self._jtjdiag = np.sum(jtj_diag, axis=0) + work = client.gather(work) + jtj_diag += np.sum(work, axis=0) + + self._jtjdiag = jtj_diag return self._jtjdiag @@ -333,15 +332,17 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for worker, futures in self._futures.items(): - f.append( - client.submit( - _calc_fields, - futures, - m_future, - workers=worker, + for futures in self._futures: + f.append([]) + for objfct, worker in zip(futures, self._workers): + f[-1].append( + client.submit( + _calc_fields, + objfct, + m_future, + workers=worker, + ) ) - ) self._stashed_fields = f return f @@ -366,15 +367,16 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for worker, futures in self._futures.items(): - stores.append( - client.submit( - _store_model, - futures, - self._m_as_future, - workers=worker, + for futures in self._futures: + for objfct, worker in zip(futures, self._workers): + stores.append( + client.submit( + _store_model, + objfct, + self._m_as_future, + workers=worker, + ) ) - ) self.client.gather(stores) # blocking call to ensure all models were stored self._model = value @@ -389,18 +391,18 @@ def objfcts(self): def objfcts(self, objfcts): client = self.client - workload, lookup = _validate_type_or_future_of_type( + futures, workers, lookup = _validate_type_or_future_of_type( "objfcts", objfcts, L2DataMisfit, client, workers=self.workers, - return_lookup=True, + return_workers=True, ) self._objfcts = objfcts - self._futures = workload - self._workers = list(workload) + self._futures = futures + self._workers = workers self._lookup = { misfit.simulation: (future, worker) @@ -415,20 +417,22 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future + residuals = [] - future_residuals = [] - for worker, futures in self._futures.items(): - future_residuals.append( - client.submit( - _calc_residual, - futures, - m_future, - workers=worker, + for futures in self._futures: + future_residuals = [] + for objfct, worker in zip(futures, self._workers): + future_residuals.append( + client.submit( + _calc_residual, + objfct, + m_future, + workers=worker, + ) ) - ) - residuals = client.gather(future_residuals) + residuals += client.gather(future_residuals) - return np.hstack(residuals) + return residuals def broadcast_updates(self, updates: dict): """ @@ -452,37 +456,3 @@ def broadcast_updates(self, updates: dict): ) ) self.client.gather(stores) # blocking call to ensure all models were stored - - @property - def multipliers(self): - r"""Multipliers for the objective functions. - - For a composite objective function :math:`\phi`, that is, a weighted sum of - objective functions :math:`\phi_i` with multipliers :math:`c_i` such that - - .. math:: - \phi = \sum_{i = 1}^N c_i \phi_i, - - this method returns the multipliers :math:`c_i` in - the same order of the ``objfcts``. - - Returns - ------- - list of int - Multipliers for the objective functions. - """ - - return { - worker: multipliers - for worker, multipliers in zip( - self._workers, np.array_split(self._multipliers, len(self._workers)) - ) - } - - @multipliers.setter - def multipliers(self, value): - """Set multipliers attribute after checking if they are valid.""" - for multiplier in value: - _validate_multiplier(multiplier) - _check_length_objective_funcs_multipliers(self.objfcts, value) - self._multipliers = value From 13bf9cee0c18909735e706fdd2136b234efb8634 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 25 Jul 2025 08:19:57 -0700 Subject: [PATCH 6/7] Use priority to send queue all jobs on workers --- simpeg/dask/objective_function.py | 75 ++++++++++++++++++------------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 4b9addd8d7..2583927a7b 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -132,7 +132,9 @@ def __call__(self, m, f=None): values = [] count = 0 - for futures in self._futures: + for ind, futures in enumerate(self._futures): + + priority = len(self._futures) - ind # reverse order for priority for objfct, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: @@ -145,6 +147,7 @@ def __call__(self, m, f=None): self.multipliers[count], m_future, workers=worker, + priority=priority, ) ) count += 1 @@ -193,11 +196,12 @@ def deriv(self, m, f=None): client = self.client m_future = self._m_as_future - derivs = 0.0 count = 0 + future_deriv = [] + for ind, futures in enumerate(self._futures): + + priority = len(self._futures) - ind # reverse order for priority - for futures in self._futures: - future_deriv = [] for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue @@ -209,13 +213,14 @@ def deriv(self, m, f=None): self.multipliers[count], m_future, workers=worker, + priority=priority, ) ) count += 1 - future_deriv = client.gather(future_deriv) + future_deriv = client.gather(future_deriv) - derivs += np.sum(future_deriv, axis=0) + derivs = np.sum(future_deriv, axis=0) return derivs @@ -234,12 +239,11 @@ def deriv2(self, m, v=None, f=None): m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) - derivs = 0.0 count = 0 + future_derivs = [] + for ind, futures in enumerate(self._futures): - for futures in self._futures: - - future_derivs = [] + priority = len(self._futures) - ind # reverse order for priority for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue @@ -253,12 +257,13 @@ def deriv2(self, m, v=None, f=None): v_future, # field, workers=worker, + priority=priority, ) ) count += 1 - future_derivs = self.client.gather(future_derivs) - derivs += np.sum(future_derivs, axis=0) + future_derivs = self.client.gather(future_derivs) + derivs = np.sum(future_derivs, axis=0) return derivs @@ -270,10 +275,11 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future - dpred = [] + # dpred = [] + future_preds = [] + for ind, futures in enumerate(self._futures): - for futures in self._futures: - future_preds = [] + priority = len(self._futures) - ind # reverse order for priority for objfct, worker in zip(futures, self._workers): future_preds.append( client.submit( @@ -281,9 +287,10 @@ def get_dpred(self, m, f=None): objfct, m_future, workers=worker, + priority=priority, ) ) - dpred += client.gather(future_preds) + dpred = client.gather(future_preds) return dpred @@ -294,13 +301,11 @@ def getJtJdiag(self, m, f=None): self.model = m m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: - - jtj_diag = 0.0 client = self.client + work = [] + for ind, futures in enumerate(self._futures): - for futures in self._futures: - work = [] - + priority = len(self._futures) - ind # reverse order for priority for objfct, worker in zip(futures, self._workers): work.append( client.submit( @@ -308,11 +313,12 @@ def getJtJdiag(self, m, f=None): objfct, m_future, workers=worker, + priority=priority, ) ) - work = client.gather(work) - jtj_diag += np.sum(work, axis=0) + work = client.gather(work) + jtj_diag = np.sum(work, axis=0) self._jtjdiag = jtj_diag @@ -332,15 +338,18 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for futures in self._futures: - f.append([]) + for ind, futures in enumerate(self._futures): + + priority = len(self._futures) - ind # reverse order for priority + for objfct, worker in zip(futures, self._workers): - f[-1].append( + f.append( client.submit( _calc_fields, objfct, m_future, workers=worker, + priority=priority, ) ) self._stashed_fields = f @@ -367,7 +376,9 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for futures in self._futures: + for ind, futures in enumerate(self._futures): + + priority = len(self._futures) - ind # reverse order for priority for objfct, worker in zip(futures, self._workers): stores.append( client.submit( @@ -375,6 +386,7 @@ def model(self, value): objfct, self._m_as_future, workers=worker, + priority=priority, ) ) self.client.gather(stores) # blocking call to ensure all models were stored @@ -418,9 +430,11 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future residuals = [] + future_residuals = [] + for ind, futures in enumerate(self._futures): + + priority = len(self._futures) - ind # reverse order for priority - for futures in self._futures: - future_residuals = [] for objfct, worker in zip(futures, self._workers): future_residuals.append( client.submit( @@ -428,9 +442,10 @@ def residuals(self, m, f=None): objfct, m_future, workers=worker, + priority=priority, ) ) - residuals += client.gather(future_residuals) + residuals = client.gather(future_residuals) return residuals From ccd501825faaac92ba88991fb07219af13100017 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 25 Jul 2025 09:45:10 -0700 Subject: [PATCH 7/7] Revert "Use priority to send queue all jobs on workers" This reverts commit 13bf9cee0c18909735e706fdd2136b234efb8634. --- simpeg/dask/objective_function.py | 75 +++++++++++++------------------ 1 file changed, 30 insertions(+), 45 deletions(-) diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 2583927a7b..4b9addd8d7 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -132,9 +132,7 @@ def __call__(self, m, f=None): values = [] count = 0 - for ind, futures in enumerate(self._futures): - - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: for objfct, worker in zip(futures, self._workers, strict=True): if self.multipliers[count] == 0.0: @@ -147,7 +145,6 @@ def __call__(self, m, f=None): self.multipliers[count], m_future, workers=worker, - priority=priority, ) ) count += 1 @@ -196,12 +193,11 @@ def deriv(self, m, f=None): client = self.client m_future = self._m_as_future + derivs = 0.0 count = 0 - future_deriv = [] - for ind, futures in enumerate(self._futures): - - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: + future_deriv = [] for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue @@ -213,14 +209,13 @@ def deriv(self, m, f=None): self.multipliers[count], m_future, workers=worker, - priority=priority, ) ) count += 1 - future_deriv = client.gather(future_deriv) + future_deriv = client.gather(future_deriv) - derivs = np.sum(future_deriv, axis=0) + derivs += np.sum(future_deriv, axis=0) return derivs @@ -239,11 +234,12 @@ def deriv2(self, m, v=None, f=None): m_future = self._m_as_future [v_future] = client.scatter([v], broadcast=True) + derivs = 0.0 count = 0 - future_derivs = [] - for ind, futures in enumerate(self._futures): - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: + + future_derivs = [] for objfct, worker in zip(futures, self._workers): if self.multipliers[count] == 0.0: # don't evaluate the fct continue @@ -257,13 +253,12 @@ def deriv2(self, m, v=None, f=None): v_future, # field, workers=worker, - priority=priority, ) ) count += 1 - future_derivs = self.client.gather(future_derivs) - derivs = np.sum(future_derivs, axis=0) + future_derivs = self.client.gather(future_derivs) + derivs += np.sum(future_derivs, axis=0) return derivs @@ -275,11 +270,10 @@ def get_dpred(self, m, f=None): client = self.client m_future = self._m_as_future - # dpred = [] - future_preds = [] - for ind, futures in enumerate(self._futures): + dpred = [] - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: + future_preds = [] for objfct, worker in zip(futures, self._workers): future_preds.append( client.submit( @@ -287,10 +281,9 @@ def get_dpred(self, m, f=None): objfct, m_future, workers=worker, - priority=priority, ) ) - dpred = client.gather(future_preds) + dpred += client.gather(future_preds) return dpred @@ -301,11 +294,13 @@ def getJtJdiag(self, m, f=None): self.model = m m_future = self._m_as_future if getattr(self, "_jtjdiag", None) is None: + + jtj_diag = 0.0 client = self.client - work = [] - for ind, futures in enumerate(self._futures): - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: + work = [] + for objfct, worker in zip(futures, self._workers): work.append( client.submit( @@ -313,12 +308,11 @@ def getJtJdiag(self, m, f=None): objfct, m_future, workers=worker, - priority=priority, ) ) - work = client.gather(work) - jtj_diag = np.sum(work, axis=0) + work = client.gather(work) + jtj_diag += np.sum(work, axis=0) self._jtjdiag = jtj_diag @@ -338,18 +332,15 @@ def fields(self, m): # The above should pass the model to all the internal simulations. f = [] - for ind, futures in enumerate(self._futures): - - priority = len(self._futures) - ind # reverse order for priority - + for futures in self._futures: + f.append([]) for objfct, worker in zip(futures, self._workers): - f.append( + f[-1].append( client.submit( _calc_fields, objfct, m_future, workers=worker, - priority=priority, ) ) self._stashed_fields = f @@ -376,9 +367,7 @@ def model(self, value): [self._m_as_future] = client.scatter([value], broadcast=True) stores = [] - for ind, futures in enumerate(self._futures): - - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: for objfct, worker in zip(futures, self._workers): stores.append( client.submit( @@ -386,7 +375,6 @@ def model(self, value): objfct, self._m_as_future, workers=worker, - priority=priority, ) ) self.client.gather(stores) # blocking call to ensure all models were stored @@ -430,11 +418,9 @@ def residuals(self, m, f=None): client = self.client m_future = self._m_as_future residuals = [] - future_residuals = [] - for ind, futures in enumerate(self._futures): - - priority = len(self._futures) - ind # reverse order for priority + for futures in self._futures: + future_residuals = [] for objfct, worker in zip(futures, self._workers): future_residuals.append( client.submit( @@ -442,10 +428,9 @@ def residuals(self, m, f=None): objfct, m_future, workers=worker, - priority=priority, ) ) - residuals = client.gather(future_residuals) + residuals += client.gather(future_residuals) return residuals