Skip to content

Commit 8b61ee5

Browse files
committed
Merge commit 'aec182ae72d51dad0f46cdfe7ff9a41380d7da35'
2 parents 76ca3eb + aec182a commit 8b61ee5

File tree

5 files changed

+53
-32
lines changed

5 files changed

+53
-32
lines changed

torch/lib/THC/THCReduceAll.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ bool THC_reduceAll(THCState* state,
331331
// If our destination is not on the device, copy the value back to
332332
// the host (synchronous!)
333333
if (!outOnDevice) {
334-
cudaMemcpy(out, devOut, sizeof(AccT), cudaMemcpyDeviceToHost);
334+
THCudaCheck(cudaMemcpy(out, devOut, sizeof(AccT), cudaMemcpyDeviceToHost));
335335
}
336336

337337
if (freeDevOut) {

torch/lib/THC/THCScanUtils.cuh

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
// Collection of in-kernel scan / prefix sum utilities
77

88
// Inclusive prefix sum using shared memory
9-
template <typename T, bool KillWARDependency>
10-
__device__ void inclusivePrefixSum(T* smem, T in, T* out) {
9+
template <typename T, bool KillWARDependency, class BinaryFunction>
10+
__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) {
1111
// FIXME: this is a slow, simple implementation; need up/down sweep,
1212
// prevent smem conflicts
1313
smem[threadIdx.x] = in;
@@ -18,7 +18,7 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
1818
T val = 0;
1919

2020
if (threadIdx.x >= offset) {
21-
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
21+
val = binop(smem[threadIdx.x - offset], smem[threadIdx.x]);
2222
}
2323

2424
__syncthreads();
@@ -38,11 +38,11 @@ __device__ void inclusivePrefixSum(T* smem, T in, T* out) {
3838
}
3939

4040
// Exclusive prefix sum using shared memory
41-
template <typename T, bool KillWARDependency>
42-
__device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
41+
template <typename T, bool KillWARDependency, class BinaryFunction>
42+
__device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunction binop) {
4343
// FIXME: crappy implementation
4444
// We kill write-after-read dependencies separately below, hence the `false`
45-
inclusivePrefixSum<T, false>(smem, in, out);
45+
inclusivePrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
4646

4747
*out -= in;
4848
*carry = smem[blockDim.x - 1];
@@ -55,8 +55,8 @@ __device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
5555

5656
// Inclusive prefix sum for binary vars using intra-warp voting +
5757
// shared memory
58-
template <typename T, bool KillWARDependency>
59-
__device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
58+
template <typename T, bool KillWARDependency, class BinaryFunction>
59+
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
6060
// Within-warp, we use warp voting.
6161
T vote = __ballot(in);
6262
T index = __popc(getLaneMaskLe() & vote);
@@ -77,16 +77,16 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
7777
int current = 0;
7878
for (int i = 0; i < blockDim.x / 32; ++i) {
7979
T v = smem[i];
80-
smem[i] += current;
81-
current += v;
80+
smem[i] = binop(smem[i], current);
81+
current = binop(current, v);
8282
}
8383
}
8484

8585
__syncthreads();
8686

8787
// load the carry from the preceding warp
8888
if (warp >= 1) {
89-
index += smem[warp - 1];
89+
index = binop(index, smem[warp - 1]);
9090
}
9191

9292
*out = index;
@@ -98,9 +98,9 @@ __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
9898

9999
// Exclusive prefix sum for binary vars using intra-warp voting +
100100
// shared memory
101-
template <typename T, bool KillWARDependency>
102-
__device__ void exclusiveBinaryPrefixSum(T* smem, bool in, T* out, T* carry) {
103-
inclusiveBinaryPrefixSum<T, false>(smem, in, out);
101+
template <typename T, bool KillWARDependency, class BinaryFunction>
102+
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
103+
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
104104

105105
// Inclusive to exclusive
106106
*out -= (T) in;

torch/lib/THC/THCTensorTopK.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "THCAsmUtils.cuh"
66
#include "THCScanUtils.cuh"
77
#include "THCTensorTypeUtils.cuh"
8+
#include "THCTensorMathReduce.cuh"
89
#include <algorithm> // for std::min
910

1011
#if CUDA_VERSION >= 7000
@@ -322,7 +323,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
322323

323324
int index;
324325
int carry;
325-
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
326+
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
326327

327328
if (hasTopK) {
328329
int writeIndex = writeIndexStart + index;
@@ -354,7 +355,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
354355

355356
int index;
356357
int carry;
357-
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
358+
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
358359

359360
if (hasTopK && index < topKRemaining) {
360361
int writeIndex = writeIndexStart + index;

torch/lib/THC/generic/THCTensorMath.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
8787
// loop below will overwrite the value
8888
int maxDim = dimension + 1;
8989

90-
// ldimension is the actual dimension we cat along (minus 1, for 0-based indexing)
91-
int ldimension = dimension;
90+
// cat_dimension is the actual dimension we cat along
91+
int cat_dimension = dimension;
9292

9393
for (i = 0; i < numInputs; i++)
9494
{
@@ -100,13 +100,13 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
100100
// In the event that the user specified -1 as the concat dimension, then
101101
// we want to pick the maxDim as dimension to cat along (and thus maxDim - 1 as the
102102
// value due to 0-based indexing). If the maxDim is // 0 (i.e. we are catting all
103-
// empty tensors), then we set ldimension to be 0
103+
// empty tensors), then we set cat_dimension to be 0
104104
if (dimension + TH_INDEX_BASE == -1) {
105-
ldimension = maxDim ? (maxDim - 1) : 0;
105+
cat_dimension = maxDim ? (maxDim - 1) : 0;
106106
}
107107

108108
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
109-
THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
109+
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
110110

111111
size = THLongStorage_newWithSize(maxDim);
112112
for(i = 0; i < maxDim; i++)
@@ -115,7 +115,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
115115
long dimSize = i < THCTensor_(nDimension)(state, inputs[0])
116116
? THCTensor_(size)(state, inputs[0], i)
117117
: THMin(THCTensor_(nDimension)(state, inputs[0]), 1);
118-
if (i == ldimension)
118+
if (i == cat_dimension)
119119
{
120120
for (j = 1; j < numInputs; j++)
121121
{
@@ -203,15 +203,15 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
203203

204204
// Template Declarations for dim = 1, 2, 3, 4
205205
#define HANDLE_CASE(DIMS) \
206-
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock>>>(data, d_inputs, param, ldimension, param.outputStride[dimension]);
206+
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
207207

208208
// Now we loop
209209
offset = 0;
210210
for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
211211
cohortMax = 0;
212212
for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
213-
long dimSize = ldimension < THCTensor_(nDimension)(state, inputs[i+j])
214-
? THCTensor_(size)(state, inputs[i+j], ldimension)
213+
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j])
214+
? THCTensor_(size)(state, inputs[i+j], cat_dimension)
215215
: 1;
216216

217217
stackInputs[j].input = THCTensor_(data)(state, inputs[i+j]);
@@ -223,7 +223,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
223223
// update offset
224224
offset += dimSize;
225225
}
226-
cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor<real, unsigned int>), cudaMemcpyHostToDevice);
226+
THCudaCheck(cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor<real, unsigned int>), cudaMemcpyHostToDevice));
227227

228228
// Next, let's consider how we set our kernel launch parameters.
229229
// We borrow from THCApply, which the kernel's internal indexing
@@ -267,12 +267,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
267267
// No reason to copy when input is empty
268268
if (!THCTensor_(nDimension)(state, inputs[j])) continue;
269269

270-
long dimSize = ldimension < THCTensor_(nDimension)(state, inputs[j])
271-
? THCTensor_(size)(state, inputs[j], ldimension)
270+
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[j])
271+
? THCTensor_(size)(state, inputs[j], cat_dimension)
272272
: 1;
273273

