66// Collection of in-kernel scan / prefix sum utilities
77
88// Inclusive prefix sum using shared memory
9- template <typename T, bool KillWARDependency>
10- __device__ void inclusivePrefixSum (T* smem, T in, T* out) {
9+ template <typename T, bool KillWARDependency, class BinaryFunction >
10+ __device__ void inclusivePrefixScan (T* smem, T in, T* out, BinaryFunction binop ) {
1111 // FIXME: this is a slow, simple implementation; need up/down sweep,
1212 // prevent smem conflicts
1313 smem[threadIdx .x ] = in;
@@ -18,7 +18,7 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
1818 T val = 0 ;
1919
2020 if (threadIdx .x >= offset) {
21- val = smem[threadIdx .x - offset] + smem[threadIdx .x ];
21+ val = binop ( smem[threadIdx .x - offset], smem[threadIdx .x ]) ;
2222 }
2323
2424 __syncthreads ();
@@ -38,11 +38,11 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
3838}
3939
4040// Exclusive prefix sum using shared memory
41- template <typename T, bool KillWARDependency>
42- __device__ void exclusivePrefixSum (T* smem, T in, T* out, T* carry) {
41+ template <typename T, bool KillWARDependency, class BinaryFunction >
42+ __device__ void exclusivePrefixScan (T* smem, T in, T* out, T* carry, BinaryFunction binop ) {
4343 // FIXME: crappy implementation
4444 // We kill write-after-read dependencies separately below, hence the `false`
45- inclusivePrefixSum <T, false >(smem, in, out);
45+ inclusivePrefixScan <T, false , BinaryFunction >(smem, in, out, binop );
4646
4747 *out -= in;
4848 *carry = smem[blockDim .x - 1 ];
@@ -55,8 +55,8 @@ __device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
5555
5656// Inclusive prefix sum for binary vars using intra-warp voting +
5757// shared memory
58- template <typename T, bool KillWARDependency>
59- __device__ void inclusiveBinaryPrefixSum (T* smem, bool in, T* out) {
58+ template <typename T, bool KillWARDependency, class BinaryFunction >
59+ __device__ void inclusiveBinaryPrefixScan (T* smem, bool in, T* out, BinaryFunction binop ) {
6060 // Within-warp, we use warp voting.
6161 T vote = __ballot (in);
6262 T index = __popc (getLaneMaskLe () & vote);
@@ -77,16 +77,16 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
7777 int current = 0 ;
7878 for (int i = 0 ; i < blockDim .x / 32 ; ++i) {
7979 T v = smem[i];
80- smem[i] += current;
81- current += v ;
80+ smem[i] = binop (smem[i], current) ;
81+ current = binop (current, v) ;
8282 }
8383 }
8484
8585 __syncthreads ();
8686
8787 // load the carry from the preceding warp
8888 if (warp >= 1 ) {
89- index += smem[warp - 1 ];
89+ index = binop (index, smem[warp - 1 ]) ;
9090 }
9191
9292 *out = index;
@@ -98,9 +98,9 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
9898
9999// Exclusive prefix sum for binary vars using intra-warp voting +
100100// shared memory
101- template <typename T, bool KillWARDependency>
102- __device__ void exclusiveBinaryPrefixSum (T* smem, bool in, T* out, T* carry) {
103- inclusiveBinaryPrefixSum <T, false >(smem, in, out);
101+ template <typename T, bool KillWARDependency, class BinaryFunction >
102+ __device__ void exclusiveBinaryPrefixScan (T* smem, bool in, T* out, T* carry, BinaryFunction binop ) {
103+ inclusiveBinaryPrefixScan <T, false , BinaryFunction >(smem, in, out, binop );
104104
105105 // Inclusive to exclusive
106106 *out -= (T) in;
0 commit comments