@@ -1175,6 +1175,28 @@ def test_add_module(self):
11751175 self .assertEqual (net .l , l3 )
11761176 self .assertRaises (TypeError , lambda : net .add_module ('x' , 'non-module' ))
11771177
1178+ def test_module_to_argparse (self ):
1179+ net = nn .Sequential (nn .Linear (3 , 3 ))
1180+ cpu = torch .device ('cpu' )
1181+ with self .assertRaises (TypeError ):
1182+ net .to (cpu , True )
1183+ with self .assertRaises (TypeError ):
1184+ net .to (torch .long )
1185+ with self .assertRaises (TypeError ):
1186+ net .to (None , True )
1187+ with self .assertRaises (TypeError ):
1188+ net .to (cpu , torch .long , True )
1189+ with self .assertRaises (TypeError ):
1190+ net .to (cpu , dtype = torch .long , non_blocking = True )
1191+ with self .assertRaises (TypeError ):
1192+ net .to ([])
1193+ with self .assertRaises (TypeError ):
1194+ net .to ({}, non_blocking = True )
1195+ with self .assertRaises (TypeError ):
1196+ net .to (torch .tensor (3 , dtype = torch .long ), non_blocking = True )
1197+ with self .assertRaises (TypeError ):
1198+ net .to (cpu , torch .tensor (3 , dtype = torch .long ), non_blocking = True )
1199+
11781200 def test_type (self ):
11791201 l = nn .Linear (10 , 20 )
11801202 net = nn .Module ()
@@ -1203,22 +1225,22 @@ def test_type(self):
12031225 self .assertIsInstance (l .weight .data , torch .FloatTensor )
12041226 self .assertIsInstance (l .bias .data , torch .FloatTensor )
12051227 self .assertIsInstance (net .indices , torch .LongTensor )
1206- net .to ("cuda" , torch .double )
1228+ net .to ("cuda" , torch .double , True )
12071229 self .assertIsInstance (l .weight .data , torch .cuda .DoubleTensor )
12081230 self .assertIsInstance (l .bias .data , torch .cuda .DoubleTensor )
12091231 self .assertIsInstance (net .indices , torch .cuda .LongTensor )
1210- net .to (device = "cuda:0" , dtype = torch .half )
1232+ net .to (torch . empty ( 1 , device = "cuda:0" , dtype = torch .half ) )
12111233 self .assertIsInstance (l .weight .data , torch .cuda .HalfTensor )
12121234 self .assertIsInstance (l .bias .data , torch .cuda .HalfTensor )
12131235 self .assertIsInstance (net .indices , torch .cuda .LongTensor )
1214- net .to (torch .device ("cpu" ))
1236+ net .to (torch .device ("cpu" ), non_blocking = True )
12151237 self .assertIsInstance (l .weight .data , torch .HalfTensor )
12161238 self .assertIsInstance (l .bias .data , torch .HalfTensor )
12171239 self .assertIsInstance (net .indices , torch .LongTensor )
12181240 net .type (torch .FloatTensor )
12191241 self .assertIsInstance (l .weight .data , torch .FloatTensor )
12201242 self .assertIsInstance (l .bias .data , torch .FloatTensor )
1221- net .type (torch .DoubleTensor )
1243+ net .to (torch .DoubleTensor ( 1 ) )
12221244 self .assertIsInstance (l .weight .data , torch .DoubleTensor )
12231245 self .assertIsInstance (l .bias .data , torch .DoubleTensor )
12241246 if TEST_CUDA :
0 commit comments