Skip to content

Commit 302adb7

Browse files
weiyangfbfacebook-github-bot
authored andcommitted
added torch.rot90() to ATen (#8628)
Summary: 1. fixes #6271 2. implemented torch.rot90() following [numpy.rot90()](https://github.com/numpy/numpy/blob/6a58e25703cbecb6786faa09a04ae2ec8221348b/numpy/lib/function_base.py#L54-L138) Pull Request resolved: #8628 Reviewed By: ezyang Differential Revision: D8987860 Pulled By: weiyangfb fbshipit-source-id: 8dac3b2a1f6d3288672977aba8b547706ce97fe9
1 parent 2f5c0c3 commit 302adb7

File tree

14 files changed

+204
-29
lines changed

14 files changed

+204
-29
lines changed

aten/src/ATen/WrapDimUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,11 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) {
8686
return dim;
8787
}
8888

89+
// wrap negative dims in a vector
90+
static inline void wrap_all_dims(std::vector<int64_t>& dims_to_wrap, int64_t tensor_total_dims) {
91+
for (size_t i = 0; i < dims_to_wrap.size(); i++) {
92+
dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
93+
}
94+
}
95+
8996
}

aten/src/ATen/native/TensorTransformations.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ namespace native {
1111

1212
Tensor flip_cpu(const Tensor& self, IntList dims) {
1313
const int64_t total_dims = self.dim(), flip_dims_size = dims.size();
14-
check_errors(total_dims, flip_dims_size, dims);
14+
flip_check_errors(total_dims, flip_dims_size, dims);
1515

1616
auto flip_dims_v = std::vector<int64_t>(dims);
17+
wrap_all_dims(flip_dims_v, total_dims);
1718
std::sort(flip_dims_v.begin(), flip_dims_v.end());
1819
auto final_indices = std::vector<at::Tensor>(total_dims);
1920

@@ -56,4 +57,39 @@ Tensor flip_cpu(const Tensor& self, IntList dims) {
5657
return out_tensor;
5758
}
5859

60+
Tensor rot90(const Tensor& self, int64_t k, IntList dims) {
61+
const int64_t total_dims = self.dim(), total_rot_dims = dims.size();
62+
63+
AT_CHECK(total_rot_dims == 2,
64+
"expected total rotation dims == 2, but got dims = ", total_rot_dims);
65+
66+
AT_CHECK(total_dims >= 2,
67+
"expected total dims >= 2, but got total dims = ", total_dims);
68+
69+
AT_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
70+
"expected rotation dims to be different, but got dim0 = ", dims[0],
71+
" and dim1 = ", dims[1]);
72+
73+
// check range of dims
74+
AT_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
75+
"Rotation dim0 out of range, dim0 = ", dims[0]);
76+
77+
AT_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
78+
"Rotation dim1 out of range, dim1 = ", dims[1]);
79+
80+
// handle modulo with negative k
81+
k = (4 + (k % 4)) % 4;
82+
83+
switch(k) {
84+
case 1:
85+
return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
86+
case 2:
87+
return self.flip(dims);
88+
case 3:
89+
return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
90+
default:
91+
return self.clone();
92+
}
93+
}
94+
5995
}} // namespace at::native
Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,36 @@
11
#include "ATen/ATen.h"
22

33
#include <ATen/Error.h>
4+
#include <ATen/WrapDimUtils.h>
45

56
#include <algorithm>
67
#include <vector>
78

