Skip to content

Commit eb314f9

Browse files
zou3519pytorchmergebot
authored andcommitted
Add setup_context staticmethod to autograd.Function (#89859)
Adds a setup_context staticmethod to autograd.Function. If it exists, then the user splits the ctx-specific logic from the forward() and puts it in the setup_context staticmethod. Docs will come later when we remove the feature flag. Test Plan: - some light tests Pull Request resolved: #89859 Approved by: https://github.com/soulitzer
1 parent 103be1f commit eb314f9

File tree

3 files changed

+174
-15
lines changed

3 files changed

+174
-15
lines changed

test/test_autograd.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,94 @@ def fn(x):
544544
with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
545545
torch._C._will_engine_execute_node(out)
546546

547+
def test_custom_function_setup_context_simple(self):
548+
class MySquare(Function):
549+
@staticmethod
550+
def forward(x):
551+
return x ** 2
552+
553+
@staticmethod
554+
def setup_context(ctx, inputs, outputs):
555+
x, = inputs
556+
ctx.save_for_backward(x)
557+
558+
@staticmethod
559+
def backward(ctx, gO):
560+
x, = ctx.saved_tensors
561+
return gO * 2 * x
562+
563+
with torch.autograd.function._set_autograd_function_extension_enabled(True):
564+
x = torch.randn([], requires_grad=True)
565+
y = MySquare.apply(x)
566+
gx, = torch.autograd.grad(y, x)
567+
self.assertEqual(gx, 2 * x)
568+
569+
def test_custom_function_setup_context_multi_output(self):
570+
# Multiple outputs with some non-Tensor outputs.
571+
class MySquare(Function):
572+
@staticmethod
573+
def forward(x):
574+
two_x = x.item() * 2
575+
return x ** 2, two_x
576+
577+
@staticmethod
578+
def setup_context(ctx, inputs, outputs):
579+
x, = inputs
580+
_, two_x = outputs
581+
ctx.two_x = two_x
582+
583+
@staticmethod
584+
@once_differentiable
585+
def backward(ctx, gO, _):
586+
return gO * ctx.two_x
587+
588+
with torch.autograd.function._set_autograd_function_extension_enabled(True):
589+
x = torch.randn([], requires_grad=True)
590+
y, _ = MySquare.apply(x)
591+
gx, = torch.autograd.grad(y, x)
592+
self.assertEqual(gx, 2 * x)
593+
594+
def test_custom_function_setup_context_multi_input(self):
595+
class MyReshape(Function):
596+
@staticmethod
597+
def forward(x, shape, scale_forward, scale_backward):
598+
return x.reshape(shape) * scale_forward
599+
600+
@staticmethod
601+
def setup_context(ctx, inputs, outputs):
602+
x, shape, scale_forward, scale_backward = inputs
603+
ctx.scale_backward = scale_backward
604+
ctx.x_shape = x.shape
605+
606+
@staticmethod
607+
def backward(ctx, gO):
608+
return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
609+
610+
class MyReshapeRef(Function):
611+
@staticmethod
612+
def forward(ctx, x, shape, scale_forward, scale_backward):
613+
ctx.scale_backward = scale_backward
614+
ctx.x_shape = x.shape
615+
return x.reshape(shape) * scale_forward
616+
617+
@staticmethod
618+
def backward(ctx, gO):
619+
return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
620+
621+
def test(x, shape, scale_forward, scale_backward):
622+
y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum()
623+
gx, = torch.autograd.grad(y, x)
624+
625+
y_expected = MyReshapeRef.apply(x, shape, scale_forward, scale_backward).sum()
626+
gx_expected, = torch.autograd.grad(y_expected, x)
627+
628+
self.assertEqual(y_expected, y)
629+
self.assertEqual(gx_expected, gx)
630+
631+
with torch.autograd.function._set_autograd_function_extension_enabled(True):
632+
test(torch.randn(24, requires_grad=True), (3, 8), 7, 11)
633+
test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2)
634+
547635
def test_accumulate_grad(self):
548636
grad_output = torch.ones(5, 5)
549637

torch/autograd/function.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import torch
23
import torch._C as _C
34
from torch._C import _functions
@@ -468,6 +469,17 @@ def traceable(fn_cls):
468469
return fn_cls
469470

470471

