@@ -1110,6 +1110,91 @@ def check():
11101110 module_list .extend (s .modules ())
11111111 check ()
11121112
1113+ def test_ModuleDict (self ):
1114+ modules = OrderedDict ([
1115+ ('act' , nn .ReLU ()),
1116+ ('conv' , nn .Conv2d (10 , 10 , 5 )),
1117+ ('fc' , nn .Linear (5 , 5 )),
1118+ ])
1119+
1120+ module_dict = nn .ModuleDict (modules )
1121+
1122+ def check ():
1123+ self .assertEqual (len (module_dict ), len (modules ))
1124+ for k1 , m2 in zip (modules , module_dict .children ()):
1125+ self .assertIs (modules [k1 ], m2 )
1126+ for k1 , k2 in zip (modules , module_dict ):
1127+ self .assertIs (modules [k1 ], module_dict [k2 ])
1128+ for k in module_dict :
1129+ self .assertIs (module_dict [k ], modules [k ])
1130+ for k in module_dict .keys ():
1131+ self .assertIs (module_dict [k ], modules [k ])
1132+ for k , v in module_dict .items ():
1133+ self .assertIs (modules [k ], v )
1134+ for k1 , m2 in zip (modules , module_dict .values ()):
1135+ self .assertIs (modules [k1 ], m2 )
1136+ for k in modules .keys ():
1137+ self .assertTrue (k in module_dict )
1138+ check ()
1139+
1140+ modules ['conv' ] = nn .Conv2d (3 , 4 , 3 )
1141+ module_dict ['conv' ] = modules ['conv' ]
1142+ check ()
1143+
1144+ next_modules = [
1145+ ('fc2' , nn .Linear (5 , 5 )),
1146+ ('act' , nn .Sigmoid ()),
1147+ ]
1148+ modules .update (next_modules )
1149+ module_dict .update (next_modules )
1150+ check ()
1151+
1152+ next_modules = OrderedDict ([
1153+ ('fc3' , nn .Linear (5 , 5 )),
1154+ ('act2' , nn .Sigmoid ()),
1155+ ])
1156+ modules .update (next_modules )
1157+ module_dict .update (next_modules )
1158+ check ()
1159+
1160+ next_modules = {
1161+ 'fc4' : nn .Linear (5 , 5 ),
1162+ 'act3' : nn .Sigmoid ()
1163+ }
1164+ modules .update (sorted (next_modules .items ()))
1165+ module_dict .update (next_modules )
1166+ check ()
1167+
1168+ del module_dict ['fc' ]
1169+ del modules ['fc' ]
1170+ check ()
1171+
1172+ with self .assertRaises (TypeError ):
1173+ module_dict .update (nn .ReLU ())
1174+
1175+ with self .assertRaises (TypeError ):
1176+ module_dict .update ([nn .ReLU ()])
1177+
1178+ with self .assertRaises (ValueError ):
1179+ module_dict .update ([[nn .ReLU ()]])
1180+
1181+ with self .assertRaises (TypeError ):
1182+ module_dict [1 ] = nn .ReLU ()
1183+
1184+ s = nn .Sequential (modules )
1185+ module_dict = nn .ModuleDict (s .named_children ())
1186+ check ()
1187+
1188+ c = module_dict .pop ('conv' )
1189+ self .assertIs (c , modules ['conv' ])
1190+ modules .pop ('conv' )
1191+ check ()
1192+
1193+ module_dict .clear ()
1194+ self .assertEqual (len (module_dict ), 0 )
1195+ modules .clear ()
1196+ check ()
1197+
11131198 def test_ParameterList (self ):
11141199 def make_param ():
11151200 return Parameter (torch .randn (10 , 10 ))
@@ -1174,6 +1259,88 @@ def check():
11741259 param_list .extend (s .parameters ())
11751260 check ()
11761261
1262+ def test_ParameterDict (self ):
1263+ parameters = OrderedDict ([
1264+ ('p1' , Parameter (torch .randn (10 , 10 ))),
1265+ ('p2' , Parameter (torch .randn (10 , 10 ))),
1266+ ('p3' , Parameter (torch .randn (10 , 10 ))),
1267+ ])
1268+
1269+ parameter_dict = nn .ParameterDict (parameters )
1270+
1271+ def check ():
1272+ self .assertEqual (len (parameter_dict ), len (parameters ))
1273+ for k1 , m2 in zip (parameters , parameter_dict .parameters ()):
1274+ self .assertIs (parameters [k1 ], m2 )
1275+ for k1 , k2 in zip (parameters , parameter_dict ):
1276+ self .assertIs (parameters [k1 ], parameter_dict [k2 ])
1277+ for k in parameter_dict :
1278+ self .assertIs (parameter_dict [k ], parameters [k ])
1279+ for k in parameter_dict .keys ():
1280+ self .assertIs (parameter_dict [k ], parameters [k ])
1281+ for k , v in parameter_dict .items ():
1282+ self .assertIs (v , parameters [k ])
1283+ for k1 , m2 in zip (parameters , parameter_dict .values ()):
1284+ self .assertIs (parameters [k1 ], m2 )
1285+ for k in parameters .keys ():
1286+ self .assertTrue (k in parameter_dict )
1287+
1288+ check ()
1289+
1290+ parameters ['p4' ] = Parameter (torch .randn (10 , 10 ))
1291+ parameter_dict ['p4' ] = parameters ['p4' ]
1292+ check ()
1293+
1294+ next_parameters = [
1295+ ('p5' , Parameter (torch .randn (10 , 10 ))),
1296+ ('p2' , Parameter (torch .randn (10 , 10 ))),
1297+ ]
1298+ parameters .update (next_parameters )
1299+ parameter_dict .update (next_parameters )
1300+ check ()
1301+
1302+ next_parameters = OrderedDict ([
1303+ ('p6' , Parameter (torch .randn (10 , 10 ))),
1304+ ('p5' , Parameter (torch .randn (10 , 10 ))),
1305+ ])
1306+ parameters .update (next_parameters )
1307+ parameter_dict .update (next_parameters )
1308+ check ()
1309+
1310+ next_parameters = {
1311+ 'p8' : Parameter (torch .randn (10 , 10 )),
1312+ 'p7' : Parameter (torch .randn (10 , 10 ))
1313+ }
1314+ parameters .update (sorted (next_parameters .items ()))
1315+ parameter_dict .update (next_parameters )
1316+ check ()
1317+
1318+ del parameter_dict ['p3' ]
1319+ del parameters ['p3' ]
1320+ check ()
1321+
1322+ with self .assertRaises (TypeError ):
1323+ parameter_dict .update (1 )
1324+
1325+ with self .assertRaises (TypeError ):
1326+ parameter_dict .update ([1 ])
1327+
1328+ with self .assertRaises (ValueError ):
1329+ parameter_dict .update (Parameter (torch .randn (10 , 10 )))
1330+
1331+ with self .assertRaises (TypeError ):
1332+ parameter_dict [1 ] = Parameter (torch .randn (10 , 10 ))
1333+
1334+ p_pop = parameter_dict .pop ('p4' )
1335+ self .assertIs (p_pop , parameters ['p4' ])
1336+ parameters .pop ('p4' )
1337+ check ()
1338+
1339+ parameter_dict .clear ()
1340+ self .assertEqual (len (parameter_dict ), 0 )
1341+ parameters .clear ()
1342+ check ()
1343+
11771344 def test_add_module (self ):
11781345 l = nn .Linear (10 , 20 )
11791346 net = nn .Module ()
0 commit comments