Skip to content

Commit d85b098

Browse files
author
Will Feng
committed
Remove usage of legacy autograd function
1 parent 7ed82ea commit d85b098

File tree

3 files changed

+107
-104
lines changed

3 files changed

+107
-104
lines changed

test/test_autograd.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -504,31 +504,28 @@ def test_hook_none(self):
504504
# WARNING: this is a test for autograd internals.
505505
# You should never have to use such things in your code.
506506
class NoneGradientFunction(Function):
507-
508-
def forward(self, x, y):
509-
assert self.needs_input_grad[0]
510-
assert not self.needs_input_grad[1]
507+
@staticmethod
508+
def forward(ctx, x, y):
509+
assert ctx.needs_input_grad[0]
510+
assert not ctx.needs_input_grad[1]
511511
return x, y
512512

513-
def backward(self, grad_x, grad_y):
513+
@staticmethod
514+
def backward(ctx, grad_x, grad_y):
514515
return grad_x, None
515516

516-
fn = NoneGradientFunction()
517517
was_called = [False]
518518

519-
def hook(grad_input, grad_output):
520-
self.assertIsInstance(grad_input, tuple)
521-
self.assertIsInstance(grad_output, tuple)
522-
self.assertIsNotNone(grad_input[0])
523-
self.assertIsNotNone(grad_input[1])
524-
self.assertIsNotNone(grad_output[0])
525-
self.assertIsNotNone(grad_output[1])
519+
def hook(grad):
520+
self.assertIsNotNone(grad)
526521
was_called[0] = True
527-
fn.register_hook(hook)
528522

529523
x = torch.randn(5, 5, requires_grad=True)
530524
y = torch.randn(5, 5)
531-
sum(fn(x, y)).sum().backward()
525+
rx, ry = NoneGradientFunction.apply(x, y)
526+
rx.register_hook(hook)
527+
ry.register_hook(hook)
528+
sum(rx, ry).sum().backward()
532529
self.assertTrue(was_called[0])
533530

534531
def test_retain_grad(self):
@@ -601,14 +598,15 @@ def test_backward(self):
601598

602599
def test_sparse_backward(self):
603600
class FixedGradientFunction(Function):
604-
def __init__(self, grad):
605-
self.grad = grad
606-
607-
def forward(self, x):
601+
@staticmethod
602+
def forward(ctx, x, grad_x):
603+
ctx.save_for_backward(grad_x)
608604
return x
609605

610-
def backward(self, grad_x):
611-
return self.grad
606+
@staticmethod
607+
def backward(ctx, grad_x):
608+
saved_grad_x, = ctx.saved_tensors
609+
return saved_grad_x, None
612610

613611
size = torch.Size([6, 3, 2])
614612
i1 = torch.LongTensor([
@@ -624,21 +622,19 @@ def backward(self, grad_x):
624622
v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]])
625623
sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size)
626624
dense_grad = torch.rand(size).double()
627-
sparse_fn1 = FixedGradientFunction(sparse_grad1)
628-
sparse_fn2 = FixedGradientFunction(sparse_grad2)
629-
dense_fn = FixedGradientFunction(dense_grad)
625+
fn = FixedGradientFunction
630626

