Skip to content

Commit 5cc26c0

Browse files
ssnlsoumith
authored andcommitted
Add default PyTorch seeding and worker_init_fn to DataLoader (#4018)
* Add default PyTorch seeding and worker_init_fn to DataLoader * generate seed using current RNG each time * worker_seed <- main_proc_RNG_generated_seed + worker_id
1 parent 30e6898 commit 5cc26c0

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

test/test_dataloader.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,41 @@ def __len__(self):
112112
return self.size
113113

114114

115+
class SeedDataset(Dataset):
116+
117+
def __init__(self, size):
118+
self.size = size
119+
120+
def __getitem__(self, idx):
121+
return torch.initial_seed()
122+
123+
def __len__(self):
124+
return self.size
125+
126+
127+
# Inspired by https://stackoverflow.com/a/26703365
128+
# This will ensure that each worker at least processes one data
129+
class SynchronizedSeedDataset(Dataset):
130+
131+
def __init__(self, size, num_workers):
132+
assert size >= num_workers
133+
self.count = multiprocessing.Value('i', 0)
134+
self.barrier = multiprocessing.Semaphore(0)
135+
self.num_workers = num_workers
136+
self.size = size
137+
138+
def __getitem__(self, idx):
139+
self.count.value += 1
140+
if self.count.value == self.num_workers:
141+
self.barrier.release()
142+
self.barrier.acquire()
143+
self.barrier.release()
144+
return torch.initial_seed()
145+
146+
def __len__(self):
147+
return self.size
148+
149+
115150
class TestDataLoader(TestCase):
116151

117152
def setUp(self):
@@ -222,6 +257,27 @@ def _test_timeout():
222257
finally:
223258
p.terminate()
224259

260+
def test_worker_seed(self):
261+
num_workers = 6
262+
dataset = SynchronizedSeedDataset(num_workers, num_workers)
263+
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
264+
seeds = set()
265+
for batch in dataloader:
266+
seeds.add(batch[0])
267+
self.assertEqual(len(seeds), num_workers)
268+
269+
def test_worker_init_fn(self):
270+
# test custom init function
271+
def init_fn(worker_id):
272+
torch.manual_seed(12345)
273+
274+
dataset = SeedDataset(4)
275+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
276+
worker_init_fn=init_fn)
277+
for batch in dataloader:
278+
self.assertEqual(12345, batch[0])
279+
self.assertEqual(12345, batch[1])
280+
225281
def test_shuffle(self):
226282
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
227283

torch/utils/data/dataloader.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
_remove_worker_pids, _error_if_any_worker_fails
55
from .sampler import SequentialSampler, RandomSampler, BatchSampler
66
import signal
7+
import functools
78
import collections
89
import re
910
import sys
@@ -30,7 +31,7 @@ def __init__(self, exc_info):
3031
self.exc_msg = "".join(traceback.format_exception(*exc_info))
3132

3233

33-
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
34+
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
3435
global _use_shared_memory
3536
_use_shared_memory = True
3637

@@ -41,6 +42,11 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
4142
_set_worker_signal_handlers()
4243

4344
torch.set_num_threads(1)
45+
torch.manual_seed(seed)
46+
47+
if init_fn is not None:
48+
init_fn(worker_id)
49+
4450
while True:
4551
r = index_queue.get()
4652
if r is None:
@@ -183,6 +189,7 @@ def __init__(self, loader):
183189
self.sample_iter = iter(self.batch_sampler)
184190

185191
if self.num_workers > 0:
192+
self.worker_init_fn = loader.worker_init_fn
186193
self.index_queue = multiprocessing.SimpleQueue()
187194
self.worker_result_queue = multiprocessing.SimpleQueue()
188195
self.batches_outstanding = 0
@@ -192,15 +199,13 @@ def __init__(self, loader):
192199
self.rcvd_idx = 0
193200
self.reorder_dict = {}
194201

202+
base_seed = torch.LongTensor(1).random_()[0]
195203
self.workers = [
196204
multiprocessing.Process(
197205
target=_worker_loop,
198-
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn))
199-
for _ in range(self.num_workers)]
200-
201-
for w in self.workers:
202-
w.daemon = True # ensure that the worker exits on process exit
203-
w.start()
206+
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
207+
base_seed + i, self.worker_init_fn, i))
208+
for i in range(self.num_workers)]
204209

205210
if self.pin_memory or self.timeout > 0:
206211
self.data_queue = queue.Queue()
@@ -212,6 +217,10 @@ def __init__(self, loader):
212217
else:
213218
self.data_queue = self.worker_result_queue
214219

220+
for w in self.workers:
221+
w.daemon = True # ensure that the worker exits on process exit
222+
w.start()
223+
215224
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
216225
_set_SIGCHLD_handler()
217226
self.worker_pids_set = True
@@ -326,7 +335,7 @@ class DataLoader(object):
326335
indices at a time. Mutually exclusive with batch_size, shuffle,
327336
sampler, and drop_last.
328337
num_workers (int, optional): how many subprocesses to use for data
329-
loading. 0 means that the data will be loaded in the main process
338+
loading. 0 means that the data will be loaded in the main process.
330339
(default: 0)
331340
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
332341
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
@@ -337,18 +346,31 @@ class DataLoader(object):
337346
will be smaller. (default: False)
338347
timeout (numeric, optional): if positive, the timeout value for collecting a batch
339348
from workers. Should always be non-negative. (default: 0)
349+
worker_init_fn (callable, optional): If not None, this will be called on each
350+
worker subprocess with the worker id as input, after seeding and before data
351+
loading. (default: None)
352+
353+
.. note:: By default, each worker will have its PyTorch seed set to
354+
``base_seed + worker_id``, where ``base_seed`` is a long generated
355+
by main process using its RNG. You may use ``torch.initial_seed()`` to access
356+
this value in :attr:`worker_init_fn`, which can be used to set other seeds
357+
(e.g. NumPy) before data loading.
358+
359+
.. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
360+
unpicklable object, e.g., a lambda function.
340361
"""
341362

342363
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
343364
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
344-
timeout=0):
365+
timeout=0, worker_init_fn=None):
345366
self.dataset = dataset
346367
self.batch_size = batch_size
347368
self.num_workers = num_workers
348369
self.collate_fn = collate_fn
349370
self.pin_memory = pin_memory
350371
self.drop_last = drop_last
351372
self.timeout = timeout
373+
self.worker_init_fn = worker_init_fn
352374

353375
if timeout < 0:
354376
raise ValueError('timeout option should be non-negative')

0 commit comments

Comments
 (0)