11import math
22import sys
3+ import errno
4+ import os
35import ctypes
6+ import signal
47import torch
58import time
69import traceback
710import unittest
811from torch import multiprocessing
912from torch .utils .data import Dataset , TensorDataset , DataLoader , ConcatDataset
1013from torch .utils .data .dataset import random_split
11- from torch .utils .data .dataloader import default_collate
12- from common import TestCase , run_tests , TEST_NUMPY
14+ from torch .utils .data .dataloader import default_collate , ExceptionWrapper
15+ from common import TestCase , run_tests , TEST_NUMPY , IS_WINDOWS
1316from common_nn import TEST_CUDA
1417
1518
19+ JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5
20+
21+
1622class TestDatasetRandomSplit (TestCase ):
1723 def test_lengths_must_equal_datset_size (self ):
1824 with self .assertRaises (ValueError ):
@@ -100,6 +106,46 @@ def test_add_dataset(self):
100106 self .assertEqual (0 , (d3 [0 ][0 ] - result [14 ][0 ]).abs ().sum ())
101107
102108
109+ # Stores the first encountered exception in .exception.
110+ # Inspired by https://stackoverflow.com/a/33599967
111+ class ErrorTrackingProcess (multiprocessing .Process ):
112+
113+ def __init__ (self , * args , ** kwargs ):
114+ super (ErrorTrackingProcess , self ).__init__ (* args , ** kwargs )
115+ self ._pconn , self ._cconn = multiprocessing .Pipe ()
116+ self ._exception = None
117+
118+ def run (self ):
119+ # Disable stderr printing from os level, and make workers not printing
120+ # to stderr.
121+ # Can't use sys.stderr.close, otherwise Python `raise` will error with
122+ # ValueError: I/O operation on closed file.
123+ os .close (sys .stderr .fileno ())
124+ try :
125+ super (ErrorTrackingProcess , self ).run ()
126+ self ._cconn .send (None )
127+ except Exception as e :
128+ self ._cconn .send (ExceptionWrapper (sys .exc_info ()))
129+ raise
130+
131+ @property
132+ def exception (self ):
133+ if self ._pconn .poll ():
134+ self ._exception = self ._pconn .recv ()
135+ if self ._exception is None :
136+ return None
137+ else :
138+ return self ._exception .exc_type (self ._exception .exc_msg )
139+
140+ # ESRCH means that os.kill can't finds alive proc
141+ def send_signal (self , signum , ignore_ESRCH = False ):
142+ try :
143+ os .kill (self .pid , signum )
144+ except OSError as e :
145+ if not ignore_ESRCH or e .errno != errno .ESRCH :
146+ raise
147+
148+
103149class ErrorDataset (Dataset ):
104150
105151 def __init__ (self , size ):
@@ -170,6 +216,23 @@ def __len__(self):
170216 return self .size
171217
172218
219+ def _test_timeout ():
220+ dataset = SleepDataset (10 , 10 )
221+ dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 , timeout = 1 )
222+ _ = next (iter (dataloader ))
223+
224+
225+ def _test_segfault ():
226+ dataset = SegfaultDataset (10 )
227+ dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 )
228+ _ = next (iter (dataloader ))
229+
230+
231+ # test custom init function
232+ def init_fn (worker_id ):
233+ torch .manual_seed (12345 )
234+
235+
173236class TestDataLoader (TestCase ):
174237
175238 def setUp (self ):
@@ -248,34 +311,30 @@ def test_multiple_dataloaders(self):
248311
249312 @unittest .skipIf (True , "flaky test" )
250313 def test_segfault (self ):
251- def _test_segfault ():
252- sys .stderr .close ()
253- dataset = SegfaultDataset (10 )
254- dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 )
255- _ = next (iter (dataloader ))
256-
257- p = multiprocessing .Process (target = _test_segfault )
314+ p = ErrorTrackingProcess (target = _test_segfault )
258315 p .start ()
259- p .join (1.0 )
316+ p .join (JOIN_TIMEOUT )
260317 try :
261318 self .assertFalse (p .is_alive ())
262319 self .assertNotEqual (p .exitcode , 0 )
320+ if IS_WINDOWS :
321+ self .assertIsInstance (p .exception , OSError )
322+ self .assertRegex (str (p .exception ), r'access violation reading ' )
323+ else :
324+ self .assertIsInstance (p .exception , RuntimeError )
325+ self .assertRegex (str (p .exception ), r'DataLoader worker \(pid \d+\) is killed by signal: ' )
263326 finally :
264327 p .terminate ()
265328
266329 def test_timeout (self ):
267- def _test_timeout ():
268- sys .stderr .close ()
269- dataset = SleepDataset (10 , 10 )
270- dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 , timeout = 1 )
271- _ = next (iter (dataloader ))
272-
273- p = multiprocessing .Process (target = _test_timeout )
330+ p = ErrorTrackingProcess (target = _test_timeout )
274331 p .start ()
275- p .join (3.0 )
332+ p .join (JOIN_TIMEOUT )
276333 try :
277334 self .assertFalse (p .is_alive ())
278335 self .assertNotEqual (p .exitcode , 0 )
336+ self .assertIsInstance (p .exception , RuntimeError )
337+ self .assertRegex (str (p .exception ), r'DataLoader timed out after \d+ seconds' )
279338 finally :
280339 p .terminate ()
281340
@@ -289,10 +348,6 @@ def test_worker_seed(self):
289348 self .assertEqual (len (seeds ), num_workers )
290349
291350 def test_worker_init_fn (self ):
292- # test custom init function
293- def init_fn (worker_id ):
294- torch .manual_seed (12345 )
295-
296351 dataset = SeedDataset (4 )
297352 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 ,
298353 worker_init_fn = init_fn )
@@ -381,10 +436,10 @@ def test_partial_workers(self):
381436 break
382437 del loader
383438 for w in workers :
384- w .join (1.0 ) # timeout of one second
439+ w .join (JOIN_TIMEOUT )
385440 self .assertFalse (w .is_alive (), 'subprocess not terminated' )
386441 self .assertEqual (w .exitcode , 0 )
387- worker_manager_thread .join (1.0 )
442+ worker_manager_thread .join (JOIN_TIMEOUT )
388443 self .assertFalse (worker_manager_thread .is_alive ())
389444
390445 def test_len (self ):
0 commit comments