Skip to content

Commit 6b20e22

Browse files
committed
fix: remove unneeded trait bounds and use f64::total_cmp
also adds a comment demonstrating success of unwrap
1 parent 7b4180c commit 6b20e22

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/distribution/multinomial.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,10 @@ where
188188
fn sample_generic<D, R, T>(p: &[f64], n: u64, dim: D, rng: &mut R) -> OVector<T, D>
189189
where
190190
D: Dim,
191-
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
192191
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<T, D>,
193192
R: ::rand::Rng + ?Sized,
194-
T: ::num_traits::Num
195-
+ ::nalgebra::Scalar
193+
T: nalgebra::Scalar
194+
+ num_traits::Zero
196195
+ num_traits::AsPrimitive<u64>
197196
+ num_traits::FromPrimitive,
198197
super::Binomial: rand::distributions::Distribution<T>,
@@ -203,9 +202,7 @@ where
203202
let mut samples_left = n;
204203

205204
let mut p_sorted_inds: Vec<_> = (0..p.len()).collect();
206-
207-
// unwrap because NAN elements not allowed from this struct's `new`
208-
p_sorted_inds.sort_unstable_by(|&i, &j| p[j].partial_cmp(&p[i]).unwrap());
205+
p_sorted_inds.sort_unstable_by(|&i, &j| p[j].total_cmp(&p[i]));
209206

210207
for ind in p_sorted_inds.into_iter().take(p.len() - 1) {
211208
let pi = p[ind];
@@ -215,6 +212,11 @@ where
215212
if !(0.0..=1.0).contains(&probs_not_taken) || samples_left == 0 {
216213
break;
217214
}
215+
// since $p_j \le p_i \forall j < i$ and $\vec{p}$ is normalized, then
216+
// $1 - sum(p_j, j, 0, i-1) = sum(p_j, j, i, n) = p_i + sum(p_j, j, i+1, n) > p_i$
217+
// this guarantees that logically p_binom on [0,1]
218+
// TODO: demonstrate that this behavior also behaves well with floating point
219+
// the logical reasoning provides that `unwrap` of Binomial::new will typically succeed
218220
let p_binom = pi / probs_not_taken;
219221
res[ind] = super::Binomial::new(p_binom, samples_left)
220222
.unwrap()

0 commit comments

Comments
 (0)