diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index fe68e234f1..823840564b 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -11,6 +11,7 @@ from dask import array, delayed from dask.distributed import get_client +from time import time from simpeg.dask.utils import get_parallel_blocks from simpeg.utils import mkvc @@ -405,10 +406,6 @@ def get_field_deriv_block( # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] local_ind = np.arange(rx_ind.shape[0])[time_check] - - if len(local_ind) < 1: - continue - indices.append( (np.arange(count, count + len(local_ind)), local_ind), ) @@ -425,14 +422,14 @@ def get_field_deriv_block( stacked_blocks.append(stacked_block) - if len(stacked_blocks) > 0: - blocks = np.hstack(stacked_blocks) - + blocks = np.hstack(stacked_blocks) + if blocks.ndim == 2 and blocks.shape[1] > 0: solve = (AdiagTinv * blocks).reshape(blocks.shape) else: solve = None updated_ATinv_df_duT_v = [] + for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip( block, field_derivs, ATinv_df_duT_v, indices, strict=True ): @@ -444,10 +441,9 @@ def get_field_deriv_block( ) ATinv_chunk = np.zeros(shape, dtype=np.float32) - if solve is None: - continue + if solve is not None: + ATinv_chunk[:, local_ind] = solve[:, columns] - ATinv_chunk[:, local_ind] = solve[:, columns] updated_ATinv_df_duT_v.append(ATinv_chunk) return updated_ATinv_df_duT_v