Skip to content

Commit 17f94d2

Browse files
albanDsoumith
authored andcommitted
Fix topk work size computation (#5053)
* fix grid computation for topk kernel * backslash alignment, no change in code
1 parent ac0b41e commit 17f94d2

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

torch/lib/THC/generic/THCTensorTopK.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ THC_API void THCTensor_(topk)(THCState* state,
2929
THLongStorage_free(topKSize);
3030

3131
#define RUN_K(INDEX_T, DIM, DIR) \
32-
gatherTopK<real, INDEX_T, DIM, DIR> \
32+
gatherTopK<real, INDEX_T, DIM, DIR> \
3333
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
3434
inputInfo, \
3535
sliceSize, \
@@ -63,10 +63,10 @@ THC_API void THCTensor_(topk)(THCState* state,
6363
}
6464

6565
#define RUN_T(INDEX_T) \
66-
TensorInfo<real, INDEX_T> inputInfo = \
67-
getTensorInfo<THCTensor, INDEX_T>(state, input); \
68-
TensorInfo<real, INDEX_T> topKInfo = \
69-
getTensorInfo<THCTensor, INDEX_T>(state, topK); \
66+
TensorInfo<real, INDEX_T> inputInfo = \
67+
getTensorInfo<THCTensor, INDEX_T>(state, input); \
68+
TensorInfo<real, INDEX_T> topKInfo = \
69+
getTensorInfo<THCTensor, INDEX_T>(state, topK); \
7070
TensorInfo<int64_t, INDEX_T> indicesInfo = \
7171
getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices); \
7272
\
@@ -82,9 +82,11 @@ THC_API void THCTensor_(topk)(THCState* state,
8282
int collapseIndicesDim = indicesInfo.collapseDims(dim); \
8383
\
8484
int64_t inputSlices = 1; \
85-
int64_t topKSlices = 1; \
86-
for (int i = 0; i < numDims; ++i) { \
85+
for (int i = 0; i < inputInfo.dims; ++i) { \
8786
inputSlices *= inputInfo.sizes[i]; \
87+
} \
88+
int64_t topKSlices = 1; \
89+
for (int i = 0; i < topKInfo.dims; ++i) { \
8890
topKSlices *= topKInfo.sizes[i]; \
8991
} \
9092
\

0 commit comments

Comments
 (0)