11//! Provides the functions related to [Chi-Squared tests](https://en.wikipedia.org/wiki/Chi-squared_test)
22
33use crate :: distribution:: { ChiSquared , ContinuousCDF } ;
4+ use crate :: prec;
45
56/// Represents the errors that can occur when computing the chisquare function
67#[ derive( Copy , Clone , PartialEq , Eq , Debug , Hash ) ]
@@ -70,25 +71,40 @@ pub fn chisquare(
7071 if n <= 1 {
7172 return Err ( ChiSquareTestError :: FObsInvalid ) ;
7273 }
73- let total_samples: usize = f_obs. iter ( ) . sum ( ) ;
74- let f_obs: Vec < f64 > = f_obs. iter ( ) . map ( |x| * x as f64 ) . collect ( ) ;
75-
76- let f_exp = match f_exp {
77- Some ( f_to_validate) => {
78- // same length check
79- if f_to_validate. len ( ) != n {
80- return Err ( ChiSquareTestError :: FExpInvalid ) ;
81- }
82- // same sum check
83- if f_to_validate. iter ( ) . sum :: < f64 > ( ) as usize != total_samples {
84- return Err ( ChiSquareTestError :: FExpInvalid ) ;
85- }
86- f_to_validate. to_vec ( )
74+ let stat = if let Some ( f_exp) = f_exp {
75+ if f_exp. len ( ) != n {
76+ return Err ( ChiSquareTestError :: FExpInvalid ) ;
8777 }
88- None => {
89- // make the expected assuming equal frequency
90- vec ! [ total_samples as f64 / n as f64 ; n]
78+
79+ let mut total_samples = 0.0 ;
80+ let mut sum_expected = 0.0 ;
81+
82+ let mut stat = 0.0 ;
83+
84+ for ( obs, exp) in f_obs. iter ( ) . zip ( f_exp) {
85+ let obs = * obs as f64 ;
86+
87+ stat += ( obs - exp) . powi ( 2 ) / exp;
88+
89+ total_samples += obs;
90+ sum_expected += exp;
91+ }
92+
93+ if !prec:: relative_eq!( total_samples, sum_expected) {
94+ return Err ( ChiSquareTestError :: FExpInvalid ) ;
9195 }
96+
97+ stat
98+ } else {
99+ let total_samples: usize = f_obs. iter ( ) . sum ( ) ;
100+ // Assume all frequencies are equally likely
101+ let exp = total_samples as f64 / n as f64 ;
102+
103+ f_obs
104+ . iter ( )
105+ . map ( |obs| * obs as f64 )
106+ . map ( |obs| ( obs - exp) . powi ( 2 ) / exp)
107+ . sum ( )
92108 } ;
93109
94110 let ddof = match ddof {
@@ -102,12 +118,6 @@ pub fn chisquare(
102118 } ;
103119 let dof = n - 1 - ddof;
104120
105- let stat = f_obs
106- . into_iter ( )
107- . zip ( f_exp)
108- . map ( |( o, e) | ( o - e) . powi ( 2 ) / e)
109- . sum :: < f64 > ( ) ;
110-
111121 let chi_dist = ChiSquared :: new ( dof as f64 ) . expect ( "ddof validity should already be checked" ) ;
112122 let pvalue = 1.0 - chi_dist. cdf ( stat) ;
113123
0 commit comments