@@ -21,6 +21,15 @@ inline __host__ __device__ float eps() { return 1e-12f; }
2121template <>
2222inline __host__ __device__ double eps () { return 1e-12 ; }
2323
24+ template <typename T>
25+ inline __host__ __device__ T safe_log (T a) {
26+ if (a == 0 .)
27+ {
28+ return THCNumerics<T>::log (eps<T>());
29+ }
30+ return THCNumerics<T>::log (a);
31+ }
32+
2433template <typename Dtype, typename Acctype>
2534struct bce_functor
2635{
@@ -31,7 +40,8 @@ struct bce_functor
3140 Dtype input = thrust::get<0 >(x);
3241 Dtype t = thrust::get<1 >(x);
3342 assert (input >= 0 . && input <= 1 .);
34- return - (t * THCNumerics<Acctype>::log (input + eps<Acctype>()) + (Acctype (1 )- t) * THCNumerics<Acctype>::log (Acctype (1 ) - input + eps<Acctype>()));
43+ return - (t * safe_log<Acctype>(ScalarConvert<Dtype, Acctype>::to (input))
44+ + (Acctype (1 ) - t) * safe_log<Acctype>(Acctype (1 ) - input));
3545 }
3646};
3747
@@ -46,8 +56,8 @@ struct bce_updateOutput_no_reduce_functor
4656 {
4757 assert (*input >= 0 . && *input <= 1 .);
4858 *output = ScalarConvert<Acctype, Dtype>::to (
49- -(*target * THCNumerics <Acctype>:: log (*input + eps< Acctype>( )) +
50- (Acctype (1 ) - *target) * THCNumerics <Acctype>:: log (Acctype (1 ) - *input + eps<Acctype>() )));
59+ -(*target * safe_log <Acctype>(ScalarConvert<Dtype, Acctype>:: to (*input )) +
60+ (Acctype (1 ) - *target) * safe_log <Acctype>(Acctype (1 ) - *input)));
5161 }
5262};
5363
@@ -62,8 +72,8 @@ struct bce_functor_weights
6272 Dtype t = thrust::get<1 >(x);
6373 Dtype w = thrust::get<2 >(x);
6474 assert (input >= 0 . && input <= 1 .);
65- return - w * (t * THCNumerics <Acctype>:: log (input + eps< Acctype>( )) +
66- (Acctype (1 ) - t) * THCNumerics <Acctype>:: log (Acctype (1 ) - input + eps<Acctype>() ));
75+ return - w * (t * safe_log <Acctype>(ScalarConvert<Dtype, Acctype>:: to (input )) +
76+ (Acctype (1 ) - t) * safe_log <Acctype>(Acctype (1 ) - input));
6777 }
6878};
6979
0 commit comments