@@ -1903,6 +1903,25 @@ def test_index(self):
19031903 self .assertEqual (reference [:, 2 , 1 :6 :2 ],
19041904 torch .stack ([reference [:, 2 , 1 ], reference [:, 2 , 3 ], reference [:, 2 , 5 ]], 1 ))
19051905
1906+ lst = [list (range (i , i + 10 )) for i in range (0 , 100 , 10 )]
1907+ tensor = torch .DoubleTensor (lst )
1908+ for i in range (100 ):
1909+ idx1_start = random .randrange (10 )
1910+ idx1_end = idx1_start + random .randrange (1 , 10 - idx1_start + 1 )
1911+ idx1_step = random .randrange (1 , 8 )
1912+ idx1 = slice (idx1_start , idx1_end , idx1_step )
1913+ if random .randrange (2 ) == 0 :
1914+ idx2_start = random .randrange (10 )
1915+ idx2_end = idx2_start + random .randrange (1 , 10 - idx2_start + 1 )
1916+ idx2_step = random .randrange (1 , 8 )
1917+ idx2 = slice (idx2_start , idx2_end , idx2_step )
1918+ lst_indexed = list (map (lambda l : l [idx2 ], lst [idx1 ]))
1919+ tensor_indexed = tensor [idx1 , idx2 ]
1920+ else :
1921+ lst_indexed = lst [idx1 ]
1922+ tensor_indexed = tensor [idx1 ]
1923+ self .assertEqual (torch .DoubleTensor (lst_indexed ), tensor_indexed )
1924+
19061925 self .assertRaises (ValueError , lambda : reference [1 :9 :0 ])
19071926 self .assertRaises (ValueError , lambda : reference [1 :9 :- 1 ])
19081927
0 commit comments