@@ -405,8 +405,9 @@ def _nll_loss_backward(
405405 grad_output = grad_output / total_weight
406406
407407 target = target .unsqueeze (channel_dim )
408+ safe_target = torch .where (target != ignore_index , target , 0 )
408409 grad_input = torch .zeros_like (self )
409- grad_input = torch .scatter (grad_input , channel_dim , target , - 1.0 )
410+ grad_input = torch .scatter (grad_input , channel_dim , safe_target , - 1.0 )
410411
411412 if grad_input .dim () > grad_output .dim () > 0 :
412413 grad_output = grad_output .unsqueeze (channel_dim )
@@ -417,9 +418,7 @@ def _nll_loss_backward(
417418 weight = weight .reshape (new_shape )
418419 grad_output = grad_output * weight
419420
420- has_ignore_index = ignore_index >= 0
421- if has_ignore_index :
422- grad_output = torch .where (target != ignore_index , grad_output , 0 )
421+ grad_output = torch .where (target != ignore_index , grad_output , 0 )
423422
424423 return grad_input * grad_output
425424
@@ -2798,37 +2797,30 @@ def nll_loss_forward(
27982797 if weight is not None :
27992798 w = weight .unsqueeze (0 ) if n_dims > 1 else weight
28002799 self = self * w
2801-
2802- target_ = target .unsqueeze (channel_dim )
2800+ safe_target = torch . where ( target != ignore_index , target , 0 )
2801+ safe_target_ = safe_target .unsqueeze (channel_dim )
28032802 # target can be [N, 1] or [1]
28042803
2805- result = - torch .gather (self , channel_dim , target_ ).squeeze (channel_dim )
2804+ result = - torch .gather (self , channel_dim , safe_target_ ).squeeze (channel_dim )
28062805
2807- if ignore_index >= 0 :
2808- result = torch .where (target != ignore_index , result , 0 )
2806+ result = torch .where (target != ignore_index , result , 0 )
28092807
28102808 if reduction == Reduction .NONE .value and n_dims > 1 :
28112809 total_weight = self .new_full ((), 0.0 )
28122810 return result , total_weight
28132811
28142812 if weight is not None :
28152813 w = weight .unsqueeze (0 ).expand (self .shape ) if n_dims > 1 else weight
2816- wsum = torch .gather (w , channel_dim , target_ ).squeeze (channel_dim )
2817- if ignore_index >= 0 :
2818- wsum = torch .where (target != ignore_index , wsum , 0 )
2814+ wsum = torch .gather (w , channel_dim , safe_target_ ).squeeze (channel_dim )
2815+ wsum = torch .where (target != ignore_index , wsum , 0 )
28192816 total_weight = wsum .sum ()
2820- elif ignore_index >= 0 :
2821- total_weight = (target != ignore_index ).sum ().to (self )
28222817 else :
2823- total_weight = self . new_full ((), 1.0 * result . numel () )
2818+ total_weight = ( target != ignore_index ). sum (). to ( self )
28242819
28252820 if reduction == Reduction .SUM .value :
28262821 result = result .sum ()
28272822 elif reduction == Reduction .MEAN .value :
2828- if weight is None :
2829- result = result .sum () / total_weight if ignore_index >= 0 else result .mean ()
2830- else :
2831- result = result .sum () / total_weight
2823+ result = result .sum () / total_weight
28322824
28332825 return result , total_weight
28342826
0 commit comments