Skip to content

Commit 123156a

Browse files
committed
EINTR and kill by loader fix
1 parent b42f163 commit 123156a

File tree

4 files changed

+264
-38
lines changed

4 files changed

+264
-38
lines changed

test/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def accept_output(update_type):
378378
self.assertEqual(s, expected)
379379

380380
if sys.version_info < (3, 2):
381+
# assertRegexpMatches renamed assertRegex in 3.2
382+
assertRegex = unittest.TestCase.assertRegexpMatches
381383
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
382384
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
383385

test/test_dataloader.py

Lines changed: 156 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
import math
22
import sys
3+
import errno
34
import os
45
import ctypes
6+
import signal
57
import torch
68
import time
79
import traceback
810
import unittest
11+
import socket
912
from torch import multiprocessing
1013
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
1114
from torch.utils.data.dataset import random_split
12-
from torch.utils.data.dataloader import default_collate
15+
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
1316
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
1417
from common_nn import TEST_CUDA
1518

19+
20+
if sys.version_info[0] == 2:
21+
import Queue as queue
22+
else:
23+
import queue
24+
25+
1626
JOIN_TIMEOUT = 14.0 if IS_WINDOWS else 1.5
1727

1828

@@ -103,6 +113,46 @@ def test_add_dataset(self):
103113
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
104114

105115

116+
# Stores the first encountered exception in .exception.
117+
# Inspired by https://stackoverflow.com/a/33599967
118+
class ErrorTrackingProcess(multiprocessing.Process):
119+
120+
def __init__(self, *args, **kwargs):
121+
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
122+
self._pconn, self._cconn = multiprocessing.Pipe()
123+
self._exception = None
124+
125+
def run(self):
126+
# Disable stderr printing from os level, and make workers not printing
127+
# to stderr.
128+
# Can't use sys.stderr.close, otherwise Python `raise` will error with
129+
# ValueError: I/O operation on closed file.
130+
os.close(sys.stderr.fileno())
131+
try:
132+
super(ErrorTrackingProcess, self).run()
133+
self._cconn.send(None)
134+
except Exception as e:
135+
self._cconn.send(ExceptionWrapper(sys.exc_info()))
136+
raise
137+
138+
@property
139+
def exception(self):
140+
if self._pconn.poll():
141+
self._exception = self._pconn.recv()
142+
if self._exception is None:
143+
return None
144+
else:
145+
return self._exception.exc_type(self._exception.exc_msg)
146+
147+
# ESRCH means that os.kill can't finds alive proc
148+
def send_signal(self, signum, ignore_ESRCH=False):
149+
try:
150+
os.kill(self.pid, signum)
151+
except OSError as e:
152+
if not ignore_ESRCH or e.errno != errno.ESRCH:
153+
raise
154+
155+
106156
class ErrorDataset(Dataset):
107157

108158
def __init__(self, size):
@@ -175,21 +225,65 @@ def __len__(self):
175225

176226

177227
def _test_timeout():
178-
os.close(sys.stderr.fileno())
179-
sys.stderr.close()
180228
dataset = SleepDataset(10, 10)
181229
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
182230
_ = next(iter(dataloader))
183231

184232

185233
def _test_segfault():
186-
os.close(sys.stderr.fileno())
187-
sys.stderr.close()
188234
dataset = SegfaultDataset(10)
189235
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
190236
_ = next(iter(dataloader))
191237

192238

