Skip to content

Commit 35fed93

Browse files
interesaaatfacebook-github-bot
authored andcommitted
Adding Poisson NLL loss to libtorch (#19316)
Summary: This PR add Poisson NLL loss to aten and substitute the python implementation with a call to the c++. Fixes #19186. Pull Request resolved: #19316 Differential Revision: D15012957 Pulled By: ezyang fbshipit-source-id: 0a3f56e8307969c2f9cc321b5357a496c3d1784e
1 parent ed25b8a commit 35fed93

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

aten/src/ATen/native/Loss.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
// define constants like M_PI and C keywords for MSVC
2+
#ifdef _MSC_VER
3+
#define _USE_MATH_DEFINES
4+
#include <math.h>
5+
#endif
16
#include <ATen/ATen.h>
27
#include <ATen/NativeFunctions.h>
38
#include <ATen/Dispatch.h>
49
#include <ATen/CPUApplyUtils.h>
510

611
#define EPSILON 1e-12
12+
#define _USE_MATH_DEFINES
713

814
namespace {
915
static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
@@ -125,4 +131,21 @@ Tensor binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tenso
125131

126132
return grad_input;
127133
}
134+
135+
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
136+
{
137+
Tensor loss;
138+
if (log_input) {
139+
loss = at::exp(input) - target * input;
140+
} else {
141+
loss = input - target * at::log(input + eps);
142+
}
143+
144+
if (full) {
145+
auto mask1 = (target > 1);
146+
loss.masked_select(mask1) += (target * at::log(target) - target + 0.5 * at::log(2 * M_PI * target)).masked_select(mask1);
147+
}
148+
149+
return apply_loss_reduction(loss, reduction);
150+
}
128151
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,9 @@
14111411
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
14121412
variants: function, method
14131413

1414+
- func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor
1415+
variants: function
1416+
14141417
- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
14151418

14161419
- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

torch/nn/functional.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,22 +1930,11 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
19301930
"""
19311931
if size_average is not None or reduce is not None:
19321932
reduction = _Reduction.legacy_get_string(size_average, reduce)
1933-
if log_input:
1934-
loss = torch.exp(input) - target * input
1935-
else:
1936-
loss = input - target * torch.log(input + eps)
1937-
if full:
1938-
mask = target > 1
1939-
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
1940-
if reduction == 'none':
1941-
ret = loss
1942-
elif reduction == 'mean':
1943-
ret = torch.mean(loss)
1944-
elif reduction == 'sum':
1945-
ret = torch.sum(loss)
1946-
else:
1933+
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
19471934
ret = input
19481935
raise ValueError(reduction + " is not valid")
1936+
1937+
ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction))
19491938
return ret
19501939

19511940

0 commit comments

Comments
 (0)