Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
055b558
[pytree] add APIs to determine a class is a namedtuple or PyStructSeq…
XuehaiPan Nov 8, 2023
262a9b4
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 8, 2023
6f6e4fc
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 11, 2023
ffd22ad
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 11, 2023
f76b0da
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 11, 2023
1cc8266
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 11, 2023
1c383c7
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 18, 2023
60b8bf7
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 18, 2023
191337c
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 18, 2023
d7efa32
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 22, 2023
1013048
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 22, 2023
803db31
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 28, 2023
cc690af
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
3cd31a9
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
5be6748
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
7185ec0
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
e96f815
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
b4eb8c4
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Nov 30, 2023
635856a
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Dec 1, 2023
8e44087
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Dec 1, 2023
9e0f304
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Dec 7, 2023
9c35fc4
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Dec 8, 2023
352ec51
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Dec 24, 2023
5db1b78
Update on "[pytree] add APIs to determine a class is a namedtuple or …
XuehaiPan Jan 20, 2024
e6fb4ac
Update
XuehaiPan Mar 20, 2024
51f2fd6
Update
XuehaiPan Mar 22, 2024
ad4c02e
Update
XuehaiPan Mar 24, 2024
debb991
Update
XuehaiPan Apr 21, 2024
26e9b9f
Update
XuehaiPan Apr 21, 2024
089755c
Update
XuehaiPan Jun 21, 2024
1fa4833
Update
XuehaiPan Jul 22, 2024
0ddbed0
Update
XuehaiPan Aug 12, 2024
d134d37
Update
XuehaiPan Oct 20, 2024
7b81fac
Update
XuehaiPan Oct 20, 2024
b326ca8
Update
XuehaiPan Oct 20, 2024
fc27cc0
Update
XuehaiPan Oct 20, 2024
558cc9e
Update
XuehaiPan Oct 20, 2024
9eb0d12
Update
XuehaiPan Oct 20, 2024
40241fb
Update
XuehaiPan Oct 21, 2024
b1c56ee
Update
XuehaiPan Oct 21, 2024
ce8dd05
Update
XuehaiPan Oct 21, 2024
768cb28
Update
XuehaiPan Oct 21, 2024
0915929
Update
XuehaiPan Oct 21, 2024
431db69
Update
XuehaiPan Oct 21, 2024
52f8591
Update
XuehaiPan Oct 21, 2024
dd684e8
Update
XuehaiPan Oct 21, 2024
cfc02f7
Update
XuehaiPan Oct 21, 2024
ae9abfc
Update
XuehaiPan Oct 21, 2024
493eb16
Update
XuehaiPan Oct 21, 2024
d3b9d71
Update
XuehaiPan Oct 22, 2024
7cdbae0
Update
XuehaiPan Oct 22, 2024
18292cd
Update
XuehaiPan Oct 22, 2024
27b2675
Update
XuehaiPan Oct 22, 2024
be4efe6
Update
XuehaiPan Oct 22, 2024
02720bc
Update
XuehaiPan Oct 24, 2024
3c02cfa
Update
XuehaiPan Oct 24, 2024
6bc3bc0
Update
XuehaiPan Oct 24, 2024
cb69cac
Update
XuehaiPan Oct 25, 2024
5e92f2c
Update
XuehaiPan Oct 29, 2024
f4ce844
Update
XuehaiPan Oct 29, 2024
a043f30
Update
XuehaiPan Oct 29, 2024
99df03c
Update
XuehaiPan Oct 29, 2024
f393e8e
Update
XuehaiPan Oct 30, 2024
9a583da
Update
XuehaiPan Oct 30, 2024
6c08038
Update
XuehaiPan Nov 5, 2024
79cf1d8
Update
XuehaiPan Nov 11, 2024
dcc50a3
Update
XuehaiPan Nov 17, 2024
c055f62
Update
XuehaiPan Nov 20, 2024
2a58ee9
Update
XuehaiPan Nov 20, 2024
d0f0043
Update
XuehaiPan Nov 20, 2024
538a7ca
Update
XuehaiPan Nov 20, 2024
242123d
Update
XuehaiPan Nov 20, 2024
9c9be5c
Update
XuehaiPan Nov 20, 2024
c796b81
Update
XuehaiPan Nov 20, 2024
873f357
Update
XuehaiPan Nov 21, 2024
a1784b7
Update
XuehaiPan Nov 21, 2024
7467291
Update
XuehaiPan Nov 21, 2024
926849f
Update
XuehaiPan Nov 21, 2024
ed58779
Update
XuehaiPan Nov 22, 2024
7c6a82d
Update
XuehaiPan Nov 22, 2024
317aec5
Update
XuehaiPan Nov 26, 2024
07cb1eb
Update
XuehaiPan Nov 26, 2024
d235cfd
Update
XuehaiPan Nov 27, 2024
c21ab62
Update
XuehaiPan Dec 2, 2024
b4ebd5d
Update
XuehaiPan Dec 2, 2024
c162536
Update
XuehaiPan Dec 7, 2024
0c8714d
Update
XuehaiPan Dec 9, 2024
f1e5777
Update
XuehaiPan Dec 13, 2024
51b744f
Update
XuehaiPan Jan 13, 2025
d2576e0
Update
XuehaiPan Feb 4, 2025
9f67769
Update
XuehaiPan Feb 25, 2025
174bdbd
Update
XuehaiPan Feb 25, 2025
6736b4e
Update
XuehaiPan Feb 25, 2025
58bf959
Update
XuehaiPan Feb 25, 2025
b2e0ddc
Update
XuehaiPan Feb 25, 2025
a3d53df
Update
XuehaiPan Feb 25, 2025
c5c8d77
Update
XuehaiPan Feb 25, 2025
e77c132
Update
XuehaiPan Feb 25, 2025
7398bb1
Update
XuehaiPan Feb 25, 2025
226a2ba
Update
XuehaiPan Feb 26, 2025
b90d3a6
Update
XuehaiPan Feb 26, 2025
d184fa2
Update
XuehaiPan Feb 26, 2025
4edc9e3
Update
XuehaiPan Feb 26, 2025
8338e6f
Update
XuehaiPan Feb 26, 2025
ec932cb
Update
XuehaiPan Feb 26, 2025
0d429df
Update
XuehaiPan Feb 26, 2025
de2da6d
Update
XuehaiPan Feb 28, 2025
cf3ca87
Update
XuehaiPan Mar 3, 2025
db02c48
Update
XuehaiPan Mar 6, 2025
79964c0
Update
XuehaiPan Mar 11, 2025
d2707e8
Update
XuehaiPan Mar 12, 2025
e493147
Update
XuehaiPan Mar 13, 2025
43fb2e2
Update
XuehaiPan Mar 14, 2025
7b48138
Update
XuehaiPan Mar 20, 2025
9e67c64
Update
XuehaiPan Mar 31, 2025
0b034ed
Update
XuehaiPan Mar 31, 2025
1979e85
Update
XuehaiPan Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ def load(cls, model, example_inputs):
# see https://github.com/pytorch/pytorch/issues/113029
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)

