Skip to content

Commit b6adf68

Browse files
danielsimigsoumith
authored andcommitted
EmbeddingBag to handle empty bags in all modes (#7389)
1 parent 3f02922 commit b6adf68

File tree

2 files changed

+75
-35
lines changed

2 files changed

+75
-35
lines changed

aten/src/ATen/native/EmbeddingBag.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ namespace native {
2525

2626
static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
2727
Tensor &offset2bag) {
28-
offset2bag.index_fill_(0, offsets, 1); // offset2bag = [1 0 1 0 1]
29-
offset2bag[0] = 0; // offset2bag = [0 0 1 0 1]
28+
offset2bag.index_add_(
29+
0, offsets, at::ones_like(offsets)); // offset2bag = [1 0 1 0 1]
30+
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
3031
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
3132
}
3233

@@ -65,13 +66,14 @@ static void index_select_add(const Tensor &select_indices,
6566

6667
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
6768
const int64_t mode, Tensor &bag_size) {
68-
if (mode == 1) { // MODE_MEAN
69+
if (mode == 1 || mode == 2) {
70+
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
6971
if (offsets.size(0) != 1) {
7072
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
7173
offsets.slice(0, 1, offsets.size(0), 1) -
7274
offsets.slice(0, 0, offsets.size(0) - 1, 1);
73-
bag_size[-1] = indices.size(0) - offsets[-1];
7475
}
76+
bag_size[-1] = indices.size(0) - offsets[-1];
7577
}
7678
}
7779

@@ -83,8 +85,12 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
8385
auto bag_size_ = indices.size(0);
8486
output /= bag_size_;
8587
} else {
86-
auto bag_size_ =
87-
bag_size.toType(output.type()).unsqueeze(1).expand_as(output);
88+
// Avoid dividing by 0 for empty bags.
89+
// Instead we want empty bags to return all 0s
90+
auto bag_size_ = at::max(bag_size, at::ones_like(bag_size))
91+
.toType(output.type())
92+
.unsqueeze(1)
93+
.expand_as(output);
8894
output /= bag_size_;
8995
}
9096
}
@@ -113,7 +119,7 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
113119
template <typename scalar_t>
114120
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
115121
const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
116-
122+
117123
auto max_indices = at::zeros(indices.type(), {offsets.size(0), weight.size(1)});
118124

119125
int64_t numel = indices.numel();
@@ -125,9 +131,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
125131
auto max_indices_stride = max_indices.stride(0);
126132

127133
auto weight_data = weight.data<scalar_t>();
128-
auto output_data = output.data<scalar_t>();
129-
auto weight_stride = weight.stride(0);
130-
auto output_stride = output.stride(0);
134+
auto output_data = output.data<scalar_t>();
135+
auto weight_stride = weight.stride(0);
136+
auto output_stride = output.stride(0);
131137

132138
for (int i = 0; i < numel; i++) {
133139
auto bag = offset2bag_data[i];
@@ -137,7 +143,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
137143
for (int dim = 0; dim < dims; dim++) {
138144
auto& current_item = output_data[output_stride * bag + dim];
139145
auto weight_item = weight_data[weight_stride * word_idx + dim];
140-
146+
141147
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
142148

143149
if (is_first_for_bag || weight_item > current_item) {
@@ -164,9 +170,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
164170
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
165171

166172
auto bag_size = at::zeros(indices.type(), offsets.sizes());
167-
auto offset2bag =
168-
at::zeros(indices__.type(), {indices.size(0)}); // offset2bag = [0 0 0 0 0]
173+
make_bag_size(offsets, indices, mode, bag_size);
174+
175+
// If the last entries are empty, that the last offsets are irrelevant as they
176+
// won't change anything in the assignment of ID -> bag, but index_add would
177+
// throw out of bounds error. So to keep it simple we just add one more
178+
// entry to the end then get rid of it after make_offset2bag.
179+
auto offset2bag = at::zeros(
180+
indices__.type(), {indices.sizes()[0] + 1}); // offset2bag = [0 0 0 0 0]
181+
169182
make_offset2bag(offsets, indices, offset2bag);
183+
184+
offset2bag.resize_({indices.sizes()[0]});
185+
170186
auto output = at::zeros(weight.type(), {offsets.size(0), weight.size(1)});
171187

172188
if (mode == MODE_MEAN || mode == MODE_SUM) {
@@ -175,7 +191,6 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
175191
} else if (weight.type().scalarType() == kDouble) {
176192
index_select_add<double>(indices, offset2bag, weight, output);
177193
}
178-
make_bag_size(offsets, indices, mode, bag_size);
179194
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
180195
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
181196
} else { // MODE_MAX
@@ -190,7 +205,7 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
190205
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
191206
const Tensor &offsets__,
192207
const Tensor &offset2bag__,
193-
const Tensor &bag_size_,
208+
const Tensor &bag_size_,
194209
const Tensor &max_indices_,
195210
int64_t num_weights,
196211
bool scale_grad_by_freq, int64_t mode,
@@ -305,10 +320,14 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
305320
igwd + ddim * index, 1);
306321
}
307322
}
308-
}
323+
}
309324
} else if (mode == MODE_MAX) {
325+
auto nonempty_max_indices = max_indices_.index_select(0, bag_size_.nonzero().view(-1));
326+
auto nonempty_grad = grad_.index_select(0, bag_size_.nonzero().view(-1));
327+
310328
for (int64_t dim = 0; dim < grad.size(1); dim++) {
311-
index_grad_weight.select(1, dim).index_add_(0, max_indices_.select(1, dim), grad_.select(1, dim));
329+
index_grad_weight.select(1, dim).index_add_(
330+
0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
312331
}
313332
}
314333

test/test_nn.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,33 +1530,52 @@ def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
15301530
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype)
15311531
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
15321532
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
1533-
offsets = torch.tensor([0, 3], device=device, dtype=torch.long)
1534-
grad_output = torch.arange(1, 5, device=device, dtype=dtype).view(2, 2)
15351533

1536-
if mode == 'sum':
1534+
# Empty list is only handled in CPU for now
1535+
offsets = torch.tensor([0, 3], device=device, dtype=torch.long) if cuda \
1536+
else torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
1537+
1538+
grad_output = torch.tensor(
1539+
[1, 2,
1540+
3, 4], device=device, dtype=dtype).view(2, 2)
1541+
grad_output_with_empty = torch.tensor(
1542+
[99, 99,
1543+
1, 2,
1544+
99, 99,
1545+
3, 4,
1546+
99, 99], device=device, dtype=dtype).view(5, 2)
1547+
1548+
if mode == "sum" or mode == "mean":
1549+
denominator = 1 if mode == "sum" else 3
15371550
expected_output = torch.tensor(
15381551
[[13, 16],
1539-
[13, 16]], device=device, dtype=dtype)
1552+
[13, 16]], device=device, dtype=dtype) / denominator
1553+
1554+
expected_output_with_empty = torch.tensor(
1555+
[[0, 0],
1556+
[13, 16],
1557+
[0, 0],
1558+
[13, 16],
1559+
[0, 0]], device=device, dtype=dtype) / denominator
1560+
15401561
expected_grad_weight = torch.tensor(
15411562
[[3, 4],
15421563
[5, 8],
15431564
[0, 0],
15441565
[1, 2],
1545-
[3, 4]], device=device, dtype=dtype)
1546-
elif mode == 'mean':
1547-
expected_output = torch.tensor(
1548-
[[13. / 3, 16. / 3],
1549-
[13. / 3, 16. / 3]], device=device, dtype=dtype)
1550-
expected_grad_weight = torch.tensor(
1551-
[[3. / 3, 4. / 3],
1552-
[1. / 3 + 1. / 3 + 3. / 3, 2. / 3 + 2. / 3 + 4. / 3],
1553-
[0., 0.],
1554-
[1. / 3, 2. / 3],
1555-
[3. / 3, 4. / 3]], device=device, dtype=dtype)
1556-
elif mode == 'max':
1566+
[3, 4]], device=device, dtype=dtype) / denominator
1567+
elif mode == "max":
15571568
expected_output = torch.tensor(
15581569
[[7, 8],
15591570
[9, 10]], device=device, dtype=dtype)
1571+
1572+
expected_output_with_empty = torch.tensor(
1573+
[[0, 0],
1574+
[7, 8],
1575+
[0, 0],
1576+
[9, 10],
1577+
[0, 0]], device=device, dtype=dtype)
1578+
15601579
expected_grad_weight = torch.tensor(
15611580
[[0, 0],
15621581
[0, 0],
@@ -1565,12 +1584,14 @@ def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
15651584
[3, 4]], device=device, dtype=dtype)
15661585

15671586
output = es(input, offsets)
1568-
output.backward(grad_output)
1587+
output.backward(grad_output if cuda else grad_output_with_empty)
15691588

15701589
es_weight_grad = es.weight.grad.data
15711590
if sparse:
15721591
es_weight_grad = es.weight.grad.data.to_dense()
1573-
self.assertEqual(output.data, expected_output)
1592+
self.assertEqual(
1593+
output.data,
1594+
expected_output if cuda else expected_output_with_empty)
15741595
self.assertEqual(es_weight_grad, expected_grad_weight, dtype2prec[dtype])
15751596

15761597
# check same example except as 2D (2 x 3)

0 commit comments

Comments
 (0)