Skip to content

Commit 98775b6

Browse files
authored
Merge pull request #718 from killeent/templatize-scan
genericize PrefixSum --> PrefixScan via binary operator template parameter
2 parents dfca8df + b7cc2a5 commit 98775b6

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

THCScanUtils.cuh

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
// Collection of in-kernel scan / prefix sum utilities
77

88
// Inclusive prefix sum using shared memory
9-
template <typename T, bool KillWARDependency>
10-
__device__ void inclusivePrefixSum(T* smem, T in, T* out) {
9+
template <typename T, bool KillWARDependency, class BinaryFunction>
10+
__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) {
1111
// FIXME: this is a slow, simple implementation; need up/down sweep,
1212
// prevent smem conflicts
1313
smem[threadIdx.x] = in;
@@ -18,7 +18,7 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
1818
T val = 0;
1919

2020
if (threadIdx.x >= offset) {
21-
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
21+
val = binop(smem[threadIdx.x - offset], smem[threadIdx.x]);
2222
}
2323

2424
__syncthreads();
@@ -38,11 +38,11 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
3838
}
3939

4040
// Exclusive prefix sum using shared memory
41-
template <typename T, bool KillWARDependency>
42-
__device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
41+
template <typename T, bool KillWARDependency, class BinaryFunction>
42+
__device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunction binop) {
4343
// FIXME: crappy implementation
4444
// We kill write-after-read dependencies separately below, hence the `false`
45-
inclusivePrefixSum<T, false>(smem, in, out);
45+
inclusivePrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
4646

4747
*out -= in;
4848
*carry = smem[blockDim.x - 1];
@@ -55,8 +55,8 @@ __device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
5555

5656
// Inclusive prefix sum for binary vars using intra-warp voting +
5757
// shared memory
58-
template <typename T, bool KillWARDependency>
59-
__device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
58+
template <typename T, bool KillWARDependency, class BinaryFunction>
59+
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
6060
// Within-warp, we use warp voting.
6161
T vote = __ballot(in);
6262
T index = __popc(getLaneMaskLe() & vote);
@@ -77,16 +77,16 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
7777
int current = 0;
7878
for (int i = 0; i < blockDim.x / 32; ++i) {
7979
T v = smem[i];
80-
smem[i] += current;
81-
current += v;
80+
smem[i] = binop(smem[i], current);
81+
current = binop(current, v);
8282
}
8383
}
8484

8585
__syncthreads();
8686

8787
// load the carry from the preceding warp
8888
if (warp >= 1) {
89-
index += smem[warp - 1];
89+
index = binop(index, smem[warp - 1]);
9090
}
9191

9292
*out = index;
@@ -98,9 +98,9 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
9898

9999
// Exclusive prefix sum for binary vars using intra-warp voting +
100100
// shared memory
101-
template <typename T, bool KillWARDependency>
102-
__device__ void exclusiveBinaryPrefixSum(T* smem, bool in, T* out, T* carry) {
103-
inclusiveBinaryPrefixSum<T, false>(smem, in, out);
101+
template <typename T, bool KillWARDependency, class BinaryFunction>
102+
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
103+
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
104104

105105
// Inclusive to exclusive
106106
*out -= (T) in;

THCTensorTopK.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "THCAsmUtils.cuh"
66
#include "THCScanUtils.cuh"
77
#include "THCTensorTypeUtils.cuh"
8+
#include "THCTensorMathReduce.cuh"
89
#include <algorithm> // for std::min
910

1011
#if CUDA_VERSION >= 7000
@@ -322,7 +323,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
322323

323324
int index;
324325
int carry;
325-
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
326+
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
326327

327328
if (hasTopK) {
328329
int writeIndex = writeIndexStart + index;
@@ -354,7 +355,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
354355

355356
int index;
356357
int carry;
357-
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
358+
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
358359

359360
if (hasTopK && index < topKRemaining) {
360361
int writeIndex = writeIndexStart + index;

0 commit comments

Comments
 (0)