@@ -11,14 +11,14 @@ namespace at {
1111namespace native {
1212namespace {
1313
14- const char lerp_tensor_name[] = " lerp_tensor_kernel " ;
14+ const char lerp_tensor_name[] = " lerp_tensor " ;
1515void lerp_tensor_kernel (at::TensorIteratorBase& iter) {
1616 auto dtype = iter.common_dtype ();
1717 if (at::isComplexType (dtype)) {
1818#if AT_USE_JITERATOR()
1919 static const auto lerp_tensor_string = jiterator_stringify (
2020 template <typename T>
21- T lerp_tensor_kernel (T self_val, T end_val, T weight_val) {
21+ T lerp_tensor (T self_val, T end_val, T weight_val) {
2222 return (std::abs (weight_val) < 0.5 )
2323 ? self_val + weight_val * (end_val - self_val)
2424 : end_val -
@@ -78,14 +78,14 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
7878 }
7979}
8080
81- const char lerp_scalar_name[] = " lerp_scalar_kernel " ;
81+ const char lerp_scalar_name[] = " lerp_scalar " ;
8282void lerp_scalar_kernel (at::TensorIteratorBase& iter, const c10::Scalar& weight) {
8383 auto dtype = iter.common_dtype ();
8484 if (at::isComplexType (dtype)) {
85- #if false // AT_USE_JITERATOR()
85+ #if AT_USE_JITERATOR()
8686 static const auto lerp_scalar_string = jiterator_stringify (
8787 template <typename T>
88- T lerp_scalar_kernel (T self_val, T end_val, int weight) {
88+ T lerp_scalar (T self_val, T end_val, float weight) {
8989 auto weight_val = weight.to <T>();
9090 return (std::abs (weight_val) < 0.5 )
9191 ? self_val + weight_val * (end_val - self_val)
@@ -103,7 +103,7 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight)
103103 lerp_scalar_string,
104104 /* scalar_pos=*/ at::cuda::jit::BinaryFuncVariant::NoScalar,
105105 /* scalar_val=*/ 0 ,
106- /* extra_args=*/ std::make_tuple (weight));
106+ /* extra_args=*/ std::make_tuple (weight. to < scalar_t >() ));
107107 });
108108#else
109109 AT_DISPATCH_COMPLEX_TYPES_AND (kComplexHalf , dtype, " lerp_cuda" , [&] {
0 commit comments