Skip to content

Commit 4797f98

Browse files
kevinzakkasoumith
authored andcommitted
add reduce arg to PoissonNLLLoss (#3770)
* add reduce arg to PoissonNLLLoss * fixed comments except reference function * fixed unit test * small indentation fix * fixing last comments by richard * lint check * another linting issue
1 parent 840760c commit 4797f98

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

test/test_nn.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,18 +3749,28 @@ def forward(self, *args):
37493749
module_name='PoissonNLLLoss',
37503750
input_size=(2, 3, 4, 5),
37513751
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
3752-
desc='reduced_loss',
3752+
desc='no_full_loss', # without sterling approx
37533753
),
37543754
dict(
37553755
module_name='PoissonNLLLoss',
37563756
constructor_args=(False, True, True),
37573757
input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
37583758
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
3759-
desc='full_loss',
3759+
desc='full_loss', # with sterling approx
37603760
),
37613761
]
37623762

37633763

3764+
def poissonnllloss_no_reduce_test():
3765+
t = Variable(torch.randn(10, 10))
3766+
return dict(
3767+
fullname='PoissonNLLLLoss_no_reduce',
3768+
constructor=wrap_functional(
3769+
lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)),
3770+
input_fn=lambda: torch.rand(10, 10),
3771+
pickle=False)
3772+
3773+
37643774
def kldivloss_no_reduce_test():
37653775
t = Variable(torch.randn(10, 10))
37663776
return dict(
@@ -3931,6 +3941,7 @@ def smoothl1loss_no_reduce_test():
39313941

39323942

39333943
new_module_tests = [
3944+
poissonnllloss_no_reduce_test(),
39343945
kldivloss_no_reduce_test(),
39353946
l1loss_no_reduce_test(),
39363947
mseloss_no_reduce_test(),

torch/nn/functional.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
10531053
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
10541054

10551055

1056-
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8):
1056+
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):
10571057
r"""Poisson negative log likelihood loss.
10581058
10591059
See :class:`~torch.nn.PoissonNLLLoss` for details.
@@ -1072,6 +1072,10 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
10721072
the losses are instead summed for each minibatch. Default: ``True``
10731073
eps (float, optional): Small value to avoid evaluation of log(0) when
10741074
log_input=False. Default: 1e-8
1075+
reduce (bool, optional): By default, the losses are averaged
1076+
over observations for each minibatch, or summed, depending on
1077+
size_average. When reduce is ``False``, returns a loss per batch
1078+
element instead and ignores size_average. Default: ``True``
10751079
"""
10761080
if log_input:
10771081
loss = torch.exp(input) - target * input
@@ -1080,10 +1084,11 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
10801084
if full:
10811085
mask = target > 1
10821086
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
1087+
if not reduce:
1088+
return loss
10831089
if size_average:
10841090
return torch.mean(loss)
1085-
else:
1086-
return torch.sum(loss)
1091+
return torch.sum(loss)
10871092

10881093

10891094
kl_div = _add_docstr(torch._C._nn.kl_div, r"""

torch/nn/modules/loss.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ class PoissonNLLLoss(_Loss):
208208
is set to ``False``, the losses are instead summed for each minibatch.
209209
eps (float, optional): Small value to avoid evaluation of log(0) when
210210
log_input==``False``. Default: 1e-8
211+
reduce (bool, optional): By default, the losses are averaged
212+
over observations for each minibatch, or summed, depending on
213+
size_average. When reduce is ``False``, returns a loss per batch
214+
element instead and ignores size_average. Default: ``True``
211215
212216
Examples::
213217
@@ -217,16 +221,17 @@ class PoissonNLLLoss(_Loss):
217221
>>> output = loss(log_input, target)
218222
>>> output.backward()
219223
"""
220-
def __init__(self, log_input=True, full=False, size_average=True, eps=1e-8):
221-
super(PoissonNLLLoss, self).__init__()
224+
def __init__(self, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):
225+
super(PoissonNLLLoss, self).__init__(size_average)
222226
self.log_input = log_input
223227
self.full = full
224-
self.size_average = size_average
225228
self.eps = eps
229+
self.reduce = reduce
226230

227231
def forward(self, log_input, target):
228232
_assert_no_grad(target)
229-
return F.poisson_nll_loss(log_input, target, self.log_input, self.full, self.size_average, self.eps)
233+
return F.poisson_nll_loss(log_input, target, self.log_input, self.full,
234+
self.size_average, self.eps, self.reduce)
230235

231236

232237
class KLDivLoss(_Loss):

0 commit comments

Comments
 (0)