Skip to content

Commit d65f194

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[dynamo] support operator.attrgetter and operator.itemgetter (#141122)
Pull Request resolved: #141122 Approved by: https://github.com/Skylion007, https://github.com/jansel
1 parent fb529c2 commit d65f194

File tree

4 files changed

+137
-0
lines changed

4 files changed

+137
-0
lines changed

test/dynamo/test_functions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,6 +3071,44 @@ def fn():
30713071
opt_fn = torch.compile(fn, fullgraph=True)
30723072
self.assertEqual(opt_fn(), fn())
30733073

3074+
def test_attrgetter(self):
3075+
for attrs in (
3076+
("shape",),
3077+
("data.shape",),
3078+
("device", "shape"),
3079+
("device", "shape", "data.shape"),
3080+
):
3081+
with self.subTest(attrs=attrs):
3082+
3083+
def fn(x, y):
3084+
getter = operator.attrgetter(*attrs)
3085+
return getter(x), getter(y)
3086+
3087+
opt_fn = torch.compile(fullgraph=True)(fn)
3088+
3089+
x = torch.randn(3, 4)
3090+
y = torch.randn(3, 4)
3091+
self.assertEqual(opt_fn(x, y), fn(x, y))
3092+
3093+
def test_itemgetter(self):
3094+
for items in (
3095+
(0,),
3096+
(slice(1, 3),),
3097+
(0, 1),
3098+
(slice(1, 3), 0, 1),
3099+
):
3100+
with self.subTest(items=items):
3101+
3102+
def fn(x, y):
3103+
getter = operator.itemgetter(*items)
3104+
return getter(x), getter(y)
3105+
3106+
opt_fn = torch.compile(fullgraph=True)(fn)
3107+
3108+
x = torch.randn(3, 4)
3109+
y = torch.randn(3, 4)
3110+
self.assertEqual(opt_fn(x, y), fn(x, y))
3111+
30743112
def gen_random_range_args(self):
30753113
args_count = random.randint(1, 3)
30763114
args = [random.randint(-10, 10) for _ in range(args_count)]

torch/_dynamo/polyfills/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
builtins as builtins,
2222
functools as functools,
2323
itertools as itertools,
24+
operator as operator,
2425
os as os,
2526
sys as sys,
2627
)

torch/_dynamo/polyfills/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"builtins",
1717
"functools",
1818
"itertools",
19+
"operator",
1920
"os",
2021
"sys",
2122
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Python polyfills for operator
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import operator
8+
from typing import Any, Callable, overload, TypeVar
9+
from typing_extensions import TypeVarTuple, Unpack
10+
11+
from ..decorators import substitute_in_graph
12+
13+
14+
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
15+
__all__ = ["attrgetter", "itemgetter"]
16+
17+
18+
_T = TypeVar("_T")
19+
_T1 = TypeVar("_T1")
20+
_T2 = TypeVar("_T2")
21+
_Ts = TypeVarTuple("_Ts")
22+
_U = TypeVar("_U")
23+
_U1 = TypeVar("_U1")
24+
_U2 = TypeVar("_U2")
25+
_Us = TypeVarTuple("_Us")
26+
27+
28+
@overload
29+
def attrgetter(attr: str, /) -> Callable[[Any], _U]:
30+
...
31+
32+
33+
@overload
34+
def attrgetter(
35+
attr1: str, attr2: str, /, *attrs: str
36+
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
37+
...
38+
39+
40+
# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter
41+
@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
42+
def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
43+
if len(attrs) == 0:
44+
raise TypeError("attrgetter expected 1 argument, got 0")
45+
46+
if any(not isinstance(attr, str) for attr in attrs):
47+
raise TypeError("attribute name must be a string")
48+
49+
def resolve_attr(obj: Any, attr: str) -> Any:
50+
for name in attr.split("."):
51+
obj = getattr(obj, name)
52+
return obj
53+
54+
if len(attrs) == 1:
55+
attr = attrs[0]
56+
57+
def getter(obj: Any) -> Any:
58+
return resolve_attr(obj, attr)
59+
60+
else:
61+
62+
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
63+
return tuple(resolve_attr(obj, attr) for attr in attrs)
64+
65+
return getter
66+
67+
68+
@overload
69+
def itemgetter(item: _T, /) -> Callable[[Any], _U]:
70+
...
71+
72+
73+
@overload
74+
def itemgetter(
75+
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
76+
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
77+
...
78+
79+
80+
# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter
81+
@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
82+
def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]:
83+
if len(items) == 0:
84+
raise TypeError("itemgetter expected 1 argument, got 0")
85+
86+
if len(items) == 1:
87+
item = items[0]
88+
89+
def getter(obj: Any) -> Any:
90+
return obj[item]
91+
92+
else:
93+
94+
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
95+
return tuple(obj[item] for item in items)
96+
97+
return getter

0 commit comments

Comments
 (0)