diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 14eab5c957..fee66abde2 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -1,12 +1,14 @@ import dask import dask.array - +import os from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim from ....utils import Zero +from SimPEG.fields import TimeFields from multiprocessing import cpu_count import numpy as np import scipy.sparse as sp from dask import array, delayed + from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks import zarr @@ -22,6 +24,76 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +@delayed +def field_projection(field_array, src_list, array_ind, time_ind, func): + fieldI = field_array[:, :, array_ind] + if fieldI.shape[0] == fieldI.size: + fieldI = mkvc(fieldI, 2) + new_array = func(fieldI, src_list, time_ind) + if new_array.ndim == 1: + new_array = new_array[:, np.newaxis, np.newaxis] + elif new_array.ndim == 2: + new_array = new_array[:, :, np.newaxis] + + return new_array + + +def _getField(self, name, ind, src_list): + srcInd, timeInd = ind + + if name in self._fields: + out = self._fields[name][:, srcInd, timeInd] + else: + # Aliased fields + alias, loc, func = self.aliasFields[name] + if isinstance(func, str): + assert hasattr(self, func), ( + "The alias field function is a string, but it does " + "not exist in the Fields class." + ) + func = getattr(self, func) + pointerFields = self._fields[alias][:, srcInd, timeInd] + pointerShape = self._correctShape(alias, ind) + pointerFields = pointerFields.reshape(pointerShape, order="F") + + # First try to return the function as three arguments (without timeInd) + if timeInd == slice(None, None, None): + try: + # assume it will take care of integrating over all times + return func(pointerFields, srcInd) + except TypeError: + pass + + timeII = np.arange(self.simulation.nT + 1)[timeInd] + if not isinstance(src_list, list): + src_list = [src_list] + + if timeII.size == 1: + pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) + pointerFields = pointerFields.reshape(pointerShapeDeflated, order="F") + out = func(pointerFields, src_list, timeII) + else: # loop over the time steps + nT = pointerShape[2] + arrays = [] + + for i, TIND_i in enumerate(timeII): # Need to parallelize this + arrays.append( + array.from_delayed( + field_projection(pointerFields, src_list, i, TIND_i, func), + dtype=np.float32, + shape=(pointerShape[0], pointerShape[1], 1), + ) + ) + + out = array.dstack(arrays).compute() + + shape = self._correctShape(name, ind, deflate=True) + return out.reshape(shape, order="F") + + +TimeFields._getField = _getField + + def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m @@ -102,6 +174,19 @@ def source_evaluation(simulation, sources, time): Sim.getSourceTerm = dask_getSourceTerm +@delayed +def evaluate_receivers(block, mesh, time_mesh, fields, fields_array): + data = [] + for source, ind, receiver in block: + Ps = receiver.getSpatialP(mesh, fields) + Pt = receiver.getTimeP(time_mesh, fields) + vector = (Pt * (Ps * fields_array[:, ind, :]).T).flatten() + + data.append(vector) + + return np.hstack(data) + + def dask_dpred(self, m=None, f=None, compute_J=False): """ dpred(m, f=None) @@ -121,27 +206,35 @@ def dask_dpred(self, m=None, f=None, compute_J=False): "data. Please set the survey for the simulation: " "simulation.survey = survey" ) - # ct = time() + if f is None: if m is None: m = self.model f, Ainv = self.fields(m, return_Ainv=compute_J) - # print(f"took {time() - ct} s to compute fields") - def evaluate_receiver(source, receiver, mesh, time_mesh, fields): - return receiver.eval(source, mesh, time_mesh, fields).flatten() - - row = delayed(evaluate_receiver, pure=True) rows = [] - for src in self.survey.source_list: + receiver_projection = self.survey.source_list[0].receiver_list[0].projField + fields_array = f[:, receiver_projection, :] + all_receivers = [] + + for ind, src in enumerate(self.survey.source_list): for rx in src.receiver_list: - rows.append( - array.from_delayed( - row(src, rx, self.mesh, self.time_mesh, f), - dtype=np.float32, - shape=(rx.nD,), - ) + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(all_receivers, cpu_count()) + + for block in receiver_blocks: + n_data = np.sum(rec.nD for _, _, rec in block) + if n_data == 0: + continue + + rows.append( + array.from_delayed( + evaluate_receivers(block, self.mesh, self.time_mesh, f, fields_array), + dtype=np.float64, + shape=(n_data,), ) + ) data = array.hstack(rows).compute() @@ -157,95 +250,136 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed -def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): +def delayed_block_deriv( + n_times, chunks, field_len, source_list, mesh, time_mesh, fields, shape +): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] - rx_count = 0 - for source in source_list: - sources_block = [] - - for rx in source.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + j_updates = [] + + for indices, arrays in chunks: + j_update = 0.0 + source = source_list[indices[0]] + receiver = source.receiver_list[indices[1]] + + spatialP = receiver.getSpatialP(mesh, fields) + timeP = receiver.getTimeP(time_mesh, fields) + + derivative_fun = getattr(fields, "_{}Deriv".format(receiver.projField), None) + time_derivs = [] + for time_index in range(n_times + 1): + if len(timeP[:, time_index].data) == 0: + time_derivs.append( + sp.csr_matrix((field_len, len(arrays[0])), dtype=np.float32) + ) + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) + continue + + projection = sp.kron(timeP[:, time_index], spatialP) cur = derivative_fun( time_index, source, None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + projection.T, adjoint=True, ) - sources_block.append(cur[0]) + time_derivs.append(cur[0]) if not isinstance(cur[1], Zero): - Jmatrix[rx_ind, :] += cur[1].T - - rx_count += rx.nD + j_update += cur[1].T + else: + j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32) - df_duT.append(sources_block) + j_updates.append(j_update) + df_duT.append(time_derivs) - return df_duT + return df_duT, j_updates -def compute_field_derivs(simulation, Jmatrix, fields): +def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape): """ Compute the derivative of the fields """ - df_duT = [] + delayed_chunks = [] + for chunks in blocks: + if len(chunks) == 0: + continue - for time_index in range(simulation.nT + 1): - df_duT.append( - block_deriv( - time_index, - simulation._fieldType + "Solution", - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix, - ) + delayed_block = delayed_block_deriv( + simulation.nT, + chunks, + fields_shape[0], + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + simulation.model.size, ) + delayed_chunks.append(delayed_block) + + result = dask.compute(delayed_chunks)[0] + df_duT = [ + [[[] for _ in block] for block in blocks if len(block) > 0] + for _ in range(simulation.nT + 1) + ] + j_updates = [] + + for bb, block in enumerate(result): + j_updates += block[1] + for cc, chunk in enumerate(block[0]): + for ind, time_block in enumerate(chunk): + df_duT[ind][bb][cc] = time_block - df_duT = dask.compute(df_duT)[0] + j_updates = sp.vstack(j_updates) - return df_duT + if len(j_updates.data) > 0: + Jmatrix += j_updates + if simulation.store_sensitivities == "disk": + sens_name = simulation.sensitivity_path[:-5] + f"_{time_index % 2}.zarr" + array.to_zarr(Jmatrix, sens_name, compute=True, overwrite=True) + Jmatrix = array.from_zarr(sens_name) + + return df_duT, Jmatrix @delayed def deriv_block( - s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, field_derivs, tInd ): if (s_id, r_id, b_id) not in ATinv_df_duT_v: # last timestep (first to be solved) - stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[ - :, sub_ind - ] + stacked_block = field_derivs.toarray()[:, sub_ind] else: stacked_block = np.asarray( - simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + field_derivs[:, sub_ind] - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] ) return stacked_block -def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape): +def update_deriv_blocks(address, indices, derivatives, solve, shape): if address not in derivatives: deriv_array = np.zeros(shape) else: - deriv_array = derivatives[address].compute() + deriv_array = derivatives[address] if address in indices: columns, local_ind = indices[address] deriv_array[:, local_ind] = solve[:, columns] - derivatives[address] = delayed(deriv_array) + derivatives[address] = deriv_array def get_field_deriv_block( - simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask + simulation, + block: list, + field_derivs: list, + tInd: int, + AdiagTinv, + ATinv_df_duT_v: dict, + time_mask, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. @@ -258,12 +392,11 @@ def get_field_deriv_block( if tInd < simulation.nT - 1: Asubdiag = simulation.getAsubdiag(tInd + 1) - for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): + for ((s_id, r_id, b_id), (rx_ind, j_ind, shape)), field_deriv in zip( + block, field_derivs + ): # Cut out early data - rx = simulation.survey.source_list[s_id].receiver_list[r_id] - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ - rx_ind - ] + time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] sub_ind = rx_ind[time_check] local_ind = np.arange(rx_ind.shape[0])[time_check] @@ -283,7 +416,7 @@ def get_field_deriv_block( Asubdiag, local_ind, sub_ind, - simulation, + field_deriv, tInd, ) @@ -292,27 +425,25 @@ def get_field_deriv_block( deriv_comp, dtype=float, shape=( - simulation.field_derivs[tInd][s_id][r_id].shape[0], + field_deriv.shape[0], len(local_ind), ), ) ) if len(stacked_blocks) > 0: blocks = array.hstack(stacked_blocks).compute() - solve = AdiagTinv * blocks + solve = (AdiagTinv * blocks).reshape(blocks.shape) else: solve = None update_list = [] - for address in block: + for (address, arrays), field_deriv in zip(block, field_derivs): shape = ( - simulation.field_derivs[tInd][address[0]][address[1]].shape[0], - len(block[address][0]), + field_deriv.shape[0], + len(arrays[0]), ) - update_list.append( - update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape) - ) - dask.compute(update_list) + + update_deriv_blocks(address, indices, ATinv_df_duT_v, solve, shape) return ATinv_df_duT_v @@ -321,97 +452,140 @@ def get_field_deriv_block( def compute_rows( simulation, tInd, - address, # (s_id, r_id, b_id) - indices, # (rx_ind, j_ind), + chunks, ATinv_df_duT_v, fields, - Jmatrix, - ftype, time_mask, ): """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - src = simulation.survey.source_list[address[0]] - rx = src.receiver_list[address[1]] - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ - indices[0] - ] - local_ind = np.arange(indices[0].shape[0])[time_check] + n_rows = np.sum(len(chunk[1][0]) for chunk in chunks) + rows = [] - if len(local_ind) < 1: - return + for address, ind_array in chunks: + src = simulation.survey.source_list[address[0]] + time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] + local_ind = np.arange(len(ind_array[0]))[time_check] - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, - fields[src, ftype, tInd], - ATinv_df_duT_v[address][:, local_ind], - adjoint=True, - ) + if len(local_ind) < 1: + return + + field_derivs = ATinv_df_duT_v[address] + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( + tInd, + fields[:, address[0], tInd], + field_derivs[:, local_ind], + adjoint=True, + ) - dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True - ) # on nodes of time mesh + dRHST_dm_v = simulation.getRHSDeriv( + tInd + 1, src, field_derivs[:, local_ind], adjoint=True + ) # on nodes of time mesh - un_src = fields[src, ftype, tInd + 1] - # cell centered on time mesh - dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True - ) + un_src = fields[:, address[0], tInd + 1] + # cell centered on time mesh + dAT_dm_v = simulation.getAdiagDeriv( + tInd, un_src, field_derivs[:, local_ind], adjoint=True + ) + row_block = np.zeros( + (len(ind_array[1]), simulation.model.size), dtype=np.float32 + ) + row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype( + np.float32 + ) + rows.append(row_block) - Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + return np.vstack(rows) def compute_J(self, f=None, Ainv=None): """ Compute the rows for the sensitivity matrix. """ - if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) ftype = self._fieldType + "Solution" - Jmatrix = delayed(np.zeros((self.survey.nD, self.model.size), dtype=np.float32)) + sens_name = self.sensitivity_path[:-5] + if self.store_sensitivities == "disk": + rows = array.zeros( + (self.survey.nD, self.model.size), + chunks=(self.max_chunk_size, self.model.size), + dtype=np.float32, + ) + Jmatrix = array.to_zarr( + rows, + os.path.join(sens_name + "_1.zarr"), + compute=True, + return_stored=True, + overwrite=True, + ) + else: + Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float64) + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) - self.field_derivs = compute_field_derivs(self, Jmatrix, f) + fields_array = f[:, ftype, :] + times_field_derivs, Jmatrix = compute_field_derivs( + self, f, blocks, Jmatrix, fields_array.shape + ) + ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] j_row_updates = [] time_mask = data_times > simulation_times[tInd] - for block in blocks.values(): + tc = time() + for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]): ATinv_df_duT_v = get_field_deriv_block( - self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + self, block, field_deriv, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - for address, indices in block.items(): - j_row_updates.append( + if len(block) == 0: + continue + + j_row_updates.append( + array.from_delayed( compute_rows( self, tInd, - address, - indices, + block, ATinv_df_duT_v, - f, - Jmatrix, - ftype, + fields_array, time_mask, - ) + ), + dtype=np.float32, + shape=( + np.sum(len(chunk[1][0]) for chunk in block), + self.model.size, + ), ) - dask.compute(j_row_updates) + ) + + if self.store_sensitivities == "disk": + sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr" + array.to_zarr( + Jmatrix + array.vstack(j_row_updates), + sens_name, + compute=True, + overwrite=True, + ) + Jmatrix = array.from_zarr(sens_name) + else: + Jmatrix += array.vstack(j_row_updates).compute() + for A in Ainv.values(): A.clean() - if self.store_sensitivities == "disk": - del Jmatrix - return array.from_zarr(self.sensitivity_path + f"J.zarr") - else: - return Jmatrix.compute() + if self.store_sensitivities == "ram": + return np.asarray(Jmatrix) + + return Jmatrix Sim.compute_J = compute_J diff --git a/SimPEG/dask/simulation.py b/SimPEG/dask/simulation.py index b60cb47d3f..06b69a0ae9 100644 --- a/SimPEG/dask/simulation.py +++ b/SimPEG/dask/simulation.py @@ -9,6 +9,7 @@ Sim._max_ram = 16 + @property def max_ram(self): "Maximum ram in (Gb)" @@ -62,7 +63,7 @@ def n_cpu(self, other): def make_synthetic_data( - self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs + self, m, relative_error=0.05, noise_floor=0.0, f=None, add_noise=False, **kwargs ): """ Make synthetic data given a model, and a standard deviation. @@ -106,12 +107,13 @@ def make_synthetic_data( noise_floor=noise_floor, ) + Sim.make_synthetic_data = make_synthetic_data @property def workers(self): - if getattr(self, '_workers', None) is None: + if getattr(self, "_workers", None) is None: self._workers = None return self._workers @@ -127,7 +129,7 @@ def workers(self, workers): def dask_Jvec(self, m, v): """ - Compute sensitivity matrix (J) and vector (v) product. + Compute sensitivity matrix (J) and vector (v) product. """ self.model = m @@ -145,7 +147,7 @@ def dask_Jvec(self, m, v): def dask_Jtvec(self, m, v): """ - Compute adjoint sensitivity matrix (J^T) and vector (v) product. + Compute adjoint sensitivity matrix (J^T) and vector (v) product. """ self.model = m @@ -174,13 +176,11 @@ def Jmatrix(self): if self.store_sensitivities == "ram": self._Jmatrix = client.persist( - delayed(self.compute_J)(), - workers=self.workers + delayed(self.compute_J)(), workers=self.workers ) else: self._Jmatrix = client.compute( - delayed(self.compute_J)(), - workers=self.workers + delayed(self.compute_J)(), workers=self.workers ) elif isinstance(self._Jmatrix, Future): @@ -226,11 +226,13 @@ def evaluate_receiver(source, receiver, mesh, fields): rows = [] for src in self.survey.source_list: for rx in src.receiver_list: - rows.append(array.from_delayed( - row(src, rx, self.mesh, f), - dtype=np.float32, - shape=(rx.nD,), - )) + rows.append( + array.from_delayed( + row(src, rx, self.mesh, f), + dtype=np.float32, + shape=(rx.nD,), + ) + ) data = array.hstack(rows).compute() @@ -246,7 +248,7 @@ def evaluate_receiver(source, receiver, mesh, fields): def dask_getJtJdiag(self, m, W=None): """ - Return the diagonal of JtJ + Return the diagonal of JtJ """ self.model = m if getattr(self, "_jtjdiag", None) is None: @@ -258,11 +260,11 @@ def dask_getJtJdiag(self, m, W=None): else: W = W.diagonal() ** 2.0 - diag = array.einsum('i,ij,ij->j', W, self.Jmatrix, self.Jmatrix) + diag = array.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) if isinstance(diag, array.Array): diag = np.asarray(diag.compute()) self._jtjdiag = diag - return self._jtjdiag \ No newline at end of file + return self._jtjdiag diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index 037961c1ab..c287bc4ad6 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -40,18 +40,20 @@ def compute(self, job): return job.compute() -def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): +def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> list: """ Get the blocks of sources and receivers to be computed in parallel. - Stored as a dictionary of source, receiver pairs index. The value is an array of indices + Stored as a list of tuples for + (source, receiver, block index) and array of indices for the rows of the sensitivity matrix. """ data_block_size = np.ceil(max_chunk_size / (m_size * 8.0 * 1e-6)) row_count = 0 row_index = 0 block_count = 0 - blocks = {0: {}} + blocks = [[]] + for s_id, src in enumerate(source_list): for r_id, rx in enumerate(src.receiver_list): indices = np.arange(rx.nD).astype(int) @@ -66,12 +68,27 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): if (row_count + chunk_size) > (data_block_size * cpu_count()): row_count = 0 block_count += 1 - blocks[block_count] = {} + blocks.append = [] - blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange( - row_index, row_index + chunk_size - ).astype(int) + blocks[block_count].append( + ( + (s_id, r_id, ind), + ( + chunk, + np.arange(row_index, row_index + chunk_size).astype(int), + rx.locations.shape[0], + ), + ) + ) row_index += chunk_size row_count += chunk_size + # Re-split over cpu_count if too few blocks + if len(blocks) < cpu_count(): + flatten_blocks = [] + for block in blocks: + flatten_blocks += block + + chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + return [[flatten_blocks[i] for i in chunk] for chunk in chunks] return blocks diff --git a/SimPEG/electromagnetics/time_domain/fields.py b/SimPEG/electromagnetics/time_domain/fields.py index 384432c736..7a0b45e3b0 100644 --- a/SimPEG/electromagnetics/time_domain/fields.py +++ b/SimPEG/electromagnetics/time_domain/fields.py @@ -56,7 +56,8 @@ def _dbdtDeriv(self, tInd, src, dun_dm_v, v, adjoint=False): if adjoint is True: return ( self._dbdtDeriv_u(tInd, src, v, adjoint), - self._dbdtDeriv_m(tInd, src, v, adjoint), + Zero() + # self._dbdtDeriv_m(tInd, src, v, adjoint), ) return self._dbdtDeriv_u(tInd, src, dun_dm_v) + self._dbdtDeriv_m(tInd, src, v) @@ -160,7 +161,7 @@ def _dbdt(self, bSolution, source_list, tInd): def _dbdtDeriv_u(self, tInd, src, dun_dm_v, adjoint=False): if adjoint is True: - return -self._eDeriv_u(tInd, src, self._edgeCurl.T * dun_dm_v, adjoint) + return -self._eDeriv_u(tInd, src, self._edgeCurl.T @ dun_dm_v, adjoint) return -(self._edgeCurl * self._eDeriv_u(tInd, src, dun_dm_v)) def _dbdtDeriv_m(self, tInd, src, v, adjoint=False): @@ -179,7 +180,7 @@ def _e(self, bSolution, source_list, tInd): def _eDeriv_u(self, tInd, src, dun_dm_v, adjoint=False): if adjoint is True: - return self._MfMui.T * (self._edgeCurl * (self._MeSigmaI.T * dun_dm_v)) + return self._MfMui.T @ (self._edgeCurl @ (self._MeSigmaI.T @ dun_dm_v)) return self._MeSigmaI * (self._edgeCurl.T * (self._MfMui * dun_dm_v)) def _eDeriv_m(self, tInd, src, v, adjoint=False): diff --git a/SimPEG/electromagnetics/time_domain/receivers.py b/SimPEG/electromagnetics/time_domain/receivers.py index 49e9e78c32..d2a3a1f8cb 100644 --- a/SimPEG/electromagnetics/time_domain/receivers.py +++ b/SimPEG/electromagnetics/time_domain/receivers.py @@ -99,14 +99,17 @@ def getSpatialP(self, mesh, f): scipy.sparse.csr_matrix P, the interpolation matrix """ - P = Zero() - field = f._GLoc(self.projField) - for strength, comp in zip(self.orientation, ["x", "y", "z"]): - if strength != 0.0: - P = P + strength * mesh.get_interpolation_matrix( - self.locations, field + comp - ) - return P + if getattr(self, "spatialP", None) is None: + P = Zero() + field = f._GLoc(self.projField) + for strength, comp in zip(self.orientation, ["x", "y", "z"]): + if strength != 0.0: + P = P + strength * mesh.get_interpolation_matrix( + self.locations, field + comp + ) + self.spatialP = P + + return self.spatialP def getTimeP(self, time_mesh, f): """Get time projection matrix from mesh to receivers. @@ -124,8 +127,13 @@ def getTimeP(self, time_mesh, f): scipy.sparse.csr_matrix P, the interpolation matrix """ - projected_time_grid = f._TLoc(self.projField) - return time_mesh.get_interpolation_matrix(self.times, projected_time_grid) + if getattr(self, "timeP", None) is None: + projected_time_grid = f._TLoc(self.projField) + self.timeP = time_mesh.get_interpolation_matrix( + self.times, projected_time_grid + ) + + return self.timeP def getP(self, mesh, time_mesh, f): """Returns projection matrices as a list for all components collected by the receivers.