diff --git a/news/norm-fixes.rst b/news/norm-fixes.rst new file mode 100644 index 00000000..45e07883 --- /dev/null +++ b/news/norm-fixes.rst @@ -0,0 +1,23 @@ +**Added:** + +* Fill component tail with zeros during normalization + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index 3269570d..443bc272 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -133,6 +133,7 @@ def __init__( self._rng = np.random.default_rng(self.random_state) self._plotter = SNMFPlotter() if self.show_plots else None + self._fill_tail_zero = False def _initialize_factors( self, @@ -212,6 +213,7 @@ def _initialize_factors( self._init_components = self.components_.copy() self._init_weights = self.weights_.copy() self._init_stretch = self.stretch_.copy() + self._fill_tail_zero = False # Second-order spline: Tridiagonal (-2 on diags, 1 on sub/superdiags) self._spline_smooth_operator = 0.25 * diags( @@ -409,54 +411,58 @@ def _normalize_results(self): self._prev_grad_components = np.zeros_like( self.components_ ) # Previous gradient of X (zeros for now) - self.residuals_ = self._get_residual_matrix() - self.objective_function_ = self._get_objective_function() - self.objective_difference_ = None - self._objective_history = [self.objective_function_] - self._outer_iter = 0 - self._inner_iter = 0 - for outiter in range(self.max_iter): - self._outer_iter = outiter - if outiter == 1: - self._inner_iter = ( - 1 # So step size can adapt without an inner loop - ) - self._update_components() + self._fill_tail_zero = True + try: self.residuals_ = self._get_residual_matrix() self.objective_function_ = self._get_objective_function() - self.objective_log.append( - { - "step": "c_norm", - "iteration": outiter, - "objective": self.objective_function_, - "timestamp": time.time(), - } - ) - self.objective_difference_ = ( - self.objective_log[-2]["objective"] - - self.objective_log[-1]["objective"] - ) - if self._plotter is not None: - self._plotter.update( - components=self.components_, - weights=self.weights_, - stretch=self.stretch_, - update_tag="normalize components", + self.objective_difference_ = None + self._objective_history = [self.objective_function_] + self._outer_iter = 0 + self._inner_iter = 0 + for outiter in range(self.max_iter): + self._outer_iter = outiter + if outiter == 1: + self._inner_iter = ( + 1 # So step size can adapt without an inner loop + ) + self._update_components() + self.residuals_ = self._get_residual_matrix() + self.objective_function_ = self._get_objective_function() + self.objective_log.append( + { + "step": "c_norm", + "iteration": outiter, + "objective": self.objective_function_, + "timestamp": time.time(), + } ) - convergence_threshold = self.objective_function_ * self.tol - if self.verbose: - print( - f"\n--- Iteration {outiter} after normalization---" - f"\nTotal Objective : {self.objective_function_:.5e}" - "\nConvergence Check : Δ " - f"({self.objective_difference_:.2e})" - f" < Threshold ({convergence_threshold:.2e})\n" + self.objective_difference_ = ( + self.objective_log[-2]["objective"] + - self.objective_log[-1]["objective"] ) - if ( - self.objective_difference_ < convergence_threshold - and outiter >= 7 - ): - break + if self._plotter is not None: + self._plotter.update( + components=self.components_, + weights=self.weights_, + stretch=self.stretch_, + update_tag="normalize components", + ) + convergence_threshold = self.objective_function_ * self.tol + if self.verbose: + print( + f"\n--- Iteration {outiter} after normalization---" + f"\nTotal Objective : {self.objective_function_:.5e}" + "\nConvergence Check : Δ " + f"({self.objective_difference_:.2e})" + f" < Threshold ({convergence_threshold:.2e})\n" + ) + if ( + self.objective_difference_ < convergence_threshold + and outiter >= 7 + ): + break + finally: + self._fill_tail_zero = False def _outer_loop(self): if self.verbose: @@ -591,13 +597,37 @@ def _get_residual_matrix( if stretch is None: stretch = self.stretch_ - reconstructed_matrix = _reconstruct_matrix( - components, weights, stretch - ) + if self._fill_tail_zero: + reconstructed_matrix = self._reconstruct_from_stretched_components( + components=components, + weights=weights, + stretch=stretch, + ) + else: + reconstructed_matrix = _reconstruct_matrix( + components, weights, stretch + ) residuals = reconstructed_matrix - self._source_matrix return residuals + def _reconstruct_from_stretched_components( + self, components=None, weights=None, stretch=None + ): + stretched_components, _, _ = self._compute_stretched_components( + components=components, + weights=weights, + stretch=stretch, + ) + intermediate = stretched_components.flatten(order="F").reshape( + (self.signal_length_ * self.n_signals_, self.n_components_), + order="F", + ) + return intermediate.sum(axis=1).reshape( + (self.signal_length_, self.n_signals_), + order="F", + ) + def _get_objective_function(self, residuals=None, stretch=None): """Return the objective value, passing stored attributes or overrides to _compute_objective_function(). @@ -680,13 +710,14 @@ def _compute_stretched_components( # For each stretched coordinate, find its prior integer (original) # index and their difference - i0 = np.floor(t).astype(np.int64) # prior original index - alpha = t - i0.astype(float) # fractional distance between left/right + i0_raw = np.floor(t).astype(np.int64) # prior original index + alpha = t - i0_raw.astype(float) + i1_raw = i0_raw + 1 - # Clip indices to range (0, signal_len - 1) to maintain original size + # Clip indices to range (0, signal_len - 1) to gather safely. max_idx = signal_len - 1 - i0 = np.clip(i0, 0, max_idx) - i1 = np.clip(i0 + 1, 0, max_idx) + i0 = np.clip(i0_raw, 0, max_idx) + i1 = np.clip(i1_raw, 0, max_idx) # Gather sample values comps_3d = components[ @@ -694,6 +725,9 @@ def _compute_stretched_components( ] # expand components by a dim for broadcasting across n_signals c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values + if self._fill_tail_zero: + c0 = np.where(i0_raw < signal_len, c0, 0.0) + c1 = np.where(i1_raw < signal_len, c1, 0.0) # Linear interpolation to determine stretched sample values interp = c0 * (1.0 - alpha) + c1 * alpha @@ -795,6 +829,42 @@ def _apply_transformation_matrix( return stretch_transformed + def _compute_component_gradient_zero_tail( + self, stretch=None, weights=None, residuals=None + ): + if stretch is None: + stretch = self.stretch_ + if weights is None: + weights = self.weights_ + if residuals is None: + residuals = self.residuals_ + + gradient = np.zeros_like(self.components_) + sample_indices = np.arange(self.signal_length_, dtype=float) + for signal in range(self.n_signals_): + for comp in range(self.n_components_): + positions = sample_indices / stretch[comp, signal] + left_indices = np.floor(positions).astype(int) + alpha = positions - left_indices + right_indices = left_indices + 1 + scaled_residual = residuals[:, signal] * weights[comp, signal] + + left_mask = left_indices < self.signal_length_ + np.add.at( + gradient[:, comp], + left_indices[left_mask], + scaled_residual[left_mask] * (1.0 - alpha[left_mask]), + ) + + right_mask = right_indices < self.signal_length_ + np.add.at( + gradient[:, comp], + right_indices[right_mask], + scaled_residual[right_mask] * alpha[right_mask], + ) + + return gradient + def _solve_quadratic_program(self, t, m): """ Solves the quadratic program for updating y in stretched NMF: @@ -870,9 +940,14 @@ def _update_components(self): reshaped_stretched_components - self._source_matrix ) # Compute gradient - self._grad_components = self._apply_transformation_matrix( - residuals=component_residuals - ).toarray() # toarray equivalent of full, make non-sparse + if self._fill_tail_zero: + self._grad_components = self._compute_component_gradient_zero_tail( + residuals=component_residuals + ) + else: + self._grad_components = self._apply_transformation_matrix( + residuals=component_residuals + ).toarray() # toarray equivalent of full, make non-sparse # Compute initial step size `initial_step_size` initial_step_size = np.linalg.eigvalsh( @@ -916,10 +991,11 @@ def _update_components(self): ) self.components_ = mask * self.components_ - objective_improvement = self.objective_log[-1][ - "objective" - ] - self._get_objective_function( - residuals=self._get_residual_matrix() + objective_improvement = ( + self.objective_function_ + - self._get_objective_function( + residuals=self._get_residual_matrix() + ) ) # Check if objective function improves