Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 123 additions & 58 deletions aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#include <omp.h>
#endif

namespace {
const int MODE_SUM = 0;
const int MODE_MEAN = 1;
const int MODE_MAX = 2;
}

namespace at {
namespace native {

Expand Down Expand Up @@ -50,7 +56,7 @@ static void index_select_add(const Tensor &select_indices,
auto src_data = src.data<T>();
auto output_data = output.data<T>();
auto numel = add_indices.numel();
int64_t ddim = src.sizes()[1];
int64_t ddim = src.size(1);
for (int64_t i = 0; i < numel; i++) {
axpy<T>(ddim, 1, src_data + ddim * select_indices_data[i], 1,
output_data + ddim * add_indices_data[i], 1);
Expand All @@ -60,11 +66,11 @@ static void index_select_add(const Tensor &select_indices,
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] != 1) {
bag_size.slice(0, 0, bag_size.sizes()[0] - 1, 1) =
offsets.slice(0, 1, offsets.sizes()[0], 1) -
offsets.slice(0, 0, offsets.sizes()[0] - 1, 1);
bag_size[-1] = indices.sizes()[0] - offsets[-1];
if (offsets.size(0) != 1) {
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
offsets.slice(0, 1, offsets.size(0), 1) -
offsets.slice(0, 0, offsets.size(0) - 1, 1);
bag_size[-1] = indices.size(0) - offsets[-1];
}
}
}
Expand All @@ -73,8 +79,8 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &output,
const Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] == 1) {
auto bag_size_ = indices.sizes()[0];
if (offsets.size(0) == 1) {
auto bag_size_ = indices.size(0);
output /= bag_size_;
} else {
auto bag_size_ =
Expand All @@ -90,8 +96,8 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
Tensor &output, const Tensor &offset2bag,
const Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] == 1) {
auto bag_size_ = indices.sizes()[0];
if (offsets.size(0) == 1) {
auto bag_size_ = indices.size(0);
output /= bag_size_;
} else {
auto inv_bag_size_ = (1 / bag_size.toType(output.type()))
Expand All @@ -103,7 +109,48 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
return output;
}

std::tuple<Tensor, Tensor, Tensor>

template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {

auto max_indices = at::zeros(indices.type(), {offsets.size(0), weight.size(1)});

int64_t numel = indices.numel();
int64_t dims = weight.size(1);
auto indices_data = indices.data<int64_t>();
auto offset2bag_data = offset2bag.data<int64_t>();

auto max_indices_data = max_indices.data<int64_t>();
auto max_indices_stride = max_indices.stride(0);

auto weight_data = weight.data<scalar_t>();
auto output_data = output.data<scalar_t>();
auto weight_stride = weight.stride(0);
auto output_stride = output.stride(0);

for (int i = 0; i < numel; i++) {
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];


for (int dim = 0; dim < dims; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item = weight_data[weight_stride * word_idx + dim];

bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;

if (is_first_for_bag || weight_item > current_item) {
current_item = weight_item;
max_indices_data[max_indices_stride * bag + dim] = word_idx;
}
}
}

return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
}

std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
const Tensor &offsets__, const bool scale_grad_by_freq,
const int64_t mode, bool sparse) {
Expand All @@ -118,23 +165,34 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,

auto bag_size = at::zeros(indices.type(), offsets.sizes());
auto offset2bag =
at::zeros(indices__.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
at::zeros(indices__.type(), {indices.size(0)}); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag);
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
if (weight.type().scalarType() == kFloat) {
index_select_add<float>(indices, offset2bag, weight, output);
} else if (weight.type().scalarType() == kDouble) {
index_select_add<double>(indices, offset2bag, weight, output);
auto output = at::zeros(weight.type(), {offsets.size(0), weight.size(1)});

if (mode == MODE_MEAN || mode == MODE_SUM) {
if (weight.type().scalarType() == kFloat) {
index_select_add<float>(indices, offset2bag, weight, output);
} else if (weight.type().scalarType() == kDouble) {
index_select_add<double>(indices, offset2bag, weight, output);
}
make_bag_size(offsets, indices, mode, bag_size);
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
} else { // MODE_MAX
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.type(), "embedding_bag_cpu_max", [&]() {
return embedding_bag_cpu_max<scalar_t>(weight, indices, offset2bag, output, bag_size, offsets);
}
);
}
make_bag_size(offsets, indices, mode, bag_size);
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
}

Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
const Tensor &offsets__,
const Tensor &offset2bag__,
const Tensor &bag_size_, int64_t num_weights,
const Tensor &bag_size_,
const Tensor &max_indices_,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
bool sparse) {
auto indices_arg = TensorArg(indices__, "indices__", 1);
Expand All @@ -153,15 +211,16 @@ Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
scale_grad_by_freq, mode);
} else {
return at::embedding_bag_dense_backward(
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights,
scale_grad_by_freq, mode);
}
}

Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
const Tensor &offsets__,
const Tensor &offset2bag__,
const Tensor &bag_size_, int64_t num_weights,
const Tensor &bag_size_,
const Tensor& max_indices_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode) {
auto grad = grad_.contiguous();
auto grad_arg = TensorArg(grad, "grad_", 1);
Expand Down Expand Up @@ -196,6 +255,9 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
counts[indices_data[i]]++;
}

auto index_grad_weight =
at::zeros(grad.type(), {num_weights, grad.size(1)}).contiguous();

std::vector<int64_t> counts_uniq;
counts_uniq.reserve(num_weights);
int64_t o = 0;
Expand All @@ -207,43 +269,46 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
o++;
}

auto index_grad_weight =
at::zeros(grad.type(), {num_weights, grad.sizes()[1]}).contiguous();

#pragma omp parallel for if (numel > 1000)
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
int64_t index = indices_data[start];
for (int64_t j = start; j < counts_uniq[i]; j++) {
int64_t source = offset2bag_data[j];
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.sizes()[0] == 1) {
auto bag_size = indices.sizes()[0];
scale /= bag_size;
} else {
if (source == offsets_.sizes()[0] - 1) {
scale /= indices.sizes()[0] - offsets_data[offsets_.sizes()[0] - 1];
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
if (mode == MODE_MEAN || mode == MODE_SUM) {
#pragma omp parallel for if (numel > 1000)
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
int64_t index = indices_data[start];
for (int64_t j = start; j < counts_uniq[i]; j++) {
int64_t source = offset2bag_data[j];
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.size(0) == 1) {
auto bag_size = indices.size(0);
scale /= bag_size;
} else {
if (source == offsets_.size(0) - 1) {
scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
}
}
}
int64_t ddim = grad.size(1);
if (grad.type().scalarType() == kFloat) {
auto igwd = index_grad_weight.data<float>();
auto gd = grad.data<float>();
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else if (grad.type().scalarType() == kDouble) {
auto igwd = index_grad_weight.data<double>();
auto gd = grad.data<double>();
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
}
}
int64_t ddim = grad.sizes()[1];
if (grad.type().scalarType() == kFloat) {
auto igwd = index_grad_weight.data<float>();
auto gd = grad.data<float>();
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else if (grad.type().scalarType() == kDouble) {
auto igwd = index_grad_weight.data<double>();
auto gd = grad.data<double>();
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
}
} else if (mode == MODE_MAX) {
for (int64_t dim = 0; dim < grad.size(1); dim++) {
index_grad_weight.select(1, dim).index_add_(0, max_indices_.select(1, dim), grad_.select(1, dim));
}
}

Expand Down
Loading