Skip to content

Commit 52cc073

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Implement reshape_as (#9452)
Summary: 1. Added tests 2. Added doc string 3. Remove view_as redundant definition from tensor.py Closes #9416 Pull Request resolved: #9452 Differential Revision: D8851794 Pulled By: ezyang fbshipit-source-id: 0aa0430dd0a174e1a5caddbc50a7e2c9eb7802bc
1 parent 11fc16d commit 52cc073

File tree

8 files changed

+59
-13
lines changed

8 files changed

+59
-13
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ Tensor reshape(const Tensor& self, IntList proposed_shape) {
265265
return at::_unsafe_view(self.clone(), shape);
266266
}
267267

268+
Tensor reshape_as(const Tensor& self, const Tensor& other) {
269+
return self.reshape(other.sizes());
270+
}
271+
268272
Tensor select(const Tensor& self, int64_t dim, int64_t index) {
269273
int64_t ndim = self.dim();
270274
AT_CHECK(ndim > 0, "select() cannot be applied to a 0-dim tensor.");

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,9 @@
10961096

10971097
- func: reshape(Tensor self, IntList shape) -> Tensor
10981098

1099+
- func: reshape_as(Tensor self, Tensor other) -> Tensor
1100+
variants: method
1101+
10991102
- func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor)
11001103
variants: function
11011104
dispatch:

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ view of a storage and defines numeric operations on it.
329329
.. automethod:: repeat
330330
.. automethod:: requires_grad_
331331
.. automethod:: reshape
332+
.. automethod:: reshape_as
332333
.. automethod:: resize_
333334
.. automethod:: resize_as_
334335
.. automethod:: round

test/test_autograd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,6 +2631,9 @@ class dont_convert(tuple):
26312631
('reshape', (S,), (S,), '1d'),
26322632
('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
26332633
('reshape', (), (1,), 'scalar_to_1d'),
2634+
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
2635+
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
2636+
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
26342637
('flip', (S, S, S), ([0],), 'd0'),
26352638
('flip', (S, S, S), ([0, 1, 2],), 'd012'),
26362639
('flip', (S, S, S), ([0, 2],), 'd02'),

test/test_torch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6031,6 +6031,11 @@ def test_reshape(self):
60316031
self.assertEqual(empty.reshape([1, -1]).shape, (0,))
60326032
self.assertRaises(RuntimeError, lambda: empty.reshape(1))
60336033

6034+
x = torch.randn(3, 3)
6035+
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
6036+
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
6037+
self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10)))
6038+
60346039
@skipIfNoZeroSize
60356040
def test_empty_reshape(self):
60366041
x = torch.randn(0, 6)

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@
579579
self: repeat_backward(grad, self.dim(), repeats)
580580

581581
# DO NOT define a backward for reshape!
582-
# reshape is special in that it sometimes returns a view, and somtimes not.
582+
# reshape is special in that it sometimes returns a view, and sometimes not.
583583
# Defining a backward will make codegen spit out the forward call as
584584
# as_variable(baseType->reshape(self)),
585585
# making it impossible (hard) to detect when it is actually a view.

torch/_tensor_docs.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,20 @@ def callable(a, b) -> number
16661666
See :func:`torch.reshape`
16671667
""")
16681668

1669+
add_docstr_all('reshape_as',
1670+
r"""
1671+
reshape_as(other) -> Tensor
1672+
1673+
Returns this tensor as the same shape as :attr:`other`.
1674+
``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``.
1675+
1676+
Please see :meth:`~Tensor.reshape` for more information about ``reshape``.
1677+
1678+
Args:
1679+
other (:class:`torch.Tensor`): The result tensor has the same shape
1680+
as :attr:`other`.
1681+
""")
1682+
16691683
add_docstr_all('resize_',
16701684
r"""
16711685
resize_(*sizes) -> Tensor
@@ -2407,6 +2421,20 @@ def callable(a, b) -> number
24072421
24082422
""")
24092423

2424+
add_docstr_all('view_as',
2425+
r"""
2426+
view_as(other) -> Tensor
2427+
2428+
View this tensor as the same size as :attr:`other`.
2429+
``self.view_as(other)`` is equivalent to ``self.view(other.size())``.
2430+
2431+
Please see :meth:`~Tensor.view` for more information about ``view``.
2432+
2433+
Args:
2434+
other (:class:`torch.Tensor`): The result tensor has the same size
2435+
as :attr:`other`.
2436+
""")
2437+
24102438
add_docstr_all('expand',
24112439
r"""
24122440
expand(*sizes) -> Tensor
@@ -2445,6 +2473,20 @@ def callable(a, b) -> number
24452473
[ 3, 3, 3, 3]])
24462474
""")
24472475

2476+
add_docstr_all('expand_as',
2477+
r"""
2478+
expand_as(other) -> Tensor
2479+
2480+
Expand this tensor to the same size as :attr:`other`.
2481+
``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``.
2482+
2483+
Please see :meth:`~Tensor.expand` for more information about ``expand``.
2484+
2485+
Args:
2486+
other (:class:`torch.Tensor`): The result tensor has the same size
2487+
as :attr:`other`.
2488+
""")
2489+
24482490
add_docstr_all('zero_',
24492491
r"""
24502492
zero_() -> Tensor

torch/tensor.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,6 @@ def share_memory_(self):
219219
self.storage().share_memory_()
220220
return self
221221

222-
def view_as(self, tensor):
223-
r"""view_as(other) -> Tensor
224-
225-
View this tensor as the same size as :attr:`other`.
226-
``self.view_as(other)`` is equivalent to ``self.view(other.size())``.
227-
228-
Args:
229-
other (:class:`torch.Tensor`): The result tensor has the same size
230-
as :attr:`other.size()`.
231-
"""
232-
return self.view(tensor.size())
233-
234222
def __reversed__(self):
235223
r"""Reverses the tensor along dimension 0."""
236224
if self.dim() == 0:

0 commit comments

Comments
 (0)