diff --git a/Cargo.toml b/Cargo.toml index b1394aa75..cd6939a33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ itertools = { version = "0.7.0", default-features = false } rayon = { version = "1.0.3", optional = true } -approx = { version = "0.3", optional = true } +approx = { version = "0.3.2", optional = true } # Use via the `blas` crate feature! cblas-sys = { version = "0.1.4", optional = true, default-features = false } @@ -49,6 +49,8 @@ serde = { version = "1.0", optional = true } defmac = "0.2" quickcheck = { version = "0.7.2", default-features = false } rawpointer = "0.1" +itertools = { version = "0.7.0", default-features = false, features = ["use_std"] } +approx = "0.3.2" [features] # Enable blas usage diff --git a/examples/column_standardize.rs b/examples/column_standardize.rs index 032c520e3..80ca7332a 100644 --- a/examples/column_standardize.rs +++ b/examples/column_standardize.rs @@ -23,9 +23,9 @@ fn main() { [ 2., 2., 2.]]; println!("{:8.4}", data); - println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0))); + println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap()); - data -= &data.mean_axis(Axis(0)); + data -= &data.mean_axis(Axis(0)).unwrap(); println!("{:8.4}", data); data /= &std(&data, Axis(0)); diff --git a/src/array_approx.rs b/src/array_approx.rs index f8adeb09e..eb0a32fde 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -1,13 +1,14 @@ use crate::imp_prelude::*; -use crate::{FoldWhile, Zip}; +use crate::Zip; use approx::{AbsDiffEq, RelativeEq, UlpsEq}; /// **Requires crate feature `"approx"`** -impl AbsDiffEq for ArrayBase +impl AbsDiffEq> for ArrayBase where A: AbsDiffEq, A::Epsilon: Clone, S: Data, + T: Data, D: Dimension, { type Epsilon = A::Epsilon; @@ -16,29 +17,23 @@ where A::default_epsilon() } - fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { if self.shape() != other.shape() { return false; } Zip::from(self) .and(other) - .fold_while(true, |_, a, b| { - if A::abs_diff_ne(a, b, epsilon.clone()) { - FoldWhile::Done(false) - } else { - FoldWhile::Continue(true) - } - }) - .into_inner() + .all(|a, b| A::abs_diff_eq(a, b, epsilon.clone())) } } /// **Requires crate feature `"approx"`** -impl RelativeEq for ArrayBase +impl RelativeEq> for ArrayBase where A: RelativeEq, A::Epsilon: Clone, S: Data, + T: Data, D: Dimension, { fn default_max_relative() -> A::Epsilon { @@ -47,7 +42,7 @@ where fn relative_eq( &self, - other: &ArrayBase, + other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon, ) -> bool { @@ -56,43 +51,30 @@ where } Zip::from(self) .and(other) - .fold_while(true, |_, a, b| { - if A::relative_ne(a, b, epsilon.clone(), max_relative.clone()) { - FoldWhile::Done(false) - } else { - FoldWhile::Continue(true) - } - }) - .into_inner() + .all(|a, b| A::relative_eq(a, b, epsilon.clone(), max_relative.clone())) } } /// **Requires crate feature `"approx"`** -impl UlpsEq for ArrayBase +impl UlpsEq> for ArrayBase where A: UlpsEq, A::Epsilon: Clone, S: Data, + T: Data, D: Dimension, { fn default_max_ulps() -> u32 { A::default_max_ulps() } - fn ulps_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_ulps: u32) -> bool { + fn ulps_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_ulps: u32) -> bool { if self.shape() != other.shape() { return false; } Zip::from(self) .and(other) - .fold_while(true, |_, a, b| { - if A::ulps_ne(a, b, epsilon.clone(), max_ulps) { - FoldWhile::Done(false) - } else { - FoldWhile::Continue(true) - } - }) - .into_inner() + .all(|a, b| A::ulps_eq(a, b, epsilon.clone(), max_ulps)) } } diff --git a/src/arrayformat.rs b/src/arrayformat.rs index d13add8f7..894215d8c 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -8,77 +8,115 @@ use std::fmt; use super::{ ArrayBase, + Axis, Data, Dimension, NdProducer, + Ix }; -use crate::dimension::IntoDimension; - -fn format_array(view: &ArrayBase, f: &mut fmt::Formatter, - mut format: F) - -> fmt::Result - where F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result, - D: Dimension, - S: Data, +use crate::aliases::Ix1; + +const PRINT_ELEMENTS_LIMIT: Ix = 3; + +fn format_1d_array( + view: &ArrayBase, + f: &mut fmt::Formatter, + mut format: F, + limit: Ix) -> fmt::Result + where + F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result, + S: Data, { - let ndim = view.dim.slice().len(); - /* private nowadays - if ndim > 0 && f.width.is_none() { - f.width = Some(4) - } - */ - // None will be an empty iter. - let mut last_index = match view.dim.first_index() { - None => view.dim.clone(), - Some(ix) => ix, - }; - for _ in 0..ndim { - write!(f, "[")?; - } - let mut first = true; - // Simply use the indexed iterator, and take the index wraparounds - // as cues for when to add []'s and how many to add. - for (index, elt) in view.indexed_iter() { - let index = index.into_dimension(); - let take_n = if ndim == 0 { 1 } else { ndim - 1 }; - let mut update_index = false; - for (i, (a, b)) in index.slice() - .iter() - .take(take_n) - .zip(last_index.slice().iter()) - .enumerate() { - if a != b { - // New row. - // # of ['s needed - let n = ndim - i - 1; - for _ in 0..n { - write!(f, "]")?; - } - write!(f, ",")?; - write!(f, "\n")?; - for _ in 0..ndim - n { - write!(f, " ")?; - } - for _ in 0..n { - write!(f, "[")?; + let to_be_printed = to_be_printed(view.len(), limit); + + let n_to_be_printed = to_be_printed.len(); + + write!(f, "[")?; + for (j, index) in to_be_printed.into_iter().enumerate() { + match index { + PrintableCell::ElementIndex(i) => { + format(&view[i], f)?; + if j != n_to_be_printed - 1 { + write!(f, ", ")?; } - first = true; - update_index = true; - break; - } - } - if !first { - write!(f, ", ")?; + }, + PrintableCell::Ellipses => write!(f, "..., ")?, } - first = false; - format(elt, f)?; + } + write!(f, "]")?; + Ok(()) +} - if update_index { - last_index = index; - } +enum PrintableCell { + ElementIndex(usize), + Ellipses, +} + +// Returns what indexes should be printed for a certain axis. +// If the axis is longer than 2 * limit, a `Ellipses` is inserted +// where indexes are being omitted. +fn to_be_printed(length: usize, limit: usize) -> Vec { + if length <= 2 * limit { + (0..length).map(|x| PrintableCell::ElementIndex(x)).collect() + } else { + let mut v: Vec = (0..limit).map(|x| PrintableCell::ElementIndex(x)).collect(); + v.push(PrintableCell::Ellipses); + v.extend((length-limit..length).map(|x| PrintableCell::ElementIndex(x))); + v } - for _ in 0..ndim { - write!(f, "]")?; +} + +fn format_array( + view: &ArrayBase, + f: &mut fmt::Formatter, + mut format: F, + limit: Ix) -> fmt::Result +where + F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result + Clone, + D: Dimension, + S: Data, +{ + // If any of the axes has 0 length, we return the same empty array representation + // e.g. [[]] for 2-d arrays + if view.shape().iter().any(|&x| x == 0) { + write!(f, "{}{}", "[".repeat(view.ndim()), "]".repeat(view.ndim()))?; + return Ok(()) + } + match view.shape() { + // If it's 0 dimensional, we just print out the scalar + [] => format(view.iter().next().unwrap(), f)?, + // We delegate 1-dimensional arrays to a specialized function + [_] => format_1d_array(&view.view().into_dimensionality::().unwrap(), f, format, limit)?, + // For n-dimensional arrays, we proceed recursively + shape => { + // Cast into a dynamically dimensioned view + // This is required to be able to use `index_axis` + let view = view.view().into_dyn(); + // We start by checking what indexes from the first axis should be printed + // We put a `None` in the middle if we are omitting elements + let to_be_printed = to_be_printed(shape[0], limit); + + let n_to_be_printed = to_be_printed.len(); + + write!(f, "[")?; + for (j, index) in to_be_printed.into_iter().enumerate() { + match index { + PrintableCell::ElementIndex(i) => { + // Proceed recursively with the (n-1)-dimensional slice + format_array( + &view.index_axis(Axis(0), i), f, format.clone(), limit + )?; + // We need to add a separator after each slice, + // apart from the last one + if j != n_to_be_printed - 1 { + write!(f, ",\n ")? + } + }, + PrintableCell::Ellipses => write!(f, "...,\n ")? + } + } + write!(f, "]")?; + } } Ok(()) } @@ -92,7 +130,7 @@ impl<'a, A: fmt::Display, S, D: Dimension> fmt::Display for ArrayBase where S: Data, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format_array(self, f, <_>::fmt) + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT) } } @@ -105,7 +143,7 @@ impl<'a, A: fmt::Debug, S, D: Dimension> fmt::Debug for ArrayBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // Add extra information for Debug - format_array(self, f, <_>::fmt)?; + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT)?; write!(f, " shape={:?}, strides={:?}, layout={:?}", self.shape(), self.strides(), layout=self.view().layout())?; match D::NDIM { @@ -124,7 +162,7 @@ impl<'a, A: fmt::LowerExp, S, D: Dimension> fmt::LowerExp for ArrayBase where S: Data, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format_array(self, f, <_>::fmt) + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT) } } @@ -136,7 +174,7 @@ impl<'a, A: fmt::UpperExp, S, D: Dimension> fmt::UpperExp for ArrayBase where S: Data, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format_array(self, f, <_>::fmt) + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT) } } /// Format the array using `LowerHex` and apply the formatting parameters used @@ -147,7 +185,7 @@ impl<'a, A: fmt::LowerHex, S, D: Dimension> fmt::LowerHex for ArrayBase where S: Data, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format_array(self, f, <_>::fmt) + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT) } } @@ -159,6 +197,161 @@ impl<'a, A: fmt::Binary, S, D: Dimension> fmt::Binary for ArrayBase where S: Data, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format_array(self, f, <_>::fmt) + format_array(self, f, <_>::fmt, PRINT_ELEMENTS_LIMIT) + } +} + +#[cfg(test)] +mod formatting_with_omit { + use crate::prelude::*; + use super::*; + + fn print_output_diff(expected: &str, actual: &str) { + println!("Expected output:\n{}\nActual output:\n{}", expected, actual); + } + + #[test] + fn empty_arrays() { + let a: Array2 = arr2(&[[], []]); + let actual_output = format!("{}", a); + let expected_output = String::from("[[]]"); + print_output_diff(&expected_output, &actual_output); + assert_eq!(expected_output, actual_output); + } + + #[test] + fn zero_length_axes() { + let a = Array3::::zeros((3, 0, 4)); + let actual_output = format!("{}", a); + let expected_output = String::from("[[[]]]"); + print_output_diff(&expected_output, &actual_output); + assert_eq!(expected_output, actual_output); + } + + #[test] + fn dim_0() { + let element = 12; + let a = arr0(element); + let actual_output = format!("{}", a); + let expected_output = format!("{}", element); + print_output_diff(&expected_output, &actual_output); + assert_eq!(expected_output, actual_output); + } + + #[test] + fn dim_1() { + let overflow: usize = 5; + let a = Array1::from_elem((PRINT_ELEMENTS_LIMIT * 2 + overflow, ), 1); + let mut expected_output = String::from("["); + a.iter() + .take(PRINT_ELEMENTS_LIMIT) + .for_each(|elem| { expected_output.push_str(format!("{}, ", elem).as_str()) }); + expected_output.push_str("..."); + a.iter() + .skip(PRINT_ELEMENTS_LIMIT + overflow) + .for_each(|elem| { expected_output.push_str(format!(", {}", elem).as_str()) }); + expected_output.push(']'); + let actual_output = format!("{}", a); + + print_output_diff(&expected_output, &actual_output); + assert_eq!(actual_output, expected_output); + } + + #[test] + fn dim_2_last_axis_overflow() { + let overflow: usize = 3; + let a = Array2::from_elem((PRINT_ELEMENTS_LIMIT, PRINT_ELEMENTS_LIMIT * 2 + overflow), 1); + let mut expected_output = String::from("["); + + for i in 0..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!("[{}", a[(i, 0)]).as_str()); + for j in 1..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(", ..."); + for j in PRINT_ELEMENTS_LIMIT + overflow..PRINT_ELEMENTS_LIMIT * 2 + overflow { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(if i < PRINT_ELEMENTS_LIMIT - 1 { "],\n " } else { "]" }); + } + expected_output.push(']'); + let actual_output = format!("{}", a); + + print_output_diff(&expected_output, &actual_output); + assert_eq!(actual_output, expected_output); + } + + #[test] + fn dim_2_non_last_axis_overflow() { + let overflow: usize = 5; + let a = Array2::from_elem((PRINT_ELEMENTS_LIMIT * 2 + overflow, PRINT_ELEMENTS_LIMIT), 1); + let mut expected_output = String::from("["); + + for i in 0..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!("[{}", a[(i, 0)]).as_str()); + for j in 1..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str("],\n "); + } + expected_output.push_str("...,\n "); + for i in PRINT_ELEMENTS_LIMIT + overflow..PRINT_ELEMENTS_LIMIT * 2 + overflow { + expected_output.push_str(format!("[{}", a[(i, 0)]).as_str()); + for j in 1..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(if i == PRINT_ELEMENTS_LIMIT * 2 + overflow - 1 { + "]" + } else { + "],\n " + }); + } + expected_output.push(']'); + let actual_output = format!("{}", a); + + print_output_diff(&expected_output, &actual_output); + assert_eq!(actual_output, expected_output); + } + + #[test] + fn dim_2_multi_directional_overflow() { + let overflow: usize = 5; + let a = Array2::from_elem( + (PRINT_ELEMENTS_LIMIT * 2 + overflow, PRINT_ELEMENTS_LIMIT * 2 + overflow), 1 + ); + let mut expected_output = String::from("["); + + for i in 0..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!("[{}", a[(i, 0)]).as_str()); + for j in 1..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(", ..."); + for j in PRINT_ELEMENTS_LIMIT + overflow..PRINT_ELEMENTS_LIMIT * 2 + overflow { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str("],\n "); + } + expected_output.push_str("...,\n "); + for i in PRINT_ELEMENTS_LIMIT + overflow..PRINT_ELEMENTS_LIMIT * 2 + overflow { + expected_output.push_str(format!("[{}", a[(i, 0)]).as_str()); + for j in 1..PRINT_ELEMENTS_LIMIT { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(", ..."); + for j in PRINT_ELEMENTS_LIMIT + overflow..PRINT_ELEMENTS_LIMIT * 2 + overflow { + expected_output.push_str(format!(", {}", a[(i, j)]).as_str()); + } + expected_output.push_str(if i == PRINT_ELEMENTS_LIMIT * 2 + overflow - 1 { + "]" + } else { + "],\n " + }); + } + expected_output.push(']'); + let actual_output = format!("{}", a); + + print_output_diff(&expected_output, &actual_output); + assert_eq!(actual_output, expected_output); } } diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index b268c9b24..79acad788 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -473,7 +473,7 @@ //! //! //! -//! [`a.all_close(&b, 1e-8)`][.all_close()] +//! [`a.abs_diff_eq(&b, 1e-8)`][.abs_diff_eq()] //! //! //! @@ -557,7 +557,7 @@ //! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array //! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square //! -//! [.all_close()]: ../../struct.ArrayBase.html#method.all_close +//! [.abs_diff_eq()]: ../../struct.ArrayBase.html#impl-AbsDiffEq> //! [ArcArray]: ../../type.ArcArray.html //! [arr2()]: ../../fn.arr2.html //! [array!]: ../../macro.array.html diff --git a/src/geomspace.rs b/src/geomspace.rs new file mode 100644 index 000000000..93b73329b --- /dev/null +++ b/src/geomspace.rs @@ -0,0 +1,174 @@ +// Copyright 2014-2016 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use num_traits::Float; + +/// An iterator of a sequence of geometrically spaced floats. +/// +/// Iterator element type is `F`. +pub struct Geomspace { + sign: F, + start: F, + step: F, + index: usize, + len: usize, +} + +impl Iterator for Geomspace +where + F: Float, +{ + type Item = F; + + #[inline] + fn next(&mut self) -> Option { + if self.index >= self.len { + None + } else { + // Calculate the value just like numpy.linspace does + let i = self.index; + self.index += 1; + let exponent = self.start + self.step * F::from(i).unwrap(); + Some(self.sign * exponent.exp()) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let n = self.len - self.index; + (n, Some(n)) + } +} + +impl DoubleEndedIterator for Geomspace +where + F: Float, +{ + #[inline] + fn next_back(&mut self) -> Option { + if self.index >= self.len { + None + } else { + // Calculate the value just like numpy.linspace does + self.len -= 1; + let i = self.len; + let exponent = self.start + self.step * F::from(i).unwrap(); + Some(self.sign * exponent.exp()) + } + } +} + +impl ExactSizeIterator for Geomspace where Geomspace: Iterator {} + +/// An iterator of a sequence of geometrically spaced values. +/// +/// The `Geomspace` has `n` elements, where the first element is `a` and the +/// last element is `b`. +/// +/// Iterator element type is `F`, where `F` must be either `f32` or `f64`. +/// +/// **Panics** if the interval `[a, b]` contains zero (including the end points). +#[inline] +pub fn geomspace(a: F, b: F, n: usize) -> Geomspace +where + F: Float, +{ + assert!( + a != F::zero() && b != F::zero(), + "Start and/or end of geomspace cannot be zero.", + ); + assert!( + a.is_sign_negative() == b.is_sign_negative(), + "Logarithmic interval cannot cross 0." + ); + + let log_a = a.abs().ln(); + let log_b = b.abs().ln(); + let step = if n > 1 { + let nf: F = F::from(n).unwrap(); + (log_b - log_a) / (nf - F::one()) + } else { + F::zero() + }; + Geomspace { + sign: a.signum(), + start: log_a, + step: step, + index: 0, + len: n, + } +} + +#[cfg(test)] +mod tests { + use super::geomspace; + use crate::{arr1, Array1}; + + #[test] + #[cfg(approx)] + fn valid() { + let array: Array1<_> = geomspace(1e0, 1e3, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5)); + + let array: Array1<_> = geomspace(1e3, 1e0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5)); + + let array: Array1<_> = geomspace(-1e3, -1e0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5)); + + let array: Array1<_> = geomspace(-1e0, -1e3, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5)); + } + + #[test] + fn iter_forward() { + let mut iter = geomspace(1.0f64, 1e3, 4); + + assert!(iter.size_hint() == (4, Some(4))); + + assert!((iter.next().unwrap() - 1e0).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e1).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e2).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e3).abs() < 1e-5); + assert!(iter.next().is_none()); + + assert!(iter.size_hint() == (0, Some(0))); + } + + #[test] + fn iter_backward() { + let mut iter = geomspace(1.0f64, 1e3, 4); + + assert!(iter.size_hint() == (4, Some(4))); + + assert!((iter.next_back().unwrap() - 1e3).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e2).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e1).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e0).abs() < 1e-5); + assert!(iter.next_back().is_none()); + + assert!(iter.size_hint() == (0, Some(0))); + } + + #[test] + #[should_panic] + fn zero_lower() { + geomspace(0.0, 1.0, 4); + } + + #[test] + #[should_panic] + fn zero_upper() { + geomspace(1.0, 0.0, 4); + } + + #[test] + #[should_panic] + fn zero_included() { + geomspace(-1.0, 1.0, 4); + } +} diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 1c2b84439..fca687a0c 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -10,18 +10,18 @@ //! //! -use num_traits::{Zero, One, Float}; +use num_traits::{Float, One, Zero}; use std::isize; use std::mem; -use crate::imp_prelude::*; -use crate::StrideShape; use crate::dimension; -use crate::linspace; use crate::error::{self, ShapeError}; -use crate::indices; +use crate::imp_prelude::*; use crate::indexes; +use crate::indices; use crate::iterators::{to_vec, to_vec_mapped}; +use crate::StrideShape; +use crate::{linspace, geomspace, logspace}; /// # Constructor Methods for Owned Arrays /// @@ -101,6 +101,61 @@ impl ArrayBase { Self::from_vec(to_vec(linspace::range(start, end, step))) } + + /// Create a one-dimensional array with `n` elements logarithmically spaced, + /// with the starting value being `base.powf(start)` and the final one being + /// `base.powf(end)`. `A` must be a floating point type. + /// + /// If `base` is negative, all values will be negative. + /// + /// **Panics** if the length is greater than `isize::MAX`. + /// + /// ```rust + /// use ndarray::{Array, arr1}; + /// use approx::AbsDiffEq; + /// + /// # #[cfg(feature = "approx")] { + /// let array = Array::logspace(10.0, 0.0, 3.0, 4); + /// assert!(array.abs_diff_eq(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5)); + /// + /// let array = Array::logspace(-10.0, 3.0, 0.0, 4); + /// assert!(array.abs_diff_eq(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5)); + /// # } + /// ``` + pub fn logspace(base: A, start: A, end: A, n: usize) -> Self + where + A: Float, + { + Self::from_vec(to_vec(logspace::logspace(base, start, end, n))) + } + + /// Create a one-dimensional array from the inclusive interval `[start, + /// end]` with `n` elements geometrically spaced. `A` must be a floating + /// point type. + /// + /// The interval can be either all positive or all negative; however, it + /// cannot contain 0 (including the end points). + /// + /// **Panics** if `n` is greater than `isize::MAX`. + /// + /// ```rust + /// use ndarray::{Array, arr1}; + /// use approx::AbsDiffEq; + /// + /// # #[cfg(feature = "approx")] { + /// let array = Array::geomspace(1e0, 1e3, 4); + /// assert!(array.abs_diff_eq(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5)); + /// + /// let array = Array::geomspace(-1e3, -1e0, 4); + /// assert!(array.abs_diff_eq(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5)); + /// # } + /// ``` + pub fn geomspace(start: A, end: A, n: usize) -> Self + where + A: Float, + { + Self::from_vec(to_vec(geomspace::geomspace(start, end, n))) + } } /// ## Constructor methods for two-dimensional arrays. diff --git a/src/impl_methods.rs b/src/impl_methods.rs index de74af795..5887bebcb 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -83,17 +83,53 @@ where } /// Return the shape of the array as it stored in the array. + /// + /// This is primarily useful for passing to other `ArrayBase` + /// functions, such as when creating another array of the same + /// shape and dimensionality. + /// + /// ``` + /// use ndarray::Array; + /// + /// let a = Array::from_elem((2, 3), 5.); + /// + /// // Create an array of zeros that's the same shape and dimensionality as `a`. + /// let b = Array::::zeros(a.raw_dim()); + /// ``` pub fn raw_dim(&self) -> D { self.dim.clone() } /// Return the shape of the array as a slice. - pub fn shape(&self) -> &[Ix] { + /// + /// Note that you probably don't want to use this to create an array of the + /// same shape as another array because creating an array with e.g. + /// [`Array::zeros()`](ArrayBase::zeros) using a shape of type `&[usize]` + /// results in a dynamic-dimensional array. If you want to create an array + /// that has the same shape and dimensionality as another array, use + /// [`.raw_dim()`](ArrayBase::raw_dim) instead: + /// + /// ```rust + /// use ndarray::{Array, Array2}; + /// + /// let a = Array2::::zeros((3, 4)); + /// let shape = a.shape(); + /// assert_eq!(shape, &[3, 4]); + /// + /// // Since `a.shape()` returned `&[usize]`, we get an `ArrayD` instance: + /// let b = Array::zeros(shape); + /// assert_eq!(a.clone().into_dyn(), b); + /// + /// // To get the same dimension type, use `.raw_dim()` instead: + /// let c = Array::zeros(a.raw_dim()); + /// assert_eq!(a, c); + /// ``` + pub fn shape(&self) -> &[usize] { self.dim.slice() } - /// Return the strides of the array as a slice - pub fn strides(&self) -> &[Ixs] { + /// Return the strides of the array as a slice. + pub fn strides(&self) -> &[isize] { let s = self.strides.slice(); // reinterpret unsigned integer as signed unsafe { @@ -1476,6 +1512,12 @@ where /// **Note:** Cannot be used for mutable iterators, since repeating /// elements would create aliasing pointers. fn upcast(to: &D, from: &E, stride: &E) -> Option { + // Make sure the product of non-zero axis lengths does not exceed + // `isize::MAX`. This is the only safety check we need to perform + // because all the other constraints of `ArrayBase` are guaranteed + // to be met since we're starting from a valid `ArrayBase`. + let _ = size_of_shape_checked(to).ok()?; + let mut new_stride = to.clone(); // begin at the back (the least significant dimension) // size of the axis has to either agree or `from` has to be 1 @@ -1998,14 +2040,17 @@ where /// /// ``` /// use ndarray::arr2; + /// use approx::AbsDiffEq; /// + /// # #[cfg(feature = "approx")] { /// let mut a = arr2(&[[ 0., 1.], /// [-1., 2.]]); /// a.mapv_inplace(f32::exp); /// assert!( - /// a.all_close(&arr2(&[[1.00000, 2.71828], - /// [0.36788, 7.38906]]), 1e-5) + /// a.abs_diff_eq(&arr2(&[[1.00000, 2.71828], + /// [0.36788, 7.38906]]), 1e-5) /// ); + /// # } /// ``` pub fn mapv_inplace(&mut self, mut f: F) where S: DataMut, @@ -2066,13 +2111,18 @@ where { let view_len = self.len_of(axis); let view_stride = self.strides.axis(axis); - // use the 0th subview as a map to each 1d array view extended from - // the 0th element. - self.index_axis(axis, 0).map(|first_elt| { - unsafe { - mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride))) - } - }) + if view_len == 0 { + let new_dim = self.dim.remove_axis(axis); + Array::from_shape_fn(new_dim, move |_| mapping(ArrayView::from(&[]))) + } else { + // use the 0th subview as a map to each 1d array view extended from + // the 0th element. + self.index_axis(axis, 0).map(|first_elt| { + unsafe { + mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride))) + } + }) + } } /// Reduce the values along an axis into just one value, producing a new @@ -2094,12 +2144,17 @@ where { let view_len = self.len_of(axis); let view_stride = self.strides.axis(axis); - // use the 0th subview as a map to each 1d array view extended from - // the 0th element. - self.index_axis_mut(axis, 0).map_mut(|first_elt: &mut A| { - unsafe { - mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride))) - } - }) + if view_len == 0 { + let new_dim = self.dim.remove_axis(axis); + Array::from_shape_fn(new_dim, move |_| mapping(ArrayViewMut::from(&mut []))) + } else { + // use the 0th subview as a map to each 1d array view extended from + // the 0th element. + self.index_axis_mut(axis, 0).map_mut(|first_elt| { + unsafe { + mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride))) + } + }) + } } } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index b9aab1aa9..959d51092 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -13,6 +13,7 @@ mod windows; mod lanes; pub mod iter; +use std::iter::FromIterator; use std::marker::PhantomData; use std::ptr; @@ -388,6 +389,54 @@ impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> { { either!(self.inner, iter => iter.fold(init, g)) } + + fn nth(&mut self, n: usize) -> Option { + either_mut!(self.inner, iter => iter.nth(n)) + } + + fn collect(self) -> B + where B: FromIterator + { + either!(self.inner, iter => iter.collect()) + } + + fn all(&mut self, f: F) -> bool + where F: FnMut(Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.all(f)) + } + + fn any(&mut self, f: F) -> bool + where F: FnMut(Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.any(f)) + } + + fn find

