@@ -1690,12 +1690,196 @@ def test_gc_in_destructor(self):
16901690 segfault.
16911691 """
16921692 class CollectOnDelete (Function ):
1693+ def forward (self , x ):
1694+ return x
1695+
1696+ def backward (self , grad_output ):
1697+ return grad_output
16931698
16941699 def __del__ (self ):
16951700 gc .collect ()
16961701
16971702 for _ in range (10 ):
1698- Variable (torch .randn (10 , 10 ), _grad_fn = CollectOnDelete ())
1703+ CollectOnDelete ()(torch .randn (1 , requires_grad = True )).backward ()
1704+
1705+ def test_call_legacy_twice (self ):
1706+ class Id (Function ):
1707+ def forward (self , x ):
1708+ self .save_for_backward (x )
1709+ return x
1710+
1711+ def backward (self , grad_x ):
1712+ x = self .saved_tensors
1713+ return x
1714+
1715+ f = Id ()
1716+ x1 = torch .zeros (1 , requires_grad = True )
1717+ x2 = torch .ones (1 , requires_grad = True )
1718+ y = f (x1 )
1719+ with warnings .catch_warnings (record = True ) as w :
1720+ z = f (x2 )
1721+ self .assertIn ('extending-torch-autograd' , str (w [1 ].message ))
1722+ # I don't really care about the functional correctness of this
1723+ # part of the test: if you make a change that causes this test
1724+ # to fail, it's probably OK to just fix this test case to follow
1725+ # it. I'm mostly making sure we don't segfault here.
1726+ y .backward ()
1727+ self .assertEqual (x2 .grad , x2 )
1728+
1729+ # Delete this test when legacy custom autograd functions are deleted.
1730+ def test_naughty_legacy_variable_grad_fn (self ):
1731+ class Id (Function ):
1732+ def forward (self , x ):
1733+ return x
1734+
1735+ def backward (self , grad_x ):
1736+ return grad_x
1737+
1738+ self .assertRaises (RuntimeError , lambda : Variable (torch .zeros (1 ), _grad_fn = Id ()))
1739+
1740+ # Delete this test when legacy custom autograd functions are deleted.
1741+ def test_naughty_legacy_function_backward_before_forward (self ):
1742+ class Id (Function ):
1743+ def forward (self , x ):
1744+ return x
1745+
1746+ def backward (self , grad_x ):
1747+ return grad_x
1748+
1749+ f = Id ()
1750+ self .assertRaises (RuntimeError , lambda : f ._do_backward ((torch .zeros (0 ), ), False ))
1751+
1752+ # Delete this test when legacy custom autograd functions are deleted.
1753+ def test_naughty_legacy_function_early_access (self ):
1754+ class Id (Function ):
1755+ def forward (self , x ):
1756+ return x
1757+
1758+ def backward (self , grad_x ):
1759+ return grad_x
1760+
1761+ f = Id ()
1762+ # A legacy autograd function is not fully initialized until you actually
1763+ # apply it. That means a lot of accessors on them don't actually work.
1764+ # Test that we properly error in this case.
1765+ self .assertRaises (RuntimeError , lambda : f .register_hook (lambda x , y : None ))
1766+ self .assertRaises (RuntimeError , lambda : f .next_functions )
1767+ self .assertRaises (RuntimeError , lambda : f .metadata )
1768+
1769+ @unittest .expectedFailure
1770+ def test_naughty_anomaly_access (self ):
1771+ class MyFunction (Function ):
1772+ @staticmethod
1773+ def forward (ctx , x ):
1774+ return x
1775+
1776+ @staticmethod
1777+ def backward (ctx , g ):
1778+ return g
1779+
1780+ x = torch .zeros (1 , requires_grad = True )
1781+ y = MyFunction .apply (x )
1782+ y .backward ()
1783+ y .grad_fn .metadata
1784+ g = y .grad_fn
1785+ del y
1786+ g .metadata # this currently fails, but shouldn't
1787+
1788+ def test_naughty_autograd_function_stashing_ctx (self ):
1789+ saved_ctx = []
1790+
1791+ class Id (Function ):
1792+ @staticmethod
1793+ def forward (ctx , x ):
1794+ ctx .save_for_backward (x )
1795+ return x
1796+
1797+ @staticmethod
1798+ def backward (ctx , grad_x ):
1799+ saved_ctx .append (ctx )
1800+ return ctx .saved_tensors
1801+
1802+ p = torch .zeros (1 , requires_grad = True )
1803+ loss = Id .apply (p )
1804+ loss .backward (retain_graph = True )
1805+ del loss
1806+ # At this point in time, it complains that the graph has been freed
1807+ # (which indeed true, although a somewhat indirect way of stating the
1808+ # problem).
1809+ self .assertRaises (RuntimeError , lambda : saved_ctx [0 ].saved_tensors )
1810+
1811+ def test_custom_autograd_repeated_grad_grad (self ):
1812+ # This test failed the equality check in PR #22983; it's an interesting
1813+ # and different test case worth enshrining. mult1 is not testing
1814+ # anything that interesting, but mult2 is the interesting case.
1815+
1816+ def mult1 (x ):
1817+ return x .prod (dim = - 1 ).prod (dim = - 1 )
1818+
1819+ class Mult (torch .autograd .Function ):
1820+ @staticmethod
1821+ def forward (ctx , x ):
1822+ y = mult1 (x )
1823+ ctx .save_for_backward (x , y )
1824+ return y
1825+
1826+ @staticmethod
1827+ def backward (ctx , grad_output ):
1828+ x , y = ctx .saved_tensors
1829+ return (grad_output * y )[:, None , None ] / x
1830+
1831+ mult2 = Mult .apply
1832+
1833+ def check_gradgrad_repeated (x , y ):
1834+ gy , = torch .autograd .grad (y [0 ], x , create_graph = True )
1835+ ggy_1 , = torch .autograd .grad (gy [0 , 0 , 0 ], x , retain_graph = True )
1836+ gy , = torch .autograd .grad (y [0 ], x , create_graph = True )
1837+ ggy_2 , = torch .autograd .grad (gy [0 , 0 , 0 ], x , retain_graph = True )
1838+ self .assertEqual (ggy_1 [0 , 0 , 1 ], ggy_2 [0 , 0 , 1 ])
1839+
1840+ x = torch .ones (2 , 4 , 4 ).requires_grad_ ()
1841+ check_gradgrad_repeated (x , mult1 (x ))
1842+ check_gradgrad_repeated (x , mult2 (x ))
1843+
1844+ def test_custom_autograd_no_early_free (self ):
1845+ # This test failed complaining that buffers had already been freed
1846+ # prior to #22983. Also pretty interesting test case.
1847+ class Double (torch .autograd .Function ):
1848+ @staticmethod
1849+ def forward (ctx , x ):
1850+ y = x ** 2
1851+ ctx .save_for_backward (x , y )
1852+ return y
1853+
1854+ @staticmethod
1855+ def backward (ctx , grad_output ):
1856+ x , _ = ctx .saved_tensors
1857+ return grad_output * 2 * x
1858+
1859+ # this is equivalent, but uses the output of .forward() in .backward()
1860+ class Double2 (Double ):
1861+ @staticmethod
1862+ def backward (ctx , grad_output ):
1863+ x , y = ctx .saved_tensors
1864+ return grad_output * 2 * y / x
1865+
1866+ double = Double .apply
1867+ double2 = Double2 .apply
1868+
1869+ x = torch .tensor (2 ).double ().requires_grad_ ()
1870+
1871+ self .assertTrue (torch .autograd .gradcheck (double , x ))
1872+ self .assertTrue (torch .autograd .gradgradcheck (double , x ))
1873+ self .assertTrue (torch .autograd .gradcheck (double2 , x ))
1874+ self .assertTrue (torch .autograd .gradgradcheck (double2 , x ))
1875+
1876+ y = double (x )
1877+ torch .autograd .grad (y , x , create_graph = True )
1878+ torch .autograd .grad (y , x )
1879+
1880+ y = double2 (x )
1881+ torch .autograd .grad (y , x , create_graph = True )
1882+ torch .autograd .grad (y , x ) # should not error!
16991883
17001884 @unittest .skipIf (torch .cuda .device_count () < 2 , "no multi-GPU" )
17011885 @skipIfRocm
0 commit comments