Skip to content

Commit fe580e8

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Rewrite lerp operator to use TensorIterator and support compile-time vectorization. (#22038)
Summary: Get benefit from the compile time vectorization and multi-threading. Before: ```python In [1]: import torch In [2]: x = torch.randn(1000000) In [3]: y = torch.randn(1000000) In [4]: w = 0.7 In [5]: timeit torch.lerp(x, y, w) 2.29 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` After: ```python In [1]: import torch In [2]: x = torch.randn(1000000) In [3]: y = torch.randn(1000000) In [4]: w = 0.7 In [5]: timeit torch.lerp(x, y, w) 452 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` After with multi-processing: ```python In [1]: import torch In [2]: x = torch.randn(1000000) In [3]: y = torch.randn(1000000) In [4]: w = 0.7 In [5]: timeit torch.lerp(x, y, w) 167 µs ± 48.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` Pull Request resolved: #22038 Differential Revision: D15941468 Pulled By: VitalyFedyunin fbshipit-source-id: fa8a5126187df4e6c849452e035b00b22be25739
1 parent 2863052 commit fe580e8

File tree

3 files changed

+100
-68
lines changed

3 files changed

+100
-68
lines changed

aten/src/ATen/native/Lerp.cpp

Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,56 @@
1+
#include <ATen/native/Lerp.h>
2+
13
#include <ATen/ATen.h>
24
#include <ATen/CPUApplyUtils.h>
35
#include <ATen/NativeFunctions.h>
46
#include <ATen/Dispatch.h>
57
#include <ATen/ExpandUtils.h>
68

7-
namespace {
8-
template <typename scalar_t>
9-
void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, const at::Tensor& weight) {
10-
at::CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
11-
ret, self, end, weight,
12-
[](scalar_t& ret_val,
13-
const scalar_t& self_val,
14-
const scalar_t& end_val,
15-
const scalar_t& weight_val) {
16-
ret_val = (weight_val < 0.5) ?
17-
self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
18-
});
19-
}
20-
21-
template <typename scalar_t>
22-
void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, scalar_t weight_val) {
23-
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
24-
ret, self, end,
25-
[=](scalar_t& ret_val,
26-
const scalar_t& self_val,
27-
const scalar_t& end_val) {
28-
ret_val = (weight_val < 0.5) ?
29-
self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
30-
});
31-
}
32-
33-
} // namespace
34-
359
namespace at {
3610
namespace native {
3711

3812
Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self,
3913
const Tensor& end, const Tensor& weight) {
40-
Tensor b_self, b_end, b_weight;
4114
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
4215
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
43-
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cpu");
44-
result.resize_as_(b_self);
45-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
46-
lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
47-
});
16+
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
4817
return result;
4918
}
5019

5120
Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self,
5221
const Tensor& end, Scalar weight) {
53-
Tensor b_self, b_end;
54-
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cpu");
55-
result.resize_as_(b_self);
56-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
57-
lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
58-
});
22+
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
5923
return result;
6024
}
6125

6226
Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
63-
Tensor b_self, b_end, b_weight;
64-
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cpu");
65-
TORCH_CHECK(b_self.sizes() == self.sizes(),
66-
"output with shape ", self.sizes(),
67-
" doesn't match the broadcast shape ", b_self.sizes());
6827
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
6928
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
70-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
71-
lerp_cpu<scalar_t>(self, b_self, b_end, b_weight);
72-
});
29+
lerp_kernel_tensor_weight(kCPU, self, self, end, weight);
7330
return self;
7431
}
7532

7633
Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
77-
Tensor b_self, b_end;
78-
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cpu");
79-
TORCH_CHECK(b_self.sizes() == self.sizes(),
80-
"output with shape ", self.sizes(),
81-
" doesn't match the broadcast shape ", b_self.sizes());
82-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
83-
lerp_cpu<scalar_t>(self, b_self, b_end, weight.to<scalar_t>());
84-
});
34+
lerp_kernel_scalar_weight(kCPU, self, self, end, weight);
8535
return self;
8636
}
8737

8838
Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weight) {
89-
Tensor b_self, b_end, b_weight;
9039
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
9140
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
92-
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cpu");
93-
Tensor result = at::empty_like(b_self);
94-
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
95-
lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
96-
});
41+
Tensor result = at::empty({0}, self.options());
42+
lerp_kernel_tensor_weight(kCPU, result, self, end, weight);
9743
return result;
9844
}
9945

10046
Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) {
101-
Tensor b_self, b_end;
102-
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cpu");
103-
Tensor result = at::empty_like(b_self);
104-
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
105-
lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
106-
});
47+
Tensor result = at::empty({0}, self.options());
48+
lerp_kernel_scalar_weight(kCPU, result, self, end, weight);
10749
return result;
10850
}
10951

52+
DEFINE_DISPATCH(lerp_kernel_scalar_weight);
53+
DEFINE_DISPATCH(lerp_kernel_tensor_weight);
54+
11055
} // namespace native
11156
} // namespace at

aten/src/ATen/native/Lerp.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/native/DispatchStub.h>
5+
6+
namespace at {
7+
namespace native {
8+
9+
using lerp_fn_scalar = void (*)(
10+
at::Tensor& ret,
11+
const at::Tensor& self,
12+
const at::Tensor& end,
13+
Scalar weight);
14+
15+
using lerp_fn_tensor = void (*)(
16+
at::Tensor& ret,
17+
const at::Tensor& self,
18+
const at::Tensor& end,
19+
const at::Tensor& weights);
20+
21+
DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
22+
DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
23+
24+
} // namespace native
25+
} // namespace at
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include <ATen/ATen.h>
2+
3+
#include <ATen/Dispatch.h>
4+
#include <ATen/native/Lerp.h>
5+
#include <ATen/native/TensorIterator.h>
6+
#include <ATen/native/cpu/Loops.h>
7+
8+
namespace at {
9+
namespace native {
10+
namespace {
11+
12+
static void lerp_kernel_scalar(
13+
Tensor& ret,
14+
const Tensor& self,
15+
const Tensor& end,
16+
Scalar weight) {
17+
auto builder = at::TensorIterator::Builder();
18+
builder.add_output(ret);
19+
builder.add_input(self);
20+
builder.add_input(end);
21+
auto iter = builder.build();
22+
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] {
23+
scalar_t weight_val = weight.to<scalar_t>();
24+
at::native::cpu_kernel(
25+
*iter,
26+
[weight_val](scalar_t self_val, scalar_t end_val) {
27+
return (weight_val < 0.5)
28+
? self_val + weight_val * (end_val - self_val)
29+
: end_val - (end_val - self_val) * (1 - weight_val);
30+
});
31+
});
32+
}
33+
34+
static void lerp_kernel_tensor(
35+
Tensor& ret,
36+
const Tensor& self,
37+
const Tensor& end,
38+
const Tensor& weights) {
39+
auto builder = at::TensorIterator::Builder();
40+
builder.add_output(ret);
41+
builder.add_input(self);
42+
builder.add_input(end);
43+
builder.add_input(weights);
44+
auto iter = builder.build();
45+
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "lerp_kernel_tensor", [&] {
46+
at::native::cpu_kernel(
47+
*iter,
48+
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
49+
return (weight_val < 0.5)
50+
? self_val + weight_val * (end_val - self_val)
51+
: end_val - (end_val - self_val) * (1 - weight_val);
52+
});
53+
});
54+
}
55+
56+
} // anonymous namespace
57+
58+
REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_kernel_scalar);
59+
REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_kernel_tensor);
60+
61+
} // namespace native
62+
} // namespace at

0 commit comments

Comments
 (0)