Skip to content

Commit 05908e8

Browse files
dhpollackapaszke
authored andcommitted
current code works with dim = 3, so I added it to dim checks
1 parent 9b6441e commit 05908e8

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

test/common_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def kldivloss_reference(input, target, size_average=True, reduce=True):
259259

260260
def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
261261
size_average=True, reduce=True):
262-
assert input.dim() >= 4
262+
assert input.dim() >= 3
263263
N = input.size(0)
264264
C = input.size(1)
265265
out_size = (N,) + input.size()[2:]

test/test_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4266,6 +4266,15 @@ def forward(self, *args):
42664266
check_no_size_average=True,
42674267
desc='higher_dim'
42684268
),
4269+
dict(
4270+
module_name='NLLLoss',
4271+
input_size=(2, 3, 5),
4272+
target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
4273+
reference_fn=lambda i, t, m:
4274+
loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
4275+
check_no_size_average=True,
4276+
desc='dim_is_3'
4277+
),
42694278
dict(
42704279
module_name='PoissonNLLLoss',
42714280
input_size=(2, 3, 4, 5),

torch/nn/functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,10 +1162,10 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
11621162
11631163
Args:
11641164
input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
1165-
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 2`
1165+
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 1`
11661166
in the case of K-dimensional loss.
11671167
target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`,
1168-
or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 2` for
1168+
or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 1` for
11691169
K-dimensional loss.
11701170
weight (Tensor, optional): a manual rescaling weight given to each
11711171
class. If given, has to be a Tensor of size `C`
@@ -1192,7 +1192,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
11921192
return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
11931193
elif dim == 4:
11941194
return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
1195-
elif dim > 4:
1195+
elif dim == 3 or dim > 4:
11961196
n = input.size(0)
11971197
c = input.size(1)
11981198
out_size = (n,) + input.size()[2:]
@@ -1206,7 +1206,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
12061206
out = torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
12071207
return out.view(out_size)
12081208
else:
1209-
raise ValueError('Expected 2, 4, or more than 4 dimensions (got {})'.format(dim))
1209+
raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
12101210

12111211

12121212
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):

0 commit comments

Comments
 (0)