Skip to content

collate_fn returns subclass of torch.Tensor, but DataLoader transforms back to torch.Tensor #17716

@rubenvereecken

Description

@rubenvereecken

🐛 Bug / Unexpected Behaviour

I have a custom subclass of torch.Tensor: PackedSequences, which is meant to pack together sequences of variable length. I supply my DataLoader with a custom collate_fn which returns PackedSequences objects just fine. When I sample a batch, I find my PackedSequences objects have been converted to torch.Tensor, losing the elementary extra data in PackedSequences.

To Reproduce

I'm unsure how to create a minimal example with working data.

class PackedSequences(torch.Tensor):
    @staticmethod
    def __new__(cls, tensors, *args, **kwargs):
        flat = torch.cat(tensors)
        return super().__new__(cls, flat, *args, **kwargs)

    def __init__(self, tensors):
        self.lengths = [len(t) for t in tensors]

def collate_seq(batch):
   batch = PackedSequences(batch)
   assert type(batch) == PackedSequences
   return batch

dataset = None
dataloader = DataLoader(dataset, collate_fn=collate_seq)
batch = dataloader[0]
# type(batch[0]) == torch.Tensor

Expected behavior

I'd expect

batch = dataloader[0]
# type(batch[0]) == PackedSequences

Environment

PyTorch version: 0.4.1
OS: Linux Mint 18.3 Sylvia
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.6

cc @ssnl @VitalyFedyunin @ejguan @hameerabbasi @rgommers @peterbell10

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: __torch_function__module: dataloaderRelated to torch.utils.data.DataLoader and SamplertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions