Skip to content

Commit c93c884

Browse files
goelhardiksoumith
authored andcommitted
Add negative dimension to transpose and tests (#792)
1 parent 490c15f commit c93c884

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

test/test_torch.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tempfile
88
import unittest
99
import warnings
10-
from itertools import product
10+
from itertools import product, combinations
1111
from common import TestCase, iter_indices, TEST_NUMPY, run_tests, download_file, skipIfNoLapack
1212

1313
if TEST_NUMPY:
@@ -2987,6 +2987,23 @@ def test_Size(self):
29872987
self.assertIsInstance(x[:-1], torch.Size)
29882988
self.assertIsInstance(x + x, torch.Size)
29892989

2990+
def test_transpose_neg(self):
2991+
x = torch.randn(10, 20, 30)
2992+
ndim = 3
2993+
2994+
for i, j in combinations(range(ndim), 2):
2995+
a = x.transpose(i, j)
2996+
b = x.transpose(i - ndim, j - ndim)
2997+
self.assertEqual(a, b)
2998+
2999+
a = torch.transpose(x, i, j)
3000+
b = torch.transpose(x, i - ndim, j - ndim)
3001+
self.assertEqual(a, b)
3002+
3003+
a = x.clone()
3004+
x.transpose_(i, j)
3005+
x.transpose_(i - ndim, j - ndim)
3006+
self.assertEqual(a, x)
29903007

29913008
if __name__ == '__main__':
29923009
run_tests()

torch/csrc/generic/methods/Tensor.cwrap

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,21 +342,41 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
342342
- THBoolTensor* mask
343343
]]
344344

345+
#if IS_CUDA
346+
THTensor* THTensor_(transpose_neg)(THCState* state, THTensor *self, THTensor *src, int dim0, int dim1)
347+
#else
348+
THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int dim1)
349+
#endif
350+
{
351+
int ndim = self->nDimension;
352+
if (dim0 < 0)
353+
dim0 += ndim;
354+
if (dim1 < 0)
355+
dim1 += ndim;
356+
if (src != NULL) {
357+
THTensor_(transpose)(LIBRARY_STATE self, src, dim0, dim1);
358+
return NULL;
359+
} else {
360+
return THTensor_(newTranspose)(LIBRARY_STATE self, dim0, dim1);
361+
}
362+
}
363+
345364
[[
346365
name: transpose
347366
with_stateless: True
348-
cname: newTranspose
367+
cname: transpose_neg
349368
cpu_half: True
350369
return: THTensor*
351370
arguments:
352371
- THTensor* self
372+
- CONSTANT NULL
353373
- long dim0
354374
- long dim1
355375
]]
356376

357377
[[
358378
name: transpose_
359-
cname: transpose
379+
cname: transpose_neg
360380
cpu_half: True
361381
return: self
362382
arguments:

0 commit comments

Comments
 (0)