274274
THCTensor *nt = THCTensor_(newWithTensor)(state, result);
275-
THCTensor_(narrow)(state, nt, NULL, ldimension, offset, dimSize);
275+
THCTensor_(narrow)(state, nt, NULL, cat_dimension, offset, dimSize);
276276
THCTensor_(copy)(state, nt, inputs[j]);
277277
THCTensor_(free)(state, nt);
278278
offset += dimSize;

torch/lib/THC/generic/THCTensorMathBlas.cu

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
430430
THC_API void
431431
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
432432
real alpha, THCTensor *batch1, THCTensor *batch2) {
433-
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
433+
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
434434
THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
435435
THArgCheck(THCTensor_(nDimension)(state, t) == 3, 4, "expected 3D tensor");
436436
THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
@@ -522,8 +522,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
522522
ldb = batch2_->stride[1];
523523
}
524524

525-
// Compute pointers to matrices in each batch.
526525
long num_batches = result_->size[0];
526+
527+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
528+
// Compute pointers to matrices in each batch.
527529
size_t matrices_size = num_batches * sizeof(real*);
528530

529531
// Copy pointers to device.
@@ -580,6 +582,24 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
580582
THCudaFree(state, d_matrices2);
581583
THCudaFree(state, d_result_matrices);
582584

585+
#elif defined(THC_REAL_IS_HALF)
586+
// Currently no HgemmBatched in Cublas
587+
for (long i = 0; i < num_batches; ++i) {
588+
THCudaBlas_Hgemm(
589+
state,
590+
transpose_batch1,
591+
transpose_batch2,
592+
result_->size[transpose_result ? 2 : 1],
593+
result_->size[transpose_result ? 1 : 2],
594+
batch1_->size[transpose_result ? 1 : 2],
595+
alpha,
596+
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
597+
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
598+
beta,
599+
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
600+
}
601+
#endif
602+
583603
if (batch1_ != batch1) {
584604
THCTensor_(free)(state, batch1_);
585605
}

0 commit comments

Comments
 (0)