Skip to content

Commit b3d57ec

Browse files
committed
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification
ghstack-source-id: 7d04f95 Pull Request resolved: #138214
1 parent ab40a39 commit b3d57ec

File tree

21 files changed

+313
-145
lines changed

21 files changed

+313
-145
lines changed

test/export/test_export.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@
8686
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
8787
from torch.testing._internal.two_tensor import TwoTensor
8888
from torch.utils._pytree import (
89-
LeafSpec,
9089
register_constant,
9190
tree_flatten,
9291
tree_map,
9392
tree_unflatten,
9493
TreeSpec,
9594
treespec_dumps,
95+
treespec_leaf,
9696
treespec_loads,
9797
)
9898

@@ -6300,7 +6300,7 @@ class MyDataClass:
63006300

63016301
dt = MyDataClass(x=3, y=4)
63026302
flat, spec = tree_flatten(dt)
6303-
self.assertTrue(spec, LeafSpec())
6303+
self.assertTrue(spec, treespec_leaf())
63046304
self.assertTrue(len(flat) == 1)
63056305

63066306
torch.export.register_dataclass(
@@ -6311,7 +6311,9 @@ class MyDataClass:
63116311
flat, spec = tree_flatten(dt)
63126312
self.assertEqual(
63136313
spec,
6314-
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
6314+
TreeSpec(
6315+
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
6316+
),
63156317
)
63166318
self.assertEqual(flat, [3, 4])
63176319

@@ -6344,7 +6346,7 @@ class MyOtherDataClass: # the pytree registration don't allow registering the s
63446346
TreeSpec(
63456347
MyOtherDataClass,
63466348
[["x", "y", "z"], []],
6347-
[LeafSpec(), LeafSpec(), LeafSpec()],
6349+
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
63486350
),
63496351
)
63506352
self.assertEqual(flat, [3, 4, None])

test/test_pytree.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
6565
A = auto()
6666

6767

68-
python_leafspec = python_pytree.LeafSpec()
69-
70-
7168
class TestGenericPytree(TestCase):
7269
def test_aligned_public_apis(self):
7370
public_apis = python_pytree.__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197194
def run_test_with_leaf(leaf):
198195
values, treespec = pytree.tree_flatten(leaf)
199196
self.assertEqual(values, [leaf])
200-
self.assertEqual(treespec, pytree.LeafSpec())
197+
self.assertEqual(treespec, pytree.treespec_leaf())
201198

202199
unflattened = pytree.tree_unflatten(values, treespec)
203200
self.assertEqual(unflattened, leaf)
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215212
(
216213
python_pytree,
217214
lambda tup: python_pytree.TreeSpec(
218-
tuple, None, [python_leafspec for _ in tup]
215+
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
219216
),
220217
),
221218
name="python",
@@ -250,7 +247,7 @@ def run_test(tup):
250247
(
251248
python_pytree,
252249
lambda lst: python_pytree.TreeSpec(
253-
list, None, [python_leafspec for _ in lst]
250+
list, None, [python_pytree.treespec_leaf() for _ in lst]
254251
),
255252
),
256253
name="python",
@@ -286,7 +283,7 @@ def run_test(lst):
286283
lambda dct: python_pytree.TreeSpec(
287284
dict,
288285
list(dct.keys()),
289-
[python_leafspec for _ in dct.values()],
286+
[python_pytree.treespec_leaf() for _ in dct.values()],
290287
),
291288
),
292289
name="python",
@@ -327,7 +324,7 @@ def run_test(dct):
327324
lambda odict: python_pytree.TreeSpec(
328325
OrderedDict,
329326
list(odict.keys()),
330-
[python_leafspec for _ in odict.values()],
327+
[python_pytree.treespec_leaf() for _ in odict.values()],
331328
),
332329
),
333330
name="python",
@@ -371,7 +368,7 @@ def run_test(odict):
371368
lambda ddct: python_pytree.TreeSpec(
372369
defaultdict,
373370
[ddct.default_factory, list(ddct.keys())],
374-
[python_leafspec for _ in ddct.values()],
371+
[python_pytree.treespec_leaf() for _ in ddct.values()],
375372
),
376373
),
377374
name="python",
@@ -413,7 +410,7 @@ def run_test(ddct):
413410
(
414411
python_pytree,
415412
lambda deq: python_pytree.TreeSpec(
416-
deque, deq.maxlen, [python_leafspec for _ in deq]
413+
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
417414
),
418415
),
419416
name="python",
@@ -453,7 +450,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
453450
def run_test(tup):
454451
if pytree is python_pytree:
455452
expected_spec = python_pytree.TreeSpec(
456-
namedtuple, Point, [python_leafspec for _ in tup]
453+
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
457454
)
458455
else:
459456
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@@ -848,16 +845,16 @@ def test_import_pytree_doesnt_import_optree(self):
848845

