-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add Unflatten Module #41564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Unflatten Module #41564
Conversation
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Thanks for the PR.
I wasn't aware actually that Tensor.unflaten only works with named dimensions. My bad.
This is a bit unsatisfying here are the nn.Unflatten() layer wouldn't be able to do the opposite of nn.Flatten() in many cases (only NCHW inputs would work).
I wonder if we don't want the user to provide dim and unflattened_size as ints. So that we can get the same general API as unflatten but we don't have to recreate named Tensors in the middle.
One thing I could see as well is to accept as input, either
- dim as string and the second argument is a namedshape.
- dim is a int and the second argument is a tuple of int or torch.Size
|
Thanks for your feedback. No worries about that, what you've said is very reasonable. So you're saying that you'd like something like this? def __init__(self, dim: int, unflattened_size: Size) -> None:
super(Unflatten, self).__init__()
self.dim = dim
self.unflattened_size = unflattened_size
def forward(self, input: Tensor) -> Tensor:
return input.unflatten(
self.dim,
(('C', self.unflattened_size[0]), ('H', self.unflattened_size[1]), ('W', self.unflattened_size[2]))
)Where |
For the arguments yes. But your forward here would only support specific inputs. What about something like: class UnFlatten(nn.Module)
def __init__(self, dim: Union[int,str], unflattened_size: Union[tuple,Size]) -> None:
super(Unflatten, self).__init__()
if isinstance(dim, int):
self.named = False
# Make sure the unflattened_size is a tuple of ints
else:
self.named = True
# Make sure the unflattened_size is a tuple of tuple
self.dim = dim
self.unflattened_size = unflattened_size
def forward(self, input: Tensor) -> Tensor:
if self.named:
return input.unflatten(self.dim, self.unflattened_size)
else:
dim = self.dim
if self.dim < 0:
dim += input.ndim()
inp_size = list(input.size())
new_size = inp_size[:dim] + self.unflattened_size + inp_size[dim+1:]
return input.view(new_size)Also can you add a simple test in test/test_nn.py that makes sure that we get errors when wrong inputs are provided and that it does what we expect. |
💊 CI failures summary and remediationsAs of commit 3854ba7 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 33 times. |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Just small comments for the doc/test.
Oh right, yes.
Of course. Absolutely. |
|
|
||
| class Unflatten(Module): | ||
| r""" | ||
| Unflattens a tensor into another tensor of a desired shape. For use with :class:`~nn.Sequential`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD I think this should cover all possible ways to call this module. 👍
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thanks for making the change.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
What should we do with that check that failed @albanD? |
|
This is just flaky internal tests :'( I'll take care of it! |
This PR implements a feature extension discussed in #41516.
I followed this other PR #22245 to add this other module. While I was at it, I also added
extra_repr()method inFlattenwhich was missing.I see there are no unit tests for these modules. Should I add those too? If so, what is the best place I should place these?