11import math
22import sys
3+ import os
34import ctypes
45import torch
56import time
@@ -153,15 +154,16 @@ class SynchronizedSeedDataset(Dataset):
153154
154155 def __init__ (self , size , num_workers ):
155156 assert size >= num_workers
156- self .count = multiprocessing .Value ('i' , 0 )
157+ self .count = multiprocessing .Value ('i' , 0 , lock = True )
157158 self .barrier = multiprocessing .Semaphore (0 )
158159 self .num_workers = num_workers
159160 self .size = size
160161
161162 def __getitem__ (self , idx ):
162- self .count .value += 1
163- if self .count .value == self .num_workers :
164- self .barrier .release ()
163+ with self .count .get_lock ():
164+ self .count .value += 1
165+ if self .count .value == self .num_workers :
166+ self .barrier .release ()
165167 self .barrier .acquire ()
166168 self .barrier .release ()
167169 return torch .initial_seed ()
@@ -249,6 +251,7 @@ def test_multiple_dataloaders(self):
249251 @unittest .skipIf (IS_WINDOWS , "TODO: need to fix this test case for Windows" )
250252 def test_segfault (self ):
251253 def _test_segfault ():
254+ os .close (sys .stderr .fileno ())
252255 sys .stderr .close ()
253256 dataset = SegfaultDataset (10 )
254257 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 )
@@ -266,6 +269,7 @@ def _test_segfault():
266269 @unittest .skipIf (IS_WINDOWS , "TODO: need to fix this test case for Windows" )
267270 def test_timeout (self ):
268271 def _test_timeout ():
272+ os .close (sys .stderr .fileno ())
269273 sys .stderr .close ()
270274 dataset = SleepDataset (10 , 10 )
271275 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 , timeout = 1 )
0 commit comments