Skip to content

Commit 7159e85

Browse files
committed
fix as_strided backward when input is overlapping
check for input overlapping too [doc] clarify gradcheck behabior when input is overlapping longer note
1 parent 4a9eccd commit 7159e85

File tree

4 files changed

+316
-156
lines changed

4 files changed

+316
-156
lines changed

test/test_autograd.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,39 +2054,41 @@ def test_dir(self):
20542054

20552055
def test_as_strided(self):
20562056

2057-
def test(x, *args, **kwargs):
2057+
def test(x, repro_fn, *args):
20582058
def closure(x):
2059-
return x.as_strided(*args, **kwargs)
2059+
if repro_fn is not None:
2060+
x = repro_fn(x)
2061+
return x.as_strided(*args)
20602062

20612063
x = x.to(torch.double).detach().requires_grad_()
20622064
gradcheck(closure, [x])
20632065
gradgradcheck(closure, [x])
20642066

20652067
# test
2066-
test(torch.arange(0, 25).view(5, 5), [3, 3], [6, 2], 2)
2068+
test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2)
20672069

20682070
# test crazy stride at dim with size 1 case
2069-
test(torch.randn(10), [1, 2, 1, 5], [0, 5, 100, 1], 2)
2071+
test(torch.randn(10), None, [1, 2, 1, 5], [0, 5, 100, 1], 2)
20702072

20712073
# test expand case
2072-
test(torch.randn(5), [3, 3, 3], [0, 1, 0], 2)
2073-
test(torch.randn(5), [3, 3, 3], [0, 0, 0], 4)
2074-
test(torch.randn(5).expand(5, 5), [5, 5], [0, 1], 0)
2074+
test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2)
2075+
test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4)
2076+
test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0)
20752077

20762078
# test non-expand overlapping case
2077-
test(torch.randn(35), [6, 6], [5, 1], 2)
2078-
test(torch.randn(15), [3, 2], [3, 6], 2)
2079+
test(torch.randn(35), None, [6, 6], [5, 1], 2)
2080+
test(torch.randn(15), None, [3, 2], [3, 6], 2)
20792081

20802082
# test transpose case
2081-
test(torch.randn(3, 4), [4, 3], [1, 4])
2083+
test(torch.randn(3, 4), None, [4, 3], [1, 4])
20822084

20832085
# test "getting things outside the input" case
20842086
x = torch.randn(6, 2)
2085-
test(x[3:], [3, 2], [2, 1])
2087+
test(x[3:], None, [3, 2], [2, 1], 0) # should be all zeros
20862088
self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3])
20872089

2088-
# test input expanded case
2089-
test(torch.randn(2, 3).expand(10, 2, 3), [2, 3], [3, 1], 0)
2090+
# test select on expanded input case
2091+
test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0)
20902092

20912093
def _test_where_functional(self, t):
20922094
x = Variable(t(torch.randn(5, 5)), requires_grad=True)

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,8 @@
578578
# DO NOT define a backward for reshape!
579579
# reshape is special in that it sometimes returns a view, and somtimes not.
580580
# Defining a backward will make codegen spit out the forward call as
581-
# as_variable(baseType->reshape(self))
582-
# , making it impossible (hard) to detect when it is actually a view.
581+
# as_variable(baseType->reshape(self)),
582+
# making it impossible (hard) to detect when it is actually a view.
583583
# - name: reshape(Tensor self, IntList shape)
584584

585585
- name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale)

0 commit comments

Comments
 (0)