diff --git a/simpeg/dask/inverse_problem.py b/simpeg/dask/inverse_problem.py index e5e3c2d092..5caf1faec5 100644 --- a/simpeg/dask/inverse_problem.py +++ b/simpeg/dask/inverse_problem.py @@ -46,9 +46,9 @@ def get_dpred(self, m, f=None, return_residuals=False): ) if return_residuals: - return np.hstack(results[0]), np.hstack(results[1]) + return results[0], results[1] - return np.hstack(results) + return results BaseInvProblem.get_dpred = get_dpred @@ -61,7 +61,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True): self.model = m self.dpred, self.residuals = self.get_dpred(m, return_residuals=True) - phi_d = np.vdot(self.residuals, self.residuals) + phi_d = (np.hstack(self.residuals) ** 2.0).sum() reg2Deriv = [] if isinstance(self.reg, ComboObjectiveFunction): diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index a52040ce70..190fdf393c 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -149,32 +149,34 @@ def _validate_type_or_future_of_type( objects = validate_list_of_types( property_name, objects, obj_type, ensure_unique=True ) - workload = [[]] + workloads = {} + for worker in workers: + workloads[worker] = [] count = 0 - for obj in objects: - if count == len(workers): - count = 0 - workload.append([]) + for ii, obj in enumerate(objects): + count = ii % len(workers) if isinstance(obj, Future): future = obj + count = workers.index(client.who_has(obj)[obj.key]) else: future = client.scatter([obj], workers=workers[count])[0] - workload[-1].append(future) - count += 1 + workloads[workers[count]].append(future) futures = [] assignments = [] - for work in workload: - for obj, worker in zip(work, workers): + for worker, work in workloads.items(): + for future in work: futures.append( client.submit( - lambda v: not isinstance(v, obj_type), obj, workers=worker + lambda v: not isinstance(v, obj_type), future, workers=worker ) ) - assignments.append(client.submit(_set_worker, obj, worker, workers=worker)) + assignments.append( + client.submit(_set_worker, future, worker, workers=worker) + ) client.gather(assignments) @@ -182,6 +184,14 @@ def _validate_type_or_future_of_type( if np.any(is_not_obj): raise TypeError(f"{property_name} futures must be an instance of {obj_type}") + # Re-distribute the workload to ensure all workers are equally loaded + workload = [] + for work in workloads.values(): + for ii, future in enumerate(work): + if len(workload) <= ii: + workload.append([]) + workload[ii].append(future) + if return_workers: return workload, workers else: @@ -382,9 +392,9 @@ def get_dpred(self, m, f=None, return_residuals=False): dpred += [result] if return_residuals: - return np.hstack(dpred), np.hstack(residuals) + return dpred, residuals - return np.hstack(dpred) + return dpred def getJtJdiag(self, m, f=None): """ diff --git a/simpeg/directives/_regularization.py b/simpeg/directives/_regularization.py index 16dca5a93e..d3fb7be2b3 100644 --- a/simpeg/directives/_regularization.py +++ b/simpeg/directives/_regularization.py @@ -218,7 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float: chi_factor : float Chi factor to compute the target misfit from. """ - return self.invProb.dpred.shape[0] * chi_factor + return np.hstack(self.invProb.dpred).shape[0] * chi_factor def adjust_cooling_schedule(self): """ diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index 44a183184d..d88902f69c 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -3058,10 +3058,9 @@ def initialize(self): def endIter(self): ratio = self.invProb.beta / self.last_beta chi_factors = [] - for objfct, pred in zip(self.invProb.dmisfit.objfcts, self.invProb.dpred): - residual = objfct.W * (objfct.data.dobs - pred) + for residual in self.invProb.residuals: phi_d = np.vdot(residual, residual) - chi_factors.append(phi_d / objfct.nD) + chi_factors.append(phi_d / len(residual)) self.chi_factors = np.asarray(chi_factors)