Skip to content

Commit ef14590

Browse files
anderspapittoezyang
authored andcommitted
Support calling pack_padedd_sequence with a Variable lengths (#5113)
This was accidentally lost while addressing review comments on #4695 pack_padded_sequence may be called either with a list or with a Variable. If called with a list we convert to Variable internally. I added to test_nn to test the new codepath. The bug was also caught by the onnx-fb-universe tests (which rely on passing in Variable).
1 parent bf60329 commit ef14590

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

test/test_nn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3053,7 +3053,9 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
30533053
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5)
30543054

30553055
for module in (nn.RNN, nn.LSTM, nn.GRU):
3056-
for bias, bidirectional, batch_first, contig, variable_len in product((True, False), repeat=5):
3056+
for bias, bidirectional, batch_first, contig, variable_len, lens_as_variable \
3057+
in product((True, False), repeat=6):
3058+
30573059
num_directions = 2 if bidirectional else 1
30583060
if batch_first:
30593061
input_val = torch.randn(batch, seq_length, input_size)
@@ -3073,6 +3075,8 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
30733075

30743076
if variable_len:
30753077
lengths = [7, 5, 5, 2, 1, 1]
3078+
if lens_as_variable:
3079+
lengths = Variable(torch.LongTensor(lengths))
30763080
input_val = Variable(input_val)
30773081
grad_output = Variable(grad_output)
30783082
input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)

torch/nn/_functions/packing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def forward(ctx, input, lengths, batch_first):
1515
steps = []
1616
batch_sizes = []
1717

18-
# lengths is a Tensor, so we must convert to list before reversed()
19-
lengths_iter = reversed(list(lengths))
18+
# lengths is a Tensor, so we must convert to [int] before reversed()
19+
lengths_iter = reversed(lengths.tolist())
2020

2121
batch_size = input.size(1)
2222

torch/nn/utils/rnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def is_cuda(self):
9191

9292

9393
def _pack_padded_sequence(input, lengths, batch_first=False):
94+
if isinstance(lengths, list):
95+
lengths = Variable(torch.LongTensor(lengths))
96+
9497
data, batch_sizes = PackPadded.apply(input, lengths, batch_first)
9598

9699
return PackedSequence(data, batch_sizes)

0 commit comments

Comments
 (0)