Skip to content

Commit c40b99f

Browse files
martinraisonsoumith
authored andcommitted
speed up CPU EmbeddingBag (indexSelectAdd op) (#5433)
* speed up CPU EmbeddingBag (indexSelectAdd op) * keep operator inside EmbeddingBag + speedup * comment * update checkScalarTypes signature * enforce type in embedding_bag_backward_cpu
1 parent ecffe53 commit c40b99f

File tree

3 files changed

+71
-13
lines changed

3 files changed

+71
-13
lines changed

aten/src/ATen/TensorUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) {
162162
}
163163
}
164164

165+
void checkScalarTypes(CheckedFrom c, const TensorArg& t,
166+
at::ArrayRef<ScalarType> l) {
167+
if (std::find(l.begin(), l.end(), t->type().scalarType()) == l.end()) {
168+
std::ostringstream oss;
169+
oss << "Expected tensor for " << t << " to have one of the following "
170+
<< "scalar types: ";
171+
size_t i = 0;
172+
for (auto ty : l) {
173+
if (i != 0) {
174+
oss << ", ";
175+
}
176+
oss << toString(ty);
177+
i++;
178+
}
179+
oss << "; but got " << t->toString()
180+
<< " instead (while checking arguments for " << c << ")";
181+
throw std::runtime_error(oss.str());
182+
}
183+
}
184+
165185
void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
166186
checkAllSame(c, tensors, checkSameType);
167187
}

aten/src/ATen/TensorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel);
6262
void checkSameNumel(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2);
6363
void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
6464
void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
65+
void checkScalarTypes(CheckedFrom c, const TensorArg& t, at::ArrayRef<ScalarType> l);
6566
void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2);
6667
void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
6768
void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2);
@@ -78,4 +79,3 @@ void * maybe_data_ptr(const Tensor& tensor);
7879
void * maybe_data_ptr(const TensorArg& tensor);
7980

8081
}
81-

aten/src/ATen/native/EmbeddingBag.cpp

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,39 @@ static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
2424
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
2525
}
2626

27+
template<typename T>
28+
static void axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);
29+
template<>
30+
void axpy<float>(int64_t n, float a, float *x, int64_t incx,
31+
float *y, int64_t incy) {
32+
THFloatBlas_axpy(n, a, x, incx, y, incy);
33+
}
34+
template<>
35+
void axpy<double>(int64_t n, double a, double *x, int64_t incx,
36+
double *y, int64_t incy) {
37+
THDoubleBlas_axpy(n, a, x, incx, y, incy);
38+
}
39+
40+
// This function combines index_select (using select_indices as the index) and
41+
// index_add (using add_indices as the index), without creating an intermediary
42+
// tensor to hold the selected embeddings
43+
template<typename T>
44+
static void index_select_add(const Tensor &select_indices,
45+
const Tensor &add_indices,
46+
const Tensor &src,
47+
Tensor &output) {
48+
auto add_indices_data = add_indices.data<int64_t>();
49+
auto select_indices_data = select_indices.data<int64_t>();
50+
auto src_data = src.data<T>();
51+
auto output_data = output.data<T>();
52+
auto numel = add_indices.numel();
53+
int64_t ddim = src.sizes()[1];
54+
for (int64_t i = 0; i < numel; i++) {
55+
axpy<T>(ddim, 1, src_data + ddim * select_indices_data[i], 1,
56+
output_data + ddim * add_indices_data[i], 1);
57+
}
58+
}
59+
2760
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
2861
const int64_t mode, Tensor &bag_size) {
2962
if (mode == 1) { // MODE_MEAN
@@ -61,10 +94,10 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
6194
auto bag_size_ = indices.sizes()[0];
6295
output /= bag_size_;
6396
} else {
64-
auto bag_size_ = bag_size.toType(output.type())
65-
.unsqueeze(1)
66-
.index_select(0, offset2bag);
67-
output /= bag_size_;
97+
auto inv_bag_size_ = (1 / bag_size.toType(output.type()))
98+
.unsqueeze(1)
99+
.index_select(0, offset2bag);
100+
output *= inv_bag_size_;
68101
}
69102
}
70103
return output;
@@ -80,14 +113,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
80113
checkScalarType("embedding_bag", offsets_arg, kLong);
81114
Tensor indices = indices__.contiguous();
82115
Tensor offsets = offsets__.contiguous();
116+
auto weight_arg = TensorArg(weight, "weight", 1);
117+
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
83118

84119
auto bag_size = at::zeros(indices.type(), offsets.sizes());
85120
auto offset2bag =
86121
at::zeros(indices__.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
87122
make_offset2bag(offsets, indices, offset2bag);
88123
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
89-
auto index_output = weight.index_select(0, indices);
90-
output.index_add_(0, offset2bag, index_output);
124+
if (weight.type().scalarType() == kFloat) {
125+
index_select_add<float>(indices, offset2bag, weight, output);
126+
} else if (weight.type().scalarType() == kDouble) {
127+
index_select_add<double>(indices, offset2bag, weight, output);
128+
}
91129
make_bag_size(offsets, indices, mode, bag_size);
92130
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
93131
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
@@ -126,6 +164,8 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
126164
const Tensor &bag_size_, int64_t num_weights,
127165
bool scale_grad_by_freq, int64_t mode) {
128166
auto grad = grad_.contiguous();
167+
auto grad_arg = TensorArg(grad, "grad_", 1);
168+
checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble});
129169
auto indices_arg = TensorArg(indices__, "indices__", 1);
130170
checkScalarType("embedding_bag", indices_arg, kLong);
131171
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
@@ -196,15 +236,13 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
196236
if (grad.type().scalarType() == kFloat) {
197237
auto igwd = index_grad_weight.data<float>();
198238
auto gd = grad.data<float>();
199-
THFloatBlas_axpy(ddim, (float)scale, gd + ddim * source, 1,
200-
igwd + ddim * index, 1);
239+
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
240+
igwd + ddim * index, 1);
201241
} else if (grad.type().scalarType() == kDouble) {
202242
auto igwd = index_grad_weight.data<double>();
203243
auto gd = grad.data<double>();
204-
THDoubleBlas_axpy(ddim, (double)scale, gd + ddim * source, 1,
205-
igwd + ddim * index, 1);
206-
} else {
207-
index_grad_weight[index].add_(grad[source], scale);
244+
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
245+
igwd + ddim * index, 1);
208246
}
209247
}
210248
}

0 commit comments

Comments
 (0)