89
namespace at {
910
namespace native {
1011

11-
static inline void check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) {
12+
static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) {
1213
// check if number of axis in dim is valid
13-
AT_CHECK(flip_dims_size > 0,
14-
"expected input tensor dims > 0, but got tensor dims size=", flip_dims_size);
14+
AT_CHECK(flip_dims_size > 0 && flip_dims_size <= total_dims,
15+
"flip dims size out of range, got flip dims size=", flip_dims_size);
1516

16-
// check duplicates in dims
1717
auto flip_dims_v = std::vector<int64_t>(dims);
18-
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
19-
AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size,
20-
"dims has duplicates, original flip dims size=", flip_dims_size,
21-
", but unique flip dims size=", flip_dims_v.size());
22-
23-
// check len of dims
24-
AT_CHECK(flip_dims_size <= total_dims,
25-
"expected flip dims size <= tensor total dims, but got flip dims size=",
26-
flip_dims_size, " and tensor total dim=", total_dims);
2718

2819
// check if dims axis within range
2920
auto min_max_d = std::minmax_element(flip_dims_v.begin(), flip_dims_v.end());
3021

31-
AT_CHECK(*min_max_d.first >= 0,
32-
"expected flip dims axis >= 0, but got min flip dims=", *min_max_d.first);
22+
AT_CHECK(*min_max_d.first < total_dims && *min_max_d.first >= -total_dims,
23+
"The min flip dims out of range, got min flip dims=", *min_max_d.first);
3324

34-
AT_CHECK(*min_max_d.second < total_dims,
35-
"expected flip dims axis < tensor total dims, but got max flip dims=",
36-
*min_max_d.second, " and tensor total dim=", total_dims);
25+
AT_CHECK(*min_max_d.second < total_dims && *min_max_d.second >= -total_dims,
26+
"The max flip dims out of range, got max flip dims=", *min_max_d.second);
27+
28+
// check duplicates in dims
29+
wrap_all_dims(flip_dims_v, total_dims);
30+
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
31+
AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size,
32+
"dims has duplicates, original flip dims size=", flip_dims_size,
33+
", but unique flip dims size=", flip_dims_v.size());
3734
}
3835

3936
}} // namespace at::native

aten/src/ATen/native/cuda/TensorTransformations.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void flip_cuda_kernel(scalar_t* in_tensor, scalar_t* out_tensor, int64_t N, int6
6969
Tensor flip_cuda(const Tensor& self, IntList dims) {
7070
auto in_tensor = self;
7171
const int64_t flip_dims_size = dims.size(), total_dims = in_tensor.dim(), N = in_tensor.numel();
72-
check_errors(total_dims, flip_dims_size, dims);
72+
flip_check_errors(total_dims, flip_dims_size, dims);
7373

7474
int64_t block_size = 512;
7575
dim3 dim_block(block_size);
@@ -80,21 +80,23 @@ Tensor flip_cuda(const Tensor& self, IntList dims) {
8080
return out_tensor;
8181
}
8282

83+
auto flip_dims = std::vector<int64_t>(dims);
84+
wrap_all_dims(flip_dims, total_dims);
85+
8386
// use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
84-
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (dims[0] == 0 || dims[0] == total_dims - 1)) {
87+
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
8588
AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
8689
auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
8790
auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
88-
int flip_dim = in_tensor_info.collapseDims(dims[0]);
89-
out_tensor_info.collapseDims(dims[0]);
91+
int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
92+
out_tensor_info.collapseDims(flip_dims[0]);
9093
kernel_pointwise_flip_apply2<scalar_t, int64_t>
9194
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
9295
in_tensor_info, out_tensor_info, N, flip_dim, total_dims);
9396
});
9497
return out_tensor;
9598
}
9699

97-
auto flip_dims = std::vector<int64_t>(dims);
98100
auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())});
99101

100102
auto shape = std::vector<int64_t>(in_tensor.sizes());

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,9 @@
14501450
CPU: flip_cpu
14511451
CUDA: flip_cuda
14521452

1453+
# default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args
1454+
- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor
1455+
14531456
- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor
14541457
variants: function
14551458

aten/src/ATen/native_parse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def parse_default(s):
2020
return s
2121
elif s == '{}':
2222
return '{}'
23+
elif re.match(r'{.*}', s):
24+
return s
2325
elif s == 'nullopt':
2426
return s
2527
try:

test/test_autograd.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,6 +2654,11 @@ class dont_convert(tuple):
26542654
('flip', (S, S, S), ([0, 1, 2],), 'd012'),
26552655
('flip', (S, S, S), ([0, 2],), 'd02'),
26562656
('flip', (S, S, S), ([2, 0],), 'd20'),
2657+
('flip', (S, S, S), ([-1],), 'neg_d'),
2658+
('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'),
2659+
('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'),
2660+
('rot90', (S, S, S), (1, [1, -1],), 'k1_neg_d'),
2661+
('rot90', (S, S, S), (), 'default'),
26572662
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
26582663
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
26592664
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),

