Skip to content

Commit 0656ef4

Browse files
zou3519soumith
authored andcommitted
remove sort requirement from pad-sequence (#7928)
* pad-sequence no longer requires sorting entries pad-sequence can get the max_len from the list of sequences. entries only need to be sorted if output will be used for pack_padded_sequence, which can throw the error itself. * remove sort requirement from pad-sequence Picks up from #5974. Removes the requirement that input sequences to pad_sequence have to be sorted. Addressed the comments in the PR: - Updated docstring for pad_sequence - Remove sort requirement in pad_sequence test - Test unsorted and sorted sequences in pad_sequence test
1 parent c5b895a commit 0656ef4

File tree

2 files changed

+20
-23
lines changed

2 files changed

+20
-23
lines changed

test/test_nn.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3235,33 +3235,40 @@ def pad(tensor, length):
32353235
return torch.cat(
32363236
[tensor.data, tensor.data.new(
32373237
length - tensor.size(0), *tensor.size()[1:]).zero_()])
3238+
32383239
# single dimensional
32393240
a = torch.tensor([1, 2, 3])
32403241
b = torch.tensor([4, 5])
32413242
c = torch.tensor([6])
32423243

32433244
# batch_first = true
3244-
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
3245-
padded = rnn_utils.pad_sequence([a, b, c], True)
3245+
expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
3246+
padded = rnn_utils.pad_sequence([b, a, c], True)
32463247
self.assertEqual(padded, expected)
32473248

32483249
# batch_first = false
3249-
padded = rnn_utils.pad_sequence([a, b, c])
3250+
padded = rnn_utils.pad_sequence([b, a, c])
32503251
self.assertEqual(padded, expected.transpose(0, 1))
32513252

32523253
# pad with non-zero value
3253-
expected = torch.tensor([[1, 2, 3], [4, 5, 1], [6, 1, 1]])
3254-
padded = rnn_utils.pad_sequence([a, b, c], True, 1)
3254+
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
3255+
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
3256+
self.assertEqual(padded, expected)
3257+
3258+
# Test pad sorted sequence
3259+
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
3260+
padded = rnn_utils.pad_sequence([a, b, c], True)
32553261
self.assertEqual(padded, expected)
32563262

3257-
# more dimensional
3263+
# more dimensions
32583264
maxlen = 9
32593265
for num_dim in (0, 1, 2, 3):
32603266
sequences = []
32613267
trailing_dims = [4] * num_dim
3262-
for i in range(maxlen, 0, -1):
3268+
for i in range(1, maxlen + 1):
32633269
seq_len = i * i
32643270
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
3271+
random.shuffle(sequences)
32653272
expected = []
32663273
for seq in sequences:
32673274
expected.append(pad(seq, maxlen * maxlen))
@@ -3274,10 +3281,6 @@ def pad(tensor, length):
32743281
padded = rnn_utils.pad_sequence(sequences)
32753282
self.assertEqual(padded, expected.transpose(0, 1))
32763283

3277-
# unsorted sequences should raise exception
3278-
self.assertRaises(
3279-
ValueError, lambda: rnn_utils.pad_sequence([b, a, c], [2, 3, 1]))
3280-
32813284
def test_pack_sequence(self):
32823285
def _compatibility_test(sequences, lengths, batch_first):
32833286
padded = rnn_utils.pad_sequence(sequences, batch_first)

torch/nn/utils/rnn.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,9 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
274274
``pad_sequence`` stacks a list of Tensors along a new dimension,
275275
and pads them to equal length. For example, if the input is list of
276276
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
277-
otherwise. The list of sequences should be sorted in the order of
278-
decreasing length.
277+
otherwise.
279278
280-
`B` is batch size. It's equal to the number of elements in ``sequences``.
279+
`B` is batch size. It is equal to the number of elements in ``sequences``.
281280
`T` is length of the longest sequence.
282281
`L` is length of the sequence.
283282
`*` is any number of trailing dimensions, including none.
@@ -292,15 +291,15 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
292291
293292
Note:
294293
This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` where `T` is the
295-
length of longest sequence.
294+
length of the longest sequence.
296295
Function assumes trailing dimensions and type of all the Tensors
297296
in sequences are same.
298297
299298
Arguments:
300299
sequences (list[Tensor]): list of variable length sequences.
301300
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
302301
``T x B x *`` otherwise
303-
padding_value (float, optional): value for padded elements.
302+
padding_value (float, optional): value for padded elements. Default: 0.
304303
305304
Returns:
306305
Tensor of size ``T x B x *`` if batch_first is False
@@ -310,8 +309,8 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
310309
# assuming trailing dimensions and type of all the Tensors
311310
# in sequences are same and fetching those from sequences[0]
312311
max_size = sequences[0].size()
313-
max_len, trailing_dims = max_size[0], max_size[1:]
314-
prev_l = max_len
312+
trailing_dims = max_size[1:]
313+
max_len = max([s.size(0) for s in sequences])
315314
if batch_first:
316315
out_dims = (len(sequences), max_len) + trailing_dims
317316
else:
@@ -320,11 +319,6 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
320319
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
321320
for i, tensor in enumerate(sequences):
322321
length = tensor.size(0)
323-
# temporary sort check, can be removed when we handle sorting internally
324-
if prev_l < length:
325-
raise ValueError(
326-
"sequences must be sorted in the order of decreasing length")
327-
prev_l = length
328322
# use index notation to prevent duplicate references to the tensor
329323
if batch_first:
330324
out_tensor[i, :length, ...] = tensor

0 commit comments

Comments
 (0)