Skip to content

Commit 86eeeab

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Fix segmentation fault in grad_fn (#9292)
Summary: Fixes #8774 . Reviewed By: soumith Differential Revision: D8836478 Pulled By: apaszke fbshipit-source-id: f113bf47fe493be9f095a5a5490caf08dbb44e38
1 parent bcd20f9 commit 86eeeab

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

torch/autograd/function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,13 @@ def wrapper(ctx, *args):
204204
if not requires_grad:
205205
return outputs
206206

207-
err_fn = torch._C._functions.DelayedError(
208-
b"trying to differentiate twice a function that was marked"
209-
b"with @once_differentiable")
210-
211207
if not isinstance(outputs, tuple):
212208
outputs = (outputs,)
213209

210+
err_fn = torch._C._functions.DelayedError(
211+
b"trying to differentiate twice a function that was marked"
212+
b"with @once_differentiable", len(outputs))
213+
214214
# Create aliases of each output that has requires_grad=True. We need
215215
# at least one of the inputs to err_fn to require grad so that the
216216
# output will have a grad_fn.

torch/csrc/autograd/functions/basic_ops.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ struct Error : public Function {
2525

2626
// Identity in forward, Error in backward. Used to implement @once_differentiable
2727
struct DelayedError : public Function {
28-
DelayedError(std::string msg)
29-
: msg(std::move(msg)) {};
28+
DelayedError(std::string msg, int num_inputs)
29+
: msg(std::move(msg)) {
30+
for (int i = 0; i < num_inputs; i++)
31+
add_input_metadata(Function::undefined_input());
32+
}
3033

3134
variable_list apply(variable_list&& inputs) override;
3235

torch/csrc/autograd/functions/init.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ using torch::TupleParser;
1515
struct DelayedErrorCtor {
1616
DelayedError* operator()(PyObject* args) {
1717
std::string msg;
18+
int num_inputs;
1819

19-
TupleParser parser(args, 1);
20+
TupleParser parser(args, 2);
2021
parser.parse(msg, "msg");
22+
parser.parse(num_inputs, "num_inputs");
2123

22-
return new DelayedError(msg);
24+
return new DelayedError(msg, num_inputs);
2325
}
2426
};
2527

torch/csrc/autograd/python_cpp_function.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs)
2727
}
2828

2929
int num_inputs = PyTuple_GET_SIZE(args);
30+
int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
31+
if (num_inputs != num_inputs_required) {
32+
return PyErr_Format(PyExc_TypeError, "expected %d arguments, got %d instead",
33+
num_inputs_required, num_inputs);
34+
}
3035
variable_list vars(num_inputs);
3136
for (int i = 0; i != num_inputs; ++i) {
3237
PyObject* arg = PyTuple_GET_ITEM(args, i);

0 commit comments

Comments
 (0)