Skip to content

Commit ea3c36b

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
NumPy Scalar to PyTorch Scalar (#9225)
Summary: Fixes #4985 . Pull Request resolved: #9225 Differential Revision: D8769317 Pulled By: ezyang fbshipit-source-id: eeaeaf0749c9dc9e372634da68b4bd23e6e3ad28
1 parent c9eab34 commit ea3c36b

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

test/test_torch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7903,6 +7903,20 @@ def test_ctor_with_numpy_array(self):
79037903
for i in range(len(array)):
79047904
self.assertEqual(tensor[i], array[i])
79057905

7906+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
7907+
def test_ctor_with_numpy_scalar_ctor(self):
7908+
dtypes = [
7909+
np.double,
7910+
np.float,
7911+
np.float16,
7912+
np.int64,
7913+
np.int32,
7914+
np.int16,
7915+
np.uint8
7916+
]
7917+
for dtype in dtypes:
7918+
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
7919+
79067920
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
79077921
def test_numpy_index(self):
79087922
i = np.int32([0, 1, 2])

torch/csrc/utils/tensor_new.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ ScalarType infer_scalar_type(PyObject *obj) {
139139
}
140140
#ifdef USE_NUMPY
141141
if (PyArray_Check(obj)) {
142-
auto array = (PyArrayObject*)obj;
143-
return numpy_dtype_to_aten(PyArray_TYPE(array));
142+
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
143+
}
144+
if (PyArray_CheckScalar(obj)) {
145+
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)(PyArray_FromScalar(obj, NULL))));
144146
}
145147
#endif
146148
if (PySequence_Check(obj)) {

0 commit comments

Comments
 (0)