Skip to content

Commit 769f5f7

Browse files
authored
Handling of scalars in torch.Size (#5676)
* Handling of scalars in torch.Size torch.Size() constructor uses python_arg_parser IntList in python_arg_parser can take iter/range Have IntList take python iterables and ranges. Address comments: don't use python_arg_parser and instead call __index__ in THPSize_pynew Address comments Address comments * Rebased * Address nit
1 parent d102f9e commit 769f5f7

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7100,6 +7100,19 @@ def test_Size(self):
71007100
self.assertIsInstance(x[:-1], torch.Size)
71017101
self.assertIsInstance(x + x, torch.Size)
71027102

7103+
def test_Size_scalar(self):
7104+
three = torch.tensor(3)
7105+
two = torch.tensor(2)
7106+
x = torch.Size([0, 1, two, three, 4])
7107+
for i in range(1, 5):
7108+
self.assertEqual(x[i], i)
7109+
7110+
def test_Size_iter(self):
7111+
for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]:
7112+
x = torch.Size(sizes)
7113+
for i in range(0, 5):
7114+
self.assertEqual(x[i], i + 1)
7115+
71037116
def test_t_not_2d_error(self):
71047117
self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
71057118
self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())

torch/csrc/Size.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,25 @@ static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kw
5050
if (self) {
5151
for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
5252
PyObject *item = PyTuple_GET_ITEM(self.get(), i);
53-
if (!THPUtils_checkLong(item) && !isTracedVar(item)) {
54-
return PyErr_Format(PyExc_TypeError, "torch.Size() takes an iterable of 'int' (item %zd is '%s')",
55-
i, Py_TYPE(item)->tp_name);
53+
if (isTracedVar(item)) {
54+
continue;
5655
}
56+
if (THPUtils_checkLong(item)) {
57+
continue;
58+
}
59+
// item.__index__() works with 0-dim tensors and tensors with one element
60+
THPObjectPtr number(PyNumber_Index(item));
61+
if (number && THPUtils_checkLong(number.get())) {
62+
Py_INCREF(number.get());
63+
auto status = PyTuple_SetItem(self, i, number.get());
64+
if (status != 0) {
65+
throw python_error();
66+
}
67+
continue;
68+
}
69+
return PyErr_Format(PyExc_TypeError,
70+
"torch.Size() takes an iterable of 'int' (item %zd is '%s')",
71+
i, Py_TYPE(item)->tp_name);
5772
}
5873
}
5974
return self.release();

0 commit comments

Comments
 (0)