Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion simpeg/dask/data_misfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def dask_call(self, m, f=None):
Distributed :obj:`simpeg.data_misfit.L2DataMisfit.__call__`
"""
R = self.W * self.residual(m, f=f)
phi_d = 0.5 * da.dot(R, R)
phi_d = da.dot(R, R)
if not isinstance(phi_d, np.ndarray):
return compute(self, phi_d)
return phi_d
Expand Down
2 changes: 1 addition & 1 deletion simpeg/dask/inverse_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True):
phi_d = 0
for (mult, objfct), pred in zip(self.dmisfit, self.dpred):
residual = objfct.W * (objfct.data.dobs - pred)
phi_d += 0.5 * mult * np.vdot(residual, residual)
phi_d += mult * np.vdot(residual, residual)

phi_d = np.asarray(phi_d)
# print(self.dpred[0])
Expand Down
1 change: 1 addition & 0 deletions simpeg/dask/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def Jmatrix(self):
if getattr(self, "_Jmatrix", None) is None:
if self.workers is None:
self._Jmatrix = self.compute_J()
self._G = self._Jmatrix
else:
client = get_client() # Assumes a Client already exists

Expand Down
1 change: 1 addition & 0 deletions simpeg/directives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
VectorInversion,
SaveIterationsGeoH5,
ProjectSphericalBounds,
ScaleMisfitMultipliers,
)

from .pgi_directives import (
Expand Down
79 changes: 78 additions & 1 deletion simpeg/directives/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3169,7 +3169,7 @@ def save_log(self):
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]

for file in ["SimPEG.out", "SimPEG.log"]:
for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]:
filepath = dirpath / file

if not filepath.is_file():
Expand Down Expand Up @@ -3426,3 +3426,80 @@ def endIter(self):
for directive in directiveList:
if not isinstance(directive, SaveIterationsGeoH5):
directive.endIter()


class ScaleMisfitMultipliers(InversionDirective):
"""
Scale the misfits by the relative chi-factors of multiple misfit functions.

The goal is to reduce the relative influence of the misfit functions with
lowest chi-factors so that all functions reach a similar level of fit at
convergence to the global target.

Parameters
----------

path : str
Path to save the chi-factors log file.
"""

def __init__(self, path: Path | None, **kwargs):
self.last_beta = None

if path is None:
path = Path()

self.filepath = path / "ChiFactors.log"

super().__init__(**kwargs)

def initialize(self):
self.last_beta = self.invProb.beta
self.multipliers = self.invProb.dmisfit.multipliers

with open(self.filepath, "w", encoding="utf-8") as f:
f.write(
"Iterations\t"
+ "\t".join(
f"[{objfct.name}]" for objfct in self.invProb.dmisfit.objfcts
)
)
f.write("\n")

def endIter(self):
Comment thread
benk-mira marked this conversation as resolved.
ratio = self.invProb.beta / self.last_beta
chi_factors = []
phi_ds = []
for objfct, pred in zip(self.invProb.dmisfit.objfcts, self.invProb.dpred):
residual = objfct.W * (objfct.data.dobs - pred)
phi_d = np.vdot(residual, residual)
chi_factors.append(phi_d / objfct.nD)
phi_ds.append(phi_d)

phi_ds = np.asarray(phi_ds)
chi_factors = np.asarray(chi_factors)
scalings = chi_factors / chi_factors.max()

# Force beta ratio scaling if below target
scalings[chi_factors < 1] *= ratio

# Normalize total phi_d with scalings
multipliers = (
self.multipliers
* scalings
* phi_ds.sum()
/ (self.multipliers * phi_ds * scalings).sum()
)

with open(self.filepath, "a", encoding="utf-8") as f:
f.write(
f"{self.opt.iter}\t"
+ "\t".join(
f"{multi:.2e}*{chi:.2e}"
for multi, chi in zip(multipliers, chi_factors)
)
+ "\n"
)

self.invProb.dmisfit.multipliers = multipliers.tolist()
self.last_beta = self.invProb.beta