Skip to content

Reinitialize the random generator in worker processes of DataLoader #3880

@netheril96

Description

@netheril96

torch.utils.data.DataLoader can use multiprocessing to load and preprocess the data. It is commonly used to overlap the GPU computation and data loading. It has a minor flaw, though. On Unix, the processes are by default created with fork, so the global random state is copied. If each worker does something pseudo-randomly (common in data augmentation), their actions are all the same. For example, random crop is commonly used in image classification task, and in this case, all DataLoader workers will crop image in the same way, reducing the effectiveness of data augmentation.

I propose a new option worker_init_fn be added to DataLoader constructor, and the function is called with the process ID right in the beginning of each worker loop. By default, the function ignores the process ID and reseeds both pytorch and numpy default random generator. The users can override it to do other things that each worker might need differently.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions