diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 5ce1b66869..d483e57436 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -155,265 +155,296 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): Sim.field_derivs = None -def compute_J(self, f=None, Ainv=None): - if f is None: - f, Ainv = self.fields(self.model, return_Ainv=True) +@delayed +def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): + """Compute derivatives for sources and receivers in a block""" + field_len = len(fields[source_list[0], field_type, 0]) + df_duT = [] + rx_count = 0 + for source in source_list: + sources_block = [] + + for rx in source.receiver_list: + PTv = rx.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + sources_block.append(cur[0]) + Jmatrix[rx_ind, :] += cur[1].T + rx_count += rx.nD - m_size = self.model.size - row_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) - ) - ) + df_duT.append(sources_block) - if self.store_sensitivities == "disk": - self.J_initializer = zarr.open( - self.sensitivity_path + f"J_initializer.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), + return df_duT + + +def compute_field_derivs(simulation, Jmatrix, fields): + """ + Compute the derivative of the fields + """ + df_duT = [] + + for time_index in range(simulation.nT + 1): + df_duT.append( + block_deriv( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, + ) ) - else: - self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32) - solution_type = self._fieldType + "Solution" # the thing we solved for - - if self.field_derivs is None: - # print("Start loop for field derivs") - block_size = len(f[self.survey.source_list[0], solution_type, 0]) - - field_derivs = [] - for tInd in range(self.nT + 1): - d_count = 0 - df_duT_v = [] - for i_s, src in enumerate(self.survey.source_list): - src_field_derivs = delayed(block_deriv, pure=True)( - self, src, tInd, f, block_size, d_count - ) - df_duT_v += [src_field_derivs] - d_count += np.sum([rx.nD for rx in src.receiver_list]) - field_derivs += [df_duT_v] - # print("Dask loop field derivs") - # tc = time() + df_duT = dask.compute(df_duT)[0] - self.field_derivs = dask.compute(field_derivs)[0] - # print(f"Done in {time() - tc} seconds") + return df_duT - if self.store_sensitivities == "disk": - Jmatrix = ( - zarr.open( - self.sensitivity_path + f"J.zarr", - mode="w", - shape=(self.survey.nD, m_size), - chunks=(row_chunks, m_size), + +def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): + """ + Get the blocks of sources and receivers to be computed in parallel. + + Stored as a dictionary of source, receiver pairs index. The value is an array of indices + for the rows of the sensitivity matrix. + """ + data_block_size = np.ceil(max_chunk_size / (m_size * 8.0 * 1e-6)) + row_count = 0 + row_index = 0 + block_count = 0 + blocks = {0: {}} + for s_id, src in enumerate(source_list): + for r_id, rx in enumerate(src.receiver_list): + indices = np.arange(rx.nD).astype(int) + chunks = np.array_split( + indices, int(np.ceil(len(indices) / data_block_size)) ) - + self.J_initializer - ) - else: - Jmatrix = dask.delayed( - np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer - ) - f = dask.delayed(f) - field_derivs_t = {} - d_block_size = np.ceil(128.0 / (m_size * 8.0 * 1e-6)) + for ind, chunk in enumerate(chunks): + chunk_size = len(chunk) - # Check which time steps we need to compute - simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - data_times = self.survey.source_list[0].receiver_list[0].times - n_times = len(data_times) + # Condition to start a new block + if (row_count + chunk_size) > (data_block_size * cpu_count()): + row_count = 0 + block_count += 1 + blocks[block_count] = {} - for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - AdiagTinv = Ainv[dt] - Asubdiag = self.getAsubdiag(tInd) - row_count = 0 - row_blocks = [] - field_derivs = {} - source_blocks = [] - d_count = 0 + blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange( + row_index, row_index + chunk_size + ).astype(int) + row_index += chunk_size + row_count += chunk_size - data_bool = data_times > simulation_times[tInd] + return blocks - if data_bool.sum() == 0: - continue - # tc_loop = time() - # print(f"Loop sources for {tInd}") - for isrc, src in enumerate(self.survey.source_list): - column_inds = np.kron( - np.ones(int(src.vnD / n_times), dtype=bool), data_bool - ) - # for block in range(len(self.field_derivs[tInd][isrc])): - if isrc not in field_derivs_t: - field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ - :, column_inds - ] - else: - field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] - - n_data = self.field_derivs[tInd + 1][isrc].shape[1] - d_count += column_inds.sum() - - if d_count > d_block_size: - source_blocks = block_append( - self, - f, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, - ) - field_derivs = {} - row_count = d_count - d_count = 0 - - if field_derivs: - source_blocks = block_append( - self, - f, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, - ) +@delayed +def deriv_block( + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd +): + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + # last timestep (first to be solved) + stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[ + :, sub_ind + ] - # print(f"Done in {time() - tc_loop} seconds") - # tc = time() - # print(f"Compute field derivs for {tInd}") - del field_derivs_t - field_derivs_t = { - isrc: elem for isrc, elem in enumerate(dask.compute(source_blocks)[0]) - } - # print(f"Done in {time() - tc} seconds") + else: + stacked_block = np.asarray( + simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] + ) - for A in Ainv.values(): - A.clean() + return stacked_block - if self.store_sensitivities == "disk": - del Jmatrix - return array.from_zarr(self.sensitivity_path + f"J.zarr") + +def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): + if address not in derivatives: + deriv_array = np.zeros(shape) else: - return Jmatrix.compute() + deriv_array = derivatives[address].compute() + if address in indices: + columns, local_ind = indices[address] + deriv_array[:, local_ind] = solve[:, columns] -Sim.compute_J = compute_J + derivatives[address] = delayed(deriv_array) -def block_append( - simulation, - fields, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, +def get_field_deriv_block( + simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask ): - solves = AdiagTinv * np.hstack(list(field_derivs.values())) + """ + Stack the blocks of field derivatives for a given timestep and call the direct solver. + """ + stacked_blocks = [] + indices = {} count = 0 - for (isrc, src), block in field_derivs.items(): - column_inds = np.kron( - np.ones(int(src.vnD / len(data_bool)), dtype=bool), data_bool - ) - n_rows = column_inds.sum() - source_blocks.append( - dask.array.from_delayed( - delayed(parallel_block_compute, pure=True)( - simulation, - fields, - src, - solves[:, count : count + n_rows], - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - simulation.field_derivs[tInd][isrc], - column_inds, - ), - shape=simulation.field_derivs[tInd + 1][isrc].shape, - dtype=np.float32, - ) - ) - count += n_rows - # print(f"Appending block {isrc} in {time() - tc} seconds") - row_count += len(column_inds) + Asubdiag = None + if tInd < simulation.nT - 1: + Asubdiag = simulation.getAsubdiag(tInd + 1) - return source_blocks + for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): + # Cut out early data + rx = simulation.survey.source_list[s_id].receiver_list[r_id] + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + rx_ind + ] + sub_ind = rx_ind[time_check] + local_ind = np.arange(rx_ind.shape[0])[time_check] + if len(sub_ind) < 1: + continue -def block_deriv(simulation, src, tInd, f, block_size, row_count): - src_field_derivs = None - for rx in src.receiver_list: - v = sp.eye(rx.nD, dtype=float) - PT_v = rx.evalDeriv( - src, simulation.mesh, simulation.time_mesh, f, v, adjoint=True + indices[(s_id, r_id, b_id)] = ( + np.arange(count, count + len(sub_ind)), + local_ind, ) - df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) - - cur = df_duTFun( - simulation.nT, - src, - None, - PT_v[tInd * block_size : (tInd + 1) * block_size, :], - adjoint=True, + count += len(sub_ind) + deriv_comp = deriv_block( + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + sub_ind, + simulation, + tInd, ) - if not isinstance(cur[1], Zero): - simulation.J_initializer[row_count : row_count + rx.nD, :] += cur[1].T + stacked_blocks.append( + array.from_delayed( + deriv_comp, + dtype=float, + shape=( + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(local_ind), + ), + ) + ) + if len(stacked_blocks) > 0: + blocks = array.hstack(stacked_blocks).compute() + solve = AdiagTinv * blocks + else: + solve = None - if src_field_derivs is None: - src_field_derivs = cur[0] - else: - src_field_derivs += cur[0] + update_list = [] + for address in block: + shape = ( + simulation.field_derivs[tInd][address[0]][address[1]].shape[0], + len(block[address][0]), + ) + update_list.append( + update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) + ) + dask.compute(update_list) - # n_blocks = int(np.ceil(np.prod(src_field_derivs.shape) * 8. * 1e-6 / 128.)) - # ind_col = np.array_split(np.arange(src_field_derivs.shape[1]), col_blocks) - # return [src_field_derivs[:, ind] for ind in ind_col] - return src_field_derivs + return ATinv_df_duT_v -def parallel_block_compute( +@delayed +def compute_rows( simulation, - f, - src, - ATinv_df_duT_v, - row_count, tInd, - solution_type, + address, # (s_id, r_id, b_id) + indices, # (rx_ind, j_ind), + ATinv_df_duT_v, + fields, Jmatrix, - Asubdiag, - field_derivs, - data_bool, + ftype, + time_mask, ): - rows = np.arange(row_count, row_count + len(data_bool))[data_bool] - field_derivs_t = np.asarray(field_derivs.todense()) - field_derivs_t[:, data_bool] -= Asubdiag.T * ATinv_df_duT_v + """ + Compute the rows of the sensitivity matrix for a given source and receiver. + """ + src = simulation.survey.source_list[address[0]] + rx = src.receiver_list[address[1]] + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + indices[0] + ] + local_ind = np.arange(indices[0].shape[0])[time_check] + + if len(local_ind) < 1: + return dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, solution_type, tInd], ATinv_df_duT_v, adjoint=True + tInd, + fields[src, ftype, tInd], + ATinv_df_duT_v[address][:, local_ind], + adjoint=True, + ) + + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True + ) # on nodes of time mesh + + un_src = fields[src, ftype, tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) - dRHST_dm_v = simulation.getRHSDeriv(tInd + 1, src, ATinv_df_duT_v, adjoint=True) - un_src = f[src, solution_type, tInd + 1] - dAT_dm_v = simulation.getAdiagDeriv(tInd, un_src, ATinv_df_duT_v, adjoint=True) - Jmatrix[rows, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - return field_derivs_t + Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + + +def compute_J(self, f=None, Ainv=None): + """ + Compute the rows for the sensitivity matrix. + """ + + if f is None: + f, Ainv = self.fields(self.model, return_Ainv=True) + + ftype = self._fieldType + "Solution" + Jmatrix = delayed(np.zeros((self.survey.nD, self.model.size), dtype=np.float32)) + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + blocks = get_parallel_blocks( + self.survey.source_list, self.model.shape[0], self.max_chunk_size + ) + self.field_derivs = compute_field_derivs(self, Jmatrix, f) + ATinv_df_duT_v = {} + for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): + AdiagTinv = Ainv[dt] + j_row_updates = [] + time_mask = data_times > simulation_times[tInd] + + for block in blocks.values(): + ATinv_df_duT_v = get_field_deriv_block( + self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + ) + + for address, indices in block.items(): + j_row_updates.append( + compute_rows( + self, + tInd, + address, + indices, + ATinv_df_duT_v, + f, + Jmatrix, + ftype, + time_mask, + ) + ) + dask.compute(j_row_updates) + for A in Ainv.values(): + A.clean() + + if self.store_sensitivities == "disk": + del Jmatrix + return array.from_zarr(self.sensitivity_path + f"J.zarr") + else: + return Jmatrix.compute() + + +Sim.compute_J = compute_J diff --git a/SimPEG/directives/directives.py b/SimPEG/directives/directives.py index 8150dec3c8..d5daf81a1a 100644 --- a/SimPEG/directives/directives.py +++ b/SimPEG/directives/directives.py @@ -2934,7 +2934,8 @@ def save_components(self, iteration: int, values: list[np.ndarray] = None): elif self.attribute_type == "predicted": dpred = getattr(self.invProb, "dpred", None) if dpred is None: - dpred = self.invProb.get_dpred(self.invProb.model) + fields = self.invProb.dmisfit.objfcts[0].simulation.fields(self.invProb.model) + dpred = self.invProb.get_dpred(self.invProb.model, [fields]) self.invProb.dpred = dpred if self.joint_index is not None: diff --git a/pyproject.toml b/pyproject.toml index 160c6169d5..a26dbc830e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ [tool.poetry] name = "Mira-SimPEG" -version = "0.19.0.dev4" +version = "0.19.0.dev5" license = "MIT" description = "Mira Geoscience fork of SimPEG: Simulation and Parameter Estimation in Geophysics"