472+
# Private feature flag. Not user-facing.
473+
@contextlib.contextmanager
474+
def _set_autograd_function_extension_enabled(enabled=True):
475+
try:
476+
prev_state = torch._C._is_autograd_function_extension_enabled()
477+
torch._C._set_autograd_function_extension_enabled(enabled)
478+
yield
479+
finally:
480+
torch._C._set_autograd_function_extension_enabled(prev_state)
481+
482+
471483
class InplaceFunction(Function):
472484

473485
def __init__(self, inplace=False):

torch/csrc/autograd/python_function.cpp

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,43 @@ PyObject* THPFunction_maybe_clear_saved_tensors(
821821
END_HANDLE_TH_ERRORS
822822
}
823823

824+
namespace {
825+
826+
THPObjectPtr make_ctx_input_tuple(
827+
THPFunction* ctx,
828+
const UnpackedInput& unpacked_input,
829+
int64_t num_args) {
830+
THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
831+
if (!ctx_input_tuple)
832+
return {};
833+
Py_INCREF(ctx);
834+
PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
835+
for (const auto i : c10::irange(num_args)) {
836+
PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
837+
Py_INCREF(arg);
838+
PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
839+
}
840+
return ctx_input_tuple;
841+
}
842+
843+
THPObjectPtr make_ctx_input_output_tuple(
844+
THPFunction* ctx,
845+
UnpackedInput& unpacked_input,
846+
PyObject* outputs) {
847+
THPObjectPtr result(PyTuple_New(3));
848+
if (!result)
849+
return {};
850+
Py_INCREF(ctx);
851+
Py_INCREF(unpacked_input.input_tuple.get());
852+
Py_INCREF(outputs);
853+
PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx);
854+
PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get());
855+
PyTuple_SET_ITEM(result.get(), 2, outputs);
856+
return result;
857+
}
858+
859+
} // namespace
860+
824861
PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
825862
HANDLE_TH_ERRORS
826863

@@ -865,29 +902,51 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
865902
ctx->needs_input_grad = input_info.needs_input_grad.release();
866903
ctx->is_variable_input = std::move(input_info.is_variable_input);
867904

868-
// Prepend ctx to input_tuple, in preparation for static method call
905+
// autograd.Function may optionally contain a setup_context staticmethod.
906+
// In this case, autograd.Function.forward does NOT accept a ctx object.
907+
bool has_separate_setup_context_fn =
908+
(isAutogradFunctionExtensionEnabled() &&
909+
PyObject_HasAttrString(cls, "setup_context"));
910+
869911
auto num_args = PyTuple_GET_SIZE(inputs);
870-
THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
871-
if (!ctx_input_tuple)
872-
return nullptr;
873-
Py_INCREF(ctx);
874-
PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
875-
for (const auto i : c10::irange(num_args)) {
876-
PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
877-
Py_INCREF(arg);
878-
PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
879-
}
880912

881913
// Call forward
882-
THPObjectPtr tensor_outputs;
914+
THPObjectPtr outputs;
883915
{
884916
AutoGradMode grad_mode(false);
885917
at::AutoFwGradMode fw_grad_mode(false);
886918
THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
887919
if (!forward_fn)
888920
return nullptr;
889-
tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
890-
if (!tensor_outputs)
921+
if (has_separate_setup_context_fn) {
922+
// call forward followed by setup_context
923+
outputs = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
924+
if (!outputs) {
925+
return nullptr;
926+
}
927+
// signature is setup_context(ctx, inputs, outputs)
928+
auto ctx_input_output_tuple =
929+
make_ctx_input_output_tuple(ctx, unpacked_input, outputs);
930+
if (!ctx_input_output_tuple) {
931+
return nullptr;
932+
}
933+
THPObjectPtr setup_context_fn(
934+
PyObject_GetAttrString(cls, "setup_context"));
935+
auto result =
936+
PyObject_CallObject(setup_context_fn, ctx_input_output_tuple);
937+
if (!result) {
938+
return nullptr;
939+
}
940+
} else {
941+
// call forward
942+
auto ctx_input_tuple =
943+
make_ctx_input_tuple(ctx, unpacked_input, num_args);
944+
if (!ctx_input_tuple) {
945+
return nullptr;
946+
}
947+
outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
948+
}
949+
if (!outputs)
891950
return nullptr;
892951
}
893952

@@ -897,7 +956,7 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
897956
ctx,
898957
unpacked_input,
899958
inputs,
900-
std::move(tensor_outputs),
959+
std::move(outputs),
901960
is_executable,
902961
node);
903962
END_HANDLE_TH_ERRORS

0 commit comments

Comments
 (0)