Skip to content

Commit b79d74a

Browse files
colesburysoumith
authored andcommitted
Re-initialize autograd engine in child processes (#4158)
* Re-initialize autograd engine in child processes The autograd engine uses threads for backwards. These don't exist after forks and they were not being re-initialized because the Engine::start_threads_flag was already set. This re-initializes the engine in child processes, which will cause it to re-create threads when backwards() is called in the child process. Note that we only attempt to handle the common case where fork() is called while the backwards threads are idle. Fixes #3966 * Avoid non-async-signal-safe functions in fork handler
1 parent 5c46427 commit b79d74a

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

test/test_multiprocessing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def send_tensor(queue, event, tp):
5050
event.wait()
5151

5252

53+
def call_backward():
54+
x = torch.autograd.Variable(torch.randn(3, 3), requires_grad=True)
55+
x.sum().backward()
56+
57+
5358
def sum_tensors(inq, outq):
5459
with torch.cuda.device(1):
5560
tensors = inq.get()
@@ -417,6 +422,14 @@ def test_is_shared_cuda(self):
417422
t = torch.randn(5, 5).cuda()
418423
self.assertTrue(t.is_shared())
419424

425+
def test_backwards_fork(self):
426+
r"backwards() should succeed when called before and after a fork"
427+
call_backward()
428+
p = mp.Process(target=call_backward)
429+
p.start()
430+
p.join(1)
431+
self.assertFalse(p.is_alive())
432+
420433

421434
if __name__ == '__main__':
422435
run_tests()

torch/csrc/autograd/python_engine.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "torch/csrc/PtrWrapper.h"
88
#include "torch/csrc/utils/auto_gil.h"
99

10+
#ifndef _WIN32
11+
#include <pthread.h>
12+
#endif
13+
1014
#include <unordered_set>
1115

1216
using namespace torch::autograd;
@@ -130,10 +134,27 @@ void compute_partial_exec_callbacks(const function_list& roots,
130134
}
131135
}
132136

137+
static bool _reinitialize_engine = false;
138+
139+
static void _maybe_reinitialize_engine_after_fork() {
140+
// This is "probably" thread-safe because the flag is set in a fork handler
141+
// before any threads are created, and this function is only called with the
142+
// GIL held. However, using fork + threads is playing with fire so this is
143+
// more of a "best effort" thing. For example, if the fork occurs while the
144+
// backwards threads hold a lock, we'll probably deadlock in the engine
145+
// destructor.
146+
if (_reinitialize_engine) {
147+
engine.~PythonEngine();
148+
new (&engine) torch::autograd::python::PythonEngine();
149+
_reinitialize_engine = false;
150+
}
151+
}
152+
133153
// Implementation of torch._C._EngineBase.run_backward
134154
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
135155
{
136156
HANDLE_TH_ERRORS
157+
_maybe_reinitialize_engine_after_fork();
137158
PyObject *variables = NULL;
138159
PyObject *grad_variables = NULL;
139160
unsigned char keep_graph = 0;
@@ -263,6 +284,8 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
263284
}
264285

265286
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
287+
HANDLE_TH_ERRORS
288+
_maybe_reinitialize_engine_after_fork();
266289
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
267290
Py_INCREF(_callback);
268291
engine.queue_callback([callback]() {
@@ -271,6 +294,7 @@ PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
271294
if (!result) throw python_error();
272295
});
273296
Py_RETURN_NONE;
297+
END_HANDLE_TH_ERRORS
274298
}
275299

276300
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -326,8 +350,17 @@ PyTypeObject THPEngineType = {
326350
THPEngine_new /* tp_new */
327351
};
328352

353+
static void child_atfork() {
354+
_reinitialize_engine = true;
355+
}
356+
329357
bool THPEngine_initModule(PyObject *module)
330358
{
359+
#ifndef _WIN32
360+
if (pthread_atfork(NULL, NULL, child_atfork) != 0) {
361+
throw std::runtime_error("unable to set pthread_atfork handler");
362+
}
363+
#endif
331364
if (PyType_Ready(&THPEngineType) < 0)
332365
return false;
333366
Py_INCREF(&THPEngineType);

0 commit comments

Comments
 (0)