Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,11 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) {
return dim;
}

// wrap negative dims in a vector
static inline void wrap_all_dims(std::vector<int64_t>& dims_to_wrap, int64_t tensor_total_dims) {
for (size_t i = 0; i < dims_to_wrap.size(); i++) {
dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
}
}

}
38 changes: 37 additions & 1 deletion aten/src/ATen/native/TensorTransformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ namespace native {

Tensor flip_cpu(const Tensor& self, IntList dims) {
const int64_t total_dims = self.dim(), flip_dims_size = dims.size();
check_errors(total_dims, flip_dims_size, dims);
flip_check_errors(total_dims, flip_dims_size, dims);

auto flip_dims_v = std::vector<int64_t>(dims);
wrap_all_dims(flip_dims_v, total_dims);
std::sort(flip_dims_v.begin(), flip_dims_v.end());
auto final_indices = std::vector<at::Tensor>(total_dims);

Expand Down Expand Up @@ -56,4 +57,39 @@ Tensor flip_cpu(const Tensor& self, IntList dims) {
return out_tensor;
}

Tensor rot90(const Tensor& self, int64_t k, IntList dims) {
const int64_t total_dims = self.dim(), total_rot_dims = dims.size();

AT_CHECK(total_rot_dims == 2,
"expected total rotation dims == 2, but got dims = ", total_rot_dims);

AT_CHECK(total_dims >= 2,
"expected total dims >= 2, but got total dims = ", total_dims);

AT_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
"expected rotation dims to be different, but got dim0 = ", dims[0],
" and dim1 = ", dims[1]);

// check range of dims
AT_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
"Rotation dim0 out of range, dim0 = ", dims[0]);

AT_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
"Rotation dim1 out of range, dim1 = ", dims[1]);

// handle modulo with negative k
k = (4 + (k % 4)) % 4;

switch(k) {
case 1:
return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
case 2:
return self.flip(dims);
case 3:
return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
default:
return self.clone();
}
}

}} // namespace at::native
33 changes: 15 additions & 18 deletions aten/src/ATen/native/TensorTransformations.h
Original file line number Diff line number Diff line change
@@ -1,39 +1,36 @@
#include "ATen/ATen.h"

#include <ATen/Error.h>
#include <ATen/WrapDimUtils.h>

#include <algorithm>
#include <vector>

namespace at {
namespace native {

static inline void check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) {
static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) {
// check if number of axis in dim is valid
AT_CHECK(flip_dims_size > 0,
"expected input tensor dims > 0, but got tensor dims size=", flip_dims_size);
AT_CHECK(flip_dims_size > 0 && flip_dims_size <= total_dims,
"flip dims size out of range, got flip dims size=", flip_dims_size);

// check duplicates in dims
auto flip_dims_v = std::vector<int64_t>(dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size,
"dims has duplicates, original flip dims size=", flip_dims_size,
", but unique flip dims size=", flip_dims_v.size());

// check len of dims
AT_CHECK(flip_dims_size <= total_dims,
"expected flip dims size <= tensor total dims, but got flip dims size=",
flip_dims_size, " and tensor total dim=", total_dims);

// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims_v.begin(), flip_dims_v.end());

AT_CHECK(*min_max_d.first >= 0,
"expected flip dims axis >= 0, but got min flip dims=", *min_max_d.first);
AT_CHECK(*min_max_d.first < total_dims && *min_max_d.first >= -total_dims,
"The min flip dims out of range, got min flip dims=", *min_max_d.first);

AT_CHECK(*min_max_d.second < total_dims,
"expected flip dims axis < tensor total dims, but got max flip dims=",
*min_max_d.second, " and tensor total dim=", total_dims);
AT_CHECK(*min_max_d.second < total_dims && *min_max_d.second >= -total_dims,
"The max flip dims out of range, got max flip dims=", *min_max_d.second);

// check duplicates in dims
wrap_all_dims(flip_dims_v, total_dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size,
"dims has duplicates, original flip dims size=", flip_dims_size,
", but unique flip dims size=", flip_dims_v.size());
}

}} // namespace at::native
12 changes: 7 additions & 5 deletions aten/src/ATen/native/cuda/TensorTransformations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void flip_cuda_kernel(scalar_t* in_tensor, scalar_t* out_tensor, int64_t N, int6
Tensor flip_cuda(const Tensor& self, IntList dims) {
auto in_tensor = self;
const int64_t flip_dims_size = dims.size(), total_dims = in_tensor.dim(), N = in_tensor.numel();
check_errors(total_dims, flip_dims_size, dims);
flip_check_errors(total_dims, flip_dims_size, dims);

int64_t block_size = 512;
dim3 dim_block(block_size);
Expand All @@ -80,21 +80,23 @@ Tensor flip_cuda(const Tensor& self, IntList dims) {
return out_tensor;
}

auto flip_dims = std::vector<int64_t>(dims);
wrap_all_dims(flip_dims, total_dims);

// use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (dims[0] == 0 || dims[0] == total_dims - 1)) {
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
int flip_dim = in_tensor_info.collapseDims(dims[0]);
out_tensor_info.collapseDims(dims[0]);
int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
out_tensor_info.collapseDims(flip_dims[0]);
kernel_pointwise_flip_apply2<scalar_t, int64_t>
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
in_tensor_info, out_tensor_info, N, flip_dim, total_dims);
});
return out_tensor;
}

auto flip_dims = std::vector<int64_t>(dims);
auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())});

auto shape = std::vector<int64_t>(in_tensor.sizes());
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,9 @@
CPU: flip_cpu
CUDA: flip_cuda

# default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args
- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor

- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor
variants: function

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def parse_default(s):
return s
elif s == '{}':
return '{}'
elif re.match(r'{.*}', s):
return s
elif s == 'nullopt':
return s
try:
Expand Down
5 changes: 5 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,11 @@ class dont_convert(tuple):
('flip', (S, S, S), ([0, 1, 2],), 'd012'),
('flip', (S, S, S), ([0, 2],), 'd02'),
('flip', (S, S, S), ([2, 0],), 'd20'),
('flip', (S, S, S), ([-1],), 'neg_d'),
('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'),
('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'),
('rot90', (S, S, S), (1, [1, -1],), 'k1_neg_d'),
('rot90', (S, S, S), (), 'default'),
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
Expand Down
8 changes: 8 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ def tmp(t):
('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True),
('flip', small_3d, lambda t: [0, 2], 'd02', types, True),
('flip', small_3d, lambda t: [2, 0], 'd20', types, True),
('flip', small_3d, lambda t: [-1], 'neg_d', types, True),
('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True),
('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True),
('rot90', small_3d, lambda t: [1, [1, -1]], 'k1_neg_d', types, True),
('rot90', small_3d, lambda t: [], 'default', types, True),
('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types),
('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
Expand Down Expand Up @@ -1415,6 +1420,9 @@ def test_view(self):
def test_flip(self):
TestTorch._test_flip(self, use_cuda=True)

def test_rot90(self):
TestTorch._test_rot90(self, use_cuda=True)

def test_signal_window_functions(self):
TestTorch._test_signal_window_functions(self, device=torch.device('cuda'))

Expand Down
46 changes: 44 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6686,6 +6686,8 @@ def _test_flip(self, use_cuda=False):
self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1))
self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2))

# check for wrap dim
self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1))
# check for permute
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2))
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0))
Expand All @@ -6696,8 +6698,6 @@ def _test_flip(self, use_cuda=False):
self.assertRaises(TypeError, lambda: data.flip())
# not allow size of flip dim > total dims
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 2, 3))
# not allow dim < 0
self.assertRaises(RuntimeError, lambda: data.flip(-1))
# not allow dim > max dim
self.assertRaises(RuntimeError, lambda: data.flip(3))

Expand All @@ -6722,6 +6722,10 @@ def _test_flip(self, use_cuda=False):
self.assertEqual(flip0_result, data.flip(0))
self.assertEqual(flip1_result, data.flip(1))

# test empty tensor, should just return an empty tensor of the same shape
data = torch.tensor([])
self.assertEqual(data, data.flip(0))

def test_flip(self):
self._test_flip(self, use_cuda=False)

Expand All @@ -6735,6 +6739,44 @@ def test_reversed(self):
val = torch.tensor(42)
self.assertEqual(reversed(val), torch.tensor(42))

@staticmethod
def _test_rot90(self, use_cuda=False):
device = torch.device("cuda" if use_cuda else "cpu")
data = torch.arange(1, 5, device=device).view(2, 2)
self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))

# test for default args k=1, dims=[0, 1]
self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))

# test for reversed order of dims
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))

# test for modulo of k
self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))

# test for dims out-of-range error
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3]))
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2]))

# test tensor with more than 2D
data = torch.arange(1, 9, device=device).view(2, 2, 2)
self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]))
self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))

# test for errors
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3]))
self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1]))
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))

def test_rot90(self):
self._test_rot90(self, use_cuda=False)

def test_storage(self):
v = torch.randn(3, 5)
self.assertEqual(v.storage()[0], v.data[0][0])
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,9 @@
- name: flip(Tensor self, IntList dims)
self: grad.flip(dims)

- name: rot90(Tensor self, int64_t k, IntList dims)
self: grad.rot90(-k, dims)

- name: take(Tensor self, Tensor index)
self: zeros_like(self).put_(index, grad, true)

Expand Down
4 changes: 3 additions & 1 deletion tools/jit/gen_jit_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,9 @@ def format_arg(arg):
.replace('false', 'False') \
.replace('nullptr', 'None') \
.replace('Reduction::ElementwiseMean', 'ElementwiseMean') \
.replace('{}', 'None' if is_tensor_arg(arg) else '[]')
.replace('{}', 'None' if is_tensor_arg(arg) else '[]') \
.replace('{', '[') \
.replace('}', ']')

default = default_map.get(default, default)
decl = '{}={}'.format(decl, default)
Expand Down
37 changes: 37 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4498,6 +4498,43 @@ def parse_kwargs(desc):
[ 0, 1]]])
""")

add_docstr(torch.rot90,
r"""
rot90(input, k, dims) -> Tensor

Rotate a n-D tensor by 90 degrees in the plane specified by dims axis.
Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0.

Args:
input (Tensor): the input tensor
k (int): number of times to rotate
dims (a list or tuple): axis to rotate

Example::

>>> x = torch.arange(4).view(2, 2)
>>> x
tensor([[0, 1],
[2, 3]])
>>> torch.rot90(x, 1, [0, 1])
tensor([[1, 3],
[0, 2]])

>>> x = torch.arange(8).view(2, 2, 2)
>>> x
tensor([[[0, 1],
[2, 3]],

[[4, 5],
[6, 7]]])
>>> torch.rot90(x, 1, [1, 2])
tensor([[[1, 3],
[0, 2]],

[[5, 7],
[4, 6]]])
""")

add_docstr(torch.take,
r"""
take(input, indices) -> Tensor
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ static bool isTensorSubtype(Value* v) {
at::optional<std::vector<int64_t>> getIntListAttribute(at::optional<int32_t> N, Value* input) {
auto list = constant_as<Shared<jit::IntList>>(input);
if(list)
return std::vector<int64_t>(*list);
return std::vector<int64_t>(list.value()->elements());
// broadcast IntList[3] with value 4 -> {4, 4, 4}
if(!N)
return at::nullopt;
Expand Down
Loading