Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 72 additions & 68 deletions simpeg/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def compute_J(self, m, f=None):
delayed_compute_rows = delayed(compute_rows)
sim = self
for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)):

AdiagTinv = Ainv[dt]
j_row_updates = []
future_updates = []
time_mask = data_times > simulation_times[tInd]

if not np.any(time_mask):
Expand All @@ -197,56 +198,54 @@ def compute_J(self, m, f=None):
client,
)

if client:
field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker)
else:
field_derivatives = ATinv_df_duT_v

for block_ind in range(len(blocks)):

if len(block) == 0:
continue

field_derivatives = ATinv_df_duT_v[ind]
if client:
field_derivatives = client.scatter(
ATinv_df_duT_v[ind], workers=self.worker
future_updates.append(
client.submit(
compute_rows,
sim,
tInd,
block_ind,
blocks,
field_derivatives,
fields_array,
time_mask,
workers=self.worker,
)
)
for bb, row in enumerate(block):
if client:
# field_derivatives = client.scatter(
# ATinv_df_duT_v[ind], workers=self.worker
# )
j_row_updates.append(
client.submit(
compute_rows,
else:
future_updates.append(
array.from_delayed(
delayed_compute_rows(
sim,
tInd,
row,
bb,
block_ind,
blocks,
field_derivatives,
fields_array,
time_mask,
workers=self.worker,
)
)
else:
j_row_updates.append(
array.from_delayed(
delayed_compute_rows(
sim,
tInd,
row,
bb,
field_derivatives,
fields_array,
time_mask,
),
dtype=np.float32,
shape=(
np.sum([len(chunk[1][0]) for chunk in block]),
m.size,
),
)
),
dtype=np.float32,
shape=(
np.sum([len(chunk[1][0]) for chunk in block]),
m.size,
),
)
)

if client:
j_row_updates = np.vstack(client.gather(j_row_updates))
j_row_updates = np.vstack(client.gather(future_updates))
else:
j_row_updates = array.vstack(j_row_updates).compute()
j_row_updates = array.vstack(future_updates).compute()

if self.store_sensitivities == "disk":
sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr"
Expand Down Expand Up @@ -500,49 +499,54 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs):
def compute_rows(
simulation,
tInd,
chunks,
ind,
block_ind,
blocks,
field_derivs,
fields,
time_mask,
):
"""
Compute the rows of the sensitivity matrix for a given source and receiver.
"""
(address, ind_array) = chunks
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
src = simulation.survey.source_list[address[0]]
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
local_ind = np.arange(len(ind_array[0]))[time_check]
rows = []
for ind, (address, ind_array) in enumerate(blocks[block_ind]):
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
src = simulation.survey.source_list[address[0]]
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
local_ind = np.arange(len(ind_array[0]))[time_check]

if len(local_ind) < 1:
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
rows.append(row_block)
continue

if len(local_ind) < 1:
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd,
fields[:, address[0], tInd],
field_derivs[block_ind][ind][:, local_ind],
adjoint=True,
)
return row_block

dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd,
fields[:, address[0], tInd],
field_derivs[ind][:, local_ind],
adjoint=True,
)

dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True
) # on nodes of time mesh
dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
) # on nodes of time mesh

un_src = fields[:, address[0], tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True
)
row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32)
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
np.float32
)
un_src = fields[:, address[0], tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
)
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
np.float32
)
rows.append(row_block)

return row_block
return np.vstack(rows)


def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):
Expand Down
18 changes: 9 additions & 9 deletions simpeg/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def get_parallel_blocks(
row_count += chunk_size

# # Re-split over cpu_count if too few blocks
# if len(blocks) < thread_count and optimize:
# flatten_blocks = []
# for block in blocks:
# flatten_blocks += block
#
# chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
# return [
# [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
# ]
if len(blocks) < thread_count and optimize:
flatten_blocks = []
for block in blocks:
flatten_blocks += block

chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
return [
[flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
]

return blocks