Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,7 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights

if isinstance(input_val, rnn_utils.PackedSequence):
input = rnn_utils.PackedSequence(
Variable(input_val.data, requires_grad=True), input_val.batch_sizes)
Variable(input_val.data.data, requires_grad=True), input_val.batch_sizes)
input_var = input.data
else:
input = Variable(input_val.clone(), requires_grad=True)
Expand Down
53 changes: 53 additions & 0 deletions torch/nn/_functions/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch.autograd import Function


class PackPadded(Function):
@staticmethod
def forward(ctx, input, lengths, batch_first):
if batch_first:
input = input.transpose(0, 1)

if lengths[-1] <= 0:
raise ValueError("Length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <= 0")

steps = []
batch_sizes = []
lengths_iter = reversed(lengths)
batch_size = input.size(1)

if len(lengths) != batch_size:
raise ValueError("Expected `len(lengths)` to be equal to batch_size, but got "
"{} (batch_size={}).".format(len(lengths), batch_size))

prev_l = 0
for i, l in enumerate(lengths_iter):
if l > prev_l:
c_batch_size = batch_size - i
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
batch_sizes.extend([c_batch_size] * (l - prev_l))
prev_l = l

elif prev_l > l:
raise ValueError("'lengths' array has to be sorted in decreasing order")

ctx.batch_sizes = batch_sizes
ctx.batch_first = batch_first
ctx.input_size = input.size()

return torch.cat(steps), torch.LongTensor(batch_sizes)

@staticmethod
def backward(ctx, grad_steps, grad_batch_sizes):
grad_input = grad_steps.new(*ctx.input_size).zero_()

offset = 0
for i, bs in enumerate(ctx.batch_sizes):
grad_input[i, :bs] = grad_steps[offset:offset + bs]
offset += bs

if ctx.batch_first:
grad_input = grad_input.transpose(0, 1)

return grad_input, None, None
29 changes: 5 additions & 24 deletions torch/nn/utils/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch.autograd import Variable


from .._functions.packing import PackPadded

PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])


Expand Down Expand Up @@ -56,30 +58,9 @@ def pack_padded_sequence(input, lengths, batch_first=False):
Returns:
a :class:`PackedSequence` object
"""
if lengths[-1] <= 0:
raise ValueError("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
if batch_first:
input = input.transpose(0, 1)

steps = []
batch_sizes = []
lengths_iter = reversed(lengths)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")

prev_l = 0
for i, l in enumerate(lengths_iter):
if l > prev_l:
c_batch_size = batch_size - i
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
batch_sizes.extend([c_batch_size] * (l - prev_l))
prev_l = l
elif prev_l > l: # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")

return PackedSequence(torch.cat(steps), batch_sizes)
data, batch_sizes = PackPadded.apply(input, lengths, batch_first)

return PackedSequence(data, list(batch_sizes.data))


def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0):
Expand Down