|
| 1 | +#include <ATen/native/Lerp.h> |
| 2 | + |
1 | 3 | #include <ATen/ATen.h> |
2 | 4 | #include <ATen/CPUApplyUtils.h> |
3 | 5 | #include <ATen/NativeFunctions.h> |
4 | 6 | #include <ATen/Dispatch.h> |
5 | 7 | #include <ATen/ExpandUtils.h> |
6 | 8 |
|
7 | | -namespace { |
8 | | -template <typename scalar_t> |
9 | | -void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, const at::Tensor& weight) { |
10 | | - at::CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>( |
11 | | - ret, self, end, weight, |
12 | | - [](scalar_t& ret_val, |
13 | | - const scalar_t& self_val, |
14 | | - const scalar_t& end_val, |
15 | | - const scalar_t& weight_val) { |
16 | | - ret_val = (weight_val < 0.5) ? |
17 | | - self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val); |
18 | | - }); |
19 | | -} |
20 | | - |
21 | | -template <typename scalar_t> |
22 | | -void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, scalar_t weight_val) { |
23 | | - at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>( |
24 | | - ret, self, end, |
25 | | - [=](scalar_t& ret_val, |
26 | | - const scalar_t& self_val, |
27 | | - const scalar_t& end_val) { |
28 | | - ret_val = (weight_val < 0.5) ? |
29 | | - self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val); |
30 | | - }); |
31 | | -} |
32 | | - |
33 | | -} // namespace |
34 | | - |
35 | 9 | namespace at { |
36 | 10 | namespace native { |
37 | 11 |
|
38 | 12 | Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self, |
39 | 13 | const Tensor& end, const Tensor& weight) { |
40 | | - Tensor b_self, b_end, b_weight; |
41 | 14 | TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), |
42 | 15 | "weight should be of dimension max(self.dim(), end.dim()) or lesser"); |
43 | | - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cpu"); |
44 | | - result.resize_as_(b_self); |
45 | | - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{ |
46 | | - lerp_cpu<scalar_t>(result, b_self, b_end, b_weight); |
47 | | - }); |
| 16 | + lerp_kernel_tensor_weight(kCPU, result, self, end, weight); |
48 | 17 | return result; |
49 | 18 | } |
50 | 19 |
|
51 | 20 | Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self, |
52 | 21 | const Tensor& end, Scalar weight) { |
53 | | - Tensor b_self, b_end; |
54 | | - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cpu"); |
55 | | - result.resize_as_(b_self); |
56 | | - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{ |
57 | | - lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>()); |
58 | | - }); |
| 22 | + lerp_kernel_scalar_weight(kCPU, result, self, end, weight); |
59 | 23 | return result; |
60 | 24 | } |
61 | 25 |
|
62 | 26 | Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) { |
63 | | - Tensor b_self, b_end, b_weight; |
64 | | - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cpu"); |
65 | | - TORCH_CHECK(b_self.sizes() == self.sizes(), |
66 | | - "output with shape ", self.sizes(), |
67 | | - " doesn't match the broadcast shape ", b_self.sizes()); |
68 | 27 | TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), |
69 | 28 | "weight should be of dimension max(self.dim(), end.dim()) or lesser"); |
70 | | - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{ |
71 | | - lerp_cpu<scalar_t>(self, b_self, b_end, b_weight); |
72 | | - }); |
| 29 | + lerp_kernel_tensor_weight(kCPU, self, self, end, weight); |
73 | 30 | return self; |
74 | 31 | } |
75 | 32 |
|
76 | 33 | Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) { |
77 | | - Tensor b_self, b_end; |
78 | | - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cpu"); |
79 | | - TORCH_CHECK(b_self.sizes() == self.sizes(), |
80 | | - "output with shape ", self.sizes(), |
81 | | - " doesn't match the broadcast shape ", b_self.sizes()); |
82 | | - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{ |
83 | | - lerp_cpu<scalar_t>(self, b_self, b_end, weight.to<scalar_t>()); |
84 | | - }); |
| 34 | + lerp_kernel_scalar_weight(kCPU, self, self, end, weight); |
85 | 35 | return self; |
86 | 36 | } |
87 | 37 |
|
88 | 38 | Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weight) { |
89 | | - Tensor b_self, b_end, b_weight; |
90 | 39 | TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), |
91 | 40 | "weight should be of dimension max(self.dim(), end.dim()) or lesser"); |
92 | | - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cpu"); |
93 | | - Tensor result = at::empty_like(b_self); |
94 | | - AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{ |
95 | | - lerp_cpu<scalar_t>(result, b_self, b_end, b_weight); |
96 | | - }); |
| 41 | + Tensor result = at::empty({0}, self.options()); |
| 42 | + lerp_kernel_tensor_weight(kCPU, result, self, end, weight); |
97 | 43 | return result; |
98 | 44 | } |
99 | 45 |
|
100 | 46 | Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) { |
101 | | - Tensor b_self, b_end; |
102 | | - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cpu"); |
103 | | - Tensor result = at::empty_like(b_self); |
104 | | - AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{ |
105 | | - lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>()); |
106 | | - }); |
| 47 | + Tensor result = at::empty({0}, self.options()); |
| 48 | + lerp_kernel_scalar_weight(kCPU, result, self, end, weight); |
107 | 49 | return result; |
108 | 50 | } |
109 | 51 |
|
| 52 | +DEFINE_DISPATCH(lerp_kernel_scalar_weight); |
| 53 | +DEFINE_DISPATCH(lerp_kernel_tensor_weight); |
| 54 | + |
110 | 55 | } // namespace native |
111 | 56 | } // namespace at |
0 commit comments