Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions aten/src/ATen/native/TensorTransformations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "ATen/ATen.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include <functional>
#include <numeric>
#include <vector>


namespace at {
namespace native {

Tensor reverse_dim(const Tensor& t, int64_t dim) {
Tensor index = at::arange(t.type().toScalarType(at::ScalarType::Long), t.size(dim) - 1, -1, -1);
return t.index_select(dim, index);
}

Tensor flip_cpu(const Tensor& self, IntList dims) {

int64_t total_dims = self.dim(), flip_dims_size = dims.size();

// check if number of axis in dim is valid
if (flip_dims_size == 0) {
std::stringstream ss;
ss << "expected input tensor dims not empty, "
<< "but got tensor dims size=" << flip_dims_size;
throw std::runtime_error(ss.str());
}

// check duplicates in dims
auto flip_dims_v = std::vector<int64_t>(dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {
std::stringstream ss;
ss << "dims has duplicates, "
<< "original flip dims size=" << flip_dims_size << ", "
<< "but unique flip dims size= " << flip_dims_v.size();
throw std::runtime_error(ss.str());
}

// check len of dims
if (flip_dims_size > total_dims) {
std::stringstream ss;
ss << "expected flip dims size <= tensor total dims, "
<< "but got flip dims size=" << flip_dims_size << " and "
<< "tensor total dim=" << total_dims;
throw std::runtime_error(ss.str());
}

// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;
for (auto d : dims) {
min_d = std::min(min_d, d);
max_d = std::max(max_d, d);
}

if (min_d < 0) {
std::stringstream ss;
ss << "expected flip dims axis >= 0, "
<< "but got min flip dims=" << min_d;
throw std::runtime_error(ss.str());
}

if (max_d >= total_dims) {
std::stringstream ss;
ss << "expected flip dims axis < tensor total dims, "
<< "but got max flip dims=" << max_d << " and "
<< "tensor total dim=" << total_dims;
throw std::runtime_error(ss.str());
}

Tensor out_t = self.clone();
for (auto d : dims) {
out_t = reverse_dim(out_t, d);
}
return out_t;
}

}} // namespace at::native
151 changes: 151 additions & 0 deletions aten/src/ATen/native/cuda/TensorTransformations.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include "ATen/NativeFunctions.h"
#include "ATen/ATen.h"
#include <algorithm>
#include <sstream>

#include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh"


namespace at {
namespace native {

// Map the index of an element in tensor from 1D to nD
__device__ __forceinline__
void linear_index_to_indices(int64_t linear_index, int64_t* strides, int64_t total_dims, int64_t* indices) {
int64_t res = linear_index;
for (int64_t i = 0; i < total_dims; i++) {
int64_t indices_i = linear_index * total_dims + i;
indices[indices_i] = res / strides[i];
res = res % strides[i];
}
}

/*
Map the index of an element in tensor from nD to 1D. A tensor is originally in nD shape, and 1D is the unfolded version of it (a vector).

Example: given a 3D tensor
[
[ [1, 2], [3, 4] ],
[ [5, 6], [7, 8] ],
[ [9, 10], [11, 12] ],
]

Here element 3 has nD index (indice) = (0, 1, 0), and stride = (4, 2, 1). To map nD to 1D, we can use formula: sum(indice[i] * stride[i]). For instance, in the example above, 0 * 4 + 1 * 2 + 0 * 1 = 2, and so the oneD index = 2.
*/
__device__ __forceinline__
int64_t indices_to_linear_index(int64_t* indices, int64_t total_dims, int64_t* strides, int64_t src_linear_index) {
int64_t dest_linear_index = 0;
for (int64_t i = 0; i < total_dims; i++) {
int64_t indices_i = src_linear_index * total_dims + i;
dest_linear_index += indices[indices_i] * strides[i];
}
return dest_linear_index;
}

template <typename scalar_t>
__global__
void flip_cuda_kernel(scalar_t* in_t, scalar_t* out_t, int64_t N, int64_t* dims, int64_t* indices,
int64_t flip_dims_size, int64_t* strides, int64_t* shape, int64_t total_dims) {

int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_index >= N) {
return;
}

linear_index_to_indices(linear_index, strides, total_dims, indices);

// Flip nD index along each desired dimension
for (int64_t i = 0 ; i < flip_dims_size; i++) {
int64_t dim = dims[i];
int64_t indices_dim = linear_index * total_dims + dim;
indices[indices_dim] = shape[dim] - 1 - indices[indices_dim];
}
int64_t dest_linear_index = indices_to_linear_index(indices, total_dims, strides, linear_index);
out_t[linear_index] = in_t[dest_linear_index];
}

// Flip tensor given a list of dims
Tensor flip_cuda(const Tensor& self, IntList dims) {

// TODO: support non-contiguous tensors
auto in_t = self.contiguous();

int64_t flip_dims_size = dims.size(), total_dims = in_t.dim(), N = in_t.numel();

// check if number of axis in dim is valid
if (flip_dims_size == 0) {
std::stringstream ss;
ss << "expected input tensor dims not empty, "
<< "but got tensor dims size=" << flip_dims_size;
throw std::runtime_error(ss.str());
}

// check duplicates in dims
auto flip_dims_v = std::vector<int64_t>(dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {
std::stringstream ss;
ss << "dims has duplicates, "
<< "original flip dims size=" << flip_dims_size << ", "
<< "but unique flip dims size= " << flip_dims_v.size();
throw std::runtime_error(ss.str());
}

// check len of dims
if (flip_dims_size > total_dims) {
std::stringstream ss;
ss << "expected flip dims size <= tensor total dims, "
<< "but got flip dims size=" << flip_dims_size << " and "
<< "tensor total dim=" << total_dims;
throw std::runtime_error(ss.str());
}

// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;
for (auto d : dims) {
min_d = std::min(min_d, d);
max_d = std::max(max_d, d);
}

if (min_d < 0) {
std::stringstream ss;
ss << "expected flip dims axis >= 0, "
<< "but got min flip dims=" << min_d;
throw std::runtime_error(ss.str());
}

if (max_d >= total_dims) {
std::stringstream ss;
ss << "expected flip dims axis < tensor total dims, "
<< "but got max flip dims=" << max_d << " and "
<< "tensor total dim=" << total_dims;
throw std::runtime_error(ss.str());
}

auto flip_dims = std::vector<int64_t>(dims);
auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())});

auto shape = std::vector<int64_t>(in_t.sizes());
auto shape_t = at::CPU(kLong).tensorFromBlob(shape.data(), {static_cast<int64_t>(shape.size())});

auto strides = std::vector<int64_t>(in_t.strides());
auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast<int64_t>(strides.size())});

