Skip to content

Commit 0e55cc4

Browse files
ysiraichipytorchmergebot
authored andcommitted
[HigherOrderOp] Flatten outputs of wrap. (#109433)
Fix: #109247 This PR flattens `wrap` outputs by inlining `pytree.tree_flatten` function after calling the inner function. Pull Request resolved: #109433 Approved by: https://github.com/zou3519 ghstack dependencies: #110290
1 parent f68f49c commit 0e55cc4

File tree

3 files changed

+133
-70
lines changed

3 files changed

+133
-70
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 83 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _test_wrap_simple(
146146
func,
147147
args_generator,
148148
expected_num_wrap_args,
149-
expected_opcount=1,
149+
expected_opcount=2,
150150
return_graph=False,
151151
):
152152
# Given a `func` that has a single call to `wrap`,
@@ -267,7 +267,7 @@ def f(x):
267267
f,
268268
default_args_generator((x,)),
269269
ifdynstaticdefault(2, 3),
270-
expected_opcount=ifdynstaticdefault(1, 2),
270+
expected_opcount=ifdynstaticdefault(2, 3),
271271
)
272272

273273
def test_wrap_pytree_args_nested(self):
@@ -304,7 +304,8 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor)
304304
305305
wrap_body_0 = self.wrap_body_0
306306
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
307-
return (wrap,)
307+
getitem = wrap[0]; wrap = None
308+
return (getitem,)
308309
309310
class GraphModule(torch.nn.Module):
310311
def forward(self, l_x_, l_y_, l_z_):
@@ -313,7 +314,7 @@ def forward(self, l_x_, l_y_, l_z_):
313314
add = sin + cos; sin = cos = None
314315
sin_1 = l_z_.sin(); l_z_ = None
315316
sub = add - sin_1; add = sin_1 = None
316-
return sub
317+
return (sub,)
317318
""",
318319
)
319320

@@ -328,7 +329,7 @@ def f(x, y):
328329
f,
329330
default_args_generator((x, y)),
330331
ifdynstaticdefault(2, 3),
331-
expected_opcount=ifdynstaticdefault(1, 2),
332+
expected_opcount=ifdynstaticdefault(2, 3),
332333
return_graph=True,
333334
)
334335
if torch._dynamo.config.assume_static_by_default:
@@ -341,13 +342,14 @@ def forward(self, L_x_ : torch.Tensor):
341342
342343
wrap_body_0 = self.wrap_body_0
343344
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
344-
return (wrap,)
345+
getitem = wrap[0]; wrap = None
346+
return (getitem,)
345347
346348
class GraphModule(torch.nn.Module):
347349
def forward(self, l_x_):
348350
view = l_x_.view(3); l_x_ = None
349351
add = view + 0.5; view = None
350-
return add
352+
return (add,)
351353
""",
352354
)
353355
else:
@@ -362,13 +364,14 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
362364
363365
wrap_body_0 = self.wrap_body_0
364366
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, size); wrap_body_0 = l_x_ = size = None
365-
return (wrap,)
367+
getitem = wrap[0]; wrap = None
368+
return (getitem,)
366369
367370
class GraphModule(torch.nn.Module):
368371
def forward(self, l_x_, size):
369372
view = l_x_.view(size); l_x_ = size = None
370373
add = view + 0.5; view = None
371-
return add
374+
return (add,)
372375
""",
373376
)
374377

@@ -443,14 +446,14 @@ def f(x):
443446

444447
self.assertEqual(result, x + global_var)
445448
self.assertEqual(cnt.frame_count, 1)
446-
self.assertEqual(cnt.op_count, 1)
449+
self.assertEqual(cnt.op_count, 2)
447450

448451
self.assertEqual(len(backend.graphs), 1)
449452
wrap_node = find_first_node(backend.graphs[0], wrap)
450453
self.assertTrue(len(wrap_node.args), 3)
451454

452455
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
453-
self.assertEqual(op_count(body_function), 1)
456+
self.assertEqual(op_count(body_function), 2)
454457
inner_wrap_node = find_first_node(body_function, wrap)
455458
self.assertTrue(len(inner_wrap_node.args), 3)
456459

@@ -532,7 +535,7 @@ def f(x, y):
532535

533536
self.assertEqual(result, x + y + x)
534537
self.assertEqual(cnt.frame_count, 1)
535-
self.assertEqual(cnt.op_count, 1)
538+
self.assertEqual(cnt.op_count, 2)
536539
self.assertEqual(len(backend.graphs), 1)
537540

538541
# No changes to args of outer wrap
@@ -542,14 +545,14 @@ def f(x, y):
542545

543546
# z was lifted to arg of inner wrap
544547
body_function = getattr(gm, wrap_node.args[0].name)
545-
# addition + wrap
546-
self.assertEqual(op_count(body_function), 2)
548+
# addition + wrap + getitem
549+
self.assertEqual(op_count(body_function), 3)
547550
inner_wrap_node = find_first_node(body_function, wrap)
548551
self.assertTrue(len(inner_wrap_node.args), 3)
549552

550553
# Innermost body function: z was also lifted to arg
551554
body_function = getattr(body_function, inner_wrap_node.args[0].name)
552-
self.assertEqual(op_count(body_function), 1)
555+
self.assertEqual(op_count(body_function), 2)
553556
inner_wrap_node = find_first_node(body_function, wrap)
554557
self.assertTrue(len(inner_wrap_node.args), 3)
555558

@@ -1048,14 +1051,14 @@ def fn(*, x, y, z=None):
10481051
counters.clear()
10491052
opt = torch.compile(f, backend="eager", fullgraph=True)
10501053
opt(x, y)
1051-
self.assertEqual(counters["stats"]["calls_captured"], 1)
1054+
self.assertEqual(counters["stats"]["calls_captured"], 2)
10521055

10531056
# verify that we `don't` recompile
10541057
opt(x, y)
1055-
self.assertEqual(counters["stats"]["calls_captured"], 1)
1058+
self.assertEqual(counters["stats"]["calls_captured"], 2)
10561059

10571060
output = opt(x, y, 8)
1058-
self.assertEqual(counters["stats"]["calls_captured"], 2)
1061+
self.assertEqual(counters["stats"]["calls_captured"], 4)
10591062
self.assertEqual(output, 2 * x)
10601063

10611064
def test_wrap_kwarg_default_else_branch(self):
@@ -1666,46 +1669,77 @@ def f(x):
16661669
{".*HigherOrderOperator body's output must consist of tensors only": 1},
16671670
)
16681671

1669-
def test_fallback_on_nested_tuple_output(self):
1670-
counters.clear()
1671-
1672-
backend = EagerAndRecordGraphs()
1673-
cnt = CompileCounterWithBackend(backend)
1674-
1675-
@torch.compile(backend=cnt)
1672+
def test_nested_tuple_output(self):
16761673
def f(x):
16771674
((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x)
16781675
return a + b
16791676

16801677
x = torch.randn(2, 3)
1681-
result = f(x)
16821678

1683-
self.assertEqual(result, x.sin() + x.cos())
1684-
self.assertEqual(cnt.frame_count, 1)
1685-
self.assertEqual(len(backend.graphs), 1)
1686-
wrap_node = find_first_node(backend.graphs[0], wrap)
1687-
self.assertTrue(len(wrap_node.args), 1)
1688-
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
1689-
self.assertEqual(op_count(body_function), 2)
1690-
1691-
def test_fallback_on_output_with_dict(self):
1692-
# We can likely support this in the future, I just don't want to deal
1693-
# with it right now
16941679
counters.clear()
1695-
cnt = CompileCounter()
1680+
graph = self._test_wrap_simple(
1681+
f, default_args_generator((x,)), 2, 4, return_graph=True
1682+
)
1683+
self.assertEqual(len(counters["graph_break"]), 0)
16961684

1697-
@torch.compile(backend=cnt)
1685+
if check_dynamic_shape_capture():
1686+
return
1687+
1688+
self.assertExpectedInline(
1689+
graph,
1690+
"""\
1691+
class GraphModule(torch.nn.Module):
1692+
def forward(self, L_x_ : torch.Tensor):
1693+
l_x_ = L_x_
1694+
1695+
wrap_body_0 = self.wrap_body_0
1696+
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
1697+
a = wrap[0]
1698+
b = wrap[1]; wrap = None
1699+
1700+
add = a + b; a = b = None
1701+
return (add,)
1702+
1703+
class GraphModule(torch.nn.Module):
1704+
def forward(self, l_x_):
1705+
child = l_x_.sin()
1706+
child_1 = l_x_.cos(); l_x_ = None
1707+
return (child, child_1)
1708+
""",
1709+
)
1710+
1711+
def test_output_with_dict(self):
16981712
def f(x):
16991713
return wrap(lambda x: [{"a": -x}], x)
17001714

17011715
x = torch.randn(3)
1702-
result = f(x)
1703-
self.assertEqual(result, [{"a": -x}])
1704-
self.assertEqual(cnt.frame_count, 0)
1705-
assert_dict_matches_regex(
1706-
self,
1707-
dict(counters["graph_break"]),
1708-
{".*torch.* op returned non-Tensor dict call_function": 1},
1716+
1717+
counters.clear()
1718+
graph = self._test_wrap_simple(
1719+
f, default_args_generator((x,)), 2, 2, return_graph=True
1720+
)
1721+
self.assertEqual(len(counters["graph_break"]), 0)
1722+
1723+
if check_dynamic_shape_capture():
1724+
return
1725+
1726+
self.assertExpectedInline(
1727+
graph,
1728+
"""\
1729+
class GraphModule(torch.nn.Module):
1730+
def forward(self, L_x_ : torch.Tensor):
1731+
l_x_ = L_x_
1732+
1733+
wrap_body_0 = self.wrap_body_0
1734+
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
1735+
getitem = wrap[0]; wrap = None
1736+
return (getitem,)
1737+
1738+
class GraphModule(torch.nn.Module):
1739+
def forward(self, l_x_):
1740+
child = -l_x_; l_x_ = None
1741+
return (child,)
1742+
""",
17091743
)
17101744

17111745
def test_access_module_attr(self):
@@ -1805,7 +1839,7 @@ def f(x):
18051839
return wrap(lambda x: x + y, x)
18061840

18071841
x = torch.randn(3)
1808-
self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=2)
1842+
self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=3)
18091843

18101844
def test_nested_wrap(self):
18111845
class MockModule(torch.nn.Module):
@@ -1825,16 +1859,14 @@ def gn(x):
18251859
def fn(x):
18261860
return wrap(gn, x)
18271861

1828-
self._test_wrap_simple(
1829-
fn, default_args_generator((torch.randn(10, 10),)), 4, expected_opcount=1
1830-
)
1862+
self._test_wrap_simple(fn, default_args_generator((torch.randn(10, 10),)), 4)
18311863

18321864
def test_fn_with_kwargs_in_torch_ops(self):
18331865
def fn(x):
18341866
return wrap(lambda z: torch.cos(input=z), x)
18351867

18361868
x = torch.randn(3)
1837-
self._test_wrap_simple(fn, default_args_generator((x,)), 2, expected_opcount=1)
1869+
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
18381870

18391871
def test_hooks(self):
18401872
class ToyModel(torch.nn.Module):

test/dynamo/test_subclasses.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,20 @@ def forward(self, L_x_ : torch.Tensor):
312312
313313
wrap_body_0 = self.wrap_body_0
314314
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
315-
return (wrap,)
315+
getitem = wrap[0]; wrap = None
316+
return (getitem,)
316317
317318
class GraphModule(torch.nn.Module):
318319
def forward(self, l_x_):
319320
add_ = l_x_.add_(1.0); l_x_ = None
320-
return add_
321+
return (add_,)
321322
"""
322-
check_count_and_graph(1, 1, 1, expected_graph)
323+
check_count_and_graph(1, 2, 1, expected_graph)
323324

324325
ff = torch.func.functionalize(f)
325326
ff_out = ff(t_clone)
326327
# frame count and op count are incremented due to re-compilation
327-
check_count_and_graph(2, 2, 2, expected_graph)
328+
check_count_and_graph(2, 4, 2, expected_graph)
328329

329330
try:
330331
x = torch._to_functional_tensor(t_clone2)
@@ -335,7 +336,7 @@ def forward(self, l_x_):
335336
torch._disable_functionalization()
336337

337338
# frame count and op count are incremented due to re-compilation
338-
check_count_and_graph(3, 3, 3, expected_graph)
339+
check_count_and_graph(3, 6, 3, expected_graph)
339340

340341
def test_has_torch_function(self):
341342
class MyTensor:

0 commit comments

Comments
 (0)