Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,24 @@ def test_embedding_padding_idx(self):
self.assertEqual(output[0][0].sum().data[0], 0)
self.assertEqual(output[1][2].sum().data[0], 0)

# negative indexing check for padding_idx
# padding_idx=-2, num_embeddings=10 ==> index 8 padded
embedding = nn.Embedding(10, 20, padding_idx=-2)
input = Variable(torch.LongTensor([[0, 2, 8, 5], [4, 8, 0, 9]]))
output = embedding(input)
self.assertEqual(output[0][2].sum().data[0], 0)
self.assertEqual(output[1][1].sum().data[0], 0)

embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True)
input = Variable(torch.LongTensor([[0, 2, 8, 5], [4, 8, 0, 9]]))
output = embedding(input)
self.assertEqual(output[0][2].sum().data[0], 0)
self.assertEqual(output[1][1].sum().data[0], 0)

# out of bounds check for padding_idx
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=25)
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=-25)

def test_embedding_max_norm(self):
embedding = nn.Embedding(22, 5, max_norm=1.0)
input = Variable(torch.LongTensor([2, 8, 8, 6]))
Expand Down
10 changes: 8 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,14 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,

"""
input = input.contiguous()
if padding_idx is None:
padding_idx = -1
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings'
padding_idx = weight.size(0) + padding_idx
elif padding_idx is None:
padding_idx = -1
if max_norm is not None:
with torch.no_grad():
torch._C._VariableBase.embedding_renorm_(weight, input, max_norm, norm_type)
Expand Down
6 changes: 6 additions & 0 deletions torch/nn/modules/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
Expand Down