@@ -29,7 +29,7 @@ THC_API void THCTensor_(topk)(THCState* state,
2929 THLongStorage_free (topKSize);
3030
3131#define RUN_K (INDEX_T, DIM, DIR ) \
32- gatherTopK<real, INDEX_T, DIM, DIR> \
32+ gatherTopK<real, INDEX_T, DIM, DIR> \
3333 <<<grid, block, 0 , THCState_getCurrentStream(state)>>> ( \
3434 inputInfo, \
3535 sliceSize, \
@@ -63,10 +63,10 @@ THC_API void THCTensor_(topk)(THCState* state,
6363 }
6464
6565#define RUN_T (INDEX_T ) \
66- TensorInfo<real, INDEX_T> inputInfo = \
67- getTensorInfo<THCTensor, INDEX_T>(state, input); \
68- TensorInfo<real, INDEX_T> topKInfo = \
69- getTensorInfo<THCTensor, INDEX_T>(state, topK); \
66+ TensorInfo<real, INDEX_T> inputInfo = \
67+ getTensorInfo<THCTensor, INDEX_T>(state, input); \
68+ TensorInfo<real, INDEX_T> topKInfo = \
69+ getTensorInfo<THCTensor, INDEX_T>(state, topK); \
7070 TensorInfo<int64_t , INDEX_T> indicesInfo = \
7171 getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices); \
7272 \
@@ -82,9 +82,11 @@ THC_API void THCTensor_(topk)(THCState* state,
8282 int collapseIndicesDim = indicesInfo.collapseDims (dim); \
8383 \
8484 int64_t inputSlices = 1 ; \
85- int64_t topKSlices = 1 ; \
86- for (int i = 0 ; i < numDims; ++i) { \
85+ for (int i = 0 ; i < inputInfo.dims ; ++i) { \
8786 inputSlices *= inputInfo.sizes [i]; \
87+ } \
88+ int64_t topKSlices = 1 ; \
89+ for (int i = 0 ; i < topKInfo.dims ; ++i) { \
8890 topKSlices *= topKInfo.sizes [i]; \
8991 } \
9092 \
0 commit comments