From e87f69bcfb9b2e83ac815652c6f98409d1c25520 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Sep 2025 14:53:56 -0700 Subject: [PATCH] Fix warning with large graph --- .../electromagnetics/time_domain/simulation.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 612896cb03..54c1fbb0e4 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -200,7 +200,12 @@ def compute_J(self, m, f=None): if len(block) == 0: continue - for row, field_derivatives in zip(block, ATinv_df_duT_v[ind]): + field_derivatives = ATinv_df_duT_v[ind] + if client: + field_derivatives = client.scatter( + ATinv_df_duT_v[ind], workers=self.worker + ) + for bb, row in enumerate(block): if client: # field_derivatives = client.scatter( # ATinv_df_duT_v[ind], workers=self.worker @@ -211,6 +216,7 @@ def compute_J(self, m, f=None): sim, tInd, row, + bb, field_derivatives, fields_array, time_mask, @@ -224,6 +230,7 @@ def compute_J(self, m, f=None): sim, tInd, row, + bb, field_derivatives, fields_array, time_mask, @@ -494,6 +501,7 @@ def compute_rows( simulation, tInd, chunks, + ind, field_derivs, fields, time_mask, @@ -516,18 +524,18 @@ def compute_rows( dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - field_derivs[:, local_ind], + field_derivs[ind][:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, field_derivs[:, local_ind], adjoint=True + tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, field_derivs[:, local_ind], adjoint=True + tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True ) row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32) row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(