if pytree._is_namedtuple_instance(example_outputs):
if pytree.is_namedtuple_instance(example_outputs):
typ = type(example_outputs)
pytree._register_namedtuple(
typ,
Expand Down
169 changes: 149 additions & 20 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import subprocess
import sys
import time
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
Expand Down Expand Up @@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")

@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_namedtuple(self, pytree):
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])

class DirectNamedTuple2(NamedTuple):
x: int
y: int

class IndirectNamedTuple1(DirectNamedTuple1):
pass

class IndirectNamedTuple2(DirectNamedTuple2):
pass

self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
self.assertFalse(pytree.is_namedtuple((0, 1)))
self.assertFalse(pytree.is_namedtuple([0, 1]))
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
self.assertFalse(pytree.is_namedtuple({0, 1}))
self.assertFalse(pytree.is_namedtuple(1))

self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple(time.struct_time))
self.assertFalse(pytree.is_namedtuple(tuple))
self.assertFalse(pytree.is_namedtuple(list))

self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
self.assertFalse(pytree.is_namedtuple_class(tuple))
self.assertFalse(pytree.is_namedtuple_class(list))

@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_structseq(self, pytree):
class FakeStructSeq(tuple):
n_fields = 2
n_sequence_fields = 2
n_unnamed_fields = 0

__slots__ = ()
__match_args__ = ("x", "y")

def __new__(cls, sequence):
return super().__new__(cls, sequence)

@property
def x(self):
return self[0]

@property
def y(self):
return self[1]

DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])

class DirectNamedTuple2(NamedTuple):
x: int
y: int

self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
self.assertTrue(pytree.is_structseq(time.gmtime()))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_structseq((0, 1)))
self.assertFalse(pytree.is_structseq([0, 1]))
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
self.assertFalse(pytree.is_structseq({0, 1}))
self.assertFalse(pytree.is_structseq(1))

self.assertFalse(pytree.is_structseq(FakeStructSeq))
self.assertTrue(pytree.is_structseq(time.struct_time))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq(tuple))
self.assertFalse(pytree.is_structseq(list))

self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
self.assertTrue(
pytree.is_structseq_class(time.struct_time),
)
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq_class(tuple))
self.assertFalse(pytree.is_structseq_class(list))

# torch.return_types.* are all PyStructSequence types
for cls in vars(torch.return_types).values():
if isinstance(cls, type) and issubclass(cls, tuple):
self.assertTrue(pytree.is_structseq(cls))
self.assertTrue(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))

inst = cls(range(cls.n_sequence_fields))
self.assertTrue(pytree.is_structseq(inst))
self.assertTrue(pytree.is_structseq(type(inst)))
self.assertFalse(pytree.is_structseq_class(inst))
self.assertTrue(pytree.is_structseq_class(type(inst)))
self.assertFalse(pytree.is_namedtuple(inst))
self.assertFalse(pytree.is_namedtuple_class(inst))
else:
self.assertFalse(pytree.is_structseq(cls))
self.assertFalse(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))