631627
# sparse first
632628
x = torch.randn(size, requires_grad=True)
633-
(sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
629+
(fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().backward()
634630
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
635631
# dense first
636632
x = torch.randn(size, requires_grad=True)
637-
(dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
633+
(fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
638634
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
639635
# sparse only
640636
x = torch.randn(size, requires_grad=True)
641-
(sparse_fn1(x) + sparse_fn2(x)).sum().backward()
637+
(fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
642638
self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
643639

644640
def test_sparse_mm_backward(self):
@@ -1913,18 +1909,19 @@ def test_numpy_requires_grad(self):
19131909

19141910
def test_return_leaf(self):
19151911
class Identity(Function):
1916-
1917-
def forward(self, a, b):
1912+
@staticmethod
1913+
def forward(ctx, a, b):
19181914
return a, a + b
19191915

1920-
def backward(self, grad_a, grad_b):
1916+
@staticmethod
1917+
def backward(ctx, grad_a, grad_b):
19211918
return grad_a + grad_b, grad_b
19221919

19231920
hook_called = [False]
19241921
x = torch.randn(5, 5, requires_grad=True)
19251922
y = torch.randn(5, 5, requires_grad=True)
19261923

1927-
q, p = Identity()(x, y)
1924+
q, p = Identity.apply(x, y)
19281925

19291926
# Make sure hooks only receive grad from usage of q, not x.
19301927
def hook(grad):
@@ -1939,21 +1936,22 @@ def hook(grad):
19391936

19401937
def test_return_leaf_inplace(self):
19411938
class Inplace(InplaceFunction):
1942-
1943-
def forward(self, a, b):
1944-
self.mark_dirty(a)
1939+
@staticmethod
1940+
def forward(ctx, a, b):
1941+
ctx.mark_dirty(a)
19451942
return a.add_(b), b + 2
19461943

1947-
def backward(self, grad_a, grad_b):
1944+
@staticmethod
1945+
def backward(ctx, grad_a, grad_b):
19481946
return grad_a, grad_a + grad_b
19491947

19501948
x = torch.randn(5, 5)
19511949
y = torch.randn(5, 5, requires_grad=True)
19521950

19531951
fn = Inplace(True)
1954-
q, p = fn(x, y)
1952+
q, p = fn.apply(x, y)
19551953
self.assertIs(q, x)
1956-
self.assertIs(q.grad_fn, fn)
1954+
self.assertIs(q.grad_fn.__class__, fn._backward_cls)
19571955
self.assertTrue(q.requires_grad)
19581956
q.sum().backward()
19591957
self.assertEqual(y.grad.data, torch.ones(5, 5))
@@ -2052,33 +2050,35 @@ def test_save_none_for_backward(self):
20522050
test_case = self
20532051

20542052
class MyFn(Function):
2055-
2056-
def forward(self, input):
2057-
self.save_for_backward(None, input, None)
2053+
@staticmethod
2054+
def forward(ctx, input):
2055+
ctx.save_for_backward(None, input, None)
20582056
return input * input
20592057

2060-
def backward(self, grad_output):
2061-
n1, input, n2 = self.saved_tensors
2058+
@staticmethod
2059+
def backward(ctx, grad_output):
2060+
n1, input, n2 = ctx.saved_tensors
20622061
test_case.assertIsNone(n1)
20632062
test_case.assertIsNone(n2)
20642063
return 2 * input * grad_output
20652064

20662065
x = torch.randn(5, 5, requires_grad=True)
2067-
y = MyFn()(x)
2066+
y = MyFn.apply(x)
20682067
y.sum().backward()
20692068
self.assertEqual(x.grad, 2 * x)
20702069

20712070
def test_too_many_grads(self):
20722071
class MyFn(Function):
2073-
2074-
def forward(self, input):
2072+
@staticmethod
2073+
def forward(ctx, input):
20752074
return input
20762075

2077-
def backward(self, grad_output):
2076+
@staticmethod
2077+
def backward(ctx, grad_output):
20782078
return grad_output, None, None
20792079

20802080
x = torch.randn(5, 5, requires_grad=True)
2081-
y = MyFn()(x)
2081+
y = MyFn.apply(x)
20822082
y.sum().backward()
20832083
self.assertEqual(x.grad, torch.ones_like(x))
20842084

@@ -2098,29 +2098,32 @@ def assert_strict_equal(var1, var2):
20982098

20992099
def test_dep_nograd(self):
21002100
class F1(Function):
2101-
2102-
def forward(self, input):
2101+
@staticmethod
2102+
def forward(ctx, input):
21032103
out = torch.randn(input.size())
2104-
self.mark_non_differentiable(out)
2104+
ctx.mark_non_differentiable(out)
21052105
return input, out
21062106

2107-
def backward(self, grad_output, ignored):
2107+
@staticmethod
2108+
def backward(ctx, grad_output, ignored):
21082109
return grad_output
21092110

21102111
class F2(Function):
2111-
2112-
def forward(self, input, ignored):
2112+
@staticmethod
2113+
def forward(ctx, input, ignored):
21132114
return input
21142115

2115-
def backward(self, grad_output):
2116+
@staticmethod
2117+
def backward(ctx, grad_output):
21162118
return grad_output, None
21172119

21182120
x = torch.randn(5, requires_grad=True)
2119-
a, b = F1()(x)
2121+
a, b = F1.apply(x)
21202122
b = b + 1 # separate F1 from F2 by another op
21212123
self.assertTrue(a.requires_grad)
21222124
self.assertFalse(b.requires_grad)
2123-
c = F2()(a, b)
2125+
c = F2.apply(a, b)
2126+
print(c.grad_fn)
21242127
c.backward(torch.ones(c.size()))
21252128
self.assertEqual(x.grad.data, torch.ones(x.size()))
21262129

0 commit comments

Comments
 (0)