Skip to content

Commit fdfc676

Browse files
ezyangfacebook-github-bot
authored andcommitted
Invert ownership between PyFunction and THPFunction.
Summary: Pull Request resolved: #22983 Test Plan: Imported from OSS Differential Revision: D16422209 Pulled By: ezyang fbshipit-source-id: d6e41a1606484fbbd7a95a547b83a4199151be68
1 parent ae5b520 commit fdfc676

File tree

8 files changed

+390
-123
lines changed

8 files changed

+390
-123
lines changed

test/test_autograd.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,12 +1690,196 @@ def test_gc_in_destructor(self):
16901690
segfault.
16911691
"""
16921692
class CollectOnDelete(Function):
1693+
def forward(self, x):
1694+
return x
1695+
1696+
def backward(self, grad_output):
1697+
return grad_output
16931698

16941699
def __del__(self):
16951700
gc.collect()
16961701

16971702
for _ in range(10):
1698-
Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
1703+
CollectOnDelete()(torch.randn(1, requires_grad=True)).backward()
1704+
1705+
def test_call_legacy_twice(self):
1706+
class Id(Function):
1707+
def forward(self, x):
1708+
self.save_for_backward(x)
1709+
return x
1710+
1711+
def backward(self, grad_x):
1712+
x = self.saved_tensors
1713+
return x
1714+
1715+
f = Id()
1716+
x1 = torch.zeros(1, requires_grad=True)
1717+
x2 = torch.ones(1, requires_grad=True)
1718+
y = f(x1)
1719+
with warnings.catch_warnings(record=True) as w:
1720+
z = f(x2)
1721+
self.assertIn('extending-torch-autograd', str(w[1].message))
1722+
# I don't really care about the functional correctness of this
1723+
# part of the test: if you make a change that causes this test
1724+
# to fail, it's probably OK to just fix this test case to follow
1725+
# it. I'm mostly making sure we don't segfault here.
1726+
y.backward()
1727+
self.assertEqual(x2.grad, x2)
1728+
1729+
# Delete this test when legacy custom autograd functions are deleted.
1730+
def test_naughty_legacy_variable_grad_fn(self):
1731+
class Id(Function):
1732+
def forward(self, x):
1733+
return x
1734+
1735+
def backward(self, grad_x):
1736+
return grad_x
1737+
1738+
self.assertRaises(RuntimeError, lambda: Variable(torch.zeros(1), _grad_fn=Id()))
1739+
1740+
# Delete this test when legacy custom autograd functions are deleted.
1741+
def test_naughty_legacy_function_backward_before_forward(self):
1742+
class Id(Function):
1743+
def forward(self, x):
1744+
return x
1745+
1746+
def backward(self, grad_x):
1747+
return grad_x
1748+
1749+
f = Id()
1750+
self.assertRaises(RuntimeError, lambda: f._do_backward((torch.zeros(0), ), False))
1751+
1752+
# Delete this test when legacy custom autograd functions are deleted.
1753+
def test_naughty_legacy_function_early_access(self):
1754+
class Id(Function):
1755+
def forward(self, x):
1756+
return x
1757+
1758+
def backward(self, grad_x):
1759+
return grad_x
1760+
1761+
f = Id()
1762+
# A legacy autograd function is not fully initialized until you actually
1763+
# apply it. That means a lot of accessors on them don't actually work.
1764+
# Test that we properly error in this case.
1765+
self.assertRaises(RuntimeError, lambda: f.register_hook(lambda x, y: None))
1766+
self.assertRaises(RuntimeError, lambda: f.next_functions)
1767+
self.assertRaises(RuntimeError, lambda: f.metadata)
1768+
1769+
@unittest.expectedFailure
1770+
def test_naughty_anomaly_access(self):
1771+
class MyFunction(Function):
1772+
@staticmethod
1773+
def forward(ctx, x):
1774+
return x
1775+
1776+
@staticmethod
1777+
def backward(ctx, g):
1778+
return g
1779+
1780+
x = torch.zeros(1, requires_grad=True)
1781+
y = MyFunction.apply(x)
1782+
y.backward()
1783+
y.grad_fn.metadata
1784+
g = y.grad_fn
1785+
del y
1786+
g.metadata # this currently fails, but shouldn't
1787+
1788+
def test_naughty_autograd_function_stashing_ctx(self):
1789+
saved_ctx = []
1790+
1791+
class Id(Function):
1792+
@staticmethod
1793+
def forward(ctx, x):
1794+
ctx.save_for_backward(x)
1795+
return x
1796+
1797+
@staticmethod
1798+
def backward(ctx, grad_x):
1799+
saved_ctx.append(ctx)
1800+
return ctx.saved_tensors
1801+
1802+
p = torch.zeros(1, requires_grad=True)
1803+
loss = Id.apply(p)
1804+
loss.backward(retain_graph=True)
1805+
del loss
1806+
# At this point in time, it complains that the graph has been freed
1807+
# (which indeed true, although a somewhat indirect way of stating the
1808+
# problem).
1809+
self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors)
1810+
1811+
def test_custom_autograd_repeated_grad_grad(self):
1812+
# This test failed the equality check in PR #22983; it's an interesting
1813+
# and different test case worth enshrining. mult1 is not testing
1814+
# anything that interesting, but mult2 is the interesting case.
1815+
1816+
def mult1(x):
1817+
return x.prod(dim=-1).prod(dim=-1)
1818+
1819+
class Mult(torch.autograd.Function):
1820+
@staticmethod
1821+
def forward(ctx, x):
1822+
y = mult1(x)
1823+
ctx.save_for_backward(x, y)
1824+
return y
1825+
1826+
@staticmethod
1827+
def backward(ctx, grad_output):
1828+
x, y = ctx.saved_tensors
1829+
return (grad_output * y)[:, None, None] / x
1830+
1831+
mult2 = Mult.apply
1832+
1833+
def check_gradgrad_repeated(x, y):
1834+
gy, = torch.autograd.grad(y[0], x, create_graph=True)
1835+
ggy_1, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
1836+
gy, = torch.autograd.grad(y[0], x, create_graph=True)
1837+
ggy_2, = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
1838+
self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1])
1839+
1840+
x = torch.ones(2, 4, 4).requires_grad_()
1841+
check_gradgrad_repeated(x, mult1(x))
1842+
check_gradgrad_repeated(x, mult2(x))
1843+
1844+
def test_custom_autograd_no_early_free(self):
1845+
# This test failed complaining that buffers had already been freed
1846+
# prior to #22983. Also pretty interesting test case.
1847+
class Double(torch.autograd.Function):
1848+
@staticmethod
1849+
def forward(ctx, x):
1850+
y = x ** 2
1851+
ctx.save_for_backward(x, y)
1852+
return y
1853+
1854+
@staticmethod
1855+
def backward(ctx, grad_output):
1856+
x, _ = ctx.saved_tensors
1857+
return grad_output * 2 * x
1858+
1859+
# this is equivalent, but uses the output of .forward() in .backward()
1860+
class Double2(Double):
1861+
@staticmethod
1862+
def backward(ctx, grad_output):
1863+
x, y = ctx.saved_tensors
1864+
return grad_output * 2 * y / x
1865+
1866+
double = Double.apply
1867+
double2 = Double2.apply
1868+
1869+
x = torch.tensor(2).double().requires_grad_()
1870+
1871+
self.assertTrue(torch.autograd.gradcheck(double, x))
1872+
self.assertTrue(torch.autograd.gradgradcheck(double, x))
1873+
self.assertTrue(torch.autograd.gradcheck(double2, x))
1874+
self.assertTrue(torch.autograd.gradgradcheck(double2, x))
1875+
1876+
y = double(x)
1877+
torch.autograd.grad(y, x, create_graph=True)
1878+
torch.autograd.grad(y, x)
1879+
1880+
y = double2(x)
1881+
torch.autograd.grad(y, x, create_graph=True)
1882+
torch.autograd.grad(y, x) # should not error!
16991883

17001884
@unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU")
17011885
@skipIfRocm

test/test_nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4734,6 +4734,32 @@ def test_data_parallel_device_args(self):
47344734
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
47354735
self.assertEqual(out, l(i))
47364736

4737+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
4738+
@skipIfRocm
4739+
def test_data_parallel_function_deletion(self):
4740+
# this test case is originated from #16532
4741+
def gradient_penalty(net, x):
4742+
output = net(x)
4743+
loss = torch.autograd.grad(
4744+
outputs=output, inputs=x,
4745+
grad_outputs=x.new_ones(output.size()),
4746+
create_graph=True, retain_graph=True)[0].mean()
4747+
return loss
4748+
4749+
net = nn.Linear(4, 1).cuda()
4750+
dpn = nn.DataParallel(net, [0, 1])
4751+
x = torch.ones(2, 4, requires_grad=True).cuda()
4752+
4753+
dpn.zero_grad()
4754+
loss = gradient_penalty(dpn, x)
4755+
loss.backward()
4756+
grads = [p.grad for p in net.parameters()]
4757+
self.assertEqual(2, len(grads))
4758+
self.assertEqual(
4759+
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
4760+
grads[0])
4761+
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
4762+
47374763
def test_state_dict(self):
47384764
l = nn.Linear(5, 5)
47394765
block = nn.Module()

torch/csrc/autograd/function.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,6 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
201201
return sequence_nr_;
202202
}
203203

204-
/// Returns a shared pointer to `this`. `PyFunction`s are not managed by
205-
/// `shared_ptr`s by default, but are bound to the lifetime of their Python
206-
/// object instead.
207-
virtual std::shared_ptr<Function> get_shared_ptr() {
208-
return shared_from_this();
209-
}
210-
211204
/// Returns the name of the dynamic type of the function, for debugging.
212205
virtual std::string name() const;
213206

torch/csrc/autograd/functions/pybind.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,4 @@ namespace py = pybind11;
1111

1212
namespace pybind11 { namespace detail {
1313

14-
// handle Python <-> torch::autograd::Function conversions
15-
template <> struct type_caster<std::shared_ptr<torch::autograd::Function>> {
16-
public:
17-
PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _("std::shared_ptr<torch::autograd::Function>"));
18-
19-
bool load(handle src, bool) {
20-
if (!THPFunction_Check(src.ptr())) return false;
21-
value = THPFunction_asFunction((THPFunction*)src.ptr());
22-
return true;
23-
}
24-
static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy /* policy */, handle /* parent */) {
25-
auto fn = functionToPyObject(std::move(src));
26-
return handle(fn);
27-
}
28-
};
29-
30-
3114
}} // namespace pybind11::detail

0 commit comments

Comments
 (0)