1515from itertools import product , permutations
1616
1717from test_jit import JitTestCase , enable_cpu_fuser , RUN_CUDA , RUN_CUDA_HALF , RUN_CUDA_MULTI_GPU , \
18- backward_graph , get_lstm_inputs , get_milstm_inputs , LSTMCellC , LSTMCellF , LSTMCellS , MiLSTMCell
18+ backward_graph , all_backward_graphs , get_lstm_inputs , get_milstm_inputs , LSTMCellC , LSTMCellF , LSTMCellS , MiLSTMCell
1919
2020
2121class TestFuser (JitTestCase ):
@@ -275,7 +275,7 @@ def funcOptMax(a, b):
275275 for f , inputs in product (funcs , [[a , b ], [a , nan ]]):
276276 inp1 , inp2 = inputs
277277 s = self .checkScript (f , (inp1 , inp2 ))
278- self .assertAllFused (s .graph_for (inp1 , inp2 ), except_for = {'aten::size' })
278+ self .assertAllFused (s .graph_for (inp1 , inp2 ), except_for = {'aten::size' , 'aten::_size_if_not_equal' })
279279
280280 c = s (inp1 , inp2 )
281281 c .sum ().backward ()
@@ -350,7 +350,8 @@ def f(x, y):
350350 self .assertAllFused (ge .graph_for (x , y ))
351351 x .requires_grad_ (True )
352352 y .requires_grad_ (True )
353- self .assertAllFused (ge .graph_for (x , y ), except_for = ("aten::size" , "prim::BroadcastSizes" ))
353+ self .assertAllFused (ge .graph_for (x , y ), except_for = ("aten::size" , "prim::BroadcastSizes" ,
354+ "aten::_size_if_not_equal" ))
354355
355356 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
356357 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -522,7 +523,8 @@ def fn_test_scalar_arg(x, p):
522523 self .assertAllFused (scripted .graph_for (x , p ))
523524 x .requires_grad_ (True )
524525 out = scripted (x , p )
525- self .assertAllFused (scripted .graph_for (x , p ), except_for = ("aten::size" , "prim::BroadcastSizes" ))
526+ self .assertAllFused (scripted .graph_for (x , p ), except_for = ("aten::size" , "prim::BroadcastSizes" ,
527+ "aten::_size_if_not_equal" ))
526528
527529 @unittest .skipIf (IS_WINDOWS or IS_SANDCASTLE , "NYI: fuser support for Windows or Sandcastle" )
528530 @enable_cpu_fuser
@@ -535,7 +537,7 @@ def f(x, y):
535537 b = torch .randn (5 , 5 , requires_grad = True )
536538 a = torch .randn (5 , 5 , requires_grad = True )
537539 s = self .checkScript (f , (a , b ))
538- self .assertAllFused (s .graph_for (a , b ), except_for = {'aten::size' })
540+ self .assertAllFused (s .graph_for (a , b ), except_for = {'aten::size' , 'aten::_size_if_not_equal' , 'prim::BroadcastSizes' })
539541
540542 c = s (a , b )
541543 ga , gb = torch .autograd .grad (c .sum (), [a , b ])
@@ -578,12 +580,12 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
578580
579581 s = self .checkScript (iou , (b1x1 , b1y1 , b1x2 , b1y2 , b2x1 , b2y1 , b2x2 , b2y2 ))
580582 self .assertAllFused (s .graph_for (b1x1 , b1y1 , b1x2 , b1y2 , b2x1 , b2y1 , b2x2 , b2y2 ),
581- except_for = {'aten::size' , 'prim::BroadcastSizes' })
583+ except_for = {'aten::size' , 'prim::BroadcastSizes' , 'aten::_size_if_not_equal' })
582584
583585 c = s (b1x1 , b1y1 , b1x2 , b1y2 , b2x1 , b2y1 , b2x2 , b2y2 )
584586 torch .autograd .grad (c .sum (), [b1x1 , b1y1 , b1x2 , b1y2 , b2x1 , b2y1 , b2x2 , b2y2 ])
585587 graph = backward_graph (s )
586- self .assertAllFused (graph , except_for = {'aten::size' , 'prim::BroadcastSizes' })
588+ self .assertAllFused (graph , except_for = {'aten::size' , 'prim::BroadcastSizes' , 'aten::_size_if_not_equal' })
587589
588590 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
589591 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -670,8 +672,8 @@ def test_lstm_cuda(self):
670672 hy , cy = module (* inputs )
671673 (hy + cy ).sum ().backward ()
672674 backward = backward_graph (module )
673- FileCheck (). check ( "FusionGroup_0" ). check_next ( "FusionGroup_1" ) \
674- . check_not ( "FusionGroup_2" ). run ( str ( backward ))
675+ self . assertAllFused ( backward , except_for = ( "aten::t" , "aten::mm" ,
676+ "aten::_grad_sum_to_size" ))
675677
676678 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
677679 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -801,7 +803,8 @@ def fn_test_erf(x):
801803 ge = self .checkTrace (fn_test_erf , (x ,))
802804 self .assertAllFused (ge .graph_for (x ))
803805 x .requires_grad_ (True )
804- self .assertAllFused (ge .graph_for (x ), except_for = ("aten::size" , "prim::BroadcastSizes" ))
806+ self .assertAllFused (ge .graph_for (x ), except_for = ("aten::size" , "prim::BroadcastSizes" ,
807+ "aten::_size_if_not_equal" ))
805808
806809 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
807810 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -818,7 +821,8 @@ def fn_test_rand(x, y):
818821 self .assertAllFused (script_f .graph_for (x , y ))
819822 x .requires_grad_ (True )
820823 out = script_f (x , y )
821- self .assertAllFused (script_f .graph_for (x , y ), except_for = ("aten::size" , "prim::BroadcastSizes" ))
824+ self .assertAllFused (script_f .graph_for (x , y ), except_for = ("aten::size" , "prim::BroadcastSizes" ,
825+ "aten::_size_if_not_equal" ))
822826 # test that broadcasting random produces correct results
823827 x = torch .ones (4 , 4 , dtype = torch .float , device = 'cuda' )
824828 y = torch .ones (4 , dtype = torch .float , device = 'cuda' )
@@ -894,6 +898,44 @@ def f(x, y):
894898 self .assertEqual (result2 , expected2 )
895899 self .assertAllFused (script_f .graph_for (x , y ), except_for = {'prim::TupleConstruct' })
896900
901+
902+ @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
903+ @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
904+ @skipIfRocm
905+ def test_grad_sum_to_size_elimination (self ):
906+
907+ def my_broadcasted_cell (a , b , c ):
908+ return (a + b ) + c
909+
910+ s1 = torch .randn (5 , 1 , requires_grad = True , device = 'cuda' )
911+ s2 = torch .randn (5 , 5 , requires_grad = True , device = 'cuda' )
912+
913+ module = self .checkScript (my_broadcasted_cell , (s1 , s1 , s1 ))
914+ forward_graph = module .graph_for (s1 , s1 , s1 )
915+ self .assertAllFused (forward_graph , except_for = ("aten::size" , "prim::BroadcastSizes" ,
916+ "aten::_size_if_not_equal" ))
917+
918+ old_plans = set ()
919+ for i in range (3 ):
920+ # if we have s2, then the s1 are _grad_sum_to_size'd
921+ args = s2 if i < 1 else s1 , s2 if i < 2 else s1 , s2
922+ args = [a .detach_ ().requires_grad_ () for a in args ]
923+ res = module (s2 if i < 1 else s1 , s2 if i < 2 else s1 , s2 )
924+ grads = torch .autograd .grad (res .sum (), args )
925+ for inp , gr in zip (args , grads ):
926+ self .assertEqual (inp .shape , gr .shape )
927+ backward = None
928+ # this is a workaround for the backward graphs not being
929+ # in order for Python 2
930+ for g in all_backward_graphs (module ):
931+ if str (g ) not in old_plans :
932+ assert backward is None
933+ backward = g
934+ old_plans .add (str (backward ))
935+ self .assertEqual (len ([1 for o in backward .outputs () if o .node ().kind () == "aten::_grad_sum_to_size" ]), i )
936+ self .assertEqual (len ([1 for o in backward .outputs () if o .node ().kind () == "prim::Param" ]), 3 - i )
937+
938+
897939 @unittest .skipIf (not IS_WINDOWS , "Test that the fuser is disabled on Windows" )
898940 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
899941 def test_windows_cuda (self ):
0 commit comments