Skip to content

Commit bf94d10

Browse files
committed
use opmath_t in jiterator
1 parent 2603e07 commit bf94d10

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

aten/src/ATen/cuda/llvm_complex.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ lerp(const complex<_Tp>& self_val, const complex<_Tp>& end_val, const complex<_T
874874
? self_val.real() + weight_val.real() * (end_val.real() - self_val.real())
875875
: end_val.real() -
876876
(end_val.real() - self_val.real()) * (static_cast<_Tp>(1) - weight_val.real()))
877-
+ ((std::abs(weight_val.imag()) < 0.5)
877+
+ ((std::abs(weight_val.imag()) < 0.5)
878878
? self_val.imag() + weight_val.imag() * (end_val.imag() - self_val.imag())
879879
: end_val.imag() -
880880
(end_val.imag() - self_val.imag()) * (static_cast<_Tp>(1) - weight_val.imag()));

aten/src/ATen/native/cuda/Lerp.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
2626
}
2727
); // lerp_tensor_string
2828
AT_DISPATCH_COMPLEX_TYPES(dtype, "lerp_cuda", [&] {
29+
using opmath_t = at::opmath_type<scalar_t>;
2930
jitted_gpu_kernel<
3031
/*name=*/ lerp_tensor_name,
3132
/*return_dtype=*/ scalar_t,
@@ -34,16 +35,20 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
3435
});
3536
#else
3637
AT_DISPATCH_COMPLEX_TYPES(dtype, "lerp_cuda", [&] {
38+
using opmath_t = at::opmath_type<scalar_t>;
3739
at::native::gpu_kernel(
3840
iter,
3941
[] GPU_LAMBDA(
4042
scalar_t self_val,
4143
scalar_t end_val,
4244
scalar_t weight_val) -> scalar_t {
43-
return (std:abs(weight_val) < 0.5)
44-
? self_val + weight_val * (end_val - self_val)
45-
: end_val -
46-
(end_val - self_val) * (static_cast<scalar_t>(1) - weight_val);
45+
opmath_t self_val_f = self_val;
46+
opmath_t end_val_f = end_val;
47+
opmath_t weight_val_f = weight_val;
48+
return (std:abs(weight_val_f) < 0.5)
49+
? self_val_f + weight_val_f * (end_val_f - self_val_f)
50+
: end_val_f -
51+
(end_val_f - self_val_f) * (opmath_t{1} - weight_val_f);
4752
});
4853
});
4954
#endif
@@ -96,14 +101,17 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight)
96101
});
97102
#else
98103
AT_DISPATCH_COMPLEX_TYPES(dtype, "lerp_cuda", [&] {
99-
auto weight_val = weight.to<scalar_t>();
104+
using opmath_t = at::opmath_type<scalar_t>;
105+
auto weight_val = weight.to<opmath_t>();
100106
gpu_kernel(
101107
iter,
102108
[=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
103-
return (std::abs(weight_val) < 0.5)
104-
? self_val + weight_val * (end_val - self_val)
105-
: end_val -
106-
(end_val - self_val) * (static_cast<scalar_t>(1) - weight_val);
109+
opmath_t self_val_f = self_val;
110+
opmath_t end_val_f = end_val;
111+
return (std::abs(weight_val_f) < 0.5)
112+
? self_val_f + weight_val_f * (end_val_f - self_val_f)
113+
: end_val_f -
114+
(end_val_f - self_val_f) * (opmath_t{1} - weight_val_f);
107115
});
108116
});
109117
#endif

0 commit comments

Comments
 (0)