@@ -26,6 +26,7 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
2626 }
2727 ); // lerp_tensor_string
2828 AT_DISPATCH_COMPLEX_TYPES (dtype, " lerp_cuda" , [&] {
29+ using opmath_t = at::opmath_type<scalar_t >;
2930 jitted_gpu_kernel<
3031 /* name=*/ lerp_tensor_name,
3132 /* return_dtype=*/ scalar_t ,
@@ -34,16 +35,20 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
3435 });
3536#else
3637 AT_DISPATCH_COMPLEX_TYPES (dtype, " lerp_cuda" , [&] {
38+ using opmath_t = at::opmath_type<scalar_t >;
3739 at::native::gpu_kernel (
3840 iter,
3941 [] GPU_LAMBDA (
4042 scalar_t self_val,
4143 scalar_t end_val,
4244 scalar_t weight_val) -> scalar_t {
43- return (std:abs (weight_val) < 0.5 )
44- ? self_val + weight_val * (end_val - self_val)
45- : end_val -
46- (end_val - self_val) * (static_cast <scalar_t >(1 ) - weight_val);
45+ opmath_t self_val_f = self_val;
46+ opmath_t end_val_f = end_val;
47+ opmath_t weight_val_f = weight_val;
48+ return (std:abs (weight_val_f) < 0.5 )
49+ ? self_val_f + weight_val_f * (end_val_f - self_val_f)
50+ : end_val_f -
51+ (end_val_f - self_val_f) * (opmath_t {1 } - weight_val_f);
4752 });
4853 });
4954#endif
@@ -96,14 +101,17 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight)
96101 });
97102#else
98103 AT_DISPATCH_COMPLEX_TYPES (dtype, " lerp_cuda" , [&] {
99- auto weight_val = weight.to <scalar_t >();
104+ using opmath_t = at::opmath_type<scalar_t >;
105+ auto weight_val = weight.to <opmath_t >();
100106 gpu_kernel (
101107 iter,
102108 [=] GPU_LAMBDA (scalar_t self_val, scalar_t end_val) {
103- return (std::abs (weight_val) < 0.5 )
104- ? self_val + weight_val * (end_val - self_val)
105- : end_val -
106- (end_val - self_val) * (static_cast <scalar_t >(1 ) - weight_val);
109+ opmath_t self_val_f = self_val;
110+ opmath_t end_val_f = end_val;
111+ return (std::abs (weight_val_f) < 0.5 )
112+ ? self_val_f + weight_val_f * (end_val_f - self_val_f)
113+ : end_val_f -
114+ (end_val_f - self_val_f) * (opmath_t {1 } - weight_val_f);
107115 });
108116 });
109117#endif
0 commit comments