From 271ac298269307dfd5df2563130c17a5b4ec2664 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sat, 30 Mar 2024 09:31:08 -0700 Subject: [PATCH 01/33] block task on compute_rows --- .../time_domain/simulation.py | 63 ++++++++++--------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 14eab5c957..f9d9c037ea 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -321,7 +321,7 @@ def get_field_deriv_block( def compute_rows( simulation, tInd, - address, # (s_id, r_id, b_id) + addresses, # (s_id, r_id, b_id) indices, # (rx_ind, j_ind), ATinv_df_duT_v, fields, @@ -332,34 +332,37 @@ def compute_rows( """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - src = simulation.survey.source_list[address[0]] - rx = src.receiver_list[address[1]] - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ - indices[0] - ] - local_ind = np.arange(indices[0].shape[0])[time_check] - - if len(local_ind) < 1: - return - - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, - fields[src, ftype, tInd], - ATinv_df_duT_v[address][:, local_ind], - adjoint=True, - ) + for address, ind_array in zip(addresses, indices): + src = simulation.survey.source_list[address[0]] + rx = src.receiver_list[address[1]] + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + ind_array[0] + ] + local_ind = np.arange(ind_array[0].shape[0])[time_check] - dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True - ) # on nodes of time mesh + if len(local_ind) < 1: + return - un_src = fields[src, ftype, tInd + 1] - # cell centered on time mesh - dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True - ) + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, + fields[src, ftype, tInd], + ATinv_df_duT_v[address][:, local_ind], + adjoint=True, + ) + + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True + ) # on nodes of time mesh - Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + un_src = fields[src, ftype, tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True + ) + + Jmatrix[ind_array[1][time_check], :] += ( + -dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v + ).T def compute_J(self, f=None, Ainv=None): @@ -388,14 +391,14 @@ def compute_J(self, f=None, Ainv=None): ATinv_df_duT_v = get_field_deriv_block( self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - - for address, indices in block.items(): + split_blocks = np.array_split(list(block.items()), cpu_count()) + for arrays in split_blocks: j_row_updates.append( compute_rows( self, tInd, - address, - indices, + arrays[:, 0], + arrays[:, 1], ATinv_df_duT_v, f, Jmatrix, From 71d620bf94bbe9db95183e7efc66f500d47bd42f Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 1 Apr 2024 16:50:20 -0700 Subject: [PATCH 02/33] New implementation --- .../time_domain/simulation.py | 211 +++++++++++------- SimPEG/dask/utils.py | 31 ++- 2 files changed, 160 insertions(+), 82 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index f9d9c037ea..b9d657b010 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -157,74 +157,127 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed -def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): +def delayed_block_deriv( + time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape +): """Compute derivatives for sources and receivers in a block""" field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] - rx_count = 0 - for source in source_list: - sources_block = [] - - for rx in source.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - cur = derivative_fun( - time_index, - source, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, + j_update = [] + # rx_count = 0 + for indices, arrays in chunks: + # for source in source_list: + source = source_list[indices[0]] + receiver = source.receiver_list[indices[1]] + PTv = receiver.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) + # rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + df_duT.append(cur[0]) + + if isinstance(cur[1], Zero): + j_update.append( + sp.csc_matrix((arrays[0].shape[0], shape), dtype=np.float32) ) - sources_block.append(cur[0]) + else: + j_update.append(cur[1].T) + + # rx_count += rx.nD + + # df_duT.append(sources_block) - if not isinstance(cur[1], Zero): - Jmatrix[rx_ind, :] += cur[1].T + return df_duT, sp.vstack(j_update) - rx_count += rx.nD - df_duT.append(sources_block) +@delayed +def sens_update(delayed_chunks_tuple, shape): + arrays = [] + for chunk in delayed_chunks_tuple: + if not isinstance(chunk[1], Zero): + arrays.append(chunk[1].T) + else: + arrays.append( + sp.csc_matrix((chunk[1][0].shape[0], shape), dtype=np.float32) + ) + + return np.vstack(arrays) + + +@delayed +def deriv_update(delayed_chunks_tuple): + arrays = [] + for chunk in delayed_chunks_tuple: + arrays.append(chunk[0]) - return df_duT + return arrays -def compute_field_derivs(simulation, Jmatrix, fields): +def compute_field_derivs(simulation, fields, blocks, Jmatrix): """ Compute the derivative of the fields """ df_duT = [] for time_index in range(simulation.nT + 1): - df_duT.append( - block_deriv( + block_derivs = [] + j_updates = [] + delayed_chunks = [] + for chunks in blocks: + if len(chunks) == 0: + continue + + delayed_block = delayed_block_deriv( time_index, + chunks, simulation._fieldType + "Solution", simulation.survey.source_list, simulation.mesh, simulation.time_mesh, fields, - Jmatrix, + simulation.model.size, ) - ) + delayed_chunks.append(delayed_block) + + block_derivs = dask.compute(delayed_chunks)[0] + j_updates = sp.vstack([item[1] for item in block_derivs], dtype=np.float32) + df_duT.append([item[0] for item in block_derivs]) + # chunk_shape = np.sum(chunk[1][0].shape[0] for chunk in chunks), simulation.model.size + # j_updates.append(array.from_delayed( + # sens_update(delayed_block, simulation.model.size), + # shape=chunk_shape, dtype=np.float32 + # )) + # block_derivs.append(deriv_update(delayed_block)) + # delayed_chunks = dask.compute(delayed_chunks)[0] + # for delayed_chunk in delayed_chunks: + + # df_duT.append(block_derivs) + # + + # update = dask.compute(array.vstack(j_updates))[0] + Jmatrix = Jmatrix + j_updates - df_duT = dask.compute(df_duT)[0] + # df_duT = dask.compute(df_duT)[0] - return df_duT + return df_duT, Jmatrix @delayed def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, field_derivs, tInd ): if (s_id, r_id, b_id) not in ATinv_df_duT_v: # last timestep (first to be solved) - stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[ - :, sub_ind - ] + stacked_block = field_derivs.toarray()[:, sub_ind] else: stacked_block = np.asarray( - simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + field_derivs[:, sub_ind] - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] ) @@ -245,7 +298,13 @@ def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): def get_field_deriv_block( - simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask + simulation, + block: list, + field_derivs: list, + tInd: int, + AdiagTinv, + ATinv_df_duT_v: dict, + time_mask, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. @@ -258,12 +317,11 @@ def get_field_deriv_block( if tInd < simulation.nT - 1: Asubdiag = simulation.getAsubdiag(tInd + 1) - for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): + for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in zip( + block, field_derivs + ): # Cut out early data - rx = simulation.survey.source_list[s_id].receiver_list[r_id] - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ - rx_ind - ] + time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] sub_ind = rx_ind[time_check] local_ind = np.arange(rx_ind.shape[0])[time_check] @@ -283,7 +341,7 @@ def get_field_deriv_block( Asubdiag, local_ind, sub_ind, - simulation, + field_deriv, tInd, ) @@ -292,22 +350,22 @@ def get_field_deriv_block( deriv_comp, dtype=float, shape=( - simulation.field_derivs[tInd][s_id][r_id].shape[0], + field_deriv.shape[0], len(local_ind), ), ) ) if len(stacked_blocks) > 0: blocks = array.hstack(stacked_blocks).compute() - solve = AdiagTinv * blocks + solve = (AdiagTinv * blocks).reshape(blocks.shape) else: solve = None update_list = [] - for address in block: + for (address, arrays), field_deriv in zip(block, field_derivs): shape = ( - simulation.field_derivs[tInd][address[0]][address[1]].shape[0], - len(block[address][0]), + field_deriv.shape[0], + len(arrays[0]), ) update_list.append( update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) @@ -321,31 +379,27 @@ def get_field_deriv_block( def compute_rows( simulation, tInd, - addresses, # (s_id, r_id, b_id) - indices, # (rx_ind, j_ind), + chunks, ATinv_df_duT_v, fields, - Jmatrix, - ftype, time_mask, ): """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - for address, ind_array in zip(addresses, indices): + n_rows = np.sum(len(chunk[1][0]) for chunk in chunks) + rows = np.zeros((n_rows, simulation.model.size), dtype=np.float32) + for address, ind_array in chunks: src = simulation.survey.source_list[address[0]] - rx = src.receiver_list[address[1]] - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ - ind_array[0] - ] - local_ind = np.arange(ind_array[0].shape[0])[time_check] + time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] + local_ind = np.arange(len(ind_array[0]))[time_check] if len(local_ind) < 1: return dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, - fields[src, ftype, tInd], + fields[:, address[0], tInd], ATinv_df_duT_v[address][:, local_ind], adjoint=True, ) @@ -354,15 +408,15 @@ def compute_rows( tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) # on nodes of time mesh - un_src = fields[src, ftype, tInd + 1] + un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) - Jmatrix[ind_array[1][time_check], :] += ( - -dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v - ).T + rows[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + + return rows def compute_J(self, f=None, Ainv=None): @@ -374,39 +428,46 @@ def compute_J(self, f=None, Ainv=None): f, Ainv = self.fields(self.model, return_Ainv=True) ftype = self._fieldType + "Solution" - Jmatrix = delayed(np.zeros((self.survey.nD, self.model.size), dtype=np.float32)) + Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) - self.field_derivs = compute_field_derivs(self, Jmatrix, f) + times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) + fields_array = delayed(f[:, ftype, :]) ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] - for block in blocks.values(): + for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( - self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - split_blocks = np.array_split(list(block.items()), cpu_count()) - for arrays in split_blocks: - j_row_updates.append( + + if len(block) == 0: + continue + n_rows = np.sum(len(chunk[1][0]) for chunk in block) + j_row_updates.append( + array.from_delayed( compute_rows( self, tInd, - arrays[:, 0], - arrays[:, 1], + block, ATinv_df_duT_v, - f, - Jmatrix, - ftype, + fields_array, time_mask, - ) + ), + dtype=np.float32, + shape=(n_rows, self.model.size), ) - dask.compute(j_row_updates) + ) + + update = dask.compute(array.vstack(j_row_updates))[0] + Jmatrix += update + for A in Ainv.values(): A.clean() @@ -414,7 +475,7 @@ def compute_J(self, f=None, Ainv=None): del Jmatrix return array.from_zarr(self.sensitivity_path + f"J.zarr") else: - return Jmatrix.compute() + return np.asarray(Jmatrix) Sim.compute_J = compute_J diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index 037961c1ab..c287bc4ad6 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -40,18 +40,20 @@ def compute(self, job): return job.compute() -def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): +def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> list: """ Get the blocks of sources and receivers to be computed in parallel. - Stored as a dictionary of source, receiver pairs index. The value is an array of indices + Stored as a list of tuples for + (source, receiver, block index) and array of indices for the rows of the sensitivity matrix. """ data_block_size = np.ceil(max_chunk_size / (m_size * 8.0 * 1e-6)) row_count = 0 row_index = 0 block_count = 0 - blocks = {0: {}} + blocks = [[]] + for s_id, src in enumerate(source_list): for r_id, rx in enumerate(src.receiver_list): indices = np.arange(rx.nD).astype(int) @@ -66,12 +68,27 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): if (row_count + chunk_size) > (data_block_size * cpu_count()): row_count = 0 block_count += 1 - blocks[block_count] = {} + blocks.append = [] - blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange( - row_index, row_index + chunk_size - ).astype(int) + blocks[block_count].append( + ( + (s_id, r_id, ind), + ( + chunk, + np.arange(row_index, row_index + chunk_size).astype(int), + rx.locations.shape[0], + ), + ) + ) row_index += chunk_size row_count += chunk_size + # Re-split over cpu_count if too few blocks + if len(blocks) < cpu_count(): + flatten_blocks = [] + for block in blocks: + flatten_blocks += block + + chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + return [[flatten_blocks[i] for i in chunk] for chunk in chunks] return blocks From e09be83ee022488c45592975e263f0fdb4760063 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 15:35:31 -0700 Subject: [PATCH 03/33] Allow disk storage of TEM sens --- .../time_domain/simulation.py | 138 +++++++++--------- 1 file changed, 67 insertions(+), 71 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index b9d657b010..1e9e0972a7 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -1,6 +1,6 @@ import dask import dask.array - +import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero from multiprocessing import cpu_count @@ -138,7 +138,7 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): rows.append( array.from_delayed( row(src, rx, self.mesh, self.time_mesh, f), - dtype=np.float32, + dtype=np.float64, shape=(rx.nD,), ) ) @@ -158,20 +158,18 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed def delayed_block_deriv( - time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape + time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape, Jmatrix ): """Compute derivatives for sources and receivers in a block""" field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] j_update = [] - # rx_count = 0 + for indices, arrays in chunks: - # for source in source_list: source = source_list[indices[0]] receiver = source.receiver_list[indices[1]] PTv = receiver.getP(mesh, time_mesh, fields).tocsr() derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) - # rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) cur = derivative_fun( time_index, source, @@ -181,41 +179,14 @@ def delayed_block_deriv( ) df_duT.append(cur[0]) - if isinstance(cur[1], Zero): - j_update.append( - sp.csc_matrix((arrays[0].shape[0], shape), dtype=np.float32) - ) - else: + if not isinstance(cur[1], Zero) and not len(cur[1].data) == 0: j_update.append(cur[1].T) - - # rx_count += rx.nD - - # df_duT.append(sources_block) - - return df_duT, sp.vstack(j_update) - - -@delayed -def sens_update(delayed_chunks_tuple, shape): - arrays = [] - for chunk in delayed_chunks_tuple: - if not isinstance(chunk[1], Zero): - arrays.append(chunk[1].T) else: - arrays.append( - sp.csc_matrix((chunk[1][0].shape[0], shape), dtype=np.float32) + j_update.append( + sp.csr_matrix((arrays[0].shape[0], Jmatrix.shape[1]), dtype=np.float32) ) - return np.vstack(arrays) - - -@delayed -def deriv_update(delayed_chunks_tuple): - arrays = [] - for chunk in delayed_chunks_tuple: - arrays.append(chunk[0]) - - return arrays + return df_duT, j_update def compute_field_derivs(simulation, fields, blocks, Jmatrix): @@ -241,28 +212,23 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): simulation.time_mesh, fields, simulation.model.size, + Jmatrix, ) delayed_chunks.append(delayed_block) - block_derivs = dask.compute(delayed_chunks)[0] - j_updates = sp.vstack([item[1] for item in block_derivs], dtype=np.float32) - df_duT.append([item[0] for item in block_derivs]) - # chunk_shape = np.sum(chunk[1][0].shape[0] for chunk in chunks), simulation.model.size - # j_updates.append(array.from_delayed( - # sens_update(delayed_block, simulation.model.size), - # shape=chunk_shape, dtype=np.float32 - # )) - # block_derivs.append(deriv_update(delayed_block)) - # delayed_chunks = dask.compute(delayed_chunks)[0] - # for delayed_chunk in delayed_chunks: + for chunk in dask.compute(delayed_chunks)[0]: + block_derivs += chunk[0] + j_updates += chunk[1] - # df_duT.append(block_derivs) - # - - # update = dask.compute(array.vstack(j_updates))[0] - Jmatrix = Jmatrix + j_updates + Jmatrix += sp.vstack(j_updates) + if simulation.store_sensitivities == "disk": + sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" + array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + Jmatrix = array.from_zarr(sens_name) + else: + dask.compute(Jmatrix) - # df_duT = dask.compute(df_duT)[0] + df_duT.append(block_derivs) return df_duT, Jmatrix @@ -388,7 +354,8 @@ def compute_rows( Compute the rows of the sensitivity matrix for a given source and receiver. """ n_rows = np.sum(len(chunk[1][0]) for chunk in chunks) - rows = np.zeros((n_rows, simulation.model.size), dtype=np.float32) + rows = [] + for address, ind_array in chunks: src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] @@ -397,26 +364,35 @@ def compute_rows( if len(local_ind) < 1: return + field_derivs = ATinv_df_duT_v[address] dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - ATinv_df_duT_v[address][:, local_ind], + field_derivs[:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True + tInd + 1, src, field_derivs[:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True + tInd, un_src, field_derivs[:, local_ind], adjoint=True ) + # if isinstance(Jmatrix, zarr.core.Array): + # Jmatrix.oindex[ind_array[1][time_check].tolist(), :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(np.float64) + # else: + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( + np.float32 + ) + rows.append(row_block) - rows[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - return rows + return np.vstack(rows) def compute_J(self, f=None, Ainv=None): @@ -428,7 +404,23 @@ def compute_J(self, f=None, Ainv=None): f, Ainv = self.fields(self.model, return_Ainv=True) ftype = self._fieldType + "Solution" - Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) + sens_name = self.sensitivity_path[:-5] + if self.store_sensitivities == "disk": + rows = array.zeros( + (self.survey.nD, self.model.size), + chunks=(self.max_chunk_size, self.model.size), + dtype=np.float32, + ) + Jmatrix = array.to_zarr( + rows, + os.path.join(sens_name + "_1.zarr"), + compute=True, + return_stored=True, + overwrite=True, + ) + else: + Jmatrix = array.zeros((self.survey.nD, self.model.size), dtype=np.float64) + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times blocks = get_parallel_blocks( @@ -449,7 +441,7 @@ def compute_J(self, f=None, Ainv=None): if len(block) == 0: continue - n_rows = np.sum(len(chunk[1][0]) for chunk in block) + j_row_updates.append( array.from_delayed( compute_rows( @@ -461,21 +453,25 @@ def compute_J(self, f=None, Ainv=None): time_mask, ), dtype=np.float32, - shape=(n_rows, self.model.size), + shape=( + np.sum(len(chunk[1][0]) for chunk in block), + self.model.size, + ), ) ) - update = dask.compute(array.vstack(j_row_updates))[0] - Jmatrix += update + Jmatrix += array.vstack(j_row_updates) + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" + array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + Jmatrix = array.from_zarr(sens_name) + else: + dask.compute(Jmatrix) for A in Ainv.values(): A.clean() - if self.store_sensitivities == "disk": - del Jmatrix - return array.from_zarr(self.sensitivity_path + f"J.zarr") - else: - return np.asarray(Jmatrix) + return Jmatrix Sim.compute_J = compute_J From 1750a32b7db55174328f1689f5c4680b4a62c310 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 15:45:24 -0700 Subject: [PATCH 04/33] Fix issue --- .../time_domain/simulation.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 1e9e0972a7..d5bc40ae84 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -179,7 +179,7 @@ def delayed_block_deriv( ) df_duT.append(cur[0]) - if not isinstance(cur[1], Zero) and not len(cur[1].data) == 0: + if not isinstance(cur[1], Zero): j_update.append(cur[1].T) else: j_update.append( @@ -197,7 +197,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): for time_index in range(simulation.nT + 1): block_derivs = [] - j_updates = [] + block_updates = [] delayed_chunks = [] for chunks in blocks: if len(chunks) == 0: @@ -217,16 +217,19 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): delayed_chunks.append(delayed_block) for chunk in dask.compute(delayed_chunks)[0]: - block_derivs += chunk[0] - j_updates += chunk[1] - - Jmatrix += sp.vstack(j_updates) - if simulation.store_sensitivities == "disk": - sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" - array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) - Jmatrix = array.from_zarr(sens_name) - else: - dask.compute(Jmatrix) + block_derivs.append(chunk[0]) + block_updates += chunk[1] + + j_updates = sp.vstack(block_updates) + + if len(j_updates.data) > 0: + Jmatrix += sp.vstack(j_updates) + if simulation.store_sensitivities == "disk": + sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" + array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + Jmatrix = array.from_zarr(sens_name) + else: + dask.compute(Jmatrix) df_duT.append(block_derivs) From 30c78af15da3d058fb3d382259a51becb267ca7c Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 16:22:30 -0700 Subject: [PATCH 05/33] Use numpy array if in RAM --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index d5bc40ae84..8b2feaee3c 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -422,7 +422,7 @@ def compute_J(self, f=None, Ainv=None): overwrite=True, ) else: - Jmatrix = array.zeros((self.survey.nD, self.model.size), dtype=np.float64) + Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float64) simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times @@ -463,7 +463,7 @@ def compute_J(self, f=None, Ainv=None): ) ) - Jmatrix += array.vstack(j_row_updates) + Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) @@ -474,6 +474,9 @@ def compute_J(self, f=None, Ainv=None): for A in Ainv.values(): A.clean() + if self.store_sensitivities == "ram": + return Jmatrix.compute() + return Jmatrix From 9e2a98c4b653c314a32c8756dafb83ff1b63c10c Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 16:38:40 -0700 Subject: [PATCH 06/33] Add some prints --- .../time_domain/simulation.py | 4 +- SimPEG/dask/simulation.py | 38 ++++++++++--------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 8b2feaee3c..4f703c658c 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -195,7 +195,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): """ df_duT = [] - for time_index in range(simulation.nT + 1): + for time_index in tqdm(range(simulation.nT + 1)): block_derivs = [] block_updates = [] delayed_chunks = [] @@ -429,7 +429,9 @@ def compute_J(self, f=None, Ainv=None): blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) + print("Computing field derivatives") times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) + print("Done") fields_array = delayed(f[:, ftype, :]) ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): diff --git a/SimPEG/dask/simulation.py b/SimPEG/dask/simulation.py index b60cb47d3f..1cde16c857 100644 --- a/SimPEG/dask/simulation.py +++ b/SimPEG/dask/simulation.py @@ -9,6 +9,7 @@ Sim._max_ram = 16 + @property def max_ram(self): "Maximum ram in (Gb)" @@ -62,7 +63,7 @@ def n_cpu(self, other): def make_synthetic_data( - self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs + self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs ): """ Make synthetic data given a model, and a standard deviation. @@ -106,12 +107,13 @@ def make_synthetic_data( noise_floor=noise_floor, ) + Sim.make_synthetic_data = make_synthetic_data @property def workers(self): - if getattr(self, '_workers', None) is None: + if getattr(self, "_workers", None) is None: self._workers = None return self._workers @@ -127,7 +129,7 @@ def workers(self, workers): def dask_Jvec(self, m, v): """ - Compute sensitivity matrix (J) and vector (v) product. + Compute sensitivity matrix (J) and vector (v) product. """ self.model = m @@ -145,7 +147,7 @@ def dask_Jvec(self, m, v): def dask_Jtvec(self, m, v): """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. + Compute adjoint sensitivity matrix (J^T) and vector (v) product. """ self.model = m @@ -174,13 +176,11 @@ def Jmatrix(self): if self.store_sensitivities == "ram": self._Jmatrix = client.persist( - delayed(self.compute_J)(), - workers=self.workers + delayed(self.compute_J)(), workers=self.workers ) else: self._Jmatrix = client.compute( - delayed(self.compute_J)(), - workers=self.workers + delayed(self.compute_J)(), workers=self.workers ) elif isinstance(self._Jmatrix, Future): @@ -226,14 +226,16 @@ def evaluate_receiver(source, receiver, mesh, fields): rows = [] for src in self.survey.source_list: for rx in src.receiver_list: - rows.append(array.from_delayed( - row(src, rx, self.mesh, f), - dtype=np.float32, - shape=(rx.nD,), - )) - + rows.append( + array.from_delayed( + row(src, rx, self.mesh, f), + dtype=np.float32, + shape=(rx.nD,), + ) + ) + print("Computing data") data = array.hstack(rows).compute() - + print("Computing data done") if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) return data, Jmatrix @@ -246,7 +248,7 @@ def evaluate_receiver(source, receiver, mesh, fields): def dask_getJtJdiag(self, m, W=None): """ - Return the diagonal of JtJ + Return the diagonal of JtJ """ self.model = m if getattr(self, "_jtjdiag", None) is None: @@ -258,11 +260,11 @@ def dask_getJtJdiag(self, m, W=None): else: W = W.diagonal() ** 2.0 - diag = array.einsum('i,ij,ij->j', W, self.Jmatrix, self.Jmatrix) + diag = array.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) if isinstance(diag, array.Array): diag = np.asarray(diag.compute()) self._jtjdiag = diag - return self._jtjdiag \ No newline at end of file + return self._jtjdiag From 635225d507ca89bd9130f77882361f11b38bb2a5 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 16:51:14 -0700 Subject: [PATCH 07/33] Move prints to TEM class --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 9 ++++++--- SimPEG/dask/simulation.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 4f703c658c..a5bab46810 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -7,6 +7,7 @@ import numpy as np import scipy.sparse as sp from dask import array, delayed +from dask.diagnostics import ProgressBar from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks import zarr @@ -31,7 +32,7 @@ def fields(self, m=None, return_Ainv=False): Ainv = {} ATinv = {} - for tInd, dt in enumerate(self.time_steps): + for tInd, dt in tqdm(enumerate(self.time_steps)): if dt not in Ainv: A = self.getAdiag(tInd) Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) @@ -133,7 +134,8 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): row = delayed(evaluate_receiver, pure=True) rows = [] - for src in self.survey.source_list: + + for src in tqdm(self.survey.source_list): for rx in src.receiver_list: rows.append( array.from_delayed( @@ -143,7 +145,8 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): ) ) - data = array.hstack(rows).compute() + with ProgressBar(): + data = array.hstack(rows).compute() if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) diff --git a/SimPEG/dask/simulation.py b/SimPEG/dask/simulation.py index 1cde16c857..06b69a0ae9 100644 --- a/SimPEG/dask/simulation.py +++ b/SimPEG/dask/simulation.py @@ -233,9 +233,9 @@ def evaluate_receiver(source, receiver, mesh, fields): shape=(rx.nD,), ) ) - print("Computing data") + data = array.hstack(rows).compute() - print("Computing data done") + if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) return data, Jmatrix From caa63a30cbb465c965ef1be63b9d59d7efe76d81 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 16:58:38 -0700 Subject: [PATCH 08/33] More progress bars for profiling --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index a5bab46810..c2aef80cd3 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -202,6 +202,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): block_derivs = [] block_updates = [] delayed_chunks = [] + print(len(blocks), len(blocks[0])) for chunks in blocks: if len(chunks) == 0: continue @@ -219,7 +220,10 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): ) delayed_chunks.append(delayed_block) - for chunk in dask.compute(delayed_chunks)[0]: + with ProgressBar(): + result = dask.compute(delayed_chunks) + + for chunk in result[0]: block_derivs.append(chunk[0]) block_updates += chunk[1] From b5c89d98ec4e019e0e8bd2513011a2ee464e28b2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 21:16:27 -0700 Subject: [PATCH 09/33] Move prints --- .../electromagnetics/time_domain/simulation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index c2aef80cd3..a548e14399 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -198,11 +198,11 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): """ df_duT = [] - for time_index in tqdm(range(simulation.nT + 1)): + for time_index in range(simulation.nT + 1): block_derivs = [] block_updates = [] delayed_chunks = [] - print(len(blocks), len(blocks[0])) + for chunks in blocks: if len(chunks) == 0: continue @@ -220,8 +220,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): ) delayed_chunks.append(delayed_block) - with ProgressBar(): - result = dask.compute(delayed_chunks) + result = dask.compute(delayed_chunks) for chunk in result[0]: block_derivs.append(chunk[0]) @@ -436,9 +435,10 @@ def compute_J(self, f=None, Ainv=None): blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) + tc = time() print("Computing field derivatives") times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) - print("Done") + print(f"Done {time() -tc}") fields_array = delayed(f[:, ftype, :]) ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): @@ -447,9 +447,12 @@ def compute_J(self, f=None, Ainv=None): time_mask = data_times > simulation_times[tInd] for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): + tc = time() + print("Computing derivative block") ATinv_df_duT_v = get_field_deriv_block( self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) + print(f"Done {time() - tc}") if len(block) == 0: continue @@ -478,7 +481,10 @@ def compute_J(self, f=None, Ainv=None): array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) Jmatrix = array.from_zarr(sens_name) else: + tc = time() + print("Computing J update") dask.compute(Jmatrix) + print(f"Done {time() - tc}") for A in Ainv.values(): A.clean() From 92bc3874cb87e243204682f6620380a9a2d2ecd8 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 21:45:07 -0700 Subject: [PATCH 10/33] Deal only with numpy array --- .../time_domain/simulation.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index a548e14399..c2e0e66888 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -445,14 +445,12 @@ def compute_J(self, f=None, Ainv=None): AdiagTinv = Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] - + tc = time() + print("Computing derivative block") for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): - tc = time() - print("Computing derivative block") ATinv_df_duT_v = get_field_deriv_block( self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - print(f"Done {time() - tc}") if len(block) == 0: continue @@ -474,17 +472,23 @@ def compute_J(self, f=None, Ainv=None): ), ) ) - - Jmatrix = Jmatrix + array.vstack(j_row_updates) + print(f"Done {time() - tc}") + # Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" - array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + array.to_zarr( + Jmatrix + array.vstack(j_row_updates), + sens_name, + compute=True, + overwrite=True, + ) Jmatrix = array.from_zarr(sens_name) else: tc = time() print("Computing J update") - dask.compute(Jmatrix) + Jmatrix += array.vstack(j_row_updates).compute() print(f"Done {time() - tc}") + print(type(Jmatrix)) for A in Ainv.values(): A.clean() From f35b424a5659dbc5823f8f55165ff8f97fc5d8d0 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 22:21:52 -0700 Subject: [PATCH 11/33] Remove Jmatrix from deriv calc --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index c2e0e66888..b43935d2a5 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -161,7 +161,7 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed def delayed_block_deriv( - time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape, Jmatrix + time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape ): """Compute derivatives for sources and receivers in a block""" field_len = len(fields[source_list[0], field_type, 0]) @@ -186,7 +186,7 @@ def delayed_block_deriv( j_update.append(cur[1].T) else: j_update.append( - sp.csr_matrix((arrays[0].shape[0], Jmatrix.shape[1]), dtype=np.float32) + sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) ) return df_duT, j_update @@ -216,7 +216,6 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): simulation.time_mesh, fields, simulation.model.size, - Jmatrix, ) delayed_chunks.append(delayed_block) From 53295f2579824ed51cd6098ba76cb25ab8b74019 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 2 Apr 2024 22:48:37 -0700 Subject: [PATCH 12/33] More prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index b43935d2a5..a6fb69bec6 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -203,6 +203,8 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): block_updates = [] delayed_chunks = [] + tc = time() + print("Prepping blocks") for chunks in blocks: if len(chunks) == 0: continue @@ -218,9 +220,11 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): simulation.model.size, ) delayed_chunks.append(delayed_block) - + print(f"Done {time() - tc}") + tc = time() + print("Computing blocks") result = dask.compute(delayed_chunks) - + print(f"Done {time() - tc}") for chunk in result[0]: block_derivs.append(chunk[0]) block_updates += chunk[1] From 254086b5289f19293a8155f4684bf936dd487da2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 3 Apr 2024 14:00:28 -0700 Subject: [PATCH 13/33] Change loop of dpred --- .../time_domain/simulation.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index a6fb69bec6..1b34bd0032 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -129,21 +129,30 @@ def dask_dpred(self, m=None, f=None, compute_J=False): f, Ainv = self.fields(m, return_Ainv=compute_J) # print(f"took {time() - ct} s to compute fields") - def evaluate_receiver(source, receiver, mesh, time_mesh, fields): - return receiver.eval(source, mesh, time_mesh, fields).flatten() + @delayed + def evaluate_receivers(source_list, indices, mesh, time_mesh, fields): + data = [] + for ind in indices: + source = source_list[ind] + for receiver in source.receiver_list: + data.append(receiver.eval(source, mesh, time_mesh, fields).flatten()) - row = delayed(evaluate_receiver, pure=True) - rows = [] + return np.hstack(data) - for src in tqdm(self.survey.source_list): - for rx in src.receiver_list: - rows.append( - array.from_delayed( - row(src, rx, self.mesh, self.time_mesh, f), - dtype=np.float64, - shape=(rx.nD,), - ) + rows = [] + fields = delayed(f) + indices = np.array_split(np.arange(len(self.survey.source_list)), cpu_count()) + for block in indices: + n_data = np.sum(self.survey.source_list[ind].nD for ind in block) + rows.append( + array.from_delayed( + evaluate_receivers( + self.survey.source_list, block, self.mesh, self.time_mesh, fields + ), + dtype=np.float64, + shape=(n_data,), ) + ) with ProgressBar(): data = array.hstack(rows).compute() @@ -488,10 +497,7 @@ def compute_J(self, f=None, Ainv=None): Jmatrix = array.from_zarr(sens_name) else: tc = time() - print("Computing J update") Jmatrix += array.vstack(j_row_updates).compute() - print(f"Done {time() - tc}") - print(type(Jmatrix)) for A in Ainv.values(): A.clean() From 5a8df3fa18bfd782374b86831468f34e0d16f33e Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 3 Apr 2024 22:23:30 -0700 Subject: [PATCH 14/33] Speed up data calc without reshaping of fields --- .../time_domain/simulation.py | 64 +++++++++++-------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 1b34bd0032..dcc1b3ad5e 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -103,6 +103,20 @@ def source_evaluation(simulation, sources, time): Sim.getSourceTerm = dask_getSourceTerm +@delayed +def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): + data = [] + for source, ind, receiver in block: + # proj = receiver.getP(mesh, time_mesh, fields) + Ps = receiver.getSpatialP(mesh, fields) + Pt = receiver.getTimeP(time_mesh, fields) + vector = (Pt * (Ps * fields_array[:, ind, :]).T).flatten() + + data.append(vector) + + return np.hstack(data) + + def dask_dpred(self, m=None, f=None, compute_J=False): """ dpred(m, f=None) @@ -128,27 +142,27 @@ def dask_dpred(self, m=None, f=None, compute_J=False): m = self.model f, Ainv = self.fields(m, return_Ainv=compute_J) - # print(f"took {time() - ct} s to compute fields") - @delayed - def evaluate_receivers(source_list, indices, mesh, time_mesh, fields): - data = [] - for ind in indices: - source = source_list[ind] - for receiver in source.receiver_list: - data.append(receiver.eval(source, mesh, time_mesh, fields).flatten()) - - return np.hstack(data) - rows = [] fields = delayed(f) - indices = np.array_split(np.arange(len(self.survey.source_list)), cpu_count()) - for block in indices: - n_data = np.sum(self.survey.source_list[ind].nD for ind in block) + fields_array = delayed( + f[:, self.survey.source_list[0].receiver_list[0].projField, :] + ) + mesh = delayed(self.mesh) + time_mesh = delayed(self.time_mesh) + all_receivers = [] + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(all_receivers, cpu_count()) + for block in receiver_blocks: + n_data = np.sum(rec.nD for _, _, rec in block) + if n_data == 0: + continue + rows.append( array.from_delayed( - evaluate_receivers( - self.survey.source_list, block, self.mesh, self.time_mesh, fields - ), + evaluate_receivers(block, mesh, time_mesh, fields, fields_array), dtype=np.float64, shape=(n_data,), ) @@ -212,8 +226,6 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): block_updates = [] delayed_chunks = [] - tc = time() - print("Prepping blocks") for chunks in blocks: if len(chunks) == 0: continue @@ -229,11 +241,9 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): simulation.model.size, ) delayed_chunks.append(delayed_block) - print(f"Done {time() - tc}") - tc = time() - print("Computing blocks") + result = dask.compute(delayed_chunks) - print(f"Done {time() - tc}") + for chunk in result[0]: block_derivs.append(chunk[0]) block_updates += chunk[1] @@ -457,8 +467,7 @@ def compute_J(self, f=None, Ainv=None): AdiagTinv = Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] - tc = time() - print("Computing derivative block") + for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask @@ -484,7 +493,7 @@ def compute_J(self, f=None, Ainv=None): ), ) ) - print(f"Done {time() - tc}") + # Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" @@ -496,14 +505,13 @@ def compute_J(self, f=None, Ainv=None): ) Jmatrix = array.from_zarr(sens_name) else: - tc = time() Jmatrix += array.vstack(j_row_updates).compute() for A in Ainv.values(): A.clean() if self.store_sensitivities == "ram": - return Jmatrix.compute() + return np.asarray(Jmatrix) return Jmatrix From 5bc6ecbcb1d898e3adb6f775b826b34e8418cbeb Mon Sep 17 00:00:00 2001 From: domfournier Date: Wed, 3 Apr 2024 22:29:14 -0700 Subject: [PATCH 15/33] Revert re-splitting of blocks --- .../electromagnetics/time_domain/simulation.py | 2 -- SimPEG/dask/utils.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index dcc1b3ad5e..e5afc0dfb8 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -256,8 +256,6 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) Jmatrix = array.from_zarr(sens_name) - else: - dask.compute(Jmatrix) df_duT.append(block_derivs) diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index c287bc4ad6..2a1900edce 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -83,12 +83,12 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> row_index += chunk_size row_count += chunk_size - # Re-split over cpu_count if too few blocks - if len(blocks) < cpu_count(): - flatten_blocks = [] - for block in blocks: - flatten_blocks += block - - chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - return [[flatten_blocks[i] for i in chunk] for chunk in chunks] + # # Re-split over cpu_count if too few blocks + # if len(blocks) < cpu_count(): + # flatten_blocks = [] + # for block in blocks: + # flatten_blocks += block + # + # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + # return [[flatten_blocks[i] for i in chunk] for chunk in chunks] return blocks From 32ae829d94cf464336e662514bfd4de78861b5ef Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 4 Apr 2024 14:44:40 -0700 Subject: [PATCH 16/33] remove prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index e5afc0dfb8..d903877541 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -32,7 +32,7 @@ def fields(self, m=None, return_Ainv=False): Ainv = {} ATinv = {} - for tInd, dt in tqdm(enumerate(self.time_steps)): + for tInd, dt in enumerate(self.time_steps): if dt not in Ainv: A = self.getAdiag(tInd) Ainv[dt] = self.solver(sp.csr_matrix(A), **self.solver_opts) @@ -456,9 +456,9 @@ def compute_J(self, f=None, Ainv=None): self.survey.source_list, self.model.shape[0], self.max_chunk_size ) tc = time() - print("Computing field derivatives") + times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) - print(f"Done {time() -tc}") + fields_array = delayed(f[:, ftype, :]) ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): From c8bf1bde7685630e5f6c53d051897c14ae2d1b52 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 11:29:39 -0700 Subject: [PATCH 17/33] Move prints --- .../dask/electromagnetics/time_domain/simulation.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index d903877541..3c03fbbda2 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -168,8 +168,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): ) ) - with ProgressBar(): - data = array.hstack(rows).compute() + data = array.hstack(rows).compute() if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) @@ -351,8 +350,12 @@ def get_field_deriv_block( ) ) if len(stacked_blocks) > 0: - blocks = array.hstack(stacked_blocks).compute() + with ProgressBar(): + blocks = array.hstack(stacked_blocks).compute() + + tc = time() solve = (AdiagTinv * blocks).reshape(blocks.shape) + print("Solve time: ", time() - tc) else: solve = None @@ -466,6 +469,7 @@ def compute_J(self, f=None, Ainv=None): j_row_updates = [] time_mask = data_times > simulation_times[tInd] + tc = time() for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask @@ -492,6 +496,7 @@ def compute_J(self, f=None, Ainv=None): ) ) + print("Prepping blocks: ", time() - tc) # Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" From 8aed506f8d2ca61692a250f9bf56327349a554b9 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 11:55:29 -0700 Subject: [PATCH 18/33] Add prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 3c03fbbda2..b451eee96f 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -167,8 +167,8 @@ def dask_dpred(self, m=None, f=None, compute_J=False): shape=(n_data,), ) ) - - data = array.hstack(rows).compute() + with ProgressBar(): + data = array.hstack(rows).compute() if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) @@ -431,7 +431,7 @@ def compute_J(self, f=None, Ainv=None): """ Compute the rows for the sensitivity matrix. """ - + print("Computing fields") if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) @@ -460,7 +460,9 @@ def compute_J(self, f=None, Ainv=None): ) tc = time() + print("COmputing field derivs") times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) + print("Field derivs: ", time() - tc) fields_array = delayed(f[:, ftype, :]) ATinv_df_duT_v = {} From 5275e59479e74d14e235c4c3473a7544aec76d07 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 12:44:26 -0700 Subject: [PATCH 19/33] MOre testing prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index b451eee96f..3f79868135 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -150,12 +150,14 @@ def dask_dpred(self, m=None, f=None, compute_J=False): mesh = delayed(self.mesh) time_mesh = delayed(self.time_mesh) all_receivers = [] - for ind, src in enumerate(self.survey.source_list): + print("Prepping receivers") + for ind, src in tqdm(enumerate(self.survey.source_list)): for rx in src.receiver_list: all_receivers.append((src, ind, rx)) receiver_blocks = np.array_split(all_receivers, cpu_count()) - for block in receiver_blocks: + print("Creatint parallel blocks") + for block in tqdm(receiver_blocks): n_data = np.sum(rec.nD for _, _, rec in block) if n_data == 0: continue From f8e4437e3b592db1558773fda0e6600835d3ac5b Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 14:12:14 -0700 Subject: [PATCH 20/33] More prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 3f79868135..c9a2ce630f 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -512,7 +512,10 @@ def compute_J(self, f=None, Ainv=None): ) Jmatrix = array.from_zarr(sens_name) else: + tc = time() + print("Adding to Jmatrix") Jmatrix += array.vstack(j_row_updates).compute() + print("Add time: ", time() - tc) for A in Ainv.values(): A.clean() From 5641fb9ea9c68eb9b3722c9e84689a818b27c246 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 14:25:40 -0700 Subject: [PATCH 21/33] Again --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index c9a2ce630f..7c484033bd 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -143,10 +143,16 @@ def dask_dpred(self, m=None, f=None, compute_J=False): f, Ainv = self.fields(m, return_Ainv=compute_J) rows = [] + tc = time() + print("delaying fields") fields = delayed(f) + print(f"Complet {time()-tc}") + tc = time() + print("delaying fields array") fields_array = delayed( f[:, self.survey.source_list[0].receiver_list[0].projField, :] ) + print(f"Complet {time()-tc}") mesh = delayed(self.mesh) time_mesh = delayed(self.time_mesh) all_receivers = [] @@ -243,7 +249,8 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): ) delayed_chunks.append(delayed_block) - result = dask.compute(delayed_chunks) + with ProgressBar(): + result = dask.compute(delayed_chunks) for chunk in result[0]: block_derivs.append(chunk[0]) From 35ab291dec4d52bd16b5a25567efde649e2540fc Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 14:37:39 -0700 Subject: [PATCH 22/33] Remove delayed on large array --- .../time_domain/simulation.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 7c484033bd..9e632dc19c 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -144,17 +144,10 @@ def dask_dpred(self, m=None, f=None, compute_J=False): rows = [] tc = time() - print("delaying fields") - fields = delayed(f) - print(f"Complet {time()-tc}") - tc = time() print("delaying fields array") - fields_array = delayed( - f[:, self.survey.source_list[0].receiver_list[0].projField, :] - ) + fields_array = f[:, self.survey.source_list[0].receiver_list[0].projField, :] print(f"Complet {time()-tc}") - mesh = delayed(self.mesh) - time_mesh = delayed(self.time_mesh) + all_receivers = [] print("Prepping receivers") for ind, src in tqdm(enumerate(self.survey.source_list)): @@ -162,7 +155,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): all_receivers.append((src, ind, rx)) receiver_blocks = np.array_split(all_receivers, cpu_count()) - print("Creatint parallel blocks") + print("Creating parallel blocks") for block in tqdm(receiver_blocks): n_data = np.sum(rec.nD for _, _, rec in block) if n_data == 0: @@ -170,7 +163,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): rows.append( array.from_delayed( - evaluate_receivers(block, mesh, time_mesh, fields, fields_array), + evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), dtype=np.float64, shape=(n_data,), ) @@ -287,7 +280,7 @@ def deriv_block( return stacked_block -def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): +def update_deriv_blocks(address, indices, derivatives, solve, shape): if address not in derivatives: deriv_array = np.zeros(shape) else: @@ -297,7 +290,7 @@ def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): columns, local_ind = indices[address] deriv_array[:, local_ind] = solve[:, columns] - derivatives[address] = delayed(deriv_array) + derivatives[address] = deriv_array def get_field_deriv_block( @@ -374,10 +367,10 @@ def get_field_deriv_block( field_deriv.shape[0], len(arrays[0]), ) - update_list.append( - update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) - ) - dask.compute(update_list) + + update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) + + # dask.compute(update_list) return ATinv_df_duT_v @@ -473,7 +466,7 @@ def compute_J(self, f=None, Ainv=None): times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) print("Field derivs: ", time() - tc) - fields_array = delayed(f[:, ftype, :]) + fields_array = f[:, ftype, :] ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] From 14088563d1602deec16ab6d04ba7e4ef74b720da Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 5 Apr 2024 16:19:10 -0700 Subject: [PATCH 23/33] Improve parallelization on field_derivs --- .../time_domain/simulation.py | 115 +++++++++++++----- 1 file changed, 85 insertions(+), 30 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 9e632dc19c..18ef9731b3 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -3,6 +3,7 @@ import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero +from SimPEG.fields import TimeFields from multiprocessing import cpu_count import numpy as np import scipy.sparse as sp @@ -23,6 +24,61 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +def _getField(self, name, ind, src_list): + srcInd, timeInd = ind + + if name in self._fields: + out = self._fields[name][:, srcInd, timeInd] + else: + # Aliased fields + alias, loc, func = self.aliasFields[name] + if isinstance(func, str): + assert hasattr(self, func), ( + "The alias field function is a string, but it does " + "not exist in the Fields class." + ) + func = getattr(self, func) + pointerFields = self._fields[alias][:, srcInd, timeInd] + pointerShape = self._correctShape(alias, ind) + pointerFields = pointerFields.reshape(pointerShape, order="F") + + # First try to return the function as three arguments (without timeInd) + if timeInd == slice(None, None, None): + try: + # assume it will take care of integrating over all times + return func(pointerFields, srcInd) + except TypeError: + pass + + timeII = np.arange(self.simulation.nT + 1)[timeInd] + if not isinstance(src_list, list): + src_list = [src_list] + + if timeII.size == 1: + pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) + pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") + out = func(pointerFields, src_list, timeII) + else: # loop over the time steps + nT = pointerShape[2] + out = list(range(nT)) + for i, TIND_i in enumerate(timeII): # Need to parallelize this + fieldI = pointerFields[:, :, i] + if fieldI.shape[0] == fieldI.size: + fieldI = mkvc(fieldI, 2) + out[i] = func(fieldI, src_list, TIND_i) + if out[i].ndim == 1: + out[i] = out[i][:, np.newaxis, np.newaxis] + elif out[i].ndim == 2: + out[i] = out[i][:, :, np.newaxis] + out = np.concatenate(out, axis=2) + + shape = self._correctShape(name, ind, deflate=True) + return out.reshape(shape, order="F") + + +TimeFields._getField = _getField + + def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -145,7 +201,9 @@ def dask_dpred(self, m=None, f=None, compute_J=False): rows = [] tc = time() print("delaying fields array") - fields_array = f[:, self.survey.source_list[0].receiver_list[0].projField, :] + receiver_projection = self.survey.source_list[0].receiver_list[0].projField + fields_array = f[:, receiver_projection, :] + print(f"Complet {time()-tc}") all_receivers = [] @@ -184,10 +242,9 @@ def dask_dpred(self, m=None, f=None, compute_J=False): @delayed def delayed_block_deriv( - time_index, chunks, field_type, source_list, mesh, time_mesh, fields, shape + time_index, chunks, field_len, source_list, mesh, time_mesh, fields, shape ): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] j_update = [] @@ -215,17 +272,13 @@ def delayed_block_deriv( return df_duT, j_update -def compute_field_derivs(simulation, fields, blocks, Jmatrix): +def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ - df_duT = [] - + delayed_blocks = [] for time_index in range(simulation.nT + 1): - block_derivs = [] - block_updates = [] delayed_chunks = [] - for chunks in blocks: if len(chunks) == 0: continue @@ -233,7 +286,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): delayed_block = delayed_block_deriv( time_index, chunks, - simulation._fieldType + "Solution", + fields_shape[0], simulation.survey.source_list, simulation.mesh, simulation.time_mesh, @@ -242,23 +295,24 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix): ) delayed_chunks.append(delayed_block) - with ProgressBar(): - result = dask.compute(delayed_chunks) + delayed_blocks.append(delayed_chunks) - for chunk in result[0]: - block_derivs.append(chunk[0]) - block_updates += chunk[1] - - j_updates = sp.vstack(block_updates) - - if len(j_updates.data) > 0: - Jmatrix += sp.vstack(j_updates) - if simulation.store_sensitivities == "disk": - sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" - array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) - Jmatrix = array.from_zarr(sens_name) + with ProgressBar(): + result = dask.compute(delayed_blocks)[0] - df_duT.append(block_derivs) + df_duT = [] + j_updates = 0.0 + for time_block in result: + for chunk in time_block: + df_duT.append([chunk[0]]) + j_updates += sp.vstack(chunk[1]) + + if len(j_updates.data) > 0: + Jmatrix += sp.vstack(j_updates) + if simulation.store_sensitivities == "disk": + sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" + array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + Jmatrix = array.from_zarr(sens_name) return df_duT, Jmatrix @@ -284,7 +338,7 @@ def update_deriv_blocks(address, indices, derivatives, solve, shape): if address not in derivatives: deriv_array = np.zeros(shape) else: - deriv_array = derivatives[address].compute() + deriv_array = derivatives[address] if address in indices: columns, local_ind = indices[address] @@ -368,7 +422,7 @@ def get_field_deriv_block( len(arrays[0]), ) - update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) + update_deriv_blocks(address, indices, ATinv_df_duT_v, solve, shape) # dask.compute(update_list) @@ -460,13 +514,14 @@ def compute_J(self, f=None, Ainv=None): blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) + fields_array = f[:, ftype, :] tc = time() - print("COmputing field derivs") - times_field_derivs, Jmatrix = compute_field_derivs(self, f, blocks, Jmatrix) + times_field_derivs, Jmatrix = compute_field_derivs( + self, f, blocks, Jmatrix, fields_array.shape + ) print("Field derivs: ", time() - tc) - fields_array = f[:, ftype, :] ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] From 15fcf8bbc4868ed603c7d173e0cfbe1e3681dde6 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 07:37:19 -0700 Subject: [PATCH 24/33] Parallel alias field computes --- .../time_domain/simulation.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 18ef9731b3..cfc4ce85de 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -24,6 +24,20 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +@delayed +def field_projection(field_array, src_list, array_ind, time_ind, func): + fieldI = field_array[:, :, array_ind] + if fieldI.shape[0] == fieldI.size: + fieldI = mkvc(fieldI, 2) + new_array = func(fieldI, src_list, time_ind) + if new_array.ndim == 1: + new_array = new_array[:, np.newaxis, np.newaxis] + elif new_array.ndim == 2: + new_array = new_array[:, :, np.newaxis] + + return new_array + + def _getField(self, name, ind, src_list): srcInd, timeInd = ind @@ -60,17 +74,18 @@ def _getField(self, name, ind, src_list): out = func(pointerFields, src_list, timeII) else: # loop over the time steps nT = pointerShape[2] - out = list(range(nT)) + arrays = [] + for i, TIND_i in enumerate(timeII): # Need to parallelize this - fieldI = pointerFields[:, :, i] - if fieldI.shape[0] == fieldI.size: - fieldI = mkvc(fieldI, 2) - out[i] = func(fieldI, src_list, TIND_i) - if out[i].ndim == 1: - out[i] = out[i][:, np.newaxis, np.newaxis] - elif out[i].ndim == 2: - out[i] = out[i][:, :, np.newaxis] - out = np.concatenate(out, axis=2) + arrays.append( + array.from_delayed( + field_projection(pointerFields, src_list, i, TIND_i, func), + dtype=np.float32, + shape=(pointerShape[0], pointerShape[1], 1), + ) + ) + + out = array.dstack(arrays).compute() shape = self._correctShape(name, ind, deflate=True) return out.reshape(shape, order="F") From bb098ec88c7c3ba541d4706974204fd2dd84003b Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 11:10:57 -0700 Subject: [PATCH 25/33] Bring back re-splitting --- .../electromagnetics/time_domain/simulation.py | 4 ++-- SimPEG/dask/utils.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index cfc4ce85de..787d3bb01b 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -382,8 +382,8 @@ def get_field_deriv_block( if tInd < simulation.nT - 1: Asubdiag = simulation.getAsubdiag(tInd + 1) - for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in zip( - block, field_derivs + for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in tqdm( + zip(block, field_derivs) ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index 2a1900edce..c287bc4ad6 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -83,12 +83,12 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> row_index += chunk_size row_count += chunk_size - # # Re-split over cpu_count if too few blocks - # if len(blocks) < cpu_count(): - # flatten_blocks = [] - # for block in blocks: - # flatten_blocks += block - # - # chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) - # return [[flatten_blocks[i] for i in chunk] for chunk in chunks] + # Re-split over cpu_count if too few blocks + if len(blocks) < cpu_count(): + flatten_blocks = [] + for block in blocks: + flatten_blocks += block + + chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + return [[flatten_blocks[i] for i in chunk] for chunk in chunks] return blocks From 86687f0e3c2aa2e4024acdebbc3e3c9d8487f355 Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 13:58:33 -0700 Subject: [PATCH 26/33] Invert loop over flied_derivs --- .../time_domain/simulation.py | 95 ++++++++++--------- 1 file changed, 52 insertions(+), 43 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 787d3bb01b..127d64679c 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -257,73 +257,82 @@ def dask_dpred(self, m=None, f=None, compute_J=False): @delayed def delayed_block_deriv( - time_index, chunks, field_len, source_list, mesh, time_mesh, fields, shape + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape ): """Compute derivatives for sources and receivers in a block""" df_duT = [] - j_update = [] + j_updates = [] for indices, arrays in chunks: + j_update = 0.0 source = source_list[indices[0]] receiver = source.receiver_list[indices[1]] PTv = receiver.getP(mesh, time_mesh, fields).tocsr() derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) - cur = derivative_fun( - time_index, - source, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, - ) - df_duT.append(cur[0]) - - if not isinstance(cur[1], Zero): - j_update.append(cur[1].T) - else: - j_update.append( - sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + time_derivs = [] + for time_index in range(n_times + 1): + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, ) + time_derivs.append(cur[0]) + + if not isinstance(cur[1], Zero): + j_update += cur[1].T + else: + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - return df_duT, j_update + j_updates.append(j_update) + df_duT.append(time_derivs) + + return df_duT, j_updates def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ - delayed_blocks = [] - for time_index in range(simulation.nT + 1): - delayed_chunks = [] - for chunks in blocks: - if len(chunks) == 0: - continue + # delayed_blocks = [] + # for time_index in range(simulation.nT + 1): + delayed_chunks = [] + for chunks in blocks: + if len(chunks) == 0: + continue - delayed_block = delayed_block_deriv( - time_index, - chunks, - fields_shape[0], - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - simulation.model.size, - ) - delayed_chunks.append(delayed_block) + delayed_block = delayed_block_deriv( + simulation.nT, + chunks, + fields_shape[0], + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + simulation.model.size, + ) + delayed_chunks.append(delayed_block) - delayed_blocks.append(delayed_chunks) + # delayed_blocks.append(delayed_chunks) with ProgressBar(): - result = dask.compute(delayed_blocks)[0] + result = dask.compute(delayed_chunks)[0] - df_duT = [] - j_updates = 0.0 - for time_block in result: - for chunk in time_block: - df_duT.append([chunk[0]]) - j_updates += sp.vstack(chunk[1]) + len_blocks = [[] * len(block) for block in blocks if len(block) > 0] + df_duT = [len_blocks.copy() for _ in range(simulation.nT + 1)] + j_updates = [] + + for block in result: + j_updates += block[1] + for bb, chunk in enumerate(block[0]): + for ind, time_block in enumerate(chunk): + df_duT[ind][bb] += time_block + + j_updates = sp.vstack(j_updates) if len(j_updates.data) > 0: - Jmatrix += sp.vstack(j_updates) + Jmatrix += j_updates if simulation.store_sensitivities == "disk": sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) From b21e9302509338aa556f9f292e54638672e0f33d Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 15:15:38 -0700 Subject: [PATCH 27/33] Workout new indexing --- .../dask/electromagnetics/time_domain/simulation.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 127d64679c..c78ae9e28a 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -319,15 +319,18 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): with ProgressBar(): result = dask.compute(delayed_chunks)[0] - len_blocks = [[] * len(block) for block in blocks if len(block) > 0] - df_duT = [len_blocks.copy() for _ in range(simulation.nT + 1)] + len_blocks = [[[] * len(block)] for block in blocks if len(block) > 0] + df_duT = [ + [[[] * len(block)] for block in blocks if len(block) > 0] + for _ in range(simulation.nT + 1) + ] j_updates = [] - for block in result: + for bb, block in enumerate(result): j_updates += block[1] - for bb, chunk in enumerate(block[0]): + for cc, chunk in enumerate(block[0]): for ind, time_block in enumerate(chunk): - df_duT[ind][bb] += time_block + df_duT[ind][bb][cc] = time_block j_updates = sp.vstack(j_updates) From 68f07f1b85ce568381afdc7532fe92205ea537ac Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 20:44:28 -0700 Subject: [PATCH 28/33] Fix list len --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index c78ae9e28a..7263c36bb8 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -319,9 +319,9 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): with ProgressBar(): result = dask.compute(delayed_chunks)[0] - len_blocks = [[[] * len(block)] for block in blocks if len(block) > 0] + # len_blocks = [[[] for _ in block] for block in blocks if len(block) > 0] df_duT = [ - [[[] * len(block)] for block in blocks if len(block) > 0] + [[[] for _ in block] for block in blocks if len(block) > 0] for _ in range(simulation.nT + 1) ] j_updates = [] From aa5eb7e686035240a92503cbe6b16994d9a0872b Mon Sep 17 00:00:00 2001 From: domfournier Date: Sun, 7 Apr 2024 21:07:17 -0700 Subject: [PATCH 29/33] Remove debug prints --- .../time_domain/simulation.py | 41 ++++++------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 7263c36bb8..10ed0377af 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -8,7 +8,8 @@ import numpy as np import scipy.sparse as sp from dask import array, delayed -from dask.diagnostics import ProgressBar + +# from dask.diagnostics import ProgressBar from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks import zarr @@ -214,22 +215,17 @@ def dask_dpred(self, m=None, f=None, compute_J=False): f, Ainv = self.fields(m, return_Ainv=compute_J) rows = [] - tc = time() - print("delaying fields array") receiver_projection = self.survey.source_list[0].receiver_list[0].projField fields_array = f[:, receiver_projection, :] - - print(f"Complet {time()-tc}") - all_receivers = [] - print("Prepping receivers") - for ind, src in tqdm(enumerate(self.survey.source_list)): + + for ind, src in enumerate(self.survey.source_list): for rx in src.receiver_list: all_receivers.append((src, ind, rx)) receiver_blocks = np.array_split(all_receivers, cpu_count()) - print("Creating parallel blocks") - for block in tqdm(receiver_blocks): + + for block in receiver_blocks: n_data = np.sum(rec.nD for _, _, rec in block) if n_data == 0: continue @@ -241,8 +237,8 @@ def dask_dpred(self, m=None, f=None, compute_J=False): shape=(n_data,), ) ) - with ProgressBar(): - data = array.hstack(rows).compute() + + data = array.hstack(rows).compute() if compute_J and self._Jmatrix is None: Jmatrix = self.compute_J(f=f, Ainv=Ainv) @@ -316,8 +312,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): # delayed_blocks.append(delayed_chunks) - with ProgressBar(): - result = dask.compute(delayed_chunks)[0] + result = dask.compute(delayed_chunks)[0] # len_blocks = [[[] for _ in block] for block in blocks if len(block) > 0] df_duT = [ @@ -394,8 +389,8 @@ def get_field_deriv_block( if tInd < simulation.nT - 1: Asubdiag = simulation.getAsubdiag(tInd + 1) - for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in tqdm( - zip(block, field_derivs) + for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in zip( + block, field_derivs ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] @@ -433,12 +428,8 @@ def get_field_deriv_block( ) ) if len(stacked_blocks) > 0: - with ProgressBar(): - blocks = array.hstack(stacked_blocks).compute() - - tc = time() + blocks = array.hstack(stacked_blocks).compute() solve = (AdiagTinv * blocks).reshape(blocks.shape) - print("Solve time: ", time() - tc) else: solve = None @@ -514,7 +505,6 @@ def compute_J(self, f=None, Ainv=None): """ Compute the rows for the sensitivity matrix. """ - print("Computing fields") if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) @@ -542,12 +532,9 @@ def compute_J(self, f=None, Ainv=None): self.survey.source_list, self.model.shape[0], self.max_chunk_size ) fields_array = f[:, ftype, :] - tc = time() - print("COmputing field derivs") times_field_derivs, Jmatrix = compute_field_derivs( self, f, blocks, Jmatrix, fields_array.shape ) - print("Field derivs: ", time() - tc) ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): @@ -582,7 +569,6 @@ def compute_J(self, f=None, Ainv=None): ) ) - print("Prepping blocks: ", time() - tc) # Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" @@ -594,10 +580,7 @@ def compute_J(self, f=None, Ainv=None): ) Jmatrix = array.from_zarr(sens_name) else: - tc = time() - print("Adding to Jmatrix") Jmatrix += array.vstack(j_row_updates).compute() - print("Add time: ", time() - tc) for A in Ainv.values(): A.clean() From 229be4cd17afcf4bb45587b2572c7862067e2460 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 8 Apr 2024 06:26:47 -0700 Subject: [PATCH 30/33] TEst with Zero on dbdtDeriv_m --- SimPEG/electromagnetics/time_domain/fields.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/SimPEG/electromagnetics/time_domain/fields.py b/SimPEG/electromagnetics/time_domain/fields.py index 384432c736..7a0b45e3b0 100644 --- a/SimPEG/electromagnetics/time_domain/fields.py +++ b/SimPEG/electromagnetics/time_domain/fields.py @@ -56,7 +56,8 @@ def _dbdtDeriv(self, tInd, src, dun_dm_v, v, adjoint=False): if adjoint is True: return ( self._dbdtDeriv_u(tInd, src, v, adjoint), - self._dbdtDeriv_m(tInd, src, v, adjoint), + Zero() + # self._dbdtDeriv_m(tInd, src, v, adjoint), ) return self._dbdtDeriv_u(tInd, src, dun_dm_v) + self._dbdtDeriv_m(tInd, src, v) @@ -160,7 +161,7 @@ def _dbdt(self, bSolution, source_list, tInd): def _dbdtDeriv_u(self, tInd, src, dun_dm_v, adjoint=False): if adjoint is True: - return -self._eDeriv_u(tInd, src, self._edgeCurl.T * dun_dm_v, adjoint) + return -self._eDeriv_u(tInd, src, self._edgeCurl.T @ dun_dm_v, adjoint) return -(self._edgeCurl * self._eDeriv_u(tInd, src, dun_dm_v)) def _dbdtDeriv_m(self, tInd, src, v, adjoint=False): @@ -179,7 +180,7 @@ def _e(self, bSolution, source_list, tInd): def _eDeriv_u(self, tInd, src, dun_dm_v, adjoint=False): if adjoint is True: - return self._MfMui.T * (self._edgeCurl * (self._MeSigmaI.T * dun_dm_v)) + return self._MfMui.T @ (self._edgeCurl @ (self._MeSigmaI.T @ dun_dm_v)) return self._MeSigmaI * (self._edgeCurl.T * (self._MfMui * dun_dm_v)) def _eDeriv_m(self, tInd, src, v, adjoint=False): From 7dbd1659a8238146880a4b06eb8a5b274c302cd7 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 8 Apr 2024 07:16:06 -0700 Subject: [PATCH 31/33] Store spatial and time projections on receivers. Skip deriv calcs for empty times --- .../time_domain/simulation.py | 19 +++++++++++-- .../electromagnetics/time_domain/receivers.py | 28 ++++++++++++------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 10ed0377af..94720156ed 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -263,15 +263,26 @@ def delayed_block_deriv( j_update = 0.0 source = source_list[indices[0]] receiver = source.receiver_list[indices[1]] - PTv = receiver.getP(mesh, time_mesh, fields).tocsr() + # PTv = receiver.getP(mesh, time_mesh, fields).tocsr() + spatialP = receiver.getSpatialP(mesh, fields) + timeP = receiver.getTimeP(time_mesh, fields) + derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) time_derivs = [] for time_index in range(n_times + 1): + if len(timeP[:, time_index].data) == 0: + time_derivs.append( + sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) + ) + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + continue + + projection = sp.kron(timeP[:, time_index], spatialP) cur = derivative_fun( time_index, source, None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + projection.T, adjoint=True, ) time_derivs.append(cur[0]) @@ -312,8 +323,10 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): # delayed_blocks.append(delayed_chunks) + tc = time() + print("Computing field derivatives") result = dask.compute(delayed_chunks)[0] - + print(f"Field derivatives computed in {time() - tc:.2f}s") # len_blocks = [[[] for _ in block] for block in blocks if len(block) > 0] df_duT = [ [[[] for _ in block] for block in blocks if len(block) > 0] diff --git a/SimPEG/electromagnetics/time_domain/receivers.py b/SimPEG/electromagnetics/time_domain/receivers.py index 49e9e78c32..d2a3a1f8cb 100644 --- a/SimPEG/electromagnetics/time_domain/receivers.py +++ b/SimPEG/electromagnetics/time_domain/receivers.py @@ -99,14 +99,17 @@ def getSpatialP(self, mesh, f): scipy.sparse.csr_matrix P, the interpolation matrix """ - P = Zero() - field = f._GLoc(self.projField) - for strength, comp in zip(self.orientation, ["x", "y", "z"]): - if strength != 0.0: - P = P + strength * mesh.get_interpolation_matrix( - self.locations, field + comp - ) - return P + if getattr(self, "spatialP", None) is None: + P = Zero() + field = f._GLoc(self.projField) + for strength, comp in zip(self.orientation, ["x", "y", "z"]): + if strength != 0.0: + P = P + strength * mesh.get_interpolation_matrix( + self.locations, field + comp + ) + self.spatialP = P + + return self.spatialP def getTimeP(self, time_mesh, f): """Get time projection matrix from mesh to receivers. @@ -124,8 +127,13 @@ def getTimeP(self, time_mesh, f): scipy.sparse.csr_matrix P, the interpolation matrix """ - projected_time_grid = f._TLoc(self.projField) - return time_mesh.get_interpolation_matrix(self.times, projected_time_grid) + if getattr(self, "timeP", None) is None: + projected_time_grid = f._TLoc(self.projField) + self.timeP = time_mesh.get_interpolation_matrix( + self.times, projected_time_grid + ) + + return self.timeP def getP(self, mesh, time_mesh, f): """Returns projection matrices as a list for all components collected by the receivers. From c31801bfc4ea7beacfcc64e8c1e3821704f1b7f0 Mon Sep 17 00:00:00 2001 From: domfournier Date: Mon, 8 Apr 2024 07:28:26 -0700 Subject: [PATCH 32/33] Remove prints --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 94720156ed..5c764bf6eb 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -323,10 +323,10 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): # delayed_blocks.append(delayed_chunks) - tc = time() - print("Computing field derivatives") + # tc = time() + # print("Computing field derivatives") result = dask.compute(delayed_chunks)[0] - print(f"Field derivatives computed in {time() - tc:.2f}s") + # print(f"Field derivatives computed in {time() - tc:.2f}s") # len_blocks = [[[] for _ in block] for block in blocks if len(block) > 0] df_duT = [ [[[] for _ in block] for block in blocks if len(block) > 0] From abe57ff121e90a7982eb3b1649e77a262602b94b Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 9 Apr 2024 14:27:38 -0400 Subject: [PATCH 33/33] Clean out commented lines --- .../time_domain/simulation.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 5c764bf6eb..fee66abde2 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -9,7 +9,6 @@ import scipy.sparse as sp from dask import array, delayed -# from dask.diagnostics import ProgressBar from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks import zarr @@ -179,7 +178,6 @@ def source_evaluation(simulation, sources, time): def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): data = [] for source, ind, receiver in block: - # proj = receiver.getP(mesh, time_mesh, fields) Ps = receiver.getSpatialP(mesh, fields) Pt = receiver.getTimeP(time_mesh, fields) vector = (Pt * (Ps * fields_array[:, ind, :]).T).flatten() @@ -208,7 +206,7 @@ def dask_dpred(self, m=None, f=None, compute_J=False): "data. Please set the survey for the simulation: " "simulation.survey = survey" ) - # ct = time() + if f is None: if m is None: m = self.model @@ -263,7 +261,7 @@ def delayed_block_deriv( j_update = 0.0 source = source_list[indices[0]] receiver = source.receiver_list[indices[1]] - # PTv = receiver.getP(mesh, time_mesh, fields).tocsr() + spatialP = receiver.getSpatialP(mesh, fields) timeP = receiver.getTimeP(time_mesh, fields) @@ -302,8 +300,6 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ - # delayed_blocks = [] - # for time_index in range(simulation.nT + 1): delayed_chunks = [] for chunks in blocks: if len(chunks) == 0: @@ -321,13 +317,7 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): ) delayed_chunks.append(delayed_block) - # delayed_blocks.append(delayed_chunks) - - # tc = time() - # print("Computing field derivatives") result = dask.compute(delayed_chunks)[0] - # print(f"Field derivatives computed in {time() - tc:.2f}s") - # len_blocks = [[[] for _ in block] for block in blocks if len(block) > 0] df_duT = [ [[[] for _ in block] for block in blocks if len(block) > 0] for _ in range(simulation.nT + 1) @@ -455,8 +445,6 @@ def get_field_deriv_block( update_deriv_blocks(address, indices, ATinv_df_duT_v, solve, shape) - # dask.compute(update_list) - return ATinv_df_duT_v @@ -500,9 +488,6 @@ def compute_rows( dAT_dm_v = simulation.getAdiagDeriv( tInd, un_src, field_derivs[:, local_ind], adjoint=True ) - # if isinstance(Jmatrix, zarr.core.Array): - # Jmatrix.oindex[ind_array[1][time_check].tolist(), :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(np.float64) - # else: row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32 ) @@ -582,7 +567,6 @@ def compute_J(self, f=None, Ainv=None): ) ) - # Jmatrix = Jmatrix + array.vstack(j_row_updates) if self.store_sensitivities == "disk": sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" array.to_zarr(