Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
#define _USE_MATH_DEFINES
#include <math.h>
#endif
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/CPUApplyUtils.h>

#define EPSILON 1e-12
#define _USE_MATH_DEFINES

namespace {
static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
Expand Down Expand Up @@ -125,4 +131,21 @@ Tensor binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tenso

return grad_input;
}

Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
{
Tensor loss;
if (log_input) {
loss = at::exp(input) - target * input;
} else {
loss = input - target * at::log(input + eps);
}

if (full) {
auto mask1 = (target > 1);
loss.masked_select(mask1) += (target * at::log(target) - target + 0.5 * at::log(2 * M_PI * target)).masked_select(mask1);
}

return apply_loss_reduction(loss, reduction);
}
}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,9 @@
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
variants: function, method

- func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor
variants: function

- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down
17 changes: 3 additions & 14 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,22 +1930,11 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
"""
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if log_input:
loss = torch.exp(input) - target * input
else:
loss = input - target * torch.log(input + eps)
if full:
mask = target > 1
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
if reduction == 'none':
ret = loss
elif reduction == 'mean':
ret = torch.mean(loss)
elif reduction == 'sum':
ret = torch.sum(loss)
else:
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
ret = input
raise ValueError(reduction + " is not valid")

ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction))
return ret


Expand Down