Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions aten/src/THC/generic/THCTensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
THC_API void THCTensor_(topk)(THCState* state,
THCTensor *topK,
THCudaLongTensor *indices,
THCTensor *input,
THCTensor *input_,
int64_t k, int dim, int dir, int sorted) {
THAssert(topK != NULL && indices != NULL && input != NULL);
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input));
THAssert(topK != NULL && indices != NULL && input_ != NULL);
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_));
THArgCheck(THCTensor_(_nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
int64_t dims = THCudaLongTensor__nDimension(state, indices);
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
int numDims = THCTensor_(_nDimension)(state, input);
int numDims = THCTensor_(_nDimension)(state, input_);
THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);

THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");

int64_t sliceSize = THCTensor_(size)(state, input, dim);
int64_t sliceSize = THCTensor_(size)(state, input_, dim);
THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension");

THCTensor *input = THCTensor_(newContiguous)(state, input_);

// Build the output size, which is the dim being selected set to
// size k
THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
Expand Down Expand Up @@ -155,6 +157,8 @@ THC_API void THCTensor_(topk)(THCState* state,
}
}

THCudaLongTensor_free(state, input);

THCudaCheck(cudaGetLastError());
}

Expand Down
8 changes: 8 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,14 @@ def test_topk_arguments(self):
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
self.assertRaises(TypeError, lambda: q.topk(4, True))

@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_topk_noncontiguous_gpu(self):
t = torch.randn(20, device="cuda")[::2]
top1, idx1 = t.topk(5)
top2, idx2 = t.contiguous().topk(5)
self.assertEqual(top1, top2)
self.assertEqual(idx1, idx2)

def test_kthvalue(self):
SIZE = 50
x = torch.rand(SIZE, SIZE, SIZE)
Expand Down