From a910d829f9302730e23c08cbf5cd168cf135a8bb Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 22 Nov 2023 14:55:10 -0800 Subject: [PATCH 01/13] Work in progress --- .../time_domain/simulation.py | 31 ++++++++++--------- SimPEG/directives/directives.py | 3 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 5ce1b66869..25a550c1d0 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -241,10 +241,11 @@ def compute_J(self, f=None, Ainv=None): # tc_loop = time() # print(f"Loop sources for {tInd}") for isrc, src in enumerate(self.survey.source_list): - column_inds = np.kron( - np.ones(int(src.vnD / n_times), dtype=bool), data_bool - ) - # for block in range(len(self.field_derivs[tInd][isrc])): + + column_inds = np.hstack([ + np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool + ) for rec in src.receiver_list]) + if isrc not in field_derivs_t: field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ :, column_inds @@ -252,7 +253,6 @@ def compute_J(self, f=None, Ainv=None): else: field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] - n_data = self.field_derivs[tInd + 1][isrc].shape[1] d_count += column_inds.sum() if d_count > d_block_size: @@ -330,9 +330,11 @@ def block_append( count = 0 for (isrc, src), block in field_derivs.items(): - column_inds = np.kron( - np.ones(int(src.vnD / len(data_bool)), dtype=bool), data_bool - ) + + column_inds = np.hstack([ + np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool + ) for rec in src.receiver_list]) + n_rows = column_inds.sum() source_blocks.append( dask.array.from_delayed( @@ -373,21 +375,20 @@ def block_deriv(simulation, src, tInd, f, block_size, row_count): simulation.nT, src, None, - PT_v[tInd * block_size : (tInd + 1) * block_size, :], + PT_v[tInd * block_size: (tInd + 1) * block_size, :], adjoint=True, ) if not isinstance(cur[1], Zero): - simulation.J_initializer[row_count : row_count + rx.nD, :] += cur[1].T + simulation.J_initializer[row_count: row_count + rx.nD, :] += cur[1].T if src_field_derivs is None: src_field_derivs = cur[0] else: - src_field_derivs += cur[0] + src_field_derivs = sp.hstack([src_field_derivs, cur[0]]) + + row_count += rx.nD - # n_blocks = int(np.ceil(np.prod(src_field_derivs.shape) * 8. * 1e-6 / 128.)) - # ind_col = np.array_split(np.arange(src_field_derivs.shape[1]), col_blocks) - # return [src_field_derivs[:, ind] for ind in ind_col] return src_field_derivs @@ -404,7 +405,7 @@ def parallel_block_compute( field_derivs, data_bool, ): - rows = np.arange(row_count, row_count + len(data_bool))[data_bool] + rows = row_count + np.where(data_bool)[0] field_derivs_t = np.asarray(field_derivs.todense()) field_derivs_t[:, data_bool] -= Asubdiag.T * ATinv_df_duT_v diff --git a/SimPEG/directives/directives.py b/SimPEG/directives/directives.py index 8150dec3c8..d5daf81a1a 100644 --- a/SimPEG/directives/directives.py +++ b/SimPEG/directives/directives.py @@ -2934,7 +2934,8 @@ def save_components(self, iteration: int, values: list[np.ndarray] = None): elif self.attribute_type == "predicted": dpred = getattr(self.invProb, "dpred", None) if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) + fields = self.invProb.dmisfit.objfcts[0].simulation.fields(self.invProb.model) + dpred = self.invProb.get_dpred(self.invProb.model, [fields]) self.invProb.dpred = dpred if self.joint_index is not None: From 50fcc52af3ea6a443d1e2605564d4530a6d9a81a Mon Sep 17 00:00:00 2001 From: fourndo Date: Fri, 24 Nov 2023 15:38:16 -0800 Subject: [PATCH 02/13] Re-implement compute_J nonparallel --- .../time_domain/simulation.py | 412 ++++++++++++------ 1 file changed, 280 insertions(+), 132 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 25a550c1d0..36c98b4ae0 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -156,157 +156,305 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): def compute_J(self, f=None, Ainv=None): +# def Jtvec(self, m, v, f=None): + r""" + Jvec computes the adjoint of the sensitivity times a vector + + .. math:: + + \mathbf{J}^\top \mathbf{v} = + \left( + \frac{d\mathbf{u}}{d\mathbf{m}} ^ \top + \frac{d\mathbf{F}}{d\mathbf{u}} ^ \top + + \frac{\partial\mathbf{F}}{\partial\mathbf{m}} ^ \top + \right) + \frac{d\mathbf{P}}{d\mathbf{F}} ^ \top + \mathbf{v} + + where + + .. math:: + + \frac{d\mathbf{u}}{d\mathbf{m}} ^\top \mathbf{A}^\top + + \frac{d\mathbf{A}(\mathbf{u})}{d\mathbf{m}} ^ \top = + \frac{d \mathbf{RHS}}{d \mathbf{m}} ^ \top + """ + if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) + ftype = self._fieldType + "Solution" # the thing we solved for + + # Ensure v is a data object. + # if not isinstance(v, Data): + # v = Data(self.survey, v) + + df_duT_v = {} + field_len = len(f[self.survey.source_list[0], ftype, 0]) + # same size as fields at a single timestep + ATinv_df_duT_v = {} + m_size = self.model.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) - ) - ) + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + # Loop over sources and receivers to create a fields object: + # PT_v, df_duT_v, df_dmT_v + # initialize storage for PT_v (don't need to preserve over sources) + PT_v = self.Fields_Derivs(self) + rx_count = 0 + for src in self.survey.source_list: + # Looping over initializing field class is appending memory! + # PT_v = Fields_Derivs(self.mesh) # initialize storage + # #for PT_v (don't need to preserve over sources) + # initialize size + df_duT_v[src] = {} - if self.store_sensitivities == "disk": - self.J_initializer = zarr.open( - self.sensitivity_path + f"J_initializer.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), - ) - else: - self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32) - solution_type = self._fieldType + "Solution" # the thing we solved for - - if self.field_derivs is None: - # print("Start loop for field derivs") - block_size = len(f[self.survey.source_list[0], solution_type, 0]) - - field_derivs = [] - for tInd in range(self.nT + 1): - d_count = 0 - df_duT_v = [] - for i_s, src in enumerate(self.survey.source_list): - src_field_derivs = delayed(block_deriv, pure=True)( - self, src, tInd, f, block_size, d_count + for rx in src.receiver_list: + df_duT_v[src][rx] = {} + PTv = np.asarray( + rx.getP(self.mesh, self.time_mesh, f).todense().T + ).reshape((field_len, self.nT + 1, -1), order="F") + + n_rec_comp = rx.locations.shape[0] * (self.nT + 1) + + # PT_v[src, "{}Deriv".format(rx.projField), :] = rx.evalDeriv( + # src, self.mesh, self.time_mesh, f, mkvc(v[src, rx]), adjoint=True + # ) # this is += + + # PT_v = np.reshape(curPT_v,(len(curPT_v)/self.time_mesh.nN, + # self.time_mesh.nN), order='F') + df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) + + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + for tInd in range(self.nT + 1): + cur = df_duTFun( + tInd, + src, + None, + PTv[:, tInd, :], + adjoint=True, ) - df_duT_v += [src_field_derivs] - d_count += np.sum([rx.nD for rx in src.receiver_list]) - - field_derivs += [df_duT_v] - # print("Dask loop field derivs") - # tc = time() - - self.field_derivs = dask.compute(field_derivs)[0] - # print(f"Done in {time() - tc} seconds") - - if self.store_sensitivities == "disk": - Jmatrix = ( - zarr.open( - self.sensitivity_path + f"J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), - ) - + self.J_initializer - ) - else: - Jmatrix = dask.delayed( - np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer - ) - f = dask.delayed(f) - field_derivs_t = {} - d_block_size = np.ceil(128.0 / (m_size * 8.0 * 1e-6)) + df_duT_v[src][rx][tInd] = cur[0] + Jmatrix[rx_ind, :] += cur[1].T + + rx_count += rx.nD + + del PT_v # no longer need this - # Check which time steps we need to compute - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - n_times = len(data_times) + AdiagTinv = None + + # Do the back-solve through time + # if the previous timestep is the same: no need to refactor the matrix + # for tInd, dt in zip(range(self.nT), self.time_steps): for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + # tInd = tIndP - 1 AdiagTinv = Ainv[dt] - Asubdiag = self.getAsubdiag(tInd) - row_count = 0 - row_blocks = [] - field_derivs = {} - source_blocks = [] - d_count = 0 + # Asubdiag = self.getAsubdiag(tInd) - data_bool = data_times > simulation_times[tInd] + if tInd < self.nT - 1: + Asubdiag = self.getAsubdiag(tInd + 1) + rx_count = 0 + for isrc, src in enumerate(self.survey.source_list): - if data_bool.sum() == 0: - continue + if isrc not in ATinv_df_duT_v: + ATinv_df_duT_v[isrc] = {} - # tc_loop = time() - # print(f"Loop sources for {tInd}") - for isrc, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + if rx not in ATinv_df_duT_v[isrc]: + ATinv_df_duT_v[isrc][rx] = {} - column_inds = np.hstack([ - np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool - ) for rec in src.receiver_list]) - - if isrc not in field_derivs_t: - field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ - :, column_inds - ] - else: - field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] - - d_count += column_inds.sum() - - if d_count > d_block_size: - source_blocks = block_append( - self, - f, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + # solve against df_duT_v + if tInd >= self.nT - 1: + + + # last timestep (first to be solved) + ATinv_df_duT_v[isrc][rx][tInd] = ( + AdiagTinv + * df_duT_v[src][rx][tInd+1] + ) + elif tInd > -1: + ATinv_df_duT_v[isrc][rx][tInd] = AdiagTinv * ( + df_duT_v[src][rx][tInd+1] + - Asubdiag.T * ATinv_df_duT_v[isrc][rx][tInd+1] + ) + + dAsubdiagT_dm_v = self.getAsubdiagDeriv( + tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx][tInd], adjoint=True ) - field_derivs = {} - row_count = d_count - d_count = 0 - - if field_derivs: - source_blocks = block_append( - self, - f, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, - ) - # print(f"Done in {time() - tc_loop} seconds") - # tc = time() - # print(f"Compute field derivs for {tInd}") - del field_derivs_t - field_derivs_t = { - isrc: elem for isrc, elem in enumerate(dask.compute(source_blocks)[0]) - } - # print(f"Done in {time() - tc} seconds") + dRHST_dm_v = self.getRHSDeriv( + tInd + 1, src, ATinv_df_duT_v[isrc][rx][tInd], adjoint=True + ) # on nodes of time mesh - for A in Ainv.values(): - A.clean() + un_src = f[src, ftype, tInd + 1] + # cell centered on time mesh + dAT_dm_v = self.getAdiagDeriv( + tInd, un_src, ATinv_df_duT_v[isrc][rx][tInd], adjoint=True + ) - if self.store_sensitivities == "disk": - del Jmatrix - return array.from_zarr(self.sensitivity_path + f"J.zarr") - else: - return Jmatrix.compute() + Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + + rx_count += rx.nD + # Treat the initial condition + + # del df_duT_v, ATinv_df_duT_v, A, Asubdiag + if AdiagTinv is not None: + AdiagTinv.clean() + + return Jmatrix + # if f is None: + # f, Ainv = self.fields(self.model, return_Ainv=True) + # + # m_size = self.model.size + # row_chunks = int( + # np.ceil( + # float(self.survey.nD) + # / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) + # ) + # ) + # + # if self.store_sensitivities == "disk": + # self.J_initializer = zarr.open( + # self.sensitivity_path + f"J_initializer.zarr", + # mode="w", + # shape=(self.survey.nD, m_size), + # chunks=(row_chunks, m_size), + # ) + # else: + # self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32) + # solution_type = self._fieldType + "Solution" # the thing we solved for + # + # if self.field_derivs is None: + # # print("Start loop for field derivs") + # block_size = len(f[self.survey.source_list[0], solution_type, 0]) + # + # field_derivs = [] + # for tInd in range(self.nT + 1): + # d_count = 0 + # df_duT_v = [] + # for i_s, src in enumerate(self.survey.source_list): + # src_field_derivs = delayed(block_deriv, pure=True)( + # self, src, tInd, f, block_size, d_count + # ) + # df_duT_v += [src_field_derivs] + # d_count += np.sum([rx.nD for rx in src.receiver_list]) + # + # field_derivs += [df_duT_v] + # # print("Dask loop field derivs") + # # tc = time() + # + # self.field_derivs = dask.compute(field_derivs)[0] + # # print(f"Done in {time() - tc} seconds") + # + # if self.store_sensitivities == "disk": + # Jmatrix = ( + # zarr.open( + # self.sensitivity_path + f"J.zarr", + # mode="w", + # shape=(self.survey.nD, m_size), + # chunks=(row_chunks, m_size), + # ) + # + self.J_initializer + # ) + # else: + # Jmatrix = dask.delayed( + # np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer + # ) + # + # f = dask.delayed(f) + # field_derivs_t = {} + # d_block_size = np.ceil(128.0 / (m_size * 8.0 * 1e-6)) + # + # # Check which time steps we need to compute + # simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + # data_times = self.survey.source_list[0].receiver_list[0].times + # n_times = len(data_times) + # + # for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + # AdiagTinv = Ainv[dt] + # Asubdiag = self.getAsubdiag(tInd) + # row_count = 0 + # row_blocks = [] + # field_derivs = {} + # source_blocks = [] + # d_count = 0 + # + # data_bool = data_times > simulation_times[tInd] + # + # if data_bool.sum() == 0: + # continue + # + # # tc_loop = time() + # # print(f"Loop sources for {tInd}") + # for isrc, src in enumerate(self.survey.source_list): + # + # column_inds = np.hstack([ + # np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool + # ) for rec in src.receiver_list]) + # + # if isrc not in field_derivs_t: + # field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ + # :, column_inds + # ] + # else: + # field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] + # + # d_count += column_inds.sum() + # + # if d_count > d_block_size: + # source_blocks = block_append( + # self, + # f, + # AdiagTinv, + # field_derivs, + # m_size, + # row_count, + # tInd, + # solution_type, + # Jmatrix, + # Asubdiag, + # source_blocks, + # data_bool, + # ) + # field_derivs = {} + # row_count = d_count + # d_count = 0 + # + # if field_derivs: + # source_blocks = block_append( + # self, + # f, + # AdiagTinv, + # field_derivs, + # m_size, + # row_count, + # tInd, + # solution_type, + # Jmatrix, + # Asubdiag, + # source_blocks, + # data_bool, + # ) + # + # # print(f"Done in {time() - tc_loop} seconds") + # # tc = time() + # # print(f"Compute field derivs for {tInd}") + # del field_derivs_t + # field_derivs_t = { + # isrc: elem for isrc, elem in enumerate(dask.compute(source_blocks)[0]) + # } + # # print(f"Done in {time() - tc} seconds") + # + # for A in Ainv.values(): + # A.clean() + # + # if self.store_sensitivities == "disk": + # del Jmatrix + # return array.from_zarr(self.sensitivity_path + f"J.zarr") + # else: + # return Jmatrix.compute() Sim.compute_J = compute_J From 36949ad94ed37d9b50aa7887219093d60348d03b Mon Sep 17 00:00:00 2001 From: fourndo Date: Sat, 25 Nov 2023 09:40:38 -0800 Subject: [PATCH 03/13] Parallel the field derivs --- .../time_domain/simulation.py | 211 +++++++++--------- 1 file changed, 104 insertions(+), 107 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 36c98b4ae0..0328232885 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -155,8 +155,61 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): Sim.field_derivs = None +def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): + """Compute derivatives for sources and receivers in a block""" + field_len = len(fields[source_list[0], field_type, 0]) + df_duT = {src: {} for src in source_list} + + rx_count = 0 + for src in source_list: + df_duT[src] = {rx: {} for rx in src.receiver_list} + + for rx in src.receiver_list: + PTv = np.asarray( + rx.getP(mesh, time_mesh, fields).todense().T + ).reshape((field_len, time_mesh.n_faces, -1), order="F") + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + + cur = derivative_fun( + time_index, + src, + None, + sp.csr_matrix(PTv[:, time_index, :]), + adjoint=True, + ) + df_duT[src][rx] = cur[0] + Jmatrix[rx_ind, :] += cur[1].T + + rx_count += rx.nD + + return df_duT + + +def compute_field_derivs(simulation, Jmatrix, fields): + """ + Compute the derivative of the fields + """ + + df_duT = [] + + for time_index in range(simulation.nT + 1): + df_duT.append(delayed(block_deriv, pure=True)( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix + )) + + + df_duT = dask.compute(df_duT)[0] + + return df_duT + def compute_J(self, f=None, Ainv=None): -# def Jtvec(self, m, v, f=None): r""" Jvec computes the adjoint of the sensitivity times a vector @@ -183,74 +236,15 @@ def compute_J(self, f=None, Ainv=None): if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) - ftype = self._fieldType + "Solution" # the thing we solved for + ftype = self._fieldType + "Solution" + Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) - # Ensure v is a data object. - # if not isinstance(v, Data): - # v = Data(self.survey, v) + self.field_derivs = compute_field_derivs(self, Jmatrix, f) - df_duT_v = {} - field_len = len(f[self.survey.source_list[0], ftype, 0]) - # same size as fields at a single timestep ATinv_df_duT_v = {} - - m_size = self.model.size - Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) - # Loop over sources and receivers to create a fields object: - # PT_v, df_duT_v, df_dmT_v - # initialize storage for PT_v (don't need to preserve over sources) - PT_v = self.Fields_Derivs(self) - rx_count = 0 - for src in self.survey.source_list: - # Looping over initializing field class is appending memory! - # PT_v = Fields_Derivs(self.mesh) # initialize storage - # #for PT_v (don't need to preserve over sources) - # initialize size - df_duT_v[src] = {} - - for rx in src.receiver_list: - df_duT_v[src][rx] = {} - PTv = np.asarray( - rx.getP(self.mesh, self.time_mesh, f).todense().T - ).reshape((field_len, self.nT + 1, -1), order="F") - - n_rec_comp = rx.locations.shape[0] * (self.nT + 1) - - # PT_v[src, "{}Deriv".format(rx.projField), :] = rx.evalDeriv( - # src, self.mesh, self.time_mesh, f, mkvc(v[src, rx]), adjoint=True - # ) # this is += - - # PT_v = np.reshape(curPT_v,(len(curPT_v)/self.time_mesh.nN, - # self.time_mesh.nN), order='F') - df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) - - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - for tInd in range(self.nT + 1): - cur = df_duTFun( - tInd, - src, - None, - PTv[:, tInd, :], - adjoint=True, - ) - - df_duT_v[src][rx][tInd] = cur[0] - Jmatrix[rx_ind, :] += cur[1].T - - rx_count += rx.nD - - del PT_v # no longer need this - - AdiagTinv = None - - # Do the back-solve through time - # if the previous timestep is the same: no need to refactor the matrix - # for tInd, dt in zip(range(self.nT), self.time_steps): - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - # tInd = tIndP - 1 + AdiagTinv = Ainv[dt] - # Asubdiag = self.getAsubdiag(tInd) if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) @@ -267,43 +261,46 @@ def compute_J(self, f=None, Ainv=None): rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) # solve against df_duT_v if tInd >= self.nT - 1: - - # last timestep (first to be solved) - ATinv_df_duT_v[isrc][rx][tInd] = ( + ATinv_df_duT_v[isrc][rx] = ( AdiagTinv - * df_duT_v[src][rx][tInd+1] + * self.field_derivs[tInd+1][src][rx].toarray() ) elif tInd > -1: - ATinv_df_duT_v[isrc][rx][tInd] = AdiagTinv * ( - df_duT_v[src][rx][tInd+1] - - Asubdiag.T * ATinv_df_duT_v[isrc][rx][tInd+1] + ATinv_df_duT_v[isrc][rx] = AdiagTinv * np.asarray( + self.field_derivs[tInd+1][src][rx] + - Asubdiag.T * ATinv_df_duT_v[isrc][rx] ) dAsubdiagT_dm_v = self.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx][tInd], adjoint=True + tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx], adjoint=True ) dRHST_dm_v = self.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[isrc][rx][tInd], adjoint=True + tInd + 1, src, ATinv_df_duT_v[isrc][rx], adjoint=True ) # on nodes of time mesh un_src = f[src, ftype, tInd + 1] # cell centered on time mesh dAT_dm_v = self.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[isrc][rx][tInd], adjoint=True + tInd, un_src, ATinv_df_duT_v[isrc][rx], adjoint=True ) Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T rx_count += rx.nD - # Treat the initial condition - # del df_duT_v, ATinv_df_duT_v, A, Asubdiag - if AdiagTinv is not None: - AdiagTinv.clean() - return Jmatrix + for A in Ainv.values(): + A.clean() + + if self.store_sensitivities == "disk": + del Jmatrix + return array.from_zarr(self.sensitivity_path + f"J.zarr") + else: + return Jmatrix + + # if f is None: # f, Ainv = self.fields(self.model, return_Ainv=True) # @@ -510,34 +507,34 @@ def block_append( return source_blocks -def block_deriv(simulation, src, tInd, f, block_size, row_count): - src_field_derivs = None - for rx in src.receiver_list: - v = sp.eye(rx.nD, dtype=float) - PT_v = rx.evalDeriv( - src, simulation.mesh, simulation.time_mesh, f, v, adjoint=True - ) - df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) - - cur = df_duTFun( - simulation.nT, - src, - None, - PT_v[tInd * block_size: (tInd + 1) * block_size, :], - adjoint=True, - ) - - if not isinstance(cur[1], Zero): - simulation.J_initializer[row_count: row_count + rx.nD, :] += cur[1].T - - if src_field_derivs is None: - src_field_derivs = cur[0] - else: - src_field_derivs = sp.hstack([src_field_derivs, cur[0]]) - - row_count += rx.nD - - return src_field_derivs +# def block_deriv(simulation, src, tInd, f, block_size, row_count): +# src_field_derivs = None +# for rx in src.receiver_list: +# v = sp.eye(rx.nD, dtype=float) +# PT_v = rx.evalDeriv( +# src, simulation.mesh, simulation.time_mesh, f, v, adjoint=True +# ) +# df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) +# +# cur = df_duTFun( +# simulation.nT, +# src, +# None, +# PT_v[tInd * block_size: (tInd + 1) * block_size, :], +# adjoint=True, +# ) +# +# if not isinstance(cur[1], Zero): +# simulation.J_initializer[row_count: row_count + rx.nD, :] += cur[1].T +# +# if src_field_derivs is None: +# src_field_derivs = cur[0] +# else: +# src_field_derivs = sp.hstack([src_field_derivs, cur[0]]) +# +# row_count += rx.nD +# +# return src_field_derivs def parallel_block_compute( From 2f4e4de1cc886bd8b5a3d6ca934c5485ab53a7ce Mon Sep 17 00:00:00 2001 From: fourndo Date: Mon, 27 Nov 2023 20:03:11 -0800 Subject: [PATCH 04/13] Full parallel run. Needs validation --- .../time_domain/simulation.py | 211 +++++++++++++----- 1 file changed, 152 insertions(+), 59 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 0328232885..e9aa6c2372 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -209,28 +209,100 @@ def compute_field_derivs(simulation, Jmatrix, fields): return df_duT -def compute_J(self, f=None, Ainv=None): - r""" - Jvec computes the adjoint of the sensitivity times a vector +def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): + """ + Get the blocks of sources and receivers to be computed in parallel. - .. math:: + Stored as a dictionary of source, receiver pairs index. The value is an array of indices + for the rows of the sensitivity matrix. + """ + data_block_size = np.ceil(max_chunk_size / (m_size * 8.0 * 1e-6)) + row_count = 0 + row_index = 0 + block_count = 0 + blocks = {0: {}} + for src in source_list: + for rx in src.receiver_list: - \mathbf{J}^\top \mathbf{v} = - \left( - \frac{d\mathbf{u}}{d\mathbf{m}} ^ \top - \frac{d\mathbf{F}}{d\mathbf{u}} ^ \top - + \frac{\partial\mathbf{F}}{\partial\mathbf{m}} ^ \top - \right) - \frac{d\mathbf{P}}{d\mathbf{F}} ^ \top - \mathbf{v} + indices = np.arange(rx.nD).astype(int) + chunks = np.split(indices, int(np.ceil(len(indices)/data_block_size))) - where + for ind, chunk in enumerate(chunks): + chunk_size = len(chunk) - .. math:: + # Condition to start a new block + if (row_count + chunk_size) > (data_block_size * cpu_count() / 2): + row_count = 0 + block_count += 1 + blocks[block_count] = {} + + blocks[block_count][(src, rx, ind)] = chunk, np.arange(row_index, row_index + chunk_size).astype(int) + row_index += chunk_size + row_count += chunk_size + + return blocks + + +def get_field_deriv_block(simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict): + """ + Stack the blocks of field derivatives for a given timestep and call the direct solver. + """ + stacked_blocks = [] + indices = [] + count = 0 + for (src, rx, ind), (rx_ind, j_ind) in block.items(): + indices.append( + np.arange(count, count + len(rx_ind)) + ) + count += len(rx_ind) + if (src, rx, ind) not in ATinv_df_duT_v: + # last timestep (first to be solved) + stacked_blocks.append( + simulation.field_derivs[tInd + 1][src][rx].toarray()[:, rx_ind] + ) + + else: + Asubdiag = simulation.getAsubdiag(tInd + 1) + stacked_blocks.append( + np.asarray( + simulation.field_derivs[tInd + 1][src][rx][:, rx_ind] + - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)] + ) + ) + + solve = AdiagTinv * np.hstack(stacked_blocks) + + for (src, rx, ind), columns in zip(block, indices): + ATinv_df_duT_v[(src, rx, ind)] = solve[:, columns] + + return ATinv_df_duT_v - \frac{d\mathbf{u}}{d\mathbf{m}} ^\top \mathbf{A}^\top + - \frac{d\mathbf{A}(\mathbf{u})}{d\mathbf{m}} ^ \top = - \frac{d \mathbf{RHS}}{d \mathbf{m}} ^ \top + +def compute_rows(simulation, tInd, src, rx_ind, j_ind, ATinv_df_duT_v, f, Jmatrix, ftype): + """ + Compute the rows of the sensitivity matrix for a given source and receiver. + """ + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, f[src, ftype, tInd], ATinv_df_duT_v, adjoint=True + ) + + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, ATinv_df_duT_v, adjoint=True + ) # on nodes of time mesh + + un_src = f[src, ftype, tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, ATinv_df_duT_v, adjoint=True + ) + + Jmatrix[j_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + + + +def compute_J(self, f=None, Ainv=None): + """ + Compute the rows for the sensitivity matrix. """ if f is None: @@ -238,6 +310,7 @@ def compute_J(self, f=None, Ainv=None): ftype = self._fieldType + "Solution" Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) + blocks = get_parallel_blocks(self.survey.source_list, self.model.shape[0], self.max_chunk_size) self.field_derivs = compute_field_derivs(self, Jmatrix, f) @@ -246,49 +319,69 @@ def compute_J(self, f=None, Ainv=None): AdiagTinv = Ainv[dt] - if tInd < self.nT - 1: - Asubdiag = self.getAsubdiag(tInd + 1) - rx_count = 0 - for isrc, src in enumerate(self.survey.source_list): - - if isrc not in ATinv_df_duT_v: - ATinv_df_duT_v[isrc] = {} - - for rx in src.receiver_list: - if rx not in ATinv_df_duT_v[isrc]: - ATinv_df_duT_v[isrc][rx] = {} - - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - # solve against df_duT_v - if tInd >= self.nT - 1: - # last timestep (first to be solved) - ATinv_df_duT_v[isrc][rx] = ( - AdiagTinv - * self.field_derivs[tInd+1][src][rx].toarray() - ) - elif tInd > -1: - ATinv_df_duT_v[isrc][rx] = AdiagTinv * np.asarray( - self.field_derivs[tInd+1][src][rx] - - Asubdiag.T * ATinv_df_duT_v[isrc][rx] - ) - - dAsubdiagT_dm_v = self.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx], adjoint=True - ) - - dRHST_dm_v = self.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[isrc][rx], adjoint=True - ) # on nodes of time mesh + ATinv_df_duT_v = {} + j_row_updates = [] + for block in blocks.values(): + ATinv_df_duT_v = get_field_deriv_block(self, block, tInd, AdiagTinv, ATinv_df_duT_v) - un_src = f[src, ftype, tInd + 1] - # cell centered on time mesh - dAT_dm_v = self.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[isrc][rx], adjoint=True - ) - - Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - rx_count += rx.nD + for (src, rx, ind), (rx_ind, j_ind) in block.items(): + j_row_updates.append(delayed(compute_rows, pure=True)( + self, + tInd, + src, + rx_ind, + j_ind, + ATinv_df_duT_v[(src, rx, ind)], + f, + Jmatrix, + ftype + )) + # for (src, rx, ind), (rx_ind, j_ind) in block.items(): + + + dask.compute(j_row_updates) + + # rx_count = 0 + # for isrc, src in enumerate(self.survey.source_list): + # + # if isrc not in ATinv_df_duT_v: + # ATinv_df_duT_v[isrc] = {} + # + # for rx in src.receiver_list: + # if rx not in ATinv_df_duT_v[isrc]: + # ATinv_df_duT_v[isrc][rx] = {} + # + # rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + # # solve against df_duT_v + # if tInd >= self.nT - 1: + # # last timestep (first to be solved) + # ATinv_df_duT_v[isrc][rx] = ( + # AdiagTinv + # * self.field_derivs[tInd+1][src][rx].toarray() + # ) + # elif tInd > -1: + # ATinv_df_duT_v[isrc][rx] = AdiagTinv * np.asarray( + # self.field_derivs[tInd+1][src][rx] + # - Asubdiag.T * ATinv_df_duT_v[isrc][rx] + # ) + # + # dAsubdiagT_dm_v = self.getAsubdiagDeriv( + # tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx], adjoint=True + # ) + # + # dRHST_dm_v = self.getRHSDeriv( + # tInd + 1, src, ATinv_df_duT_v[isrc][rx], adjoint=True + # ) # on nodes of time mesh + # + # un_src = f[src, ftype, tInd + 1] + # # cell centered on time mesh + # dAT_dm_v = self.getAdiagDeriv( + # tInd, un_src, ATinv_df_duT_v[isrc][rx], adjoint=True + # ) + # + # Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + + # rx_count += rx.nD for A in Ainv.values(): From 4988515066edbf60c5b624cfa1e766eaaab561b6 Mon Sep 17 00:00:00 2001 From: fourndo Date: Mon, 27 Nov 2023 20:58:54 -0800 Subject: [PATCH 05/13] Fix parallel process. Passes benchmark --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index e9aa6c2372..359ca13167 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -316,11 +316,9 @@ def compute_J(self, f=None, Ainv=None): ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - AdiagTinv = Ainv[dt] - - ATinv_df_duT_v = {} j_row_updates = [] + for block in blocks.values(): ATinv_df_duT_v = get_field_deriv_block(self, block, tInd, AdiagTinv, ATinv_df_duT_v) From dded48e7337b7a862e848bd91270543a68849933 Mon Sep 17 00:00:00 2001 From: fourndo Date: Tue, 28 Nov 2023 11:27:56 -0800 Subject: [PATCH 06/13] Add time masking. Good benchmark --- .../time_domain/simulation.py | 451 ++++-------------- 1 file changed, 97 insertions(+), 354 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 359ca13167..997908bebf 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -165,9 +165,7 @@ def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jm df_duT[src] = {rx: {} for rx in src.receiver_list} for rx in src.receiver_list: - PTv = np.asarray( - rx.getP(mesh, time_mesh, fields).todense().T - ).reshape((field_len, time_mesh.n_faces, -1), order="F") + PTv = rx.getP(mesh, time_mesh, fields).tocsr() derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) @@ -175,7 +173,7 @@ def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jm time_index, src, None, - sp.csr_matrix(PTv[:, time_index, :]), + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, adjoint=True, ) df_duT[src][rx] = cur[0] @@ -194,21 +192,23 @@ def compute_field_derivs(simulation, Jmatrix, fields): df_duT = [] for time_index in range(simulation.nT + 1): - df_duT.append(delayed(block_deriv, pure=True)( - time_index, - simulation._fieldType + "Solution", - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix - )) - + df_duT.append( + delayed(block_deriv, pure=True)( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, + ) + ) df_duT = dask.compute(df_duT)[0] return df_duT + def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): """ Get the blocks of sources and receivers to be computed in parallel. @@ -223,9 +223,8 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): blocks = {0: {}} for src in source_list: for rx in src.receiver_list: - indices = np.arange(rx.nD).astype(int) - chunks = np.split(indices, int(np.ceil(len(indices)/data_block_size))) + chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size))) for ind, chunk in enumerate(chunks): chunk_size = len(chunk) @@ -236,68 +235,107 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): block_count += 1 blocks[block_count] = {} - blocks[block_count][(src, rx, ind)] = chunk, np.arange(row_index, row_index + chunk_size).astype(int) + blocks[block_count][(src, rx, ind)] = chunk, np.arange( + row_index, row_index + chunk_size + ).astype(int) row_index += chunk_size row_count += chunk_size return blocks -def get_field_deriv_block(simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict): +def get_field_deriv_block( + simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask +): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ stacked_blocks = [] - indices = [] + indices = {} count = 0 for (src, rx, ind), (rx_ind, j_ind) in block.items(): - indices.append( - np.arange(count, count + len(rx_ind)) - ) - count += len(rx_ind) + # Cut out early data + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + rx_ind + ] + sub_ind = rx_ind[time_check] + + if len(sub_ind) < 1: + continue + + indices[(src, rx, ind)] = (np.arange(count, count + len(sub_ind)), sub_ind) + count += len(sub_ind) + if (src, rx, ind) not in ATinv_df_duT_v: # last timestep (first to be solved) stacked_blocks.append( - simulation.field_derivs[tInd + 1][src][rx].toarray()[:, rx_ind] + simulation.field_derivs[tInd + 1][src][rx].toarray()[:, sub_ind] ) else: Asubdiag = simulation.getAsubdiag(tInd + 1) stacked_blocks.append( np.asarray( - simulation.field_derivs[tInd + 1][src][rx][:, rx_ind] - - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)] + simulation.field_derivs[tInd + 1][src][rx][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] ) ) - solve = AdiagTinv * np.hstack(stacked_blocks) + if len(stacked_blocks) > 1: + solve = AdiagTinv * np.hstack(stacked_blocks) + + for src, rx, ind in block: + ATinv_df_duT_v[(src, rx, ind)] = np.zeros( + ( + simulation.field_derivs[tInd][src][rx].shape[0], + len(block[(src, rx, ind)][0]), + ) + ) - for (src, rx, ind), columns in zip(block, indices): - ATinv_df_duT_v[(src, rx, ind)] = solve[:, columns] + if (src, rx, ind) in indices: + columns, sub_ind = indices[(src, rx, ind)] + ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] = solve[:, columns] return ATinv_df_duT_v -def compute_rows(simulation, tInd, src, rx_ind, j_ind, ATinv_df_duT_v, f, Jmatrix, ftype): +def compute_rows( + simulation, + tInd, + src, + rx, + rx_ind, + j_ind, + ATinv_df_duT_v, + f, + Jmatrix, + ftype, + time_mask, +): """ Compute the rows of the sensitivity matrix for a given source and receiver. """ + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[rx_ind] + sub_ind = rx_ind[time_check] + + if len(sub_ind) < 1: + return + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v, adjoint=True + tInd, f[src, ftype, tInd], ATinv_df_duT_v[:, sub_ind], adjoint=True ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v, adjoint=True + tInd + 1, src, ATinv_df_duT_v[:, sub_ind], adjoint=True ) # on nodes of time mesh un_src = f[src, ftype, tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v, adjoint=True + tInd, un_src, ATinv_df_duT_v[:, sub_ind], adjoint=True ) - Jmatrix[j_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - + Jmatrix[j_ind[time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T def compute_J(self, f=None, Ainv=None): @@ -310,7 +348,11 @@ def compute_J(self, f=None, Ainv=None): ftype = self._fieldType + "Solution" Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) - blocks = get_parallel_blocks(self.survey.source_list, self.model.shape[0], self.max_chunk_size) + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + blocks = get_parallel_blocks( + self.survey.source_list, self.model.shape[0], self.max_chunk_size + ) self.field_derivs = compute_field_derivs(self, Jmatrix, f) @@ -318,70 +360,32 @@ def compute_J(self, f=None, Ainv=None): for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] j_row_updates = [] + time_mask = data_times > simulation_times[tInd] for block in blocks.values(): - ATinv_df_duT_v = get_field_deriv_block(self, block, tInd, AdiagTinv, ATinv_df_duT_v) + ATinv_df_duT_v = get_field_deriv_block( + self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + ) for (src, rx, ind), (rx_ind, j_ind) in block.items(): - j_row_updates.append(delayed(compute_rows, pure=True)( - self, - tInd, - src, - rx_ind, - j_ind, - ATinv_df_duT_v[(src, rx, ind)], - f, - Jmatrix, - ftype - )) - # for (src, rx, ind), (rx_ind, j_ind) in block.items(): - + j_row_updates.append( + delayed(compute_rows, pure=True)( + self, + tInd, + src, + rx, + rx_ind, + j_ind, + ATinv_df_duT_v[(src, rx, ind)], + f, + Jmatrix, + ftype, + time_mask, + ) + ) dask.compute(j_row_updates) - # rx_count = 0 - # for isrc, src in enumerate(self.survey.source_list): - # - # if isrc not in ATinv_df_duT_v: - # ATinv_df_duT_v[isrc] = {} - # - # for rx in src.receiver_list: - # if rx not in ATinv_df_duT_v[isrc]: - # ATinv_df_duT_v[isrc][rx] = {} - # - # rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - # # solve against df_duT_v - # if tInd >= self.nT - 1: - # # last timestep (first to be solved) - # ATinv_df_duT_v[isrc][rx] = ( - # AdiagTinv - # * self.field_derivs[tInd+1][src][rx].toarray() - # ) - # elif tInd > -1: - # ATinv_df_duT_v[isrc][rx] = AdiagTinv * np.asarray( - # self.field_derivs[tInd+1][src][rx] - # - Asubdiag.T * ATinv_df_duT_v[isrc][rx] - # ) - # - # dAsubdiagT_dm_v = self.getAsubdiagDeriv( - # tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx], adjoint=True - # ) - # - # dRHST_dm_v = self.getRHSDeriv( - # tInd + 1, src, ATinv_df_duT_v[isrc][rx], adjoint=True - # ) # on nodes of time mesh - # - # un_src = f[src, ftype, tInd + 1] - # # cell centered on time mesh - # dAT_dm_v = self.getAdiagDeriv( - # tInd, un_src, ATinv_df_duT_v[isrc][rx], adjoint=True - # ) - # - # Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - # rx_count += rx.nD - - for A in Ainv.values(): A.clean() @@ -392,265 +396,4 @@ def compute_J(self, f=None, Ainv=None): return Jmatrix - # if f is None: - # f, Ainv = self.fields(self.model, return_Ainv=True) - # - # m_size = self.model.size - # row_chunks = int( - # np.ceil( - # float(self.survey.nD) - # / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) - # ) - # ) - # - # if self.store_sensitivities == "disk": - # self.J_initializer = zarr.open( - # self.sensitivity_path + f"J_initializer.zarr", - # mode="w", - # shape=(self.survey.nD, m_size), - # chunks=(row_chunks, m_size), - # ) - # else: - # self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32) - # solution_type = self._fieldType + "Solution" # the thing we solved for - # - # if self.field_derivs is None: - # # print("Start loop for field derivs") - # block_size = len(f[self.survey.source_list[0], solution_type, 0]) - # - # field_derivs = [] - # for tInd in range(self.nT + 1): - # d_count = 0 - # df_duT_v = [] - # for i_s, src in enumerate(self.survey.source_list): - # src_field_derivs = delayed(block_deriv, pure=True)( - # self, src, tInd, f, block_size, d_count - # ) - # df_duT_v += [src_field_derivs] - # d_count += np.sum([rx.nD for rx in src.receiver_list]) - # - # field_derivs += [df_duT_v] - # # print("Dask loop field derivs") - # # tc = time() - # - # self.field_derivs = dask.compute(field_derivs)[0] - # # print(f"Done in {time() - tc} seconds") - # - # if self.store_sensitivities == "disk": - # Jmatrix = ( - # zarr.open( - # self.sensitivity_path + f"J.zarr", - # mode="w", - # shape=(self.survey.nD, m_size), - # chunks=(row_chunks, m_size), - # ) - # + self.J_initializer - # ) - # else: - # Jmatrix = dask.delayed( - # np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer - # ) - # - # f = dask.delayed(f) - # field_derivs_t = {} - # d_block_size = np.ceil(128.0 / (m_size * 8.0 * 1e-6)) - # - # # Check which time steps we need to compute - # simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - # data_times = self.survey.source_list[0].receiver_list[0].times - # n_times = len(data_times) - # - # for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - # AdiagTinv = Ainv[dt] - # Asubdiag = self.getAsubdiag(tInd) - # row_count = 0 - # row_blocks = [] - # field_derivs = {} - # source_blocks = [] - # d_count = 0 - # - # data_bool = data_times > simulation_times[tInd] - # - # if data_bool.sum() == 0: - # continue - # - # # tc_loop = time() - # # print(f"Loop sources for {tInd}") - # for isrc, src in enumerate(self.survey.source_list): - # - # column_inds = np.hstack([ - # np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool - # ) for rec in src.receiver_list]) - # - # if isrc not in field_derivs_t: - # field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ - # :, column_inds - # ] - # else: - # field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] - # - # d_count += column_inds.sum() - # - # if d_count > d_block_size: - # source_blocks = block_append( - # self, - # f, - # AdiagTinv, - # field_derivs, - # m_size, - # row_count, - # tInd, - # solution_type, - # Jmatrix, - # Asubdiag, - # source_blocks, - # data_bool, - # ) - # field_derivs = {} - # row_count = d_count - # d_count = 0 - # - # if field_derivs: - # source_blocks = block_append( - # self, - # f, - # AdiagTinv, - # field_derivs, - # m_size, - # row_count, - # tInd, - # solution_type, - # Jmatrix, - # Asubdiag, - # source_blocks, - # data_bool, - # ) - # - # # print(f"Done in {time() - tc_loop} seconds") - # # tc = time() - # # print(f"Compute field derivs for {tInd}") - # del field_derivs_t - # field_derivs_t = { - # isrc: elem for isrc, elem in enumerate(dask.compute(source_blocks)[0]) - # } - # # print(f"Done in {time() - tc} seconds") - # - # for A in Ainv.values(): - # A.clean() - # - # if self.store_sensitivities == "disk": - # del Jmatrix - # return array.from_zarr(self.sensitivity_path + f"J.zarr") - # else: - # return Jmatrix.compute() - - Sim.compute_J = compute_J - - -def block_append( - simulation, - fields, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, -): - solves = AdiagTinv * np.hstack(list(field_derivs.values())) - count = 0 - - for (isrc, src), block in field_derivs.items(): - - column_inds = np.hstack([ - np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool - ) for rec in src.receiver_list]) - - n_rows = column_inds.sum() - source_blocks.append( - dask.array.from_delayed( - delayed(parallel_block_compute, pure=True)( - simulation, - fields, - src, - solves[:, count : count + n_rows], - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - simulation.field_derivs[tInd][isrc], - column_inds, - ), - shape=simulation.field_derivs[tInd + 1][isrc].shape, - dtype=np.float32, - ) - ) - count += n_rows - # print(f"Appending block {isrc} in {time() - tc} seconds") - row_count += len(column_inds) - - return source_blocks - - -# def block_deriv(simulation, src, tInd, f, block_size, row_count): -# src_field_derivs = None -# for rx in src.receiver_list: -# v = sp.eye(rx.nD, dtype=float) -# PT_v = rx.evalDeriv( -# src, simulation.mesh, simulation.time_mesh, f, v, adjoint=True -# ) -# df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) -# -# cur = df_duTFun( -# simulation.nT, -# src, -# None, -# PT_v[tInd * block_size: (tInd + 1) * block_size, :], -# adjoint=True, -# ) -# -# if not isinstance(cur[1], Zero): -# simulation.J_initializer[row_count: row_count + rx.nD, :] += cur[1].T -# -# if src_field_derivs is None: -# src_field_derivs = cur[0] -# else: -# src_field_derivs = sp.hstack([src_field_derivs, cur[0]]) -# -# row_count += rx.nD -# -# return src_field_derivs - - -def parallel_block_compute( - simulation, - f, - src, - ATinv_df_duT_v, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - field_derivs, - data_bool, -): - rows = row_count + np.where(data_bool)[0] - field_derivs_t = np.asarray(field_derivs.todense()) - field_derivs_t[:, data_bool] -= Asubdiag.T * ATinv_df_duT_v - - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, solution_type, tInd], ATinv_df_duT_v, adjoint=True - ) - dRHST_dm_v = simulation.getRHSDeriv(tInd + 1, src, ATinv_df_duT_v, adjoint=True) - un_src = f[src, solution_type, tInd + 1] - dAT_dm_v = simulation.getAdiagDeriv(tInd, un_src, ATinv_df_duT_v, adjoint=True) - Jmatrix[rows, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - return field_derivs_t From 2061306ba14debb5c82e5bc0c8ad5ef229cc7e58 Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 29 Nov 2023 10:45:37 -0800 Subject: [PATCH 07/13] More optimization --- .../time_domain/simulation.py | 157 ++++++++++-------- 1 file changed, 86 insertions(+), 71 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 997908bebf..fe9b7cca08 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -155,31 +155,27 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): Sim.field_derivs = None -def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): +@delayed +def block_deriv( + time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix +): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source_list[0], field_type, 0]) - df_duT = {src: {} for src in source_list} - - rx_count = 0 - for src in source_list: - df_duT[src] = {rx: {} for rx in src.receiver_list} - - for rx in src.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - - cur = derivative_fun( - time_index, - src, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, - ) - df_duT[src][rx] = cur[0] - Jmatrix[rx_ind, :] += cur[1].T + field_len = len(fields[source, field_type, 0]) + df_duT = [] - rx_count += rx.nD + for rx in source.receiver_list: + PTv = rx.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + df_duT.append(cur[0]) + Jmatrix[rx_ind, :] += cur[1].T return df_duT @@ -192,17 +188,24 @@ def compute_field_derivs(simulation, Jmatrix, fields): df_duT = [] for time_index in range(simulation.nT + 1): - df_duT.append( - delayed(block_deriv, pure=True)( - time_index, - simulation._fieldType + "Solution", - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix, + rx_count = 0 + sources_block = [] + for source in simulation.survey.source_list: + sources_block.append( + block_deriv( + time_index, + simulation._fieldType + "Solution", + source, + rx_count, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, + ) ) - ) + rx_count += source.nD + + df_duT.append(sources_block) df_duT = dask.compute(df_duT)[0] @@ -221,8 +224,8 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): row_index = 0 block_count = 0 blocks = {0: {}} - for src in source_list: - for rx in src.receiver_list: + for s_id, src in enumerate(source_list): + for r_id, rx in enumerate(src.receiver_list): indices = np.arange(rx.nD).astype(int) chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size))) @@ -235,7 +238,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): block_count += 1 blocks[block_count] = {} - blocks[block_count][(src, rx, ind)] = chunk, np.arange( + blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange( row_index, row_index + chunk_size ).astype(int) row_index += chunk_size @@ -253,61 +256,68 @@ def get_field_deriv_block( stacked_blocks = [] indices = {} count = 0 - for (src, rx, ind), (rx_ind, j_ind) in block.items(): + for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): # Cut out early data + rx = simulation.survey.source_list[s_id].receiver_list[r_id] time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ rx_ind ] sub_ind = rx_ind[time_check] + local_ind = np.arange(rx_ind.shape[0])[time_check] if len(sub_ind) < 1: continue - indices[(src, rx, ind)] = (np.arange(count, count + len(sub_ind)), sub_ind) + indices[(s_id, r_id, b_id)] = ( + np.arange(count, count + len(sub_ind)), + local_ind, + ) count += len(sub_ind) - if (src, rx, ind) not in ATinv_df_duT_v: + if (s_id, r_id, b_id) not in ATinv_df_duT_v: # last timestep (first to be solved) stacked_blocks.append( - simulation.field_derivs[tInd + 1][src][rx].toarray()[:, sub_ind] + simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind] ) else: Asubdiag = simulation.getAsubdiag(tInd + 1) stacked_blocks.append( np.asarray( - simulation.field_derivs[tInd + 1][src][rx][:, sub_ind] - - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] + simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] ) ) - if len(stacked_blocks) > 1: + if len(stacked_blocks) > 0: solve = AdiagTinv * np.hstack(stacked_blocks) - for src, rx, ind in block: - ATinv_df_duT_v[(src, rx, ind)] = np.zeros( + for s_id, r_id, b_id in block: + ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( ( - simulation.field_derivs[tInd][src][rx].shape[0], - len(block[(src, rx, ind)][0]), + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(block[(s_id, r_id, b_id)][0]), ) ) - if (src, rx, ind) in indices: - columns, sub_ind = indices[(src, rx, ind)] - ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] = solve[:, columns] + if (s_id, r_id, b_id) in indices: + try: + columns, local_ind = indices[(s_id, r_id, b_id)] + ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns] + except: + print("ouch") return ATinv_df_duT_v +@delayed def compute_rows( simulation, tInd, - src, - rx, - rx_ind, - j_ind, + address, # (s_id, r_id, b_id) + indices, # (rx_ind, j_ind), ATinv_df_duT_v, - f, + fields, Jmatrix, ftype, time_mask, @@ -315,27 +325,34 @@ def compute_rows( """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[rx_ind] - sub_ind = rx_ind[time_check] - - if len(sub_ind) < 1: + src = simulation.survey.source_list[address[0]] + rx = src.receiver_list[address[1]] + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + indices[0] + ] + local_ind = np.arange(indices[0].shape[0])[time_check] + + if len(local_ind) < 1: return dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd, + fields[src, ftype, tInd], + ATinv_df_duT_v[address][:, local_ind], + adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) # on nodes of time mesh - un_src = f[src, ftype, tInd + 1] + un_src = fields[src, ftype, tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) - Jmatrix[j_ind[time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T def compute_J(self, f=None, Ainv=None): @@ -367,16 +384,14 @@ def compute_J(self, f=None, Ainv=None): self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - for (src, rx, ind), (rx_ind, j_ind) in block.items(): + for address, indices in block.items(): j_row_updates.append( - delayed(compute_rows, pure=True)( + compute_rows( self, tInd, - src, - rx, - rx_ind, - j_ind, - ATinv_df_duT_v[(src, rx, ind)], + address, + indices, + ATinv_df_duT_v, f, Jmatrix, ftype, From 9ce095661029463567f14a55ad827a97c268225c Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 29 Nov 2023 15:14:28 -0800 Subject: [PATCH 08/13] More parallel blocks --- .../time_domain/simulation.py | 151 +++++++++++------- 1 file changed, 90 insertions(+), 61 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index fe9b7cca08..2039fce1f3 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -156,26 +156,30 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed -def block_deriv( - time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix -): +def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source, field_type, 0]) + field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] + rx_count = 0 + for source in source_list: + sources_block = [] - for rx in source.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - cur = derivative_fun( - time_index, - source, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, - ) - df_duT.append(cur[0]) - Jmatrix[rx_ind, :] += cur[1].T + for rx in source.receiver_list: + PTv = rx.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + sources_block.append(cur[0]) + Jmatrix[rx_ind, :] += cur[1].T + rx_count += rx.nD + + df_duT.append(sources_block) return df_duT @@ -184,28 +188,20 @@ def compute_field_derivs(simulation, Jmatrix, fields): """ Compute the derivative of the fields """ - df_duT = [] for time_index in range(simulation.nT + 1): - rx_count = 0 - sources_block = [] - for source in simulation.survey.source_list: - sources_block.append( - block_deriv( - time_index, - simulation._fieldType + "Solution", - source, - rx_count, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix, - ) + df_duT.append( + block_deriv( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, ) - rx_count += source.nD - - df_duT.append(sources_block) + ) df_duT = dask.compute(df_duT)[0] @@ -247,6 +243,30 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): return blocks +@delayed +def deriv_block( + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd +): + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + # last timestep (first to be solved) + stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[ + :, sub_ind + ] + + else: + stacked_block = np.asarray( + simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] + ) + + return stacked_block + + +def update_deriv_blocks(address, indices, derivatives, solve): + columns, local_ind = indices[address] + derivatives[:, local_ind] = solve[:, columns] + + def get_field_deriv_block( simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask ): @@ -256,6 +276,11 @@ def get_field_deriv_block( stacked_blocks = [] indices = {} count = 0 + + Asubdiag = None + if tInd < simulation.nT - 1: + Asubdiag = simulation.getAsubdiag(tInd + 1) + for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): # Cut out early data rx = simulation.survey.source_list[s_id].receiver_list[r_id] @@ -274,38 +299,44 @@ def get_field_deriv_block( ) count += len(sub_ind) - if (s_id, r_id, b_id) not in ATinv_df_duT_v: - # last timestep (first to be solved) - stacked_blocks.append( - simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind] - ) - - else: - Asubdiag = simulation.getAsubdiag(tInd + 1) - stacked_blocks.append( - np.asarray( - simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] - - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] - ) + stacked_blocks.append( + deriv_block( + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + sub_ind, + simulation, + tInd, ) + ) if len(stacked_blocks) > 0: - solve = AdiagTinv * np.hstack(stacked_blocks) + solve = AdiagTinv * np.hstack(dask.compute(stacked_blocks)[0]) + update_list = [] for s_id, r_id, b_id in block: - ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( - ( - simulation.field_derivs[tInd][s_id][r_id].shape[0], - len(block[(s_id, r_id, b_id)][0]), + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( + ( + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(block[(s_id, r_id, b_id)][0]), + ) ) - ) if (s_id, r_id, b_id) in indices: - try: - columns, local_ind = indices[(s_id, r_id, b_id)] - ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns] - except: - print("ouch") + update_list.append( + update_deriv_blocks( + (s_id, r_id, b_id), + indices, + ATinv_df_duT_v[(s_id, r_id, b_id)], + solve, + ) + ) + + dask.compute(update_list) return ATinv_df_duT_v @@ -370,9 +401,7 @@ def compute_J(self, f=None, Ainv=None): blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) - self.field_derivs = compute_field_derivs(self, Jmatrix, f) - ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] From 2d98d0a1d1e670b84666536555b71629a1b71a5a Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 29 Nov 2023 16:05:59 -0800 Subject: [PATCH 09/13] Slight change --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 2039fce1f3..85d8daf66d 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -264,7 +264,7 @@ def deriv_block( def update_deriv_blocks(address, indices, derivatives, solve): columns, local_ind = indices[address] - derivatives[:, local_ind] = solve[:, columns] + derivatives[address][:, local_ind] = solve[:, columns] def get_field_deriv_block( @@ -331,7 +331,7 @@ def get_field_deriv_block( update_deriv_blocks( (s_id, r_id, b_id), indices, - ATinv_df_duT_v[(s_id, r_id, b_id)], + ATinv_df_duT_v, solve, ) ) From aa7d2c242a09c2860cf9d03e1cb46c6738b54087 Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 29 Nov 2023 23:24:42 -0800 Subject: [PATCH 10/13] Delay large arrays --- .../time_domain/simulation.py | 81 ++++++++++--------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 85d8daf66d..f72a223a95 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -229,7 +229,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): chunk_size = len(chunk) # Condition to start a new block - if (row_count + chunk_size) > (data_block_size * cpu_count() / 2): + if (row_count + chunk_size) > (data_block_size * cpu_count()): row_count = 0 block_count += 1 blocks[block_count] = {} @@ -262,9 +262,17 @@ def deriv_block( return stacked_block -def update_deriv_blocks(address, indices, derivatives, solve): - columns, local_ind = indices[address] - derivatives[address][:, local_ind] = solve[:, columns] +def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): + if address not in derivatives: + deriv_array = np.zeros(shape) + else: + deriv_array = derivatives[address].compute() + + if address in indices: + columns, local_ind = indices[address] + deriv_array[:, local_ind] = solve[:, columns] + + derivatives[address] = delayed(deriv_array) def get_field_deriv_block( @@ -298,44 +306,41 @@ def get_field_deriv_block( local_ind, ) count += len(sub_ind) + deriv_comp = deriv_block( + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + sub_ind, + simulation, + tInd, + ) stacked_blocks.append( - deriv_block( - s_id, - r_id, - b_id, - ATinv_df_duT_v, - Asubdiag, - local_ind, - sub_ind, - simulation, - tInd, + array.from_delayed( + deriv_comp, + dtype=float, + shape=( + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(local_ind), + ), ) ) - if len(stacked_blocks) > 0: - solve = AdiagTinv * np.hstack(dask.compute(stacked_blocks)[0]) + blocks = array.hstack(stacked_blocks).compute() + solve = AdiagTinv * blocks update_list = [] - for s_id, r_id, b_id in block: - if (s_id, r_id, b_id) not in ATinv_df_duT_v: - ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( - ( - simulation.field_derivs[tInd][s_id][r_id].shape[0], - len(block[(s_id, r_id, b_id)][0]), - ) - ) - - if (s_id, r_id, b_id) in indices: - update_list.append( - update_deriv_blocks( - (s_id, r_id, b_id), - indices, - ATinv_df_duT_v, - solve, - ) - ) - + for address in block: + shape = ( + simulation.field_derivs[tInd][address[0]][address[1]].shape[0], + len(block[address][0]), + ) + update_list.append( + update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) + ) dask.compute(update_list) return ATinv_df_duT_v @@ -395,7 +400,7 @@ def compute_J(self, f=None, Ainv=None): f, Ainv = self.fields(self.model, return_Ainv=True) ftype = self._fieldType + "Solution" - Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) + Jmatrix = delayed(np.zeros((self.survey.nD, self.model.size), dtype=np.float32)) simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times blocks = get_parallel_blocks( @@ -427,9 +432,7 @@ def compute_J(self, f=None, Ainv=None): time_mask, ) ) - dask.compute(j_row_updates) - for A in Ainv.values(): A.clean() @@ -437,7 +440,7 @@ def compute_J(self, f=None, Ainv=None): del Jmatrix return array.from_zarr(self.sensitivity_path + f"J.zarr") else: - return Jmatrix + return Jmatrix.compute() Sim.compute_J = compute_J From 3e26bb32aa3f0e5eeb04f5c08c7b044bd1b00776 Mon Sep 17 00:00:00 2001 From: fourndo Date: Thu, 30 Nov 2023 10:44:58 -0800 Subject: [PATCH 11/13] Fix array splitting --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index f72a223a95..7d4e399b04 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -223,7 +223,9 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): for s_id, src in enumerate(source_list): for r_id, rx in enumerate(src.receiver_list): indices = np.arange(rx.nD).astype(int) - chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size))) + chunks = np.array_split( + indices, int(np.ceil(len(indices) / data_block_size)) + ) for ind, chunk in enumerate(chunks): chunk_size = len(chunk) From adb6d579db66dedb73e72afc332c9118f4091709 Mon Sep 17 00:00:00 2001 From: fourndo Date: Thu, 30 Nov 2023 10:48:02 -0800 Subject: [PATCH 12/13] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 160c6169d5..a26dbc830e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ [tool.poetry] name = "Mira-SimPEG" -version = "0.19.0.dev4" +version = "0.19.0.dev5" license = "MIT" description = "Mira Geoscience fork of SimPEG: Simulation and Parameter Estimation in Geophysics" From fa7b526e359e5f59332836dd54707a62c4486c17 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Nov 2023 12:42:17 -0800 Subject: [PATCH 13/13] Fix for empty set --- SimPEG/dask/electromagnetics/time_domain/simulation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 7d4e399b04..d483e57436 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -333,6 +333,8 @@ def get_field_deriv_block( if len(stacked_blocks) > 0: blocks = array.hstack(stacked_blocks).compute() solve = AdiagTinv * blocks + else: + solve = None update_list = [] for address in block: