Skip to content

Batched Dataloader #26957

@gkossakowski

Description

@gkossakowski

🚀 Feature

Add a mode to Dataset that enables fetching data in batches instead of item-by-item.

Motivation

If model training takes relatively small individual examples as an input, like in the case of training on tabular data, the python interpreter overhead of fetching data becomes so large that hinders training performance. In other words, training becomes CPU-bound (even with multiprocessing enabled).

This came up in a real scenario of the StarSpace model from FAIR.

Pitch

Add an optional __getbatch__ method to the Dataset that's analogous to __getitem__ but takes a collection of indices as an input. Make the Dataloader aware of BatchedDataset. Once the Dataloader recognizes that the __getbatch__ is present, that method is used for fetching data, one batch at the time.

As a result, the user receives an ability to pass data in batch end-to-end and avoid the high cost (per byte read) of python interpreter.

I implemented a variant of batch loading for aforementioned StarSpace model and got the training down from 5.5 days to under 24 hours. The person who originally implemented it used standard PyTorch data loading abstractions and fall into the trap of low performance.

This is a type of issue anybody working on e.g. tabular data will be running into. Unfortunately, there's no natural way out given current PyTorch abstractions.

Alternatives

Implement this on top of existing abstractions by "smuggling" batches values wrapped as a single value and unwrapping them in a custom collate function. The code, that I provide below, is fairly subtle and a bit hacky (abusing current abstractions). The code is fully functional and used in production, though.

Edit: I found also this: #19228 which a different way of implementing what I need. The downside of IterableDataset is that it essentially throws through the window the nice decomposition into Dataset, Sampler and Dataloader. Suddenly, you're responsible for implementing all of the logic. Having said that, this is a big improvement over my rather hacky solution I posted below.

cc @ssnl

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: dataloaderRelated to torch.utils.data.DataLoader and SamplertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions