|
| 1 | +#include <sys/wait.h> |
| 2 | +#include <map> |
| 3 | +#include <set> |
| 4 | +#include <atomic> |
| 5 | +#include <signal.h> |
| 6 | +#include "THP.h" |
| 7 | + |
| 8 | +// In cases like DataLoader, if a worker process die due to bus error/segfault |
| 9 | +// or just hang, the main process, if implemented with |
| 10 | +// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is |
| 11 | +// difficult to avoid on PyTorch side as it can be caused by limited shm, or |
| 12 | +// other libraries users call in the workers. The following methods is an effort |
| 13 | +// to do our best provide some error message to users when such unfortunate |
| 14 | +// events happen. |
| 15 | + |
| 16 | +// TODO: The following don't work on Windows. Specifically, sigaction, waitid |
| 17 | +// calls ,and SIGCHLD handler. Currently, dummy implementations are provided |
| 18 | +// for Windows. |
| 19 | + |
| 20 | +#ifndef _WIN32 |
| 21 | + |
| 22 | +// Critical signal handlers should be registered on worker processes before |
| 23 | +// doing work. |
| 24 | +// The handler will raise default handler so that the kill information will be |
| 25 | +// retrieved from main process. |
| 26 | +// Python handle is _set_worker_signal_handlers(). |
| 27 | +#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ |
| 28 | +static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \ |
| 29 | +{ \ |
| 30 | + write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ |
| 31 | + struct sigaction sa; \ |
| 32 | + sa.sa_handler = SIG_DFL; \ |
| 33 | + sa.sa_flags = 0; \ |
| 34 | + if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, NULL) != 0) { \ |
| 35 | + _exit(EXIT_FAILURE); \ |
| 36 | + } else { \ |
| 37 | + raise(SIGNAL); \ |
| 38 | + } \ |
| 39 | +} |
| 40 | + |
| 41 | +// signal(2) is really not portable. So use sigaction. |
| 42 | +// http://man7.org/linux/man-pages/man2/signal.2.html |
| 43 | +static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr) |
| 44 | +{ |
| 45 | + struct sigaction sa; |
| 46 | + sa.sa_sigaction = handler; |
| 47 | + sa.sa_flags = SA_RESTART|SA_SIGINFO|SA_NOCLDSTOP|SA_NODEFER; |
| 48 | + if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) { |
| 49 | + std::ostringstream oss; |
| 50 | + oss << "An error occurred while setting handler for " << strsignal(signal) << "."; |
| 51 | + throw std::runtime_error(oss.str()); |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " |
| 56 | + "This might be caused by insufficient shared memory (shm).\n"); |
| 57 | +SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); |
| 58 | + |
| 59 | +PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) { |
| 60 | + HANDLE_TH_ERRORS |
| 61 | + setSignalHandler(SIGBUS, &handler_SIGBUS, NULL); |
| 62 | + setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL); |
| 63 | + Py_RETURN_TRUE; |
| 64 | + END_HANDLE_TH_ERRORS |
| 65 | +} |
| 66 | + |
| 67 | +static std::map<int64_t, std::set<pid_t>> worker_pids = {}; |
| 68 | + |
| 69 | +PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { |
| 70 | + HANDLE_TH_ERRORS |
| 71 | + int error; |
| 72 | + std::set<pid_t> *pid_set; |
| 73 | + pid_t pid; |
| 74 | + siginfo_t infop; |
| 75 | + |
| 76 | + // Only check the pids we care about |
| 77 | + for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) { |
| 78 | + pid_set = &(it->second); |
| 79 | + for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) { |
| 80 | + pid = *pid_it; |
| 81 | + // Use waitid rather than waitpid so that we can set NOWAIT, and that Python |
| 82 | + // and other handlers can get whatever info they want about the child. |
| 83 | + infop.si_pid = 0; |
| 84 | + error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT); |
| 85 | + // ignore errors and case with no waitable child |
| 86 | + if (error < 0 || infop.si_pid == 0) |
| 87 | + continue; |
| 88 | + if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error |
| 89 | + std::ostringstream oss; |
| 90 | + oss << "DataLoader worker (pid " << pid << ") exited unexpectedly " |
| 91 | + << "with exit code " << infop.si_status << "."; |
| 92 | + // This is necessary. Otherwise, the runtime error will kill the other |
| 93 | + // workers, and trigger this again. |
| 94 | + pid_set->clear(); |
| 95 | + throw std::runtime_error(oss.str()); |
| 96 | + } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal |
| 97 | + std::ostringstream oss; |
| 98 | + oss << "DataLoader worker (pid " << pid << ") is killed by signal: " |
| 99 | + << strsignal(infop.si_status) << "."; |
| 100 | + // This is necessary. Otherwise, the runtime error will kill the other |
| 101 | + // workers, and trigger this again. |
| 102 | + pid_set->clear(); |
| 103 | + throw std::runtime_error(oss.str()); |
| 104 | + } |
| 105 | + } |
| 106 | + } |
| 107 | + Py_RETURN_NONE; |
| 108 | + END_HANDLE_TH_ERRORS |
| 109 | +} |
| 110 | + |
| 111 | +// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple |
| 112 | +// of pids we are interested in. |
| 113 | +PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) { |
| 114 | + HANDLE_TH_ERRORS |
| 115 | + Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0; |
| 116 | + THPUtils_assert(num_args == 2, "_update_worker_pids expectes exactly 2 arguments."); |
| 117 | + int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); |
| 118 | + THPUtils_assert(worker_pids.find(key) == worker_pids.end(), "_update_worker_pids " |
| 119 | + "should be called only once for each DataLoader."); |
| 120 | + PyObject *child_pids = PyTuple_GET_ITEM(args, 1); |
| 121 | + THPUtils_assert(PyTuple_Check(child_pids), "_update_worker_pids " |
| 122 | + "expects a tuple for child_pids, but got %s.", THPUtils_typename(child_pids)); |
| 123 | + |
| 124 | + std::set<pid_t> pids_set = {}; |
| 125 | + auto size = PyTuple_GET_SIZE(child_pids); |
| 126 | + for (int idx = 0; idx < size; idx++) { |
| 127 | + PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); |
| 128 | + pids_set.insert((pid_t) THPUtils_unpackLong(obj)); |
| 129 | + } |
| 130 | + |
| 131 | + worker_pids[key] = pids_set; |
| 132 | + |
| 133 | + Py_RETURN_NONE; |
| 134 | + END_HANDLE_TH_ERRORS |
| 135 | +} |
| 136 | + |
| 137 | +PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) { |
| 138 | + HANDLE_TH_ERRORS |
| 139 | + |
| 140 | + int64_t key = THPUtils_unpackLong(loader_id); |
| 141 | + THPUtils_assert(worker_pids.find(key) != worker_pids.end(), "Cannot find worker " |
| 142 | + "information for DataLoader with id %ld.", key); |
| 143 | + |
| 144 | + worker_pids.erase(key); |
| 145 | + |
| 146 | + Py_RETURN_NONE; |
| 147 | + END_HANDLE_TH_ERRORS |
| 148 | +} |
| 149 | + |
| 150 | +#undef SIGNAL_HANDLER |
| 151 | + |
| 152 | +#else |
| 153 | +// dummy implementations for windows |
| 154 | + |
| 155 | +PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored) { |
| 156 | + Py_RETURN_TRUE; |
| 157 | +} |
| 158 | + |
| 159 | +PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) { |
| 160 | + Py_RETURN_TRUE; |
| 161 | +} |
| 162 | + |
| 163 | +PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) { |
| 164 | + Py_RETURN_NONE; |
| 165 | +} |
| 166 | + |
| 167 | +PyObject *THPModule_exitIfAnyWorkerFails(PyObject *module, PyObject *_ignored) { |
| 168 | + Py_RETURN_NONE; |
| 169 | +} |
| 170 | + |
| 171 | +#endif |
| 172 | + |
| 173 | +PyMethodDef DataLoaderMethods[] = { |
| 174 | + {"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL}, |
| 175 | + {"_update_worker_pids", (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, NULL}, |
| 176 | + {"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, NULL}, |
| 177 | + {"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, NULL}, |
| 178 | + {NULL, NULL, 0, NULL} |
| 179 | +}; |
0 commit comments