@@ -45,7 +45,7 @@ def test_unsupported(self):
4545 with self .assertRaises (RuntimeError ) as context :
4646 torch .randn (1 , 2 , 3 , 4 , dtype = torch .float , device = torch .device ('cuda' )).to_mkldnn ()
4747 # some factory functions
48- for creator in [torch .empty , torch . ones , torch .zeros , torch .randn , torch .rand ]:
48+ for creator in [torch .ones , torch .zeros , torch .randn , torch .rand ]:
4949 with self .assertRaises (RuntimeError ) as context :
5050 creator (1 , 2 , 3 , 4 , dtype = torch .float , device = torch .device ('cpu' ), layout = torch ._mkldnn )
5151
@@ -289,6 +289,11 @@ def test_set_data_tensorimpl_type(self):
289289 with self .assertRaisesRegex (RuntimeError , 'different types of TensorImpl' ):
290290 x .data = x_mkldnn
291291
292+ def test_empty (self ):
293+ x1 = torch .empty (4 , 5 , 2 , 3 , dtype = torch .float32 )
294+ x2 = torch .empty (4 , 5 , 2 , 3 , dtype = torch .float32 , layout = torch ._mkldnn )
295+ self .assertEqual (x1 .size (), x2 .to_dense ().size ())
296+ self .assertEqual (x1 .dtype , x2 .to_dense ().dtype )
292297
293298if __name__ == '__main__' :
294299 run_tests ()
0 commit comments