Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
f7eda41
Update
XuehaiPan Oct 5, 2024
d6de4bf
Update
XuehaiPan Oct 5, 2024
83ff2c2
Update
XuehaiPan Oct 5, 2024
2c55fce
Update
XuehaiPan Oct 5, 2024
23b3379
Update
XuehaiPan Oct 5, 2024
6edfa81
Update
XuehaiPan Oct 5, 2024
fa386bb
Update
XuehaiPan Oct 5, 2024
97d1310
Update
XuehaiPan Oct 6, 2024
1d1b4d9
Update
XuehaiPan Oct 6, 2024
5d20ac4
Update
XuehaiPan Oct 6, 2024
b730a54
Update
XuehaiPan Oct 6, 2024
a4aaf07
Update
XuehaiPan Oct 6, 2024
ddbe690
Update
XuehaiPan Oct 6, 2024
ce450b5
Update
XuehaiPan Oct 6, 2024
08d01a5
Update
XuehaiPan Oct 6, 2024
47e5bc7
Update
XuehaiPan Oct 6, 2024
4e83cb8
Update
XuehaiPan Oct 6, 2024
b9ceeed
Update
XuehaiPan Oct 7, 2024
932776b
Update
XuehaiPan Oct 7, 2024
e10ade1
Update
XuehaiPan Oct 13, 2024
11bada8
Update
XuehaiPan Oct 13, 2024
0326b4b
Update
XuehaiPan Oct 16, 2024
3c8698d
Update
XuehaiPan Oct 16, 2024
e983ed4
Update
XuehaiPan Oct 16, 2024
200004a
Update
XuehaiPan Oct 16, 2024
fb4c78e
Update
XuehaiPan Oct 16, 2024
c72cc80
Update
XuehaiPan Oct 16, 2024
0970fb7
Update
XuehaiPan Oct 16, 2024
815d030
Update
XuehaiPan Oct 16, 2024
b5f7032
Update
XuehaiPan Oct 16, 2024
801514c
Update
XuehaiPan Oct 16, 2024
8771df0
Update
XuehaiPan Oct 16, 2024
724400a
Update
XuehaiPan Oct 16, 2024
3065c51
Update
XuehaiPan Oct 17, 2024
159d4bb
Update
XuehaiPan Oct 17, 2024
ba29eed
Update
XuehaiPan Oct 25, 2024
4764114
Update
XuehaiPan Oct 26, 2024
eae8204
Update
XuehaiPan Oct 29, 2024
914de47
Update
XuehaiPan Oct 29, 2024
e1f453f
Update
XuehaiPan Oct 30, 2024
bd6a2be
Update
XuehaiPan Oct 30, 2024
d59e9b6
Update
XuehaiPan Nov 2, 2024
3f315e3
Update
XuehaiPan Nov 5, 2024
5f9ef0f
Update
XuehaiPan Nov 11, 2024
56993dc
Update
XuehaiPan Nov 17, 2024
01edb01
Update
XuehaiPan Nov 20, 2024
2711657
Update
XuehaiPan Nov 20, 2024
8807abe
Update
XuehaiPan Nov 20, 2024
bf6e61e
Update
XuehaiPan Nov 20, 2024
9812a3b
Update
XuehaiPan Nov 20, 2024
dba520b
Update
XuehaiPan Nov 20, 2024
75d327e
Update
XuehaiPan Nov 20, 2024
410d7e3
Update
XuehaiPan Nov 21, 2024
a73afbc
Update
XuehaiPan Nov 21, 2024
f3930a0
Update
XuehaiPan Nov 21, 2024
5beb07d
Update
XuehaiPan Nov 21, 2024
794dc1e
Update
XuehaiPan Nov 22, 2024
73bc5a9
Update
XuehaiPan Nov 22, 2024
63ea0fd
Update
XuehaiPan Nov 26, 2024
5076f40
Update
XuehaiPan Nov 26, 2024
cf4d14e
Update
XuehaiPan Nov 27, 2024
c20d1d4
Update
XuehaiPan Dec 2, 2024
3e962b6
Update
XuehaiPan Dec 2, 2024
452fedf
Update
XuehaiPan Dec 7, 2024
91778a9
Update
XuehaiPan Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10314,6 +10314,8 @@ def fn(x, y):

