-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enable len(dataloader) for iterable dataset #23587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/utils/data/dataloader.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naively, I would have expected __len__ and __iter__ to agree here. But iter never returns. So how can there be a length? Am I misunderstanding the relationship between these magic methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry for the confusion. I didn't explain this clear enough.
len(data_loader) directly returns the len(self.index_sampler). Previously, just throws when using an IterableDataset because an _InfiniteConstantSampler is used for the dataset and its __len__ throws.
However, the IterableDataset provided by users could already implement __len__ itself. So what this patch does is that _InfiniteConstantSampler.__len__ now grabs the __len__ of the dataset (and throw if it is not implemented) so the data loader can have the provided __len__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that makes sense, but it doesn't (as far as I can tell), explain what I was confused about in the first place, which is how __iter__ relates to __len__?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The relation is not enforced strictly. We trust users to give a correct __len__ that matches what their __iter__ does, similar to what we did for old map-style dataset, i.e., trusting the dataset __len__ matches what their __getitem__ supports.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WorkerSpecificIterableDataset.__len__ seems reasonable, but I'm not so sure about the other one.
this looks reasonable, but I am not too confident about this corner of the code. If apaszke doesn't get to looking at this in the near future bug me again
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why are we supposed to make this change. Having a finite length on an infinite sampler just seems wrong, and the weird interaction with multiprocessing opens up a huge possibility for user errors. I don't think it's worth it given that you can just do len(dataset) if you really want to use that number.
|
@apaszke I can make it so this is not implemented via the sampler class (by moving it to This only changes behavior when the user provided |
|
@apaszke Thanks for reviewing. I fully understand your argument, and would love to hear what you think about my comment above. :) |
ac30c62 to
6c407d2
Compare
|
The problem is that the This is definitely not the case for iterable datasets, which are generally harder to paralellize, and it will generally be hard to keep the length exactly as the I agree that the |
|
@apaszke That is a fair point. I can add some quite cheap checks that the |
|
I'm just worried that this might end up being too strict... Like, maybe someone wants to have a rough idea about the number of samples that will be returned, so they implement that on the dataset, but at the same time they don't care enough to ensure that it matches exactly. On the other hand, maybe having a warning would be a good way to highlight potential user errors in the way iterable datasets are replicated... I'm quite conflicted on this one. |
|
@apaszke We could go extra fancy and only warn if |
|
Yeah I thought about that as well, but at this point it's just getting super complicated... But it makes more sense as an API. |
|
Finally got around to update this. @apaszke , I've implemented warning if |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
Copy-paste comment from code for reasoning:
```
# NOTE [ IterableDataset and __len__ ]
#
# For `IterableDataset`, `__len__` could be inaccurate when one naively
# does multi-processing data loading, since the samples will be duplicated.
# However, no real use case should be actually using that behavior, so
# it should count as a user error. We should generally trust user
# code to do the proper thing (e.g., configure each replica differently
# in `__iter__`), and give us the correct `__len__` if they choose to
# implement it (this will still throw if the dataset does not implement
# a `__len__`).
#
# To provide a further warning, we track if `__len__` was called on the
# `DataLoader`, save the returned value in `self._len_called`, and warn
# if the iterator ends up yielding more than this number of samples.
```
Fixes pytorch#30184
Pull Request resolved: pytorch#23587
Differential Revision: D18852625
Pulled By: ailzhang
fbshipit-source-id: aea8d4d70c7f21aaa69b35908a6f43026493d826
Copy-paste comment from code for reasoning:
Fixes #30184