From d1428c3a01cec6c1b7fb216b72f15d540d17ba1d Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Sep 2025 14:58:32 -0700 Subject: [PATCH 1/5] Change scatter assignment --- simpeg/dask/electromagnetics/time_domain/simulation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 54c1fbb0e4..42bcba2be5 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -200,16 +200,15 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - field_derivatives = ATinv_df_duT_v[ind] if client: field_derivatives = client.scatter( ATinv_df_duT_v[ind], workers=self.worker ) + else: + field_derivatives = ATinv_df_duT_v[ind] + for bb, row in enumerate(block): if client: - # field_derivatives = client.scatter( - # ATinv_df_duT_v[ind], workers=self.worker - # ) j_row_updates.append( client.submit( compute_rows, From 2d0044246ff17fd65ed112dbd8df4f46b96689ff Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Sep 2025 15:26:11 -0700 Subject: [PATCH 2/5] Flatten the loop over derivatives --- .../time_domain/simulation.py | 126 +++++++++--------- 1 file changed, 64 insertions(+), 62 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 42bcba2be5..1da4616049 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -186,7 +186,7 @@ def compute_J(self, m, f=None): for ind, (block, field_deriv) in enumerate( zip(blocks, times_field_derivs[tInd + 1], strict=True) ): - ATinv_df_duT_v[ind] = get_field_deriv_block( + atinv_block_deriv = get_field_deriv_block( self, block, field_deriv, @@ -202,45 +202,43 @@ def compute_J(self, m, f=None): if client: field_derivatives = client.scatter( - ATinv_df_duT_v[ind], workers=self.worker + atinv_block_deriv, workers=self.worker ) else: - field_derivatives = ATinv_df_duT_v[ind] + field_derivatives = atinv_block_deriv - for bb, row in enumerate(block): - if client: - j_row_updates.append( - client.submit( - compute_rows, + if client: + j_row_updates.append( + client.submit( + compute_rows, + sim, + tInd, + block, + field_derivatives, + fields_array, + time_mask, + workers=self.worker, + ) + ) + else: + j_row_updates.append( + array.from_delayed( + delayed_compute_rows( sim, tInd, - row, - bb, + block, field_derivatives, fields_array, time_mask, - workers=self.worker, - ) - ) - else: - j_row_updates.append( - array.from_delayed( - delayed_compute_rows( - sim, - tInd, - row, - bb, - field_derivatives, - fields_array, - time_mask, - ), - dtype=np.float32, - shape=( - np.sum([len(chunk[1][0]) for chunk in block]), - m.size, - ), - ) + ), + dtype=np.float32, + shape=( + np.sum([len(chunk[1][0]) for chunk in block]), + m.size, + ), ) + ) + ATinv_df_duT_v[ind] = atinv_block_deriv if client: j_row_updates = np.vstack(client.gather(j_row_updates)) @@ -499,8 +497,7 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): def compute_rows( simulation, tInd, - chunks, - ind, + block, field_derivs, fields, time_mask, @@ -508,40 +505,45 @@ def compute_rows( """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - (address, ind_array) = chunks - # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): - src = simulation.survey.source_list[address[0]] - time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - local_ind = np.arange(len(ind_array[0]))[time_check] + rows = [] + for ind, (address, ind_array) in enumerate(block): + # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): + src = simulation.survey.source_list[address[0]] + time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] + local_ind = np.arange(len(ind_array[0]))[time_check] + + if len(local_ind) < 1: + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + rows.append(row_block) + continue - if len(local_ind) < 1: - row_block = np.zeros( - (len(ind_array[1]), simulation.model.size), dtype=np.float32 + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, + fields[:, address[0], tInd], + field_derivs[ind][:, local_ind], + adjoint=True, ) - return row_block - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, - fields[:, address[0], tInd], - field_derivs[ind][:, local_ind], - adjoint=True, - ) - - dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True - ) # on nodes of time mesh + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True + ) # on nodes of time mesh - un_src = fields[:, address[0], tInd + 1] - # cell centered on time mesh - dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True - ) - row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32) - row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( - np.float32 - ) + un_src = fields[:, address[0], tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True + ) + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( + np.float32 + ) + rows.append(row_block) - return row_block + return np.vstack(rows) def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): From ec553c36c5898b0f951fa0e017c5eb1aca3314f5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Sep 2025 15:33:13 -0700 Subject: [PATCH 3/5] Bring back optimizaing blocks --- simpeg/dask/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index e17c130924..d7100b604a 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -76,14 +76,14 @@ def get_parallel_blocks( row_count += chunk_size # # Re-split over cpu_count if too few blocks - # if len(blocks) < thread_count and optimize: - # flatten_blocks = [] - # for block in blocks: - # flatten_blocks += block - # - # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - # return [ - # [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 - # ] + if len(blocks) < thread_count and optimize: + flatten_blocks = [] + for block in blocks: + flatten_blocks += block + + chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + return [ + [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 + ] return blocks From 385458039bf2a1f1803466b86fd08def984e1471 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Sep 2025 17:48:02 -0700 Subject: [PATCH 4/5] TRy without scattering derivs --- .../time_domain/simulation.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 1da4616049..ef765b71b0 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -176,8 +176,9 @@ def compute_J(self, m, f=None): delayed_compute_rows = delayed(compute_rows) sim = self for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)): + AdiagTinv = Ainv[dt] - j_row_updates = [] + future_updates = [] time_mask = data_times > simulation_times[tInd] if not np.any(time_mask): @@ -200,15 +201,15 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - if client: - field_derivatives = client.scatter( - atinv_block_deriv, workers=self.worker - ) - else: - field_derivatives = atinv_block_deriv + # if client: + # field_derivatives = client.scatter( + # atinv_block_deriv, workers=self.worker + # ) + # else: + field_derivatives = atinv_block_deriv if client: - j_row_updates.append( + future_updates.append( client.submit( compute_rows, sim, @@ -221,7 +222,7 @@ def compute_J(self, m, f=None): ) ) else: - j_row_updates.append( + future_updates.append( array.from_delayed( delayed_compute_rows( sim, @@ -241,9 +242,9 @@ def compute_J(self, m, f=None): ATinv_df_duT_v[ind] = atinv_block_deriv if client: - j_row_updates = np.vstack(client.gather(j_row_updates)) + j_row_updates = np.vstack(client.gather(future_updates)) else: - j_row_updates = array.vstack(j_row_updates).compute() + j_row_updates = array.vstack(future_updates).compute() if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" From 9f88d68aea88547501e7675bd68b51506cf6ab55 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Sep 2025 19:03:53 -0700 Subject: [PATCH 5/5] But scatter on large array. Move indexing --- .../time_domain/simulation.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index ef765b71b0..9905a1cc03 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -187,7 +187,7 @@ def compute_J(self, m, f=None): for ind, (block, field_deriv) in enumerate( zip(blocks, times_field_derivs[tInd + 1], strict=True) ): - atinv_block_deriv = get_field_deriv_block( + ATinv_df_duT_v[ind] = get_field_deriv_block( self, block, field_deriv, @@ -198,23 +198,24 @@ def compute_J(self, m, f=None): client, ) + if client: + field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker) + else: + field_derivatives = ATinv_df_duT_v + + for block_ind in range(len(blocks)): + if len(block) == 0: continue - # if client: - # field_derivatives = client.scatter( - # atinv_block_deriv, workers=self.worker - # ) - # else: - field_derivatives = atinv_block_deriv - if client: future_updates.append( client.submit( compute_rows, sim, tInd, - block, + block_ind, + blocks, field_derivatives, fields_array, time_mask, @@ -227,7 +228,8 @@ def compute_J(self, m, f=None): delayed_compute_rows( sim, tInd, - block, + block_ind, + blocks, field_derivatives, fields_array, time_mask, @@ -239,7 +241,6 @@ def compute_J(self, m, f=None): ), ) ) - ATinv_df_duT_v[ind] = atinv_block_deriv if client: j_row_updates = np.vstack(client.gather(future_updates)) @@ -498,7 +499,8 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs): def compute_rows( simulation, tInd, - block, + block_ind, + blocks, field_derivs, fields, time_mask, @@ -507,7 +509,7 @@ def compute_rows( Compute the rows of the sensitivity matrix for a given source and receiver. """ rows = [] - for ind, (address, ind_array) in enumerate(block): + for ind, (address, ind_array) in enumerate(blocks[block_ind]): # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] @@ -523,18 +525,18 @@ def compute_rows( dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - field_derivs[ind][:, local_ind], + field_derivs[block_ind][ind][:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True + tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True + tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32