Skip to content

Commit 23eef2a

Browse files
committed
Register Saved Tensors hooks
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 42155f0 Pull Request resolved: #60663
1 parent 37aca0f commit 23eef2a

File tree

6 files changed

+76
-3
lines changed

6 files changed

+76
-3
lines changed

torch/csrc/autograd/init.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/autograd/python_function.h>
1212
#include <torch/csrc/autograd/function.h>
1313
#include <torch/csrc/autograd/saved_variable.h>
14+
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
1415
#include <torch/csrc/autograd/utils/wrap_outputs.h>
1516
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
1617
#include <torch/csrc/utils/pycfunction_helpers.h>
@@ -261,7 +262,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
261262
}))
262263
.def("register_hooks", [](torch::autograd::SavedVariable &s, py::function &pack_hook, py::function &unpack_hook) {
263264
s.register_hooks();
264-
});
265+
// s.register_hook(std::make_unique<torch::autograd::PySavedVariableHooks>(pack_hook, unpack_hook));
266+
// s.register_hook(std::make_unique<torch::autograd::PySavedVariableHooks>(pack_hook.release().ptr(), unpack_hook.release().ptr()));
267+
});
265268

266269
Py_RETURN_TRUE;
267270
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
2+
#include <torch/csrc/autograd/python_variable.h>
3+
4+
namespace torch { namespace autograd {
5+
// PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) : pack_hook_(pack_hook), unpack_hook_(unpack_hook){};
6+
PySavedVariableHooks::PySavedVariableHooks(PyObject *pack_hook, PyObject *unpack_hook) : pack_hook_(pack_hook), unpack_hook_(unpack_hook){};
7+
// PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) {
8+
// pack_hook_ = pack_hook;
9+
// unpack_hook_ = unpack_hook;
10+
// }
11+
12+
PyObject* PySavedVariableHooks::call_pack_hook(at::Tensor tensor) {
13+
return pack_hook_(tensor).release().ptr();
14+
};
15+
16+
at::Tensor PySavedVariableHooks::call_unpack_hook(PyObject* obj) {
17+
return THPVariable_Unpack(unpack_hook_(obj).release().ptr());
18+
};
19+
20+
21+
}}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
#include <torch/csrc/autograd/saved_variable_hooks.h>
5+
#include <torch/csrc/python_headers.h>
6+
#include <torch/csrc/utils/auto_gil.h>
7+
8+
#include <ATen/ATen.h>
9+
10+
namespace py = pybind11;
11+
12+
namespace torch { namespace autograd {
13+
14+
struct TORCH_API PySavedVariableHooks : public SavedVariableHooks {
15+
// PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
16+
PySavedVariableHooks(PyObject* pack_hook, PyObject* unpack_hook);
17+
PyObject* call_pack_hook(at::Tensor tensor) override;
18+
at::Tensor call_unpack_hook(PyObject* obj) override;
19+
20+
private:
21+
// py::function pack_hook_;
22+
// py::function unpack_hook_;
23+
PyObject* pack_hook_;
24+
PyObject* unpack_hook_;
25+
};
26+
27+
}}

torch/csrc/autograd/saved_variable.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
153153
return var;
154154
}
155155

156+
// void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
156157
void SavedVariable::register_hooks() {
157158
if (!data_.defined()) {
158159
if (!was_default_constructed_) {
@@ -166,14 +167,16 @@ void SavedVariable::register_hooks() {
166167
TORCH_CHECK(false, "Calling register_hook on a tensor with value None is forbidden");
167168
}
168169
}
170+
// hooks_ = std::move(hooks);
171+
// data_ = hooks_->call_unpack_hook(hooks_->call_pack_hook(data_));
169172
}
170173

171174
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
172175
const char* ERR_BACKWARD_TWICE =
173176
"Trying to backward through the graph a second time (or directly access saved "
174-
"variables after they have already been freed). Saved intermediate values "
177+
"tensors after they have already been freed). Saved intermediate values "
175178
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
176179
"retain_graph=True if you need to backward through the graph a second time or "
177-
"if you need to access saved variables after calling backward.";
180+
"if you need to access saved tensors after calling backward.";
178181

179182
}} // namespace torch::autograd

torch/csrc/autograd/saved_variable.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <torch/csrc/WindowsTorchApiMacro.h>
44
#include <torch/csrc/autograd/forward_grad.h>
5+
#include <torch/csrc/autograd/saved_variable_hooks.h>
56

67
#include <ATen/ATen.h>
78

@@ -39,6 +40,8 @@ class TORCH_API SavedVariable {
3940

4041
void register_hooks();
4142

43+
// void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
44+
4245
void reset_data() {
4346
return data_.reset();
4447
}
@@ -70,6 +73,8 @@ class TORCH_API SavedVariable {
7073
std::weak_ptr<Node> weak_grad_fn_;
7174
c10::VariableVersion version_counter_;
7275

76+
std::unique_ptr<SavedVariableHooks> hooks_;
77+
7378
uint32_t saved_version_ = 0;
7479
uint32_t output_nr_ = 0;
7580
bool was_default_constructed_ = true;
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <torch/csrc/python_headers.h>
4+
#include <torch/csrc/WindowsTorchApiMacro.h>
5+
#include <ATen/ATen.h>
6+
7+
namespace torch { namespace autograd {
8+
9+
struct TORCH_API SavedVariableHooks {
10+
virtual PyObject* call_pack_hook(at::Tensor tensor) = 0;
11+
virtual at::Tensor call_unpack_hook(PyObject* obj)= 0 ;
12+
};
13+
14+
}}

0 commit comments

Comments
 (0)