class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):
Expand Down Expand Up @@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self):
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
)

spec = py_pytree.TreeSpec(
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(Point1(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)

Expand All @@ -990,18 +1117,28 @@ class Point2(NamedTuple):
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
)

spec = py_pytree.TreeSpec(
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
spec = py_pytree.tree_structure(Point2(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)

class Point3(Point2):
pass

py_pytree._register_namedtuple(
Point3,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
)

spec = py_pytree.tree_structure(Point3(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)

def test_pytree_serialize_namedtuple_bad(self):
DummyType = namedtuple("DummyType", ["x", "y"])

spec = py_pytree.TreeSpec(
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))

with self.assertRaisesRegex(
NotImplementedError, "Please register using `_register_namedtuple`"
Expand All @@ -1020,9 +1157,7 @@ def __init__(self, x, y):
lambda xs, _: DummyType(*xs),
)

spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
with self.assertRaisesRegex(
NotImplementedError, "No registered serialization name"
):
Expand All @@ -1042,9 +1177,7 @@ def __init__(self, x, y):
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertIn("moo", serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
Expand Down Expand Up @@ -1082,9 +1215,7 @@ def __init__(self, x, y):
from_dumpable_context=lambda dumpable_context: None,
)

spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))

with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
Expand All @@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self):
import json

Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(Point(1, 2))
py_pytree._register_namedtuple(
Point,
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/polyfills/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def _(*args: Any, **kwargs: Any) -> bool:
"structseq_fields",
):
__func = getattr(optree, __name)
substitute_in_graph(__func, can_constant_fold_through=True)(
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
__all__ += [__name] # noqa: PLE0604
del __func
del __name

Expand Down
2 changes: 1 addition & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec):
def store_namedtuple_fields(ts):
if ts.type is None:
return
if ts.type == namedtuple:
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: Due to this, previously, all tests passed in OSS while break internally. With the latest commit, this can be reverted.

serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name
if serialized_type_name in self.treespec_namedtuple_fields:
field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names
Expand Down
12 changes: 5 additions & 7 deletions torch/autograd/forward_ad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import os
from collections import namedtuple
from typing import Any
from typing import Any, NamedTuple, Optional

import torch

Expand Down Expand Up @@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None):
return torch._VF._make_dual(tensor, tangent, level=level)


_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])


class UnpackedDualTensor(_UnpackedDualTensor):
class UnpackedDualTensor(NamedTuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuehaiPan did you need to subclass NamedTuple? Why did we need to change the original code?

Copy link
Collaborator Author

@XuehaiPan XuehaiPan Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep it as it is. I can revert this if that is preferred.

Subclassing typing.NamedTuple is a more modern way to create a namedtuple type. The original code was previously changed here https://github.com/pytorch/pytorch/pull/76492/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9 3 years ago. That was a workaround to add __doc__ to a namedtuple type created by collections.namedtuple.

Copy link
Contributor

@zou3519 zou3519 Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is good, thank you for explaining

r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.

See :func:`unpack_dual` for more details.

"""

primal: torch.Tensor
tangent: Optional[torch.Tensor]


def unpack_dual(tensor, *, level=None):
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.
Expand Down
24 changes: 20 additions & 4 deletions torch/testing/_internal/composite_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,16 @@ def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):

expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
expected = tree_map(fwAD.unpack_dual, expected)
expected_primals = tree_map(lambda x: x.primal, expected)
expected_tangents = tree_map(lambda x: x.tangent, expected)
expected_primals = tree_map(
lambda x: x.primal,
expected,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
expected_tangents = tree_map(
lambda x: x.tangent,
expected,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)

# Permutations of arg and kwargs in CCT.
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
Expand Down Expand Up @@ -586,7 +594,15 @@ def unwrap(e):
return e.elem if isinstance(e, CCT) else e

actual = tree_map(fwAD.unpack_dual, actual)
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
actual_primals = tree_map(
lambda x: unwrap(x.primal),
actual,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
actual_tangents = tree_map(
lambda x: unwrap(x.tangent),
actual,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)
17 changes: 16 additions & 1 deletion torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from optree import PyTreeSpec as TreeSpec # direct import for type annotations

import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
from torch.utils._pytree import (
is_namedtuple as is_namedtuple,
is_namedtuple_class as is_namedtuple_class,
is_namedtuple_instance as is_namedtuple_instance,
is_structseq as is_structseq,
is_structseq_class as is_structseq_class,
is_structseq_instance as is_structseq_instance,
KeyEntry as KeyEntry,
)


__all__ = [
Expand All @@ -39,6 +47,7 @@
"keystr",
"key_get",
"register_pytree_node",
"tree_is_leaf",
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
Expand All @@ -58,6 +67,12 @@
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]


Expand Down
Loading
Loading