Skip to content

Commit 73966f6

Browse files
li-roysoumith
authored andcommitted
Stop BCELoss from returning negative results (#8147)
* Stop BCELoss from returning negative results * check explicitly for 0 before taking log * add tests * fix lint * address comments
1 parent e2be77e commit 73966f6

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

aten/src/THCUNN/BCECriterion.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ inline __host__ __device__ float eps() { return 1e-12f; }
2121
template <>
2222
inline __host__ __device__ double eps() { return 1e-12; }
2323

24+
template <typename T>
25+
inline __host__ __device__ T safe_log(T a) {
26+
if (a == 0.)
27+
{
28+
return THCNumerics<T>::log(eps<T>());
29+
}
30+
return THCNumerics<T>::log(a);
31+
}
32+
2433
template <typename Dtype, typename Acctype>
2534
struct bce_functor
2635
{
@@ -31,7 +40,8 @@ struct bce_functor
3140
Dtype input = thrust::get<0>(x);
3241
Dtype t = thrust::get<1>(x);
3342
assert(input >= 0. && input <= 1.);
34-
return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
43+
return - (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input))
44+
+ (Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
3545
}
3646
};
3747

@@ -46,8 +56,8 @@ struct bce_updateOutput_no_reduce_functor
4656
{
4757
assert(*input >= 0. && *input <= 1.);
4858
*output = ScalarConvert<Acctype, Dtype>::to(
49-
-(*target * THCNumerics<Acctype>::log(*input + eps<Acctype>()) +
50-
(Acctype(1) - *target) * THCNumerics<Acctype>::log(Acctype(1) - *input + eps<Acctype>())));
59+
-(*target * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(*input)) +
60+
(Acctype(1) - *target) * safe_log<Acctype>(Acctype(1) - *input)));
5161
}
5262
};
5363

@@ -62,8 +72,8 @@ struct bce_functor_weights
6272
Dtype t = thrust::get<1>(x);
6373
Dtype w = thrust::get<2>(x);
6474
assert(input >= 0. && input <= 1.);
65-
return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) +
66-
(Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
75+
return - w * (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input)) +
76+
(Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
6777
}
6878
};
6979

aten/src/THNN/generic/BCECriterion.c

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
#define EPS 1e-12
66

7+
static inline real safe_log(real a) {
8+
if (a == 0.) {
9+
return log(EPS);
10+
}
11+
return log(a);
12+
}
13+
714
void THNN_(BCECriterion_updateOutput)(
815
THNNState *state,
916
THTensor *input,
@@ -24,7 +31,7 @@ void THNN_(BCECriterion_updateOutput)(
2431
THAssertMsg(x >= 0. && x <= 1.,
2532
"input value should be between 0~1, but got %f",
2633
(double) x);
27-
*output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y));
34+
*output_data = -(safe_log(x) * y + safe_log(1. - x) * (1. - y));
2835
);
2936
if (weights) {
3037
THTensor_(cmul)(output, output, weights);
@@ -43,7 +50,7 @@ void THNN_(BCECriterion_updateOutput)(
4350
THAssertMsg(x >= 0. && x <= 1.,
4451
"input value should be between 0~1, but got %f",
4552
(double) x);
46-
sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w;
53+
sum -= (safe_log(x) * y + safe_log(1. - x) * (1. - y)) * w;
4754
);
4855
} else {
4956
TH_TENSOR_APPLY2(real, input, real, target,
@@ -52,7 +59,7 @@ void THNN_(BCECriterion_updateOutput)(
5259
THAssertMsg(x >= 0. && x <= 1.,
5360
"input value should be between 0~1, but got %f",
5461
(double) x);
55-
sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y);
62+
sum -= safe_log(x) * y + safe_log(1. - x) * (1. - y);
5663
);
5764
}
5865

test/test_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,6 +4116,15 @@ def func(root):
41164116
gradcheck(func, [v])
41174117
gradgradcheck(func, [v])
41184118

4119+
def test_bce_loss_always_nonnegative(self):
4120+
target = torch.ones(5)
4121+
input = torch.ones(5)
4122+
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4123+
4124+
target = torch.zeros(5)
4125+
input = torch.zeros(5)
4126+
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4127+
41194128
def test_bce_with_logits_raises_if_target_and_input_are_different_size(self):
41204129
target = torch.rand(5)
41214130
input = torch.rand(5, 1)

0 commit comments

Comments
 (0)