auto indices = at::zeros(CUDA(kLong), {N, total_dims});
auto out_t = at::zeros_like(in_t);

int64_t block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);

AT_DISPATCH_ALL_TYPES_AND_HALF(in_t.type(), "flip_cuda", [&] {
using cuda_scalar_t = cuda::type<scalar_t>;
flip_cuda_kernel<<<dim_grid, dim_block, 0, globalContext().getCurrentCUDAStream()>>>(
in_t.data<cuda_scalar_t>(), out_t.data<cuda_scalar_t>(), N, flip_dims_t.toType(CUDA(kLong)).data<int64_t>(), indices.data<int64_t>(), flip_dims_size, strides_t.toType(CUDA(kLong)).data<int64_t>(), shape_t.toType(CUDA(kLong)).data<int64_t>(), total_dims);
});

return out_t;
}

}} // namespace at::native
9 changes: 7 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,13 @@
variants: function

- func: randint_out(Tensor result, int64_t low, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function
variants: function

- func: randint_like(Tensor self, int64_t high) -> Tensor
variants: function

- func: randint_like(Tensor self, int64_t low, int64_t high) -> Tensor
variants: function
variants: function

- func: randint_like(Tensor self, int64_t high, *, Type dtype) -> Tensor
variants: function
Expand Down Expand Up @@ -760,6 +760,11 @@
- func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor
variants: method

- func: flip(Tensor self, IntList dims) -> Tensor
dispatch:
CPU: flip_cpu
CUDA: flip_cuda

- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor
variants: function

Expand Down
2 changes: 2 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,8 @@ class dont_convert(tuple):
('reshape', (S,), (S,), '1d'),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
('reshape', (), (1,), 'scalar_to_1d'),
('flip', (S, S, S), ([0],), 'd0'),
('flip', (S, S, S), ([0, 1, 2],), 'd012'),
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
Expand Down
5 changes: 5 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def tmp(t):
('zero', small_3d, lambda t: [],),
('zeros', small_3d, lambda t: [1, 2, 3, 4],),
('eye', small_2d, lambda t: [3, 4],),
('flip', small_3d, lambda t: [0], 'd0', types, True),
('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True),
('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types),
('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types),
Expand Down Expand Up @@ -1309,6 +1311,9 @@ def test_det_logdet_slogdet(self):
def test_view(self):
TestTorch._test_view(self, lambda t: t.cuda())

def test_flip(self):
TestTorch._test_flip(self, use_cuda=True)

def test_fft_ifft_rfft_irfft(self):
def cuda_randn_double(*sizes):
return torch.cuda.DoubleTensor(*sizes).normal_()
Expand Down
36 changes: 36 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5519,6 +5519,42 @@ def test_permute(self):
self.assertEqual(perm, new)
self.assertEqual(x.size(), orig)

@staticmethod
def _test_flip(self, use_cuda=False):

This comment was marked as off-topic.

if use_cuda:
cuda = torch.device("cuda")
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=cuda).view(2, 2, 2)
# large data testing
large_data = torch.arange(0, 100000000, device=cuda).view(10000, 10000)
large_data.flip([0, 1])
else:
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2)

self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0))
self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1))
self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2))

# not allow flip on the same dim more than once
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
# not allow empty list as input
self.assertRaises(TypeError, lambda: data.flip())
# not allow size of flip dim > total dims
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 2, 3))
# not allow dim < 0
self.assertRaises(RuntimeError, lambda: data.flip(-1))
# not allow dim > max dim
self.assertRaises(RuntimeError, lambda: data.flip(3))

# test for non-contiguous case
if use_cuda:
data_to_expand = torch.arange(1, 4, device=cuda).view(3, 1)
else:
data_to_expand = torch.arange(1, 4).view(3, 1)
self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), data_to_expand.expand(3, 2).flip(0))

def test_flip(self):
self._test_flip(self, use_cuda=False)

def test_storage(self):
v = torch.randn(3, 5)
self.assertEqual(v.storage()[0], v.data[0][0])
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,9 @@
- name: t(Tensor self)
self: grad.t()

- name: flip(Tensor self, IntList dims)
self: grad.flip(dims)

- name: take(Tensor self, Tensor index)
self: zeros_like(self).put_(index, grad, true)

Expand Down