@@ -1162,10 +1162,10 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
11621162
11631163 Args:
11641164 input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
1165- in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 2 `
1165+ in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 1 `
11661166 in the case of K-dimensional loss.
11671167 target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`,
1168- or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 2 ` for
1168+ or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 1 ` for
11691169 K-dimensional loss.
11701170 weight (Tensor, optional): a manual rescaling weight given to each
11711171 class. If given, has to be a Tensor of size `C`
@@ -1192,7 +1192,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
11921192 return torch ._C ._nn .nll_loss (input , target , weight , size_average , ignore_index , reduce )
11931193 elif dim == 4 :
11941194 return torch ._C ._nn .nll_loss2d (input , target , weight , size_average , ignore_index , reduce )
1195- elif dim > 4 :
1195+ elif dim == 3 or dim > 4 :
11961196 n = input .size (0 )
11971197 c = input .size (1 )
11981198 out_size = (n ,) + input .size ()[2 :]
@@ -1206,7 +1206,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
12061206 out = torch ._C ._nn .nll_loss2d (input , target , weight , size_average , ignore_index , reduce )
12071207 return out .view (out_size )
12081208 else :
1209- raise ValueError ('Expected 2, 4, or more than 4 dimensions (got {})' .format (dim ))
1209+ raise ValueError ('Expected 2 or more dimensions (got {})' .format (dim ))
12101210
12111211
12121212def poisson_nll_loss (input , target , log_input = True , full = False , size_average = True , eps = 1e-8 , reduce = True ):
0 commit comments