Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
51409e1
Add ModuleDict
karandwivedi42 Jun 13, 2018
cd017bd
Better clear
karandwivedi42 Jun 13, 2018
ce95158
Fix flake8
karandwivedi42 Jun 13, 2018
e67eefe
Fix update from dict
karandwivedi42 Jun 14, 2018
55308e3
Fix tests
karandwivedi42 Jun 14, 2018
6637306
Trigger build
karandwivedi42 Jun 14, 2018
9bc5d67
Fix test and remove str
karandwivedi42 Jun 14, 2018
6cd94df
oops
karandwivedi42 Jun 14, 2018
7430b25
trigger build
karandwivedi42 Jun 14, 2018
b22491b
Check for valid iterable in update
karandwivedi42 Jun 15, 2018
8c6290f
lint
karandwivedi42 Jun 15, 2018
7f29793
Check for valid iterable in update
karandwivedi42 Jun 15, 2018
c94f908
Trigger build
karandwivedi42 Jun 18, 2018
79578f9
Remove initialization from named_paramters and other fixes
karandwivedi42 Jun 18, 2018
bc7a2a2
Trigger build
karandwivedi42 Jun 18, 2018
9b6b56c
trigger build
karandwivedi42 Jun 19, 2018
376d52b
Trigger build
karandwivedi42 Jun 20, 2018
ca8746d
Disallow update from dict
karandwivedi42 Jun 21, 2018
a5ee8c6
Add sorting instead for more friendly initialization
karandwivedi42 Jun 21, 2018
fc1ec62
remove redundant checks
karandwivedi42 Jun 21, 2018
326419a
trigger build
karandwivedi42 Jun 22, 2018
df82ee8
Only sort if not ordereddict
karandwivedi42 Jun 22, 2018
c8c38dc
fix test
karandwivedi42 Jun 22, 2018
544816e
Add key should be string
karandwivedi42 Jun 26, 2018
731a13c
oops
karandwivedi42 Jun 26, 2018
594d20b
oops
karandwivedi42 Jun 26, 2018
302e5ed
Remove checks because they will be added to module.py
karandwivedi42 Jun 26, 2018
b0c28ec
oops
karandwivedi42 Jun 26, 2018
683841f
Remove extra newline at end
karandwivedi42 Jun 29, 2018
008f808
Merge branch 'master' into moduledict
karandwivedi42 Jun 29, 2018
80112e1
trigger build
karandwivedi42 Jul 2, 2018
e67716c
Remove unicode
karandwivedi42 Jul 8, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,24 @@ Containers
.. autoclass:: ModuleList
:members:

:hidden:`ModuleDict`
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ModuleDict
:members:

:hidden:`ParameterList`
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ParameterList
:members:

:hidden:`ParameterDict`
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ParameterDict
:members:

Convolution layers
----------------------------------

Expand Down
167 changes: 167 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,91 @@ def check():
module_list.extend(s.modules())
check()

def test_ModuleDict(self):
modules = OrderedDict([
('act', nn.ReLU()),
('conv', nn.Conv2d(10, 10, 5)),
('fc', nn.Linear(5, 5)),
])

module_dict = nn.ModuleDict(modules)

def check():
self.assertEqual(len(module_dict), len(modules))
for k1, m2 in zip(modules, module_dict.children()):
self.assertIs(modules[k1], m2)
for k1, k2 in zip(modules, module_dict):
self.assertIs(modules[k1], module_dict[k2])
for k in module_dict:
self.assertIs(module_dict[k], modules[k])
for k in module_dict.keys():
self.assertIs(module_dict[k], modules[k])
for k, v in module_dict.items():
self.assertIs(modules[k], v)
for k1, m2 in zip(modules, module_dict.values()):
self.assertIs(modules[k1], m2)
for k in modules.keys():
self.assertTrue(k in module_dict)
check()

modules['conv'] = nn.Conv2d(3, 4, 3)
module_dict['conv'] = modules['conv']
check()

next_modules = [
('fc2', nn.Linear(5, 5)),
('act', nn.Sigmoid()),
]
modules.update(next_modules)
module_dict.update(next_modules)
check()

