Skip to content

Commit 48b0d07

Browse files
Rebase + don't waste memory when not in max mode
1 parent fccbe57 commit 48b0d07

File tree

2 files changed

+12
-25
lines changed

2 files changed

+12
-25
lines changed

aten/src/ATen/native/EmbeddingBag.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
177177
}
178178
make_bag_size(offsets, indices, mode, bag_size);
179179
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
180-
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
180+
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
181181
} else { // MODE_MAX
182182
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
183183
weight.type(), "embedding_bag_cpu_max", [&]() {
@@ -296,37 +296,19 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
296296
if (grad.type().scalarType() == kFloat) {
297297
auto igwd = index_grad_weight.data<float>();
298298
auto gd = grad.data<float>();
299-
THFloatBlas_axpy(ddim, (float)scale, gd + ddim * source, 1,
300-
igwd + ddim * index, 1);
299+
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
300+
igwd + ddim * index, 1);
301301
} else if (grad.type().scalarType() == kDouble) {
302302
auto igwd = index_grad_weight.data<double>();
303303
auto gd = grad.data<double>();
304-
THDoubleBlas_axpy(ddim, (double)scale, gd + ddim * source, 1,
305-
igwd + ddim * index, 1);
306-
} else {
307-
index_grad_weight[index].add_(grad[source], scale);
304+
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
305+
igwd + ddim * index, 1);
308306
}
309307
}
310-
<<<<<<< HEAD
311-
}
312-
int64_t ddim = grad.sizes()[1];
313-
if (grad.type().scalarType() == kFloat) {
314-
auto igwd = index_grad_weight.data<float>();
315-
auto gd = grad.data<float>();
316-
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
317-
igwd + ddim * index, 1);
318-
} else if (grad.type().scalarType() == kDouble) {
319-
auto igwd = index_grad_weight.data<double>();
320-
auto gd = grad.data<double>();
321-
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
322-
igwd + ddim * index, 1);
323-
}
324-
=======
325308
}
326309
} else if (mode == MODE_MAX) {
327310
for (int64_t dim = 0; dim < grad.sizes()[1]; dim++) {
328311
index_grad_weight.select(1, dim).index_add_(0, max_indices_.select(1, dim), grad_.select(1, dim));
329-
>>>>>>> Add max mode support to EmbeddingBag
330312
}
331313
}
332314

aten/src/ATen/native/cuda/EmbeddingBag.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,12 @@ embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
330330
cudaStream_t stream = globalContext().getCurrentCUDAStream();
331331

332332
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
333-
auto max_indices = at::zeros(indices.type(), {offsets.sizes()[0], weight.sizes()[1]});
333+
334+
Tensor max_indices;
335+
336+
if (mode == MODE_MAX) {
337+
max_indices = at::zeros(indices.type(), {offsets.sizes()[0], weight.sizes()[1]});
338+
}
334339

335340
dim3 block = dim3(32, 8);
336341
int grid = 1024;
@@ -340,7 +345,7 @@ embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
340345
indices.data<int64_t>(), offsets.data<int64_t>(),
341346
weight.data<cuda_scalar_t>(), output.data<cuda_scalar_t>(),
342347
offset2bag.data<int64_t>(), numIndices, numBags, stride, mode,
343-
bag_size.data<int64_t>(), max_indices.data<int64_t>());
348+
bag_size.data<int64_t>(), mode == MODE_MAX ? max_indices.data<int64_t>() : NULL);
344349
});
345350

346351
THCudaCheck(cudaGetLastError());

0 commit comments

Comments
 (0)