Skip to content

Commit c89c294

Browse files
alvgaonafacebook-github-bot
authored andcommitted
Add Unflatten Module (#41564)
Summary: This PR implements a feature extension discussed in #41516. I followed this other PR #22245 to add this other module. While I was at it, I also added `extra_repr()` method in `Flatten` which was missing. I see there are no unit tests for these modules. Should I add those too? If so, what is the best place I should place these? Pull Request resolved: #41564 Reviewed By: gchanan Differential Revision: D22636766 Pulled By: albanD fbshipit-source-id: f9efdefd3ffe7d9af9482087625344af8f990943
1 parent fe41558 commit c89c294

File tree

4 files changed

+154
-2
lines changed

4 files changed

+154
-2
lines changed

docs/source/nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ Utility functions in other modules
355355
nn.utils.rnn.pack_sequence
356356

357357
nn.Flatten
358+
nn.Unflatten
358359

359360
Quantized Functions
360361
--------------------

test/test_nn.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8326,6 +8326,51 @@ def test_functional_grad_conv(self):
83268326
torch.nn.grad._grad_input_padding(torch.rand(1, 2, 3), [1, 2, 5], (1,), (0,), (3,))
83278327
self.assertEqual(len(w), 1)
83288328

8329+
def test_unflatten(self):
8330+
tensor_input = torch.randn(2, 50)
8331+
8332+
# Unflatten Tensor
8333+
8334+
unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5))
8335+
tensor_output = unflatten(tensor_input)
8336+
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
8337+
8338+
# Unflatten NamedTensor
8339+
8340+
unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5)))
8341+
named_tensor_input = tensor_input.refine_names('N', 'features')
8342+
named_tensor_output = unflatten(named_tensor_input)
8343+
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
8344+
8345+
def test_unflatten_invalid_arg(self):
8346+
# Wrong type for unflattened_size (tuple of floats)
8347+
8348+
with self.assertRaisesRegex(
8349+
TypeError,
8350+
r"unflattened_size must be tuple of ints, but found element of type float at pos 2"):
8351+
nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0))
8352+
8353+
# Wrong type for unflattened_size (tuple of lists)
8354+
8355+
with self.assertRaisesRegex(
8356+
TypeError,
8357+
r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"):
8358+
nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5]))
8359+
8360+
# Wrong type for unflattened_size (list of ints)
8361+
8362+
with self.assertRaisesRegex(
8363+
TypeError,
8364+
r"unflattened_size must be a tuple of ints, but found type list"):
8365+
nn.Unflatten(dim=1, unflattened_size=[2, 5, 5])
8366+
8367+
# Wrong type for unflattened_size (list of lists)
8368+
8369+
with self.assertRaisesRegex(
8370+
TypeError,
8371+
r"unflattened_size must be a tuple of tuples, but found type list"):
8372+
nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]])
8373+
83298374

83308375
class TestNNInit(TestCase):
83318376
def setUp(self):

torch/nn/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .adaptive import AdaptiveLogSoftmaxWithLoss
3131
from .transformer import TransformerEncoder, TransformerDecoder, \
3232
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
33-
from .flatten import Flatten
33+
from .flatten import Flatten, Unflatten
3434

3535
__all__ = [
3636
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
@@ -54,5 +54,5 @@
5454
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
5555
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
5656
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
57-
'Flatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
57+
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
5858
]

torch/nn/modules/flatten.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from .module import Module
22

3+
from typing import Union
34
from torch import Tensor
5+
from torch import Size
6+
47

