From acd4d16e4ae42efa1a8d5db6ef2a2bd862d33cd1 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 24 Apr 2024 15:53:29 -0700 Subject: [PATCH 1/3] Fix for large loop --- .../time_domain/simulation.py | 38 +++++++++++++------ SimPEG/dask/utils.py | 2 +- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index fee66abde2..7d24657fe4 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -11,6 +11,7 @@ from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks +from SimPEG.utils import mkvc import zarr from time import time from tqdm import tqdm @@ -215,6 +216,10 @@ def dask_dpred(self, m=None, f=None, compute_J=False): rows = [] receiver_projection = self.survey.source_list[0].receiver_list[0].projField fields_array = f[:, receiver_projection, :] + + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] + all_receivers = [] for ind, src in enumerate(self.survey.source_list): @@ -283,7 +288,7 @@ def delayed_block_deriv( projection.T, adjoint=True, ) - time_derivs.append(cur[0]) + time_derivs.append(cur[0][:, arrays[0]]) if not isinstance(cur[1], Zero): j_update += cur[1].T @@ -344,15 +349,15 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): @delayed def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, field_derivs, tInd + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd ): if (s_id, r_id, b_id) not in ATinv_df_duT_v: # last timestep (first to be solved) - stacked_block = field_derivs.toarray()[:, sub_ind] + stacked_block = field_derivs.toarray()[:, local_ind] else: stacked_block = np.asarray( - field_derivs[:, sub_ind] + field_derivs[:, local_ind] - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] ) @@ -367,7 +372,8 @@ def update_deriv_blocks(address, indices, derivatives, solve, shape): if address in indices: columns, local_ind = indices[address] - deriv_array[:, local_ind] = solve[:, columns] + if solve is not None: + deriv_array[:, local_ind] = solve[:, columns] derivatives[address] = deriv_array @@ -397,17 +403,16 @@ def get_field_deriv_block( ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] - sub_ind = rx_ind[time_check] local_ind = np.arange(rx_ind.shape[0])[time_check] - if len(sub_ind) < 1: + if len(local_ind) < 1: continue indices[(s_id, r_id, b_id)] = ( - np.arange(count, count + len(sub_ind)), + np.arange(count, count + len(local_ind)), local_ind, ) - count += len(sub_ind) + count += len(local_ind) deriv_comp = deriv_block( s_id, r_id, @@ -415,11 +420,9 @@ def get_field_deriv_block( ATinv_df_duT_v, Asubdiag, local_ind, - sub_ind, field_deriv, tInd, ) - stacked_blocks.append( array.from_delayed( deriv_comp, @@ -469,7 +472,11 @@ def compute_rows( local_ind = np.arange(len(ind_array[0]))[time_check] if len(local_ind) < 1: - return + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + rows.append(row_block) + continue field_derivs = ATinv_df_duT_v[address] dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( @@ -530,6 +537,10 @@ def compute_J(self, f=None, Ainv=None): self.survey.source_list, self.model.shape[0], self.max_chunk_size ) fields_array = f[:, ftype, :] + + if len(self.survey.source_list) == 1: + fields_array = fields_array[:, np.newaxis, :] + times_field_derivs, Jmatrix = compute_field_derivs( self, f, blocks, Jmatrix, fields_array.shape ) @@ -540,6 +551,9 @@ def compute_J(self, f=None, Ainv=None): j_row_updates = [] time_mask = data_times > simulation_times[tInd] + if not np.any(time_mask): + continue + tc = time() for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index c287bc4ad6..558fb08aa3 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -68,7 +68,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> if (row_count + chunk_size) > (data_block_size * cpu_count()): row_count = 0 block_count += 1 - blocks.append = [] + blocks.append([]) blocks[block_count].append( ( From e4dc6e73042dbb592924b0e44b58ad585c36daa7 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 25 Apr 2024 10:26:37 -0700 Subject: [PATCH 2/3] Add check for coo matrix --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 7d24657fe4..94344f098b 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -288,6 +288,11 @@ def delayed_block_deriv( projection.T, adjoint=True, ) + + derivatives = cur[0] + if isinstance(derivatives, sp.coo_array): + derivatives = derivatives.tocsr() + time_derivs.append(cur[0][:, arrays[0]]) if not isinstance(cur[1], Zero): From ce4439698d1bf6243a943f72153f64a8684851b7 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 25 Apr 2024 10:54:30 -0700 Subject: [PATCH 3/3] Simplify logic. Bump version --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 6 +----- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 94344f098b..7307531eb8 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -280,7 +280,7 @@ def delayed_block_deriv( j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) continue - projection = sp.kron(timeP[:, time_index], spatialP) + projection = sp.kron(timeP[:, time_index], spatialP, format="csr") cur = derivative_fun( time_index, source, @@ -289,10 +289,6 @@ def delayed_block_deriv( adjoint=True, ) - derivatives = cur[0] - if isinstance(derivatives, sp.coo_array): - derivatives = derivatives.tocsr() - time_derivs.append(cur[0][:, arrays[0]]) if not isinstance(cur[1], Zero): diff --git a/pyproject.toml b/pyproject.toml index 09c0693eff..87c0040a2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ [tool.poetry] name = "Mira-SimPEG" -version = "0.19.0.dev7" +version = "0.19.0.dev8" license = "MIT" description = "Mira Geoscience fork of SimPEG: Simulation and Parameter Estimation in Geophysics"