@@ -222,7 +222,7 @@ auto PyFunction::name() const -> std::string {
222222 AutoGIL gil;
223223 auto f = (THPFunction*) obj;
224224 auto name = std::string (Py_TYPE (f)->tp_name );
225- THPObjectPtr _legacy (PyObject_GetAttrString (obj, " _is_legacy" ));
225+ THPObjectPtr _legacy (PyObject_GetAttrString ((PyObject*) obj, " _is_legacy" )); // CONST?!
226226 if (_legacy == Py_True) {
227227 name += " LegacyBackward" ;
228228 }
@@ -238,14 +238,19 @@ auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> {
238238// Traverse and clear are required for supporting Python's GC cycle handling.
239239static int THPFunction_traverse (THPFunction *self, visitproc visit, void *arg)
240240{
241- for (const auto & hook : self->cdata .pre_hooks ()) {
242- if (auto pyhook = dynamic_cast <PyFunctionPreHook*>(hook.get ())) {
243- Py_VISIT (pyhook->dict );
241+ auto cdata = self->cdata .lock ();
242+ // cdata could be null if someone constructed a legacy function but haven't
243+ // actually called backward() on it yet.
244+ if (cdata) {
245+ for (const auto & hook : cdata->pre_hooks ()) {
246+ if (auto pyhook = dynamic_cast <PyFunctionPreHook*>(hook.get ())) {
247+ Py_VISIT (pyhook->dict );
248+ }
244249 }
245- }
246- for ( const auto & hook : self-> cdata . post_hooks ( )) {
247- if ( auto pyhook = dynamic_cast <PyFunctionPostHook*>(hook. get ())) {
248- Py_VISIT (pyhook-> dict );
250+ for ( const auto & hook : cdata-> post_hooks ()) {
251+ if ( auto pyhook = dynamic_cast <PyFunctionPostHook*>(hook. get () )) {
252+ Py_VISIT (pyhook-> dict );
253+ }
249254 }
250255 }
251256 Py_VISIT (self->to_save );
@@ -256,7 +261,11 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
256261
257262static int THPFunction_clear (THPFunction *self)
258263{
259- self->cdata .clear_input_metadata ();
264+ // Why is this guaranteed to be true? Suppose that self->cdata is non-null
265+ // (otherwise the condition is trivially true). Then there is a PyFunction
266+ // which contains an owning reference to this object. But we are only
267+ // allowed to clear if all owning references are gone! Contradiction.
268+ TORCH_INTERNAL_ASSERT (!self->cdata .lock ());
260269
261270 Py_CLEAR (self->needs_input_grad );
262271
@@ -269,22 +278,14 @@ static int THPFunction_clear(THPFunction *self)
269278 self->saved_variables .clear ();
270279 self->is_variable_input .clear ();
271280
272- // Moving the hooks out makes sure to first disassociate them from the
273- // function, but without destroying any of them. They will get deleted when
274- // exiting this scope. This is important, because deleting Python objects can
275- // trigger deletion of other objects, and they can reference this function,
276- // seeing it in a half-deleted state.
277- auto pre_hooks = std::move (self->cdata .pre_hooks ());
278- auto post_hooks = std::move (self->cdata .post_hooks ());
279-
280281 return 0 ;
281282}
282283
283284static void THPFunction_dealloc (THPFunction* self)
284285{
285286 PyObject_GC_UnTrack (self);
286287 THPFunction_clear (self);
287- self->cdata .~PyFunction ();
288+ self->cdata .~weak_ptr< PyFunction> ();
288289 self->output_info .~vector ();
289290 self->input_info .~vector ();
290291 self->saved_variables .~vector ();
@@ -299,7 +300,8 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
299300 // Python zero-initializes the object memory, so there's no need to initialize
300301 // most fields
301302 THPFunction* self = (THPFunction*)obj;
302- new (&self->cdata ) PyFunction (obj);
303+ // Setup the PyFunction later; we can't keep it live here
304+ new (&self->cdata ) std::weak_ptr<PyFunction>();
303305 new (&self->output_info ) std::vector<VariableInfo>();
304306 new (&self->input_info ) std::vector<VariableInfo>();
305307 new (&self->saved_variables ) std::vector<SavedVariable>();
@@ -411,15 +413,16 @@ static void _save_variables(THPFunction* self)
411413 Py_ssize_t num_saved = PyTuple_GET_SIZE (self->to_save );
412414 self->saved_variables .clear ();
413415 self->saved_variables .reserve (num_saved);
414- auto cdata_ptr = &self->cdata ;
416+ auto cdata_ptr = self->cdata .lock ();
417+ TORCH_INTERNAL_ASSERT (cdata_ptr);
415418 for (int i = 0 ; i < num_saved; i++) {
416419 PyObject *obj = PyTuple_GET_ITEM (self->to_save , i);
417420 if (obj == Py_None) {
418421 self->saved_variables .emplace_back ();
419422 continue ;
420423 } else if (THPVariable_Check (obj)) {
421424 auto variable = (THPVariable*)obj;
422- bool is_output = variable->cdata .grad_fn ().get () == cdata_ptr;
425+ bool is_output = variable->cdata .grad_fn ().get () == cdata_ptr. get () ;
423426 self->saved_variables .emplace_back (variable->cdata , is_output);
424427 } else {
425428 throw TypeError (
@@ -588,7 +591,9 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
588591 THPObjectPtr outputs (PyTuple_New (num_outputs));
589592 if (!outputs) throw python_error ();
590593
591- grad_fn->cdata .clear_input_metadata ();
594+ auto cdata = grad_fn->cdata .lock ();
595+ TORCH_INTERNAL_ASSERT (cdata);
596+ cdata->clear_input_metadata ();
592597
593598 // Record type, device, and size information about inputs
594599 if (is_executable) {
@@ -635,7 +640,36 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
635640 auto & unpacked_input = info_pair.first ;
636641 auto & input_info = info_pair.second ;
637642 bool is_executable = input_info.is_executable ;
638- self->cdata .set_next_edges (std::move (input_info.next_edges ));
643+ Py_INCREF (self);
644+ // Needs to be an owning reference to keep it live until the end
645+ // of this function, since THPFunction won't keep it live (eventually,
646+ // we'll take out an owning reference when we process_outputs).
647+ std::shared_ptr<PyFunction> cdata;
648+ if (cdata = self->cdata .lock ()) {
649+ // In some pathological cases, self->cdata can already be set on entry to
650+ // this function. This occurs on misuse of the legacy autograd API in the
651+ // following way:
652+ //
653+ // f = MyFunction()
654+ // y1 = f(x1)
655+ // y2 = f(x2) # bad!!
656+ //
657+ // Historically, we did something very nutty: we set y1.grad_fn ==
658+ // y2.grad_fn (even though these variables really have nothing to do with
659+ // each other.) At least now we have a warning. All of this hoo-ha will
660+ // go away when we delete the implementation of legacy autograd.
661+ TORCH_WARN (
662+ " Legacy autograd function object was called twice. You will probably "
663+ " get incorrect gradients from this computation, as the saved tensors "
664+ " from the second invocation will clobber the saved tensors from the "
665+ " first invocation. Please consider rewriting your autograd function "
666+ " in the modern style; for information on the new format, please see: "
667+ " https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd" );
668+ } else {
669+ cdata = std::shared_ptr<PyFunction>(new PyFunction (THPObjectPtr ((PyObject*)self)), deleteFunction);
670+ self->cdata = cdata;
671+ }
672+ cdata->set_next_edges (std::move (input_info.next_edges ));
639673 self->needs_input_grad = input_info.needs_input_grad .release ();
640674
641675 // We don't support tracing in the legacy code path
@@ -670,6 +704,9 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
670704 if (!ctx_obj) return nullptr ;
671705 THPFunction* ctx = (THPFunction*)ctx_obj.get ();
672706
707+ auto cdata = std::shared_ptr<PyFunction>(new PyFunction (std::move (ctx_obj)), deleteFunction);
708+ ctx->cdata = cdata;
709+
673710 // Prepare inputs and allocate context (grad fn)
674711 auto info_pair = unpack_input<false >(inputs);
675712 UnpackedInput& unpacked_input = info_pair.first ;
@@ -680,14 +717,15 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
680717
681718 // Initialize backward function (and ctx)
682719 bool is_executable = input_info.is_executable ;
683- ctx-> cdata . set_next_edges (std::move (input_info.next_edges ));
720+ cdata-> set_next_edges (std::move (input_info.next_edges ));
684721 ctx->needs_input_grad = input_info.needs_input_grad .release ();
685722 ctx->is_variable_input = std::move (input_info.is_variable_input );
686723
687724 // Prepend ctx to input_tuple, in preparation for static method call
688725 auto num_args = PyTuple_GET_SIZE (inputs);
689726 THPObjectPtr ctx_input_tuple (PyTuple_New (num_args + 1 ));
690- PyTuple_SET_ITEM (ctx_input_tuple.get (), 0 , ctx_obj.release ());
727+ Py_INCREF (ctx);
728+ PyTuple_SET_ITEM (ctx_input_tuple.get (), 0 , (PyObject*)ctx);
691729 for (int i = 0 ; i < num_args; ++i) {
692730 PyObject *arg = PyTuple_GET_ITEM (unpacked_input.input_tuple .get (), i);
693731 Py_INCREF (arg);
@@ -749,7 +787,9 @@ static void _prepare_grads(THPFunction *self, THPObjectPtr& raw_grads, bool is_g
749787static void _trim_grad_input (THPFunction *self, THPObjectPtr& grad_input)
750788{
751789 int num_grads = PyTuple_GET_SIZE (grad_input.get ());
752- const int num_outputs = self->cdata .num_outputs ();
790+ auto cdata = self->cdata .lock ();
791+ TORCH_INTERNAL_ASSERT (cdata);
792+ const int num_outputs = cdata->num_outputs ();
753793 if (num_grads > num_outputs) {
754794 // Check that all extra grads are none
755795 bool all_none = true ;
@@ -777,9 +817,11 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
777817 THPUtils_invalidArguments (args, nullptr , " _do_backward" , 1 , " (tuple, bool)" );
778818 return nullptr ;
779819 }
780- THPUtils_assert (PyTuple_GET_SIZE (raw_grad_output) == self->cdata .num_inputs (),
820+ auto cdata = self->cdata .lock ();
821+ TORCH_INTERNAL_ASSERT (cdata);
822+ THPUtils_assert (PyTuple_GET_SIZE (raw_grad_output) == cdata->num_inputs (),
781823 " %s got an invalid number of gradients (expected %d got %d)" ,
782- THPUtils_typename (self), self-> cdata . num_inputs (),
824+ THPUtils_typename (self), cdata-> num_inputs (),
783825 PyTuple_GET_SIZE (raw_grad_output));
784826
785827 // Some of the output might have been unused, so we have to allocate
@@ -800,7 +842,7 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
800842 // if and only if the additional ones are all None
801843 _trim_grad_input (self, grad_input);
802844 int num_grads = PyTuple_GET_SIZE (grad_input.get ());
803- int num_outputs = self-> cdata . num_outputs ();
845+ int num_outputs = cdata-> num_outputs ();
804846 THPUtils_assert (num_grads == num_outputs, " %s returned an invalid number of "
805847 " gradient tensors (expected %d, but got %d)" , THPUtils_typename (self),
806848 num_outputs, num_grads);
@@ -827,13 +869,17 @@ PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var)
827869 THPVariable *var = (THPVariable*)_var;
828870 std::unique_ptr<FunctionPreHook> hook (new PyFunctionPreHook (
829871 var->backward_hooks , var->cdata .output_nr ()));
830- self->cdata .add_pre_hook (std::move (hook));
872+ auto cdata = self->cdata .lock ();
873+ TORCH_INTERNAL_ASSERT (cdata);
874+ cdata->add_pre_hook (std::move (hook));
831875 Py_RETURN_NONE;
832876}
833877
834878PyObject* THPFunction_register_hook (THPFunction *self, PyObject *hook)
835879{
836- return torch::autograd::registerFunctionHook (self->cdata , hook);
880+ auto cdata = self->cdata .lock ();
881+ TORCH_INTERNAL_ASSERT (cdata);
882+ return torch::autograd::registerFunctionHook (*cdata, hook);
837883}
838884
839885static PyObject *unpack_saved_variables (
@@ -887,14 +933,16 @@ PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused)
887933
888934PyObject *THPFunction_next_functions (THPFunction *self, void *_unused)
889935{
890- const auto num_outputs = self->cdata .num_outputs ();
936+ auto cdata = self->cdata .lock ();
937+ TORCH_INTERNAL_ASSERT (cdata);
938+ const auto num_outputs = cdata->num_outputs ();
891939 THPObjectPtr result (PyTuple_New (num_outputs));
892940 if (!result)
893941 return nullptr ;
894942 for (uint32_t i = 0 ; i < num_outputs; i++) {
895943 THPObjectPtr fn_tuple (PyTuple_New (2 ));
896944 if (!fn_tuple) return nullptr ;
897- const auto & edge = self-> cdata . next_edge (i);
945+ const auto & edge = cdata-> next_edge (i);
898946 PyObject* fn = functionToPyObject (edge.function );
899947 if (!fn) return nullptr ;
900948 PyTuple_SET_ITEM (fn_tuple.get (), 0 , fn);
@@ -906,7 +954,8 @@ PyObject *THPFunction_next_functions(THPFunction *self, void *_unused)
906954
907955PyObject *THPFunction_metadata (THPFunction *self, void *_unused)
908956{
909- auto metadata = static_cast <PyAnomalyMetadata*>(self->cdata .metadata ())->dict ();
957+ auto cdata = self->cdata .lock ();
958+ auto metadata = static_cast <PyAnomalyMetadata*>(cdata->metadata ())->dict ();
910959
911960 Py_INCREF (metadata);
912961 return metadata;
@@ -1051,6 +1100,5 @@ std::shared_ptr<PyFunction> THPFunction_asFunction(THPFunction* self)
10511100 return std::shared_ptr<PyFunction>();
10521101 }
10531102
1054- Py_INCREF ((PyObject*)self);
1055- return std::shared_ptr<PyFunction>(&self->cdata , Decref ());
1103+ return self->cdata .lock ();
10561104}
0 commit comments