239+
def _test_interrupt_retry(timeout=0):
240+
dataset = TensorDataset(torch.randn(1, 1), torch.randn(1, 1))
241+
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, timeout=timeout)
242+
dataloaderiter = iter(dataloader)
243+
244+
# make SIGUSR1 interrupt
245+
def handler(signum, frame):
246+
pass
247+
signal.signal(signal.SIGUSR1, handler)
248+
249+
# Replace iterator getter with a wrapper that reliably calls an
250+
# interruptable blocking recv syscall to simulate interruption during recv
251+
# in queue.get.
252+
# The used socket.recv call below in the replacing function is quite
253+
# dangerous because it blocks everything on Python side, including the
254+
# cleaning up in dataloder.__del__ when this process exits. To prevent
255+
# orphan worker child, we manually terminate worker process here.
256+
# Conveniently, the worker has SIGTERM handler installed so SIGTERM from
257+
# loader process won't cause loader error.
258+
data = dataloaderiter.data_queue.get() # ensure that worker handlers are installed
259+
for w in dataloaderiter.workers:
260+
w.terminate()
261+
262+
def interruptable_get(*args, **kwargs):
263+
if dataloaderiter.shutdown:
264+
return data
265+
# get and config timeout if the argument is present
266+
if timeout >= 0:
267+
if 'timeout' in kwargs:
268+
timeout_val = kwargs['timeout']
269+
elif len(args) > 1:
270+
timeout_val = args[1] # first arg is `block`
271+
else:
272+
timeout_val = None
273+
socket.setdefaulttimeout(timeout_val)
274+
s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM)
275+
s.bind(("127.0.0.1", 0))
276+
try:
277+
return s.recv(1024)
278+
except socket.timeout:
279+
raise queue.Empty
280+
finally:
281+
s.close()
282+
283+
dataloaderiter.data_queue.get = interruptable_get
284+
_ = next(dataloaderiter)
285+
286+
193287
# test custom init function
194288
def init_fn(worker_id):
195289
torch.manual_seed(12345)
@@ -272,22 +366,30 @@ def test_multiple_dataloaders(self):
272366
next(loader2_it)
273367

274368
def test_segfault(self):
275-
p = multiprocessing.Process(target=_test_segfault)
369+
p = ErrorTrackingProcess(target=_test_segfault)
276370
p.start()
277-
p.join(JOIN_TIMEOUT)
371+
p.join(3.0 + JOIN_TIMEOUT)
278372
try:
279373
self.assertFalse(p.is_alive())
280374
self.assertNotEqual(p.exitcode, 0)
375+
if IS_WINDOWS:
376+
self.assertIsInstance(p.exception, OSError)
377+
self.assertRegex(str(p.exception), r'access violation reading ')
378+
else:
379+
self.assertIsInstance(p.exception, RuntimeError)
380+
self.assertRegex(str(p.exception), r'DataLoader worker \(worker_pid \d+\) is killed by signal: ')
281381
finally:
282382
p.terminate()
283383

284384
def test_timeout(self):
285-
p = multiprocessing.Process(target=_test_timeout)
385+
p = ErrorTrackingProcess(target=_test_timeout)
286386
p.start()
287387
p.join(3.0 + JOIN_TIMEOUT)
288388
try:
289389
self.assertFalse(p.is_alive())
290390
self.assertNotEqual(p.exitcode, 0)
391+
self.assertIsInstance(p.exception, RuntimeError)
392+
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
291393
finally:
292394
p.terminate()
293395

@@ -308,6 +410,52 @@ def test_worker_init_fn(self):
308410
self.assertEqual(12345, batch[0])
309411
self.assertEqual(12345, batch[1])
310412

