@@ -613,8 +613,8 @@ def forward(self, primals_1):
613613 t_1 = torch.ops.aten.t.default(clone); clone = None
614614 select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None
615615 t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None
616- t_3 = torch.ops.aten.t.default(t_2); t_2 = None
617- return [t_3 , 3, 3, 1, 3, 0]""" )
616+ t_4 = torch.ops.aten.t.default(t_2); t_2 = None
617+ return [t_4 , 3, 3, 1, 3, 0]""" )
618618
619619 def test_view_and_inplace_view (self ):
620620 def f (a , b ):
@@ -683,11 +683,12 @@ def forward(self, primals_1):
683683 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
684684 as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
685685 mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
686- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
687- as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
688- t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
686+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
687+ as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
688+ as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
689+ t_1 = torch.ops.aten.t.default(as_strided_6); as_strided_6 = None
689690 mul_1 = torch.ops.aten.mul.Tensor(t_1, 3); t_1 = None
690- return [mul , mul_1, 4, 1, 0]""" )
691+ return [as_strided_3 , mul_1, 4, 1, 0]""" )
691692
692693 def test_input_mutation_aliases_other_input (self ):
693694 def f (a , b ):
@@ -712,10 +713,11 @@ def forward(self, primals_1):
712713 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
713714 as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
714715 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
715- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
716- as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
717- add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
718- return [add, add_1]""" )
716+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
717+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
718+ as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
719+ add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
720+ return [as_strided_2, add_1]""" )
719721
720722 def test_input_mutation_aliases_other_input2 (self ):
721723 def f (a , b ):
@@ -736,10 +738,11 @@ def forward(self, primals_1):
736738 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
737739 as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
738740 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
739- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
740- as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
741- add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
742- return [add, add_1]""" )
741+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
742+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
743+ as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
744+ add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
745+ return [as_strided_2, add_1]""" )
743746
744747 def test_input_mutation_aliases_and_output_alias (self ):
745748 def f (a , b ):
@@ -758,9 +761,11 @@ def inp_callable():
758761 self .assertExpectedInline (fw_graph .code .strip (), """\
759762 def forward(self, primals_1):
760763 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
761- as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
764+ as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
762765 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
763- return [add, 4, 1, 0]""" )
766+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
767+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
768+ return [as_strided_2, 4, 1, 0]""" )
764769
765770 def test_input_aliased_with_mutation_output_alias (self ):
766771 def f (a , b , c ):
@@ -783,10 +788,12 @@ def inp_callable():
783788 self .assertExpectedInline (fw_graph .code .strip (), """\
784789 def forward(self, primals_1, primals_2):
785790 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
786- as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
791+ as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
787792 mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
793+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
794+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
788795 add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
789- return [mul , add, 4, 1, 0]""" )
796+ return [as_strided_2 , add, 4, 1, 0]""" )
790797
791798 def test_input_metadata_mutation_aliases (self ):
792799 def f (a , b ):
@@ -829,11 +836,12 @@ def forward(self, primals_1, primals_2):
829836 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
830837 as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
831838 mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None
832- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
833- as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
834- add = torch.ops.aten.add.Tensor(as_strided_2, 1); as_strided_2 = None
839+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
840+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
841+ as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
842+ add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
835843 add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
836- return [mul , add, add_1]""" )
844+ return [as_strided_2 , add, add_1]""" )
837845
838846 def test_input_mutation_aliases_bases_out_of_order (self ):
839847 # This tests our calling convention: if b and d are aliased, then the outer calling convention
@@ -864,12 +872,13 @@ def forward(self, primals_1, primals_2, primals_3):
864872 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
865873 as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
866874 add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
875+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
876+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
867877 add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
868- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = None
869- as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
870- t_1 = torch.ops.aten.t.default(as_strided_4); as_strided_4 = None
878+ as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
879+ t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
871880 add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = t_1 = None
872- return [add , add_2, 4, 1, 0, 4, 1, 0]""" )
881+ return [as_strided_2 , add_2, 4, 1, 0, 4, 1, 0]""" )
873882
874883 # Mondo test that tests a combination of:
875884 # input is mutated, that aliases another input (so we make a synthetic base)
@@ -913,10 +922,11 @@ def forward(self, primals_1, primals_2):
913922 clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
914923 as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
915924 mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
916- as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
917- as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
918- add = torch.ops.aten.add.Tensor(as_strided_4, mul); as_strided_4 = None
919- return [mul, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""" )
925+ as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
926+ as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
927+ as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
928+ add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = None
929+ return [as_strided_2, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""" )
920930
921931 def test_no_grad_input_output (self ):
922932 def f (a , b ):
0 commit comments