Skip to content

Commit 97008a6

Browse files
karandwivedi42facebook-github-bot
authored andcommitted
Add ModuleDict and ParameterDict containers (#8463)
Summary: Addresses: #4048 and #5297 (comment) Pull Request resolved: #8463 Reviewed By: SsnL Differential Revision: D8689291 Pulled By: ezyang fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
1 parent cffca29 commit 97008a6

File tree

4 files changed

+423
-4
lines changed

4 files changed

+423
-4
lines changed

docs/source/nn.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,24 @@ Containers
3434
.. autoclass:: ModuleList
3535
:members:
3636

37+
:hidden:`ModuleDict`
38+
~~~~~~~~~~~~~~~~~~~~
39+
40+
.. autoclass:: ModuleDict
41+
:members:
42+
3743
:hidden:`ParameterList`
3844
~~~~~~~~~~~~~~~~~~~~~~~
3945

4046
.. autoclass:: ParameterList
4147
:members:
4248

49+
:hidden:`ParameterDict`
50+
~~~~~~~~~~~~~~~~~~~~~~~
51+
52+
.. autoclass:: ParameterDict
53+
:members:
54+
4355
Convolution layers
4456
----------------------------------
4557

test/test_nn.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,91 @@ def check():
11101110
module_list.extend(s.modules())
11111111
check()
11121112

1113+
def test_ModuleDict(self):
1114+
modules = OrderedDict([
1115+
('act', nn.ReLU()),
1116+
('conv', nn.Conv2d(10, 10, 5)),
1117+
('fc', nn.Linear(5, 5)),
1118+
])
1119+
1120+
module_dict = nn.ModuleDict(modules)
1121+
1122+
def check():
1123+
self.assertEqual(len(module_dict), len(modules))
1124+
for k1, m2 in zip(modules, module_dict.children()):
1125+
self.assertIs(modules[k1], m2)
1126+
for k1, k2 in zip(modules, module_dict):
1127+
self.assertIs(modules[k1], module_dict[k2])
1128+
for k in module_dict:
1129+
self.assertIs(module_dict[k], modules[k])
1130+
for k in module_dict.keys():
1131+
self.assertIs(module_dict[k], modules[k])
1132+
for k, v in module_dict.items():
1133+
self.assertIs(modules[k], v)
1134+
for k1, m2 in zip(modules, module_dict.values()):
1135+
self.assertIs(modules[k1], m2)
1136+
for k in modules.keys():
1137+
self.assertTrue(k in module_dict)
1138+
check()
1139+
1140+
modules['conv'] = nn.Conv2d(3, 4, 3)
1141+
module_dict['conv'] = modules['conv']
1142+
check()
1143+
1144+
next_modules = [
1145+
('fc2', nn.Linear(5, 5)),
1146+
('act', nn.Sigmoid()),
1147+
]
1148+
modules.update(next_modules)
1149+
module_dict.update(next_modules)
1150+
check()
1151+
1152+
next_modules = OrderedDict([
1153+
('fc3', nn.Linear(5, 5)),
1154+
('act2', nn.Sigmoid()),
1155+
])
1156+
modules.update(next_modules)
1157+
module_dict.update(next_modules)
1158+
check()
1159+
1160+
next_modules = {
1161+
'fc4': nn.Linear(5, 5),
1162+
'act3': nn.Sigmoid()
1163+
}
1164+
modules.update(sorted(next_modules.items()))
1165+
module_dict.update(next_modules)
1166+
check()
1167+
1168+
del module_dict['fc']
1169+
del modules['fc']
1170+
check()
1171+
1172+
with self.assertRaises(TypeError):
1173+
module_dict.update(nn.ReLU())
1174+
1175+
with self.assertRaises(TypeError):
1176+
module_dict.update([nn.ReLU()])
1177+
1178+
with self.assertRaises(ValueError):
1179+
module_dict.update([[nn.ReLU()]])
1180+
1181+
with self.assertRaises(TypeError):
1182+
module_dict[1] = nn.ReLU()
1183+
1184+
s = nn.Sequential(modules)
1185+
module_dict = nn.ModuleDict(s.named_children())
1186+
check()
1187+
1188+
c = module_dict.pop('conv')
1189+
self.assertIs(c, modules['conv'])
1190+
modules.pop('conv')
1191+
check()
1192+
1193+
module_dict.clear()
1194+
self.assertEqual(len(module_dict), 0)
1195+
modules.clear()
1196+
check()
1197+
11131198
def test_ParameterList(self):
11141199
def make_param():
11151200
return Parameter(torch.randn(10, 10))
@@ -1174,6 +1259,88 @@ def check():
11741259
param_list.extend(s.parameters())
11751260
check()
11761261

1262+
def test_ParameterDict(self):
1263+
parameters = OrderedDict([
1264+
('p1', Parameter(torch.randn(10, 10))),
1265+
('p2', Parameter(torch.randn(10, 10))),
1266+
('p3', Parameter(torch.randn(10, 10))),
1267+
])
1268+
1269+
parameter_dict = nn.ParameterDict(parameters)
1270+
1271+
def check():
1272+
self.assertEqual(len(parameter_dict), len(parameters))
1273+
for k1, m2 in zip(parameters, parameter_dict.parameters()):
1274+
self.assertIs(parameters[k1], m2)
1275+
for k1, k2 in zip(parameters, parameter_dict):
1276+
self.assertIs(parameters[k1], parameter_dict[k2])
1277+
for k in parameter_dict:
1278+
self.assertIs(parameter_dict[k], parameters[k])
1279+
for k in parameter_dict.keys():
1280+
self.assertIs(parameter_dict[k], parameters[k])
1281+
for k, v in parameter_dict.items():
1282+
self.assertIs(v, parameters[k])
1283+
for k1, m2 in zip(parameters, parameter_dict.values()):
1284+
self.assertIs(parameters[k1], m2)
1285+
for k in parameters.keys():
1286+
self.assertTrue(k in parameter_dict)
1287+
1288+
check()
1289+
1290+
parameters['p4'] = Parameter(torch.randn(10, 10))
1291+
parameter_dict['p4'] = parameters['p4']
1292+
check()
1293+
1294+
next_parameters = [
1295+
('p5', Parameter(torch.randn(10, 10))),
1296+
('p2', Parameter(torch.randn(10, 10))),
1297+
]
1298+
parameters.update(next_parameters)
1299+
parameter_dict.update(next_parameters)
1300+
check()
1301+
1302+
next_parameters = OrderedDict([
1303+
('p6', Parameter(torch.randn(10, 10))),
1304+
('p5', Parameter(torch.randn(10, 10))),
1305+
])
1306+
parameters.update(next_parameters)
1307+
parameter_dict.update(next_parameters)
1308+
check()
1309+
1310+
next_parameters = {
1311+
'p8': Parameter(torch.randn(10, 10)),
1312+
'p7': Parameter(torch.randn(10, 10))
1313+
}
1314+
parameters.update(sorted(next_parameters.items()))
1315+
parameter_dict.update(next_parameters)
1316+
check()
1317+
1318+
del parameter_dict['p3']
1319+
del parameters['p3']
1320+
check()
1321+
1322+
with self.assertRaises(TypeError):
1323+
parameter_dict.update(1)
1324+
1325+
with self.assertRaises(TypeError):
1326+
parameter_dict.update([1])
1327+
1328+
with self.assertRaises(ValueError):
1329+
parameter_dict.update(Parameter(torch.randn(10, 10)))
1330+
1331+
with self.assertRaises(TypeError):
1332+
parameter_dict[1] = Parameter(torch.randn(10, 10))
1333+
1334+
p_pop = parameter_dict.pop('p4')
1335+
self.assertIs(p_pop, parameters['p4'])
1336+
parameters.pop('p4')
1337+
check()
1338+
1339+
parameter_dict.clear()
1340+
self.assertEqual(len(parameter_dict), 0)
1341+
parameters.clear()
1342+
check()
1343+
11771344
def test_add_module(self):
11781345
l = nn.Linear(10, 20)
11791346
net = nn.Module()

torch/nn/modules/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
CosineEmbeddingLoss, HingeEmbeddingLoss, MarginRankingLoss, \
1010
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
1111
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
12-
from .container import Container, Sequential, ModuleList, ParameterList
12+
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
1313
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
1414
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, LPPool1d, LPPool2d, AdaptiveMaxPool1d, \
1515
AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
@@ -36,8 +36,8 @@
3636
'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
3737
'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
3838
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss',
39-
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList',
40-
'ParameterList', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
39+
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
40+
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
4141
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d',
4242
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
4343
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',

0 commit comments

Comments
 (0)