Skip to content

Commit 4fa34ae

Browse files
committed
Initial implementation of ParallelIterator for AxisIter
1 parent 6f75925 commit 4fa34ae

File tree

5 files changed

+103
-0
lines changed

5 files changed

+103
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ optional = true
3939
blas-sys = { version = "0.6.3", optional = true, default-features = false }
4040
matrixmultiply = { version = "0.1.11" }
4141

42+
rayon = { version = "0.5.0", optional = true, default-features = false }
43+
4244
[dependencies.serde]
4345
version = "0.8"
4446
optional = true

src/iterators/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ use super::{
2121
Axis,
2222
};
2323

24+
#[cfg(feature = "rayon")]
25+
mod par;
26+
2427
/// Base for array iterators
2528
///
2629
/// Iterator element type is `&'a A`.

src/iterators/par.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
3+
use rayon::par_iter::ParallelIterator;
4+
use rayon::par_iter::IndexedParallelIterator;
5+
use rayon::par_iter::ExactParallelIterator;
6+
use rayon::par_iter::BoundedParallelIterator;
7+
use rayon::par_iter::internal::{Consumer, UnindexedConsumer};
8+
use rayon::par_iter::internal::bridge;
9+
use rayon::par_iter::internal::ProducerCallback;
10+
use rayon::par_iter::internal::Producer;
11+
12+
use super::AxisIter;
13+
use imp_prelude::*;
14+
15+
16+
17+
impl<'a, A, D> ParallelIterator for AxisIter<'a, A, D>
18+
where D: Dimension,
19+
A: Sync,
20+
{
21+
type Item = <Self as Iterator>::Item;
22+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
23+
where C: UnindexedConsumer<Self::Item>
24+
{
25+
bridge(self, consumer)
26+
}
27+
}
28+
29+
impl<'a, A, D> IndexedParallelIterator for AxisIter<'a, A, D>
30+
where D: Dimension,
31+
A: Sync,
32+
{
33+
fn with_producer<Cb>(self, callback: Cb) -> Cb::Output
34+
where Cb: ProducerCallback<Self::Item>
35+
{
36+
callback.callback(self)
37+
}
38+
}
39+
40+
impl<'a, A, D> ExactParallelIterator for AxisIter<'a, A, D>
41+
where D: Dimension,
42+
A: Sync,
43+
{
44+
fn len(&mut self) -> usize {
45+
self.size_hint().0
46+
}
47+
}
48+
49+
impl<'a, A, D> BoundedParallelIterator for AxisIter<'a, A, D>
50+
where D: Dimension,
51+
A: Sync,
52+
{
53+
fn upper_bound(&mut self) -> usize {
54+
ExactParallelIterator::len(self)
55+
}
56+
57+
fn drive<C>(self, consumer: C) -> C::Result
58+
where C: Consumer<Self::Item>
59+
{
60+
bridge(self, consumer)
61+
}
62+
}
63+
64+
// This is the real magic, I guess
65+
66+
impl<'a, A, D> Producer for AxisIter<'a, A, D>
67+
where D: Dimension,
68+
A: Sync,
69+
{
70+
fn cost(&mut self, len: usize) -> f64 {
71+
// FIXME: No idea about what this is
72+
len as f64
73+
}
74+
75+
fn split_at(self, i: usize) -> (Self, Self) {
76+
self.split_at(i)
77+
}
78+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ extern crate matrixmultiply;
7777
extern crate itertools;
7878
extern crate num_traits as libnum;
7979
extern crate num_complex;
80+
#[cfg(feature = "rayon")]
81+
extern crate rayon;
8082

8183
use std::iter::Zip;
8284
use std::marker::PhantomData;

tests/rayon.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#![cfg(feature = "rayon")]
2+
3+
extern crate rayon;
4+
extern crate ndarray;
5+
6+
use ndarray::prelude::*;
7+
8+
use rayon::par_iter::ParallelIterator;
9+
10+
#[test]
11+
fn test_axis_iter() {
12+
let mut a = Array2::<u32>::zeros((10240, 10240));
13+
for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() {
14+
v.fill(i as _);
15+
}
16+
let s = ParallelIterator::map(a.axis_iter(Axis(0)), |x| x.scalar_sum()).sum();
17+
assert_eq!(s, a.scalar_sum());
18+
}

0 commit comments

Comments
 (0)