849846
def test_treespec_equality(self):
850847
self.assertEqual(
851-
python_pytree.LeafSpec(),
852-
python_pytree.LeafSpec(),
848+
python_pytree.treespec_leaf(),
849+
python_pytree.treespec_leaf(),
853850
)
854851
self.assertEqual(
855852
python_pytree.TreeSpec(list, None, []),
856853
python_pytree.TreeSpec(list, None, []),
857854
)
858855
self.assertEqual(
859-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
860-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
856+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
857+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
861858
)
862859
self.assertFalse(
863860
python_pytree.TreeSpec(tuple, None, [])
@@ -892,24 +889,32 @@ def test_treespec_repr(self):
892889
# python_pytree.tree_structure({})
893890
python_pytree.TreeSpec(dict, [], []),
894891
# python_pytree.tree_structure([0])
895-
python_pytree.TreeSpec(list, None, [python_leafspec]),
892+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
896893
# python_pytree.tree_structure([0, 1])
897894
python_pytree.TreeSpec(
898895
list,
899896
None,
900-
[python_leafspec, python_leafspec],
897+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
901898
),
902899
# python_pytree.tree_structure((0, 1, 2))
903900
python_pytree.TreeSpec(
904901
tuple,
905902
None,
906-
[python_leafspec, python_leafspec, python_leafspec],
903+
[
904+
python_pytree.treespec_leaf(),
905+
python_pytree.treespec_leaf(),
906+
python_pytree.treespec_leaf(),
907+
],
907908
),
908909
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
909910
python_pytree.TreeSpec(
910911
dict,
911912
["a", "b", "c"],
912-
[python_leafspec, python_leafspec, python_leafspec],
913+
[
914+
python_pytree.treespec_leaf(),
915+
python_pytree.treespec_leaf(),
916+
python_pytree.treespec_leaf(),
917+
],
913918
),
914919
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
915920
python_pytree.TreeSpec(
@@ -919,13 +924,17 @@ def test_treespec_repr(self):
919924
python_pytree.TreeSpec(
920925
tuple,
921926
None,
922-
[python_leafspec, python_leafspec],
927+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
923928
),
924-
python_leafspec,
929+
python_pytree.treespec_leaf(),
925930
python_pytree.TreeSpec(
926931
dict,
927932
["a", "b", "c"],
928-
[python_leafspec, python_leafspec, python_leafspec],
933+
[
934+
python_pytree.treespec_leaf(),
935+
python_pytree.treespec_leaf(),
936+
python_pytree.treespec_leaf(),
937+
],
929938
),
930939
],
931940
),
@@ -938,12 +947,15 @@ def test_treespec_repr(self):
938947
tuple,
939948
None,
940949
[
941-
python_leafspec,
942-
python_leafspec,
950+
python_pytree.treespec_leaf(),
951+
python_pytree.treespec_leaf(),
943952
python_pytree.TreeSpec(
944953
list,
945954
None,
946-
[python_leafspec, python_leafspec],
955+
[
956+
python_pytree.treespec_leaf(),
957+
python_pytree.treespec_leaf(),
958+
],
947959
),
948960
],
949961
),
@@ -957,12 +969,12 @@ def test_treespec_repr(self):
957969
python_pytree.TreeSpec(
958970
list,
959971
None,
960-
[python_leafspec, python_leafspec],
972+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
961973
),
962974
python_pytree.TreeSpec(
963975
list,
964976
None,
965-
[python_leafspec, python_leafspec],
977+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
966978
),
967979
python_pytree.TreeSpec(dict, [], []),
968980
],
@@ -971,7 +983,7 @@ def test_treespec_repr(self):
971983
python_pytree.TreeSpec(
972984
python_pytree.structseq,
973985
torch.return_types.sort,
974-
[python_leafspec, python_leafspec],
986+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
975987
),
976988
],
977989
)
@@ -997,7 +1009,7 @@ def test_pytree_serialize_defaultdict_enum(self):
9971009
list,
9981010
None,
9991011
[
1000-
python_leafspec,
1012+
python_pytree.treespec_leaf(),
10011013
],
10021014
),
10031015
],
@@ -1006,7 +1018,7 @@ def test_pytree_serialize_defaultdict_enum(self):
10061018
self.assertIsInstance(serialized_spec, str)
10071019

10081020
def test_pytree_serialize_enum(self):
1009-
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
1021+
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
10101022

