diff --git a/Cargo.toml b/Cargo.toml index 34eacae..df7a50c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,11 @@ name = "vs_linalg" harness = false required-features = [ "bench" ] +[[bench]] +name = "exact" +harness = false +required-features = [ "bench", "exact" ] + [profile.release] lto = "fat" codegen-units = 1 diff --git a/benches/exact.rs b/benches/exact.rs new file mode 100644 index 0000000..3fd7846 --- /dev/null +++ b/benches/exact.rs @@ -0,0 +1,166 @@ +//! Benchmarks for exact arithmetic operations. +//! +//! These benchmarks measure the performance of the `exact` feature's +//! arbitrary-precision methods across dimensions D=2..5 (the primary +//! target for geometric predicates). + +use criterion::Criterion; +use la_stack::{Matrix, Vector}; +use pastey::paste; +use std::hint::black_box; + +#[inline] +#[allow(clippy::cast_precision_loss)] +const fn matrix_entry(r: usize, c: usize) -> f64 { + if r == c { + (r as f64).mul_add(1.0e-3, (D as f64) + 1.0) + } else { + 0.1 / ((r + c + 1) as f64) + } +} + +#[inline] +const fn make_matrix_rows() -> [[f64; D]; D] { + let mut rows = [[0.0; D]; D]; + let mut r = 0; + while r < D { + let mut c = 0; + while c < D { + rows[r][c] = matrix_entry::(r, c); + c += 1; + } + r += 1; + } + rows +} + +#[inline] +#[allow(clippy::cast_precision_loss)] +fn make_vector_array() -> [f64; D] { + let mut data = [0.0; D]; + let mut i = 0; + while i < D { + data[i] = (i as f64) + 1.0; + i += 1; + } + data +} + +/// Near-singular matrix: base singular matrix + tiny perturbation. +/// This forces the exact Bareiss fallback in `det_sign_exact` (the fast +/// f64 filter cannot resolve the sign). +#[inline] +fn near_singular_3x3() -> Matrix<3> { + let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); // 2^-50 + Matrix::<3>::from_rows([ + [1.0 + perturbation, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]) +} + +macro_rules! gen_exact_benches_for_dim { + ($c:expr, $d:literal) => { + paste! {{ + let a = Matrix::<$d>::from_rows(make_matrix_rows::<$d>()); + let rhs = Vector::<$d>::new(make_vector_array::<$d>()); + + let mut [] = ($c).benchmark_group(concat!("exact_d", stringify!($d))); + + // === f64 baselines === + [].bench_function("det", |bencher| { + bencher.iter(|| { + let det = black_box(a) + .det(la_stack::DEFAULT_PIVOT_TOL) + .expect("should not fail"); + black_box(det); + }); + }); + + [].bench_function("det_direct", |bencher| { + bencher.iter(|| { + let det = black_box(a).det_direct(); + black_box(det); + }); + }); + + // === det_exact (BigRational result) === + [].bench_function("det_exact", |bencher| { + bencher.iter(|| { + let det = black_box(a).det_exact().expect("should not fail"); + black_box(det); + }); + }); + + // === det_exact_f64 (exact → f64) === + [].bench_function("det_exact_f64", |bencher| { + bencher.iter(|| { + let det = black_box(a).det_exact_f64().expect("should not fail"); + black_box(det); + }); + }); + + // === det_sign_exact (adaptive: fast filter + exact fallback) === + [].bench_function("det_sign_exact", |bencher| { + bencher.iter(|| { + let sign = black_box(a).det_sign_exact().expect("should not fail"); + black_box(sign); + }); + }); + + // === solve_exact (BigRational result) === + [].bench_function("solve_exact", |bencher| { + bencher.iter(|| { + let x = black_box(a).solve_exact(black_box(rhs)).expect("should not fail"); + black_box(x); + }); + }); + + // === solve_exact_f64 (exact → f64) === + [].bench_function("solve_exact_f64", |bencher| { + bencher.iter(|| { + let x = black_box(a).solve_exact_f64(black_box(rhs)).expect("should not fail"); + black_box(x); + }); + }); + + [].finish(); + }}; + }; +} + +fn main() { + let mut c = Criterion::default().configure_from_args(); + + #[allow(unused_must_use)] + { + gen_exact_benches_for_dim!(&mut c, 2); + gen_exact_benches_for_dim!(&mut c, 3); + gen_exact_benches_for_dim!(&mut c, 4); + gen_exact_benches_for_dim!(&mut c, 5); + } + + // Near-singular 3×3: forces Bareiss fallback in det_sign_exact. + { + let m = near_singular_3x3(); + let mut group = c.benchmark_group("exact_near_singular_3x3"); + + group.bench_function("det_sign_exact", |bencher| { + bencher.iter(|| { + let sign = black_box(m).det_sign_exact().expect("should not fail"); + black_box(sign); + }); + }); + + group.bench_function("det_exact", |bencher| { + bencher.iter(|| { + let det = black_box(m).det_exact().expect("should not fail"); + black_box(det); + }); + }); + + group.finish(); + } + + c.final_summary(); +} diff --git a/src/exact.rs b/src/exact.rs index 9f0c28d..ce164db 100644 --- a/src/exact.rs +++ b/src/exact.rs @@ -44,7 +44,10 @@ fn validate_finite(m: &Matrix) -> Result<(), LaError> { for r in 0..D { for c in 0..D { if !m.rows[r][c].is_finite() { - return Err(LaError::NonFinite { col: c }); + return Err(LaError::NonFinite { + row: Some(r), + col: c, + }); } } } @@ -58,7 +61,7 @@ fn validate_finite(m: &Matrix) -> Result<(), LaError> { fn validate_finite_vec(v: &Vector) -> Result<(), LaError> { for (i, &x) in v.data.iter().enumerate() { if !x.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } } Ok(()) @@ -88,14 +91,8 @@ fn bareiss_det(m: &Matrix) -> BigRational { } // Convert f64 entries to exact BigRational. - let mut a: Vec> = Vec::with_capacity(D); - for r in 0..D { - let mut row = Vec::with_capacity(D); - for c in 0..D { - row.push(f64_to_bigrational(m.rows[r][c])); - } - a.push(row); - } + let mut a: [[BigRational; D]; D] = + std::array::from_fn(|r| std::array::from_fn(|c| f64_to_bigrational(m.rows[r][c]))); let zero = BigRational::from_integer(BigInt::from(0)); let mut prev_pivot = BigRational::from_integer(BigInt::from(1)); @@ -144,62 +141,56 @@ fn bareiss_det(m: &Matrix) -> BigRational { /// (no numerical stability concern). This matches the pivoting strategy used /// by `bareiss_det`. /// -/// Returns the exact solution as a `Vec` of length `D`. +/// Returns the exact solution as `[BigRational; D]`. /// Returns `Err(LaError::Singular)` if the matrix is exactly singular. -fn gauss_solve(m: &Matrix, b: &Vector) -> Result, LaError> { - if D == 0 { - return Ok(Vec::new()); - } - +fn gauss_solve(m: &Matrix, b: &Vector) -> Result<[BigRational; D], LaError> { let zero = BigRational::from_integer(BigInt::from(0)); - // Build augmented matrix [A | b] as D × (D+1). - let mut aug: Vec> = Vec::with_capacity(D); - for r in 0..D { - let mut row = Vec::with_capacity(D + 1); - for c in 0..D { - row.push(f64_to_bigrational(m.rows[r][c])); - } - row.push(f64_to_bigrational(b.data[r])); - aug.push(row); - } + // Build matrix and RHS separately (cannot use [BigRational; D+1] augmented + // columns because const-generic expressions are unstable). + let mut mat: [[BigRational; D]; D] = + std::array::from_fn(|r| std::array::from_fn(|c| f64_to_bigrational(m.rows[r][c]))); + let mut rhs: [BigRational; D] = std::array::from_fn(|r| f64_to_bigrational(b.data[r])); // Forward elimination with first-non-zero pivoting. for k in 0..D { // Find first non-zero pivot in column k at or below row k. - if aug[k][k] == zero { - if let Some(swap_row) = ((k + 1)..D).find(|&i| aug[i][k] != zero) { - aug.swap(k, swap_row); + if mat[k][k] == zero { + if let Some(swap_row) = ((k + 1)..D).find(|&i| mat[i][k] != zero) { + mat.swap(k, swap_row); + rhs.swap(k, swap_row); } else { return Err(LaError::Singular { pivot_col: k }); } } // Eliminate below pivot. - let pivot = aug[k][k].clone(); + let pivot = mat[k][k].clone(); for i in (k + 1)..D { - if aug[i][k] != zero { - let factor = &aug[i][k] / &pivot; - // We need index `j` to read aug[k][j] and write aug[i][j] + if mat[i][k] != zero { + let factor = &mat[i][k] / &pivot; + // We need index `j` to read mat[k][j] and write mat[i][j] // (two distinct rows) — iterators can't borrow both. #[allow(clippy::needless_range_loop)] - for j in (k + 1)..=D { - let term = &factor * &aug[k][j]; - aug[i][j] -= term; + for j in (k + 1)..D { + let term = &factor * &mat[k][j]; + mat[i][j] -= term; } - aug[i][k] = zero.clone(); + let rhs_term = &factor * &rhs[k]; + rhs[i] -= rhs_term; + mat[i][k] = zero.clone(); } } } // Back-substitution. - let mut x: Vec = vec![zero; D]; + let mut x: [BigRational; D] = std::array::from_fn(|_| zero.clone()); for i in (0..D).rev() { - let mut sum = aug[i][D].clone(); + let mut sum = rhs[i].clone(); for j in (i + 1)..D { - sum -= &aug[i][j] * &x[j]; + sum -= &mat[i][j] * &x[j]; } - x[i] = sum / &aug[i][i]; + x[i] = sum / &mat[i][i]; } Ok(x) @@ -264,7 +255,7 @@ impl Matrix { if val.is_finite() { Ok(val) } else { - Err(LaError::Overflow) + Err(LaError::Overflow { index: None }) } } @@ -301,8 +292,7 @@ impl Matrix { pub fn solve_exact(&self, b: Vector) -> Result<[BigRational; D], LaError> { validate_finite(self)?; validate_finite_vec(&b)?; - let solution = gauss_solve(self, &b)?; - Ok(std::array::from_fn(|i| solution[i].clone())) + gauss_solve(self, &b) } /// Exact linear system solve converted to `f64`. @@ -336,7 +326,7 @@ impl Matrix { for (i, val) in exact.iter().enumerate() { let f = val.to_f64().unwrap_or(f64::INFINITY); if !f.is_finite() { - return Err(LaError::Overflow); + return Err(LaError::Overflow { index: Some(i) }); } result[i] = f; } @@ -431,14 +421,14 @@ mod tests { fn []() { let mut m = Matrix::<$d>::identity(); m.set(0, 0, f64::NAN); - assert_eq!(m.det_exact(), Err(LaError::NonFinite { col: 0 })); + assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] fn []() { let mut m = Matrix::<$d>::identity(); m.set(0, 0, f64::INFINITY); - assert_eq!(m.det_exact(), Err(LaError::NonFinite { col: 0 })); + assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -462,7 +452,7 @@ mod tests { fn []() { let mut m = Matrix::<$d>::identity(); m.set(0, 0, f64::NAN); - assert_eq!(m.det_exact_f64(), Err(LaError::NonFinite { col: 0 })); + assert_eq!(m.det_exact_f64(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -647,13 +637,25 @@ mod tests { #[test] fn det_sign_exact_returns_err_on_nan() { let m = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); - assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + m.det_sign_exact(), + Err(LaError::NonFinite { + row: Some(0), + col: 0 + }) + ); } #[test] fn det_sign_exact_returns_err_on_infinity() { let m = Matrix::<2>::from_rows([[f64::INFINITY, 0.0], [0.0, 1.0]]); - assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + m.det_sign_exact(), + Err(LaError::NonFinite { + row: Some(0), + col: 0 + }) + ); } #[test] @@ -661,14 +663,26 @@ mod tests { // D ≥ 5 bypasses the fast filter, exercising the bareiss_det path. let mut m = Matrix::<5>::identity(); m.set(2, 3, f64::NAN); - assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 3 })); + assert_eq!( + m.det_sign_exact(), + Err(LaError::NonFinite { + row: Some(2), + col: 3 + }) + ); } #[test] fn det_sign_exact_returns_err_on_infinity_5x5() { let mut m = Matrix::<5>::identity(); m.set(0, 0, f64::INFINITY); - assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + m.det_sign_exact(), + Err(LaError::NonFinite { + row: Some(0), + col: 0 + }) + ); } #[test] @@ -875,7 +889,7 @@ mod tests { let big = f64::MAX / 2.0; let m = Matrix::<3>::from_rows([[0.0, 0.0, 1.0], [big, 0.0, 1.0], [0.0, big, 1.0]]); // det = big^2, which overflows f64. - assert_eq!(m.det_exact_f64(), Err(LaError::Overflow)); + assert_eq!(m.det_exact_f64(), Err(LaError::Overflow { index: None })); } // ----------------------------------------------------------------------- @@ -910,7 +924,7 @@ mod tests { let mut a = Matrix::<$d>::identity(); a.set(0, 0, f64::NAN); let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 })); + assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] @@ -918,7 +932,7 @@ mod tests { let mut a = Matrix::<$d>::identity(); a.set(0, 0, f64::INFINITY); let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 })); + assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] @@ -927,7 +941,7 @@ mod tests { let mut b_arr = [1.0f64; $d]; b_arr[0] = f64::NAN; let b = Vector::<$d>::new(b_arr); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 })); + assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: None, col: 0 })); } #[test] @@ -936,7 +950,7 @@ mod tests { let mut b_arr = [1.0f64; $d]; b_arr[$d - 1] = f64::INFINITY; let b = Vector::<$d>::new(b_arr); - assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: $d - 1 })); + assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: None, col: $d - 1 })); } #[test] @@ -973,7 +987,7 @@ mod tests { let mut a = Matrix::<$d>::identity(); a.set(0, 0, f64::NAN); let b = arbitrary_rhs::<$d>(); - assert_eq!(a.solve_exact_f64(b), Err(LaError::NonFinite { col: 0 })); + assert_eq!(a.solve_exact_f64(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } }; @@ -1112,7 +1126,10 @@ mod tests { let big = f64::MAX / 2.0; let a = Matrix::<2>::from_rows([[1.0 / big, 0.0], [0.0, 1.0 / big]]); let b = Vector::<2>::new([big, big]); - assert_eq!(a.solve_exact_f64(b), Err(LaError::Overflow)); + assert_eq!( + a.solve_exact_f64(b), + Err(LaError::Overflow { index: Some(0) }) + ); } // ----------------------------------------------------------------------- @@ -1154,7 +1171,7 @@ mod tests { fn validate_finite_vec_err_on_nan() { assert_eq!( validate_finite_vec(&Vector::<2>::new([f64::NAN, 1.0])), - Err(LaError::NonFinite { col: 0 }) + Err(LaError::NonFinite { row: None, col: 0 }) ); } @@ -1162,7 +1179,7 @@ mod tests { fn validate_finite_vec_err_on_inf() { assert_eq!( validate_finite_vec(&Vector::<2>::new([1.0, f64::NEG_INFINITY])), - Err(LaError::NonFinite { col: 1 }) + Err(LaError::NonFinite { row: None, col: 1 }) ); } @@ -1179,13 +1196,25 @@ mod tests { fn validate_finite_err_on_nan() { let mut m = Matrix::<2>::identity(); m.set(1, 0, f64::NAN); - assert_eq!(validate_finite(&m), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + validate_finite(&m), + Err(LaError::NonFinite { + row: Some(1), + col: 0 + }) + ); } #[test] fn validate_finite_err_on_inf() { let mut m = Matrix::<2>::identity(); m.set(0, 1, f64::NEG_INFINITY); - assert_eq!(validate_finite(&m), Err(LaError::NonFinite { col: 1 })); + assert_eq!( + validate_finite(&m), + Err(LaError::NonFinite { + row: Some(0), + col: 1 + }) + ); } } diff --git a/src/ldlt.rs b/src/ldlt.rs index e8de3e0..0dfc028 100644 --- a/src/ldlt.rs +++ b/src/ldlt.rs @@ -39,7 +39,10 @@ impl Ldlt { for j in 0..D { let d = f.rows[j][j]; if !d.is_finite() { - return Err(LaError::NonFinite { col: j }); + return Err(LaError::NonFinite { + row: Some(j), + col: j, + }); } if d <= tol { return Err(LaError::Singular { pivot_col: j }); @@ -49,7 +52,10 @@ impl Ldlt { for i in (j + 1)..D { let l = f.rows[i][j] / d; if !l.is_finite() { - return Err(LaError::NonFinite { col: j }); + return Err(LaError::NonFinite { + row: Some(i), + col: j, + }); } f.rows[i][j] = l; } @@ -63,7 +69,10 @@ impl Ldlt { let l_k = f.rows[k][j]; let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]); if !new_val.is_finite() { - return Err(LaError::NonFinite { col: j }); + return Err(LaError::NonFinite { + row: Some(i), + col: k, + }); } f.rows[i][k] = new_val; } @@ -132,7 +141,7 @@ impl Ldlt { sum = (-row[j]).mul_add(*x_j, sum); } if !sum.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } x[i] = sum; } @@ -141,7 +150,7 @@ impl Ldlt { for (i, x_i) in x.iter_mut().enumerate().take(D) { let diag = self.factors.rows[i][i]; if !diag.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } if diag <= self.tol { return Err(LaError::Singular { pivot_col: i }); @@ -149,7 +158,7 @@ impl Ldlt { let v = *x_i / diag; if !v.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } *x_i = v; } @@ -162,7 +171,7 @@ impl Ldlt { sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum); } if !sum.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } x[i] = sum; } @@ -331,7 +340,13 @@ mod tests { fn nonfinite_detected() { let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!( + err, + LaError::NonFinite { + row: Some(0), + col: 0 + } + ); } #[test] @@ -339,7 +354,13 @@ mod tests { // d = 1e-11 > tol, but l = 1e300 / 1e-11 = 1e311 overflows f64. let a = Matrix::<2>::from_rows([[1e-11, 1e300], [1e300, 1.0]]); let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!( + err, + LaError::NonFinite { + row: Some(1), + col: 0 + } + ); } #[test] @@ -348,7 +369,13 @@ mod tests { // (-1e200 * 1.0) * 1e200 + 1.0 overflows. let a = Matrix::<2>::from_rows([[1.0, 1e200], [1e200, 1.0]]); let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!( + err, + LaError::NonFinite { + row: Some(1), + col: 1 + } + ); } #[test] @@ -364,6 +391,20 @@ mod tests { let b = Vector::<3>::new([1e156, 0.0, 0.0]); let err = ldlt.solve_vec(b).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 1 }); + assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); + } + + #[test] + fn nonfinite_solve_vec_back_substitution_overflow() { + // SPD matrix: [[1,0,0],[0,1,2],[0,2,5]] has LDLT factors + // D=[1,1,1], L[2,1]=2. Forward sub and diagonal solve produce + // z=[0,0,1e308]. Back-substitution: x[2]=1e308 then + // x[1] = 0 - 2*1e308 = -inf (overflows f64). + let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]]); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + let b = Vector::<3>::new([0.0, 0.0, 1e308]); + let err = ldlt.solve_vec(b).unwrap_err(); + assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); } } diff --git a/src/lib.rs b/src/lib.rs index 8c1583b..ae0ea84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,19 +118,25 @@ pub enum LaError { /// The factorization column/step where a suitable pivot/diagonal could not be found. pivot_col: usize, }, - /// A non-finite value (NaN/∞) was encountered in the input. + /// A non-finite value (NaN/∞) was encountered. NonFinite { - /// The column where a non-finite value was detected. + /// Row of the non-finite entry (for matrix inputs), or `None` when + /// the error originates from a vector input or a computed intermediate. + row: Option, + /// Column index (for matrix inputs), vector index, or factorization + /// step where the non-finite value was detected. col: usize, }, /// The exact result overflows the target representation (e.g. `f64`). /// - /// This is returned by `Matrix::det_exact_f64` (requires `exact` feature) - /// when the exact `BigRational` determinant is too large to represent as - /// a finite `f64`. - /// - /// *Added in 0.3.0.* - Overflow, + /// Returned by `Matrix::det_exact_f64` and `Matrix::solve_exact_f64` + /// (requires `exact` feature) when an exact value is too large to + /// represent as a finite `f64`. + Overflow { + /// For vector results (e.g. `solve_exact_f64`), the index of the + /// component that overflowed. `None` for scalar results. + index: Option, + }, } impl fmt::Display for LaError { @@ -139,10 +145,19 @@ impl fmt::Display for LaError { Self::Singular { pivot_col } => { write!(f, "singular matrix at pivot column {pivot_col}") } - Self::NonFinite { col } => { - write!(f, "non-finite value encountered at column {col}") + Self::NonFinite { row: Some(r), col } => { + write!(f, "non-finite value at ({r}, {col})") + } + Self::NonFinite { row: None, col } => { + write!(f, "non-finite value at index {col}") + } + Self::Overflow { index: Some(i) } => { + write!( + f, + "exact result overflows the target representation at index {i}" + ) } - Self::Overflow => { + Self::Overflow { index: None } => { write!(f, "exact result overflows the target representation") } } @@ -193,20 +208,38 @@ mod tests { } #[test] - fn laerror_display_formats_nonfinite() { - let err = LaError::NonFinite { col: 2 }; - assert_eq!(err.to_string(), "non-finite value encountered at column 2"); + fn laerror_display_formats_nonfinite_with_row() { + let err = LaError::NonFinite { + row: Some(1), + col: 2, + }; + assert_eq!(err.to_string(), "non-finite value at (1, 2)"); + } + + #[test] + fn laerror_display_formats_nonfinite_without_row() { + let err = LaError::NonFinite { row: None, col: 3 }; + assert_eq!(err.to_string(), "non-finite value at index 3"); } #[test] fn laerror_display_formats_overflow() { - let err = LaError::Overflow; + let err = LaError::Overflow { index: None }; assert_eq!( err.to_string(), "exact result overflows the target representation" ); } + #[test] + fn laerror_display_formats_overflow_with_index() { + let err = LaError::Overflow { index: Some(2) }; + assert_eq!( + err.to_string(), + "exact result overflows the target representation at index 2" + ); + } + #[test] fn laerror_is_std_error_with_no_source() { let err = LaError::Singular { pivot_col: 0 }; diff --git a/src/lu.rs b/src/lu.rs index 34797be..5128bab 100644 --- a/src/lu.rs +++ b/src/lu.rs @@ -31,13 +31,19 @@ impl Lu { let mut pivot_row = k; let mut pivot_abs = lu.rows[k][k].abs(); if !pivot_abs.is_finite() { - return Err(LaError::NonFinite { col: k }); + return Err(LaError::NonFinite { + row: Some(k), + col: k, + }); } for r in (k + 1)..D { let v = lu.rows[r][k].abs(); if !v.is_finite() { - return Err(LaError::NonFinite { col: k }); + return Err(LaError::NonFinite { + row: Some(r), + col: k, + }); } if v > pivot_abs { pivot_abs = v; @@ -57,14 +63,20 @@ impl Lu { let pivot = lu.rows[k][k]; if !pivot.is_finite() { - return Err(LaError::NonFinite { col: k }); + return Err(LaError::NonFinite { + row: Some(k), + col: k, + }); } // Eliminate below pivot. for r in (k + 1)..D { let mult = lu.rows[r][k] / pivot; if !mult.is_finite() { - return Err(LaError::NonFinite { col: k }); + return Err(LaError::NonFinite { + row: Some(r), + col: k, + }); } lu.rows[r][k] = mult; @@ -120,7 +132,7 @@ impl Lu { sum = (-row[j]).mul_add(*x_j, sum); } if !sum.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } x[i] = sum; } @@ -136,13 +148,17 @@ impl Lu { let diag = row[i]; if !diag.is_finite() || !sum.is_finite() { - return Err(LaError::NonFinite { col: i }); + return Err(LaError::NonFinite { row: None, col: i }); } if diag.abs() <= self.tol { return Err(LaError::Singular { pivot_col: i }); } - x[i] = sum / diag; + let q = sum / diag; + if !q.is_finite() { + return Err(LaError::NonFinite { row: None, col: i }); + } + x[i] = q; } Ok(Vector::new(x)) @@ -415,14 +431,26 @@ mod tests { fn nonfinite_detected_on_pivot_entry() { let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!( + err, + LaError::NonFinite { + row: Some(0), + col: 0 + } + ); } #[test] fn nonfinite_detected_in_pivot_column_scan() { let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]); let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!( + err, + LaError::NonFinite { + row: Some(1), + col: 0 + } + ); } #[test] @@ -433,7 +461,7 @@ mod tests { let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]); let err = lu.solve_vec(b).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 1 }); + assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); } #[test] @@ -444,6 +472,6 @@ mod tests { let b = Vector::<2>::new([0.0, 1.0e300]); let err = lu.solve_vec(b).unwrap_err(); - assert_eq!(err, LaError::NonFinite { col: 0 }); + assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); } } diff --git a/src/matrix.rs b/src/matrix.rs index c6bc0a5..85da5a3 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -296,7 +296,19 @@ impl Matrix { return if d.is_finite() { Ok(d) } else { - Err(LaError::NonFinite { col: 0 }) + // Scan for the first non-finite entry to preserve coordinates. + for r in 0..D { + for c in 0..D { + if !self.rows[r][c].is_finite() { + return Err(LaError::NonFinite { + row: Some(r), + col: c, + }); + } + } + } + // All entries are finite but the determinant overflowed. + Err(LaError::NonFinite { row: None, col: 0 }) }; } self.lu(tol).map(|lu| lu.det()) @@ -669,14 +681,26 @@ mod tests { #[test] fn det_returns_nonfinite_error_for_nan_d2() { let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); - assert_eq!(m.det(DEFAULT_PIVOT_TOL), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + m.det(DEFAULT_PIVOT_TOL), + Err(LaError::NonFinite { + row: Some(0), + col: 0 + }) + ); } #[test] fn det_returns_nonfinite_error_for_inf_d3() { let m = Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); - assert_eq!(m.det(DEFAULT_PIVOT_TOL), Err(LaError::NonFinite { col: 0 })); + assert_eq!( + m.det(DEFAULT_PIVOT_TOL), + Err(LaError::NonFinite { + row: Some(0), + col: 0 + }) + ); } #[test]