diff --git a/src/primitives.rs b/src/primitives.rs index 5d79a5f6c..d1ace2dfe 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -66,8 +66,18 @@ pub(crate) const fn mac(a: Word, b: Word, c: Word, carry: Word) -> (Word, Word) (ret as Word, (ret >> Word::BITS) as Word) } -/// Computes `(a * b) % d`. +/// Computes `(a * b) % d`. Not constant time in `d`. #[inline(always)] -pub(crate) const fn mul_rem(a: Word, b: Word, d: Word) -> Word { +pub(crate) const fn mul_rem_vartime(a: Word, b: Word, d: Word) -> Word { ((a as WideWord * b as WideWord) % (d as WideWord)) as Word } + +/// Returns `(hi * 2^Word::BITS + lo) / d`. +/// Not constant-time in `d`. +/// Assumes that the result fits in `Word`. +#[inline(always)] +pub(crate) const fn div_wide_vartime(hi: Word, lo: Word, d: Word) -> Word { + let q = (((hi as WideWord) << Word::BITS) + (lo as WideWord)) / (d as WideWord); + debug_assert!(q <= Word::MAX as WideWord); + q as Word +} diff --git a/src/uint/boxed/mul_mod.rs b/src/uint/boxed/mul_mod.rs index 8c6b7ee15..65b08b93b 100644 --- a/src/uint/boxed/mul_mod.rs +++ b/src/uint/boxed/mul_mod.rs @@ -2,7 +2,7 @@ use crate::{ modular::{BoxedMontyForm, BoxedMontyParams}, - primitives::mul_rem, + primitives::mul_rem_vartime, BoxedUint, Limb, MulMod, Odd, WideWord, Word, }; @@ -42,7 +42,7 @@ impl BoxedUint { // We implicitly assume `LIMBS > 0`, because `Uint<0>` doesn't compile. // Still the case `LIMBS == 1` needs special handling. if self.nlimbs() == 1 { - let reduced = mul_rem(self.limbs[0].0, rhs.limbs[0].0, Word::MIN.wrapping_sub(c.0)); + let reduced = mul_rem_vartime(self.limbs[0].0, rhs.limbs[0].0, Word::MIN.wrapping_sub(c.0)); return Self::from(reduced); } diff --git a/src/uint/div.rs b/src/uint/div.rs index d650235d8..f090843af 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -1,7 +1,7 @@ //! [`Uint`] division operations. use super::div_limb::{div_rem_limb_with_reciprocal, Reciprocal}; -use crate::{CheckedDiv, ConstChoice, Limb, NonZero, Uint, Word, Wrapping}; +use crate::{CheckedDiv, ConstChoice, Limb, NonZero, Uint, Word, Wrapping, primitives::div_wide_vartime}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::CtOption; @@ -58,32 +58,72 @@ impl Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. #[allow(trivial_numeric_casts)] - pub const fn div_rem_vartime(&self, rhs: &NonZero) -> (Self, Self) { - let mb = rhs.0.bits_vartime(); - let mut bd = Self::BITS - mb; - let mut rem = *self; - let mut quo = Self::ZERO; - // If there is overflow, it means `mb == 0`, so `rhs == 0`. - let mut c = rhs.0.wrapping_shl_vartime(bd); + pub fn div_rem_vartime(&self, rhs: &NonZero) -> (Self, Self) { + let rhs = &rhs.0; + let n = ((self.bits_vartime() + Limb::BITS - 1) / Limb::BITS) as usize; + let t = ((rhs.bits_vartime() + Limb::BITS - 1) / Limb::BITS) as usize; - loop { - let (mut r, borrow) = rem.sbb(&c, Limb::ZERO); - rem = Self::select(&r, &rem, ConstChoice::from_word_mask(borrow.0)); - r = quo.bitor(&Self::ONE); - quo = Self::select(&r, &quo, ConstChoice::from_word_mask(borrow.0)); - if bd == 0 { - break; + extern crate std; + + if n < t { + return (Self::ZERO, *self); + } + + let mut x = *self; + let y = *rhs; + let mut q = Self::ZERO; + + let shift = (n - t) as u32 * Limb::BITS; + std::println!("{}", x); + std::println!("{}", y.wrapping_shl_vartime(shift)); + while x >= y.wrapping_shl_vartime(shift) { + q.limbs[n-t].0 += 1; + x = x - y.wrapping_shl_vartime(shift); + } + + std::println!("shifted"); + + for i in (t+1..=n).rev() { + if x.limbs[i] == y.limbs[t] { + q.limbs[i-t-1] = Limb::MAX; + } + else { + q.limbs[i-t-1] = Limb(div_wide_vartime(x.limbs[i].0, x.limbs[i-1].0, y.limbs[t].0)); + } + + loop { + let (limb2, limb1_1) = q.limbs[i-t-1].mul_wide(y.limbs[t]); + let (limb1_2, limb0) = q.limbs[i-t-1].mul_wide(y.limbs[t-1]); + let (limb1, carry) = limb1_1.overflowing_add(limb1_2); + let limb2 = limb2 + carry; + + let u1 = Uint::<3>::new([limb0, limb1, limb2]); + let u2 = Uint::<3>::new([x.limbs[i-2], x.limbs[i-1], x.limbs[i]]); + + if u1 <= u2 { + break; + } + + q.limbs[i-t-1].0 -= 1; + } + + let shift = (i - t - 1) as u32 * Limb::BITS; + let (yq, _) = y.mul_limb_vartime(q.limbs[i-t-1]); + let (new_x, borrow) = x.sbb(&yq.wrapping_shl_vartime(shift), Limb::ZERO); + if borrow != Limb::ZERO { + x = new_x.wrapping_add(&y.wrapping_shl_vartime(shift)); + q.limbs[i-t-1].0 -= 1; + } + else { + x = new_x; } - bd -= 1; - c = c.shr1(); - quo = quo.shl1(); } - (quo, rem) + (q, x) } /// Computes `self` % `rhs`, returns the remainder. - pub const fn rem(&self, rhs: &NonZero) -> Self { + pub fn rem(&self, rhs: &NonZero) -> Self { self.div_rem_vartime(rhs).1 } @@ -184,7 +224,7 @@ impl Uint { /// /// There’s no way wrapping could ever happen. /// This function exists, so that all operations are accounted for in the wrapping operations. - pub const fn wrapping_div_vartime(&self, rhs: &NonZero) -> Self { + pub fn wrapping_div_vartime(&self, rhs: &NonZero) -> Self { let (q, _) = self.div_rem_vartime(rhs); q } @@ -639,6 +679,22 @@ mod tests { } } + #[test] + fn div_rem_vartime2() { + let a = + U256::from_be_hex("0001630DA85CF425FC0DC8C5D061F3CCB6857377C47B4C4453C3B1802C8EB9D4"); + let b = + U256::from_be_hex("0000000000000000000000000000000371504F1CD33B86D7C773D1754918AE29"); + + let (q, r) = a.div_rem_vartime(&NonZero::new(b).unwrap()); + + assert_eq!(q, + U256::from_be_hex("000000000000000000000000000000000000672260662b797714a02d29017bce")); + assert_eq!(r, + U256::from_be_hex("000000000000000000000000000000030b36e3a65057473ca45f50ce3fdbe1d6")); + + } + #[test] fn div_max() { let mut a = U256::ZERO; diff --git a/src/uint/mul.rs b/src/uint/mul.rs index fb3175e6e..fe2250b20 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -54,6 +54,18 @@ macro_rules! impl_schoolbook_multiplication { } impl Uint { + pub(crate) fn mul_limb_vartime(&self, rhs: Limb) -> (Self, Limb) { + let mut carry = Limb::ZERO; + let mut result = Self::ZERO; + let limbs = ((self.bits_vartime() + Limb::BITS - 1) / Limb::BITS) as usize; + for i in 0..limbs { + let (x, new_carry) = self.limbs[i].mul_wide(rhs); + result.limbs[i] = x + carry; + carry = new_carry; + } + (result, carry) + } + /// Multiply `self` by `rhs`, returning a concatenated "wide" result. pub fn widening_mul( &self, diff --git a/src/uint/mul_mod.rs b/src/uint/mul_mod.rs index bbfd38097..02ae4c893 100644 --- a/src/uint/mul_mod.rs +++ b/src/uint/mul_mod.rs @@ -2,7 +2,7 @@ use crate::{ modular::{MontyForm, MontyParams}, - primitives::mul_rem, + primitives::mul_rem_vartime, Limb, MulMod, Uint, WideWord, Word, }; @@ -40,7 +40,7 @@ impl Uint { // We implicitly assume `LIMBS > 0`, because `Uint<0>` doesn't compile. // Still the case `LIMBS == 1` needs special handling. if LIMBS == 1 { - let reduced = mul_rem(self.limbs[0].0, rhs.limbs[0].0, Word::MIN.wrapping_sub(c.0)); + let reduced = mul_rem_vartime(self.limbs[0].0, rhs.limbs[0].0, Word::MIN.wrapping_sub(c.0)); return Self::from_word(reduced); } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 43dfd2f72..0ec28c754 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -49,7 +49,7 @@ impl Uint { /// Computes √(`self`) /// /// Callers can check if `self` is a square by squaring the result - pub const fn sqrt_vartime(&self) -> Self { + pub fn sqrt_vartime(&self) -> Self { // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. @@ -92,7 +92,7 @@ impl Uint { /// Wrapped sqrt is just normal √(`self`) /// There’s no way wrapping could ever happen. /// This function exists so that all operations are accounted for in the wrapping operations. - pub const fn wrapping_sqrt_vartime(&self) -> Self { + pub fn wrapping_sqrt_vartime(&self) -> Self { self.sqrt_vartime() } diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index da4c5827f..7a97d0a7f 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -32,6 +32,12 @@ prop_compose! { U256::from_le_slice(&bytes) } } +prop_compose! { + fn varsize_uint()(bytes in any::<[u8; 32]>(), size in any::()) -> U256 { + let size = size % 255 + 1; + U256::from_le_slice(&bytes) >> (256 - size) + } +} prop_compose! { fn uint_mod_p(p: Odd)(a in uint()) -> U256 { a.wrapping_rem(&p) @@ -275,6 +281,24 @@ proptest! { } } + #[test] + fn div_rem_vartime(a in varsize_uint(), b in varsize_uint()) { + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b); + + println!("\n\n--- {} {}", a_bi, b_bi); + println!("--- {} {}", a, b); + + if !b_bi.is_zero() { + let expected_q = to_uint(&a_bi / &b_bi); + let expected_r = to_uint(&a_bi % &b_bi); + let b_nz = NonZero::new(b).unwrap(); + let (actual_q, actual_r) = a.div_rem_vartime(&b_nz); + assert_eq!(expected_q, actual_q); + assert_eq!(expected_r, actual_r); + } + } + #[test] fn gcd(mut f in uint(), g in uint()) { if f.is_even().into() {