Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))


def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True):
def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8):
r"""Poisson negative log likelihood loss.

See :class:`~torch.nn.PoissonNLLLoss` for details.
Expand All @@ -934,18 +934,20 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Tru
target: random sample :math:`target \sim Pois(input)`.
log_input: if True the loss is computed as
`exp(input) - target * input`, if False then loss is
`input - target * log(input)`. Default: True
`input - target * log(input+eps)`. Default: True
full: whether to compute full loss, i. e. to add the Stirling
approximation term. Default: False
`target * log(target) - target + 0.5 * log(2 * pi * target)`.
size_average: By default, the losses are averaged over observations for
each minibatch. However, if the field sizeAverage is set to False,
the losses are instead summed for each minibatch. Default: True
eps (float, optional): Small value to avoid evaluation of log(0) when
log_input=False. Default: 1e-8
"""
if log_input:
loss = torch.exp(input) - target * input
else:
loss = input - target * torch.log(input)
loss = input - target * torch.log(input + eps)
if full:
mask = target > 1
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
Expand Down
9 changes: 6 additions & 3 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,15 @@ class PoissonNLLLoss(_Loss):
Args:
log_input (bool, optional): if True the loss is computed as
`exp(input) - target * input`, if False the loss is
`input - target * log(input)`.
`input - target * log(input+eps)`.
full (bool, optional): whether to compute full loss, i. e. to add the
Stirling approximation term
`target * log(target) - target + 0.5 * log(2 * pi * target)`.
size_average (bool, optional): By default, the losses are averaged over
observations for each minibatch. However, if the field size_average
is set to False, the losses are instead summed for each minibatch.
eps (float, optional): Small value to avoid evaluation of log(0) when
log_input=False. Default: 1e-8

Examples::

Expand All @@ -206,15 +208,16 @@ class PoissonNLLLoss(_Loss):
>>> output = loss(log_input, target)
>>> output.backward()
"""
def __init__(self, log_input=True, full=False, size_average=True):
def __init__(self, log_input=True, full=False, size_average=True, eps=1e-8):
super(PoissonNLLLoss, self).__init__()
self.log_input = log_input
self.full = full
self.size_average = size_average
self.eps = eps

def forward(self, log_input, target):
_assert_no_grad(target)
return F.poisson_nll_loss(log_input, target, self.log_input, self.full, self.size_average)
return F.poisson_nll_loss(log_input, target, self.log_input, self.full, self.size_average, self.eps)


class KLDivLoss(_Loss):
Expand Down