@@ -3742,138 +3742,102 @@ def forward(self, *tensor_list):
37423742 @skipIfUnsupportedMinOpsetVersion (12 )
37433743 @disableScriptTest ()
37443744 def test_crossentropyloss (self ):
3745- x = torch .randn (3 , 5 )
3746- y = torch .empty (3 , dtype = torch .long ).random_ (5 )
3747- self ._crossentropyloss (x , y )
3745+ for ignore_index in [- 100 , 1 ]:
3746+ x = torch .randn (3 , 5 )
3747+ y = torch .empty (3 , dtype = torch .long ).random_ (5 )
3748+ y [y == 1 ] = ignore_index
37483749
3749- x = torch .randn (3 , 5 , 2 )
3750- y = torch .empty (3 , 2 , dtype = torch .long ).random_ (5 )
3751- self ._crossentropyloss (x , y )
3750+ self ._crossentropyloss (x , y , ignore_index )
37523751
3753- x = torch .randn (3 , 5 , 2 , 7 )
3754- y = torch .empty (3 , 2 , 7 , dtype = torch .long ).random_ (5 )
3755- self ._crossentropyloss (x , y )
3752+ x = torch .randn (3 , 5 , 2 )
3753+ y = torch .empty (3 , 2 , dtype = torch .long ).random_ (5 )
3754+ y [y == 1 ] = ignore_index
3755+ self ._crossentropyloss (x , y , ignore_index )
37563756
3757- def _crossentropyloss (self , x , y ):
3757+ x = torch .randn (3 , 5 , 2 , 7 )
3758+ y = torch .empty (3 , 2 , 7 , dtype = torch .long ).random_ (5 )
3759+ y [y == 1 ] = ignore_index
3760+ self ._crossentropyloss (x , y , ignore_index )
3761+
3762+ def _crossentropyloss (self , x , y , ignore_index ):
37583763 class CrossEntropyLossNone (torch .nn .Module ):
3759- def __init__ (self ):
3764+ def __init__ (self , ignore_index ):
37603765 super (CrossEntropyLossNone , self ).__init__ ()
3761- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' )
3766+ if ignore_index == - 100 :
3767+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' )
3768+ else :
3769+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , ignore_index = ignore_index )
37623770
37633771 def forward (self , input , target ):
37643772 return self .loss (input , target )
37653773
3766- self .run_test (CrossEntropyLossNone (), input = (x , y ))
3774+ self .run_test (CrossEntropyLossNone (ignore_index ), input = (x , y ))
37673775
37683776 class CrossEntropyLossNoneWeight (torch .nn .Module ):
3769- def __init__ (self ):
3777+ def __init__ (self , ignore_index ):
37703778 super (CrossEntropyLossNoneWeight , self ).__init__ ()
3771- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ))
3779+ if ignore_index == - 100 :
3780+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ))
3781+ else :
3782+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ), ignore_index = ignore_index )
37723783
37733784 def forward (self , input , target ):
37743785 return self .loss (input , target )
37753786
3776- self .run_test (CrossEntropyLossNoneWeight (), input = (x , y ))
3787+ self .run_test (CrossEntropyLossNoneWeight (ignore_index ), input = (x , y ))
37773788
37783789 class CrossEntropyLossSum (torch .nn .Module ):
3779- def __init__ (self ):
3790+ def __init__ (self , ignore_index ):
37803791 super (CrossEntropyLossSum , self ).__init__ ()
3781- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' )
3792+ if ignore_index == - 100 :
3793+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' )
3794+ else :
3795+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , ignore_index = ignore_index )
37823796
37833797 def forward (self , input , target ):
37843798 return self .loss (input , target )
37853799
3786- self .run_test (CrossEntropyLossSum (), input = (x , y ))
3800+ self .run_test (CrossEntropyLossSum (ignore_index ), input = (x , y ))
37873801
37883802 class CrossEntropyLossSumWeight (torch .nn .Module ):
3789- def __init__ (self ):
3803+ def __init__ (self , ignore_index ):
37903804 super (CrossEntropyLossSumWeight , self ).__init__ ()
3791- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ))
3805+ if ignore_index == - 100 :
3806+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ))
3807+ else :
3808+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ), ignore_index = ignore_index )
37923809
37933810 def forward (self , input , target ):
37943811 return self .loss (input , target )
37953812
3796- self .run_test (CrossEntropyLossSumWeight (), input = (x , y ))
3813+ self .run_test (CrossEntropyLossSumWeight (ignore_index ), input = (x , y ))
37973814
37983815 class CrossEntropyLossMean (torch .nn .Module ):
3799- def __init__ (self ):
3816+ def __init__ (self , ignore_index ):
38003817 super (CrossEntropyLossMean , self ).__init__ ()
3801- self .loss = torch .nn .CrossEntropyLoss ()
3818+ if ignore_index == - 100 :
3819+ self .loss = torch .nn .CrossEntropyLoss ()
3820+ else :
3821+ self .loss = torch .nn .CrossEntropyLoss (ignore_index = ignore_index )
38023822
38033823 def forward (self , input , target ):
38043824 return self .loss (input , target )
38053825
3806- self .run_test (CrossEntropyLossMean (), input = (x , y ))
3826+ self .run_test (CrossEntropyLossMean (ignore_index ), input = (x , y ))
38073827
38083828 class CrossEntropyLossMeanWeight (torch .nn .Module ):
3809- def __init__ (self ):
3829+ def __init__ (self , ignore_index ):
38103830 super (CrossEntropyLossMeanWeight , self ).__init__ ()
3811- self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ))
3812-
3813- def forward (self , input , target ):
3814- return self .loss (input , target )
3815-
3816- self .run_test (CrossEntropyLossMeanWeight (), input = (x , y ))
3817-
3818- class CrossEntropyLossNoneIgnoreIndex (torch .nn .Module ):
3819- def __init__ (self ):
3820- super (CrossEntropyLossNoneIgnoreIndex , self ).__init__ ()
3821- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , ignore_index = 1 )
3822-
3823- def forward (self , input , target ):
3824- return self .loss (input , target )
3825-
3826- self .run_test (CrossEntropyLossNoneIgnoreIndex (), input = (x , y ))
3827-
3828- class CrossEntropyLossNoneWeightIgnoreIndex (torch .nn .Module ):
3829- def __init__ (self ):
3830- super (CrossEntropyLossNoneWeightIgnoreIndex , self ).__init__ ()
3831- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ), ignore_index = 1 )
3832-
3833- def forward (self , input , target ):
3834- return self .loss (input , target )
3835-
3836- self .run_test (CrossEntropyLossNoneWeightIgnoreIndex (), input = (x , y ))
3837-
3838- class CrossEntropyLossSumIgnoreIndex (torch .nn .Module ):
3839- def __init__ (self ):
3840- super (CrossEntropyLossSumIgnoreIndex , self ).__init__ ()
3841- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , ignore_index = 1 )
3842-
3843- def forward (self , input , target ):
3844- return self .loss (input , target )
3845-
3846- self .run_test (CrossEntropyLossSumIgnoreIndex (), input = (x , y ))
3847-
3848- class CrossEntropyLossSumWeightIgnoreIndex (torch .nn .Module ):
3849- def __init__ (self ):
3850- super (CrossEntropyLossSumWeightIgnoreIndex , self ).__init__ ()
3851- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ), ignore_index = 1 )
3852-
3853- def forward (self , input , target ):
3854- return self .loss (input , target )
3855-
3856- self .run_test (CrossEntropyLossSumWeightIgnoreIndex (), input = (x , y ))
3857-
3858- class CrossEntropyLossMeanIgnoreIndex (torch .nn .Module ):
3859- def __init__ (self ):
3860- super (CrossEntropyLossMeanIgnoreIndex , self ).__init__ ()
3861- self .loss = torch .nn .CrossEntropyLoss (ignore_index = 1 )
3831+ if ignore_index == - 100 :
3832+ self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ))
3833+ else :
3834+ self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ), ignore_index = ignore_index )
38623835
38633836 def forward (self , input , target ):
38643837 return self .loss (input , target )
38653838
3866- self .run_test (CrossEntropyLossMeanIgnoreIndex ( ), input = (x , y ))
3839+ self .run_test (CrossEntropyLossMeanWeight ( ignore_index ), input = (x , y ))
38673840
3868- class CrossEntropyLossMeanWeightIgnoreIndex (torch .nn .Module ):
3869- def __init__ (self ):
3870- super (CrossEntropyLossMeanWeightIgnoreIndex , self ).__init__ ()
3871- self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ), ignore_index = 1 )
3872-
3873- def forward (self , input , target ):
3874- return self .loss (input , target )
3875-
3876- self .run_test (CrossEntropyLossMeanWeightIgnoreIndex (), input = (x , y ))
38773841
38783842 @skipIfUnsupportedMinOpsetVersion (9 )
38793843 def test_kldiv_loss (self ):
@@ -3957,6 +3921,9 @@ def forward(self, input, target):
39573921 N , C = 5 , 4
39583922 input = torch .randn (N , 16 )
39593923 target = torch .empty (N , dtype = torch .long ).random_ (0 , C )
3924+
3925+ # using test data containing default ignore_index=-100
3926+ target [target == 1 ] = - 100
39603927 self .run_test (NLLModel (), (input , target ))
39613928
39623929 @skipIfUnsupportedMinOpsetVersion (12 )
@@ -3976,6 +3943,9 @@ def forward(self, input, target):
39763943 N , C = 5 , 4
39773944 input = torch .randn (N , 16 , 10 , 10 )
39783945 target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3946+
3947+ # using test data containing default ignore_index=-100
3948+ target [target == 1 ] = - 100
39793949 self .run_test (NLLModel (), (input , target ))
39803950
39813951 @skipIfUnsupportedMinOpsetVersion (12 )
@@ -3995,6 +3965,9 @@ def forward(self, input, target):
39953965 N , C = 5 , 4
39963966 input = torch .randn (N , 16 , 10 , 10 )
39973967 target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3968+
3969+ # using test data containing default ignore_index=-100
3970+ target [target == 1 ] = - 100
39983971 self .run_test (NLLModel (), (input , target ))
39993972
40003973 @skipIfUnsupportedMinOpsetVersion (12 )
@@ -4014,6 +3987,9 @@ def forward(self, input, target):
40143987 N , C = 5 , 4
40153988 input = torch .randn (N , 16 , 10 , 10 )
40163989 target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3990+
3991+ # using test data containing default ignore_index=-100
3992+ target [target == 1 ] = - 100
40173993 self .run_test (NLLModel (), (input , target ))
40183994
40193995 @skipIfUnsupportedMinOpsetVersion (12 )
@@ -4033,6 +4009,9 @@ def forward(self, input, target):
40334009 N , C = 5 , 4
40344010 input = torch .randn (N , 16 , 10 , 10 )
40354011 target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
4012+
4013+ # using test data containing default ignore_index=-100
4014+ target [target == 1 ] = - 100
40364015 self .run_test (NLLModel (), (input , target ))
40374016
40384017 @skipIfUnsupportedMinOpsetVersion (12 )
0 commit comments