Skip to content

Commit f366e5f

Browse files
Teaonlysoumith
authored andcommitted
Support int16 numpy conversions
issue #891
1 parent 7ad948f commit f366e5f

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

test/test_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,7 @@ def test_from_numpy(self):
28492849
np.float,
28502850
np.int64,
28512851
np.int32,
2852+
np.int16,
28522853
np.uint8
28532854
]
28542855
for dtype in dtypes:

torch/csrc/Module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ PyObject * THPModule_fromNumpy(PyObject *_unused, PyObject *array)
141141
return PyObject_CallFunctionObjArgs(THPLongTensorClass, array, NULL);
142142
} else if (type == NPY_INT32) {
143143
return PyObject_CallFunctionObjArgs(THPIntTensorClass, array, NULL);
144+
} else if (type == NPY_INT16) {
145+
return PyObject_CallFunctionObjArgs(THPShortTensorClass, array, NULL);
144146
} else if (type == NPY_UINT8) {
145147
return PyObject_CallFunctionObjArgs(THPByteTensorClass, array, NULL);
146148
}

torch/csrc/generic/Tensor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#ifdef TH_REAL_IS_INT
1717
#define NUMPY_TYPE_ENUM NPY_INT32
1818
#endif
19+
#ifdef TH_REAL_IS_SHORT
20+
#define NUMPY_TYPE_ENUM NPY_INT16
21+
#endif
1922
#ifdef TH_REAL_IS_BYTE
2023
#define NUMPY_TYPE_ENUM NPY_UINT8
2124
#endif

0 commit comments

Comments
 (0)