Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions simpeg/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,87 @@ def compute_rows(
return np.vstack(rows)


def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):
"""
Evaluate the data prediction for a block of sources.
"""
data = []
for ind in indices:

receiver_list = sources[ind].receiver_list
if len(receiver_list) == 0:
continue

for receiver in receiver_list:
data.append(receiver.eval(sources[ind], mesh, time_mesh, fields))

return np.hstack(data)


def dpred(self, m=None, f=None):
# Docstring inherited from BaseSimulation.
if self.survey is None:
raise AttributeError(
"The survey has not yet been set and is required to compute "
"data. Please set the survey for the simulation: "
"simulation.survey = survey"
)

try:
client = get_client()
except ValueError:
client = None

if f is None:
f = self.fields(m)

delayed_chunks = []

source_block = np.array_split(
np.arange(len(self.survey.source_list)), self.n_threads(client=client)
)
if client:
mesh = client.scatter(self.mesh, workers=self.worker)
time_mesh = client.scatter(self.time_mesh, workers=self.worker)
fields = client.scatter(f, workers=self.worker)
source_list = client.scatter(self.survey.source_list, workers=self.worker)
else:
mesh = self.mesh
time_mesh = self.time_mesh
delayed_eval = delayed(evaluate_dpred_block)
source_list = self.survey.source_list
fields = f

for block in source_block:
if len(block) == 0:
continue

if client:
delayed_chunks.append(
client.submit(
evaluate_dpred_block,
block,
source_list,
mesh,
time_mesh,
fields,
workers=self.worker,
)
)
else:
delayed_chunks.append(
delayed_eval(block, source_list, mesh, time_mesh, fields)
)

if client:
result = client.gather(delayed_chunks)
else:
result = dask.compute(delayed_chunks)[0]

return np.hstack(result)


Sim.dpred = dpred
Sim.fields = fields
Sim.getSourceTerm = getSourceTerm
Sim.compute_J = compute_J
Expand Down
20 changes: 10 additions & 10 deletions simpeg/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def get_parallel_blocks(
row_index += chunk_size
row_count += chunk_size

# Re-split over cpu_count if too few blocks
if len(blocks) < thread_count and optimize:
flatten_blocks = []
for block in blocks:
flatten_blocks += block

chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
return [
[flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
]
# # Re-split over cpu_count if too few blocks
# if len(blocks) < thread_count and optimize:
# flatten_blocks = []
# for block in blocks:
# flatten_blocks += block
#
# chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
# return [
# [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
# ]

return blocks
21 changes: 18 additions & 3 deletions simpeg/electromagnetics/time_domain/receivers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import scipy.sparse as sp

import numpy as np
from ...utils import mkvc, validate_type, validate_direction
from discretize.utils import Zero
from ...survey import BaseTimeRx
Expand Down Expand Up @@ -128,6 +128,20 @@ def getTimeP(self, time_mesh, f):

return self.timeP

def active_times(self, projection):
"""Get active times for the receiver.

Parameters
----------
projection : Sparse matrix

Returns
-------
numpy.ndarray
Active times for the receiver.
"""
return np.unique(sp.find(projection)[1])

def getP(self, mesh, time_mesh, f):
"""Returns projection matrices as a list for all components collected by the receivers.

Expand All @@ -153,6 +167,7 @@ def getP(self, mesh, time_mesh, f):

Ps = self.getSpatialP(mesh, f)
Pt = self.getTimeP(time_mesh, f)
Pt = Pt[:, self.active_times(Pt)]
P = sp.kron(Pt, Ps)

if self.storeProjections:
Expand Down Expand Up @@ -180,7 +195,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003
Fields projected to the receiver(s)
"""
P = self.getP(mesh, time_mesh, f)
f_part = mkvc(f[src, self.projField, :])
f_part = mkvc(f[src, self.projField, self.active_times(self.timeP)])
return P * f_part

def evalDeriv(self, src, mesh, time_mesh, f, v, adjoint=False):
Expand Down Expand Up @@ -301,7 +316,7 @@ def eval(self, src, mesh, time_mesh, f): # noqa: A003
)

P = self.getP(mesh, time_mesh, f)
f_part = mkvc(f[src, "b", :])
f_part = mkvc(f[src, "b", self.active_times(self.timeP)])
return P * f_part

def getTimeP(self, time_mesh, f):
Expand Down
2 changes: 1 addition & 1 deletion simpeg/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _getField(self, name, ind, src_list):
pointerFields = pointerFields.reshape(pointerShape, order="F")

# First try to return the function as three arguments (without timeInd)
if timeInd == slice(None, None, None):
if isinstance(timeInd, slice) and timeInd == slice(None, None, None):
try:
# assume it will take care of integrating over all times
return func(pointerFields, srcInd)
Expand Down