@@ -57,6 +57,15 @@ void SpatialSoftMax_getLaunchSizes(
5757 grid = SpatialSoftMax_getGridSize (block, max_active_blocks, outer_size, dim_size, inner_size);
5858}
5959
60+ inline dim3 SoftMax_getBlockSize (int ILP, uint64_t dim_size) {
61+ uint64_t block_size = 1 ;
62+ uint64_t max_block_size = std::min (dim_size / ILP, static_cast <uint64_t >(1024 ));
63+ while (block_size < max_block_size) block_size *= 2 ;
64+ // Launch at least a single warp - the kernel assumes that.
65+ block_size = std::max (block_size, static_cast <uint64_t >(32 ));
66+ return dim3 (block_size);
67+ }
68+
6069template <typename T>
6170struct Add {
6271 __device__ __forceinline__ T operator ()(T a, T b) const {
@@ -392,13 +401,14 @@ void HostSoftMaxForward(
392401 uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
393402 int dim)
394403{
395- // This kernel spawns a block of 1024 threads per each element in the batch.
404+ // This kernel spawns a block per each element in the batch.
396405 // XXX: it assumes that inner_size == 1
397- if (inner_size == 1 && dim_size >= 64 ) {
406+ if (inner_size == 1 ) {
407+ const int ILP = 2 ;
398408 dim3 grid (outer_size);
399- dim3 block ( 1024 );
409+ dim3 block = SoftMax_getBlockSize (ILP, dim_size );
400410
401- cunn_SoftMaxForward<2 , T, AccumT, Epilogue>
411+ cunn_SoftMaxForward<ILP , T, AccumT, Epilogue>
402412 <<<grid, block, block.x * sizeof (AccumT), THCState_getCurrentStream(state)>>> (
403413 output, input, dim_size
404414 );
@@ -429,11 +439,12 @@ void HostSoftMaxBackward(
429439 int dim)
430440{
431441 // See descriptions of kernels above.
432- if (inner_size == 1 && dim_size >= 64 ) {
442+ if (inner_size == 1 ) {
443+ const int ILP = 2 ;
433444 dim3 grid (outer_size);
434- dim3 block ( 1024 );
445+ dim3 block = SoftMax_getBlockSize (ILP, dim_size );
435446
436- cunn_SoftMaxBackward<2 , T, AccumT, Epilogue>
447+ cunn_SoftMaxBackward<ILP , T, AccumT, Epilogue>
437448 <<<grid, block, block.x * sizeof (AccumT), THCState_getCurrentStream(state)>>> (
438449 gradInput, output, gradOutput, dim_size
439450 );
0 commit comments