(&mut self, predicate: P) -> Option + where P: FnMut(&Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.find(predicate)) + } + + fn find_map(&mut self, f: F) -> Option + where F: FnMut(Self::Item) -> Option + { + either_mut!(self.inner, iter => iter.find_map(f)) + } + + fn count(self) -> usize { + either!(self.inner, iter => iter.count()) + } + + fn last(self) -> Option { + either!(self.inner, iter => iter.last()) + } + + fn position

(&mut self, predicate: P) -> Option + where P: FnMut(Self::Item) -> bool, + { + either_mut!(self.inner, iter => iter.position(predicate)) + } } impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> { @@ -455,6 +504,54 @@ impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> { { either!(self.inner, iter => iter.fold(init, g)) } + + fn nth(&mut self, n: usize) -> Option { + either_mut!(self.inner, iter => iter.nth(n)) + } + + fn collect(self) -> B + where B: FromIterator + { + either!(self.inner, iter => iter.collect()) + } + + fn all(&mut self, f: F) -> bool + where F: FnMut(Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.all(f)) + } + + fn any(&mut self, f: F) -> bool + where F: FnMut(Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.any(f)) + } + + fn find

(&mut self, predicate: P) -> Option + where P: FnMut(&Self::Item) -> bool + { + either_mut!(self.inner, iter => iter.find(predicate)) + } + + fn find_map(&mut self, f: F) -> Option + where F: FnMut(Self::Item) -> Option + { + either_mut!(self.inner, iter => iter.find_map(f)) + } + + fn count(self) -> usize { + either!(self.inner, iter => iter.count()) + } + + fn last(self) -> Option { + either!(self.inner, iter => iter.last()) + } + + fn position

