diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index f10bc965cd..4b9addd8d7 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -72,6 +72,7 @@ def _validate_type_or_future_of_type( property_name, objects, obj_type, ensure_unique=True ) workload = [[]] + lookup = {} count = 0 for obj in objects: if count == len(workers): @@ -79,7 +80,7 @@ def _validate_type_or_future_of_type( workload.append([]) obj.simulation.simulations[0].worker = workers[count] future = client.scatter([obj], workers=workers[count])[0] - + lookup[obj] = (future, workers[count]) if hasattr(obj, "name"): future.name = obj.name @@ -100,7 +101,7 @@ def _validate_type_or_future_of_type( raise TypeError(f"{property_name} futures must be an instance of {obj_type}") if return_workers: - return workload, workers + return workload, workers, lookup else: return workload @@ -390,7 +391,7 @@ def objfcts(self): def objfcts(self, objfcts): client = self.client - futures, workers = _validate_type_or_future_of_type( + futures, workers, lookup = _validate_type_or_future_of_type( "objfcts", objfcts, L2DataMisfit, @@ -404,8 +405,8 @@ def objfcts(self, objfcts): self._workers = workers self._lookup = { - obj.simulation: (future, worker) - for future, worker, obj in zip(futures[0], workers, objfcts) + misfit.simulation: (future, worker) + for misfit, (future, worker) in lookup.items() } def residuals(self, m, f=None): @@ -443,12 +444,12 @@ def broadcast_updates(self, updates: dict): if fun not in self._lookup: continue - objfct, worker = self._lookup[fun] + future, worker = self._lookup[fun] stores.append( client.submit( _setter_broadcast, - objfct, + future, key, value, workers=worker,