next_modules = OrderedDict([
('fc3', nn.Linear(5, 5)),
('act2', nn.Sigmoid()),
])
modules.update(next_modules)
module_dict.update(next_modules)
check()

next_modules = {
'fc4': nn.Linear(5, 5),
'act3': nn.Sigmoid()
}
modules.update(sorted(next_modules.items()))
module_dict.update(next_modules)
check()

del module_dict['fc']
del modules['fc']
check()

with self.assertRaises(TypeError):
module_dict.update(nn.ReLU())

with self.assertRaises(TypeError):
module_dict.update([nn.ReLU()])

with self.assertRaises(ValueError):
module_dict.update([[nn.ReLU()]])

with self.assertRaises(TypeError):
module_dict[1] = nn.ReLU()

s = nn.Sequential(modules)
module_dict = nn.ModuleDict(s.named_children())
check()

c = module_dict.pop('conv')
self.assertIs(c, modules['conv'])
modules.pop('conv')
check()

module_dict.clear()
self.assertEqual(len(module_dict), 0)
modules.clear()
check()

def test_ParameterList(self):
def make_param():
return Parameter(torch.randn(10, 10))
Expand Down Expand Up @@ -1176,6 +1261,88 @@ def check():
param_list.extend(s.parameters())
check()

def test_ParameterDict(self):
parameters = OrderedDict([
('p1', Parameter(torch.randn(10, 10))),
('p2', Parameter(torch.randn(10, 10))),
('p3', Parameter(torch.randn(10, 10))),
])

parameter_dict = nn.ParameterDict(parameters)

def check():
self.assertEqual(len(parameter_dict), len(parameters))
for k1, m2 in zip(parameters, parameter_dict.parameters()):
self.assertIs(parameters[k1], m2)
for k1, k2 in zip(parameters, parameter_dict):
self.assertIs(parameters[k1], parameter_dict[k2])
for k in parameter_dict:
self.assertIs(parameter_dict[k], parameters[k])
for k in parameter_dict.keys():
self.assertIs(parameter_dict[k], parameters[k])
for k, v in parameter_dict.items():
self.assertIs(v, parameters[k])
for k1, m2 in zip(parameters, parameter_dict.values()):
self.assertIs(parameters[k1], m2)
for k in parameters.keys():
self.assertTrue(k in parameter_dict)

check()

parameters['p4'] = Parameter(torch.randn(10, 10))
parameter_dict['p4'] = parameters['p4']
check()

next_parameters = [
('p5', Parameter(torch.randn(10, 10))),
('p2', Parameter(torch.randn(10, 10))),
]
parameters.update(next_parameters)
parameter_dict.update(next_parameters)
check()

next_parameters = OrderedDict([
('p6', Parameter(torch.randn(10, 10))),
('p5', Parameter(torch.randn(10, 10))),
])
parameters.update(next_parameters)
parameter_dict.update(next_parameters)
check()

next_parameters = {
'p8': Parameter(torch.randn(10, 10)),
'p7': Parameter(torch.randn(10, 10))
}
parameters.update(sorted(next_parameters.items()))
parameter_dict.update(next_parameters)
check()

del parameter_dict['p3']
del parameters['p3']
check()

with self.assertRaises(TypeError):
parameter_dict.update(1)

with self.assertRaises(TypeError):
parameter_dict.update([1])

with self.assertRaises(ValueError):
parameter_dict.update(Parameter(torch.randn(10, 10)))

with self.assertRaises(TypeError):
parameter_dict[1] = Parameter(torch.randn(10, 10))

p_pop = parameter_dict.pop('p4')
self.assertIs(p_pop, parameters['p4'])
parameters.pop('p4')
check()

parameter_dict.clear()
self.assertEqual(len(parameter_dict), 0)
parameters.clear()
check()

def test_add_module(self):
l = nn.Linear(10, 20)
net = nn.Module()
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
CosineEmbeddingLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ParameterList
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, LPPool1d, LPPool2d, AdaptiveMaxPool1d, \
AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
Expand All @@ -36,8 +36,8 @@
'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss',
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList',
'ParameterList', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d',
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',
Expand Down
Loading