Skip to content

Commit 1661370

Browse files
ssnlapaszke
authored andcommitted
Signal handling in DataLoader workers; Timeout option (#3474)
1 parent 3c709f5 commit 1661370

File tree

4 files changed

+336
-18
lines changed

4 files changed

+336
-18
lines changed

test/test_dataloader.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import math
22
import sys
3+
import ctypes
34
import torch
5+
import time
46
import traceback
57
import unittest
8+
from torch import multiprocessing
69
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
710
from torch.utils.data.dataloader import default_collate
811
from common import TestCase, run_tests, TEST_NUMPY
@@ -83,6 +86,32 @@ def __len__(self):
8386
return self.size
8487

8588

89+
class SegfaultDataset(Dataset):
90+
91+
def __init__(self, size):
92+
self.size = size
93+
94+
def __getitem__(self, idx):
95+
return ctypes.string_at(0)
96+
97+
def __len__(self):
98+
return self.size
99+
100+
101+
class SleepDataset(Dataset):
102+
103+
def __init__(self, size, sleep_sec):
104+
self.size = size
105+
self.sleep_sec = sleep_sec
106+
107+
def __getitem__(self, idx):
108+
time.sleep(self.sleep_sec)
109+
return idx
110+
111+
def __len__(self):
112+
return self.size
113+
114+
86115
class TestDataLoader(TestCase):
87116

88117
def setUp(self):
@@ -149,6 +178,48 @@ def test_sequential_pin_memory(self):
149178
self.assertTrue(input.is_pinned())
150179
self.assertTrue(target.is_pinned())
151180

181+
def test_multiple_dataloaders(self):
182+
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
183+
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
184+
next(loader1_it)
185+
next(loader1_it)
186+
next(loader2_it)
187+
next(loader2_it)
188+
next(loader1_it)
189+
next(loader2_it)
190+
191+
def test_segfault(self):
192+
def _test_segfault():
193+
sys.stderr.close()
194+
dataset = SegfaultDataset(10)
195+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
196+
_ = next(iter(dataloader))
197+
198+
p = multiprocessing.Process(target=_test_segfault)
199+
p.start()
200+
p.join(1.0)
201+
try:
202+
self.assertFalse(p.is_alive())
203+
self.assertNotEqual(p.exitcode, 0)
204+
finally:
205+
p.terminate()
206+
207+
def test_timeout(self):
208+
def _test_timeout():
209+
sys.stderr.close()
210+
dataset = SleepDataset(10, 10)
211+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
212+
_ = next(iter(dataloader))
213+
214+
p = multiprocessing.Process(target=_test_timeout)
215+
p.start()
216+
p.join(3.0)
217+
try:
218+
self.assertFalse(p.is_alive())
219+
self.assertNotEqual(p.exitcode, 0)
220+
finally:
221+
p.terminate()
222+
152223
def test_shuffle(self):
153224
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
154225

@@ -224,7 +295,7 @@ def test_partial_workers(self):
224295
"check that workers exit even if the iterator is not exhausted"
225296
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
226297
workers = loader.workers
227-
pin_thread = loader.pin_thread
298+
worker_manager_thread = loader.worker_manager_thread
228299
for i, sample in enumerate(loader):
229300
if i == 3:
230301
break
@@ -233,8 +304,8 @@ def test_partial_workers(self):
233304
w.join(1.0) # timeout of one second
234305
self.assertFalse(w.is_alive(), 'subprocess not terminated')
235306
self.assertEqual(w.exitcode, 0)
236-
pin_thread.join(1.0)
237-
self.assertFalse(pin_thread.is_alive())
307+
worker_manager_thread.join(1.0)
308+
self.assertFalse(worker_manager_thread.is_alive())
238309

239310
def test_len(self):
240311
def check_len(dl, expected):

torch/csrc/DataLoader.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#include <sys/wait.h>
2+
#include <map>
3+
#include <set>
4+
#include <atomic>
5+
#include <signal.h>
6+
#include "THP.h"
7+
8+
// In cases like DataLoader, if a worker process die due to bus error/segfault
9+
// or just hang, the main process, if implemented with
10+
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
11+
// difficult to avoid on PyTorch side as it can be caused by limited shm, or
12+
// other libraries users call in the workers. The following methods is an effort
13+
// to do our best provide some error message to users when such unfortunate
14+
// events happen.
15+
16+
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
17+
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
18+
// for Windows.
19+
20+
#ifndef _WIN32
21+
22+
// Critical signal handlers should be registered on worker processes before
23+
// doing work.
24+
// The handler will raise default handler so that the kill information will be
25+
// retrieved from main process.
26+
// Python handle is _set_worker_signal_handlers().
27+
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
28+
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
29+
{ \
30+
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
31+
struct sigaction sa; \
32+
sa.sa_handler = SIG_DFL; \
33+
sa.sa_flags = 0; \
34+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, NULL) != 0) { \
35+
_exit(EXIT_FAILURE); \
36+
} else { \
37+
raise(SIGNAL); \
38+
} \
39+
}
40+
41+
// signal(2) is really not portable. So use sigaction.
42+
// http://man7.org/linux/man-pages/man2/signal.2.html
43+
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
44+
{
45+
struct sigaction sa;
46+
sa.sa_sigaction = handler;
47+
sa.sa_flags = SA_RESTART|SA_SIGINFO|SA_NOCLDSTOP|SA_NODEFER;
48+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) {
49+
std::ostringstream oss;
50+
oss << "An error occurred while setting handler for " << strsignal(signal) << ".";
51+
throw std::runtime_error(oss.str());
52+
}
53+
}
54+
55+
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
56+
"This might be caused by insufficient shared memory (shm).\n");
57+
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
58+
59+
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
60+
HANDLE_TH_ERRORS
61+
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
62+
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
63+
Py_RETURN_TRUE;
64+
END_HANDLE_TH_ERRORS
65+
}
66+
67+
static std::map<int64_t, std::set<pid_t>> worker_pids = {};
68+
69+
PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
70+
HANDLE_TH_ERRORS
71+
int error;
72+
std::set<pid_t> *pid_set;
73+
pid_t pid;
74+
siginfo_t infop;
75+
76+
// Only check the pids we care about
77+
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
78+
pid_set = &(it->second);
79+
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
80+
pid = *pid_it;
81+
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
82+
// and other handlers can get whatever info they want about the child.
83+
infop.si_pid = 0;
84+
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
85+
// ignore errors and case with no waitable child
86+
if (error < 0 || infop.si_pid == 0)
87+
continue;
88+
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
89+
std::ostringstream oss;
90+
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
91+
<< "with exit code " << infop.si_status << ".";
92+
// This is necessary. Otherwise, the runtime error will kill the other
93+
// workers, and trigger this again.
94+
pid_set->clear();
95+
throw std::runtime_error(oss.str());
96+
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
97+
std::ostringstream oss;
98+
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
99+
<< strsignal(infop.si_status) << ".";
100+
// This is necessary. Otherwise, the runtime error will kill the other
101+
// workers, and trigger this again.
102+
pid_set->clear();
103+
throw std::runtime_error(oss.str());
104+
}
105+
}
106+
}
107+
Py_RETURN_NONE;
108+
END_HANDLE_TH_ERRORS
109+
}
110+
111+
// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
112+
// of pids we are interested in.
113+
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
114+
HANDLE_TH_ERRORS
115+
Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0;
116+
THPUtils_assert(num_args == 2, "_update_worker_pids expectes exactly 2 arguments.");
117+
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
118+
THPUtils_assert(worker_pids.find(key) == worker_pids.end(), "_update_worker_pids "
119+
"should be called only once for each DataLoader.");
120+
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
121+
THPUtils_assert(PyTuple_Check(child_pids), "_update_worker_pids "
122+
"expects a tuple for child_pids, but got %s.", THPUtils_typename(child_pids));
123+
124+
std::set<pid_t> pids_set = {};
125+
auto size = PyTuple_GET_SIZE(child_pids);
126+
for (int idx = 0; idx < size; idx++) {
127+
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
128+
pids_set.insert((pid_t) THPUtils_unpackLong(obj));
129+
}
130+
131+
worker_pids[key] = pids_set;
132+
133+
Py_RETURN_NONE;
134+
END_HANDLE_TH_ERRORS
135+
}
136+
137+
PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) {
138+
HANDLE_TH_ERRORS
139+
140+
int64_t key = THPUtils_unpackLong(loader_id);
141+
THPUtils_assert(worker_pids.find(key) != worker_pids.end(), "Cannot find worker "
142+
"information for DataLoader with id %ld.", key);
143+
144+
worker_pids.erase(key);
145+
146+
Py_RETURN_NONE;
147+
END_HANDLE_TH_ERRORS
148+
}
149+
150+
#undef SIGNAL_HANDLER
151+
152+
#else
153+
// dummy implementations for windows
154+
155+
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored) {
156+
Py_RETURN_TRUE;
157+
}
158+
159+
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) {
160+
Py_RETURN_TRUE;
161+
}
162+
163+
PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) {
164+
Py_RETURN_NONE;
165+
}
166+
167+
PyObject *THPModule_exitIfAnyWorkerFails(PyObject *module, PyObject *_ignored) {
168+
Py_RETURN_NONE;
169+
}
170+
171+
#endif
172+
173+
PyMethodDef DataLoaderMethods[] = {
174+
{"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL},
175+
{"_update_worker_pids", (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, NULL},
176+
{"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, NULL},
177+
{"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, NULL},
178+
{NULL, NULL, 0, NULL}
179+
};

torch/csrc/Module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "THP.h"
2929

3030
#include "ModuleSparse.cpp"
31+
#include "DataLoader.cpp"
3132

3233
PyObject* module;
3334
PyObject* tensor_classes;
@@ -797,6 +798,7 @@ static PyObject* initModule() {
797798
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
798799

799800
THPUtils_addPyMethodDefs(methods, TorchMethods);
801+
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
800802
#ifdef WITH_CUDA
801803
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
802804
#endif

0 commit comments

Comments
 (0)