44 _remove_worker_pids , _error_if_any_worker_fails
55from .sampler import SequentialSampler , RandomSampler , BatchSampler
66import signal
7+ import functools
78import collections
89import re
910import sys
@@ -30,7 +31,7 @@ def __init__(self, exc_info):
3031 self .exc_msg = "" .join (traceback .format_exception (* exc_info ))
3132
3233
33- def _worker_loop (dataset , index_queue , data_queue , collate_fn ):
34+ def _worker_loop (dataset , index_queue , data_queue , collate_fn , seed , init_fn , worker_id ):
3435 global _use_shared_memory
3536 _use_shared_memory = True
3637
@@ -41,6 +42,11 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
4142 _set_worker_signal_handlers ()
4243
4344 torch .set_num_threads (1 )
45+ torch .manual_seed (seed )
46+
47+ if init_fn is not None :
48+ init_fn (worker_id )
49+
4450 while True :
4551 r = index_queue .get ()
4652 if r is None :
@@ -183,6 +189,7 @@ def __init__(self, loader):
183189 self .sample_iter = iter (self .batch_sampler )
184190
185191 if self .num_workers > 0 :
192+ self .worker_init_fn = loader .worker_init_fn
186193 self .index_queue = multiprocessing .SimpleQueue ()
187194 self .worker_result_queue = multiprocessing .SimpleQueue ()
188195 self .batches_outstanding = 0
@@ -192,15 +199,13 @@ def __init__(self, loader):
192199 self .rcvd_idx = 0
193200 self .reorder_dict = {}
194201
202+ base_seed = torch .LongTensor (1 ).random_ ()[0 ]
195203 self .workers = [
196204 multiprocessing .Process (
197205 target = _worker_loop ,
198- args = (self .dataset , self .index_queue , self .worker_result_queue , self .collate_fn ))
199- for _ in range (self .num_workers )]
200-
201- for w in self .workers :
202- w .daemon = True # ensure that the worker exits on process exit
203- w .start ()
206+ args = (self .dataset , self .index_queue , self .worker_result_queue , self .collate_fn ,
207+ base_seed + i , self .worker_init_fn , i ))
208+ for i in range (self .num_workers )]
204209
205210 if self .pin_memory or self .timeout > 0 :
206211 self .data_queue = queue .Queue ()
@@ -212,6 +217,10 @@ def __init__(self, loader):
212217 else :
213218 self .data_queue = self .worker_result_queue
214219
220+ for w in self .workers :
221+ w .daemon = True # ensure that the worker exits on process exit
222+ w .start ()
223+
215224 _update_worker_pids (id (self ), tuple (w .pid for w in self .workers ))
216225 _set_SIGCHLD_handler ()
217226 self .worker_pids_set = True
@@ -326,7 +335,7 @@ class DataLoader(object):
326335 indices at a time. Mutually exclusive with batch_size, shuffle,
327336 sampler, and drop_last.
328337 num_workers (int, optional): how many subprocesses to use for data
329- loading. 0 means that the data will be loaded in the main process
338+ loading. 0 means that the data will be loaded in the main process.
330339 (default: 0)
331340 collate_fn (callable, optional): merges a list of samples to form a mini-batch.
332341 pin_memory (bool, optional): If ``True``, the data loader will copy tensors
@@ -337,18 +346,31 @@ class DataLoader(object):
337346 will be smaller. (default: False)
338347 timeout (numeric, optional): if positive, the timeout value for collecting a batch
339348 from workers. Should always be non-negative. (default: 0)
349+ worker_init_fn (callable, optional): If not None, this will be called on each
350+ worker subprocess with the worker id as input, after seeding and before data
351+ loading. (default: None)
352+
353+ .. note:: By default, each worker will have its PyTorch seed set to
354+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
355+ by main process using its RNG. You may use ``torch.initial_seed()`` to access
356+ this value in :attr:`worker_init_fn`, which can be used to set other seeds
357+ (e.g. NumPy) before data loading.
358+
359+ .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
360+ unpicklable object, e.g., a lambda function.
340361 """
341362
342363 def __init__ (self , dataset , batch_size = 1 , shuffle = False , sampler = None , batch_sampler = None ,
343364 num_workers = 0 , collate_fn = default_collate , pin_memory = False , drop_last = False ,
344- timeout = 0 ):
365+ timeout = 0 , worker_init_fn = None ):
345366 self .dataset = dataset
346367 self .batch_size = batch_size
347368 self .num_workers = num_workers
348369 self .collate_fn = collate_fn
349370 self .pin_memory = pin_memory
350371 self .drop_last = drop_last
351372 self .timeout = timeout
373+ self .worker_init_fn = worker_init_fn
352374
353375 if timeout < 0 :
354376 raise ValueError ('timeout option should be non-negative' )
0 commit comments