-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add length and padding keyworks to DistributedSampler #28841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add length and padding keyworks to DistributedSampler #28841
Conversation
Current implementation of `DistributedSampler` is ideal for distributed training using map datasets, as they fit in memory and have known size. However, it doesn't support distributed training using `IterableDataset` datasets, as these classes do not implement `__len__`. To fix that, a `length` keyword was added to `DistributedSampler`, which has precedence when set. An extra `padding=True` parameter was also added was give finer control on whether the (returned) index list should be padded by the sampler. This is useful for preventing duplicate reading on `IterableDataset` datasets that do not fit in memory or which data reading or transformation are expensive. Finally, set_rank method was added, similarly the existing `set_epoch`, to ease distributed training. When `DataLoader` is created with `num_workers` > 0 and `dataset` is an instance of `ChunkDataset`, a copy of `DistributedSampler` on each worker needs to be configured with their new rank. There is no back compatibility with this change.
|
Out of curiosity, are the |
Manually :) |
|
Thanks for splitting this out! |
|
@cpuhrsch Did you have a chance to check my last replies regarding using |
|
ping @fmassa :) |
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Sorry for the delay in reviewing.
I don't think that DistributedSampler should necessarily work with IterableDataset. The indices returned by the DistributedSampler are by definition useless to the general IterableDataset, and trying to use it to fit a particular use-case would add unnecessary constraints to IterableDataset.
My understanding is that this PR is trying to accomplish two things:
- add some meta-information to the
DistributedSamplerto work onIterableDatasets - (potentially) simplify the use of
IterableDatasetsin distributed mode (not present in this PR).
The concept of a sampler is not valid for IterableDataset in general -- we have no guarantees on the order of the examples that will be returned. I believe we should keep this as is in general.
Does this mean that no IterableDataset can know ahead of time of its iteration order? No, they can, but this is a special case, and should be handled by the application. I don't think this should live in PyTorch.
But then, how to make it easier for users to write their own IterableDataset that works on distributed?
What are the things we need to keep in mind in DDP in this case?
sampler.set_epoch, so that we can split the dataset between different machines- in particular, we need to handle the
idxof the worker ourselves.
How can we accomplish both, without having to change the APIs of DistributedSampler nor DataLoader?
Here is one example.
class MyIterableDataset(IterableDataset):
def __init__(self):
# each dataset, when constructed, know its rank
# and they are constructed once per (GPU) process
self.rank = dist.get_rank()
# how many are we?
self.world_size = dist.get_world_size()
# have a counter on how many epochs have passed
# can be incremented by get_chunk_iterator
# if the user wants different shuffle per epochs
self.epoch = 0
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
chunk_id = self.rank
total_chunks = self.world_size
if worker_info is not None:
chunk_id = self.rank * worker_info.num_workers + worker_info.id
total_chunks = self.world_size * worker_info.num_workers
# now return the chunks in the dataset accordingly
return iter(self.get_chunk_iterator(chunk_id, total_chunks))
def get_chunk_iterator(self, chunk_id, total_chunks):
# user implements this
passThen, all the logic specific to your application stays restricted to your Dataset implementation, via the get_chunk_iterator, which can handle buffering / shuffling / etc, and can also increment the epoch counter if the user wants.
I might be missing something, but let me know if the above implementation doesn't address your use-cases.
|
@fmassa - samplers are definitely relevant to IterableDatasets - there are various sampling techniques that apply to streams such as https://en.wikipedia.org/wiki/Reservoir_sampling |
|
@cpuhrsch I believe this should be an implementation detail of the specialization of the |
Thanks for the comment, Francisco! I agree that snippet work in some cases, but note the amount of boiler plate that was needed just to enable distributed training. More code would be needed for cases where As @cpuhrsch mentioned, there are several use cases for sampling within |
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpuhrsch
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
After reading through the thread, I am still a bit confused how this helps distributed sampler for IterableDatasets. Is this supposed to directly work already, or will there be follow-up patches?
The current PyTorch Sampler is different from the "sampling" in methods like reservoir sampling, in the sense that it only specifies the "indices" to sampler, rather than whether a sample should be kept. Due to this, in DataLoader, we currently disallow using IterableDataset with a custom sampler or batch_sampler. An example snippet would be really useful!
Finally, this needs a test to get in.
| """ | ||
|
|
||
| def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | ||
| def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IterableDataset can still implement __len__. I think it makes more sense for this sampler to assume that the dataset has __len__ than having an explicit input argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ssnl if the dataset can implement len then it makes more sense to extend Dataset as opposed to IterableDataset
This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at
|
Closing this for now |
|
Hi def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): could you assist me how I can implement it for my case? |
Current implementation of
DistributedSampleris ideal for distributedtraining using map datasets, as they fit in memory and have known size.
However, it doesn't support distributed training using
IterableDatasetdatasets, as these classes do not implement
__len__.To fix that, a
lengthkeyword was added toDistributedSampler, whichhas precedence when set.
An extra
padding=Trueparameter was also added was give finer controlon whether the (returned) index list should be padded by the sampler.
This is useful for preventing duplicate reading on
IterableDatasetdatasets that do not fit in memory or which data reading or transformation
are expensive.
Finally,
set_rankmethod was added to ease distributed training, allowingDataLoader's worker processes to register themselves onDistributedSamplerinstances throughworker_init_fnmethod. This is useful when worker processes want to change the sampling behavior based on not only in the process rank, but also on their worker ID. IterableDataset documentation mentions this issue, but the examples on it only handle dataset with known size, which is not always the case forIterableDatasetdataset.There is no back compatibility with this change.