Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 13 additions & 68 deletions aten/src/ATen/native/Lerp.cpp
Original file line number Diff line number Diff line change
@@ -1,111 +1,56 @@
#include <ATen/native/Lerp.h>

#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>

namespace {
template <typename scalar_t>
void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, const at::Tensor& weight) {
at::CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
ret, self, end, weight,
[](scalar_t& ret_val,
const scalar_t& self_val,
const scalar_t& end_val,
const scalar_t& weight_val) {
ret_val = (weight_val < 0.5) ?
self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
});
}

template <typename scalar_t>
void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, scalar_t weight_val) {
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
ret, self, end,
[=](scalar_t& ret_val,
const scalar_t& self_val,
const scalar_t& end_val) {
ret_val = (weight_val < 0.5) ?
self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
});
}

} // namespace

namespace at {
namespace native {

Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self,
const Tensor& end, const Tensor& weight) {
Tensor b_self, b_end, b_weight;
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cpu");
result.resize_as_(b_self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
});
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
return result;
}

Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self,
const Tensor& end, Scalar weight) {
Tensor b_self, b_end;
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cpu");
result.resize_as_(b_self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
});
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
return result;
}

Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
Tensor b_self, b_end, b_weight;
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cpu");
TORCH_CHECK(b_self.sizes() == self.sizes(),
"output with shape ", self.sizes(),
" doesn't match the broadcast shape ", b_self.sizes());
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
lerp_cpu<scalar_t>(self, b_self, b_end, b_weight);
});
lerp_kernel_tensor_weight(kCPU, self, self, end, weight);
return self;
}

Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
Tensor b_self, b_end;
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cpu");
TORCH_CHECK(b_self.sizes() == self.sizes(),
"output with shape ", self.sizes(),
" doesn't match the broadcast shape ", b_self.sizes());
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
lerp_cpu<scalar_t>(self, b_self, b_end, weight.to<scalar_t>());
});
lerp_kernel_scalar_weight(kCPU, self, self, end, weight);
return self;
}

Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weight) {
Tensor b_self, b_end, b_weight;
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cpu");
Tensor result = at::empty_like(b_self);
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
});
Tensor result = at::empty({0}, self.options());
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
return result;
}

Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) {
Tensor b_self, b_end;
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cpu");
Tensor result = at::empty_like(b_self);
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
});
Tensor result = at::empty({0}, self.options());
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
return result;
}

DEFINE_DISPATCH(lerp_kernel_scalar_weight);
DEFINE_DISPATCH(lerp_kernel_tensor_weight);

} // namespace native
} // namespace at
25 changes: 25 additions & 0 deletions aten/src/ATen/native/Lerp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>

namespace at {
namespace native {

using lerp_fn_scalar = void (*)(
at::Tensor& ret,
const at::Tensor& self,
const at::Tensor& end,
Scalar weight);

using lerp_fn_tensor = void (*)(
at::Tensor& ret,
const at::Tensor& self,
const at::Tensor& end,
const at::Tensor& weights);

DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);

} // namespace native
} // namespace at
62 changes: 62 additions & 0 deletions aten/src/ATen/native/cpu/LerpKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <ATen/ATen.h>

#include <ATen/Dispatch.h>
#include <ATen/native/Lerp.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>

namespace at {
namespace native {
namespace {

static void lerp_kernel_scalar(
Tensor& ret,
const Tensor& self,
const Tensor& end,
Scalar weight) {
auto builder = at::TensorIterator::Builder();
builder.add_output(ret);
builder.add_input(self);
builder.add_input(end);
auto iter = builder.build();
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] {
scalar_t weight_val = weight.to<scalar_t>();
at::native::cpu_kernel(
*iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return (weight_val < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (1 - weight_val);
});
});
}

static void lerp_kernel_tensor(
Tensor& ret,
const Tensor& self,
const Tensor& end,
const Tensor& weights) {
auto builder = at::TensorIterator::Builder();
builder.add_output(ret);
builder.add_input(self);
builder.add_input(end);
builder.add_input(weights);
auto iter = builder.build();
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "lerp_kernel_tensor", [&] {
at::native::cpu_kernel(
*iter,
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
return (weight_val < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (1 - weight_val);
});
});
}

} // anonymous namespace

REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_kernel_scalar);
REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_kernel_tensor);

} // namespace native
} // namespace at