Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,18 +3728,28 @@ def forward(self, *args):
module_name='PoissonNLLLoss',
input_size=(2, 3, 4, 5),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
desc='reduced_loss',
desc='no_full_loss', # without sterling approx
),
dict(
module_name='PoissonNLLLoss',
constructor_args=(False, True, True),
input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
desc='full_loss',
desc='full_loss', # with sterling approx
),
]


def poissonnllloss_no_reduce_test():
t = Variable(torch.randn(10, 10))
return dict(
fullname='PoissonNLLLLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)),
input_fn=lambda: torch.rand(10, 10),
pickle=False)


def kldivloss_no_reduce_test():
t = Variable(torch.randn(10, 10))
return dict(
Expand Down Expand Up @@ -3910,6 +3920,7 @@ def smoothl1loss_no_reduce_test():


new_module_tests = [
poissonnllloss_no_reduce_test(),
kldivloss_no_reduce_test(),
l1loss_no_reduce_test(),
mseloss_no_reduce_test(),
Expand Down
11 changes: 8 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))


def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8):
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):
r"""Poisson negative log likelihood loss.

See :class:`~torch.nn.PoissonNLLLoss` for details.
Expand All @@ -1122,6 +1122,10 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
the losses are instead summed for each minibatch. Default: ``True``
eps (float, optional): Small value to avoid evaluation of log(0) when
log_input=False. Default: 1e-8
reduce (bool, optional): By default, the losses are averaged
over observations for each minibatch, or summed, depending on
size_average. When reduce is ``False``, returns a loss per batch
element instead and ignores size_average. Default: ``True``
"""
if log_input:
loss = torch.exp(input) - target * input
Expand All @@ -1130,10 +1134,11 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
if full:
mask = target > 1
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
if not reduce:
return loss
if size_average:
return torch.mean(loss)
else:
return torch.sum(loss)
return torch.sum(loss)


kl_div = _add_docstr(torch._C._nn.kl_div, r"""
Expand Down
13 changes: 9 additions & 4 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ class PoissonNLLLoss(_Loss):
is set to ``False``, the losses are instead summed for each minibatch.
eps (float, optional): Small value to avoid evaluation of log(0) when
log_input==``False``. Default: 1e-8
reduce (bool, optional): By default, the losses are averaged
over observations for each minibatch, or summed, depending on
size_average. When reduce is ``False``, returns a loss per batch
element instead and ignores size_average. Default: ``True``

Examples::

Expand All @@ -217,16 +221,17 @@ class PoissonNLLLoss(_Loss):
>>> output = loss(log_input, target)
>>> output.backward()
"""
def __init__(self, log_input=True, full=False, size_average=True, eps=1e-8):
super(PoissonNLLLoss, self).__init__()
def __init__(self, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):
super(PoissonNLLLoss, self).__init__(size_average)
self.log_input = log_input
self.full = full
self.size_average = size_average
self.eps = eps
self.reduce = reduce

def forward(self, log_input, target):
_assert_no_grad(target)
return F.poisson_nll_loss(log_input, target, self.log_input, self.full, self.size_average, self.eps)
return F.poisson_nll_loss(log_input, target, self.log_input, self.full,
self.size_average, self.eps, self.reduce)


class KLDivLoss(_Loss):
Expand Down