10111023
serialized_spec = python_pytree.treespec_dumps(spec)
10121024
self.assertIsInstance(serialized_spec, str)
@@ -1169,12 +1181,20 @@ def test_saved_serialized(self):
11691181
OrderedDict,
11701182
[1, 2, 3],
11711183
[
1172-
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
1173-
python_leafspec,
1184+
python_pytree.TreeSpec(
1185+
tuple,
1186+
None,
1187+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
1188+
),
1189+
python_pytree.treespec_leaf(),
11741190
python_pytree.TreeSpec(
11751191
dict,
11761192
[4, 5, 6],
1177-
[python_leafspec, python_leafspec, python_leafspec],
1193+
[
1194+
python_pytree.treespec_leaf(),
1195+
python_pytree.treespec_leaf(),
1196+
python_pytree.treespec_leaf(),
1197+
],
11781198
),
11791199
],
11801200
)
@@ -1459,7 +1479,7 @@ def setUp(self):
14591479
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
14601480

14611481
def test_treespec_equality(self):
1462-
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1482+
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
14631483

14641484
def test_treespec_repr(self):
14651485
# Check that it looks sane

torch/_dynamo/polyfills/pytree.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
import builtins
20-
from collections.abc import Iterable
20+
from collections.abc import Iterable, Mapping
2121
from typing_extensions import Self
2222

2323

@@ -324,6 +324,61 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
324324
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
325325
return isinstance(obj, PyTreeSpec)
326326

327+
@substitute_in_graph( # type: ignore[arg-type]
328+
cxx_pytree.treespec_leaf,
329+
# We need to disable constant folding here because we want the function to reference the
330+
# PyTreeSpec class defined above, not the one in the C++ module.
331+
can_constant_fold_through=False,
332+
)
333+
def treespec_leaf() -> PyTreeSpec:
334+
return _LEAF_SPEC
335+
336+
@substitute_in_graph( # type: ignore[arg-type]
337+
cxx_pytree.treespec_tuple,
338+
# We need to disable constant folding here because we want the function to reference the
339+
# PyTreeSpec class defined above, not the one in the C++ module.
340+
can_constant_fold_through=False,
341+
)
342+
def treespec_tuple(iterable: Iterable[PyTreeSpec] = (), /) -> PyTreeSpec:
343+
children = tuple(iterable)
344+
if any(not _is_pytreespec_instance(child) for child in children):
345+
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
346+
handler = optree.register_pytree_node.get(tuple, namespace="torch") # type: ignore[attr-defined]
347+
return PyTreeSpec(
348+
tuple(children),
349+
tuple,
350+
None,
351+
tuple(range(len(children))),
352+
handler.unflatten_func,
353+
)
354+
355+
@substitute_in_graph( # type: ignore[arg-type]
356+
cxx_pytree.treespec_dict,
357+
# We need to disable constant folding here because we want the function to reference the
358+
# PyTreeSpec class defined above, not the one in the C++ module.
359+
can_constant_fold_through=False,
360+
)
361+
def treespec_dict(
362+
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
363+
/,
364+
**kwargs: PyTreeSpec,
365+
) -> PyTreeSpec:
366+
dct = dict(mapping, **kwargs)
367+
if any(not _is_pytreespec_instance(child) for child in dct.values()):
368+
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
369+
370+
(
371+
children,
372+
metadata,
373+
entries,
374+
unflatten_func,
375+
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
376+
dct, # type: ignore[arg-type]
377+
none_is_leaf=True,
378+
namespace="torch",
379+
)
380+
return PyTreeSpec(tuple(children), dict, metadata, entries, unflatten_func) # type: ignore[arg-type]
381+
327382
@substitute_in_graph( # type: ignore[arg-type]
328383
cxx_pytree.tree_flatten,
329384
# We need to disable constant folding here because we want the function to reference the

torch/_dynamo/variables/builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,9 +3507,7 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker:
35073507
pass # failthrough to unimplemented branch
35083508
elif isinstance(value, torch.fx.graph_module.GraphModule):
35093509
return SourcelessGraphModuleVariable(value)
3510-
elif isinstance(
3511-
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
3512-
):
3510+
elif isinstance(value, torch.utils._pytree.TreeSpec):
35133511
return UserDefinedObjectVariable(value)
35143512
elif PlacementVariable.is_placement(value):
35153513
return PlacementVariable(value)

torch/_functorch/_aot_autograd/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def set(self, spec: pytree.TreeSpec) -> None:
150150
assert spec is not None
151151
self.spec: pytree.TreeSpec = spec
152152
if self.spec.type in {tuple, list} and all(
153-
child.is_leaf() for child in spec.children_specs
153+
child.is_leaf() for child in spec.children()
154154
):
155155
self.is_simple = True
156156
if self.spec.is_leaf():

0 commit comments

Comments
 (0)