test/test_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,11 @@ def tmp(t):
422422
('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True),
423423
('flip', small_3d, lambda t: [0, 2], 'd02', types, True),
424424
('flip', small_3d, lambda t: [2, 0], 'd20', types, True),
425+
('flip', small_3d, lambda t: [-1], 'neg_d', types, True),
426+
('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True),
427+
('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True),
428+
('rot90', small_3d, lambda t: [1, [1, -1]], 'k1_neg_d', types, True),
429+
('rot90', small_3d, lambda t: [], 'default', types, True),
425430
('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types),
426431
('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
427432
('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
@@ -1417,6 +1422,9 @@ def test_view(self):
14171422
def test_flip(self):
14181423
TestTorch._test_flip(self, use_cuda=True)
14191424

1425+
def test_rot90(self):
1426+
TestTorch._test_rot90(self, use_cuda=True)
1427+
14201428
def test_signal_window_functions(self):
14211429
TestTorch._test_signal_window_functions(self, device=torch.device('cuda'))
14221430

test/test_torch.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6720,6 +6720,8 @@ def _test_flip(self, use_cuda=False):
67206720
self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1))
67216721
self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2))
67226722

6723+
# check for wrap dim
6724+
self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1))
67236725
# check for permute
67246726
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2))
67256727
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0))
@@ -6730,8 +6732,6 @@ def _test_flip(self, use_cuda=False):
67306732
self.assertRaises(TypeError, lambda: data.flip())
67316733
# not allow size of flip dim > total dims
67326734
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 2, 3))
6733-
# not allow dim < 0
6734-
self.assertRaises(RuntimeError, lambda: data.flip(-1))
67356735
# not allow dim > max dim
67366736
self.assertRaises(RuntimeError, lambda: data.flip(3))
67376737

@@ -6756,6 +6756,10 @@ def _test_flip(self, use_cuda=False):
67566756
self.assertEqual(flip0_result, data.flip(0))
67576757
self.assertEqual(flip1_result, data.flip(1))
67586758

6759+
# test empty tensor, should just return an empty tensor of the same shape
6760+
data = torch.tensor([])
6761+
self.assertEqual(data, data.flip(0))
6762+
67596763
def test_flip(self):
67606764
self._test_flip(self, use_cuda=False)
67616765

@@ -6769,6 +6773,44 @@ def test_reversed(self):
67696773
val = torch.tensor(42)
67706774
self.assertEqual(reversed(val), torch.tensor(42))
67716775

6776+
@staticmethod
6777+
def _test_rot90(self, use_cuda=False):
6778+
device = torch.device("cuda" if use_cuda else "cpu")
6779+
data = torch.arange(1, 5, device=device).view(2, 2)
6780+
self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
6781+
self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
6782+
self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
6783+
self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))
6784+
6785+
# test for default args k=1, dims=[0, 1]
6786+
self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))
6787+
6788+
# test for reversed order of dims
6789+
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))
6790+
6791+
# test for modulo of k
6792+
self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
6793+
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
6794+
self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))
6795+
6796+
# test for dims out-of-range error
6797+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3]))
6798+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2]))
6799+
6800+
# test tensor with more than 2D
6801+
data = torch.arange(1, 9, device=device).view(2, 2, 2)
6802+
self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]))
6803+
self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))
6804+
6805+
# test for errors
6806+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3]))
6807+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1]))
6808+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
6809+
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))
6810+
6811+
def test_rot90(self):
6812+
self._test_rot90(self, use_cuda=False)
6813+
67726814
def test_storage(self):
67736815
v = torch.randn(3, 5)
67746816
self.assertEqual(v.storage()[0], v.data[0][0])

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,9 @@
688688
- name: flip(Tensor self, IntList dims)
689689
self: grad.flip(dims)
690690

691+
- name: rot90(Tensor self, int64_t k, IntList dims)
692+
self: grad.rot90(-k, dims)
693+
691694
- name: take(Tensor self, Tensor index)
692695
self: zeros_like(self).put_(index, grad, true)
693696

0 commit comments

Comments
 (0)