Skip to content

Commit 5b06a24

Browse files
committed
use python side SIGCHLD
1 parent 44fff08 commit 5b06a24

File tree

2 files changed

+96
-162
lines changed

2 files changed

+96
-162
lines changed

torch/csrc/DataLoader.cpp

Lines changed: 63 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <sys/wait.h>
2+
#include <map>
23
#include <set>
34
#include <atomic>
45
#include <signal.h>
@@ -12,8 +13,9 @@
1213
// to do our best provide some error message to users when such unfortunate
1314
// events happen.
1415

15-
// TODO: The following don't work on Windows. Specifically, waitid calls and
16-
// SIGCHLD handler. Currently, dummy implementation is provided for Windows.
16+
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
17+
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
18+
// for Windows.
1719

1820
#ifndef _WIN32
1921

@@ -37,7 +39,7 @@ static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *
3739
sigemptyset(&sa.sa_mask);
3840
if (sigaction(signal, &sa, old_sa_ptr) != 0) {
3941
std::ostringstream oss;
40-
oss << "An error occurred while setting handler for " << strsignal(signal);
42+
oss << "An error occurred while setting handler for " << strsignal(signal) << ".";
4143
throw std::runtime_error(oss.str());
4244
}
4345
}
@@ -54,166 +56,75 @@ PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
5456
END_HANDLE_TH_ERRORS
5557
}
5658

