Skip to content

Commit 8444e16

Browse files
t-vifacebook-github-bot
authored andcommitted
Only accept continguous tensors in TopK for cuda (#9441)
Summary: Fixes: #9421 I don't think it is easy to deal with non-contiguous array in cuda topk, so I'm adding a check. The argument number is a bit confusing when it shows in PyTorch but it is consistent with the other checks. (Not sure whether it would make sense to eliminate argument numbers from the error TH/THC error messages given that they're probably off more than once...) Do we need a test that it indeed refuses non-contiguous? Pull Request resolved: #9441 Reviewed By: soumith Differential Revision: D8850719 Pulled By: ezyang fbshipit-source-id: d50561bb37ed50ab97aeaf54d8e3fc6c765bdc7c
1 parent 8814648 commit 8444e16

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

aten/src/THC/generic/THCTensorTopK.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
THC_API void THCTensor_(topk)(THCState* state,
66
THCTensor *topK,
77
THCudaLongTensor *indices,
8-
THCTensor *input,
8+
THCTensor *input_,
99
int64_t k, int dim, int dir, int sorted) {
10-
THAssert(topK != NULL && indices != NULL && input != NULL);
11-
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input));
10+
THAssert(topK != NULL && indices != NULL && input_ != NULL);
11+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_));
1212
THArgCheck(THCTensor_(_nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
1313
int64_t dims = THCudaLongTensor__nDimension(state, indices);
1414
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
15-
int numDims = THCTensor_(_nDimension)(state, input);
15+
int numDims = THCTensor_(_nDimension)(state, input_);
1616
THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
1717

1818
THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");
1919

20-
int64_t sliceSize = THCTensor_(size)(state, input, dim);
20+
int64_t sliceSize = THCTensor_(size)(state, input_, dim);
2121
THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension");
2222

23+
THCTensor *input = THCTensor_(newContiguous)(state, input_);
24+
2325
// Build the output size, which is the dim being selected set to
2426
// size k
2527
THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
@@ -155,6 +157,8 @@ THC_API void THCTensor_(topk)(THCState* state,
155157
}
156158
}
157159

160+
THCudaLongTensor_free(state, input);
161+
158162
THCudaCheck(cudaGetLastError());
159163
}
160164

test/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,6 +3194,14 @@ def test_topk_arguments(self):
31943194
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
31953195
self.assertRaises(TypeError, lambda: q.topk(4, True))
31963196

3197+
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
3198+
def test_topk_noncontiguous_gpu(self):
3199+
t = torch.randn(20, device="cuda")[::2]
3200+
top1, idx1 = t.topk(5)
3201+
top2, idx2 = t.contiguous().topk(5)
3202+
self.assertEqual(top1, top2)
3203+
self.assertEqual(idx1, idx2)
3204+
31973205
def test_kthvalue(self):
31983206
SIZE = 50
31993207
x = torch.rand(SIZE, SIZE, SIZE)

0 commit comments

Comments
 (0)