Skip to content

Commit 9f96415

Browse files
committed
Invert ownership between PyFunction and THPFunction.
Fixes #16532 and #14960. This patch is a massive hack. The way I constructed it was I flipped the ownership between PyFunction and THPFunction, but maintained a weak pointer from THPFunction to PyFunction so all existing code works. Essentially, this patch assumes that PyFunction stays live as long as you have a THPFunction: intuitively, this makes sense, since the ctx object should only really stay live as long as you're actually going to execute the backwards, which will keep the PyFunction live. But as you can see from the presently skipped tests (specifically, test_hook_none), this is not always true. But it seems to be true for the code we care about, and that's enough for me! Some subtleties: - PyFunction is a C++ object that refers to a PyObject. This means it needs a custom deleter to handle deleting the PyObject, since you can't assume you have the GIL when it dies. - The old test_gc_in_destructor failed our internal assert because we never actually ran a backwards, and thus never actually materialized PyFunction. I'm chalking this up as "misuse of API" and rewrote the test to not have this problem. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: b00842b Pull Request resolved: #22983
1 parent cd11109 commit 9f96415

File tree

4 files changed

+153
-42
lines changed

4 files changed

+153
-42
lines changed

test/test_autograd.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def bw_hook(grad):
500500
self.assertEqual(counter[0], 1, 'bw_hook not called')
501501
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
502502

