Skip to content

Commit 013bb58

Browse files
authored
Merge pull request #336 from FreezyLemon/optimize-chisquare-test
perf: Various improvements to chisquare
2 parents 2a21f02 + 0af8efb commit 013bb58

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

src/stats_tests/chisquare.rs

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Provides the functions related to [Chi-Squared tests](https://en.wikipedia.org/wiki/Chi-squared_test)
22
33
use crate::distribution::{ChiSquared, ContinuousCDF};
4+
use crate::prec;
45

56
/// Represents the errors that can occur when computing the chisquare function
67
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
@@ -70,25 +71,40 @@ pub fn chisquare(
7071
if n <= 1 {
7172
return Err(ChiSquareTestError::FObsInvalid);
7273
}
73-
let total_samples: usize = f_obs.iter().sum();
74-
let f_obs: Vec<f64> = f_obs.iter().map(|x| *x as f64).collect();
75-
76-
let f_exp = match f_exp {
77-
Some(f_to_validate) => {
78-
// same length check
79-
if f_to_validate.len() != n {
80-
return Err(ChiSquareTestError::FExpInvalid);
81-
}
82-
// same sum check
83-
if f_to_validate.iter().sum::<f64>() as usize != total_samples {
84-
return Err(ChiSquareTestError::FExpInvalid);
85-
}
86-
f_to_validate.to_vec()
74+
let stat = if let Some(f_exp) = f_exp {
75+
if f_exp.len() != n {
76+
return Err(ChiSquareTestError::FExpInvalid);
8777
}
88-
None => {
89-
// make the expected assuming equal frequency
90-
vec![total_samples as f64 / n as f64; n]
78+
79+
let mut total_samples = 0.0;
80+
let mut sum_expected = 0.0;
81+
82+
let mut stat = 0.0;
83+
84+
for (obs, exp) in f_obs.iter().zip(f_exp) {
85+
let obs = *obs as f64;
86+
87+
stat += (obs - exp).powi(2) / exp;
88+
89+
total_samples += obs;
90+
sum_expected += exp;
91+
}
92+
93+
if !prec::relative_eq!(total_samples, sum_expected) {
94+
return Err(ChiSquareTestError::FExpInvalid);
9195
}
96+
97+
stat
98+
} else {
99+
let total_samples: usize = f_obs.iter().sum();
100+
// Assume all frequencies are equally likely
101+
let exp = total_samples as f64 / n as f64;
102+
103+
f_obs
104+
.iter()
105+
.map(|obs| *obs as f64)
106+
.map(|obs| (obs - exp).powi(2) / exp)
107+
.sum()
92108
};
93109

94110
let ddof = match ddof {
@@ -102,12 +118,6 @@ pub fn chisquare(
102118
};
103119
let dof = n - 1 - ddof;
104120

105-
let stat = f_obs
106-
.into_iter()
107-
.zip(f_exp)
108-
.map(|(o, e)| (o - e).powi(2) / e)
109-
.sum::<f64>();
110-
111121
let chi_dist = ChiSquared::new(dof as f64).expect("ddof validity should already be checked");
112122
let pvalue = 1.0 - chi_dist.cdf(stat);
113123

0 commit comments

Comments
 (0)