Skip to content

Commit 4e549e9

Browse files
albanDsoumith
authored andcommitted
fix triu and tril for zero-strided inputs on gpu (#4962)
1 parent 00f9da7 commit 4e549e9

File tree

3 files changed

+43
-30
lines changed

3 files changed

+43
-30
lines changed

test/test_cuda.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def medium_2d(t):
9797
return make_tensor(t, M, M)
9898

9999

100+
def medium_2d_expanded(t):
101+
return t(1).expand(M, M)
102+
103+
100104
def medium_2d_scaled(t, scale=10):
101105
return make_tensor(t, M, M).mul(scale)
102106

@@ -143,6 +147,13 @@ def tmp(t):
143147
return t(*sizes).copy_(torch.randn(*sizes))
144148
return tmp
145149

150+
# Content of each tuple:
151+
# - function name
152+
# - constructor for the tensor, signature: fn(tensor_type) -> tensor
153+
# - constructor for the arguments, signature: fn(tensor_type) -> list
154+
# - postfix name for the test (must be unique for a given function) (default='')
155+
# - tensor types to use (default=types)
156+
# - disable inplace test, if set to True, no inplace test will be done (default=False)
146157
tests = [
147158
('add', small_3d, lambda t: [number(3.14, 3, t)]),
148159
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
@@ -295,9 +306,11 @@ def tmp(t):
295306
('topk', small_3d_unique, lambda t: [2, 1, True, True], 'dim_desc_sort'),
296307
('trace', medium_2d, lambda t: [],),
297308
('tril', medium_2d, lambda t: [],),
309+
('tril', medium_2d_expanded, lambda t: [], 'zero_stride', types, True),
298310
('tril', medium_2d, lambda t: [2], 'positive'),
299311
('tril', medium_2d, lambda t: [-2], 'negative'),
300312
('triu', medium_2d, lambda t: [],),
313+
('triu', medium_2d_expanded, lambda t: [], 'zero_stride', types, True),
301314
('triu', medium_2d, lambda t: [2], 'positive'),
302315
('triu', medium_2d, lambda t: [-2], 'negative'),
303316
('unsqueeze', new_t(2, 3, 4), lambda t: [2],),
@@ -1106,18 +1119,27 @@ def test_nvtx(self):
11061119
for t in types:
11071120
tensor = t()
11081121
gpu_tensor = get_gpu_type(t)()
1122+
1123+
# Default values
1124+
desc = ''
1125+
type_subset = types
1126+
no_inplace = False
11091127
if len(decl) == 3:
11101128
name, constr, arg_constr = decl
1111-
desc = ''
11121129
elif len(decl) == 4:
11131130
name, constr, arg_constr, desc = decl
11141131
elif len(decl) == 5:
11151132
name, constr, arg_constr, desc, type_subset = decl
1116-
if t not in type_subset:
1117-
continue
1133+
elif len(decl) == 6:
1134+
name, constr, arg_constr, desc, type_subset, no_inplace = decl
1135+
1136+
if t not in type_subset:
1137+
continue
11181138

11191139
precision = custom_precision.get(name, TestCuda.precision)
11201140
for inplace in (True, False):
1141+
if inplace and no_inplace:
1142+
continue
11211143
if inplace:
11221144
name_inner = name + '_'
11231145
else:

torch/lib/THC/THCTensorMathPairwise.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ struct TensorTriOp {
375375
TensorTriOp(T *start_, int64_t stride0_, int64_t stride1_, int64_t k_)
376376
: start(start_), stride0(stride0_), stride1(stride1_), k(k_) {}
377377

378-
__device__ __forceinline__ int mask(T *in) {
379-
ptrdiff_t n = in - start;
378+
__device__ __forceinline__ int mask(T *out) {
379+
ptrdiff_t n = out - start;
380380
int64_t row, col;
381381
if (stride0 > stride1)
382382
{
@@ -393,7 +393,7 @@ struct TensorTriOp {
393393
}
394394

395395
__device__ __forceinline__ void operator()(T* out, T* in) {
396-
*out = mask(in) ? *in : ScalarConvert<int, T>::to(0);
396+
*out = mask(out) ? *in : ScalarConvert<int, T>::to(0);
397397
}
398398

399399
__device__ __forceinline__ void operator()(T* v) {

torch/lib/THC/generic/THCTensorMathPairwise.cu

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -193,31 +193,27 @@ void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, int64_
193193
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
194194
THArgCheck(src_->nDimension == 2, 1, "expected a matrix");
195195

196-
THCTensor *src = src_;
197-
if (self_ == src_)
198-
src = THCTensor_(newContiguous)(state, src_);
196+
if (self_ != src_)
197+
THCTensor_(resizeAs)(state, self_, src_);
199198

200-
int64_t stride0 = src->stride[0];
201-
int64_t stride1 = src->stride[1];
202-
real *start = THCTensor_(data)(state, src);
199+
int64_t stride0 = self_->stride[0];
200+
int64_t stride1 = self_->stride[1];
201+
real *start = THCTensor_(data)(state, self_);
203202

204203
TensorTriOp<real, 0> op(start, stride0, stride1, k);
205204

206205
if (self_ == src_) {
207-
if (!THC_pointwiseApply1(state, src, op)) {
206+
if (!THC_pointwiseApply1(state, src_, op)) {
208207
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
209208
}
210209
} else {
211-
THCTensor_(resizeAs)(state, self_, src);
210+
THCTensor_(resizeAs)(state, self_, src_);
212211

213-
if (!THC_pointwiseApply2(state, self_, src, op)) {
212+
if (!THC_pointwiseApply2(state, self_, src_, op)) {
214213
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
215214
}
216215
}
217216

218-
if (self_ == src_)
219-
THCTensor_(freeCopyTo)(state, src, src_);
220-
221217
THCudaCheck(cudaGetLastError());
222218
}
223219

@@ -226,31 +222,26 @@ void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, int64_
226222
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
227223
THArgCheck(src_->nDimension == 2, 1, "expected a matrix");
228224

229-
THCTensor *src = src_;
230-
if (self_ == src_)
231-
src = THCTensor_(newContiguous)(state, src_);
225+
if (self_ != src_)
226+
THCTensor_(resizeAs)(state, self_, src_);
232227

233-
int64_t stride0 = src->stride[0];
234-
int64_t stride1 = src->stride[1];
235-
real *start = THCTensor_(data)(state, src);
228+
int64_t stride0 = self_->stride[0];
229+
int64_t stride1 = self_->stride[1];
230+
real *start = THCTensor_(data)(state, self_);
236231

237232
TensorTriOp<real, 1> op(start, stride0, stride1, k);
238233

239234
if (self_ == src_) {
240-
if (!THC_pointwiseApply1(state, src, op)) {
235+
if (!THC_pointwiseApply1(state, src_, op)) {
241236
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
242237
}
243238
} else {
244-
THCTensor_(resizeAs)(state, self_, src);
245239

246-
if (!THC_pointwiseApply2(state, self_, src, op)) {
240+
if (!THC_pointwiseApply2(state, self_, src_, op)) {
247241
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
248242
}
249243
}
250244

251-
if (self_ == src_)
252-
THCTensor_(freeCopyTo)(state, src, src_);
253-
254245
THCudaCheck(cudaGetLastError());
255246
}
256247

0 commit comments

Comments
 (0)