Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions simpeg/dask/inverse_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def get_dpred(self, m, f=None, return_residuals=False):
)

if return_residuals:
return np.hstack(results[0]), np.hstack(results[1])
return results[0], results[1]

return np.hstack(results)
return results


BaseInvProblem.get_dpred = get_dpred
Expand All @@ -61,7 +61,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True):
self.model = m
self.dpred, self.residuals = self.get_dpred(m, return_residuals=True)

phi_d = np.vdot(self.residuals, self.residuals)
phi_d = (np.hstack(self.residuals) ** 2.0).sum()

reg2Deriv = []
if isinstance(self.reg, ComboObjectiveFunction):
Expand Down
36 changes: 23 additions & 13 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,39 +149,49 @@ def _validate_type_or_future_of_type(
objects = validate_list_of_types(
property_name, objects, obj_type, ensure_unique=True
)
workload = [[]]
workloads = {}
for worker in workers:
workloads[worker] = []

count = 0
for obj in objects:
if count == len(workers):
count = 0
workload.append([])
for ii, obj in enumerate(objects):
count = ii % len(workers)

if isinstance(obj, Future):
future = obj
count = workers.index(client.who_has(obj)[obj.key])
else:
future = client.scatter([obj], workers=workers[count])[0]

workload[-1].append(future)
count += 1
workloads[workers[count]].append(future)

futures = []
assignments = []
for work in workload:
for obj, worker in zip(work, workers):
for worker, work in workloads.items():
for future in work:
futures.append(
client.submit(
lambda v: not isinstance(v, obj_type), obj, workers=worker
lambda v: not isinstance(v, obj_type), future, workers=worker
)
)
assignments.append(client.submit(_set_worker, obj, worker, workers=worker))
assignments.append(
client.submit(_set_worker, future, worker, workers=worker)
)

client.gather(assignments)

is_not_obj = np.array(client.gather(futures))
if np.any(is_not_obj):
raise TypeError(f"{property_name} futures must be an instance of {obj_type}")

# Re-distribute the workload to ensure all workers are equally loaded
workload = []
for work in workloads.values():
for ii, future in enumerate(work):
if len(workload) <= ii:
workload.append([])
workload[ii].append(future)

if return_workers:
return workload, workers
else:
Expand Down Expand Up @@ -382,9 +392,9 @@ def get_dpred(self, m, f=None, return_residuals=False):
dpred += [result]

if return_residuals:
return np.hstack(dpred), np.hstack(residuals)
return dpred, residuals

return np.hstack(dpred)
return dpred

def getJtJdiag(self, m, f=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion simpeg/directives/_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float:
chi_factor : float
Chi factor to compute the target misfit from.
"""
return self.invProb.dpred.shape[0] * chi_factor
return np.hstack(self.invProb.dpred).shape[0] * chi_factor

def adjust_cooling_schedule(self):
"""
Expand Down
5 changes: 2 additions & 3 deletions simpeg/directives/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3058,10 +3058,9 @@ def initialize(self):
def endIter(self):
ratio = self.invProb.beta / self.last_beta
chi_factors = []
for objfct, pred in zip(self.invProb.dmisfit.objfcts, self.invProb.dpred):
residual = objfct.W * (objfct.data.dobs - pred)
for residual in self.invProb.residuals:
phi_d = np.vdot(residual, residual)
chi_factors.append(phi_d / objfct.nD)
chi_factors.append(phi_d / len(residual))

self.chi_factors = np.asarray(chi_factors)

Expand Down