Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -4066,17 +4066,6 @@
default: 0
]]

[[
name: reshape_
cname: resize
cpu_half: True
return: self
arguments:
- THTensor* self
- arg: THSize* size
- arg: THStride* stride
]]

[[
name: _sparse_mask
return: argument 0
Expand Down
89 changes: 89 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,95 @@ Tensor repeat(const Tensor& self, IntList repeats) {
return result;
}

// Infers the size of a dim with size -1, if it exists. Also checks that new
// shape is compatible with the number of elements.
static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
auto res = shape.vec();
int64_t newsize = 1;
auto infer_dim = at::optional<int64_t>();
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (shape[dim] == -1) {
if (infer_dim) {
throw std::runtime_error("only one dimension can be inferred");
}
infer_dim = dim;
} else if (shape[dim] >= 0) {
newsize *= shape[dim];
} else {
runtime_error("invalid shape dimension %zd", shape[dim]);
}
}

if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
if (infer_dim) {
res[*infer_dim] = numel / newsize;
}
if (numel == 0) {
// Collapse zero-element shapes into one dimension because TH handles zeros
// in sizes strangely: x.resize_(1, 0) has shape (1,). TODO: remove this
// once we have multi-dimensional empty tensors.
return {0};
}
return res;
}

std::ostringstream ss;
ss << "shape '" << shape << "' is invalid for input of size " << numel;
throw std::runtime_error(ss.str());
}

static at::optional<std::vector<int64_t>>
compute_stride(const Tensor& self, IntList newshape) {
auto oldstride = self.strides();
auto oldshape = self.sizes();
if (oldshape.empty()) {
return std::vector<int64_t>(newshape.size(), 1);
}

std::vector<int64_t> newstride(newshape.size());
int64_t view_d = newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return {};

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return {};
}
return newstride;
}

Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) {
runtime_error("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());
if (auto stride = compute_stride(self, shape)) {
return self.as_strided(shape, *stride);
}
return at::_unsafe_view(self.clone(), shape);
}

Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
AT_ASSERT(ndim > 0, "select() cannot be applied to a 0-dim tensor.");
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@
- func: repeat(Tensor self, IntList repeats) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.

- func: reshape(Tensor self, IntList shape) -> Tensor

- func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor)
variants: function
dispatch:
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: renorm
.. automethod:: renorm_
.. automethod:: repeat
.. automethod:: reshape
.. automethod:: resize_
.. automethod:: resize_as_
.. automethod:: round
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Indexing, Slicing, Joining, Mutating Ops
.. autofunction:: index_select
.. autofunction:: masked_select
.. autofunction:: nonzero
.. autofunction:: reshape
.. autofunction:: split
.. autofunction:: squeeze
.. autofunction:: stack
Expand Down
6 changes: 6 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2213,6 +2213,11 @@ class dont_convert(tuple):
('view', (S,), (S,), '1d'),
('view', (), (dont_convert(()),), 'scalar_to_scalar'),
('view', (), (1,), 'scalar_to_1d'),
('reshape', (S, S, S), (S * S, S),),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'),
('reshape', (S,), (S,), '1d'),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
('reshape', (), (1,), 'scalar_to_1d'),
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
Expand Down Expand Up @@ -2748,6 +2753,7 @@ def unpack_variables(args):
'addmv_',
'addr',
'addr_',
'reshape',
'where' # argument order
}
EXCLUDE_GRADCHECK = {
Expand Down
25 changes: 25 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4425,6 +4425,31 @@ def _test_view(self, cast):
def test_view(self):
TestTorch._test_view(self, lambda x: x)

def test_reshape(self):
x = torch.randn(3, 3)
self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))

y = torch.randn(4, 4, 4)[:, 0, :]
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())

s = torch.randn(())
self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
self.assertEqual(s.reshape(-1).shape, (1,))
self.assertRaises(RuntimeError, lambda: s.reshape(2))

empty = torch.tensor([])
self.assertEqual(empty, empty.reshape(-1))
self.assertEqual(empty, empty.reshape([0]))
# TODO: fix these once we have multi-dimensional empty tensors
self.assertEqual(empty.reshape([0, 1]).shape, (0,))
self.assertEqual(empty.reshape([1, -1]).shape, (0,))
self.assertRaises(RuntimeError, lambda: empty.reshape(1))

def test_expand(self):
tensor = torch.rand(1, 8, 1)
tensor2 = torch.rand(5)
Expand Down
13 changes: 13 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,19 @@ def callable(a, b) -> number

""")

add_docstr_all('reshape',
r"""
reshape(*shape) -> Tensor

Returns a tensor with the same data and number of elements as :attr:`self`,
but with the specified shape.

Args:
shape (tuple of ints or int...): the desired shape

See :func:`torch.reshape`
""")

add_docstr_all('resize_',
r"""
resize_(*sizes) -> Tensor
Expand Down
35 changes: 35 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4295,6 +4295,41 @@

""")

add_docstr(torch.reshape,
r"""
reshape(input, shape) -> Tensor

Returns a tensor with the same data and number of elements as :attr:`input`,
but with the specified shape. When possible, the returned tensor will be a view
of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs
with compatible strides can be reshaped without copying, but you should not
depend on the copying vs. viewing behavior.

A single dimension may be -1, in which case it's inferred from the remaining
dimensions and the number of elements in :attr:`input`.

Args:
input (Tensor): the tensor to be reshaped
shape (tuple of ints): the new shape

Example::

>>> a = torch.arange(4)
>>> torch.reshape(a, (2, 2))
0 1
2 3
[torch.FloatTensor of size (2,2)]

>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
0
1
2
3
[torch.FloatTensor of size (4,)]
""")


add_docstr(torch.round,
r"""
round(input, out=None) -> Tensor
Expand Down
5 changes: 1 addition & 4 deletions torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,10 @@ static void tensorNewClone(rpc::RPCMessage& raw_message) {
static void tensorResize(rpc::RPCMessage& raw_message) {
at::Tensor tensor = unpackRetrieveTensor(raw_message);
THLongStorage *size = unpackTHLongStorage(raw_message);
THLongStorage *stride = unpackTHLongStorage(raw_message);
finalize(raw_message);
at::ArrayRef<int64_t> sizeRef(size->data, size->size);
at::ArrayRef<int64_t> strideRef(stride->data, stride->size);
tensor.reshape_(sizeRef, strideRef);
tensor.resize_(sizeRef);
THLongStorage_free(size);
THLongStorage_free(stride);
}

static void tensorResizeAs(rpc::RPCMessage& raw_message) {
Expand Down