Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/optional.h"
#include <TH/THTensor.hpp>

#include <algorithm>

Expand Down Expand Up @@ -216,53 +217,12 @@ static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
throw std::runtime_error(ss.str());
}

static at::optional<std::vector<int64_t>>
compute_stride(const Tensor& self, IntList newshape) {
auto oldstride = self.strides();
auto oldshape = self.sizes();
if (oldshape.empty()) {
return std::vector<int64_t>(newshape.size(), 1);
}

std::vector<int64_t> newstride(newshape.size());
int64_t view_d = newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return {};
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return {};
}
return newstride;
}

Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());
if (auto stride = compute_stride(self, shape)) {
if (auto stride = THTensor_compute_stride(self.sizes(), self.strides(), shape)) {
return self.as_strided(shape, *stride);
}
return at::_unsafe_view(self.clone(), shape);
Expand Down
68 changes: 68 additions & 0 deletions aten/src/TH/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "generic/THTensorLapack.cpp"
#include "THGenerateFloatTypes.h"

#include <numeric>

void THTensor_free(THTensor *self)
{
if(!self)
Expand All @@ -54,3 +56,69 @@ void THTensor_free(THTensor *self)
}
}
}

// On a high level,
// 1. separate oldshape chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] * oldstride[i+1]
// 2. newshape must be able to be separated into same number of chunks as oldshape was separated into,
// where each chunk of newshape has matching ``numel'', i.e., number of subspaces,
// as the corresponding chunk of oldshape.
at::optional<std::vector<int64_t>>
THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride, at::IntList newshape) {
if (oldshape.empty()) {
return std::vector<int64_t>(newshape.size(), 1);
}

// NOTE: stride is arbitrary is somewhat arbitrary in the numel() == 0 case;
// to match NumPy behavior we copy the strides if the size matches, otherwise
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity didn't seem worth it.
int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, std::multiplies<int64_t>());
if (numel == 0 && oldshape.equals(newshape)) {
return std::vector<int64_t>(oldstride);
}

std::vector<int64_t> newstride(newshape.size());
if (numel == 0) {
int64_t view_numel = 1;
for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
if (view_d == newshape.size() - 1) {
newstride[view_d] = 1;
} else {
newstride[view_d] = std::max<int64_t>(newshape[view_d+1], 1) * newstride[view_d+1];
}
}
return newstride;
}

int64_t view_d = newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return at::nullopt;
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return at::nullopt;
}
return newstride;
}
2 changes: 2 additions & 0 deletions aten/src/TH/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,5 @@ typedef struct THTensor
#include "THGenerateAllTypes.h"

TH_API void THTensor_free(THTensor *self);
at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
at::IntList newshape);
52 changes: 7 additions & 45 deletions aten/src/TH/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,58 +243,20 @@ THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, int64_t size_,
return self;
}

// Also sets new_stride if viewable.
//
// On a high level,
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
// 2. view_size must be able to be separated into same number of chunks, where
// each chunk pair has matching ``numel'', i.e., number of subspaces.
static int THTensor_(isViewable)(THTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
// dim indices
int64_t tensor_d = tensor->_dim() - 1;
if (tensor_d < 0) {
return 1;
}
int64_t view_d = view_size->size - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = tensor->stride[tensor_d];
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (; tensor_d >= 0; tensor_d--) {
tensor_numel *= tensor->size[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || THLongStorage_data(view_size)[view_d] == 1)) {
THLongStorage_data(new_stride)[view_d] = view_numel * chunk_base_stride;
view_numel *= THLongStorage_data(view_size)[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return 0;
}
if (tensor_d > 0) {
chunk_base_stride = tensor->stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
// check that we iterated through all view size
return view_d == -1;
}

THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size)
{
ptrdiff_t numel = THTensor_(nElement)(tensor);
THTensor *self = THTensor_(new)();
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
THArgCheck(THTensor_(isViewable)(tensor, inferred_size, new_stride), 2, "view size is "
auto stride = THTensor_compute_stride(at::IntList(tensor->size, tensor->dim()),
at::IntList(tensor->stride, tensor->dim()),
at::IntList(inferred_size->data<int64_t>(), inferred_size->size));
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
THLongStorage_rawCopy(new_stride, stride_value.data());
THTensor_(setStorage)(self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
THLongStorage_free(inferred_size);
THLongStorage_free(new_stride);
Expand Down
52 changes: 7 additions & 45 deletions aten/src/THC/generic/THCTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,58 +232,20 @@ THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimensi
return self;
}

// Also sets new_stride if viewable.
//
// On a high level,
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
// 2. view_size must be able to be separated into same number of chunks, where
// each chunk pair has matching ``numel'', i.e., number of subspaces.
static int THCTensor_(isViewable)(THCState *state, THCTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
// dim indices
int64_t tensor_d = tensor->_dim() - 1;
if (tensor_d < 0) {
return 1;
}
int64_t view_d = view_size->size - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = tensor->stride[tensor_d];
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (; tensor_d >= 0; tensor_d--) {
tensor_numel *= tensor->size[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || THLongStorage_data(view_size)[view_d] == 1)) {
THLongStorage_data(new_stride)[view_d] = view_numel * chunk_base_stride;
view_numel *= THLongStorage_data(view_size)[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return 0;
}
if (tensor_d > 0) {
chunk_base_stride = tensor->stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
// check that we iterated through all view size
return view_d == -1;
}

THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size)
{
ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
THCTensor *self = THCTensor_(new)(state);
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
THArgCheck(THCTensor_(isViewable)(state, tensor, inferred_size, new_stride), 2, "View size is "
auto stride = THTensor_compute_stride(at::IntList(tensor->size, tensor->dim()),
at::IntList(tensor->stride, tensor->dim()),
at::IntList(inferred_size->data<int64_t>(), inferred_size->size));
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
THLongStorage_rawCopy(new_stride, stride_value.data());
THCTensor_(setStorage)(state, self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
THLongStorage_free(inferred_size);
THLongStorage_free(new_stride);
Expand Down