11#include < sys/wait.h>
2+ #include < map>
23#include < set>
34#include < atomic>
45#include < signal.h>
1213// to do our best provide some error message to users when such unfortunate
1314// events happen.
1415
15- // TODO: The following don't work on Windows. Specifically, waitid calls and
16- // SIGCHLD handler. Currently, dummy implementation is provided for Windows.
16+ // TODO: The following don't work on Windows. Specifically, sigaction, waitid
17+ // calls ,and SIGCHLD handler. Currently, dummy implementations are provided
18+ // for Windows.
1719
1820#ifndef _WIN32
1921
@@ -37,7 +39,7 @@ static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *
3739 sigemptyset (&sa.sa_mask );
3840 if (sigaction (signal, &sa, old_sa_ptr) != 0 ) {
3941 std::ostringstream oss;
40- oss << " An error occurred while setting handler for " << strsignal (signal);
42+ oss << " An error occurred while setting handler for " << strsignal (signal) << " . " ;
4143 throw std::runtime_error (oss.str ());
4244 }
4345}
@@ -54,166 +56,75 @@ PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
5456 END_HANDLE_TH_ERRORS
5557}
5658
57- static std::set<pid_t > worker_pid_set = {};
58- // The following are needed since std::set is not asynchronous safe.
59- static std::atomic<pid_t *> worker_pids;
60- static std::atomic<size_t > num_worker_pids (0 );
61- // Pipe used as a lock to avoid update of the above and SIGCHLD handler in parallel.
62- static int comm_pipe[2 ] = {-1 , -1 };
63-
64- static void updatePIDsArray () {
65- size_t new_size = worker_pid_set.size ();
66- auto new_ptr = (pid_t *)malloc (sizeof (pid_t ) * new_size);
67- size_t idx = 0 ;
68- for (auto it = worker_pid_set.begin (); it != worker_pid_set.end (); it++, idx++) {
69- new_ptr[idx] = *it;
70- }
71-
72- // Block SIGCHLD handler for this thread so SIGCHLD handler can't interrupt
73- // from this thread
74- sigset_t sigset, old_sigset;
75- sigemptyset (&sigset);
76- sigaddset (&sigset, SIGCHLD);
77- if (sigprocmask (SIG_BLOCK, &sigset, &old_sigset) != 0 ) {
78- throw std::runtime_error (" An error occurred while setting worker information "
79- " for DataLoader SIGCHLD handler" );
80- }
81- // Acquire ``lock'' so handlers on other threads can't interrupt
82- char c;
83- read (comm_pipe[0 ], &c, 1 );
84-
85- pid_t *old_ptr = worker_pids;
86- num_worker_pids = new_size;
87- worker_pids = new_ptr;
88- free (old_ptr);
89-
90- // Release ``lock''
91- write (comm_pipe[1 ], &c, 1 );
92- // Restore handler for this thread.
93- if (sigprocmask (SIG_SETMASK, &old_sigset, NULL ) != 0 ) {
94- throw std::runtime_error (" An error occurred while setting DataLoader SIGCHLD handler" );
95- }
96- }
97-
98- static struct sigaction orig_SIGCHLD_sa;
99-
100- // SIGCHLD hander should be registered on main loader process to catch any
101- // worker failing.
102- // Python handles are _set_main_signal_handers_for_workers() and
103- // _remove_main_signal_handers_for_workers().
104- static void handler_SIGCHLD_main (int sig, siginfo_t *info, void *ctx) {
105- // Acquire ``lock'' so make sure that worker_pids won't change
106- char c;
107- read (comm_pipe[0 ], &c, 1 );
59+ static std::map<int64_t , std::set<pid_t >> worker_pids = {};
10860
61+ PyObject *THPModule_errorIfAnyWorkerFails (PyObject *module ) {
62+ HANDLE_TH_ERRORS
10963 int error;
64+ std::set<pid_t > pid_set;
65+ pid_t pid;
11066 siginfo_t infop;
11167
112- // Only check the pids we care about so that Python can see other processes'
113- // status.
114- for (size_t i = 0 ; i < num_worker_pids; i++) {
115- // Use waitid rather than waitpid so that we can set NOWAIT, and that Python
116- // can get whatever info it wants about the child process.
117- error = waitid (P_PID, worker_pids[i], &infop, WEXITED|WNOHANG|WNOWAIT);
118- if (error < 0 ) // ignore errors
119- continue ;
120- if ((infop.si_code == CLD_EXITED && infop.si_status != 0 ) || // exit with error
121- (infop.si_code == CLD_KILLED) ||
122- (infop.si_code == CLD_DUMPED)) {
123- _exit (EXIT_FAILURE);
124- }
125- }
126-
127- // Release ``lock''
128- write (comm_pipe[1 ], &c, 1 );
129-
130- // Call the overridden handler.
131- if ((orig_SIGCHLD_sa.sa_flags | SA_SIGINFO) != 0 ) {
132- // handler is sa_sigaction, this shouldn't really be SIG_IGN or SIG_DFL, but
133- // sa_sigaction and sa_handler happen to be a union, and this fact is
134- // apparently used by Python, so check here.
135- // https://stackoverflow.com/a/24080440
136- if (orig_SIGCHLD_sa.sa_sigaction == (void (*)(int , siginfo_t *, void *)) SIG_IGN) {
137- // SIG_IGN for SIGCHLD is to reap the child and do nothing else.
138- while (waitpid (-1 , 0 , WNOHANG) > 0 ) {}
139- } else if (orig_SIGCHLD_sa.sa_sigaction != (void (*)(int , siginfo_t *, void *)) SIG_DFL) {
140- // SIG_DFL for SIGCHLD is to leave the child as a zombie (do nothing)
141- orig_SIGCHLD_sa.sa_sigaction (sig, info, ctx);
142- }
143- } else {
144- // handler is sa_handler
145- if (orig_SIGCHLD_sa.sa_handler == SIG_IGN) {
146- while (waitpid (-1 , 0 , WNOHANG) > 0 ) {}
147- } else if (orig_SIGCHLD_sa.sa_handler != SIG_DFL) {
148- orig_SIGCHLD_sa.sa_handler (sig);
68+ // Only check the pids we care about
69+ for (auto it = worker_pids.begin (); it != worker_pids.end (); ++it) {
70+ pid_set = it->second ;
71+ for (auto pid_it = pid_set.begin (); pid_it != pid_set.end (); ++pid_it) {
72+ pid = *pid_it;
73+ // Use waitid rather than waitpid so that we can set NOWAIT, and that Python
74+ // and other handlers can get whatever info they want about the child.
75+ error = waitid (P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
76+ if (error < 0 ) // ignore errors
77+ continue ;
78+ if ((infop.si_code == CLD_EXITED && infop.si_status != 0 ) || // exit with error
79+ (infop.si_code == CLD_KILLED) ||
80+ (infop.si_code == CLD_DUMPED)) {
81+ std::ostringstream oss;
82+ oss << " DataLoader worker (pid " << pid << " ) exited unexpectedly." ;
83+ pid_set.clear ();
84+ throw std::runtime_error (oss.str ());
85+ }
14986 }
15087 }
151- }
152-
153- static int isSIGCHLDHanderSet () {
154- struct sigaction sa;
155- int error = sigaction (SIGCHLD, NULL , &sa);
156- if (error == 0 ) {
157- return ((sa.sa_flags | SA_SIGINFO) != 0 ) && (sa.sa_sigaction == &handler_SIGCHLD_main);
158- } else {
159- throw std::runtime_error (" An error occurred while checking DataLoader SIGCHLD handler" );
160- }
88+ Py_RETURN_NONE;
89+ END_HANDLE_TH_ERRORS
16190}
16291
16392// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
16493// of pids we are interested in.
165- PyObject *THPModule_setMainSignalHandlers (PyObject *module , PyObject *child_pids ) {
94+ PyObject *THPModule_updateWorkerPIDs (PyObject *module , PyObject *args ) {
16695 HANDLE_TH_ERRORS
167- // assert these types are lock free, just to be safe
168- THPUtils_assert (worker_pids.is_lock_free (), " worker_pids is not lock free" );
169- THPUtils_assert (num_worker_pids.is_lock_free (), " num_worker_pids is not lock free" );
170-
171- THPUtils_assert (PyTuple_Check (child_pids), " _set_main_signal_handlers_for_workers "
172- " expects a tuple, but got %s" , THPUtils_typename (child_pids));
173-
174- if (comm_pipe[0 ] == -1 ) {
175- // we have GIL here so we are fine
176- if (pipe (comm_pipe) != 0 ) {
177- throw std::runtime_error (" An error occurred while setting DataLoader SIGCHLD handler" );
178- }
179- char c = ' _' ;
180- write (comm_pipe[1 ], &c, 1 );
181- }
182-
96+ Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size (args) : 0 ;
97+ THPUtils_assert (num_args == 2 , " _update_worker_pids expectes exactly 2 arguments." );
98+ int64_t key = THPUtils_unpackLong (PyTuple_GET_ITEM (args, 0 ));
99+ THPUtils_assert (worker_pids.find (key) == worker_pids.end (), " _update_worker_pids "
100+ " should be called only once for each DataLoader." );
101+ PyObject *child_pids = PyTuple_GET_ITEM (args, 1 );
102+ THPUtils_assert (PyTuple_Check (child_pids), " _update_worker_pids "
103+ " expects a tuple for child_pids, but got %s." , THPUtils_typename (child_pids));
104+
105+ std::set<pid_t > pids_set = {};
183106 auto size = PyTuple_GET_SIZE (child_pids);
184107 for (int idx = 0 ; idx < size; idx++) {
185108 PyObject* obj = PyTuple_GET_ITEM (child_pids, idx);
186- worker_pid_set .insert ((pid_t ) THPUtils_unpackLong (obj));
109+ pids_set .insert ((pid_t ) THPUtils_unpackLong (obj));
187110 }
188- updatePIDsArray ();
189111
190- // To avoid chain calling our handler, check if the current handler is already
191- // set as ours.
192- if (!isSIGCHLDHanderSet ()) {
193- setSignalHandler (SIGCHLD, &handler_SIGCHLD_main, &orig_SIGCHLD_sa);
194- }
195- Py_RETURN_TRUE;
112+ worker_pids[key] = pids_set;
113+
114+ Py_RETURN_NONE;
196115 END_HANDLE_TH_ERRORS
197116}
198117
199- PyObject *THPModule_removeMainSignalHandlers (PyObject *module , PyObject *child_pids ) {
118+ PyObject *THPModule_removeWorkerPIDs (PyObject *module , PyObject *loader_id ) {
200119 HANDLE_TH_ERRORS
201- THPUtils_assert (PyTuple_Check (child_pids), " _remove_main_signal_handlers_for_workers "
202- " expects a tuple, but got %s" , THPUtils_typename (child_pids));
203120
204- auto size = PyTuple_GET_SIZE (child_pids);
205- for (int idx = 0 ; idx < size; idx++) {
206- PyObject* obj = PyTuple_GET_ITEM (child_pids, idx);
207- worker_pid_set.erase ((pid_t ) THPUtils_unpackLong (obj));
208- }
209- updatePIDsArray ();
121+ int64_t key = THPUtils_unpackLong (loader_id);
122+ THPUtils_assert (worker_pids.find (key) != worker_pids.end (), " Cannot find worker "
123+ " information for DataLoader with id %ld." , key);
210124
211- if (isSIGCHLDHanderSet ()) {
212- if (sigaction (SIGCHLD, &orig_SIGCHLD_sa, NULL ) != 0 ) {
213- throw std::runtime_error (" An error occurred while restoring DataLoader SIGCHLD handler" );
214- }
215- }
216- Py_RETURN_TRUE;
125+ worker_pids.erase (key);
126+
127+ Py_RETURN_NONE;
217128 END_HANDLE_TH_ERRORS
218129}
219130
@@ -226,19 +137,24 @@ PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored
226137 Py_RETURN_TRUE;
227138}
228139
229- PyObject *THPModule_setMainSignalHandlers (PyObject *module , PyObject *_ignored) {
140+ PyObject *THPModule_updateWorkerPIDs (PyObject *module , PyObject *_ignored) {
230141 Py_RETURN_TRUE;
231142}
232143
233- PyObject *THPModule_removeMainSignalHandlers (PyObject *module , PyObject *_ignored) {
234- Py_RETURN_TRUE;
144+ PyObject *THPModule_removeWorkerPIDs (PyObject *module , PyObject *_ignored) {
145+ Py_RETURN_NONE;
146+ }
147+
148+ PyObject *THPModule_exitIfAnyWorkerFails (PyObject *module , PyObject *_ignored) {
149+ Py_RETURN_NONE;
235150}
236151
237152#endif
238153
239154PyMethodDef DataLoaderMethods[] = {
240- {" _set_worker_signal_handlers" , (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL },
241- {" _set_main_signal_handlers_for_workers" , (PyCFunction)THPModule_setMainSignalHandlers, METH_O, NULL },
242- {" _remove_main_signal_handlers_for_workers" , (PyCFunction)THPModule_removeMainSignalHandlers, METH_O, NULL },
155+ {" _set_worker_signal_handlers" , (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL },
156+ {" _update_worker_pids" , (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, NULL },
157+ {" _remove_worker_pids" , (PyCFunction)THPModule_removeWorkerPIDs, METH_O, NULL },
158+ {" _error_if_any_worker_fails" , (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, NULL },
243159 {NULL , NULL , 0 , NULL }
244160};
0 commit comments