diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 823840564b..3fd1a42697 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -561,6 +561,87 @@ def compute_rows( return np.vstack(rows) +def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): + """ + Evaluate the data prediction for a block of sources. + """ + data = [] + for ind in indices: + + receiver_list = sources[ind].receiver_list + if len(receiver_list) == 0: + continue + + for receiver in receiver_list: + data.append(receiver.eval(sources[ind], mesh, time_mesh, fields)) + + return np.hstack(data) + + +def dpred(self, m=None, f=None): + # Docstring inherited from BaseSimulation. + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + try: + client = get_client() + except ValueError: + client = None + + if f is None: + f = self.fields(m) + + delayed_chunks = [] + + source_block = np.array_split( + np.arange(len(self.survey.source_list)), self.n_threads(client=client) + ) + if client: + mesh = client.scatter(self.mesh, workers=self.worker) + time_mesh = client.scatter(self.time_mesh, workers=self.worker) + fields = client.scatter(f, workers=self.worker) + source_list = client.scatter(self.survey.source_list, workers=self.worker) + else: + mesh = self.mesh + time_mesh = self.time_mesh + delayed_eval = delayed(evaluate_dpred_block) + source_list = self.survey.source_list + fields = f + + for block in source_block: + if len(block) == 0: + continue + + if client: + delayed_chunks.append( + client.submit( + evaluate_dpred_block, + block, + source_list, + mesh, + time_mesh, + fields, + workers=self.worker, + ) + ) + else: + delayed_chunks.append( + delayed_eval(block, source_list, mesh, time_mesh, fields) + ) + + if client: + result = client.gather(delayed_chunks) + else: + result = dask.compute(delayed_chunks)[0] + + return np.hstack(result) + + +Sim.dpred = dpred Sim.fields = fields Sim.getSourceTerm = getSourceTerm Sim.compute_J = compute_J diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index 19ba96f1a6..e17c130924 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -75,15 +75,15 @@ def get_parallel_blocks( row_index += chunk_size row_count += chunk_size - # Re-split over cpu_count if too few blocks - if len(blocks) < thread_count and optimize: - flatten_blocks = [] - for block in blocks: - flatten_blocks += block - - chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - return [ - [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 - ] + # # Re-split over cpu_count if too few blocks + # if len(blocks) < thread_count and optimize: + # flatten_blocks = [] + # for block in blocks: + # flatten_blocks += block + # + # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + # return [ + # [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 + # ] return blocks diff --git a/simpeg/electromagnetics/time_domain/receivers.py b/simpeg/electromagnetics/time_domain/receivers.py index dc2cbb0255..0a0fa0966c 100644 --- a/simpeg/electromagnetics/time_domain/receivers.py +++ b/simpeg/electromagnetics/time_domain/receivers.py @@ -1,5 +1,5 @@ import scipy.sparse as sp - +import numpy as np from ...utils import mkvc, validate_type, validate_direction from discretize.utils import Zero from ...survey import BaseTimeRx @@ -128,6 +128,20 @@ def getTimeP(self, time_mesh, f): return self.timeP + def active_times(self, projection): + """Get active times for the receiver. + + Parameters + ---------- + projection : Sparse matrix + + Returns + ------- + numpy.ndarray + Active times for the receiver. + """ + return np.unique(sp.find(projection)[1]) + def getP(self, mesh, time_mesh, f): """Returns projection matrices as a list for all components collected by the receivers. @@ -153,6 +167,7 @@ def getP(self, mesh, time_mesh, f): Ps = self.getSpatialP(mesh, f) Pt = self.getTimeP(time_mesh, f) + Pt = Pt[:, self.active_times(Pt)] P = sp.kron(Pt, Ps) if self.storeProjections: @@ -180,7 +195,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003 Fields projected to the receiver(s) """ P = self.getP(mesh, time_mesh, f) - f_part = mkvc(f[src, self.projField, :]) + f_part = mkvc(f[src, self.projField, self.active_times(self.timeP)]) return P * f_part def evalDeriv(self, src, mesh, time_mesh, f, v, adjoint=False): @@ -301,7 +316,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003 ) P = self.getP(mesh, time_mesh, f) - f_part = mkvc(f[src, "b", :]) + f_part = mkvc(f[src, "b", self.active_times(self.timeP)]) return P * f_part def getTimeP(self, time_mesh, f): diff --git a/simpeg/fields.py b/simpeg/fields.py index f78c994219..009d076c6e 100644 --- a/simpeg/fields.py +++ b/simpeg/fields.py @@ -501,7 +501,7 @@ def _getField(self, name, ind, src_list): pointerFields = pointerFields.reshape(pointerShape, order="F") # First try to return the function as three arguments (without timeInd) - if timeInd == slice(None, None, None): + if isinstance(timeInd, slice) and timeInd == slice(None, None, None): try: # assume it will take care of integrating over all times return func(pointerFields, srcInd)