-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence #113257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
055b558
262a9b4
6f6e4fc
ffd22ad
f76b0da
1cc8266
1c383c7
60b8bf7
191337c
d7efa32
1013048
803db31
cc690af
3cd31a9
5be6748
7185ec0
e96f815
b4eb8c4
635856a
8e44087
9e0f304
9c35fc4
352ec51
5db1b78
e6fb4ac
51f2fd6
ad4c02e
debb991
26e9b9f
089755c
1fa4833
0ddbed0
d134d37
7b81fac
b326ca8
fc27cc0
558cc9e
9eb0d12
40241fb
b1c56ee
ce8dd05
768cb28
0915929
431db69
52f8591
dd684e8
cfc02f7
ae9abfc
493eb16
d3b9d71
7cdbae0
18292cd
27b2675
be4efe6
02720bc
3c02cfa
6bc3bc0
cb69cac
5e92f2c
f4ce844
a043f30
99df03c
f393e8e
9a583da
6c08038
79cf1d8
dcc50a3
c055f62
2a58ee9
d0f0043
538a7ca
242123d
9c9be5c
c796b81
873f357
a1784b7
7467291
926849f
ed58779
7c6a82d
317aec5
07cb1eb
d235cfd
c21ab62
b4ebd5d
c162536
0c8714d
f1e5777
51b744f
d2576e0
9f67769
174bdbd
6736b4e
58bf959
b2e0ddc
a3d53df
c5c8d77
e77c132
7398bb1
226a2ba
b90d3a6
d184fa2
4edc9e3
8338e6f
ec932cb
0d429df
de2da6d
cf3ca87
db02c48
79964c0
d2707e8
e493147
43fb2e2
7b48138
9e67c64
0b034ed
1979e85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec): | |
| def store_namedtuple_fields(ts): | ||
| if ts.type is None: | ||
| return | ||
| if ts.type == namedtuple: | ||
| if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NB: Due to this, previously, all tests passed in OSS while break internally. With the latest commit, this can be reverted. |
||
| serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name | ||
| if serialized_type_name in self.treespec_namedtuple_fields: | ||
| field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| # mypy: allow-untyped-defs | ||
| import os | ||
| from collections import namedtuple | ||
| from typing import Any | ||
| from typing import Any, NamedTuple, Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None): | |
| return torch._VF._make_dual(tensor, tangent, level=level) | ||
|
|
||
|
|
||
| _UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"]) | ||
|
|
||
|
|
||
| class UnpackedDualTensor(_UnpackedDualTensor): | ||
| class UnpackedDualTensor(NamedTuple): | ||
XuehaiPan marked this conversation as resolved.
Show resolved
Hide resolved
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @XuehaiPan did you need to subclass NamedTuple? Why did we need to change the original code?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep it as it is. I can revert this if that is preferred. Subclassing
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no this is good, thank you for explaining |
||
| r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. | ||
|
|
||
| See :func:`unpack_dual` for more details. | ||
|
|
||
| """ | ||
|
|
||
| primal: torch.Tensor | ||
| tangent: Optional[torch.Tensor] | ||
|
|
||
|
|
||
| def unpack_dual(tensor, *, level=None): | ||
| r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.