@@ -2054,39 +2054,41 @@ def test_dir(self):
20542054
20552055 def test_as_strided (self ):
20562056
2057- def test (x , * args , ** kwargs ):
2057+ def test (x , repro_fn , * args ):
20582058 def closure (x ):
2059- return x .as_strided (* args , ** kwargs )
2059+ if repro_fn is not None :
2060+ x = repro_fn (x )
2061+ return x .as_strided (* args )
20602062
20612063 x = x .to (torch .double ).detach ().requires_grad_ ()
20622064 gradcheck (closure , [x ])
20632065 gradgradcheck (closure , [x ])
20642066
20652067 # test
2066- test (torch .arange (0 , 25 ).view (5 , 5 ), [3 , 3 ], [6 , 2 ], 2 )
2068+ test (torch .arange (0 , 25 ), lambda x : x .view (5 , 5 ), [3 , 3 ], [6 , 2 ], 2 )
20672069
20682070 # test crazy stride at dim with size 1 case
2069- test (torch .randn (10 ), [1 , 2 , 1 , 5 ], [0 , 5 , 100 , 1 ], 2 )
2071+ test (torch .randn (10 ), None , [1 , 2 , 1 , 5 ], [0 , 5 , 100 , 1 ], 2 )
20702072
20712073 # test expand case
2072- test (torch .randn (5 ), [3 , 3 , 3 ], [0 , 1 , 0 ], 2 )
2073- test (torch .randn (5 ), [3 , 3 , 3 ], [0 , 0 , 0 ], 4 )
2074- test (torch .randn (5 ).expand (5 , 5 ), [5 , 5 ], [0 , 1 ], 0 )
2074+ test (torch .randn (5 ), None , [3 , 3 , 3 ], [0 , 1 , 0 ], 2 )
2075+ test (torch .randn (5 ), None , [3 , 3 , 3 ], [0 , 0 , 0 ], 4 )
2076+ test (torch .randn (5 ), lambda x : x .expand (5 , 5 ), [5 , 5 ], [0 , 1 ], 0 )
20752077
20762078 # test non-expand overlapping case
2077- test (torch .randn (35 ), [6 , 6 ], [5 , 1 ], 2 )
2078- test (torch .randn (15 ), [3 , 2 ], [3 , 6 ], 2 )
2079+ test (torch .randn (35 ), None , [6 , 6 ], [5 , 1 ], 2 )
2080+ test (torch .randn (15 ), None , [3 , 2 ], [3 , 6 ], 2 )
20792081
20802082 # test transpose case
2081- test (torch .randn (3 , 4 ), [4 , 3 ], [1 , 4 ])
2083+ test (torch .randn (3 , 4 ), None , [4 , 3 ], [1 , 4 ])
20822084
20832085 # test "getting things outside the input" case
20842086 x = torch .randn (6 , 2 )
2085- test (x [3 :], [3 , 2 ], [2 , 1 ])
2087+ test (x [3 :], None , [3 , 2 ], [2 , 1 ], 0 ) # should be all zeros
20862088 self .assertEqual (x [3 :].as_strided ([3 , 2 ], [2 , 1 ], 0 ), x [:3 ])
20872089
2088- # test input expanded case
2089- test (torch .randn (2 , 3 ).expand (10 , 2 , 3 ), [2 , 3 ], [3 , 1 ], 0 )
2090+ # test select on expanded input case
2091+ test (torch .randn (2 , 3 ), lambda x : x .expand (10 , 2 , 3 ), [2 , 3 ], [3 , 1 ], 0 )
20902092
20912093 def _test_where_functional (self , t ):
20922094 x = Variable (t (torch .randn (5 , 5 )), requires_grad = True )
0 commit comments