Skip to content

Commit 30bd9e6

Browse files
apaszkesoumith
authored andcommitted
Improve CUDA softmax performance
1 parent 4e549e9 commit 30bd9e6

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

torch/lib/THCUNN/SoftMaxCommon.cuh

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ void SpatialSoftMax_getLaunchSizes(
5757
grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
5858
}
5959

60+
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
61+
uint64_t block_size = 1;
62+
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(1024));
63+
while (block_size < max_block_size) block_size *= 2;
64+
// Launch at least a single warp - the kernel assumes that.
65+
block_size = std::max(block_size, static_cast<uint64_t>(32));
66+
return dim3(block_size);
67+
}
68+
6069
template<typename T>
6170
struct Add {
6271
__device__ __forceinline__ T operator()(T a, T b) const {
@@ -392,13 +401,14 @@ void HostSoftMaxForward(
392401
uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
393402
int dim)
394403
{
395-
// This kernel spawns a block of 1024 threads per each element in the batch.
404+
// This kernel spawns a block per each element in the batch.
396405
// XXX: it assumes that inner_size == 1
397-
if (inner_size == 1 && dim_size >= 64) {
406+
if (inner_size == 1) {
407+
const int ILP = 2;
398408
dim3 grid(outer_size);
399-
dim3 block(1024);
409+
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
400410

401-
cunn_SoftMaxForward<2, T, AccumT, Epilogue>
411+
cunn_SoftMaxForward<ILP, T, AccumT, Epilogue>
402412
<<<grid, block, block.x * sizeof(AccumT), THCState_getCurrentStream(state)>>>(
403413
output, input, dim_size
404414
);
@@ -429,11 +439,12 @@ void HostSoftMaxBackward(
429439
int dim)
430440
{
431441
// See descriptions of kernels above.
432-
if (inner_size == 1 && dim_size >= 64) {
442+
if (inner_size == 1) {
443+
const int ILP = 2;
433444
dim3 grid(outer_size);
434-
dim3 block(1024);
445+
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
435446

436-
cunn_SoftMaxBackward<2, T, AccumT, Epilogue>
447+
cunn_SoftMaxBackward<ILP, T, AccumT, Epilogue>
437448
<<<grid, block, block.x * sizeof(AccumT), THCState_getCurrentStream(state)>>>(
438449
gradInput, output, gradOutput, dim_size
439450
);

0 commit comments

Comments
 (0)