57-
static std::set<pid_t> worker_pid_set = {};
58-
// The following are needed since std::set is not asynchronous safe.
59-
static std::atomic<pid_t *> worker_pids;
60-
static std::atomic<size_t> num_worker_pids(0);
61-
// Pipe used as a lock to avoid update of the above and SIGCHLD handler in parallel.
62-
static int comm_pipe[2] = {-1, -1};
63-
64-
static void updatePIDsArray() {
65-
size_t new_size = worker_pid_set.size();
66-
auto new_ptr = (pid_t *)malloc(sizeof(pid_t) * new_size);
67-
size_t idx = 0;
68-
for (auto it = worker_pid_set.begin(); it != worker_pid_set.end(); it++, idx++) {
69-
new_ptr[idx] = *it;
70-
}
71-
72-
// Block SIGCHLD handler for this thread so SIGCHLD handler can't interrupt
73-
// from this thread
74-
sigset_t sigset, old_sigset;
75-
sigemptyset(&sigset);
76-
sigaddset(&sigset, SIGCHLD);
77-
if (sigprocmask(SIG_BLOCK, &sigset, &old_sigset) != 0) {
78-
throw std::runtime_error("An error occurred while setting worker information "
79-
"for DataLoader SIGCHLD handler");
80-
}
81-
// Acquire ``lock'' so handlers on other threads can't interrupt
82-
char c;
83-
read(comm_pipe[0], &c, 1);
84-
85-
pid_t *old_ptr = worker_pids;
86-
num_worker_pids = new_size;
87-
worker_pids = new_ptr;
88-
free(old_ptr);
89-
90-
// Release ``lock''
91-
write(comm_pipe[1], &c, 1);
92-
// Restore handler for this thread.
93-
if (sigprocmask(SIG_SETMASK, &old_sigset, NULL) != 0) {
94-
throw std::runtime_error("An error occurred while setting DataLoader SIGCHLD handler");
95-
}
96-
}
97-
98-
static struct sigaction orig_SIGCHLD_sa;
99-
100-
// SIGCHLD hander should be registered on main loader process to catch any
101-
// worker failing.
102-
// Python handles are _set_main_signal_handers_for_workers() and
103-
// _remove_main_signal_handers_for_workers().
104-
static void handler_SIGCHLD_main(int sig, siginfo_t *info, void *ctx) {
105-
// Acquire ``lock'' so make sure that worker_pids won't change
106-
char c;
107-
read(comm_pipe[0], &c, 1);
59+
static std::map<int64_t, std::set<pid_t>> worker_pids = {};
10860

61+
PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
62+
HANDLE_TH_ERRORS
10963
int error;
64+
std::set<pid_t> pid_set;
65+
pid_t pid;
11066
siginfo_t infop;
11167

112-
// Only check the pids we care about so that Python can see other processes'
113-
// status.
114-
for (size_t i = 0; i < num_worker_pids; i++) {
115-
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
116-
// can get whatever info it wants about the child process.
117-
error = waitid(P_PID, worker_pids[i], &infop, WEXITED|WNOHANG|WNOWAIT);
118-
if (error < 0) // ignore errors
119-
continue;
120-
if ((infop.si_code == CLD_EXITED && infop.si_status != 0) || // exit with error
121-
(infop.si_code == CLD_KILLED) ||
122-
(infop.si_code == CLD_DUMPED)) {
123-
_exit(EXIT_FAILURE);
124-
}
125-
}
126-
127-
// Release ``lock''
128-
write(comm_pipe[1], &c, 1);
129-
130-
// Call the overridden handler.
131-
if ((orig_SIGCHLD_sa.sa_flags | SA_SIGINFO) != 0) {
132-
// handler is sa_sigaction, this shouldn't really be SIG_IGN or SIG_DFL, but
133-
// sa_sigaction and sa_handler happen to be a union, and this fact is
134-
// apparently used by Python, so check here.
135-
// https://stackoverflow.com/a/24080440
136-
if (orig_SIGCHLD_sa.sa_sigaction == (void (*)(int, siginfo_t *, void *)) SIG_IGN) {
137-
// SIG_IGN for SIGCHLD is to reap the child and do nothing else.
138-
while (waitpid(-1, 0, WNOHANG) > 0) {}
139-
} else if (orig_SIGCHLD_sa.sa_sigaction != (void (*)(int, siginfo_t *, void *)) SIG_DFL) {
140-
// SIG_DFL for SIGCHLD is to leave the child as a zombie (do nothing)
141-
orig_SIGCHLD_sa.sa_sigaction(sig, info, ctx);
142-
}
143-
} else {
144-
// handler is sa_handler
145-
if (orig_SIGCHLD_sa.sa_handler == SIG_IGN) {
146-
while (waitpid(-1, 0, WNOHANG) > 0) {}
147-
} else if (orig_SIGCHLD_sa.sa_handler != SIG_DFL) {
148-
orig_SIGCHLD_sa.sa_handler(sig);
68+
// Only check the pids we care about
69+
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
70+
pid_set = it->second;
71+
for (auto pid_it = pid_set.begin(); pid_it != pid_set.end(); ++pid_it) {
72+
pid = *pid_it;
73+
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
74+
// and other handlers can get whatever info they want about the child.
75+
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
76+
if (error < 0) // ignore errors
77+
continue;
78+
if ((infop.si_code == CLD_EXITED && infop.si_status != 0) || // exit with error
79+
(infop.si_code == CLD_KILLED) ||
80+
(infop.si_code == CLD_DUMPED)) {
81+
std::ostringstream oss;
82+
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly.";
83+
pid_set.clear();
84+
throw std::runtime_error(oss.str());
85+
}
14986
}
15087
}
151-
}
152-
153-
static int isSIGCHLDHanderSet() {
154-
struct sigaction sa;
155-
int error = sigaction(SIGCHLD, NULL, &sa);
156-
if (error == 0) {
157-
return ((sa.sa_flags | SA_SIGINFO) != 0) && (sa.sa_sigaction == &handler_SIGCHLD_main);
158-
} else {
159-
throw std::runtime_error("An error occurred while checking DataLoader SIGCHLD handler");
160-
}
88+
Py_RETURN_NONE;
89+
END_HANDLE_TH_ERRORS
16190
}
16291

16392
// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
16493
// of pids we are interested in.
165-
PyObject *THPModule_setMainSignalHandlers(PyObject *module, PyObject *child_pids) {
94+
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
16695
HANDLE_TH_ERRORS
167-
// assert these types are lock free, just to be safe
168-
THPUtils_assert(worker_pids.is_lock_free(), "worker_pids is not lock free");
169-
THPUtils_assert(num_worker_pids.is_lock_free(), "num_worker_pids is not lock free");
170-
171-
THPUtils_assert(PyTuple_Check(child_pids), "_set_main_signal_handlers_for_workers "
172-
"expects a tuple, but got %s", THPUtils_typename(child_pids));
173-
174-
if (comm_pipe[0] == -1) {
175-
// we have GIL here so we are fine
176-
if (pipe(comm_pipe) != 0) {
177-
throw std::runtime_error("An error occurred while setting DataLoader SIGCHLD handler");
178-
}
179-
char c = '_';
180-
write(comm_pipe[1], &c, 1);
181-
}
182-
96+
Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0;
97+
THPUtils_assert(num_args == 2, "_update_worker_pids expectes exactly 2 arguments.");
98+
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
99+
THPUtils_assert(worker_pids.find(key) == worker_pids.end(), "_update_worker_pids "
100+
"should be called only once for each DataLoader.");
101+
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
102+
THPUtils_assert(PyTuple_Check(child_pids), "_update_worker_pids "
103+
"expects a tuple for child_pids, but got %s.", THPUtils_typename(child_pids));
104+
105+
std::set<pid_t> pids_set = {};
183106
auto size = PyTuple_GET_SIZE(child_pids);
184107
for (int idx = 0; idx < size; idx++) {
185108
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
186-
worker_pid_set.insert((pid_t) THPUtils_unpackLong(obj));
109+
pids_set.insert((pid_t) THPUtils_unpackLong(obj));
187110
}
188-
updatePIDsArray();
189111

190-
// To avoid chain calling our handler, check if the current handler is already
191-
// set as ours.
192-
if (!isSIGCHLDHanderSet()) {
193-
setSignalHandler(SIGCHLD, &handler_SIGCHLD_main, &orig_SIGCHLD_sa);
194-
}
195-
Py_RETURN_TRUE;
112+
worker_pids[key] = pids_set;
113+
114+
Py_RETURN_NONE;
196115
END_HANDLE_TH_ERRORS
197116
}
198117

199-
PyObject *THPModule_removeMainSignalHandlers(PyObject *module, PyObject *child_pids) {
118+
PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) {
200119
HANDLE_TH_ERRORS
201-
THPUtils_assert(PyTuple_Check(child_pids), "_remove_main_signal_handlers_for_workers "
202-
"expects a tuple, but got %s", THPUtils_typename(child_pids));
203120

204-
auto size = PyTuple_GET_SIZE(child_pids);
205-
for (int idx = 0; idx < size; idx++) {
206-
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
207-
worker_pid_set.erase((pid_t) THPUtils_unpackLong(obj));
208-
}
209-
updatePIDsArray();
121+
int64_t key = THPUtils_unpackLong(loader_id);
122+
THPUtils_assert(worker_pids.find(key) != worker_pids.end(), "Cannot find worker "
123+
"information for DataLoader with id %ld.", key);
210124

211-
if (isSIGCHLDHanderSet()) {
212-
if (sigaction(SIGCHLD, &orig_SIGCHLD_sa, NULL) != 0) {
213-
throw std::runtime_error("An error occurred while restoring DataLoader SIGCHLD handler");
214-
}
215-
}
216-
Py_RETURN_TRUE;
125+
worker_pids.erase(key);
126+
127+
Py_RETURN_NONE;
217128
END_HANDLE_TH_ERRORS
218129
}
219130

@@ -226,19 +137,24 @@ PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored
226137
Py_RETURN_TRUE;
227138
}
228139

