Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def send_tensor(queue, event, tp):
event.wait()


def call_backward():
x = torch.autograd.Variable(torch.randn(3, 3), requires_grad=True)
x.sum().backward()


def sum_tensors(inq, outq):
with torch.cuda.device(1):
tensors = inq.get()
Expand Down Expand Up @@ -417,6 +422,14 @@ def test_is_shared_cuda(self):
t = torch.randn(5, 5).cuda()
self.assertTrue(t.is_shared())

def test_backwards_fork(self):
r"backwards() should succeed when called before and after a fork"
call_backward()
p = mp.Process(target=call_backward)
p.start()
p.join(1)
self.assertFalse(p.is_alive())


if __name__ == '__main__':
run_tests()
33 changes: 33 additions & 0 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "torch/csrc/PtrWrapper.h"
#include "torch/csrc/utils/auto_gil.h"

#ifndef _WIN32
#include <pthread.h>
#endif

#include <unordered_set>

using namespace torch::autograd;
Expand Down Expand Up @@ -130,10 +134,27 @@ void compute_partial_exec_callbacks(const function_list& roots,
}
}

static bool _reinitialize_engine = false;

static void _maybe_reinitialize_engine_after_fork() {
// This is "probably" thread-safe because the flag is set in a fork handler

This comment was marked as off-topic.

// before any threads are created, and this function is only called with the
// GIL held. However, using fork + threads is playing with fire so this is
// more of a "best effort" thing. For example, if the fork occurs while the
// backwards threads hold a lock, we'll probably deadlock in the engine
// destructor.
if (_reinitialize_engine) {
engine.~PythonEngine();
new (&engine) torch::autograd::python::PythonEngine();
_reinitialize_engine = false;
}
}

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
_maybe_reinitialize_engine_after_fork();
PyObject *variables = NULL;
PyObject *grad_variables = NULL;
unsigned char keep_graph = 0;
Expand Down Expand Up @@ -263,6 +284,8 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}

PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
HANDLE_TH_ERRORS
_maybe_reinitialize_engine_after_fork();
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
Expand All @@ -271,6 +294,7 @@ PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
if (!result) throw python_error();
});
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
Expand Down Expand Up @@ -326,8 +350,17 @@ PyTypeObject THPEngineType = {
THPEngine_new /* tp_new */
};

static void child_atfork() {
_reinitialize_engine = true;
}

bool THPEngine_initModule(PyObject *module)
{
#ifndef _WIN32
if (pthread_atfork(NULL, NULL, child_atfork) != 0) {
throw std::runtime_error("unable to set pthread_atfork handler");
}
#endif
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
Expand Down