@@ -504,31 +504,28 @@ def test_hook_none(self):
504504 # WARNING: this is a test for autograd internals.
505505 # You should never have to use such things in your code.
506506 class NoneGradientFunction (Function ):
507-
508- def forward (self , x , y ):
509- assert self .needs_input_grad [0 ]
510- assert not self .needs_input_grad [1 ]
507+ @ staticmethod
508+ def forward (ctx , x , y ):
509+ assert ctx .needs_input_grad [0 ]
510+ assert not ctx .needs_input_grad [1 ]
511511 return x , y
512512
513- def backward (self , grad_x , grad_y ):
513+ @staticmethod
514+ def backward (ctx , grad_x , grad_y ):
514515 return grad_x , None
515516
516- fn = NoneGradientFunction ()
517517 was_called = [False ]
518518
519- def hook (grad_input , grad_output ):
520- self .assertIsInstance (grad_input , tuple )
521- self .assertIsInstance (grad_output , tuple )
522- self .assertIsNotNone (grad_input [0 ])
523- self .assertIsNotNone (grad_input [1 ])
524- self .assertIsNotNone (grad_output [0 ])
525- self .assertIsNotNone (grad_output [1 ])
519+ def hook (grad ):
520+ self .assertIsNotNone (grad )
526521 was_called [0 ] = True
527- fn .register_hook (hook )
528522
529523 x = torch .randn (5 , 5 , requires_grad = True )
530524 y = torch .randn (5 , 5 )
531- sum (fn (x , y )).sum ().backward ()
525+ rx , ry = NoneGradientFunction .apply (x , y )
526+ rx .register_hook (hook )
527+ ry .register_hook (hook )
528+ sum (rx , ry ).sum ().backward ()
532529 self .assertTrue (was_called [0 ])
533530
534531 def test_retain_grad (self ):
@@ -601,14 +598,15 @@ def test_backward(self):
601598
602599 def test_sparse_backward (self ):
603600 class FixedGradientFunction (Function ):
604- def __init__ (self , grad ):
605- self .grad = grad
606-
607- def forward (self , x ):
601+ @staticmethod
602+ def forward (ctx , x , grad_x ):
603+ ctx .save_for_backward (grad_x )
608604 return x
609605
610- def backward (self , grad_x ):
611- return self .grad
606+ @staticmethod
607+ def backward (ctx , grad_x ):
608+ saved_grad_x , = ctx .saved_tensors
609+ return saved_grad_x , None
612610
613611 size = torch .Size ([6 , 3 , 2 ])
614612 i1 = torch .LongTensor ([
@@ -624,21 +622,19 @@ def backward(self, grad_x):
624622 v2 = torch .DoubleTensor ([[1 , 2 ], [4 , 3 ], [4 , 5 ], [7 , 8 ]])
625623 sparse_grad2 = torch .sparse .DoubleTensor (i2 , v2 , size )
626624 dense_grad = torch .rand (size ).double ()
627- sparse_fn1 = FixedGradientFunction (sparse_grad1 )
628- sparse_fn2 = FixedGradientFunction (sparse_grad2 )
629- dense_fn = FixedGradientFunction (dense_grad )
625+ fn = FixedGradientFunction
630626
631627 # sparse first
632628 x = torch .randn (size , requires_grad = True )
633- (sparse_fn1 ( x ) + dense_fn ( x ) + sparse_fn2 ( x )).sum ().backward ()
629+ (fn . apply ( x , sparse_grad1 ) + fn . apply ( x , dense_grad ) + fn . apply ( x , sparse_grad2 )).sum ().backward ()
634630 self .assertEqual (x .grad , dense_grad + sparse_grad1 + sparse_grad2 )
635631 # dense first
636632 x = torch .randn (size , requires_grad = True )
637- (dense_fn ( x ) + sparse_fn1 ( x ) + sparse_fn2 ( x )).sum ().backward ()
633+ (fn . apply ( x , dense_grad ) + fn . apply ( x , sparse_grad1 ) + fn . apply ( x , sparse_grad2 )).sum ().backward ()
638634 self .assertEqual (x .grad , dense_grad + sparse_grad1 + sparse_grad2 )
639635 # sparse only
640636 x = torch .randn (size , requires_grad = True )
641- (sparse_fn1 ( x ) + sparse_fn2 ( x )).sum ().backward ()
637+ (fn . apply ( x , sparse_grad1 ) + fn . apply ( x , sparse_grad2 )).sum ().backward ()
642638 self .assertEqual (x .grad , sparse_grad1 + sparse_grad2 )
643639
644640 def test_sparse_mm_backward (self ):
@@ -1913,18 +1909,19 @@ def test_numpy_requires_grad(self):
19131909
19141910 def test_return_leaf (self ):
19151911 class Identity (Function ):
1916-
1917- def forward (self , a , b ):
1912+ @ staticmethod
1913+ def forward (ctx , a , b ):
19181914 return a , a + b
19191915
1920- def backward (self , grad_a , grad_b ):
1916+ @staticmethod
1917+ def backward (ctx , grad_a , grad_b ):
19211918 return grad_a + grad_b , grad_b
19221919
19231920 hook_called = [False ]
19241921 x = torch .randn (5 , 5 , requires_grad = True )
19251922 y = torch .randn (5 , 5 , requires_grad = True )
19261923
1927- q , p = Identity () (x , y )
1924+ q , p = Identity . apply (x , y )
19281925
19291926 # Make sure hooks only receive grad from usage of q, not x.
19301927 def hook (grad ):
@@ -1939,21 +1936,22 @@ def hook(grad):
19391936
19401937 def test_return_leaf_inplace (self ):
19411938 class Inplace (InplaceFunction ):
1942-
1943- def forward (self , a , b ):
1944- self .mark_dirty (a )
1939+ @ staticmethod
1940+ def forward (ctx , a , b ):
1941+ ctx .mark_dirty (a )
19451942 return a .add_ (b ), b + 2
19461943
1947- def backward (self , grad_a , grad_b ):
1944+ @staticmethod
1945+ def backward (ctx , grad_a , grad_b ):
19481946 return grad_a , grad_a + grad_b
19491947
19501948 x = torch .randn (5 , 5 )
19511949 y = torch .randn (5 , 5 , requires_grad = True )
19521950
19531951 fn = Inplace (True )
1954- q , p = fn (x , y )
1952+ q , p = fn . apply (x , y )
19551953 self .assertIs (q , x )
1956- self .assertIs (q .grad_fn , fn )
1954+ self .assertIs (q .grad_fn . __class__ , fn . _backward_cls )
19571955 self .assertTrue (q .requires_grad )
19581956 q .sum ().backward ()
19591957 self .assertEqual (y .grad .data , torch .ones (5 , 5 ))
@@ -2052,33 +2050,35 @@ def test_save_none_for_backward(self):
20522050 test_case = self
20532051
20542052 class MyFn (Function ):
2055-
2056- def forward (self , input ):
2057- self .save_for_backward (None , input , None )
2053+ @ staticmethod
2054+ def forward (ctx , input ):
2055+ ctx .save_for_backward (None , input , None )
20582056 return input * input
20592057
2060- def backward (self , grad_output ):
2061- n1 , input , n2 = self .saved_tensors
2058+ @staticmethod
2059+ def backward (ctx , grad_output ):
2060+ n1 , input , n2 = ctx .saved_tensors
20622061 test_case .assertIsNone (n1 )
20632062 test_case .assertIsNone (n2 )
20642063 return 2 * input * grad_output
20652064
20662065 x = torch .randn (5 , 5 , requires_grad = True )
2067- y = MyFn () (x )
2066+ y = MyFn . apply (x )
20682067 y .sum ().backward ()
20692068 self .assertEqual (x .grad , 2 * x )
20702069
20712070 def test_too_many_grads (self ):
20722071 class MyFn (Function ):
2073-
2074- def forward (self , input ):
2072+ @ staticmethod
2073+ def forward (ctx , input ):
20752074 return input
20762075
2077- def backward (self , grad_output ):
2076+ @staticmethod
2077+ def backward (ctx , grad_output ):
20782078 return grad_output , None , None
20792079
20802080 x = torch .randn (5 , 5 , requires_grad = True )
2081- y = MyFn () (x )
2081+ y = MyFn . apply (x )
20822082 y .sum ().backward ()
20832083 self .assertEqual (x .grad , torch .ones_like (x ))
20842084
@@ -2098,29 +2098,32 @@ def assert_strict_equal(var1, var2):
20982098
20992099 def test_dep_nograd (self ):
21002100 class F1 (Function ):
2101-
2102- def forward (self , input ):
2101+ @ staticmethod
2102+ def forward (ctx , input ):
21032103 out = torch .randn (input .size ())
2104- self .mark_non_differentiable (out )
2104+ ctx .mark_non_differentiable (out )
21052105 return input , out
21062106
2107- def backward (self , grad_output , ignored ):
2107+ @staticmethod
2108+ def backward (ctx , grad_output , ignored ):
21082109 return grad_output
21092110
21102111 class F2 (Function ):
2111-
2112- def forward (self , input , ignored ):
2112+ @ staticmethod
2113+ def forward (ctx , input , ignored ):
21132114 return input
21142115
2115- def backward (self , grad_output ):
2116+ @staticmethod
2117+ def backward (ctx , grad_output ):
21162118 return grad_output , None
21172119
21182120 x = torch .randn (5 , requires_grad = True )
2119- a , b = F1 () (x )
2121+ a , b = F1 . apply (x )
21202122 b = b + 1 # separate F1 from F2 by another op
21212123 self .assertTrue (a .requires_grad )
21222124 self .assertFalse (b .requires_grad )
2123- c = F2 ()(a , b )
2125+ c = F2 .apply (a , b )
2126+ print (c .grad_fn )
21242127 c .backward (torch .ones (c .size ()))
21252128 self .assertEqual (x .grad .data , torch .ones (x .size ()))
21262129
0 commit comments