Skip to content

Commit afa198d

Browse files
committed
Invert ownership between PyFunction and THPFunction.
Fixes #16532 and #14960. This patch is a massive hack. The way I constructed it was I flipped the ownership between PyFunction and THPFunction, but maintained a weak pointer from THPFunction to PyFunction so all existing code works. Essentially, this patch assumes that PyFunction stays live as long as you have a THPFunction: intuitively, this makes sense, since the ctx object should only really stay live as long as you're actually going to execute the backwards, which will keep the PyFunction live. But as you can see from the presently skipped tests (specifically, test_hook_none), this is not always true. But it seems to be true for the code we care about, and that's enough for me! Some subtleties: - PyFunction is a C++ object that refers to a PyObject. This means it needs a custom deleter to handle deleting the PyObject, since you can't assume you have the GIL when it dies. - The old test_gc_in_destructor failed our internal assert because we never actually ran a backwards, and thus never actually materialized PyFunction. I'm chalking this up as "misuse of API" and rewrote the test to not have this problem. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: a25e840 Pull Request resolved: #22983
1 parent aeee49d commit afa198d

File tree

8 files changed

+385
-123
lines changed

8 files changed

+385
-123
lines changed

test/test_autograd.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,12 +1690,196 @@ def test_gc_in_destructor(self):
16901690
segfault.
16911691
"""
16921692
class CollectOnDelete(Function):
1693+
def forward(self, x):
1694+
return x
1695+
1696+
def backward(self, grad_output):
1697+
return grad_output
16931698

16941699
def __del__(self):
16951700
gc.collect()
16961701

16971702
for _ in range(10):
1698-
Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
1703+
CollectOnDelete()(torch.randn(1, requires_grad=True)).backward()
1704+
1705+
def test_call_legacy_twice(self):
1706+
class Id(Function):
1707+
def forward(self, x):
1708+
self.save_for_backward(x)
1709+
return x
1710+
1711+
def backward(self, grad_x):
1712+
x = self.saved_tensors
1713+
return x
1714+
1715+
f = Id()
1716+
x1 = torch.zeros(1, requires_grad=True)
1717+
x2 = torch.ones(1, requires_grad=True)
1718+
y = f(x1)
1719+
with warnings.catch_warnings(record=True) as w:
1720+
z = f(x2)
1721+
self.assertIn('extending-torch-autograd', str(w[1].message))
1722+
# I don't really care about the functional correctness of this
1723+
# part of the test: if you make a change that causes this test
1724+
# to fail, it's probably OK to just fix this test case to follow
1725+
# it. I'm mostly making sure we don't segfault here.
1726+
y.backward()
1727+
self.assertEqual(x2.grad, x2)
1728+
1729+
# Delete this test when legacy custom autograd functions are deleted.
1730+
def test_naughty_legacy_variable_grad_fn(self):
1731+
class Id(Function):
1732+
def forward(self, x):
1733+
return x
1734+
1735+
def backward(self, grad_x):
1736+
return grad_x
1737+
1738+
self.assertRaises(RuntimeError, lambda: Variable(torch.zeros(1), _grad_fn=Id()))
1739+
1740+
# Delete this test when legacy custom autograd functions are deleted.
1741+
def test_naughty_legacy_function_backward_before_forward(self):
1742+
class Id(Function):
1743+
def forward(self, x):
1744+
return x
1745+
1746+
def backward(self, grad_x):
1747+
return grad_x
1748+
1749+
f = Id()
1750+
self.assertRaises(RuntimeError, lambda: f._do_backward((torch.zeros(0), ), False))
1751+
1752+
# Delete this test when legacy custom autograd functions are deleted.
1753+
def test_naughty_legacy_function_early_access(self):
1754+
class Id(Function):
1755+
def forward(self, x):
1756+
return x
1757+
1758+
def backward(self, grad_x):
1759+
return grad_x
1760+
1761+
f = Id()
1762+
# A legacy autograd function is not fully initialized until you actually
1763+
# apply it. That means a lot of accessors on them don't actually work.
1764+
# Test that we properly error in this case.
1765+
self.assertRaises(RuntimeError, lambda: f.register_hook(lambda x, y: None))
1766+
self.assertRaises(RuntimeError, lambda: f.next_functions)
1767+
self.assertRaises(RuntimeError, lambda: f.metadata)
1768+
1769+
@unittest.expectedFailure
1770+
def test_naughty_anomaly_access(self):
1771+
class MyFunction(Function):
1772+
@staticmethod
1773+
def forward(ctx, x):
1774+
return x
1775+
1776+
@staticmethod
1777+
def backward(ctx, g):
1778+
return g
1779+
1780+
x = torch.zeros(1, requires_grad=True)
1781+
y = MyFunction.apply(x)
1782+
y.backward()
1783+
y.grad_fn.metadata
1784+
g = y.grad_fn
1785+
del y
1786+
g.metadata # this currently fails, but shouldn't
1787+
1788+
def test_naughty_autograd_function_stashing_ctx(self):
1789+
saved_ctx = []
1790+
1791+
class Id(Function):
1792+
@staticmethod
1793+
def forward(ctx, x):
1794+
ctx.save_for_backward(x)
1795+
return x
1796+
1797+
@staticmethod
1798+
def backward(ctx, grad_x):
1799+
saved_ctx.append(ctx)
1800+
return ctx.saved_tensors
1801+
1802+
p = torch.zeros(1, requires_grad=True)
1803+
loss = Id.apply(p)
1804+
loss.backward(retain_graph=True)
1805+
del loss
1806+
# At this point in time, it complains that the graph has been freed
1807+
# (which indeed true, although a somewhat indirect way of stating the
1808+
# problem).
1809+
self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors)
1810+
1811+
def test_custom_autograd_repeated_grad_grad(self):
1812+
# This test failed the equality check in PR #22983; it's an interesting
1813+
# and different test case worth enshrining. mult1 is not testing
1814+
# anything that interesting, but mult2 is the interesting case.
1815+
1816+
def mult1(x):
1817+
return x.prod(dim=-1).prod(dim=-1)
1818+
1819+
class Mult(torch.autograd.Function):
1820+
@staticmethod
1821+
def forward(ctx, x):
1822+
y = mult1(x)
1823+
ctx.save_for_backward(x, y)
1824+
return y
1825+
1826+
@staticmethod
1827+
def backward(ctx, grad_output):
1828+
x, y = ctx.saved_tensors
1829+
return (grad_output * y)[:, None, None] / x
1830+
1831+
mult2 = Mult.apply
1832+
1833+
def check_gradgrad_repeated(x, y):
1834+
gy, = torch.autograd.grad(y[0], x, create_graph=True)
1835+
ggy_1, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
1836+
gy, = torch.autograd.grad(y[0], x, create_graph=True)
1837+
ggy_2, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
1838+
self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1])
1839+
1840+
x = torch.ones(2, 4, 4).requires_grad_()
1841+
check_gradgrad_repeated(x, mult1(x))
1842+
check_gradgrad_repeated(x, mult2(x))
1843+
1844+
def test_custom_autograd_no_early_free(self):
1845+
# This test failed complaining that buffers had already been freed
1846+
# prior to #22983. Also pretty interesting test case.
1847+
class Double(torch.autograd.Function):
1848+
@staticmethod
1849+
def forward(ctx, x):
1850+
y = x ** 2
1851+
ctx.save_for_backward(x, y)
1852+
return y
1853+
1854+
@staticmethod
1855+
def backward(ctx, grad_output):
1856+
x, _ = ctx.saved_tensors
1857+
return grad_output * 2 * x
1858+
1859+
# this is equivalent, but uses the output of .forward() in .backward()
1860+
class Double2(Double):
1861+
@staticmethod
1862+
def backward(ctx, grad_output):
1863+
x, y = ctx.saved_tensors
1864+
return grad_output * 2 * y / x
1865+
1866+
double = Double.apply
1867+
double2 = Double2.apply
1868+
1869+
x = torch.tensor(2).double().requires_grad_()
1870+
1871+
self.assertTrue(torch.autograd.gradcheck(double, x))
1872+
self.assertTrue(torch.autograd.gradgradcheck(double, x))
1873+
self.assertTrue(torch.autograd.gradcheck(double2, x))
1874+
self.assertTrue(torch.autograd.gradgradcheck(double2, x))
1875+
1876+
y = double(x)
1877+
torch.autograd.grad(y, x, create_graph=True)
1878+
torch.autograd.grad(y, x)
1879+
1880+
y = double2(x)
1881+
torch.autograd.grad(y, x, create_graph=True)
1882+
torch.autograd.grad(y, x) # should not error!
16991883

17001884
@unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU")
17011885
@skipIfRocm

test/test_nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4734,6 +4734,32 @@ def test_data_parallel_device_args(self):
47344734
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
47354735
self.assertEqual(out, l(i))
47364736

4737+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
4738+
@skipIfRocm
4739+
def test_data_parallel_function_deletion(self):
4740+
# this test case is originated from #16532
4741+
def gradient_penalty(net, x):
4742+
output = net(x)
4743+
loss = torch.autograd.grad(
4744+
outputs=output, inputs=x,
4745+
grad_outputs=x.new_ones(output.size()),
4746+
create_graph=True, retain_graph=True)[0].mean()
4747+
return loss
4748+
4749+
net = nn.Linear(4, 1).cuda()
4750+
dpn = nn.DataParallel(net, [0, 1])
4751+
x = torch.ones(2, 4, requires_grad=True).cuda()
4752+
4753+
dpn.zero_grad()
4754+
loss = gradient_penalty(dpn, x)
4755+
loss.backward()
4756+
grads = [p.grad for p in net.parameters()]
4757+
self.assertEqual(2, len(grads))
4758+
self.assertEqual(
4759+
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
4760+
grads[0])
4761+
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
4762+
47374763
def test_state_dict(self):
47384764
l = nn.Linear(5, 5)
47394765
block = nn.Module()

torch/csrc/autograd/function.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,6 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
201201
return sequence_nr_;
202202
}
203203

204-
/// Returns a shared pointer to `this`. `PyFunction`s are not managed by
205-
/// `shared_ptr`s by default, but are bound to the lifetime of their Python
206-
/// object instead.
207-
virtual std::shared_ptr<Function> get_shared_ptr() {
208-
return shared_from_this();
209-
}
210-
211204
/// Returns the name of the dynamic type of the function, for debugging.
212205
virtual std::string name() const;
213206

torch/csrc/autograd/functions/pybind.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,4 @@ namespace py = pybind11;
1111

1212
namespace pybind11 { namespace detail {
1313

14-
// handle Python <-> torch::autograd::Function conversions
15-
template <> struct type_caster<std::shared_ptr<torch::autograd::Function>> {
16-
public:
17-
PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _("std::shared_ptr<torch::autograd::Function>"));
18-
19-
bool load(handle src, bool) {
20-
if (!THPFunction_Check(src.ptr())) return false;
21-
value = THPFunction_asFunction((THPFunction*)src.ptr());
22-
return true;
23-
}
24-
static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy /* policy */, handle /* parent */) {
25-
auto fn = functionToPyObject(std::move(src));
26-
return handle(fn);
27-
}
28-
};
29-
30-
3114
}} // namespace pybind11::detail

0 commit comments

Comments
 (0)