@@ -3670,6 +3670,72 @@ def forward(self, input, target):
36703670
36713671 self .run_test (CrossEntropyLossMeanWeightIgnoreIndex (), input = (x , y ))
36723672
3673+ @skipIfUnsupportedMinOpsetVersion (9 )
3674+ def test_kldiv_loss (self ):
3675+
3676+ x = torch .randn (5 )
3677+ y = torch .randn (5 )
3678+ self ._kldiv_loss (x , y )
3679+
3680+ x = torch .randn (2 , 3 , 5 )
3681+ y = torch .randn (2 , 3 , 5 )
3682+ self ._kldiv_loss (x , y )
3683+
3684+ x = torch .randn (2 , 3 , 5 , 7 )
3685+ y = torch .randn (2 , 3 , 5 , 7 )
3686+ self ._kldiv_loss (x , y )
3687+
3688+ def _kldiv_loss (self , x , y ):
3689+ class KLDivLossNone (torch .nn .Module ):
3690+ def __init__ (self ):
3691+ super (KLDivLossNone , self ).__init__ ()
3692+ self .loss = torch .nn .KLDivLoss (reduction = 'none' , log_target = True )
3693+
3694+ def forward (self , input , target ):
3695+ return self .loss (input , target )
3696+
3697+ self .run_test (KLDivLossNone (), input = (x , y ))
3698+
3699+ class KLDivLossMean (torch .nn .Module ):
3700+ def __init__ (self ):
3701+ super (KLDivLossMean , self ).__init__ ()
3702+ self .loss = torch .nn .KLDivLoss (reduction = 'mean' , log_target = False )
3703+
3704+ def forward (self , input , target ):
3705+ return self .loss (input , target )
3706+
3707+ self .run_test (KLDivLossMean (), input = (x , y ))
3708+
3709+ class KLDivLossSum (torch .nn .Module ):
3710+ def __init__ (self ):
3711+ super (KLDivLossSum , self ).__init__ ()
3712+ self .loss = torch .nn .KLDivLoss (reduction = 'sum' , log_target = True )
3713+
3714+ def forward (self , input , target ):
3715+ return self .loss (input , target )
3716+
3717+ self .run_test (KLDivLossSum (), input = (x , y ))
3718+
3719+ class KLDivLossBatchMean (torch .nn .Module ):
3720+ def __init__ (self ):
3721+ super (KLDivLossBatchMean , self ).__init__ ()
3722+ self .loss = torch .nn .KLDivLoss (reduction = 'batchmean' , log_target = False )
3723+
3724+ def forward (self , input , target ):
3725+ return self .loss (input , target )
3726+
3727+ self .run_test (KLDivLossBatchMean (), input = (x , y ))
3728+
3729+ class KLDivLossMiniBatchMean (torch .nn .Module ):
3730+ def __init__ (self ):
3731+ super (KLDivLossMiniBatchMean , self ).__init__ ()
3732+ self .loss = torch .nn .KLDivLoss (reduction = 'batchmean' , size_average = False , log_target = True )
3733+
3734+ def forward (self , input , target ):
3735+ return self .loss (input , target )
3736+
3737+ self .run_test (KLDivLossMiniBatchMean (), input = (x , y ))
3738+
36733739 @skipIfUnsupportedMinOpsetVersion (12 )
36743740 def test_nllloss (self ):
36753741 class NLLModel (torch .nn .Module ):
0 commit comments