Skip to content

Commit ee00a80

Browse files
EthanSteinbergsoumith
authored andcommitted
Add max pooling support to EmbeddingBag (#5725)
* Add max mode support to EmbeddingBag * Lint fix * Fix compilation issue on other platforms * Rebase + don't waste memory when not in max mode * Oops, missed a spot * Fix whitespace from merge * less precision * Lower precision to avoid spurious failures * Minor typo * Switch to size()
1 parent 49f8732 commit ee00a80

File tree

8 files changed

+341
-138
lines changed

8 files changed

+341
-138
lines changed

aten/src/ATen/native/EmbeddingBag.cpp

Lines changed: 123 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#include <omp.h>
1515
#endif
1616

17+
namespace {
18+
const int MODE_SUM = 0;
19+
const int MODE_MEAN = 1;
20+
const int MODE_MAX = 2;
21+
}
22+
1723
namespace at {
1824
namespace native {
1925

@@ -50,7 +56,7 @@ static void index_select_add(const Tensor &select_indices,
5056
auto src_data = src.data<T>();
5157
auto output_data = output.data<T>();
5258
auto numel = add_indices.numel();
53-
int64_t ddim = src.sizes()[1];
59+
int64_t ddim = src.size(1);
5460
for (int64_t i = 0; i < numel; i++) {
5561
axpy<T>(ddim, 1, src_data + ddim * select_indices_data[i], 1,
5662
output_data + ddim * add_indices_data[i], 1);
@@ -60,11 +66,11 @@ static void index_select_add(const Tensor &select_indices,
6066
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
6167
const int64_t mode, Tensor &bag_size) {
6268
if (mode == 1) { // MODE_MEAN
63-
if (offsets.sizes()[0] != 1) {
64-
bag_size.slice(0, 0, bag_size.sizes()[0] - 1, 1) =
65-
offsets.slice(0, 1, offsets.sizes()[0], 1) -
66-
offsets.slice(0, 0, offsets.sizes()[0] - 1, 1);
67-
bag_size[-1] = indices.sizes()[0] - offsets[-1];
69+
if (offsets.size(0) != 1) {
70+
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
71+
offsets.slice(0, 1, offsets.size(0), 1) -
72+
offsets.slice(0, 0, offsets.size(0) - 1, 1);
73+
bag_size[-1] = indices.size(0) - offsets[-1];
6874
}
6975
}
7076
}
@@ -73,8 +79,8 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
7379
const int64_t mode, Tensor &output,
7480
const Tensor &bag_size) {
7581
if (mode == 1) { // MODE_MEAN
76-
if (offsets.sizes()[0] == 1) {
77-
auto bag_size_ = indices.sizes()[0];
82+
if (offsets.size(0) == 1) {
83+
auto bag_size_ = indices.size(0);
7884
output /= bag_size_;
7985
} else {
8086
auto bag_size_ =
@@ -90,8 +96,8 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
9096
Tensor &output, const Tensor &offset2bag,
9197
const Tensor &bag_size) {
9298
if (mode == 1) { // MODE_MEAN
93-
if (offsets.sizes()[0] == 1) {
94-
auto bag_size_ = indices.sizes()[0];
99+
if (offsets.size(0) == 1) {
100+
auto bag_size_ = indices.size(0);
95101
output /= bag_size_;
96102
} else {
97103
auto inv_bag_size_ = (1 / bag_size.toType(output.type()))
@@ -103,7 +109,48 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
103109
return output;
104110
}
105111

106-
std::tuple<Tensor, Tensor, Tensor>
112+
113+
template <typename scalar_t>
114+
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
115+
const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
116+
117+
auto max_indices = at::zeros(indices.type(), {offsets.size(0), weight.size(1)});
118+
119+
int64_t numel = indices.numel();
120+
int64_t dims = weight.size(1);
121+
auto indices_data = indices.data<int64_t>();
122+
auto offset2bag_data = offset2bag.data<int64_t>();
123+
124+
auto max_indices_data = max_indices.data<int64_t>();
125+
auto max_indices_stride = max_indices.stride(0);
126+
127+
auto weight_data = weight.data<scalar_t>();
128+
auto output_data = output.data<scalar_t>();
129+
auto weight_stride = weight.stride(0);
130+
auto output_stride = output.stride(0);
131+
132+
for (int i = 0; i < numel; i++) {
133+
auto bag = offset2bag_data[i];
134+
auto word_idx = indices_data[i];
135+
136+
137+
for (int dim = 0; dim < dims; dim++) {
138+
auto& current_item = output_data[output_stride * bag + dim];
139+
auto weight_item = weight_data[weight_stride * word_idx + dim];
140+
141+
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
142+
143+
if (is_first_for_bag || weight_item > current_item) {
144+
current_item = weight_item;
145+
max_indices_data[max_indices_stride * bag + dim] = word_idx;
146+
}
147+
}
148+
}
149+
150+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
151+
}
152+
153+
std::tuple<Tensor, Tensor, Tensor, Tensor>
107154
embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
108155
const Tensor &offsets__, const bool scale_grad_by_freq,
109156
const int64_t mode, bool sparse) {
@@ -118,23 +165,34 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
118165

119166
auto bag_size = at::zeros(indices.type(), offsets.sizes());
120167
auto offset2bag =
121-
at::zeros(indices__.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
168+
at::zeros(indices__.type(), {indices.size(0)}); // offset2bag = [0 0 0 0 0]
122169
make_offset2bag(offsets, indices, offset2bag);
123-
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
124-
if (weight.type().scalarType() == kFloat) {
125-
index_select_add<float>(indices, offset2bag, weight, output);
126-
} else if (weight.type().scalarType() == kDouble) {
127-
index_select_add<double>(indices, offset2bag, weight, output);
170+
auto output = at::zeros(weight.type(), {offsets.size(0), weight.size(1)});
171+
172+
if (mode == MODE_MEAN || mode == MODE_SUM) {
173+
if (weight.type().scalarType() == kFloat) {
174+
index_select_add<float>(indices, offset2bag, weight, output);
175+
} else if (weight.type().scalarType() == kDouble) {
176+
index_select_add<double>(indices, offset2bag, weight, output);
177+
}
178+
make_bag_size(offsets, indices, mode, bag_size);
179+
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
180+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
181+
} else { // MODE_MAX
182+
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
183+
weight.type(), "embedding_bag_cpu_max", [&]() {
184+
return embedding_bag_cpu_max<scalar_t>(weight, indices, offset2bag, output, bag_size, offsets);
185+
}
186+
);
128187
}
129-
make_bag_size(offsets, indices, mode, bag_size);
130-
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
131-
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
132188
}
133189

134190
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
135191
const Tensor &offsets__,
136192
const Tensor &offset2bag__,
137-
const Tensor &bag_size_, int64_t num_weights,
193+
const Tensor &bag_size_,
194+
const Tensor &max_indices_,
195+
int64_t num_weights,
138196
bool scale_grad_by_freq, int64_t mode,
139197
bool sparse) {
140198
auto indices_arg = TensorArg(indices__, "indices__", 1);
@@ -153,15 +211,16 @@ Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
153211
scale_grad_by_freq, mode);
154212
} else {
155213
return at::embedding_bag_dense_backward(
156-
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
214+
grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights,
157215
scale_grad_by_freq, mode);
158216
}
159217
}
160218

161219
Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
162220
const Tensor &offsets__,
163221
const Tensor &offset2bag__,
164-
const Tensor &bag_size_, int64_t num_weights,
222+
const Tensor &bag_size_,
223+
const Tensor& max_indices_, int64_t num_weights,
165224
bool scale_grad_by_freq, int64_t mode) {
166225
auto grad = grad_.contiguous();
167226
auto grad_arg = TensorArg(grad, "grad_", 1);
@@ -196,6 +255,9 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
196255
counts[indices_data[i]]++;
197256
}
198257

258+
auto index_grad_weight =
259+
at::zeros(grad.type(), {num_weights, grad.size(1)}).contiguous();
260+
199261
std::vector<int64_t> counts_uniq;
200262
counts_uniq.reserve(num_weights);
201263
int64_t o = 0;
@@ -207,43 +269,46 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
207269
o++;
208270
}
209271

210-
auto index_grad_weight =
211-
at::zeros(grad.type(), {num_weights, grad.sizes()[1]}).contiguous();
212-
213-
#pragma omp parallel for if (numel > 1000)
214-
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
215-
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
216-
int64_t index = indices_data[start];
217-
for (int64_t j = start; j < counts_uniq[i]; j++) {
218-
int64_t source = offset2bag_data[j];
219-
double scale = 1.0;
220-
if (scale_grad_by_freq) {
221-
scale /= counts[indices_data[i]];
222-
}
223-
if (mode == 1) { // MODE_MEAN
224-
if (offsets_.sizes()[0] == 1) {
225-
auto bag_size = indices.sizes()[0];
226-
scale /= bag_size;
227-
} else {
228-
if (source == offsets_.sizes()[0] - 1) {
229-
scale /= indices.sizes()[0] - offsets_data[offsets_.sizes()[0] - 1];
230-
} else {
231-
scale /= offsets_data[source + 1] - offsets_data[source];
272+
if (mode == MODE_MEAN || mode == MODE_SUM) {
273+
#pragma omp parallel for if (numel > 1000)
274+
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
275+
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
276+
int64_t index = indices_data[start];
277+
for (int64_t j = start; j < counts_uniq[i]; j++) {
278+
int64_t source = offset2bag_data[j];
279+
double scale = 1.0;
280+
if (scale_grad_by_freq) {
281+
scale /= counts[indices_data[i]];
282+
}
283+
if (mode == 1) { // MODE_MEAN
284+
if (offsets_.size(0) == 1) {
285+
auto bag_size = indices.size(0);
286+
scale /= bag_size;
287+
} else {
288+
if (source == offsets_.size(0) - 1) {
289+
scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
290+
} else {
291+
scale /= offsets_data[source + 1] - offsets_data[source];
292+
}
293+
}
294+
}
295+
int64_t ddim = grad.size(1);
296+
if (grad.type().scalarType() == kFloat) {
297+
auto igwd = index_grad_weight.data<float>();
298+
auto gd = grad.data<float>();
299+
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
300+
igwd + ddim * index, 1);
301+
} else if (grad.type().scalarType() == kDouble) {
302+
auto igwd = index_grad_weight.data<double>();
303+
auto gd = grad.data<double>();
304+
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
305+
igwd + ddim * index, 1);
232306
}
233307
}
234-
}
235-
int64_t ddim = grad.sizes()[1];
236-
if (grad.type().scalarType() == kFloat) {
237-
auto igwd = index_grad_weight.data<float>();
238-
auto gd = grad.data<float>();
239-
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
240-
igwd + ddim * index, 1);
241-
} else if (grad.type().scalarType() == kDouble) {
242-
auto igwd = index_grad_weight.data<double>();
243-
auto gd = grad.data<double>();
244-
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
245-
igwd + ddim * index, 1);
246-
}
308+
}
309+
} else if (mode == MODE_MAX) {
310+
for (int64_t dim = 0; dim < grad.size(1); dim++) {
311+
index_grad_weight.select(1, dim).index_add_(0, max_indices_.select(1, dim), grad_.select(1, dim));
247312
}
248313
}
249314

0 commit comments

Comments
 (0)