413+
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
414+
def test_interrupt_retry(self):
415+
# time.sleep doesn't seem to work on main process when running unittest
416+
# for all tests together. Use Process.join as a sleep function.
417+
# Case 1: interrupt, check still alove
418+
p = ErrorTrackingProcess(target=_test_interrupt_retry)
419+
p.start()
420+
p.join(JOIN_TIMEOUT) # give it some time to reach get
421+
try:
422+
self.assertTrue(p.is_alive())
423+
for i in range(3):
424+
p.send_signal(signal.SIGUSR1)
425+
p.join(0.5)
426+
self.assertTrue(p.is_alive())
427+
p.join(JOIN_TIMEOUT)
428+
except OSError as e:
429+
self.fail("DataLoader shouldn't fail due to interrupted syscall")
430+
try:
431+
self.assertTrue(p.is_alive())
432+
self.assertIsNone(p.exception)
433+
finally:
434+
p.terminate()
435+
# Case 2: timeout
436+
timeout = 2
437+
p = ErrorTrackingProcess(target=lambda: _test_interrupt_retry(timeout))
438+
p.start()
439+
p.join(JOIN_TIMEOUT) # give some time to reach get
440+
try:
441+
self.assertTrue(p.is_alive())
442+
for _ in range(5):
443+
p.send_signal(signal.SIGUSR1)
444+
p.join(0.5)
445+
p.join(2.0 + JOIN_TIMEOUT)
446+
except OSError as e:
447+
if e.errno != errno.ESRCH:
448+
# ESRCH means that os.kill finds dead proc, which can happen if
449+
# timeout triggers
450+
self.fail("DataLoader shouldn't fail due to interrupted syscall")
451+
try:
452+
self.assertFalse(p.is_alive())
453+
self.assertNotEqual(p.exitcode, 0)
454+
self.assertIsInstance(p.exception, RuntimeError)
455+
self.assertRegex(str(p.exception), r'DataLoader timed out after {} seconds'.format(timeout))
456+
finally:
457+
p.terminate()
458+
311459
def test_shuffle(self):
312460
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
313461

torch/csrc/DataLoader.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
4343

4444
// signal(2) is really not portable. So use sigaction.
4545
// http://man7.org/linux/man-pages/man2/signal.2.html
46-
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
46+
static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
4747
{
4848
struct sigaction sa;
4949
sa.sa_sigaction = handler;
@@ -59,10 +59,32 @@ SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered
5959
"This might be caused by insufficient shared memory (shm).\n");
6060
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
6161

62+
// When an error happend in DataLoader methods and Python starts to exit, the
63+
// error trace will keep the loader alive, and Python may kill the children
64+
// processes first before deleting the loader object. Then the cleaning up
65+
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
66+
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
67+
// loader process here to avoid this.
68+
static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx)
69+
{
70+
if (info->si_pid == getppid()) {
71+
_exit(EXIT_SUCCESS);
72+
}
73+
struct sigaction sa;
74+
sa.sa_handler = SIG_DFL;
75+
sa.sa_flags = 0;
76+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, NULL) != 0) {
77+
_exit(EXIT_FAILURE);
78+
} else {
79+
raise(SIGTERM);
80+
}
81+
}
82+
6283
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
6384
HANDLE_TH_ERRORS
6485
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
6586
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
87+
setSignalHandler(SIGTERM, &handler_SIGTERM, NULL);
6688
Py_RETURN_TRUE;
6789
END_HANDLE_TH_ERRORS
6890
}
@@ -73,33 +95,33 @@ PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
7395
HANDLE_TH_ERRORS
7496
int error;
7597
std::set<pid_t> *pid_set;
76-
pid_t pid;
98+
pid_t worker_pid;
7799
siginfo_t infop;
78100

79101
// Only check the pids we care about
80102
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
81103
pid_set = &(it->second);
82104
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
83-
pid = *pid_it;
105+
worker_pid = *pid_it;
84106
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
85107
// and other handlers can get whatever info they want about the child.
86108
infop.si_pid = 0;
87-
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
109+
error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT);
88110
// ignore errors and case with no waitable child
89111
if (error < 0 || infop.si_pid == 0)
90112
continue;
91113
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
92114
std::ostringstream oss;
93-
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
94-
<< "with exit code " << infop.si_status << ".";
115+
oss << "DataLoader worker (worker_pid " << worker_pid << ") exited "
116+
<< "unexpectedly with exit code " << infop.si_status << ".";
95117
// This is necessary. Otherwise, the runtime error will kill the other
96118
// workers, and trigger this again.
97119
pid_set->clear();
98120
throw std::runtime_error(oss.str());
99121
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
100122
std::ostringstream oss;
101-
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
102-
<< strsignal(infop.si_status) << ".";
123+
oss << "DataLoader worker (worker_pid " << worker_pid << ") is killed "
124+
<< "by signal: " << strsignal(infop.si_status) << ".";
103125
// This is necessary. Otherwise, the runtime error will kill the other
104126
// workers, and trigger this again.
105127
pid_set->clear();

0 commit comments

Comments
 (0)