Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions aten/src/THCUNN/BCECriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ inline __host__ __device__ float eps() { return 1e-12f; }
template <>
inline __host__ __device__ double eps() { return 1e-12; }

template <typename T>
inline __host__ __device__ T safe_log(T a) {
if (a == 0.)
{
return THCNumerics<T>::log(eps<T>());
}
return THCNumerics<T>::log(a);
}

template <typename Dtype, typename Acctype>
struct bce_functor
{
Expand All @@ -31,7 +40,8 @@ struct bce_functor
Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
assert(input >= 0. && input <= 1.);
return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
return - (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input))
+ (Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
}
};

Expand All @@ -46,8 +56,8 @@ struct bce_updateOutput_no_reduce_functor
{
assert(*input >= 0. && *input <= 1.);
*output = ScalarConvert<Acctype, Dtype>::to(
-(*target * THCNumerics<Acctype>::log(*input + eps<Acctype>()) +
(Acctype(1) - *target) * THCNumerics<Acctype>::log(Acctype(1) - *input + eps<Acctype>())));
-(*target * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(*input)) +
(Acctype(1) - *target) * safe_log<Acctype>(Acctype(1) - *input)));
}
};

Expand All @@ -62,8 +72,8 @@ struct bce_functor_weights
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
assert(input >= 0. && input <= 1.);
return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) +
(Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
return - w * (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to(input)) +
(Acctype(1) - t) * safe_log<Acctype>(Acctype(1) - input));
}
};

Expand Down
13 changes: 10 additions & 3 deletions aten/src/THNN/generic/BCECriterion.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

#define EPS 1e-12

static inline real safe_log(real a) {
if (a == 0.) {
return log(EPS);
}
return log(a);

This comment was marked as off-topic.

This comment was marked as off-topic.

}

void THNN_(BCECriterion_updateOutput)(
THNNState *state,
THTensor *input,
Expand All @@ -24,7 +31,7 @@ void THNN_(BCECriterion_updateOutput)(
THAssertMsg(x >= 0. && x <= 1.,
"input value should be between 0~1, but got %f",
(double) x);
*output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y));
*output_data = -(safe_log(x) * y + safe_log(1. - x) * (1. - y));
);
if (weights) {
THTensor_(cmul)(output, output, weights);
Expand All @@ -43,7 +50,7 @@ void THNN_(BCECriterion_updateOutput)(
THAssertMsg(x >= 0. && x <= 1.,
"input value should be between 0~1, but got %f",
(double) x);
sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w;
sum -= (safe_log(x) * y + safe_log(1. - x) * (1. - y)) * w;
);
} else {
TH_TENSOR_APPLY2(real, input, real, target,
Expand All @@ -52,7 +59,7 @@ void THNN_(BCECriterion_updateOutput)(
THAssertMsg(x >= 0. && x <= 1.,
"input value should be between 0~1, but got %f",
(double) x);
sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y);
sum -= safe_log(x) * y + safe_log(1. - x) * (1. - y);
);
}

Expand Down
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4116,6 +4116,15 @@ def func(root):
gradcheck(func, [v])
gradgradcheck(func, [v])

def test_bce_loss_always_nonnegative(self):
target = torch.ones(5)
input = torch.ones(5)
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)

target = torch.zeros(5)
input = torch.zeros(5)
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)

def test_bce_with_logits_raises_if_target_and_input_are_different_size(self):
target = torch.rand(5)
input = torch.rand(5, 1)
Expand Down