Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Jul 31, 2019

Copy-paste comment from code for reasoning:

            # NOTE [ IterableDataset and __len__ ]
            #
            # For `IterableDataset`, `__len__` could be inaccurate when one naively
            # does multi-processing data loading, since the samples will be duplicated.
            # However, no real use case should be actually using that behavior, so
            # it should count as a user error. We should generally trust user
            # code to do the proper thing (e.g., configure each replica differently
            # in `__iter__`), and give us the correct `__len__` if they choose to
            # implement it (this will still throw if the dataset does not implement
            # a `__len__`).
            #
            # To provide a further warning, we track if `__len__` was called on the
            # `DataLoader`, save the returned value in `self._len_called`, and warn
            # if the iterator ends up yielding more than this number of samples.

Fixes #30184

@pytorchbot pytorchbot added the module: dataloader Related to torch.utils.data.DataLoader and Sampler label Jul 31, 2019
@ssnl ssnl mentioned this pull request Jul 31, 2019
@ssnl ssnl requested a review from apaszke July 31, 2019 15:10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naively, I would have expected __len__ and __iter__ to agree here. But iter never returns. So how can there be a length? Am I misunderstanding the relationship between these magic methods?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry for the confusion. I didn't explain this clear enough.

len(data_loader) directly returns the len(self.index_sampler). Previously, just throws when using an IterableDataset because an _InfiniteConstantSampler is used for the dataset and its __len__ throws.

However, the IterableDataset provided by users could already implement __len__ itself. So what this patch does is that _InfiniteConstantSampler.__len__ now grabs the __len__ of the dataset (and throw if it is not implemented) so the data loader can have the provided __len__.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense, but it doesn't (as far as I can tell), explain what I was confused about in the first place, which is how __iter__ relates to __len__?

Copy link
Collaborator Author

@ssnl ssnl Jul 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The relation is not enforced strictly. We trust users to give a correct __len__ that matches what their __iter__ does, similar to what we did for old map-style dataset, i.e., trusting the dataset __len__ matches what their __getitem__ supports.

ezyang
ezyang previously requested changes Jul 31, 2019
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WorkerSpecificIterableDataset.__len__ seems reasonable, but I'm not so sure about the other one.

@ezyang ezyang dismissed their stale review July 31, 2019 17:18

this looks reasonable, but I am not too confident about this corner of the code. If apaszke doesn't get to looking at this in the near future bug me again

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why are we supposed to make this change. Having a finite length on an infinite sampler just seems wrong, and the weird interaction with multiprocessing opens up a huge possibility for user errors. I don't think it's worth it given that you can just do len(dataset) if you really want to use that number.

@ssnl
Copy link
Collaborator Author

ssnl commented Aug 1, 2019

@apaszke I can make it so this is not implemented via the sampler class (by moving it to DataLoader.__len__). The reasoning for this change is that so users don't have to do hacks like the following if they want a generic interface

dataloader_len = len(dataloader.dataset) if isinstance(dataloader.dataset, IterableDataset) else len(dataloader)

This only changes behavior when the user provided IterableDataset has an __len__, and I think that if they implement a __len__, we should use it, similar to we use that to compute DataLoader.__len__ for map-style datasets (with default samplers).

@ssnl
Copy link
Collaborator Author

ssnl commented Aug 1, 2019

@apaszke Thanks for reviewing. I fully understand your argument, and would love to hear what you think about my comment above. :)

@ssnl ssnl force-pushed the dl_itd_len branch 2 times, most recently from ac30c62 to 6c407d2 Compare August 1, 2019 17:11
@apaszke
Copy link
Contributor

apaszke commented Aug 2, 2019

The problem is that the __len__ for map-style datasets is very nicely handled such no matter whether you run it serially, with 10, 20 or a 100 workers, the data loader iterator will always return exactly as many batches/examples as its __len__ indicates.

This is definitely not the case for iterable datasets, which are generally harder to paralellize, and it will generally be hard to keep the length exactly as the __len__ reports it. Hence, I don't think this is a good idea.

I agree that the isinstance part is quite ugly, and I'm not sure what's the right solution here, but I'd never like us to have iterators inside PyTorch that declare a different number of elements to be returned than they actually yield (and this is impossible to ensure in the iterable case, because it's all due to user code).

@ssnl
Copy link
Collaborator Author

ssnl commented Aug 2, 2019

@apaszke That is a fair point. I can add some quite cheap checks that the __len__ and the actual numbers match if __len__ is implemented (i.e., for multiprocessing loading, this would be yield count <= len always, and yield count == len when all workers finish) and warn (raise?) otherwise. Would that sound a solution to you?

@apaszke
Copy link
Contributor

apaszke commented Aug 2, 2019

I'm just worried that this might end up being too strict... Like, maybe someone wants to have a rough idea about the number of samples that will be returned, so they implement that on the dataset, but at the same time they don't care enough to ensure that it matches exactly.

On the other hand, maybe having a warning would be a good way to highlight potential user errors in the way iterable datasets are replicated... I'm quite conflicted on this one.

@ssnl
Copy link
Collaborator Author

ssnl commented Aug 2, 2019

@apaszke We could go extra fancy and only warn if len(dataloader) is ever called. I'm not sure myself if it is a good idea haha.

@apaszke
Copy link
Contributor

apaszke commented Aug 2, 2019

Yeah I thought about that as well, but at this point it's just getting super complicated... But it makes more sense as an API.

@vincentqb vincentqb added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 2, 2019
@ssnl
Copy link
Collaborator Author

ssnl commented Dec 5, 2019

Finally got around to update this. @apaszke , I've implemented warning if len was called and more than len samples are yielded. Does this look good now?

@ssnl ssnl changed the title enable len(dataloader) for iterable dataset Enable len(dataloader) for iterable dataset Dec 5, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ailzhang merged this pull request in c37de32.

wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
Copy-paste comment from code for reasoning:

```
            # NOTE [ IterableDataset and __len__ ]
            #
            # For `IterableDataset`, `__len__` could be inaccurate when one naively
            # does multi-processing data loading, since the samples will be duplicated.
            # However, no real use case should be actually using that behavior, so
            # it should count as a user error. We should generally trust user
            # code to do the proper thing (e.g., configure each replica differently
            # in `__iter__`), and give us the correct `__len__` if they choose to
            # implement it (this will still throw if the dataset does not implement
            # a `__len__`).
            #
            # To provide a further warning, we track if `__len__` was called on the
            # `DataLoader`, save the returned value in `self._len_called`, and warn
            # if the iterator ends up yielding more than this number of samples.
```

Fixes pytorch#30184
Pull Request resolved: pytorch#23587

Differential Revision: D18852625

Pulled By: ailzhang

fbshipit-source-id: aea8d4d70c7f21aaa69b35908a6f43026493d826
@kuraga
Copy link
Contributor

kuraga commented Apr 2, 2024

# NOTE [ IterableDataset and __len__ ]

Related: #120139.
On torch.utils.data.Dataset itself and __len__, see #122410.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: dataloader Related to torch.utils.data.DataLoader and Sampler open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IterableDataset should be able to provide a length

8 participants