Skip to content

Commit b3a77cc

Browse files
committed
fix sparse embedding backward when input contains only padding_idx
1 parent 4c81282 commit b3a77cc

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

aten/src/ATen/native/Embedding.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,19 @@ Tensor embedding_sparse_backward(
6666
grad = grad.index(c);
6767
}
6868

69-
int64_t num_features = grad.size(-1);
69+
int64_t num_features = grad_.size(-1);
7070
auto weight_size = std::array<int64_t, 2>{{ num_weights, num_features }};
71+
auto& dense_type = grad.type();
72+
auto& sparse_type = dense_type.toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU);
73+
74+
// check if all our grad come from padding_idx
75+
if (grad.numel() == 0) {
76+
return sparse_type.sparse_coo_tensor(indices_.type().tensor(),
77+
dense_type.tensor(), weight_size);
78+
}
7179

7280
auto index = indices.view({1, -1});
7381
auto values = grad.contiguous().view({-1, num_features});
74-
75-
auto& sparse_type = grad.type().toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU);
7682
return sparse_type.sparse_coo_tensor(index, values, weight_size);
7783
}
7884

test/test_nn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,18 @@ def test_embedding_padding_idx(self):
12911291
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=25)
12921292
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=-25)
12931293

1294+
# test backward when input contains padding_idx
1295+
padding_idx = 0
1296+
embedding = nn.Embedding(5, 2, padding_idx=padding_idx)
1297+
for n in (1, 2):
1298+
for other_indices in ([], [1, 3], [2]):
1299+
indices = torch.LongTensor(other_indices + [padding_idx] * n)
1300+
pre = embedding.weight[padding_idx].clone()
1301+
embedding(indices).sum().backward()
1302+
after = (embedding.weight + embedding.weight.grad)[padding_idx]
1303+
embedding.zero_grad()
1304+
self.assertEqual(after, pre)
1305+
12941306
def test_embedding_max_norm(self):
12951307
embedding = nn.Embedding(22, 5, max_norm=1.0)
12961308
input = Variable(torch.LongTensor([2, 8, 8, 6]))

torch/nn/modules/sparse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ class Embedding(Module):
1919
padding_idx (int, optional): If given, pads the output with zeros whenever it encounters the index.
2020
max_norm (float, optional): If given, will renormalize the embeddings to always have a norm lesser than this
2121
norm_type (float, optional): The p of the p-norm to compute for the max_norm option
22-
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
22+
scale_grad_by_freq (bool, optional): if given, this will scale gradients by the frequency of
2323
the words in the mini-batch.
24-
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
24+
sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
2525
more details regarding sparse gradients.
2626
2727
Attributes:
2828
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
2929
3030
Shape:
31-
- Input: LongTensor `(N, W)`, N = mini-batch, W = number of indices to extract per mini-batch
32-
- Output: `(N, W, embedding_dim)`
31+
- Input: LongTensor of arbitrary shape containing the indices to extract
32+
- Output: `(*, embedding_dim)`, where `*` is the input shape
3333
3434
Notes:
3535
Keep in mind that only a limited number of optimizers support
@@ -166,10 +166,10 @@ class EmbeddingBag(Module):
166166
embedding_dim (int): the size of each embedding vector
167167
max_norm (float, optional): If given, will renormalize the embeddings to always have a norm lesser than this
168168
norm_type (float, optional): The p of the p-norm to compute for the max_norm option
169-
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
169+
scale_grad_by_freq (bool, optional): if given, this will scale gradients by the frequency of
170170
the words in the dictionary.
171171
mode (string, optional): 'sum' | 'mean'. Specifies the way to reduce the bag. Default: 'mean'
172-
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
172+
sparse (bool, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
173173
more details regarding sparse gradients.
174174
175175
Attributes:

0 commit comments

Comments
 (0)