Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ namespace native {

static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
Tensor &offset2bag) {
offset2bag.index_fill_(0, offsets, 1); // offset2bag = [1 0 1 0 1]
offset2bag[0] = 0; // offset2bag = [0 0 1 0 1]
offset2bag.index_add_(
0, offsets, at::ones_like(offsets)); // offset2bag = [1 0 1 0 1]
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
}

Expand Down Expand Up @@ -65,13 +66,14 @@ static void index_select_add(const Tensor &select_indices,

static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (mode == 1 || mode == 2) {
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
if (offsets.size(0) != 1) {
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
offsets.slice(0, 1, offsets.size(0), 1) -
offsets.slice(0, 0, offsets.size(0) - 1, 1);
bag_size[-1] = indices.size(0) - offsets[-1];
}
bag_size[-1] = indices.size(0) - offsets[-1];
}
}

Expand All @@ -83,8 +85,12 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
auto bag_size_ = indices.size(0);
output /= bag_size_;
} else {
auto bag_size_ =
bag_size.toType(output.type()).unsqueeze(1).expand_as(output);
// Avoid dividing by 0 for empty bags.
// Instead we want empty bags to return all 0s
auto bag_size_ = at::max(bag_size, at::ones_like(bag_size))
.toType(output.type())
.unsqueeze(1)
.expand_as(output);
output /= bag_size_;
}
}
Expand Down Expand Up @@ -113,7 +119,7 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {

auto max_indices = at::zeros(indices.type(), {offsets.size(0), weight.size(1)});

int64_t numel = indices.numel();
Expand All @@ -125,9 +131,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
auto max_indices_stride = max_indices.stride(0);

auto weight_data = weight.data<scalar_t>();
auto output_data = output.data<scalar_t>();
auto weight_stride = weight.stride(0);
auto output_stride = output.stride(0);
auto output_data = output.data<scalar_t>();
auto weight_stride = weight.stride(0);
auto output_stride = output.stride(0);

for (int i = 0; i < numel; i++) {
auto bag = offset2bag_data[i];
Expand All @@ -137,7 +143,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
for (int dim = 0; dim < dims; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item = weight_data[weight_stride * word_idx + dim];

bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;

if (is_first_for_bag || weight_item > current_item) {
Expand All @@ -164,9 +170,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});

auto bag_size = at::zeros(indices.type(), offsets.sizes());
auto offset2bag =
at::zeros(indices__.type(), {indices.size(0)}); // offset2bag = [0 0 0 0 0]
make_bag_size(offsets, indices, mode, bag_size);

// If the last entries are empty, that the last offsets are irrelevant as they
// won't change anything in the assignment of ID -> bag, but index_add would
// throw out of bounds error. So to keep it simple we just add one more
// entry to the end then get rid of it after make_offset2bag.
auto offset2bag = at::zeros(
indices__.type(), {indices.sizes()[0] + 1}); // offset2bag = [0 0 0 0 0]

make_offset2bag(offsets, indices, offset2bag);

offset2bag.resize_({indices.sizes()[0]});

auto output = at::zeros(weight.type(), {offsets.size(0), weight.size(1)});

if (mode == MODE_MEAN || mode == MODE_SUM) {
Expand All @@ -175,7 +191,6 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
} else if (weight.type().scalarType() == kDouble) {
index_select_add<double>(indices, offset2bag, weight, output);
}
make_bag_size(offsets, indices, mode, bag_size);
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
} else { // MODE_MAX
Expand All @@ -190,7 +205,7 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
const Tensor &offsets__,
const Tensor &offset2bag__,
const Tensor &bag_size_,
const Tensor &bag_size_,
const Tensor &max_indices_,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
Expand Down Expand Up @@ -305,10 +320,14 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
igwd + ddim * index, 1);
}
}
}
}
} else if (mode == MODE_MAX) {
auto nonempty_max_indices = max_indices_.index_select(0, bag_size_.nonzero().view(-1));
auto nonempty_grad = grad_.index_select(0, bag_size_.nonzero().view(-1));

for (int64_t dim = 0; dim < grad.size(1); dim++) {
index_grad_weight.select(1, dim).index_add_(0, max_indices_.select(1, dim), grad_.select(1, dim));
index_grad_weight.select(1, dim).index_add_(
0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
}
}

Expand Down
57 changes: 39 additions & 18 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,33 +1530,52 @@ def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype)
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
offsets = torch.tensor([0, 3], device=device, dtype=torch.long)
grad_output = torch.arange(1, 5, device=device, dtype=dtype).view(2, 2)

if mode == 'sum':
# Empty list is only handled in CPU for now
offsets = torch.tensor([0, 3], device=device, dtype=torch.long) if cuda \
else torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)

grad_output = torch.tensor(
[1, 2,
3, 4], device=device, dtype=dtype).view(2, 2)
grad_output_with_empty = torch.tensor(
[99, 99,
1, 2,
99, 99,
3, 4,
99, 99], device=device, dtype=dtype).view(5, 2)

if mode == "sum" or mode == "mean":
denominator = 1 if mode == "sum" else 3
expected_output = torch.tensor(
[[13, 16],
[13, 16]], device=device, dtype=dtype)
[13, 16]], device=device, dtype=dtype) / denominator

expected_output_with_empty = torch.tensor(
[[0, 0],
[13, 16],
[0, 0],
[13, 16],
[0, 0]], device=device, dtype=dtype) / denominator

expected_grad_weight = torch.tensor(
[[3, 4],
[5, 8],
[0, 0],
[1, 2],
[3, 4]], device=device, dtype=dtype)
elif mode == 'mean':
expected_output = torch.tensor(
[[13. / 3, 16. / 3],
[13. / 3, 16. / 3]], device=device, dtype=dtype)
expected_grad_weight = torch.tensor(
[[3. / 3, 4. / 3],
[1. / 3 + 1. / 3 + 3. / 3, 2. / 3 + 2. / 3 + 4. / 3],
[0., 0.],
[1. / 3, 2. / 3],
[3. / 3, 4. / 3]], device=device, dtype=dtype)
elif mode == 'max':
[3, 4]], device=device, dtype=dtype) / denominator
elif mode == "max":
expected_output = torch.tensor(
[[7, 8],
[9, 10]], device=device, dtype=dtype)

expected_output_with_empty = torch.tensor(
[[0, 0],
[7, 8],
[0, 0],
[9, 10],
[0, 0]], device=device, dtype=dtype)

expected_grad_weight = torch.tensor(
[[0, 0],
[0, 0],
Expand All @@ -1565,12 +1584,14 @@ def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
[3, 4]], device=device, dtype=dtype)

output = es(input, offsets)
output.backward(grad_output)
output.backward(grad_output if cuda else grad_output_with_empty)

es_weight_grad = es.weight.grad.data
if sparse:
es_weight_grad = es.weight.grad.data.to_dense()
self.assertEqual(output.data, expected_output)
self.assertEqual(
output.data,
expected_output if cuda else expected_output_with_empty)
self.assertEqual(es_weight_grad, expected_grad_weight, dtype2prec[dtype])

# check same example except as 2D (2 x 3)
Expand Down