diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 3fd1a42697..612896cb03 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -200,40 +200,41 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - if client: - field_derivatives = client.scatter( - ATinv_df_duT_v[ind], workers=self.worker - ) - j_row_updates.append( - client.submit( - compute_rows, - sim, - tInd, - block, - field_derivatives, - fields_array, - time_mask, - workers=self.worker, - ) - ) - else: - j_row_updates.append( - array.from_delayed( - delayed_compute_rows( + for row, field_derivatives in zip(block, ATinv_df_duT_v[ind]): + if client: + # field_derivatives = client.scatter( + # ATinv_df_duT_v[ind], workers=self.worker + # ) + j_row_updates.append( + client.submit( + compute_rows, sim, tInd, - block, - ATinv_df_duT_v[ind], + row, + field_derivatives, fields_array, time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), + workers=self.worker, + ) + ) + else: + j_row_updates.append( + array.from_delayed( + delayed_compute_rows( + sim, + tInd, + row, + field_derivatives, + fields_array, + time_mask, + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), + ) ) - ) if client: j_row_updates = np.vstack(client.gather(j_row_updates)) @@ -390,59 +391,39 @@ def get_field_deriv_block( """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ - stacked_blocks = [] if len(ATinv_df_duT_v) == 0: ATinv_df_duT_v = [[] for _ in block] - indices = [] - count = 0 Asubdiag = None if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) + updated_ATinv_df_duT_v = [] + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( block, field_derivs, ATinv_df_duT_v ): + # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] local_ind = np.arange(rx_ind.shape[0])[time_check] - indices.append( - (np.arange(count, count + len(local_ind)), local_ind), - ) - count += len(local_ind) if len(ATinv_chunk) == 0: # last timestep (first to be solved) - stacked_block = field_deriv.toarray()[:, local_ind] - - else: - stacked_block = np.asarray( - field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] - ) - - stacked_blocks.append(stacked_block) - - blocks = np.hstack(stacked_blocks) - if blocks.ndim == 2 and blocks.shape[1] > 0: - solve = (AdiagTinv * blocks).reshape(blocks.shape) - else: - solve = None - - updated_ATinv_df_duT_v = [] - - for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip( - block, field_derivs, ATinv_df_duT_v, indices, strict=True - ): - - if len(ATinv_chunk) == 0: + time_block = field_deriv.toarray()[:, local_ind] shape = ( field_deriv.shape[0], - len(arrays[0]), + len(rx_ind), ) ATinv_chunk = np.zeros(shape, dtype=np.float32) + else: + time_block = np.asarray( + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] + ) - if solve is not None: - ATinv_chunk[:, local_ind] = solve[:, columns] + if time_block.ndim == 2 and time_block.shape[1] > 0: + solve = (AdiagTinv * time_block).reshape(time_block.shape) + ATinv_chunk[:, local_ind] = solve updated_ATinv_df_duT_v.append(ATinv_chunk) @@ -513,52 +494,47 @@ def compute_rows( simulation, tInd, chunks, - ATinv_df_duT_v, + field_derivs, fields, time_mask, ): """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - rows = [] + (address, ind_array) = chunks + # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): + src = simulation.survey.source_list[address[0]] + time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] + local_ind = np.arange(len(ind_array[0]))[time_check] - for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): - src = simulation.survey.source_list[address[0]] - time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - local_ind = np.arange(len(ind_array[0]))[time_check] - - if len(local_ind) < 1: - row_block = np.zeros( - (len(ind_array[1]), simulation.model.size), dtype=np.float32 - ) - rows.append(row_block) - continue - - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, - fields[:, address[0], tInd], - field_derivs[:, local_ind], - adjoint=True, - ) - - dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, field_derivs[:, local_ind], adjoint=True - ) # on nodes of time mesh - - un_src = fields[:, address[0], tInd + 1] - # cell centered on time mesh - dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, field_derivs[:, local_ind], adjoint=True - ) + if len(local_ind) < 1: row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32 ) - row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( - np.float32 - ) - rows.append(row_block) + return row_block + + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, + fields[:, address[0], tInd], + field_derivs[:, local_ind], + adjoint=True, + ) + + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, field_derivs[:, local_ind], adjoint=True + ) # on nodes of time mesh + + un_src = fields[:, address[0], tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, field_derivs[:, local_ind], adjoint=True + ) + row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32) + row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( + np.float32 + ) - return np.vstack(rows) + return row_block def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):