Skip to content

Commit d423163

Browse files
committed
get c++ type out of weight
1 parent b4e465f commit d423163

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

aten/src/ATen/native/cuda/Lerp.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ namespace at {
1111
namespace native {
1212
namespace {
1313

14-
const char lerp_tensor_name[] = "lerp_tensor_kernel";
14+
const char lerp_tensor_name[] = "lerp_tensor";
1515
void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
1616
auto dtype = iter.common_dtype();
1717
if(at::isComplexType(dtype)) {
1818
#if AT_USE_JITERATOR()
1919
static const auto lerp_tensor_string = jiterator_stringify(
2020
template <typename T>
21-
T lerp_tensor_kernel(T self_val, T end_val, T weight_val) {
21+
T lerp_tensor(T self_val, T end_val, T weight_val) {
2222
return (std::abs(weight_val) < 0.5)
2323
? self_val + weight_val * (end_val - self_val)
2424
: end_val -
@@ -78,14 +78,14 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
7878
}
7979
}
8080

81-
const char lerp_scalar_name[] = "lerp_scalar_kernel";
81+
const char lerp_scalar_name[] = "lerp_scalar";
8282
void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) {
8383
auto dtype = iter.common_dtype();
8484
if (at::isComplexType(dtype)) {
85-
#if false // AT_USE_JITERATOR()
85+
#if AT_USE_JITERATOR()
8686
static const auto lerp_scalar_string = jiterator_stringify(
8787
template <typename T>
88-
T lerp_scalar_kernel(T self_val, T end_val, int weight) {
88+
T lerp_scalar(T self_val, T end_val, float weight) {
8989
auto weight_val = weight.to<T>();
9090
return (std::abs(weight_val) < 0.5)
9191
? self_val + weight_val * (end_val - self_val)
@@ -103,7 +103,7 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight)
103103
lerp_scalar_string,
104104
/*scalar_pos=*/ at::cuda::jit::BinaryFuncVariant::NoScalar,
105105
/*scalar_val=*/ 0,
106-
/*extra_args=*/ std::make_tuple(weight));
106+
/*extra_args=*/ std::make_tuple(weight.to<scalar_t>()));
107107
});
108108
#else
109109
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {

0 commit comments

Comments
 (0)