503+
@unittest.skip("Currently broken, will be fixed in https://github.com/pytorch/pytorch/pull/22925")
503504
def test_hook_none(self):
504505
# WARNING: this is a test for autograd internals.
505506
# You should never have to use such things in your code.
@@ -1674,12 +1675,38 @@ def test_gc_in_destructor(self):
16741675
segfault.
16751676
"""
16761677
class CollectOnDelete(Function):
1678+
def forward(self, x):
1679+
return x
1680+
1681+
def backward(self, grad_output):
1682+
return grad_output
16771683

16781684
def __del__(self):
16791685
gc.collect()
16801686

16811687
for _ in range(10):
1682-
Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
1688+
CollectOnDelete()(torch.randn(1, requires_grad=True)).backward()
1689+
1690+
def test_call_legacy_twice(self):
1691+
class Id(Function):
1692+
def forward(self, x):
1693+
self.save_for_backward(x)
1694+
return x
1695+
1696+
def backward(self, grad_x):
1697+
x = self.saved_tensors
1698+
return x
1699+
1700+
f = Id()
1701+
x1 = torch.zeros(1, requires_grad=True)
1702+
x2 = torch.ones(1, requires_grad=True)
1703+
y = f(x1)
1704+
with warnings.catch_warnings(record=True) as w:
1705+
z = f(x2)
1706+
self.assertIn('extending-torch-autograd', str(w[0].message))
1707+
y.backward()
1708+
# Yeah, uh, this is totally nuts
1709+
self.assertEqual(x2.grad, x2)
16831710

16841711
@unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU")
16851712
@skipIfRocm

test/test_nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4635,6 +4635,32 @@ def test_data_parallel_device_args(self):
46354635
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
46364636
self.assertEqual(out, l(i))
46374637

4638+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
4639+
@skipIfRocm
4640+
def test_data_parallel_function_deletion(self):
4641+
# this test case is originated from #16532
4642+
def gradient_penalty(net, x):
4643+
output = net(x)
4644+
loss = torch.autograd.grad(
4645+
outputs=output, inputs=x,
4646+
grad_outputs=x.new_ones(output.size()),
4647+
create_graph=True, retain_graph=True)[0].mean()
4648+
return loss
4649+
4650+
net = nn.Linear(4, 1).cuda()
4651+
dpn = nn.DataParallel(net, [0, 1])
4652+
x = torch.ones(2, 4, requires_grad=True).cuda()
4653+
4654+
dpn.zero_grad()
4655+
loss = gradient_penalty(dpn, x)
4656+
loss.backward()
4657+
grads = [p.grad for p in net.parameters()]
4658+
self.assertEqual(2, len(grads))
4659+
self.assertEqual(
4660+
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
4661+
grads[0])
4662+
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
4663+
46384664
def test_state_dict(self):
46394665
l = nn.Linear(5, 5)
46404666
block = nn.Module()

torch/csrc/autograd/python_function.cpp

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ auto PyFunction::name() const -> std::string {
222222
AutoGIL gil;
223223
auto f = (THPFunction*) obj;
224224
auto name = std::string(Py_TYPE(f)->tp_name);
225-
THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy"));
225+
THPObjectPtr _legacy(PyObject_GetAttrString((PyObject*)obj, "_is_legacy")); // CONST?!
226226
if (_legacy == Py_True) {
227227
name += "LegacyBackward";
228228
}
@@ -238,14 +238,19 @@ auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> {
238238
// Traverse and clear are required for supporting Python's GC cycle handling.
239239
static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
240240
{
241-
for (const auto& hook : self->cdata.pre_hooks()) {
242-
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
243-
Py_VISIT(pyhook->dict);
241+
auto cdata = self->cdata.lock();
242+
// cdata could be null if someone constructed a legacy function but haven't
243+
// actually called backward() on it yet.
244+
if (cdata) {
245+
for (const auto& hook : cdata->pre_hooks()) {
246+
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
247+
Py_VISIT(pyhook->dict);
248+
}
244249
}
245-
}
246-
for (const auto& hook : self->cdata.post_hooks()) {
247-
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
248-
Py_VISIT(pyhook->dict);
250+
for (const auto& hook : cdata->post_hooks()) {
251+
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
252+
Py_VISIT(pyhook->dict);
253+
}
249254
}
250255
}
251256
Py_VISIT(self->to_save);
@@ -256,7 +261,11 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
256261

257262
static int THPFunction_clear(THPFunction *self)
258263
{
259-
self->cdata.clear_input_metadata();
264+
// Why is this guaranteed to be true? Suppose that self->cdata is non-null
265+
// (otherwise the condition is trivially true). Then there is a PyFunction
266+
// which contains an owning reference to this object. But we are only
267+
// allowed to clear if all owning references are gone! Contradiction.
268+
TORCH_INTERNAL_ASSERT(!self->cdata.lock());
260269

261270
Py_CLEAR(self->needs_input_grad);
262271

@@ -269,22 +278,14 @@ static int THPFunction_clear(THPFunction *self)
269278
self->saved_variables.clear();
270279
self->is_variable_input.clear();
271280

272-
// Moving the hooks out makes sure to first disassociate them from the
273-
// function, but without destroying any of them. They will get deleted when
274-
// exiting this scope. This is important, because deleting Python objects can
275-
// trigger deletion of other objects, and they can reference this function,
276-
// seeing it in a half-deleted state.
277-
auto pre_hooks = std::move(self->cdata.pre_hooks());
278-
auto post_hooks = std::move(self->cdata.post_hooks());
279-
280281
return 0;
281282
}
282283

283284
static void THPFunction_dealloc(THPFunction* self)
284285
{
285286
PyObject_GC_UnTrack(self);
286287
THPFunction_clear(self);
287-
self->cdata.~PyFunction();
288+
self->cdata.~weak_ptr<PyFunction>();
288289
self->output_info.~vector();
289290
self->input_info.~vector();
290291
self->saved_variables.~vector();
@@ -299,7 +300,8 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
299300
// Python zero-initializes the object memory, so there's no need to initialize
300301
// most fields
301302
THPFunction* self = (THPFunction*)obj;
302-
new (&self->cdata) PyFunction(obj);
303+
// Setup the PyFunction later; we can't keep it live here
304+
new (&self->cdata) std::weak_ptr<PyFunction>();
303305
new (&self->output_info) std::vector<VariableInfo>();
304306
new (&self->input_info) std::vector<VariableInfo>();
305307
new (&self->saved_variables) std::vector<SavedVariable>();
@@ -411,15 +413,16 @@ static void _save_variables(THPFunction* self)
411413
Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
412414
self->saved_variables.clear();
413415
self->saved_variables.reserve(num_saved);
414-
auto cdata_ptr = &self->cdata;
416+
auto cdata_ptr = self->cdata.lock();
417+
TORCH_INTERNAL_ASSERT(cdata_ptr);
415418
for (int i = 0; i < num_saved; i++) {
416419
PyObject *obj = PyTuple_GET_ITEM(self->to_save, i);
417420
if (obj == Py_None) {
418421
self->saved_variables.emplace_back();
419422
continue;
420423
} else if (THPVariable_Check(obj)) {
421424
auto variable = (THPVariable*)obj;
422-
bool is_output = variable->cdata.grad_fn().get() == cdata_ptr;
425+
bool is_output = variable->cdata.grad_fn().get() == cdata_ptr.get();
423426
self->saved_variables.emplace_back(variable->cdata, is_output);
424427
} else {
425428
throw TypeError(
@@ -588,7 +591,9 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
588591
THPObjectPtr outputs(PyTuple_New(num_outputs));
589592
if (!outputs) throw python_error();
590593

591-
grad_fn->cdata.clear_input_metadata();
594+
auto cdata = grad_fn->cdata.lock();
595+
TORCH_INTERNAL_ASSERT(cdata);
596+
cdata->clear_input_metadata();
592597

593598
// Record type, device, and size information about inputs
594599
if (is_executable) {
@@ -635,7 +640,36 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
635640
auto& unpacked_input = info_pair.first;
636641
auto& input_info = info_pair.second;
637642
bool is_executable = input_info.is_executable;
638-
self->cdata.set_next_edges(std::move(input_info.next_edges));
643+
Py_INCREF(self);
644+
// Needs to be an owning reference to keep it live until the end
645+
// of this function, since THPFunction won't keep it live (eventually,
646+
// we'll take out an owning reference when we process_outputs).
647+
std::shared_ptr<PyFunction> cdata;
648+
if (cdata = self->cdata.lock()) {
649+
// In some pathological cases, self->cdata can already be set on entry to
650+
// this function. This occurs on misuse of the legacy autograd API in the
651+
// following way:
652+
//
653+
// f = MyFunction()
654+
// y1 = f(x1)
655+
// y2 = f(x2) # bad!!
656+
//
657+
// Historically, we did something very nutty: we set y1.grad_fn ==
658+
// y2.grad_fn (even though these variables really have nothing to do with
659+
// each other.) At least now we have a warning. All of this hoo-ha will
660+
// go away when we delete the implementation of legacy autograd.
661+
TORCH_WARN(
662+
"Legacy autograd function object was called twice. You will probably "
663+
"get incorrect gradients from this computation, as the saved tensors "
664+
"from the second invocation will clobber the saved tensors from the "
665+
"first invocation. Please consider rewriting your autograd function "
666+
"in the modern style; for information on the new format, please see: "
667+
"https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd");
668+
} else {
669+
cdata = std::shared_ptr<PyFunction>(new PyFunction(THPObjectPtr((PyObject*)self)), deleteFunction);
670+
self->cdata = cdata;
671+
}
672+
cdata->set_next_edges(std::move(input_info.next_edges));
639673
self->needs_input_grad = input_info.needs_input_grad.release();
640674

641675
// We don't support tracing in the legacy code path
@@ -670,6 +704,9 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
670704
if (!ctx_obj) return nullptr;
671705
THPFunction* ctx = (THPFunction*)ctx_obj.get();
672706

707+
auto cdata = std::shared_ptr<PyFunction>(new PyFunction(std::move(ctx_obj)), deleteFunction);
708+
ctx->cdata = cdata;
709+
673710
// Prepare inputs and allocate context (grad fn)
674711
auto info_pair = unpack_input<false>(inputs);
675712
UnpackedInput& unpacked_input = info_pair.first;
@@ -680,14 +717,15 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
680717

681718
// Initialize backward function (and ctx)
682719
bool is_executable = input_info.is_executable;
683-
ctx->cdata.set_next_edges(std::move(input_info.next_edges));
720+
cdata->set_next_edges(std::move(input_info.next_edges));
684721
ctx->needs_input_grad = input_info.needs_input_grad.release();
685722
ctx->is_variable_input = std::move(input_info.is_variable_input);
686723

687724
// Prepend ctx to input_tuple, in preparation for static method call
688725
auto num_args = PyTuple_GET_SIZE(inputs);
689726
THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
690-
PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release());
727+
Py_INCREF(ctx);
728+
PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
691729
for (int i = 0; i < num_args; ++i) {
692730
PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
693731
Py_INCREF(arg);
@@ -749,7 +787,9 @@ static void _prepare_grads(THPFunction *self, THPObjectPtr& raw_grads, bool is_g
749787
static void _trim_grad_input(THPFunction *self, THPObjectPtr& grad_input)
750788
{
751789
int num_grads = PyTuple_GET_SIZE(grad_input.get());
752-
const int num_outputs = self->cdata.num_outputs();
790+
auto cdata = self->cdata.lock();
791+
TORCH_INTERNAL_ASSERT(cdata);
792+
const int num_outputs = cdata->num_outputs();
753793
if (num_grads > num_outputs) {
754794
// Check that all extra grads are none
755795
bool all_none = true;
@@ -777,9 +817,11 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
777817
THPUtils_invalidArguments(args, nullptr, "_do_backward", 1, "(tuple, bool)");
778818
return nullptr;
779819
}
780-
THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs(),
820+
auto cdata = self->cdata.lock();
821+
TORCH_INTERNAL_ASSERT(cdata);
822+
THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == cdata->num_inputs(),
781823
"%s got an invalid number of gradients (expected %d got %d)",
782-
THPUtils_typename(self), self->cdata.num_inputs(),
824+
THPUtils_typename(self), cdata->num_inputs(),
783825
PyTuple_GET_SIZE(raw_grad_output));
784826

785827
// Some of the output might have been unused, so we have to allocate
@@ -800,7 +842,7 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
800842
// if and only if the additional ones are all None
801843
_trim_grad_input(self, grad_input);
802844
int num_grads = PyTuple_GET_SIZE(grad_input.get());
803-
int num_outputs = self->cdata.num_outputs();
845+
int num_outputs = cdata->num_outputs();
804846
THPUtils_assert(num_grads == num_outputs, "%s returned an invalid number of "
805847
"gradient tensors (expected %d, but got %d)", THPUtils_typename(self),
806848
num_outputs, num_grads);
@@ -827,13 +869,17 @@ PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var)
827869
THPVariable *var = (THPVariable*)_var;
828870
std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(
829871
var->backward_hooks, var->cdata.output_nr()));
830-
self->cdata.add_pre_hook(std::move(hook));
872+
auto cdata = self->cdata.lock();
873+
TORCH_INTERNAL_ASSERT(cdata);
874+
cdata->add_pre_hook(std::move(hook));
831875
Py_RETURN_NONE;
832876
}
833877

834878
PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook)
835879
{
836-
return torch::autograd::registerFunctionHook(self->cdata, hook);
880+
auto cdata = self->cdata.lock();
881+
TORCH_INTERNAL_ASSERT(cdata);
882+
return torch::autograd::registerFunctionHook(*cdata, hook);
837883
}
838884

839885
static PyObject *unpack_saved_variables(
@@ -887,14 +933,16 @@ PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused)
887933

888934
PyObject *THPFunction_next_functions(THPFunction *self, void *_unused)
889935
{
890-
const auto num_outputs = self->cdata.num_outputs();
936+
auto cdata = self->cdata.lock();
937+
TORCH_INTERNAL_ASSERT(cdata);
938+
const auto num_outputs = cdata->num_outputs();
891939
THPObjectPtr result(PyTuple_New(num_outputs));
892940
if (!result)
893941
return nullptr;
894942
for (uint32_t i = 0; i < num_outputs; i++) {
895943
THPObjectPtr fn_tuple(PyTuple_New(2));
896944
if (!fn_tuple) return nullptr;
897-
const auto& edge = self->cdata.next_edge(i);
945+
const auto& edge = cdata->next_edge(i);
898946
PyObject* fn = functionToPyObject(edge.function);
899947
if (!fn) return nullptr;
900948
PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
@@ -906,7 +954,8 @@ PyObject *THPFunction_next_functions(THPFunction *self, void *_unused)
906954

907955
PyObject *THPFunction_metadata(THPFunction *self, void *_unused)
908956
{
909-
auto metadata = static_cast<PyAnomalyMetadata*>(self->cdata.metadata())->dict();
957+
auto cdata = self->cdata.lock();
958+
auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict();
910959

911960
Py_INCREF(metadata);
912961
return metadata;
@@ -1051,6 +1100,5 @@ std::shared_ptr<PyFunction> THPFunction_asFunction(THPFunction* self)
10511100
return std::shared_ptr<PyFunction>();
10521101
}
10531102

1054-
Py_INCREF((PyObject*)self);
1055-
return std::shared_ptr<PyFunction>(&self->cdata, Decref());
1103+
return self->cdata.lock();
10561104
}

torch/csrc/autograd/python_function.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct VariableInfo {
3333
// A Function which is implemented by a Python object (i.e., a THPFunction).
3434
// Calls to 'apply' are forwarded to the Python method implementation.
3535
struct PyFunction : public Function {
36-
PyFunction(PyObject* obj) : obj(obj) {}
36+
PyFunction(THPObjectPtr obj) : obj(obj.release()) {}
3737

3838
variable_list apply(variable_list&& inputs) override;
3939
variable_list legacy_apply(const variable_list& inputs);
@@ -43,8 +43,16 @@ struct PyFunction : public Function {
4343
std::shared_ptr<Function> get_shared_ptr() override;
4444
bool is_traceable() override;
4545

46-
// THPFunction this Function is wrapping.
46+
// THPFunction this Function is wrapping. Owning!
4747
PyObject* obj;
48+
49+
~PyFunction() {
50+
// Can't use THPObjectPtr as a field in this class; destructor won't take
51+
// out GIL! When I forgot to do this by hand
52+
// TestAutograd.test_inplace_view_python called me out about it.
53+
AutoGIL g;
54+
Py_DECREF(obj);
55+
}
4856
};
4957

5058
/**
@@ -89,9 +97,11 @@ struct THPFunction {
8997
std::vector<bool> is_variable_input;
9098
char has_freed_buffers;
9199

92-
// The C++ wrapper for this Python function.
93-
// See a comment in THPFunction_asFunction for details about this field.
94-
torch::autograd::PyFunction cdata;
100+
// The actual PyFunction (in the autograd graph) that this data was
101+
// saved for. This field may be NULL (because a user can construct
102+
// a THPFunction directly from Python), but when this field is non-NULL,
103+
// it is guaranteed that cdata.lock()->obj == this
104+
std::weak_ptr<torch::autograd::PyFunction> cdata;
95105
};
96106

97107
bool THPFunction_initModule(PyObject *module);

0 commit comments

Comments
 (0)