(&mut self, predicate: P) -> Option + where P: FnMut(Self::Item) -> bool, + { + either_mut!(self.inner, iter => iter.position(predicate)) + } } impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> { @@ -1214,25 +1311,25 @@ send_sync_read_write!(ElementsBaseMut); /// (Trait used internally) An iterator that we trust /// to deliver exactly as many items as it said it would. -pub unsafe trait TrustedIterator { } +pub unsafe trait TrustedIterator {} -use std; -use crate::linspace::Linspace; -use crate::iter::IndicesIter; use crate::indexes::IndicesIterF; +use crate::iter::IndicesIter; +use crate::{geomspace::Geomspace, linspace::Linspace, logspace::Logspace}; +use std; -unsafe impl TrustedIterator for Linspace { } -unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> { } -unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> { } -unsafe impl TrustedIterator for std::iter::Map - where I: TrustedIterator { } -unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> { } -unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> { } -unsafe impl TrustedIterator for ::std::ops::Range { } +unsafe impl TrustedIterator for Geomspace {} +unsafe impl TrustedIterator for Linspace {} +unsafe impl TrustedIterator for Logspace {} +unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> {} +unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> {} +unsafe impl TrustedIterator for std::iter::Map where I: TrustedIterator {} +unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> {} +unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> {} +unsafe impl TrustedIterator for ::std::ops::Range {} // FIXME: These indices iter are dubious -- size needs to be checked up front. -unsafe impl TrustedIterator for IndicesIter where D: Dimension { } -unsafe impl TrustedIterator for IndicesIterF where D: Dimension { } - +unsafe impl TrustedIterator for IndicesIter where D: Dimension {} +unsafe impl TrustedIterator for IndicesIterF where D: Dimension {} /// Like Iterator::collect, but only for trusted length iterators pub fn to_vec(iter: I) -> Vec diff --git a/src/lib.rs b/src/lib.rs index 9ffab0665..297b37af7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,15 +175,18 @@ mod free_functions; pub use crate::free_functions::*; pub use crate::iterators::iter; -#[macro_use] mod slice; -mod layout; +mod error; +mod geomspace; mod indexes; mod iterators; +mod layout; mod linalg_traits; mod linspace; +mod logspace; mod numeric_util; -mod error; mod shape_builder; +#[macro_use] +mod slice; mod stacking; #[macro_use] mod zip; @@ -646,6 +649,25 @@ pub type Ixs = isize; /// - `B @ &A` which consumes `B`, updates it with the result, and returns it /// - `C @= &A` which performs an arithmetic operation in place /// +/// Note that the element type needs to implement the operator trait and the +/// `Clone` trait. +/// +/// ``` +/// use ndarray::{array, ArrayView1}; +/// +/// let owned1 = array![1, 2]; +/// let owned2 = array![3, 4]; +/// let view1 = ArrayView1::from(&[5, 6]); +/// let view2 = ArrayView1::from(&[7, 8]); +/// let mut mutable = array![9, 10]; +/// +/// let sum1 = &view1 + &view2; // Allocates a new array. Note the explicit `&`. +/// // let sum2 = view1 + &view2; // This doesn't work because `view1` is not an owned array. +/// let sum3 = owned1 + view1; // Consumes `owned1`, updates it, and returns it. +/// let sum4 = owned2 + &view2; // Consumes `owned2`, updates it, and returns it. +/// mutable += &view2; // Updates `mutable` in-place. +/// ``` +/// /// ### Binary Operators with Array and Scalar /// /// The trait [`ScalarOperand`](trait.ScalarOperand.html) marks types that can be used in arithmetic @@ -928,7 +950,8 @@ pub type Ixs = isize; /// 3Works only if the array is contiguous. /// /// The table above does not include all the constructors; it only shows -/// conversions to/from `Vec`s/slices. See below for more constructors. +/// conversions to/from `Vec`s/slices. See +/// [below](#constructor-methods-for-owned-arrays) for more constructors. /// /// [ArrayView::reborrow()]: type.ArrayView.html#method.reborrow /// [ArrayViewMut::reborrow()]: type.ArrayViewMut.html#method.reborrow @@ -941,6 +964,101 @@ pub type Ixs = isize; /// [.view()]: #method.view /// [.view_mut()]: #method.view_mut /// +/// ### Conversions from Nested `Vec`s/`Array`s +/// +/// It's generally a good idea to avoid nested `Vec`/`Array` types, such as +/// `Vec>` or `Vec>` because: +/// +/// * they require extra heap allocations compared to a single `Array`, +/// +/// * they can scatter data all over memory (because of multiple allocations), +/// +/// * they cause unnecessary indirection (traversing multiple pointers to reach +/// the data), +/// +/// * they don't enforce consistent shape within the nested +/// `Vec`s/`ArrayBase`s, and +/// +/// * they are generally more difficult to work with. +/// +/// The most common case where users might consider using nested +/// `Vec`s/`Array`s is when creating an array by appending rows/subviews in a +/// loop, where the rows/subviews are computed within the loop. However, there +/// are better ways than using nested `Vec`s/`Array`s. +/// +/// If you know ahead-of-time the shape of the final array, the cleanest +/// solution is to allocate the final array before the loop, and then assign +/// the data to it within the loop, like this: +/// +/// ```rust +/// use ndarray::{array, Array2, Axis}; +/// +/// let mut arr = Array2::zeros((2, 3)); +/// for (i, mut row) in arr.axis_iter_mut(Axis(0)).enumerate() { +/// // Perform calculations and assign to `row`; this is a trivial example: +/// row.fill(i); +/// } +/// assert_eq!(arr, array![[0, 0, 0], [1, 1, 1]]); +/// ``` +/// +/// If you don't know ahead-of-time the shape of the final array, then the +/// cleanest solution is generally to append the data to a flat `Vec`, and then +/// convert it to an `Array` at the end with +/// [`::from_shape_vec()`](#method.from_shape_vec). You just have to be careful +/// that the layout of the data (the order of the elements in the flat `Vec`) +/// is correct. +/// +/// ```rust +/// use ndarray::{array, Array2}; +/// +/// # fn main() -> Result<(), Box> { +/// let ncols = 3; +/// let mut data = Vec::new(); +/// let mut nrows = 0; +/// for i in 0..2 { +/// // Compute `row` and append it to `data`; this is a trivial example: +/// let row = vec![i; ncols]; +/// data.extend_from_slice(&row); +/// nrows += 1; +/// } +/// let arr = Array2::from_shape_vec((nrows, ncols), data)?; +/// assert_eq!(arr, array![[0, 0, 0], [1, 1, 1]]); +/// # Ok(()) +/// # } +/// ``` +/// +/// If neither of these options works for you, and you really need to convert +/// nested `Vec`/`Array` instances to an `Array`, the cleanest solution is +/// generally to use +/// [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten) +/// to get a flat `Vec`, and then convert the `Vec` to an `Array` with +/// [`::from_shape_vec()`](#method.from_shape_vec), like this: +/// +/// ```rust +/// use ndarray::{array, Array2, Array3}; +/// +/// # fn main() -> Result<(), Box> { +/// let nested: Vec> = vec![ +/// array![[1, 2, 3], [4, 5, 6]], +/// array![[7, 8, 9], [10, 11, 12]], +/// ]; +/// let inner_shape = nested[0].dim(); +/// let shape = (nested.len(), inner_shape.0, inner_shape.1); +/// let flat: Vec = nested.iter().flatten().cloned().collect(); +/// let arr = Array3::from_shape_vec(shape, flat)?; +/// assert_eq!(arr, array![ +/// [[1, 2, 3], [4, 5, 6]], +/// [[7, 8, 9], [10, 11, 12]], +/// ]); +/// # Ok(()) +/// # } +/// ``` +/// +/// Note that this implementation assumes that the nested `Vec`s are all the +/// same shape and that the `Vec` is non-empty. Depending on your application, +/// it may be a good idea to add checks for these assumptions and possibly +/// choose a different way to handle the empty case. +/// // # For implementors // // All methods must uphold the following constraints: diff --git a/src/logspace.rs b/src/logspace.rs new file mode 100644 index 000000000..7aa6d11e2 --- /dev/null +++ b/src/logspace.rs @@ -0,0 +1,146 @@ +// Copyright 2014-2016 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use num_traits::Float; + +/// An iterator of a sequence of logarithmically spaced number. +/// +/// Iterator element type is `F`. +pub struct Logspace { + sign: F, + base: F, + start: F, + step: F, + index: usize, + len: usize, +} + +impl Iterator for Logspace +where + F: Float, +{ + type Item = F; + + #[inline] + fn next(&mut self) -> Option { + if self.index >= self.len { + None + } else { + // Calculate the value just like numpy.linspace does + let i = self.index; + self.index += 1; + let exponent = self.start + self.step * F::from(i).unwrap(); + Some(self.sign * self.base.powf(exponent)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let n = self.len - self.index; + (n, Some(n)) + } +} + +impl DoubleEndedIterator for Logspace +where + F: Float, +{ + #[inline] + fn next_back(&mut self) -> Option { + if self.index >= self.len { + None + } else { + // Calculate the value just like numpy.linspace does + self.len -= 1; + let i = self.len; + let exponent = self.start + self.step * F::from(i).unwrap(); + Some(self.sign * self.base.powf(exponent)) + } + } +} + +impl ExactSizeIterator for Logspace where Logspace: Iterator {} + +/// An iterator of a sequence of logarithmically spaced number. +/// +/// The `Logspace` has `n` elements, where the first element is `base.powf(a)` +/// and the last element is `base.powf(b)`. If `base` is negative, this +/// iterator will return all negative values. +/// +/// Iterator element type is `F`, where `F` must be either `f32` or `f64`. +#[inline] +pub fn logspace(base: F, a: F, b: F, n: usize) -> Logspace +where + F: Float, +{ + let step = if n > 1 { + let nf: F = F::from(n).unwrap(); + (b - a) / (nf - F::one()) + } else { + F::zero() + }; + Logspace { + sign: base.signum(), + base: base.abs(), + start: a, + step: step, + index: 0, + len: n, + } +} + +#[cfg(test)] +mod tests { + use super::logspace; + use crate::{arr1, Array1}; + + #[test] + #[cfg(approx)] + fn valid() { + let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5)); + + let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5)); + + let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5)); + + let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect(); + assert!(array.abs_diff_eq(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5)); + } + + #[test] + fn iter_forward() { + let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + + assert!(iter.size_hint() == (4, Some(4))); + + assert!((iter.next().unwrap() - 1e0).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e1).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e2).abs() < 1e-5); + assert!((iter.next().unwrap() - 1e3).abs() < 1e-5); + assert!(iter.next().is_none()); + + assert!(iter.size_hint() == (0, Some(0))); + } + + #[test] + fn iter_backward() { + let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + + assert!(iter.size_hint() == (4, Some(4))); + + assert!((iter.next_back().unwrap() - 1e3).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e2).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e1).abs() < 1e-5); + assert!((iter.next_back().unwrap() - 1e0).abs() < 1e-5); + assert!(iter.next_back().is_none()); + + assert!(iter.size_hint() == (0, Some(0))); + } +} diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 0016e21c5..849bd2b65 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -46,6 +46,33 @@ impl ArrayBase sum } + /// Returns the [arithmetic mean] x̅ of all elements in the array: + /// + /// ```text + /// 1 n + /// x̅ = ― ∑ xᵢ + /// n i=1 + /// ``` + /// + /// If the array is empty, `None` is returned. + /// + /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. + /// + /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean + pub fn mean(&self) -> Option + where + A: Clone + FromPrimitive + Add + Div + Zero + { + let n_elements = self.len(); + if n_elements == 0 { + None + } else { + let n_elements = A::from_usize(n_elements) + .expect("Converting number of elements to `A` must not fail."); + Some(self.sum() / n_elements) + } + } + /// Return the sum of all elements in the array. /// /// *This method has been renamed to `.sum()` and will be deprecated in the @@ -88,13 +115,13 @@ impl ArrayBase /// ``` /// use ndarray::{aview0, aview1, arr2, Axis}; /// - /// let a = arr2(&[[1., 2.], - /// [3., 4.]]); + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); /// assert!( - /// a.sum_axis(Axis(0)) == aview1(&[4., 6.]) && - /// a.sum_axis(Axis(1)) == aview1(&[3., 7.]) && + /// a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) && + /// a.sum_axis(Axis(1)) == aview1(&[6., 15.]) && /// - /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&10.) + /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.) /// ); /// ``` /// @@ -123,27 +150,36 @@ impl ArrayBase /// Return mean along `axis`. /// - /// **Panics** if `axis` is out of bounds, if the length of the axis is - /// zero and division by zero panics for type `A`, or if `A::from_usize()` + /// Return `None` if the length of the axis is zero. + /// + /// **Panics** if `axis` is out of bounds or if `A::from_usize()` /// fails for the axis length. /// /// ``` - /// use ndarray::{aview1, arr2, Axis}; + /// use ndarray::{aview0, aview1, arr2, Axis}; /// - /// let a = arr2(&[[1., 2.], - /// [3., 4.]]); + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); /// assert!( - /// a.mean_axis(Axis(0)) == aview1(&[2.0, 3.0]) && - /// a.mean_axis(Axis(1)) == aview1(&[1.5, 3.5]) + /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) && + /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) && + /// + /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5) /// ); /// ``` - pub fn mean_axis(&self, axis: Axis) -> Array + pub fn mean_axis(&self, axis: Axis) -> Option> where A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { - let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail."); - let sum = self.sum_axis(axis); - sum / &aview0(&n) + let axis_length = self.len_of(axis); + if axis_length == 0 { + None + } else { + let axis_length = A::from_usize(axis_length) + .expect("Converting axis length to `A` must not fail."); + let sum = self.sum_axis(axis); + Some(sum / &aview0(&axis_length)) + } } /// Return variance along `axis`. @@ -226,9 +262,9 @@ impl ArrayBase /// The standard deviation is defined as: /// /// ```text - /// 1 n - /// stddev = sqrt ( ―――――――― ∑ (xᵢ - x̅)² ) - /// n - ddof i=1 + /// ⎛ 1 n ⎞ + /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟ + /// ⎝ n - ddof i=1 ⎠ /// ``` /// /// where @@ -270,6 +306,7 @@ impl ArrayBase /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting to the same shape isn’t possible. + #[deprecated(note="Use `abs_diff_eq` - it requires the `approx` crate feature", since="0.13")] pub fn all_close(&self, rhs: &ArrayBase, tol: A) -> bool where A: Float, S2: Data, diff --git a/src/zip/mod.rs b/src/zip/mod.rs index b26e2aebb..fc8b99c7d 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -738,6 +738,32 @@ macro_rules! map_impl { }) } + /// Tests if every element of the iterator matches a predicate. + /// + /// Returns `true` if `predicate` evaluates to `true` for all elements. + /// Returns `true` if the input arrays are empty. + /// + /// Example: + /// + /// ``` + /// use ndarray::{array, Zip}; + /// let a = array![1, 2, 3]; + /// let b = array![1, 4, 9]; + /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b)); + /// ``` + pub fn all(mut self, mut predicate: F) -> bool + where F: FnMut($($p::Item),*) -> bool + { + self.apply_core(true, move |_, args| { + let ($($p,)*) = args; + if predicate($($p),*) { + FoldWhile::Continue(true) + } else { + FoldWhile::Done(false) + } + }).into_inner() + } + expand_if!(@bool [$notlast] /// Include the producer `p` in the Zip. diff --git a/tests/array-construct.rs b/tests/array-construct.rs index 7e320c314..e14b9f9fd 100644 --- a/tests/array-construct.rs +++ b/tests/array-construct.rs @@ -21,6 +21,7 @@ fn test_dimension_zero() { } #[test] +#[cfg(features = "approx")] fn test_arc_into_owned() { let a = Array2::from_elem((5, 5), 1.).into_shared(); let mut b = a.clone(); @@ -28,7 +29,7 @@ fn test_arc_into_owned() { let mut c = b.into_owned(); c.fill(2.); // test that they are unshared - assert!(!a.all_close(&c, 0.01)); + assert_abs_diff_ne!(a, &c, 0.01); } #[test] diff --git a/tests/array.rs b/tests/array.rs index 28e2e7fbc..f9d42638d 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -12,8 +12,9 @@ use ndarray::{ multislice, }; use ndarray::indices; +use approx::AbsDiffEq; use defmac::defmac; -use itertools::{enumerate, zip}; +use itertools::{enumerate, zip, Itertools}; macro_rules! assert_panics { ($body:expr) => { @@ -680,6 +681,7 @@ fn test_sub_oob_1() { #[test] +#[cfg(feature = "approx")] fn test_select(){ // test for 2-d array let x = arr2(&[[0., 1.], [1.,0.],[1.,0.],[1.,0.],[1.,0.],[0., 1.],[0., 1.]]); @@ -687,8 +689,8 @@ fn test_select(){ let c = x.select(Axis(1),&[1]); let r_target = arr2(&[[1.,0.],[1.,0.],[0., 1.]]); let c_target = arr2(&[[1.,0.,0.,0.,0., 1., 1.]]); - assert!(r.all_close(&r_target,1e-8)); - assert!(c.all_close(&c_target.t(),1e-8)); + assert!(r.abs_diff_eq(&r_target,1e-8)); + assert!(c.abs_diff_eq(&c_target.t(),1e-8)); // test for 3-d array let y = arr3(&[[[1., 2., 3.], @@ -699,8 +701,8 @@ fn test_select(){ let c = y.select(Axis(2),&[1]); let r_target = arr3(&[[[1.5, 1.5, 3.]], [[1., 2.5, 3.]]]); let c_target = arr3(&[[[2.],[1.5]],[[2.],[2.5]]]); - assert!(r.all_close(&r_target,1e-8)); - assert!(c.all_close(&c_target,1e-8)); + assert!(r.abs_diff_eq(&r_target,1e-8)); + assert!(c.abs_diff_eq(&c_target,1e-8)); } @@ -925,175 +927,6 @@ fn assign() assert_eq!(a, arr2(&[[0, 0], [3, 4]])); } -#[test] -fn sum_mean() -{ - let a = arr2(&[[1., 2.], [3., 4.]]); - assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); - assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); - assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.])); - assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5])); - assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); - assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5])); - assert_eq!(a.sum(), 10.); -} - -#[test] -fn sum_mean_empty() { - assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); - assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); - assert_eq!( - Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), - Array::zeros((2, 3)), - ); - let a = Array1::::ones(0).mean_axis(Axis(0)); - assert_eq!(a.shape(), &[]); - assert!(a[()].is_nan()); - let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); - assert_eq!(a.shape(), &[2, 3]); - a.mapv(|x| assert!(x.is_nan())); -} - -#[test] -fn var_axis() { - let a = array![ - [ - [-9.76, -0.38, 1.59, 6.23], - [-8.57, -9.27, 5.76, 6.01], - [-9.54, 5.09, 3.21, 6.56], - ], - [ - [ 8.23, -9.63, 3.76, -3.48], - [-5.46, 5.86, -2.81, 1.35], - [-1.08, 4.66, 8.34, -0.73], - ], - ]; - assert!(a.var_axis(Axis(0), 1.5).all_close( - &aview2(&[ - [3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01], - [9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01], - [7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01] - ]), - 1e-4, - )); - assert!(a.var_axis(Axis(1), 1.7).all_close( - &aview2(&[ - [0.61676923, 80.81092308, 6.79892308, 0.11789744], - [75.19912821, 114.25235897, 48.32405128, 9.03020513], - ]), - 1e-8, - )); - assert!(a.var_axis(Axis(2), 2.3).all_close( - &aview2(&[ - [ 79.64552941, 129.09663235, 95.98929412], - [109.64952941, 43.28758824, 36.27439706], - ]), - 1e-8, - )); - - let b = array![[1.1, 2.3, 4.7]]; - assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); - assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12)); - - let c = array![[], []]; - assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[])); - - let d = array![1.1, 2.7, 3.5, 4.9]; - assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12)); -} - -#[test] -fn std_axis() { - let a = array![ - [ - [ 0.22935481, 0.08030619, 0.60827517, 0.73684379], - [ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ], - [ 0.44485313, 0.63316367, 0.11005111, 0.08656246] - ], - [ - [ 0.28924665, 0.44082454, 0.59837736, 0.41014531], - [ 0.08382316, 0.43259439, 0.1428889 , 0.44830176], - [ 0.51529756, 0.70111616, 0.20799415, 0.91851457] - ], - ]; - assert!(a.std_axis(Axis(0), 1.5).all_close( - &aview2(&[ - [ 0.05989184, 0.36051836, 0.00989781, 0.32669847], - [ 0.81957535, 0.39599997, 0.49731472, 0.17084346], - [ 0.07044443, 0.06795249, 0.09794304, 0.83195211], - ]), - 1e-4, - )); - assert!(a.std_axis(Axis(1), 1.7).all_close( - &aview2(&[ - [ 0.42698655, 0.48139215, 0.36874991, 0.41458724], - [ 0.26769097, 0.18941435, 0.30555015, 0.35118674], - ]), - 1e-8, - )); - assert!(a.std_axis(Axis(2), 2.3).all_close( - &aview2(&[ - [ 0.41117907, 0.37130425, 0.35332388], - [ 0.16905862, 0.25304841, 0.39978276], - ]), - 1e-8, - )); - - let b = array![[100000., 1., 0.01]]; - assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12)); - assert!( - b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6), - ); - - let c = array![[], []]; - assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[])); -} - -#[test] -#[should_panic] -fn var_axis_negative_ddof() { - let a = array![1., 2., 3.]; - a.var_axis(Axis(0), -1.); -} - -#[test] -#[should_panic] -fn var_axis_too_large_ddof() { - let a = array![1., 2., 3.]; - a.var_axis(Axis(0), 4.); -} - -#[test] -fn var_axis_nan_ddof() { - let a = Array2::::zeros((2, 3)); - let v = a.var_axis(Axis(1), ::std::f64::NAN); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - -#[test] -fn var_axis_empty_axis() { - let a = Array2::::zeros((2, 0)); - let v = a.var_axis(Axis(1), 0.); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - -#[test] -#[should_panic] -fn std_axis_bad_dof() { - let a = array![1., 2., 3.]; - a.std_axis(Axis(0), 4.); -} - -#[test] -fn std_axis_empty_axis() { - let a = Array2::::zeros((2, 0)); - let v = a.std_axis(Axis(1), 0.); - assert_eq!(v.shape(), &[2]); - v.mapv(|x| assert!(x.is_nan())); -} - #[test] fn iter_size_hint() { @@ -1902,13 +1735,18 @@ fn test_contiguous() { } #[test] +#[cfg(feature = "approx")] fn test_all_close() { let c = arr3(&[[[1., 2., 3.], [1.5, 1.5, 3.]], [[1., 2., 3.], [1., 2.5, 3.]]]); - assert!(c.all_close(&aview1(&[1., 2., 3.]), 1.)); - assert!(!c.all_close(&aview1(&[1., 2., 3.]), 0.1)); + assert!( + c.abs_diff_eq(&aview1(&[1., 2., 3.]).broadcast(c.raw_dim()).unwrap(), 1.) + ); + assert!( + c.abs_diff_ne(&aview1(&[1., 2., 3.]).broadcast(c.raw_dim()).unwrap(), 0.1) + ); } #[test] @@ -2002,6 +1840,27 @@ fn test_map_axis() { let c = a.map_axis(Axis(1), |view| view.sum()); let answer2 = arr1(&[6, 15, 24, 33]); assert_eq!(c, answer2); + + // Test zero-length axis case + let arr = Array3::::zeros((3, 0, 4)); + let mut counter = 0; + let result = arr.map_axis(Axis(1), |x| { + assert_eq!(x.shape(), &[0]); + counter += 1; + counter + }); + assert_eq!(result.shape(), &[3, 4]); + itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4); + + let mut arr = Array3::::zeros((3, 0, 4)); + let mut counter = 0; + let result = arr.map_axis_mut(Axis(1), |x| { + assert_eq!(x.shape(), &[0]); + counter += 1; + counter + }); + assert_eq!(result.shape(), &[3, 4]); + itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4); } #[test] diff --git a/tests/azip.rs b/tests/azip.rs index a5e53e0e0..9e4efabce 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -4,7 +4,7 @@ extern crate itertools; use ndarray::prelude::*; use ndarray::Zip; -use itertools::{assert_equal, cloned, enumerate}; +use itertools::{assert_equal, cloned}; use std::mem::swap; @@ -45,17 +45,19 @@ fn test_azip2_3() { } #[test] +#[cfg(features = "approx")] fn test_azip2_sum() { let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); for i in 0..2 { let ax = Axis(i); let mut b = Array::zeros(c.len_of(ax)); azip!(mut b, ref c (c.axis_iter(ax)) in { *b = c.sum() }); - assert!(b.all_close(&c.sum_axis(Axis(1 - i)), 1e-6)); + assert_abs_diff_eq!(b, c.sum_axis(Axis(1 - i)), 1e-6); } } #[test] +#[cfg(features = "approx")] fn test_azip3_slices() { let mut a = [0.; 32]; let mut b = [0.; 32]; @@ -69,10 +71,11 @@ fn test_azip3_slices() { *c = a.sin(); }); let res = Array::linspace(0., 3.1, 32).mapv_into(f32::sin); - assert!(res.all_close(&ArrayView::from(&c), 1e-4)); + assert_abs_diff_eq!(res, &ArrayView::from(&c), 1e-4); } #[test] +#[cfg(features = "approx")] fn test_broadcast() { let n = 16; let mut a = Array::::zeros((n, n)); @@ -90,7 +93,7 @@ fn test_broadcast() { .and_broadcast(&e); z.apply(|x, &y, &z, &w| *x = y + z + w); } - assert!(a.all_close(&(&b + &d + &e), 1e-4)); + assert_abs_diff_eq!(a, &(&b + &d + &e), 1e-4); } #[should_panic] @@ -271,3 +274,22 @@ fn test_indices_split_1() { } } } + +#[test] +fn test_zip_all() { + let a = Array::::zeros(62); + let b = Array::::ones(62); + let mut c = Array::::ones(62); + c[5] = 0.0; + assert_eq!(true, Zip::from(&a).and(&b).all(|&x, &y| x + y == 1.0)); + assert_eq!(false, Zip::from(&a).and(&b).all(|&x, &y| x == y)); + assert_eq!(false, Zip::from(&a).and(&c).all(|&x, &y| x + y == 1.0)); +} + +#[test] +fn test_zip_all_empty_array() { + let a = Array::::zeros(0); + let b = Array::::ones(0); + assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| true)); + assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| false)); +} diff --git a/tests/complex.rs b/tests/complex.rs index 8da721c4d..a7449e9f8 100644 --- a/tests/complex.rs +++ b/tests/complex.rs @@ -22,5 +22,5 @@ fn complex_mat_mul() let r = a.dot(&e); println!("{}", a); assert_eq!(r, a); - assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)])); + assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)])); } diff --git a/tests/numeric.rs b/tests/numeric.rs new file mode 100644 index 000000000..a9aaf7c91 --- /dev/null +++ b/tests/numeric.rs @@ -0,0 +1,206 @@ +extern crate approx; +use std::f64; +use ndarray::{array, Axis, aview1, arr0, arr1, arr2, Array, Array1, Array2, Array3}; +use approx::abs_diff_eq; + +#[test] +fn test_mean_with_nan_values() { + let a = array![f64::NAN, 1.]; + assert!(a.mean().unwrap().is_nan()); +} + +#[test] +fn test_mean_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert!(a.mean().is_none()); +} + +#[test] +fn test_mean_with_array_of_floats() { + let a: Array1 = array![ + 0.99889651, 0.0150731 , 0.28492482, 0.83819218, 0.48413156, + 0.80710412, 0.41762936, 0.22879429, 0.43997224, 0.23831807, + 0.02416466, 0.6269962 , 0.47420614, 0.56275487, 0.78995021, + 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, + 0.34429457, 0.88072369, 0.17638164, 0.60819363, 0.250392 , + 0.69912532, 0.78855523, 0.79140914, 0.85084218, 0.31839879, + 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, + 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, + 0.63608897, 0.84959691, 0.43599069, 0.77867775, 0.88267754, + 0.83003623, 0.67016118, 0.67547638, 0.65220036, 0.68043427 + ]; + // Computed using NumPy + let expected_mean = 0.5475494059146699; + abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = f64::EPSILON); +} + +#[test] +fn sum_mean() +{ + let a = arr2(&[[1., 2.], [3., 4.]]); + assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); + assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); + assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.]))); + assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5]))); + assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); + assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5])); + assert_eq!(a.sum(), 10.); +} + +#[test] +fn sum_mean_empty() { + assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); + assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); + assert_eq!( + Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), + Array::zeros((2, 3)), + ); + let a = Array1::::ones(0).mean_axis(Axis(0)); + assert_eq!(a, None); + let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); + assert_eq!(a, None); +} + +#[test] +#[cfg(features = "approx")] +fn var_axis() { + let a = array![ + [ + [-9.76, -0.38, 1.59, 6.23], + [-8.57, -9.27, 5.76, 6.01], + [-9.54, 5.09, 3.21, 6.56], + ], + [ + [ 8.23, -9.63, 3.76, -3.48], + [-5.46, 5.86, -2.81, 1.35], + [-1.08, 4.66, 8.34, -0.73], + ], + ]; + assert_abs_diff_eq!( + a.var_axis(Axis(0), 1.5), + &aview2(&[ + [3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01], + [9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01], + [7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01] + ]), + 1e-4, + ); + assert_abs_diff_eq!(a.var_axis(Axis(1), 1.7), + &aview2(&[ + [0.61676923, 80.81092308, 6.79892308, 0.11789744], + [75.19912821, 114.25235897, 48.32405128, 9.03020513], + ]), + 1e-8, + ); + assert_abs_diff_eq!(a.var_axis(Axis(2), 2.3), + &aview2(&[ + [ 79.64552941, 129.09663235, 95.98929412], + [109.64952941, 43.28758824, 36.27439706], + ]), + 1e-8, + ); + + let b = array![[1.1, 2.3, 4.7]]; + assert_abs_diff_eq!(b.var_axis(Axis(0), 0.), &aview1(&[0., 0., 0.]), 1e-12); + assert_abs_diff_eq!(b.var_axis(Axis(1), 0.), &aview1(&[2.24]), 1e-12); + + let c = array![[], []]; + assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[])); + + let d = array![1.1, 2.7, 3.5, 4.9]; + assert_abs_diff_eq!(d.var_axis(Axis(0), 0.), &aview0(&1.8875), 1e-12); +} + +#[test] +#[cfg(features = "approx")] +fn std_axis() { + let a = array![ + [ + [ 0.22935481, 0.08030619, 0.60827517, 0.73684379], + [ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ], + [ 0.44485313, 0.63316367, 0.11005111, 0.08656246] + ], + [ + [ 0.28924665, 0.44082454, 0.59837736, 0.41014531], + [ 0.08382316, 0.43259439, 0.1428889 , 0.44830176], + [ 0.51529756, 0.70111616, 0.20799415, 0.91851457] + ], + ]; + assert_abs_diff_eq!(a.std_axis(Axis(0), 1.5), + &aview2(&[ + [ 0.05989184, 0.36051836, 0.00989781, 0.32669847], + [ 0.81957535, 0.39599997, 0.49731472, 0.17084346], + [ 0.07044443, 0.06795249, 0.09794304, 0.83195211], + ]), + 1e-4, + ); + assert_abs_diff_eq!(a.std_axis(Axis(1), 1.7), + &aview2(&[ + [ 0.42698655, 0.48139215, 0.36874991, 0.41458724], + [ 0.26769097, 0.18941435, 0.30555015, 0.35118674], + ]), + 1e-8, + ); + assert_abs_diff_eq!(a.std_axis(Axis(2), 2.3), + &aview2(&[ + [ 0.41117907, 0.37130425, 0.35332388], + [ 0.16905862, 0.25304841, 0.39978276], + ]), + 1e-8, + ); + + let b = array![[100000., 1., 0.01]]; + assert_abs_diff_eq!(b.std_axis(Axis(0), 0.), &aview1(&[0., 0., 0.]), 1e-12); + assert_abs_diff_eq!( + b.std_axis(Axis(1), 0.), &aview1(&[47140.214021552769]), 1e-6, + ); + + let c = array![[], []]; + assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[])); +} + +#[test] +#[should_panic] +fn var_axis_negative_ddof() { + let a = array![1., 2., 3.]; + a.var_axis(Axis(0), -1.); +} + +#[test] +#[should_panic] +fn var_axis_too_large_ddof() { + let a = array![1., 2., 3.]; + a.var_axis(Axis(0), 4.); +} + +#[test] +fn var_axis_nan_ddof() { + let a = Array2::::zeros((2, 3)); + let v = a.var_axis(Axis(1), ::std::f64::NAN); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} + +#[test] +fn var_axis_empty_axis() { + let a = Array2::::zeros((2, 0)); + let v = a.var_axis(Axis(1), 0.); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} + +#[test] +#[should_panic] +fn std_axis_bad_dof() { + let a = array![1., 2., 3.]; + a.std_axis(Axis(0), 4.); +} + +#[test] +fn std_axis_empty_axis() { + let a = Array2::::zeros((2, 0)); + let v = a.std_axis(Axis(1), 0.); + assert_eq!(v.shape(), &[2]); + v.mapv(|x| assert!(x.is_nan())); +} + diff --git a/tests/par_azip.rs b/tests/par_azip.rs index 4ffe5b347..d373345e4 100644 --- a/tests/par_azip.rs +++ b/tests/par_azip.rs @@ -25,6 +25,7 @@ fn test_par_azip2() { } #[test] +#[cfg(features = "approx")] fn test_par_azip3() { let mut a = [0.; 32]; let mut b = [0.; 32]; @@ -38,7 +39,7 @@ fn test_par_azip3() { *c = a.sin(); }); let res = Array::linspace(0., 3.1, 32).mapv_into(f32::sin); - assert!(res.all_close(&ArrayView::from(&c), 1e-4)); + assert_abs_diff_eq!(res, &ArrayView::from(&c), 1e-4); } #[should_panic] diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 70b69e0eb..6dda0a1d1 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -24,12 +24,13 @@ fn test_axis_iter() { } #[test] +#[cfg(features = "approx")] fn test_axis_iter_mut() { let mut a = Array::linspace(0., 1.0f64, M * N).into_shape((M, N)).unwrap(); let b = a.mapv(|x| x.exp()); a.axis_iter_mut(Axis(0)).into_par_iter().for_each(|mut v| v.mapv_inplace(|x| x.exp())); println!("{:?}", a.slice(s![..10, ..5])); - assert!(a.all_close(&b, 0.001)); + assert_abs_diff_eq!(a, &b, 0.001); } #[test]