diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 4f03d3aac..b2ef92e43 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -540,6 +540,14 @@ impl Dimension for Dim<[Ix; 1]> { fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { self.remove_axis(axis) } + + fn from_dimension(d: &D2) -> Option { + if 1 == d.ndim() { + Some(Ix1(d[0])) + } else { + None + } + } private_impl! {} } diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 82cd1cba6..979238aab 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -28,9 +28,12 @@ //! Note that you can use the parallel iterator for [Zip] to access all other //! rayon parallel iterator methods. //! -//! Only the axis iterators are indexed parallel iterators, the rest are all -//! “unindexed”. Use ndarray’s [Zip] for lock step parallel iteration of -//! multiple arrays or producers at a time. +//! Only the axis iterators and one-dimensional array views make indexed parallel iterators, the +//! rest are all “unindexed” in rayon terms. +//! +//! Ndarray’s [Zip] is specially made to handle lock step parallel iteration of multiple arrays or +//! producers at a time, and it can handle multidimensional inputs efficiently, even with +//! multidimensional indexing. [Zip] is always recommended over rayon `.zip()` for performance. //! //! # Examples //! diff --git a/src/parallel/par.rs b/src/parallel/par.rs index efd761acf..2bf95671d 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -1,4 +1,5 @@ use rayon::iter::plumbing::bridge; +use rayon::iter::plumbing::bridge_producer_consumer; use rayon::iter::plumbing::bridge_unindexed; use rayon::iter::plumbing::Folder; use rayon::iter::plumbing::Producer; @@ -13,7 +14,7 @@ use crate::iter::AxisChunksIter; use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; -use crate::Dimension; +use crate::{Dimension, Ix1, Axis}; use crate::{ArrayView, ArrayViewMut}; /// Parallel iterator wrapper. @@ -143,11 +144,23 @@ macro_rules! par_iter_view_wrapper { fn drive_unindexed(self, consumer: C) -> C::Result where C: UnindexedConsumer { - bridge_unindexed(ParallelProducer(self.iter), consumer) + // Self is an IndexedParallelIterator when dimension is Ix1: use the indexed driver in that + // case, otherwise the general `bridge_unindexed` + if let Some(1) = D::NDIM { + let iter = self.iter.into_dimensionality::().unwrap(); + bridge_producer_consumer(iter.len(), ParallelProducer(iter), consumer) + } else { + bridge_unindexed(ParallelProducer(self.iter), consumer) + } } fn opt_len(&self) -> Option { - None + // We can return a known length here if we use an indexed bridge function in in drive_unindexed, + if let Some(1) = D::NDIM { + Some(self.iter.len()) + } else { + None + } } } @@ -156,6 +169,7 @@ macro_rules! par_iter_view_wrapper { A: $($thread_bounds)*, { type Item = <$view_name<'a, A, D> as IntoIterator>::Item; + fn split(self) -> (Self, Option) { if self.0.len() <= 1 { return (self, None) @@ -170,7 +184,56 @@ macro_rules! par_iter_view_wrapper { fn fold_with(self, folder: F) -> F where F: Folder, { - self.into_iter().fold(folder, move |f, elt| f.consume(elt)) + Zip::from(self.0).fold_while(folder, |mut folder, elt| { + folder = folder.consume(elt); + if folder.full() { + FoldWhile::Done(folder) + } else { + FoldWhile::Continue(folder) + } + }).into_inner() + } + } + + impl<'a, A> Producer for ParallelProducer<$view_name<'a, A, Ix1>> + where A: $($thread_bounds)*, + { + type Item = <$view_name<'a, A, Ix1> as IntoIterator>::Item; + type IntoIter = <$view_name<'a, A, Ix1> as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } + + fn split_at(self, index: usize) -> (Self, Self) { + let (a, b) = self.0.split_at(Axis(0), index); + (ParallelProducer(a), ParallelProducer(b)) + } + + fn fold_with(self, folder: F) -> F + where F: Folder, + { + UnindexedProducer::fold_with(self, folder) + } + } + + impl<'a, A> IndexedParallelIterator for Parallel<$view_name<'a, A, Ix1>> + where A: $($thread_bounds)*, + { + fn with_producer(self, callback: Cb) -> Cb::Output + where Cb: ProducerCallback + { + callback.callback(ParallelProducer(self.iter)) + } + + fn len(&self) -> usize { + self.iter.len() + } + + fn drive(self, consumer: C) -> C::Result + where C: Consumer + { + bridge(self, consumer) } } @@ -180,6 +243,7 @@ macro_rules! par_iter_view_wrapper { { type Item = <$view_name<'a, A, D> as IntoIterator>::Item; type IntoIter = <$view_name<'a, A, D> as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 4d5a8f1a9..f927ca0fe 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -2,6 +2,8 @@ use ndarray::parallel::prelude::*; use ndarray::prelude::*; +use std::iter::FromIterator; +use std::iter::repeat; const M: usize = 1024 * 10; const N: usize = 100; @@ -86,3 +88,21 @@ fn test_axis_chunks_iter_mut() { println!("{:?}", a.slice(s![..10, ..5])); assert_abs_diff_eq!(a, b, epsilon = 0.001); } + +#[test] +fn view_1d_indexeded() { + // test that .zip() can be used on 1D ArrayViews. + let mut a = Array::from_iter(0..((M * N) as i64)).into_shape((M, N)).unwrap(); + + // For columns A0 and A1, compute A0 = A1 - A0 (== 1) + let (a0, a1) = a.multi_slice_mut((s![.., 0], s![.., 1])); + let a1_items = a1.view().into_par_iter().cloned().collect::>(); + + a0.into_par_iter() + .zip(a1.view()) + .for_each(|(x, &y)| *x = y - *x); + + assert_eq!(a.column(0), Array::from_iter(repeat(1).take(M))); + + assert_eq!(a.column(1), Array::from(a1_items)); +}