Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3235,33 +3235,40 @@ def pad(tensor, length):
return torch.cat(
[tensor.data, tensor.data.new(
length - tensor.size(0), *tensor.size()[1:]).zero_()])

# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])

# batch_first = true
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
padded = rnn_utils.pad_sequence([a, b, c], True)
expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
padded = rnn_utils.pad_sequence([b, a, c], True)
self.assertEqual(padded, expected)

# batch_first = false
padded = rnn_utils.pad_sequence([a, b, c])
padded = rnn_utils.pad_sequence([b, a, c])
self.assertEqual(padded, expected.transpose(0, 1))

# pad with non-zero value
expected = torch.tensor([[1, 2, 3], [4, 5, 1], [6, 1, 1]])
padded = rnn_utils.pad_sequence([a, b, c], True, 1)
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
self.assertEqual(padded, expected)

# Test pad sorted sequence
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
padded = rnn_utils.pad_sequence([a, b, c], True)
self.assertEqual(padded, expected)

# more dimensional
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences = []
trailing_dims = [4] * num_dim
for i in range(maxlen, 0, -1):
for i in range(1, maxlen + 1):
seq_len = i * i
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
random.shuffle(sequences)
expected = []
for seq in sequences:
expected.append(pad(seq, maxlen * maxlen))
Expand All @@ -3274,10 +3281,6 @@ def pad(tensor, length):
padded = rnn_utils.pad_sequence(sequences)
self.assertEqual(padded, expected.transpose(0, 1))

# unsorted sequences should raise exception
self.assertRaises(
ValueError, lambda: rnn_utils.pad_sequence([b, a, c], [2, 3, 1]))

def test_pack_sequence(self):
def _compatibility_test(sequences, lengths, batch_first):
padded = rnn_utils.pad_sequence(sequences, batch_first)
Expand Down
18 changes: 6 additions & 12 deletions torch/nn/utils/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,9 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise. The list of sequences should be sorted in the order of
decreasing length.
otherwise.

`B` is batch size. It's equal to the number of elements in ``sequences``.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Expand All @@ -292,15 +291,15 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):

Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` where `T` is the
length of longest sequence.
length of the longest sequence.
Function assumes trailing dimensions and type of all the Tensors
in sequences are same.

Arguments:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements.
padding_value (float, optional): value for padded elements. Default: 0.

Returns:
Tensor of size ``T x B x *`` if batch_first is False
Expand All @@ -310,8 +309,8 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
max_len, trailing_dims = max_size[0], max_size[1:]
prev_l = max_len
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
Expand All @@ -320,11 +319,6 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# temporary sort check, can be removed when we handle sorting internally
if prev_l < length:
raise ValueError(
"sequences must be sorted in the order of decreasing length")
prev_l = length
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
Expand Down