diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 54c1fbb0e4..9905a1cc03 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -176,8 +176,9 @@ def compute_J(self, m, f=None): delayed_compute_rows = delayed(compute_rows) sim = self for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): + AdiagTinv = Ainv[dt] - j_row_updates = [] + future_updates = [] time_mask = data_times > simulation_times[tInd] if not np.any(time_mask): @@ -197,56 +198,54 @@ def compute_J(self, m, f=None): client, ) + if client: + field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker) + else: + field_derivatives = ATinv_df_duT_v + + for block_ind in range(len(blocks)): + if len(block) == 0: continue - field_derivatives = ATinv_df_duT_v[ind] if client: - field_derivatives = client.scatter( - ATinv_df_duT_v[ind], workers=self.worker + future_updates.append( + client.submit( + compute_rows, + sim, + tInd, + block_ind, + blocks, + field_derivatives, + fields_array, + time_mask, + workers=self.worker, + ) ) - for bb, row in enumerate(block): - if client: - # field_derivatives = client.scatter( - # ATinv_df_duT_v[ind], workers=self.worker - # ) - j_row_updates.append( - client.submit( - compute_rows, + else: + future_updates.append( + array.from_delayed( + delayed_compute_rows( sim, tInd, - row, - bb, + block_ind, + blocks, field_derivatives, fields_array, time_mask, - workers=self.worker, - ) - ) - else: - j_row_updates.append( - array.from_delayed( - delayed_compute_rows( - sim, - tInd, - row, - bb, - field_derivatives, - fields_array, - time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), - ) + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), ) + ) if client: - j_row_updates = np.vstack(client.gather(j_row_updates)) + j_row_updates = np.vstack(client.gather(future_updates)) else: - j_row_updates = array.vstack(j_row_updates).compute() + j_row_updates = array.vstack(future_updates).compute() if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" @@ -500,8 +499,8 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): def compute_rows( simulation, tInd, - chunks, - ind, + block_ind, + blocks, field_derivs, fields, time_mask, @@ -509,40 +508,45 @@ def compute_rows( """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - (address, ind_array) = chunks - # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): - src = simulation.survey.source_list[address[0]] - time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - local_ind = np.arange(len(ind_array[0]))[time_check] + rows = [] + for ind, (address, ind_array) in enumerate(blocks[block_ind]): + # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): + src = simulation.survey.source_list[address[0]] + time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] + local_ind = np.arange(len(ind_array[0]))[time_check] + + if len(local_ind) < 1: + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + rows.append(row_block) + continue - if len(local_ind) < 1: - row_block = np.zeros( - (len(ind_array[1]), simulation.model.size), dtype=np.float32 + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, + fields[:, address[0], tInd], + field_derivs[block_ind][ind][:, local_ind], + adjoint=True, ) - return row_block - - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, - fields[:, address[0], tInd], - field_derivs[ind][:, local_ind], - adjoint=True, - ) - dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True - ) # on nodes of time mesh + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True + ) # on nodes of time mesh - un_src = fields[:, address[0], tInd + 1] - # cell centered on time mesh - dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True - ) - row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32) - row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( - np.float32 - ) + un_src = fields[:, address[0], tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True + ) + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( + np.float32 + ) + rows.append(row_block) - return row_block + return np.vstack(rows) def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index e17c130924..d7100b604a 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -76,14 +76,14 @@ def get_parallel_blocks( row_count += chunk_size # # Re-split over cpu_count if too few blocks - # if len(blocks) < thread_count and optimize: - # flatten_blocks = [] - # for block in blocks: - # flatten_blocks += block - # - # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - # return [ - # [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 - # ] + if len(blocks) < thread_count and optimize: + flatten_blocks = [] + for block in blocks: + flatten_blocks += block + + chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + return [ + [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 + ] return blocks