Skip to content

Commit c8bd62c

Browse files
t-vifacebook-github-bot
authored andcommitted
Only accept continguous tensors in TopK for cuda (#9441)
Summary: Fixes: #9421 I don't think it is easy to deal with non-contiguous array in cuda topk, so I'm adding a check. The argument number is a bit confusing when it shows in PyTorch but it is consistent with the other checks. (Not sure whether it would make sense to eliminate argument numbers from the error TH/THC error messages given that they're probably off more than once...) Do we need a test that it indeed refuses non-contiguous? Pull Request resolved: pytorch/pytorch#9441 Reviewed By: soumith Differential Revision: D8850719 Pulled By: ezyang fbshipit-source-id: d50561bb37ed50ab97aeaf54d8e3fc6c765bdc7c
1 parent 5b446fe commit c8bd62c

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

aten/src/THC/generic/THCTensorTopK.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
THC_API void THCTensor_(topk)(THCState* state,
66
THCTensor *topK,
77
THCudaLongTensor *indices,
8-
THCTensor *input,
8+
THCTensor *input_,
99
int64_t k, int dim, int dir, int sorted) {
10-
THAssert(topK != NULL && indices != NULL && input != NULL);
11-
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input));
10+
THAssert(topK != NULL && indices != NULL && input_ != NULL);
11+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_));
1212
THArgCheck(THCTensor_(_nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
1313
int64_t dims = THCudaLongTensor__nDimension(state, indices);
1414
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
15-
int numDims = THCTensor_(_nDimension)(state, input);
15+
int numDims = THCTensor_(_nDimension)(state, input_);
1616
THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
1717

1818
THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");
1919

20-
int64_t sliceSize = THCTensor_(size)(state, input, dim);
20+
int64_t sliceSize = THCTensor_(size)(state, input_, dim);
2121
THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension");
2222

23+
THCTensor *input = THCTensor_(newContiguous)(state, input_);
24+
2325
// Build the output size, which is the dim being selected set to
2426
// size k
2527
THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
@@ -155,6 +157,8 @@ THC_API void THCTensor_(topk)(THCState* state,
155157
}
156158
}
157159

160+
THCudaLongTensor_free(state, input);
161+
158162
THCudaCheck(cudaGetLastError());
159163
}
160164

0 commit comments

Comments
 (0)