Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions news/norm-fixes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* Fill component tail with zeros during normalization

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
194 changes: 135 additions & 59 deletions src/diffpy/stretched_nmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(

self._rng = np.random.default_rng(self.random_state)
self._plotter = SNMFPlotter() if self.show_plots else None
self._fill_tail_zero = False

def _initialize_factors(
self,
Expand Down Expand Up @@ -212,6 +213,7 @@ def _initialize_factors(
self._init_components = self.components_.copy()
self._init_weights = self.weights_.copy()
self._init_stretch = self.stretch_.copy()
self._fill_tail_zero = False

# Second-order spline: Tridiagonal (-2 on diags, 1 on sub/superdiags)
self._spline_smooth_operator = 0.25 * diags(
Expand Down Expand Up @@ -409,54 +411,58 @@ def _normalize_results(self):
self._prev_grad_components = np.zeros_like(
self.components_
) # Previous gradient of X (zeros for now)
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
self.objective_difference_ = None
self._objective_history = [self.objective_function_]
self._outer_iter = 0
self._inner_iter = 0
for outiter in range(self.max_iter):
self._outer_iter = outiter
if outiter == 1:
self._inner_iter = (
1 # So step size can adapt without an inner loop
)
self._update_components()
self._fill_tail_zero = True
try:
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
self.objective_log.append(
{
"step": "c_norm",
"iteration": outiter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
self.objective_difference_ = (
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if self._plotter is not None:
self._plotter.update(
components=self.components_,
weights=self.weights_,
stretch=self.stretch_,
update_tag="normalize components",
self.objective_difference_ = None
self._objective_history = [self.objective_function_]
self._outer_iter = 0
self._inner_iter = 0
for outiter in range(self.max_iter):
self._outer_iter = outiter
if outiter == 1:
self._inner_iter = (
1 # So step size can adapt without an inner loop
)
self._update_components()
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
self.objective_log.append(
{
"step": "c_norm",
"iteration": outiter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
convergence_threshold = self.objective_function_ * self.tol
if self.verbose:
print(
f"\n--- Iteration {outiter} after normalization---"
f"\nTotal Objective : {self.objective_function_:.5e}"
"\nConvergence Check : Δ "
f"({self.objective_difference_:.2e})"
f" < Threshold ({convergence_threshold:.2e})\n"
self.objective_difference_ = (
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if (
self.objective_difference_ < convergence_threshold
and outiter >= 7
):
break
if self._plotter is not None:
self._plotter.update(
components=self.components_,
weights=self.weights_,
stretch=self.stretch_,
update_tag="normalize components",
)
convergence_threshold = self.objective_function_ * self.tol
if self.verbose:
print(
f"\n--- Iteration {outiter} after normalization---"
f"\nTotal Objective : {self.objective_function_:.5e}"
"\nConvergence Check : Δ "
f"({self.objective_difference_:.2e})"
f" < Threshold ({convergence_threshold:.2e})\n"
)
if (
self.objective_difference_ < convergence_threshold
and outiter >= 7
):
break
finally:
self._fill_tail_zero = False

def _outer_loop(self):
if self.verbose:
Expand Down Expand Up @@ -591,13 +597,37 @@ def _get_residual_matrix(
if stretch is None:
stretch = self.stretch_

reconstructed_matrix = _reconstruct_matrix(
components, weights, stretch
)
if self._fill_tail_zero:
reconstructed_matrix = self._reconstruct_from_stretched_components(
components=components,
weights=weights,
stretch=stretch,
)
else:
reconstructed_matrix = _reconstruct_matrix(
components, weights, stretch
)
residuals = reconstructed_matrix - self._source_matrix

return residuals

def _reconstruct_from_stretched_components(
self, components=None, weights=None, stretch=None
):
stretched_components, _, _ = self._compute_stretched_components(
components=components,
weights=weights,
stretch=stretch,
)
intermediate = stretched_components.flatten(order="F").reshape(
(self.signal_length_ * self.n_signals_, self.n_components_),
order="F",
)
return intermediate.sum(axis=1).reshape(
(self.signal_length_, self.n_signals_),
order="F",
)

def _get_objective_function(self, residuals=None, stretch=None):
"""Return the objective value, passing stored attributes or
overrides to _compute_objective_function().
Expand Down Expand Up @@ -680,20 +710,24 @@ def _compute_stretched_components(

# For each stretched coordinate, find its prior integer (original)
# index and their difference
i0 = np.floor(t).astype(np.int64) # prior original index
alpha = t - i0.astype(float) # fractional distance between left/right
i0_raw = np.floor(t).astype(np.int64) # prior original index
alpha = t - i0_raw.astype(float)
i1_raw = i0_raw + 1

# Clip indices to range (0, signal_len - 1) to maintain original size
# Clip indices to range (0, signal_len - 1) to gather safely.
max_idx = signal_len - 1
i0 = np.clip(i0, 0, max_idx)
i1 = np.clip(i0 + 1, 0, max_idx)
i0 = np.clip(i0_raw, 0, max_idx)
i1 = np.clip(i1_raw, 0, max_idx)

# Gather sample values
comps_3d = components[
:, :, None
] # expand components by a dim for broadcasting across n_signals
c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values
c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values
if self._fill_tail_zero:
c0 = np.where(i0_raw < signal_len, c0, 0.0)
c1 = np.where(i1_raw < signal_len, c1, 0.0)

# Linear interpolation to determine stretched sample values
interp = c0 * (1.0 - alpha) + c1 * alpha
Expand Down Expand Up @@ -795,6 +829,42 @@ def _apply_transformation_matrix(

return stretch_transformed

def _compute_component_gradient_zero_tail(
self, stretch=None, weights=None, residuals=None
):
if stretch is None:
stretch = self.stretch_
if weights is None:
weights = self.weights_
if residuals is None:
residuals = self.residuals_

gradient = np.zeros_like(self.components_)
sample_indices = np.arange(self.signal_length_, dtype=float)
for signal in range(self.n_signals_):
for comp in range(self.n_components_):
positions = sample_indices / stretch[comp, signal]
left_indices = np.floor(positions).astype(int)
alpha = positions - left_indices
right_indices = left_indices + 1
scaled_residual = residuals[:, signal] * weights[comp, signal]

left_mask = left_indices < self.signal_length_
np.add.at(
gradient[:, comp],
left_indices[left_mask],
scaled_residual[left_mask] * (1.0 - alpha[left_mask]),
)

right_mask = right_indices < self.signal_length_
np.add.at(
gradient[:, comp],
right_indices[right_mask],
scaled_residual[right_mask] * alpha[right_mask],
)

return gradient

def _solve_quadratic_program(self, t, m):
"""
Solves the quadratic program for updating y in stretched NMF:
Expand Down Expand Up @@ -870,9 +940,14 @@ def _update_components(self):
reshaped_stretched_components - self._source_matrix
)
# Compute gradient
self._grad_components = self._apply_transformation_matrix(
residuals=component_residuals
).toarray() # toarray equivalent of full, make non-sparse
if self._fill_tail_zero:
self._grad_components = self._compute_component_gradient_zero_tail(
residuals=component_residuals
)
else:
self._grad_components = self._apply_transformation_matrix(
residuals=component_residuals
).toarray() # toarray equivalent of full, make non-sparse

# Compute initial step size `initial_step_size`
initial_step_size = np.linalg.eigvalsh(
Expand Down Expand Up @@ -916,10 +991,11 @@ def _update_components(self):
)
self.components_ = mask * self.components_

objective_improvement = self.objective_log[-1][
"objective"
] - self._get_objective_function(
residuals=self._get_residual_matrix()
objective_improvement = (
self.objective_function_
- self._get_objective_function(
residuals=self._get_residual_matrix()
)
)

# Check if objective function improves
Expand Down
Loading