Skip to content

Commit fe6b87b

Browse files
committed
more helpers for NA bitmasking
1 parent 352f344 commit fe6b87b

File tree

4 files changed

+95
-32
lines changed

4 files changed

+95
-32
lines changed

inst/examples/example-na-handling-variance.cpp

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,11 @@ using namespace RcppNT2;
99
#include <Rcpp.h>
1010
using namespace Rcpp;
1111

12-
struct IsNaN
13-
{
14-
template <typename T>
15-
T operator()(const T& data)
16-
{
17-
return nt2::if_else(data == data, bitwise::ones(), bitwise::zeroes());
18-
}
19-
};
20-
21-
struct NumNonZero
22-
{
23-
template <typename T>
24-
void operator()(const T& t)
25-
{
26-
result_ += nt2::sum(nt2::if_else(t != 0.0, 1.0, 0.0));
27-
}
28-
29-
operator double() const { return result_; }
30-
31-
double result_ = 0.0;
32-
};
33-
12+
// A 'mask-aware' Sum class. The call operator accepts both
13+
// 'data' and a 'mask' of the same time; one can use
14+
// 'nt2::bitwise_and' to apply the mask. In our case, we'll
15+
// have the mask set as bitwise '0' for NA', and bitwise '1'
16+
// otherwise.
3417
class Sum
3518
{
3619
public:
@@ -50,6 +33,9 @@ class Sum
5033

5134
};
5235

36+
// A 'mask-aware' Sum of Squares class. As above, the
37+
// call operator accepts both a 'data' and a 'mask'.
38+
// We use 'nt2::bitwise_and' to apply the mask.
5339
class SumOfSquares
5440
{
5541
public:
@@ -73,16 +59,14 @@ class SumOfSquares
7359
// [[Rcpp::export]]
7460
double simdVariance(NumericVector x)
7561
{
76-
// Compute our NA bitmask. It's a vector o 'double's,
77-
// as we want to ensure SIMD 'double' x 'double'
78-
// instructions can be emitted when interacting with
79-
// our bitmask. Note that we use our 'begin' and 'end'
80-
// helpers
62+
// Use helpers to compute the NA bitmask, as well as
63+
// the number of non-NA elements in the vector.
8164
std::vector<double> naMask(x.size());
82-
simdTransform(&x[0], &x[0] + x.size(), &naMask[0], IsNaN());
65+
std::size_t n;
66+
na::mask(pbegin(x), pend(x), pbegin(naMask), &n);
8367

84-
// Using our bitmask, compute the number of non-NA elements.
85-
double N = simdFor(&naMask[0], &naMask[0] + naMask.size(), NumNonZero());
68+
// Compute the number of non-NA elements.
69+
std::size_t N = x.size() - n;
8670

8771
// Compute the sum of our 'x' vector, discarding NAs.
8872
double total = simdFor(&x[0], &x[0] + x.size(), &naMask[0], Sum());

inst/include/RcppNT2.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <RcppNT2/bitwise/bitwise.h>
4545
#include <RcppNT2/convert/convert.h>
4646
#include <RcppNT2/functor/functor.h>
47+
#include <RcppNT2/na/na.h>
4748
#include <RcppNT2/traits/traits.h>
4849
#include <RcppNT2/variadic/variadic.h>
4950

inst/include/RcppNT2/core/core.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
namespace RcppNT2 {
77

88
template <typename T>
9-
auto begin(T&& t) -> decltype(&std::forward<T>(t)[0])
9+
auto pbegin(T&& t) -> decltype(&std::forward<T>(t)[0])
1010
{
1111
return &std::forward<T>(t)[0];
1212
}
1313

1414
template <typename T>
15-
auto end(T&& t) -> decltype(&std::forward<T>(t)[0])
15+
auto pend(T&& t) -> decltype(&std::forward<T>(t)[0])
1616
{
1717
return &std::forward<T>(t)[0] + std::forward<T>(t).size();
1818
}

inst/include/RcppNT2/na/na.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef RCPP_NT2_NA_NA_H
2+
#define RCPP_NT2_NA_NA_H
3+
4+
#include <RcppNT2/bitwise/bitwise.h>
5+
6+
namespace RcppNT2 {
7+
namespace na {
8+
9+
namespace detail {
10+
11+
struct Masker
12+
{
13+
template <typename T>
14+
T operator()(const T& data)
15+
{
16+
static const double ones = bitwise::ones();
17+
static const double zeroes = bitwise::zeroes();
18+
return nt2::if_else(data == data, ones, zeroes);
19+
}
20+
};
21+
22+
struct MaskerWithCount
23+
{
24+
25+
template <typename T>
26+
T operator()(const T& data)
27+
{
28+
static const double ones = bitwise::ones();
29+
static const double zeroes = bitwise::zeroes();
30+
auto&& compare = data == data;
31+
count_ += nt2::sum(nt2::if_else(compare, 1.0, 0.0));
32+
return nt2::if_else(compare, ones, zeroes);
33+
}
34+
35+
std::size_t count() { return count_; }
36+
std::size_t count_ = 0;
37+
};
38+
39+
struct Counter
40+
{
41+
template <typename T>
42+
void operator()(const T& t)
43+
{
44+
count_ += nt2::sum(t != t);
45+
}
46+
47+
template <typename T>
48+
operator T() const { return T(count_); }
49+
50+
std::size_t count_ = 0;
51+
};
52+
53+
} // namespace detail
54+
55+
template <typename T, typename U>
56+
void mask(const T* begin, const T* end, U* out)
57+
{
58+
simdTransform(begin, end, out, detail::Masker());
59+
}
60+
61+
template <typename T, typename U, typename V>
62+
void mask(const T* begin, const T* end, U* out, V* count)
63+
{
64+
detail::MaskerWithCount masker;
65+
simdTransform(begin, end, out, std::reference_wrapper<detail::MaskerWithCount>(masker));
66+
*count = masker.count();
67+
}
68+
69+
template <typename T>
70+
std::size_t count(const T& t)
71+
{
72+
return simdReduce(pbegin(t), pend(t), std::size_t(0), detail::Counter());
73+
}
74+
75+
} // namespace na
76+
} // namespace RcppNT2
77+
78+
#endif /* RCPP_NT2_NA_NA_H */

0 commit comments

Comments
 (0)