229-
PyObject *THPModule_setMainSignalHandlers(PyObject *module, PyObject *_ignored) {
140+
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) {
230141
Py_RETURN_TRUE;
231142
}
232143

233-
PyObject *THPModule_removeMainSignalHandlers(PyObject *module, PyObject *_ignored) {
234-
Py_RETURN_TRUE;
144+
PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) {
145+
Py_RETURN_NONE;
146+
}
147+
148+
PyObject *THPModule_exitIfAnyWorkerFails(PyObject *module, PyObject *_ignored) {
149+
Py_RETURN_NONE;
235150
}
236151

237152
#endif
238153

239154
PyMethodDef DataLoaderMethods[] = {
240-
{"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL},
241-
{"_set_main_signal_handlers_for_workers", (PyCFunction)THPModule_setMainSignalHandlers, METH_O, NULL},
242-
{"_remove_main_signal_handlers_for_workers", (PyCFunction)THPModule_removeMainSignalHandlers, METH_O, NULL},
155+
{"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL},
156+
{"_update_worker_pids", (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, NULL},
157+
{"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, NULL},
158+
{"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, NULL},
243159
{NULL, NULL, 0, NULL}
244160
};

torch/utils/data/dataloader.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
import torch.multiprocessing as multiprocessing
3-
from torch._C import _set_worker_signal_handlers, \
4-
_set_main_signal_handlers_for_workers, \
5-
_remove_main_signal_handlers_for_workers
3+
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
4+
_remove_worker_pids, _error_if_any_worker_fails
65
from .sampler import SequentialSampler, RandomSampler, BatchSampler
6+
import signal
77
import collections
88
import re
99
import sys
@@ -20,6 +20,7 @@
2020

2121

2222
_use_shared_memory = False
23+
_SIGCHLD_handler_set = False
2324
"""Whether to use shared memory in default_collate"""
2425

2526

@@ -143,6 +144,21 @@ def pin_memory_batch(batch):
143144
else:
144145
return batch
145146

147+
def _set_SIGCHLD_handler():
148+
global _SIGCHLD_handler_set
149+
if _SIGCHLD_handler_set:
150+
return
151+
previous_handler = signal.getsignal(signal.SIGCHLD)
152+
def handler(signum, frame):
153+
_error_if_any_worker_fails()
154+
if callable(previous_handler):
155+
previous_handler(signum, frame)
156+
try:
157+
signal.signal(signal.SIGCHLD, handler)
158+
except ValueError as _:
159+
return # Windows doesn't support this
160+
_SIGCHLD_handler_set = True
161+
146162

147163
class DataLoaderIter(object):
148164
"Iterates once over the DataLoader's dataset, as specified by the sampler"
@@ -162,7 +178,7 @@ def __init__(self, loader):
162178
self.index_queue = multiprocessing.SimpleQueue()
163179
self.worker_result_queue = multiprocessing.SimpleQueue()
164180
self.batches_outstanding = 0
165-
self.handlers_set = False
181+
self.worker_pids_set = False
166182
self.shutdown = False
167183
self.send_idx = 0
168184
self.rcvd_idx = 0
@@ -178,10 +194,6 @@ def __init__(self, loader):
178194
w.daemon = True # ensure that the worker exits on process exit
179195
w.start()
180196

181-
self.worker_pids = tuple(w.pid for w in self.workers)
182-
183-
self.handlers_set = _set_main_signal_handlers_for_workers(self.worker_pids)
184-
185197
if self.pin_memory or self.timeout > 0:
186198
self.data_queue = queue.Queue()
187199
self.worker_manager_thread = threading.Thread(
@@ -192,6 +204,10 @@ def __init__(self, loader):
192204
else:
193205
self.data_queue = self.worker_result_queue
194206

207+
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
208+
_set_SIGCHLD_handler()
209+
self.worker_pids_set = True
210+
195211
# prime the prefetch loop
196212
for _ in range(2 * self.num_workers):
197213
self._put_indices()
@@ -222,7 +238,7 @@ def __next__(self):
222238
return self._process_next_batch(batch)
223239

224240
if self.batches_outstanding == 0:
225-
self._remove_handers()
241+
self._remove_worker_pids_information()
226242
self._shutdown_workers()
227243
raise StopIteration
228244

@@ -274,18 +290,20 @@ def _shutdown_workers(self):
274290
self.data_queue.get()
275291
for _ in self.workers:
276292
self.index_queue.put(None)
277-
# if all workers hang, no None is showed to worker_manager_thread,
278-
# we put None to let worker_manager_thread exit
293+
# if all workers hang, no None is sent to worker_manager_thread, we
294+
# put None to let worker_manager_thread exit
295+
# empty check prevents put from hanging
279296
if self.worker_result_queue.empty():
280297
self.worker_result_queue.put(None)
281298

282-
def _remove_handers(self):
283-
if self.handlers_set:
284-
self.handlers_set = not _remove_main_signal_handlers_for_workers(self.worker_pids)
299+
def _remove_worker_pids_information(self):
300+
if self.worker_pids_set:
301+
_remove_worker_pids(id(self))
302+
self.worker_pids_set = False
285303

286304
def __del__(self):
287305
if self.num_workers > 0:
288-
self._remove_handers()
306+
self._remove_worker_pids_information()
289307
self._shutdown_workers()
290308

291309

0 commit comments

Comments
 (0)