diff --git a/SimPEG/dask/electromagnetics/frequency_domain/simulation.py b/SimPEG/dask/electromagnetics/frequency_domain/simulation.py index d87040bbbe..a93884556d 100644 --- a/SimPEG/dask/electromagnetics/frequency_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/frequency_domain/simulation.py @@ -2,11 +2,14 @@ from ....utils import Zero import numpy as np import scipy.sparse as sp - -from dask import array, compute, delayed +from multiprocessing import cpu_count +from dask import array, compute, delayed, config +from dask.distributed import get_client, Client, performance_report from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag from SimPEG.dask.utils import get_parallel_blocks +from SimPEG.electromagnetics.natural_source.sources import PlanewaveXYPrimary import zarr +from tqdm import tqdm Sim.sensitivity_path = "./sensitivity/" Sim.gtgdiag = None @@ -18,25 +21,150 @@ Sim.clean_on_model_update = ["_Jmatrix", "_jtjdiag"] +@delayed +def source_evaluation(simulation, sources): + s_m, s_e = [], [] + for source in sources: + sm, se = source.eval(simulation) + s_m.append(sm) + s_e.append(se) + + return s_m, s_e + + +def dask_getSourceTerm(self, freq, source=None): + """ + Assemble the source term. This ensures that the RHS is a vector / array + of the correct size + """ + if source is None: + source_list = self.survey.get_sources_by_frequency(freq) + source_block = np.array_split(source_list, cpu_count()) + + block_compute = [] + for block in source_block: + if len(block) == 0: + continue + + block_compute.append(source_evaluation(self, block)) + + eval = compute(block_compute)[0] + s_m, s_e = [], [] + for block in eval: + if block[0]: + s_m += block[0] + s_e += block[1] + + else: + sm, se = source.eval(self) + s_m, s_e = [sm], [se] + + if isinstance(s_m[0][0], Zero): # Assume the rest is all Zero + s_m = Zero() + else: + s_m = np.vstack(s_m) + if s_m.shape[0] < s_m.shape[1]: + s_m = s_m.T + + if isinstance(s_e[0][0], Zero): # Assume the rest is all Zero + s_e = Zero() + else: + s_e = np.vstack(s_e) + if s_e.shape[0] < s_e.shape[1]: + s_e = s_e.T + return s_m, s_e + + +Sim.getSourceTerm = dask_getSourceTerm + + +@delayed +def evaluate_receivers(block, mesh, fields): + data = [] + for source, ind, receiver in block: + data.append(receiver.eval(source, mesh, fields).flatten()) + + return np.hstack(data) + + +def dask_dpred(self, m=None, f=None, compute_J=False): + """ + dpred(m, f=None) + Create the projected data from a model. + The fields, f, (if provided) will be used for the predicted data + instead of recalculating the fields (which may be expensive!). + + .. math:: + + d_\\text{pred} = P(f(m)) + + Where P is a projection of the fields onto the data space. + """ + if self.survey is None: + raise AttributeError( + "The survey has not yet been set and is required to compute " + "data. Please set the survey for the simulation: " + "simulation.survey = survey" + ) + + if f is None: + if m is None: + m = self.model + f, Ainv = self.fields(m, return_Ainv=compute_J) + + all_receivers = [] + + for ind, src in enumerate(self.survey.source_list): + for rx in src.receiver_list: + all_receivers.append((src, ind, rx)) + + receiver_blocks = np.array_split(all_receivers, cpu_count()) + rows = [] + mesh = delayed(self.mesh) + for block in receiver_blocks: + n_data = np.sum(rec.nD for _, _, rec in block) + if n_data == 0: + continue + + rows.append( + array.from_delayed( + evaluate_receivers(block, mesh, f), + dtype=np.float64, + shape=(n_data,), + ) + ) + + data = compute(array.hstack(rows))[0] + + if compute_J and self._Jmatrix is None: + Jmatrix = self.compute_J(f=f, Ainv=Ainv) + return data, Jmatrix + + return data + + +Sim.dpred = dask_dpred +Sim.field_derivs = None + + def fields(self, m=None, return_Ainv=False): if m is not None: self.model = m f = self.fieldsPair(self) - - Ainv = [] + Ainv = {} for freq in self.survey.frequencies: A = self.getA(freq) rhs = self.getRHS(freq) Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) u = Ainv_solve * rhs - Srcs = self.survey.get_sources_by_frequency(freq) - f[Srcs, self._solutionType] = u - - Ainv_solve.clean() + sources = self.survey.get_sources_by_frequency(freq) + f[sources, self._solutionType] = u if return_Ainv: - Ainv += [self.solver(sp.csr_matrix(A.T), **self.solver_opts)] + Ainv[freq] = Ainv_solve + else: + Ainv_solve.clean() if return_Ainv: return f, Ainv @@ -51,100 +179,66 @@ def compute_J(self, f=None, Ainv=None): if f is None: f, Ainv = self.fields(self.model, return_Ainv=True) + if len(Ainv) > 1: + raise NotImplementedError( + "Current implementation of parallelization assumes a single frequency per simulation. " + "Consider creating one misfit per frequency." + ) + + A_i = list(Ainv.values())[0] m_size = self.model.size - Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) + + if self.store_sensitivities == "disk": + Jmatrix = zarr.open( + self.sensitivity_path, + mode="w", + shape=(self.survey.nD, m_size), + chunks=(self.max_chunk_size, m_size), + ) + else: + Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32) + + compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) blocks = get_parallel_blocks( - self.survey.source_list, self.model.shape[0], self.max_chunk_size + self.survey.source_list, compute_row_size, optimize=False ) - count = 0 - block_count = 0 - col_chunks = None - for A_i, freq in zip(Ainv, self.survey.frequencies): - sources = [] - blocks_dfduT = [] - blocks_dfdmT = [] - block_count = 0 - - for ss, src in enumerate(self.survey.get_sources_by_frequency(freq)): - u_src = f[src, self._solutionType] - - if col_chunks is None: - col_chunks = int( - np.ceil( - float(self.survey.nD) - / np.ceil( - float(u_src.shape[0]) - * self.survey.nD - * 8.0 - * 1e-6 - / self.max_chunk_size - ) - ) - ) - - for rx in src.receiver_list: - v = np.eye(rx.nD, dtype=float) - n_blocs = np.ceil(u_src.shape[1] * rx.nD / col_chunks) - - for block in np.array_split(v, n_blocs, axis=1): - if block.shape[1] == 0: - continue - - block_count += block.shape[1] * u_src.shape[1] - blocks_dfduT.append( - array.from_delayed( - dfduT(src, rx, self.mesh, f, block), - dtype=np.float32, - shape=(u_src.shape[0], block.shape[1] * u_src.shape[1]), - ) - ) - blocks_dfdmT.append( - dfdmT(src, rx, self.mesh, f, block), - ) - sources.append(src) - - if block_count >= (col_chunks): - count = parallel_block_compute( - self, - A_i, - Jmatrix, - freq, - f, - sources, - blocks_dfduT, - blocks_dfdmT, - count, - m_size, - u_src.shape, - self._solutionType, - ) - blocks_dfduT = [] - blocks_dfdmT = [] - sources = [] - block_count = 0 - - if blocks_dfduT: - count = parallel_block_compute( - self, - A_i, - Jmatrix, - freq, - f, - sources, - blocks_dfduT, - blocks_dfdmT, - count, - m_size, - u_src.shape, - self._solutionType, + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + blocks_receiver_derivs = [] + + for block in blocks: + blocks_receiver_derivs.append( + receiver_derivs( + survey, + mesh, + fields, + block, ) + ) + + # with Client(processes=False) as client: + # with performance_report(filename="dask-report.html"): + + # Dask process for all derivatives + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + + for block_derivs_chunks, addresses_chunks in tqdm( + zip(blocks_receiver_derivs, blocks), + ncols=len(blocks_receiver_derivs), + desc=f"Sensitivities at {list(Ainv)[0]} Hz", + ): + Jmatrix = parallel_block_compute( + self, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks + ) - for A in Ainv: + for A in Ainv.values(): A.clean() if self.store_sensitivities == "disk": del Jmatrix - return array.from_zarr(self.sensitivity_path + f"J.zarr") + return array.from_zarr(self.sensitivity_path) else: return Jmatrix @@ -152,26 +246,105 @@ def compute_J(self, f=None, Ainv=None): Sim.compute_J = compute_J -@delayed -def dfduT(source, receiver, mesh, fields, block): - dfduT, _ = receiver.evalDeriv(source, mesh, fields, v=block, adjoint=True) +def parallel_block_compute( + self, Jmatrix, blocks_receiver_derivs, A_i, fields_array, addresses +): + m_size = self.model.size + block_stack = sp.hstack(blocks_receiver_derivs).toarray() + ATinvdf_duT = delayed(A_i * block_stack) + count = 0 + rows = [] + block_delayed = [] - return dfduT + for address, dfduT in zip(addresses, blocks_receiver_derivs): + n_cols = dfduT.shape[1] + n_rows = address[1][2] + block_delayed.append( + array.from_delayed( + eval_block( + self, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), + ) + ) + count += n_cols + rows += address[1][1].tolist() + + indices = np.hstack(rows) + + if self.store_sensitivities == "disk": + Jmatrix.set_orthogonal_selection( + (indices, slice(None)), + compute(array.vstack(block_delayed))[0], + ) + else: + # Dask process to compute row and store + Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + + return Jmatrix @delayed -def dfdmT(source, receiver, mesh, fields, block): - _, dfdmT = receiver.evalDeriv(source, mesh, fields, v=block, adjoint=True) +def receiver_derivs(survey, mesh, fields, blocks): + field_derivatives = [] + for address in blocks: + source = survey.source_list[address[0][0]] + receiver = source.receiver_list[address[0][1]] + + if isinstance(source, PlanewaveXYPrimary): + v = np.eye(receiver.nD, dtype=float) + else: + v = sp.csr_matrix(np.ones(receiver.nD), dtype=float) + + # Assume the derivatives in terms of model are Zero (seems to always be case) + dfduT, _ = receiver.evalDeriv( + source, mesh, fields, v=v[:, address[1][0]], adjoint=True + ) + field_derivatives.append(dfduT) - return dfdmT + return field_derivatives -def eval_block(simulation, Ainv_deriv_u, frequency, deriv_m, fields, source): +@delayed +def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): """ - Evaluate the sensitivities for the block or data and store to zarr + Evaluate the sensitivities for the block or data """ - dA_dmT = simulation.getADeriv(frequency, fields, Ainv_deriv_u, adjoint=True) - dRHS_dmT = simulation.getRHSDeriv(frequency, source, Ainv_deriv_u, adjoint=True) + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_columns = Ainv_deriv_u[:, deriv_indices] + + n_receivers = address[1][2] + source = simulation.survey.source_list[address[0][0]] + + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + n_cols = 2 + else: + source_fields = fields[:, address[0][0]] + n_cols = 1 + + n_cols *= n_receivers + + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) du_dmT = -dA_dmT if not isinstance(dRHS_dmT, Zero): du_dmT += dRHS_dmT @@ -179,72 +352,3 @@ def eval_block(simulation, Ainv_deriv_u, frequency, deriv_m, fields, source): du_dmT += deriv_m return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T - - -def parallel_block_compute( - simulation, - A_i, - Jmatrix, - freq, - fields, - sources, - blocks_deriv_u, - blocks_deriv_m, - counter, - m_size, - f_shape, - solution_type, -): - field_derivs = array.hstack(blocks_deriv_u).compute() - - # Direct-solver call - ATinvdf_duT = (A_i * field_derivs).reshape((f_shape[0], -1)) - - # Even split - split = np.cumsum([block.shape[1] for block in blocks_deriv_u])[:-1] - sub_blocks_deriv_u = np.array_split(ATinvdf_duT, split, axis=1) - - if isinstance(compute(blocks_deriv_m[0])[0], Zero): - sub_blocks_dfdmt = [Zero()] * len(sub_blocks_deriv_u) - else: - compute_blocks_deriv_m = array.hstack( - [ - array.from_delayed( - dfdmT_block, - dtype=np.float32, - shape=(f_shape[0], dfdmT_block.shape[1] * f_shape[1]), - ) - for dfdmT_block in blocks_deriv_m - ] - ).compute() - sub_blocks_dfdmt = np.array_split(compute_blocks_deriv_m, split, axis=1) - - sub_process = [] - - for sub_block_dfduT, sub_block_dfdmT, src in zip( - sub_blocks_deriv_u, sub_blocks_dfdmt, sources - ): - u_src = fields[src, solution_type] - row_size = int(sub_block_dfduT.shape[1] / f_shape[1]) - sub_process.append( - array.from_delayed( - delayed(eval_block, pure=True)( - simulation, sub_block_dfduT, freq, sub_block_dfdmT, u_src, src - ), - dtype=np.float32, - shape=(row_size, m_size), - ) - ) - - block = array.vstack(sub_process).compute() - - if simulation.store_sensitivities == "disk": - Jmatrix.set_orthogonal_selection( - (np.arange(counter, counter + block.shape[0]), slice(None)), - block.astype(np.float32), - ) - else: - Jmatrix[counter : counter + block.shape[0], :] = block.astype(np.float32) - - counter += block.shape[0] - return counter diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 7307531eb8..bcca19adbe 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -135,6 +135,17 @@ def fields(self, m=None, return_Ainv=False): Sim.fields = fields +@delayed +def source_evaluation(simulation, sources, time): + s_m, s_e = [], [] + for source in sources: + sm, se = source.eval(simulation, time) + s_m.append(sm) + s_e.append(se) + + return s_m, s_e + + def dask_getSourceTerm(self, tInd): """ Assemble the source term. This ensures that the RHS is a vector / array @@ -143,20 +154,9 @@ def dask_getSourceTerm(self, tInd): source_list = self.survey.source_list source_block = np.array_split(source_list, cpu_count()) - def source_evaluation(simulation, sources, time): - s_m, s_e = [], [] - for source in sources: - sm, se = source.eval(simulation, time) - s_m.append(sm) - s_e.append(se) - - return s_m, s_e - block_compute = [] for block in source_block: - block_compute.append( - delayed(source_evaluation, pure=True)(self, block, self.times[tInd]) - ) + block_compute.append(source_evaluation(self, block, self.times[tInd])) eval = dask.compute(block_compute)[0] @@ -534,9 +534,8 @@ def compute_J(self, f=None, Ainv=None): simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 data_times = self.survey.source_list[0].receiver_list[0].times - blocks = get_parallel_blocks( - self.survey.source_list, self.model.shape[0], self.max_chunk_size - ) + compute_row_size = np.ceil(self.max_chunk_size / (self.model.shape[0] * 8.0 * 1e-6)) + blocks = get_parallel_blocks(self.survey.source_list, compute_row_size) fields_array = f[:, ftype, :] if len(self.survey.source_list) == 1: diff --git a/SimPEG/dask/utils.py b/SimPEG/dask/utils.py index 558fb08aa3..ad292dc9a2 100644 --- a/SimPEG/dask/utils.py +++ b/SimPEG/dask/utils.py @@ -40,7 +40,7 @@ def compute(self, job): return job.compute() -def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> list: +def get_parallel_blocks(source_list: list, data_block_size, optimize=True) -> list: """ Get the blocks of sources and receivers to be computed in parallel. @@ -48,7 +48,6 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> (source, receiver, block index) and array of indices for the rows of the sensitivity matrix. """ - data_block_size = np.ceil(max_chunk_size / (m_size * 8.0 * 1e-6)) row_count = 0 row_index = 0 block_count = 0 @@ -84,7 +83,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) -> row_count += chunk_size # Re-split over cpu_count if too few blocks - if len(blocks) < cpu_count(): + if len(blocks) < cpu_count() and optimize: flatten_blocks = [] for block in blocks: flatten_blocks += block diff --git a/SimPEG/electromagnetics/frequency_domain/fields.py b/SimPEG/electromagnetics/frequency_domain/fields.py index d39dd3b039..740dd10dac 100644 --- a/SimPEG/electromagnetics/frequency_domain/fields.py +++ b/SimPEG/electromagnetics/frequency_domain/fields.py @@ -255,6 +255,13 @@ def _hDeriv(self, src, du_dm_v, v, adjoint=False): ) if adjoint: + if ( + isinstance(src.s_mDeriv(self.simulation, v, adjoint), Zero) + and isinstance(src.bPrimaryDeriv(self.simulation, v, adjoint), Zero) + and isinstance(self._MfMuiDeriv(v), Zero) + ): + return self._hDeriv_u(src, v, adjoint), Zero() + return (self._hDeriv_u(src, v, adjoint), self._hDeriv_m(src, v, adjoint)) return np.array( self._hDeriv_u(src, du_dm_v, adjoint) + self._hDeriv_m(src, v, adjoint), diff --git a/SimPEG/electromagnetics/frequency_domain/simulation.py b/SimPEG/electromagnetics/frequency_domain/simulation.py index 9c1372b80f..4a13430f6a 100644 --- a/SimPEG/electromagnetics/frequency_domain/simulation.py +++ b/SimPEG/electromagnetics/frequency_domain/simulation.py @@ -200,7 +200,7 @@ def Jtvec(self, m, v, f=None): return mkvc(Jtv) # @profile - def getSourceTerm(self, freq): + def getSourceTerm(self, freq, source=None): """ Evaluates the sources for a given frequency and puts them in matrix form @@ -209,7 +209,11 @@ def getSourceTerm(self, freq): :rtype: tuple :return: (s_m, s_e) (nE or nF, nSrc) """ - Srcs = self.survey.get_sources_by_frequency(freq) + if source is not None: + Srcs = [source] + else: + Srcs = self.survey.get_sources_by_frequency(freq) + n_fields = sum(src._fields_per_source for src in Srcs) if self._formulation == "EB": s_m = np.zeros((self.mesh.nF, n_fields), dtype=complex, order="F") @@ -362,7 +366,7 @@ def getRHSDeriv(self, freq, src, v, adjoint=False): C = self.mesh.edge_curl MfMui = self.MfMui - s_m, s_e = self.getSourceTerm(freq) + s_m, s_e = self.getSourceTerm(freq, source=src) s_mDeriv, s_eDeriv = src.evalDeriv(self, adjoint=adjoint) MfMuiDeriv = self.MfMuiDeriv(s_m) @@ -716,7 +720,7 @@ def getRHSDeriv(self, freq, src, v, adjoint=False): MeMuI = self.MeMuI MeMuIDeriv = self.MeMuIDeriv s_mDeriv, s_eDeriv = src.evalDeriv(self, adjoint=adjoint) - s_m, _ = self.getSourceTerm(freq) + s_m, _ = self.getSourceTerm(freq, source=src) if adjoint: if self._makeASymmetric: diff --git a/SimPEG/electromagnetics/natural_source/receivers.py b/SimPEG/electromagnetics/natural_source/receivers.py index 9aadd894a9..b9e822e4fc 100644 --- a/SimPEG/electromagnetics/natural_source/receivers.py +++ b/SimPEG/electromagnetics/natural_source/receivers.py @@ -2,7 +2,7 @@ import numpy as np from scipy.constants import mu_0 - +import scipy.sparse as sp from ...survey import BaseRx @@ -315,8 +315,8 @@ def _eval_impedance_deriv(self, src, mesh, f, du_dm_v=None, v=None, adjoint=Fals else: ghx_v -= gh_v - gh_v = Phx.T @ ghx_v + Phy.T @ ghy_v - ge_v = Pe.T @ ge_v + gh_v = Phx.T @ sp.csr_matrix(ghx_v) + Phy.T @ sp.csr_matrix(ghy_v) + ge_v = Pe.T @ sp.csr_matrix(ge_v) else: if mesh.dim == 1 and self.orientation != f.field_directions: gbot_v = -gbot_v @@ -489,7 +489,11 @@ def orientation(self, var): def _eval_tipper(self, src, mesh, f): # will grab both primary and secondary and sum them! - h = f[src, "h"] + + if not isinstance(f, np.ndarray): + h = f[src, "h"] + else: + h = f hx = self.getP(mesh, "Fx", "h") @ h hy = self.getP(mesh, "Fy", "h") @ h @@ -506,7 +510,11 @@ def _eval_tipper(self, src, mesh, f): def _eval_tipper_deriv(self, src, mesh, f, du_dm_v=None, v=None, adjoint=False): # will grab both primary and secondary and sum them! - h = f[src, "h"] + + if not isinstance(f, np.ndarray): + h = f[src, "h"] + else: + h = f Phx = self.getP(mesh, "Fx", "h") Phy = self.getP(mesh, "Fy", "h") @@ -547,7 +555,11 @@ def _eval_tipper_deriv(self, src, mesh, f, du_dm_v=None, v=None, adjoint=False): else: ghx_v += gh_v - gh_v = Phx.T @ ghx_v + Phy.T @ ghy_v + Phz.T @ ghz_v + gh_v = ( + Phx.T @ sp.csr_matrix(ghx_v) + + Phy.T @ sp.csr_matrix(ghy_v) + + Phz.T @ sp.csr_matrix(ghz_v) + ) return f._hDeriv(src, None, gh_v, adjoint=True) dh_v = f._hDeriv(src, du_dm_v, v, adjoint=False)