File tree Expand file tree Collapse file tree 2 files changed +18
-2
lines changed
Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -7903,6 +7903,20 @@ def test_ctor_with_numpy_array(self):
79037903 for i in range (len (array )):
79047904 self .assertEqual (tensor [i ], array [i ])
79057905
7906+ @unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
7907+ def test_ctor_with_numpy_scalar_ctor (self ):
7908+ dtypes = [
7909+ np .double ,
7910+ np .float ,
7911+ np .float16 ,
7912+ np .int64 ,
7913+ np .int32 ,
7914+ np .int16 ,
7915+ np .uint8
7916+ ]
7917+ for dtype in dtypes :
7918+ self .assertEqual (dtype (42 ), torch .tensor (dtype (42 )).item ())
7919+
79067920 @unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
79077921 def test_numpy_index (self ):
79087922 i = np .int32 ([0 , 1 , 2 ])
Original file line number Diff line number Diff line change @@ -139,8 +139,10 @@ ScalarType infer_scalar_type(PyObject *obj) {
139139 }
140140#ifdef USE_NUMPY
141141 if (PyArray_Check (obj)) {
142- auto array = (PyArrayObject*)obj;
143- return numpy_dtype_to_aten (PyArray_TYPE (array));
142+ return numpy_dtype_to_aten (PyArray_TYPE ((PyArrayObject*)obj));
143+ }
144+ if (PyArray_CheckScalar (obj)) {
145+ return numpy_dtype_to_aten (PyArray_TYPE ((PyArrayObject*)(PyArray_FromScalar (obj, NULL ))));
144146 }
145147#endif
146148 if (PySequence_Check (obj)) {
You can’t perform that action at this time.
0 commit comments