Skip to content

Commit 490c15f

Browse files
apaszkesoumith
authored andcommitted
Fix slicing with step (#905)
1 parent f2d72ba commit 490c15f

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

test/test_torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,25 @@ def test_index(self):
19031903
self.assertEqual(reference[:, 2, 1:6:2],
19041904
torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
19051905

1906+
lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
1907+
tensor = torch.DoubleTensor(lst)
1908+
for i in range(100):
1909+
idx1_start = random.randrange(10)
1910+
idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
1911+
idx1_step = random.randrange(1, 8)
1912+
idx1 = slice(idx1_start, idx1_end, idx1_step)
1913+
if random.randrange(2) == 0:
1914+
idx2_start = random.randrange(10)
1915+
idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
1916+
idx2_step = random.randrange(1, 8)
1917+
idx2 = slice(idx2_start, idx2_end, idx2_step)
1918+
lst_indexed = list(map(lambda l: l[idx2], lst[idx1]))
1919+
tensor_indexed = tensor[idx1, idx2]
1920+
else:
1921+
lst_indexed = lst[idx1]
1922+
tensor_indexed = tensor[idx1]
1923+
self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
1924+
19061925
self.assertRaises(ValueError, lambda: reference[1:9:0])
19071926
self.assertRaises(ValueError, lambda: reference[1:9:-1])
19081927

torch/csrc/generic/Tensor.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,13 @@ static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim,
477477
PyErr_SetString(PyExc_ValueError, "slice step has to be greater than 0");
478478
throw python_error();
479479
}
480-
THTensor_(narrow)(LIBRARY_STATE tresult.get(), NULL, indexed_dim, start, length * step);
480+
if (length == 0) {
481+
PyErr_SetString(PyExc_ValueError, "result of slicing is an empty tensor");
482+
throw python_error();
483+
}
484+
tresult->storageOffset += tresult->stride[indexed_dim] * start;
481485
tresult->stride[indexed_dim] *= step;
482-
tresult->size[indexed_dim] /= step;
486+
tresult->size[indexed_dim] = length;
483487
indexed_dim++;
484488
} else {
485489
return false;

0 commit comments

Comments
 (0)