Skip to content
Closed
2 changes: 1 addition & 1 deletion aten/src/ATen/WrapDimUtilsMulti.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace at {

constexpr size_t dim_bitset_size = 64;

static inline std::bitset<dim_bitset_size> dim_list_to_bitset(IntList dims, int64_t ndims, bool wrap_scalar=true) {
static inline std::bitset<dim_bitset_size> dim_list_to_bitset(IntList dims, int64_t ndims) {
AT_CHECK(ndims <= (int64_t) dim_bitset_size, "only tensors with up to ", dim_bitset_size, " dims are supported");
std::bitset<dim_bitset_size> seen;
for (size_t i = 0; i < dims.size(); i++) {
Expand Down
94 changes: 55 additions & 39 deletions aten/src/ATen/native/TensorTransformations.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ATen/native/TensorTransformations.h"
#include "ATen/WrapDimUtilsMulti.h"

#include <ATen/NativeFunctions.h>
#include <c10/util/Exception.h>
Expand All @@ -9,51 +10,66 @@
namespace at {
namespace native {

Tensor flip_cpu(const Tensor& self, IntList dims) {
const int64_t total_dims = self.dim(), flip_dims_size = dims.size();
flip_check_errors(total_dims, flip_dims_size, dims);

auto flip_dims_v = dims.vec();
wrap_all_dims(flip_dims_v, total_dims);
std::sort(flip_dims_v.begin(), flip_dims_v.end());
auto final_indices = std::vector<at::Tensor>(total_dims);

auto indices = std::vector<at::Tensor>(flip_dims_size);
for (int64_t i = 0; i < flip_dims_size; i++) {
indices[i] = at::arange(self.size(flip_dims_v[i]) - 1, -1, -1, self.type().toScalarType(at::kLong));
// creates a meshgrid
auto temp = std::vector<int64_t>(flip_dims_size, 1);
temp[i] = indices[i].size(0);
indices[i] = indices[i].view(IntList(temp));
final_indices[flip_dims_v[i]] = indices[i];
}

// check if distance between two flip dims >= 2, where permute of output tensor is needed,
// because the advanced indexing puts all non-consecutive indices in the beginning of the tensor
bool to_permute = false;
int64_t first = flip_dims_v[0], second = flip_dims_v[0];
for (int64_t i = 1; i < flip_dims_size; i++) {
second = flip_dims_v[i];
if (second - first >= 2) {
to_permute = true;
break;
constexpr size_t dim_bitset_size = 64;

template <typename scalar_t>
void inline flip_cpu_kernel(
const int64_t total_dims,
const std::vector<int64_t>& stride_contiguous_v,
const std::bitset<dim_bitset_size>& flip_dims_b,

This comment was marked as off-topic.

This comment was marked as off-topic.

const Tensor& in_tensor,
Tensor& out_tensor
){
int64_t i;
const int64_t numel = in_tensor.numel();
const scalar_t* in_tensor_d = in_tensor.data<scalar_t>();
scalar_t* out_tensor_d = out_tensor.data<scalar_t>();
auto sizes_v = in_tensor.sizes().vec();
auto strides_v = in_tensor.strides().vec();

#pragma omp parallel for private(i) if (numel > 1000)
for (i = 0; i < numel; i++) {
int64_t cur_indices = i;
int64_t rem = 0;
int64_t dst_offset = 0;

for (int64_t d = 0; d < total_dims; d++) {
int64_t temp = cur_indices;
cur_indices = cur_indices / stride_contiguous_v[d];
rem = temp - cur_indices * stride_contiguous_v[d];
dst_offset += flip_dims_b[d] ? (sizes_v[d] - 1 - cur_indices) * strides_v[d] : cur_indices * strides_v[d];
cur_indices = rem;
}
first = second;
out_tensor_d[i] = in_tensor_d[dst_offset];
}
}

if (to_permute) {
// permute output tensor
auto permute_order = std::vector<int64_t>(flip_dims_v);
for (int64_t i = 0; i < total_dims; i++) {
if (std::find(flip_dims_v.begin(), flip_dims_v.end(), i) == flip_dims_v.end()) {
permute_order.emplace_back(i);
}
Tensor flip_cpu(const Tensor& self, IntList dims) {
auto in_tensor = self;
const int64_t total_dims = in_tensor.dim();
auto flip_dims_b = dim_list_to_bitset(dims, total_dims);
Tensor out_tensor = at::empty_like(in_tensor);

// create contiguous strides for input tensor
auto stride_contiguous_v = std::vector<int64_t>(total_dims);
for (int64_t i = total_dims - 1; i >= 0; i--) {
if (i == total_dims - 1) {
stride_contiguous_v[i] = 1;
} else {
stride_contiguous_v[i] = std::max<int64_t>(in_tensor.size(i + 1), 1) * stride_contiguous_v[i + 1];
}
auto out_tensor = self.index(TensorList(final_indices));
return out_tensor.permute(IntList(permute_order));
}

auto out_tensor = self.index(TensorList(final_indices));
AT_DISPATCH_ALL_TYPES(in_tensor.type(), "flip_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});

return out_tensor;
}

Expand Down
28 changes: 14 additions & 14 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7206,14 +7206,8 @@ def test_permute(self):

@staticmethod
def _test_flip(self, use_cuda=False):
if use_cuda:
cuda = torch.device("cuda")
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=cuda).view(2, 2, 2)
# large data testing
large_data = torch.arange(0, 100000000, device=cuda).view(10000, 10000)
large_data.flip([0, 1])
else:
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2)
device = torch.device('cuda') if use_cuda else torch.device('cpu')
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2)

self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0))
self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1))
Expand All @@ -7237,15 +7231,21 @@ def _test_flip(self, use_cuda=False):
self.assertRaises(RuntimeError, lambda: data.flip(3))

# test for non-contiguous case
if use_cuda:
expanded_data = torch.arange(1, 4, device=cuda).view(3, 1).expand(3, 2)
tranposed_data = torch.arange(1, 9, device=cuda).view(2, 2, 2).transpose(0, 1)
else:
expanded_data = torch.arange(1, 4).view(3, 1).expand(3, 2)
tranposed_data = torch.arange(1, 9).view(2, 2, 2).transpose(0, 1)
expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2)
tranposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1)
self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0))
self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), tranposed_data.flip(0, 1, 2))

# test for shape
data = torch.randn(2, 3, 4, device=device)
size = [2, 3, 4]
test_dims = []
for i in range(1, 3):
test_dims += combinations(range(len(size)), i)

for ds in test_dims:
self.assertEqual(size, list(data.flip(ds).size()))

# test rectangular case
data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]])
Expand Down