Skip to content
8 changes: 4 additions & 4 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ def wrapper(ctx, *args):
if not requires_grad:
return outputs

err_fn = torch._C._functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable")

if not isinstance(outputs, tuple):
outputs = (outputs,)

err_fn = torch._C._functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable", len(outputs))

# Create aliases of each output that has requires_grad=True. We need
# at least one of the inputs to err_fn to require grad so that the
# output will have a grad_fn.
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/autograd/functions/basic_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ struct Error : public Function {

// Identity in forward, Error in backward. Used to implement @once_differentiable
struct DelayedError : public Function {
DelayedError(std::string msg)
: msg(std::move(msg)) {};
DelayedError(std::string msg, int num_inputs)
: msg(std::move(msg)) {
for (int i = 0; i < num_inputs; i++)
add_input_metadata(Function::undefined_input());
}

virtual variable_list apply(const variable_list& inputs) override;

Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/autograd/functions/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ using torch::TupleParser;
struct DelayedErrorCtor {
DelayedError* operator()(PyObject* args) {
std::string msg;
int num_inputs;

TupleParser parser(args, 1);
TupleParser parser(args, 2);
parser.parse(msg, "msg");
parser.parse(num_inputs, "num_inputs");

return new DelayedError(msg);
return new DelayedError(msg, num_inputs);
}
};

Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/python_cpp_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs)
}

int num_inputs = PyTuple_GET_SIZE(args);
int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
if (num_inputs != num_inputs_required) {
return PyErr_Format(PyExc_TypeError, "expected %d arguments, got %d instead",
num_inputs_required, num_inputs);
}
variable_list vars(num_inputs);
for (int i = 0; i != num_inputs; ++i) {
PyObject* arg = PyTuple_GET_ITEM(args, i);
Expand Down