58
class Flatten(Module):
69
r"""
@@ -31,3 +34,106 @@ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
3134

3235
def forward(self, input: Tensor) -> Tensor:
3336
return input.flatten(self.start_dim, self.end_dim)
37+
38+
def extra_repr(self) -> str:
39+
return 'start_dim={}, end_dim={}'.format(
40+
self.start_dim, self.end_dim
41+
)
42+
43+
44+
class Unflatten(Module):
45+
r"""
46+
Unflattens a tensor into another tensor of a desired shape. For use with :class:`~nn.Sequential`.
47+
48+
* :attr:`dim` specifies the dimension of the input tensor to be flattened, and it can
49+
be either `str` or `int` when `NamedTensor` or `Tensor` is used, respectively.
50+
51+
* :attr:`unflattened_size` is the size of the unflattened dimension of the tensor and it can be a
52+
`namedshape` (`tuple` of tuples) if :attr:`dim` is `str` or a `tuple` of ints as well as `torch.Size` if
53+
:attr:`dim` is an `int`.
54+
55+
Shape:
56+
- Input: :math:`(N, *dims)`
57+
- Output: :math:`(N, C_out, H_out, W_out)`
58+
59+
Args:
60+
dim (Union[int, str]): Dimension to be flattened
61+
unflattened_size (Union[tuple, torch.Size]): Size of the output tensor
62+
63+
Examples:
64+
>>> input = torch.randn(2, 50)
65+
>>> # With tuple of ints
66+
>>> m = nn.Sequential(
67+
>>> nn.Linear(50, 50),
68+
>>> nn.Unflatten(1, (2, 5, 5))
69+
>>> )
70+
>>> output = m(output)
71+
>>> output.size()
72+
torch.Size([2, 2, 5, 5])
73+
>>> # With torch.Size
74+
>>> m = nn.Sequential(
75+
>>> nn.Linear(50, 50),
76+
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
77+
>>> )
78+
>>> output = m(output)
79+
>>> output.size()
80+
torch.Size([2, 2, 5, 5])
81+
>>> # With namedshape (tuple of tuples)
82+
>>> m = nn.Sequential(
83+
>>> nn.Linear(50, 50),
84+
>>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50)))
85+
>>> )
86+
>>> output = m(output)
87+
>>> output.size()
88+
torch.Size([2, 2, 5, 5])
89+
"""
90+
__constants__ = ['dim', 'unflattened_size']
91+
dim: Union[int, str]
92+
unflattened_size: Union[tuple, Size]
93+
94+
def __init__(self, dim: Union[int, str], unflattened_size: Union[tuple, Size]) -> None:
95+
super(Unflatten, self).__init__()
96+
if isinstance(dim, int):
97+
self._require_tuple_int(unflattened_size)
98+
self.named = False
99+
else:
100+
self._require_tuple_tuple(unflattened_size)
101+
self.named = True
102+
103+
self.dim = dim
104+
self.unflattened_size = unflattened_size
105+
106+
def _require_tuple_tuple(self, input):
107+
if (isinstance(input, tuple)):
108+
for idx, elem in enumerate(input):
109+
if not isinstance(elem, tuple):
110+
raise TypeError("unflattened_size must be tuple of tuples, " +
111+
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
112+
return
113+
raise TypeError("unflattened_size must be a tuple of tuples, " +
114+
"but found type {}".format(type(input).__name__))
115+
116+
def _require_tuple_int(self, input):
117+
if (isinstance(input, tuple)):
118+
for idx, elem in enumerate(input):
119+
if not isinstance(elem, int):
120+
raise TypeError("unflattened_size must be tuple of ints, " +
121+
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
122+
return
123+
raise TypeError("unflattened_size must be a tuple of ints, but found type {}".format(type(input).__name__))
124+
125+
def forward(self, input: Tensor) -> Tensor:
126+
if self.named:
127+
return input.unflatten(self.dim, self.unflattened_size)
128+
else:
129+
dim = int(self.dim)
130+
if dim < 0:
131+
dim += input.dim()
132+
inp_size = list(input.size())
133+
new_size = inp_size[:dim] + list(self.unflattened_size) + inp_size[dim + 1:]
134+
return input.view(new_size)
135+
136+
def extra_repr(self) -> str:
137+
return 'dim={}, unflattened_size={}'.format(
138+
self.dim, self.unflattened_size
139+
)

0 commit comments

Comments
 (0)