@@ -146,7 +146,7 @@ def _test_wrap_simple(
146146 func ,
147147 args_generator ,
148148 expected_num_wrap_args ,
149- expected_opcount = 1 ,
149+ expected_opcount = 2 ,
150150 return_graph = False ,
151151 ):
152152 # Given a `func` that has a single call to `wrap`,
@@ -267,7 +267,7 @@ def f(x):
267267 f ,
268268 default_args_generator ((x ,)),
269269 ifdynstaticdefault (2 , 3 ),
270- expected_opcount = ifdynstaticdefault (1 , 2 ),
270+ expected_opcount = ifdynstaticdefault (2 , 3 ),
271271 )
272272
273273 def test_wrap_pytree_args_nested (self ):
@@ -304,7 +304,8 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor)
304304
305305 wrap_body_0 = self.wrap_body_0
306306 wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
307- return (wrap,)
307+ getitem = wrap[0]; wrap = None
308+ return (getitem,)
308309
309310 class GraphModule(torch.nn.Module):
310311 def forward(self, l_x_, l_y_, l_z_):
@@ -313,7 +314,7 @@ def forward(self, l_x_, l_y_, l_z_):
313314 add = sin + cos; sin = cos = None
314315 sin_1 = l_z_.sin(); l_z_ = None
315316 sub = add - sin_1; add = sin_1 = None
316- return sub
317+ return ( sub,)
317318""" ,
318319 )
319320
@@ -328,7 +329,7 @@ def f(x, y):
328329 f ,
329330 default_args_generator ((x , y )),
330331 ifdynstaticdefault (2 , 3 ),
331- expected_opcount = ifdynstaticdefault (1 , 2 ),
332+ expected_opcount = ifdynstaticdefault (2 , 3 ),
332333 return_graph = True ,
333334 )
334335 if torch ._dynamo .config .assume_static_by_default :
@@ -341,13 +342,14 @@ def forward(self, L_x_ : torch.Tensor):
341342
342343 wrap_body_0 = self.wrap_body_0
343344 wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
344- return (wrap,)
345+ getitem = wrap[0]; wrap = None
346+ return (getitem,)
345347
346348 class GraphModule(torch.nn.Module):
347349 def forward(self, l_x_):
348350 view = l_x_.view(3); l_x_ = None
349351 add = view + 0.5; view = None
350- return add
352+ return ( add,)
351353""" ,
352354 )
353355 else :
@@ -362,13 +364,14 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
362364
363365 wrap_body_0 = self.wrap_body_0
364366 wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, size); wrap_body_0 = l_x_ = size = None
365- return (wrap,)
367+ getitem = wrap[0]; wrap = None
368+ return (getitem,)
366369
367370 class GraphModule(torch.nn.Module):
368371 def forward(self, l_x_, size):
369372 view = l_x_.view(size); l_x_ = size = None
370373 add = view + 0.5; view = None
371- return add
374+ return ( add,)
372375""" ,
373376 )
374377
@@ -443,14 +446,14 @@ def f(x):
443446
444447 self .assertEqual (result , x + global_var )
445448 self .assertEqual (cnt .frame_count , 1 )
446- self .assertEqual (cnt .op_count , 1 )
449+ self .assertEqual (cnt .op_count , 2 )
447450
448451 self .assertEqual (len (backend .graphs ), 1 )
449452 wrap_node = find_first_node (backend .graphs [0 ], wrap )
450453 self .assertTrue (len (wrap_node .args ), 3 )
451454
452455 body_function = getattr (backend .graphs [0 ], wrap_node .args [0 ].name )
453- self .assertEqual (op_count (body_function ), 1 )
456+ self .assertEqual (op_count (body_function ), 2 )
454457 inner_wrap_node = find_first_node (body_function , wrap )
455458 self .assertTrue (len (inner_wrap_node .args ), 3 )
456459
@@ -532,7 +535,7 @@ def f(x, y):
532535
533536 self .assertEqual (result , x + y + x )
534537 self .assertEqual (cnt .frame_count , 1 )
535- self .assertEqual (cnt .op_count , 1 )
538+ self .assertEqual (cnt .op_count , 2 )
536539 self .assertEqual (len (backend .graphs ), 1 )
537540
538541 # No changes to args of outer wrap
@@ -542,14 +545,14 @@ def f(x, y):
542545
543546 # z was lifted to arg of inner wrap
544547 body_function = getattr (gm , wrap_node .args [0 ].name )
545- # addition + wrap
546- self .assertEqual (op_count (body_function ), 2 )
548+ # addition + wrap + getitem
549+ self .assertEqual (op_count (body_function ), 3 )
547550 inner_wrap_node = find_first_node (body_function , wrap )
548551 self .assertTrue (len (inner_wrap_node .args ), 3 )
549552
550553 # Innermost body function: z was also lifted to arg
551554 body_function = getattr (body_function , inner_wrap_node .args [0 ].name )
552- self .assertEqual (op_count (body_function ), 1 )
555+ self .assertEqual (op_count (body_function ), 2 )
553556 inner_wrap_node = find_first_node (body_function , wrap )
554557 self .assertTrue (len (inner_wrap_node .args ), 3 )
555558
@@ -1048,14 +1051,14 @@ def fn(*, x, y, z=None):
10481051 counters .clear ()
10491052 opt = torch .compile (f , backend = "eager" , fullgraph = True )
10501053 opt (x , y )
1051- self .assertEqual (counters ["stats" ]["calls_captured" ], 1 )
1054+ self .assertEqual (counters ["stats" ]["calls_captured" ], 2 )
10521055
10531056 # verify that we `don't` recompile
10541057 opt (x , y )
1055- self .assertEqual (counters ["stats" ]["calls_captured" ], 1 )
1058+ self .assertEqual (counters ["stats" ]["calls_captured" ], 2 )
10561059
10571060 output = opt (x , y , 8 )
1058- self .assertEqual (counters ["stats" ]["calls_captured" ], 2 )
1061+ self .assertEqual (counters ["stats" ]["calls_captured" ], 4 )
10591062 self .assertEqual (output , 2 * x )
10601063
10611064 def test_wrap_kwarg_default_else_branch (self ):
@@ -1666,46 +1669,77 @@ def f(x):
16661669 {".*HigherOrderOperator body's output must consist of tensors only" : 1 },
16671670 )
16681671
1669- def test_fallback_on_nested_tuple_output (self ):
1670- counters .clear ()
1671-
1672- backend = EagerAndRecordGraphs ()
1673- cnt = CompileCounterWithBackend (backend )
1674-
1675- @torch .compile (backend = cnt )
1672+ def test_nested_tuple_output (self ):
16761673 def f (x ):
16771674 ((a , b ),) = wrap (lambda x : ((x .sin (), x .cos ()),), x )
16781675 return a + b
16791676
16801677 x = torch .randn (2 , 3 )
1681- result = f (x )
16821678
1683- self .assertEqual (result , x .sin () + x .cos ())
1684- self .assertEqual (cnt .frame_count , 1 )
1685- self .assertEqual (len (backend .graphs ), 1 )
1686- wrap_node = find_first_node (backend .graphs [0 ], wrap )
1687- self .assertTrue (len (wrap_node .args ), 1 )
1688- body_function = getattr (backend .graphs [0 ], wrap_node .args [0 ].name )
1689- self .assertEqual (op_count (body_function ), 2 )
1690-
1691- def test_fallback_on_output_with_dict (self ):
1692- # We can likely support this in the future, I just don't want to deal
1693- # with it right now
16941679 counters .clear ()
1695- cnt = CompileCounter ()
1680+ graph = self ._test_wrap_simple (
1681+ f , default_args_generator ((x ,)), 2 , 4 , return_graph = True
1682+ )
1683+ self .assertEqual (len (counters ["graph_break" ]), 0 )
16961684
1697- @torch .compile (backend = cnt )
1685+ if check_dynamic_shape_capture ():
1686+ return
1687+
1688+ self .assertExpectedInline (
1689+ graph ,
1690+ """\
1691+ class GraphModule(torch.nn.Module):
1692+ def forward(self, L_x_ : torch.Tensor):
1693+ l_x_ = L_x_
1694+
1695+ wrap_body_0 = self.wrap_body_0
1696+ wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
1697+ a = wrap[0]
1698+ b = wrap[1]; wrap = None
1699+
1700+ add = a + b; a = b = None
1701+ return (add,)
1702+
1703+ class GraphModule(torch.nn.Module):
1704+ def forward(self, l_x_):
1705+ child = l_x_.sin()
1706+ child_1 = l_x_.cos(); l_x_ = None
1707+ return (child, child_1)
1708+ """ ,
1709+ )
1710+
1711+ def test_output_with_dict (self ):
16981712 def f (x ):
16991713 return wrap (lambda x : [{"a" : - x }], x )
17001714
17011715 x = torch .randn (3 )
1702- result = f (x )
1703- self .assertEqual (result , [{"a" : - x }])
1704- self .assertEqual (cnt .frame_count , 0 )
1705- assert_dict_matches_regex (
1706- self ,
1707- dict (counters ["graph_break" ]),
1708- {".*torch.* op returned non-Tensor dict call_function" : 1 },
1716+
1717+ counters .clear ()
1718+ graph = self ._test_wrap_simple (
1719+ f , default_args_generator ((x ,)), 2 , 2 , return_graph = True
1720+ )
1721+ self .assertEqual (len (counters ["graph_break" ]), 0 )
1722+
1723+ if check_dynamic_shape_capture ():
1724+ return
1725+
1726+ self .assertExpectedInline (
1727+ graph ,
1728+ """\
1729+ class GraphModule(torch.nn.Module):
1730+ def forward(self, L_x_ : torch.Tensor):
1731+ l_x_ = L_x_
1732+
1733+ wrap_body_0 = self.wrap_body_0
1734+ wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
1735+ getitem = wrap[0]; wrap = None
1736+ return (getitem,)
1737+
1738+ class GraphModule(torch.nn.Module):
1739+ def forward(self, l_x_):
1740+ child = -l_x_; l_x_ = None
1741+ return (child,)
1742+ """ ,
17091743 )
17101744
17111745 def test_access_module_attr (self ):
@@ -1805,7 +1839,7 @@ def f(x):
18051839 return wrap (lambda x : x + y , x )
18061840
18071841 x = torch .randn (3 )
1808- self ._test_wrap_simple (f , default_args_generator ((x ,)), 3 , expected_opcount = 2 )
1842+ self ._test_wrap_simple (f , default_args_generator ((x ,)), 3 , expected_opcount = 3 )
18091843
18101844 def test_nested_wrap (self ):
18111845 class MockModule (torch .nn .Module ):
@@ -1825,16 +1859,14 @@ def gn(x):
18251859 def fn (x ):
18261860 return wrap (gn , x )
18271861
1828- self ._test_wrap_simple (
1829- fn , default_args_generator ((torch .randn (10 , 10 ),)), 4 , expected_opcount = 1
1830- )
1862+ self ._test_wrap_simple (fn , default_args_generator ((torch .randn (10 , 10 ),)), 4 )
18311863
18321864 def test_fn_with_kwargs_in_torch_ops (self ):
18331865 def fn (x ):
18341866 return wrap (lambda z : torch .cos (input = z ), x )
18351867
18361868 x = torch .randn (3 )
1837- self ._test_wrap_simple (fn , default_args_generator ((x ,)), 2 , expected_opcount = 1 )
1869+ self ._test_wrap_simple (fn , default_args_generator ((x ,)), 2 )
18381870
18391871 def test_hooks (self ):
18401872 class ToyModel (torch .nn .Module ):
0 commit comments