Skip to content

Commit a17c011

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
fix stability in bce with pos_weight formula (#13863)
Summary: Fixes #13773 Pull Request resolved: #13863 Differential Revision: D13031803 Pulled By: ailzhang fbshipit-source-id: 6c9e044f0450eebf4555bbc02c125713d9378e2f
1 parent 0bfbdca commit a17c011

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

aten/src/ATen/native/Loss.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& targe
9393
if (pos_weight.defined()) {
9494
// pos_weight need to be broadcasted, thus mul(target) is not inplace.
9595
auto log_weight = (pos_weight - 1).mul(target).add_(1);
96-
loss = (1 - target).mul_(input).add_(log_weight.mul_((-max_val).exp_().mul_(1 + (-input).exp_()).log_().add_(max_val)));
96+
loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
9797
} else {
9898
loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
9999
}

test/test_nn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4939,6 +4939,17 @@ def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
49394939
grad = output.grad
49404940
self.assertEqual(grad, expected_grad)
49414941

4942+
def test_bce_with_logits_stability(self):
4943+
output = torch.tensor([0., -120.])
4944+
target = torch.tensor([0., 1.])
4945+
pos_weight = torch.tensor([1., 1.])
4946+
4947+
out1 = nn.BCEWithLogitsLoss()(output, target)
4948+
self.assertTrue(torch.isfinite(out1).all().item())
4949+
4950+
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4951+
self.assertTrue(torch.isfinite(out2).all().item())
4952+
49424953
def test_bce_loss_broadcasts_weights(self):
49434954
sigmoid = nn.Sigmoid()
49444955
target = torch.rand(16, 4)

0 commit comments

Comments
 (0)