11import math
22import sys
3+ import errno
34import os
45import ctypes
6+ import signal
57import torch
68import time
79import traceback
810import unittest
11+ import socket
912from torch import multiprocessing
1013from torch .utils .data import Dataset , TensorDataset , DataLoader , ConcatDataset
1114from torch .utils .data .dataset import random_split
12- from torch .utils .data .dataloader import default_collate
15+ from torch .utils .data .dataloader import default_collate , ExceptionWrapper
1316from common import TestCase , run_tests , TEST_NUMPY , IS_WINDOWS
1417from common_nn import TEST_CUDA
1518
19+
20+ if sys .version_info [0 ] == 2 :
21+ import Queue as queue
22+ else :
23+ import queue
24+
25+
1626JOIN_TIMEOUT = 14.0 if IS_WINDOWS else 1.5
1727
1828
@@ -103,6 +113,46 @@ def test_add_dataset(self):
103113 self .assertEqual (0 , (d3 [0 ][0 ] - result [14 ][0 ]).abs ().sum ())
104114
105115
116+ # Stores the first encountered exception in .exception.
117+ # Inspired by https://stackoverflow.com/a/33599967
118+ class ErrorTrackingProcess (multiprocessing .Process ):
119+
120+ def __init__ (self , * args , ** kwargs ):
121+ super (ErrorTrackingProcess , self ).__init__ (* args , ** kwargs )
122+ self ._pconn , self ._cconn = multiprocessing .Pipe ()
123+ self ._exception = None
124+
125+ def run (self ):
126+ # Disable stderr printing from os level, and make workers not printing
127+ # to stderr.
128+ # Can't use sys.stderr.close, otherwise Python `raise` will error with
129+ # ValueError: I/O operation on closed file.
130+ os .close (sys .stderr .fileno ())
131+ try :
132+ super (ErrorTrackingProcess , self ).run ()
133+ self ._cconn .send (None )
134+ except Exception as e :
135+ self ._cconn .send (ExceptionWrapper (sys .exc_info ()))
136+ raise
137+
138+ @property
139+ def exception (self ):
140+ if self ._pconn .poll ():
141+ self ._exception = self ._pconn .recv ()
142+ if self ._exception is None :
143+ return None
144+ else :
145+ return self ._exception .exc_type (self ._exception .exc_msg )
146+
147+ # ESRCH means that os.kill can't finds alive proc
148+ def send_signal (self , signum , ignore_ESRCH = False ):
149+ try :
150+ os .kill (self .pid , signum )
151+ except OSError as e :
152+ if not ignore_ESRCH or e .errno != errno .ESRCH :
153+ raise
154+
155+
106156class ErrorDataset (Dataset ):
107157
108158 def __init__ (self , size ):
@@ -175,21 +225,65 @@ def __len__(self):
175225
176226
177227def _test_timeout ():
178- os .close (sys .stderr .fileno ())
179- sys .stderr .close ()
180228 dataset = SleepDataset (10 , 10 )
181229 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 , timeout = 1 )
182230 _ = next (iter (dataloader ))
183231
184232
185233def _test_segfault ():
186- os .close (sys .stderr .fileno ())
187- sys .stderr .close ()
188234 dataset = SegfaultDataset (10 )
189235 dataloader = DataLoader (dataset , batch_size = 2 , num_workers = 2 )
190236 _ = next (iter (dataloader ))
191237
192238
239+ def _test_interrupt_retry (timeout = 0 ):
240+ dataset = TensorDataset (torch .randn (1 , 1 ), torch .randn (1 , 1 ))
241+ dataloader = DataLoader (dataset , batch_size = 1 , num_workers = 1 , timeout = timeout )
242+ dataloaderiter = iter (dataloader )
243+
244+ # make SIGUSR1 interrupt
245+ def handler (signum , frame ):
246+ pass
247+ signal .signal (signal .SIGUSR1 , handler )
248+
249+ # Replace iterator getter with a wrapper that reliably calls an
250+ # interruptable blocking recv syscall to simulate interruption during recv
251+ # in queue.get.
252+ # The used socket.recv call below in the replacing function is quite
253+ # dangerous because it blocks everything on Python side, including the
254+ # cleaning up in dataloder.__del__ when this process exits. To prevent
255+ # orphan worker child, we manually terminate worker process here.
256+ # Conveniently, the worker has SIGTERM handler installed so SIGTERM from
257+ # loader process won't cause loader error.
258+ data = dataloaderiter .data_queue .get () # ensure that worker handlers are installed
259+ for w in dataloaderiter .workers :
260+ w .terminate ()
261+
262+ def interruptable_get (* args , ** kwargs ):
263+ if dataloaderiter .shutdown :
264+ return data
265+ # get and config timeout if the argument is present
266+ if timeout >= 0 :
267+ if 'timeout' in kwargs :
268+ timeout_val = kwargs ['timeout' ]
269+ elif len (args ) > 1 :
270+ timeout_val = args [1 ] # first arg is `block`
271+ else :
272+ timeout_val = None
273+ socket .setdefaulttimeout (timeout_val )
274+ s = socket .socket (socket .AF_INET , type = socket .SOCK_DGRAM )
275+ s .bind (("127.0.0.1" , 0 ))
276+ try :
277+ return s .recv (1024 )
278+ except socket .timeout :
279+ raise queue .Empty
280+ finally :
281+ s .close ()
282+
283+ dataloaderiter .data_queue .get = interruptable_get
284+ _ = next (dataloaderiter )
285+
286+
193287# test custom init function
194288def init_fn (worker_id ):
195289 torch .manual_seed (12345 )
@@ -272,22 +366,30 @@ def test_multiple_dataloaders(self):
272366 next (loader2_it )
273367
274368 def test_segfault (self ):
275- p = multiprocessing . Process (target = _test_segfault )
369+ p = ErrorTrackingProcess (target = _test_segfault )
276370 p .start ()
277- p .join (JOIN_TIMEOUT )
371+ p .join (3.0 + JOIN_TIMEOUT )
278372 try :
279373 self .assertFalse (p .is_alive ())
280374 self .assertNotEqual (p .exitcode , 0 )
375+ if IS_WINDOWS :
376+ self .assertIsInstance (p .exception , OSError )
377+ self .assertRegex (str (p .exception ), r'access violation reading ' )
378+ else :
379+ self .assertIsInstance (p .exception , RuntimeError )
380+ self .assertRegex (str (p .exception ), r'DataLoader worker \(worker_pid \d+\) is killed by signal: ' )
281381 finally :
282382 p .terminate ()
283383
284384 def test_timeout (self ):
285- p = multiprocessing . Process (target = _test_timeout )
385+ p = ErrorTrackingProcess (target = _test_timeout )
286386 p .start ()
287387 p .join (3.0 + JOIN_TIMEOUT )
288388 try :
289389 self .assertFalse (p .is_alive ())
290390 self .assertNotEqual (p .exitcode , 0 )
391+ self .assertIsInstance (p .exception , RuntimeError )
392+ self .assertRegex (str (p .exception ), r'DataLoader timed out after \d+ seconds' )
291393 finally :
292394 p .terminate ()
293395
@@ -308,6 +410,52 @@ def test_worker_init_fn(self):
308410 self .assertEqual (12345 , batch [0 ])
309411 self .assertEqual (12345 , batch [1 ])
310412
413+ @unittest .skipIf (IS_WINDOWS , "TODO: need to fix this test case for Windows" )
414+ def test_interrupt_retry (self ):
415+ # time.sleep doesn't seem to work on main process when running unittest
416+ # for all tests together. Use Process.join as a sleep function.
417+ # Case 1: interrupt, check still alove
418+ p = ErrorTrackingProcess (target = _test_interrupt_retry )
419+ p .start ()
420+ p .join (JOIN_TIMEOUT ) # give it some time to reach get
421+ try :
422+ self .assertTrue (p .is_alive ())
423+ for i in range (3 ):
424+ p .send_signal (signal .SIGUSR1 )
425+ p .join (0.5 )
426+ self .assertTrue (p .is_alive ())
427+ p .join (JOIN_TIMEOUT )
428+ except OSError as e :
429+ self .fail ("DataLoader shouldn't fail due to interrupted syscall" )
430+ try :
431+ self .assertTrue (p .is_alive ())
432+ self .assertIsNone (p .exception )
433+ finally :
434+ p .terminate ()
435+ # Case 2: timeout
436+ timeout = 2
437+ p = ErrorTrackingProcess (target = lambda : _test_interrupt_retry (timeout ))
438+ p .start ()
439+ p .join (JOIN_TIMEOUT ) # give some time to reach get
440+ try :
441+ self .assertTrue (p .is_alive ())
442+ for _ in range (5 ):
443+ p .send_signal (signal .SIGUSR1 )
444+ p .join (0.5 )
445+ p .join (2.0 + JOIN_TIMEOUT )
446+ except OSError as e :
447+ if e .errno != errno .ESRCH :
448+ # ESRCH means that os.kill finds dead proc, which can happen if
449+ # timeout triggers
450+ self .fail ("DataLoader shouldn't fail due to interrupted syscall" )
451+ try :
452+ self .assertFalse (p .is_alive ())
453+ self .assertNotEqual (p .exitcode , 0 )
454+ self .assertIsInstance (p .exception , RuntimeError )
455+ self .assertRegex (str (p .exception ), r'DataLoader timed out after {} seconds' .format (timeout ))
456+ finally :
457+ p .terminate ()
458+
311459 def test_shuffle (self ):
312460 self ._test_shuffle (DataLoader (self .dataset , shuffle = True ))
313461
0 commit comments