def test_pytree_tree_map(self):
implemtations = [("python", python_pytree)]
if cxx_pytree is not None:
implemtations.append(("cxx", cxx_pytree))

for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):
Expand Down
119 changes: 118 additions & 1 deletion torch/_dynamo/polyfills/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs

import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES

from ..decorators import substitute_in_graph

Expand Down Expand Up @@ -200,6 +201,95 @@ def entries(self) -> list[Any]:
def entry(self, index: int) -> Any:
return self._entries[index]

def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None:
if treespec.is_leaf():
subtrees.append(node)
return

node_type = type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)

children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if metadata != treespec._metadata:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(node)}.",
)

if both_standard_dict:
# dictionary types are compatible with each other
expected_keys = treespec.entries()
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if (
node_type
is not deque # ignore mismatch of `maxlen` for deque
) and metadata != treespec._metadata:
raise ValueError(
f"Node metadata mismatch for node type {treespec.type!r}; "
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)

for subtree, subspec in zip(children, treespec._children):
helper(subspec, subtree, subtrees)

subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees

def unflatten(self, leaves: Iterable[Any]) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
Expand Down Expand Up @@ -295,3 +385,30 @@ def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
return treespec.unflatten(leaves)

__all__ += ["tree_unflatten"]

@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))

__all__ += ["tree_map"]

@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree

__all__ += ["tree_map_"]
40 changes: 25 additions & 15 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def tree_flatten(
The flattening order (i.e., the order of elements in the output list) is deterministic,
corresponding to a left-to-right depth-first tree traversal.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_flatten(tree)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> tree_flatten(1)
Expand All @@ -306,7 +306,7 @@ def tree_flatten(
if you want to keep the keys in the insertion order.

>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
>>> tree_flatten(tree)
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))

Expand Down Expand Up @@ -335,7 +335,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:

The inverse of :func:`tree_flatten`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> leaves, treespec = tree_flatten(tree)
>>> tree == tree_unflatten(leaves, treespec)
True
Expand Down Expand Up @@ -365,7 +365,7 @@ def tree_iter(

See also :func:`tree_flatten`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
Expand Down Expand Up @@ -400,7 +400,7 @@ def tree_leaves(

See also :func:`tree_flatten`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_leaves(tree)
[1, 2, 3, 4, None, 5]
>>> tree_leaves(1)
Expand Down Expand Up @@ -435,7 +435,7 @@ def tree_structure(

See also :func:`tree_flatten`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_structure(tree)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
>>> tree_structure(1)
Expand Down Expand Up @@ -472,9 +472,9 @@ def tree_map(

See also :func:`tree_map_`.

>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
{'x': 8, 'y': (43, 65)}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
{'x': False, 'y': (False, False), 'z': True}

If multiple inputs are given, the structure of the tree is taken from the first input;
Expand Down Expand Up @@ -572,7 +572,9 @@ def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:


@overload
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
def map_only(
__type_or_types_or_pred: Type3[T, S, U],
) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...


Expand All @@ -588,12 +590,14 @@ def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:


@overload
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
def map_only(
__type_or_types_or_pred: Callable[[Any], bool],
) -> MapOnlyFn[FnAny[Any]]:
...


def map_only(
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
Expand Down Expand Up @@ -858,7 +862,7 @@ def broadcast_prefix(
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
[1, 2, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
>>> broadcast_prefix([1, 2, 3], [1, 2, {"a": 3, "b": 4, "c": (None, 5)}])
[1, 2, 3, 3, 3, 3]

Args:
Expand All @@ -873,13 +877,19 @@ def broadcast_prefix(
Returns:
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
"""
return optree.broadcast_prefix(
result: List[Any] = []

def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)

tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
return result


# Broadcasts a pytree to the provided TreeSpec and returns the flattened
Expand Down
Loading
Loading