Skip to content

Commit f796080

Browse files
Stonesjtuapaszke
authored andcommitted
Add assignment support for Sequential (#4931)
1 parent f160e55 commit f796080

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

test/test_nn.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,33 @@ def test_Sequential_getitem(self):
788788
self.assertEqual(n[:-3], nn.Sequential(l1))
789789
self.assertEqual(n[::-1], nn.Sequential(l4, l3, l2, l1))
790790

791+
def test_Sequential_setitem(self):
792+
l1 = nn.Linear(10, 20)
793+
l2 = nn.Linear(20, 30)
794+
l3 = nn.Linear(30, 40)
795+
l4 = nn.Linear(40, 50)
796+
n = nn.Sequential(l1, l2, l3)
797+
n[0] = l4
798+
n[-1] = l4
799+
self.assertEqual(n[0], l4)
800+
self.assertEqual(n[2], l4)
801+
802+
def test_Sequential_setitem_named(self):
803+
l1 = nn.Linear(10, 20)
804+
l2 = nn.Linear(20, 30)
805+
l3 = nn.Linear(30, 40)
806+
l4 = nn.Linear(40, 50)
807+
n = nn.Sequential(OrderedDict([
808+
('linear1', l1),
809+
('linear2', l2),
810+
('linear3', l3),
811+
]))
812+
813+
n[0] = l4
814+
n[-1] = l4
815+
self.assertEqual(n.linear1, l4)
816+
self.assertEqual(n.linear3, l4)
817+
791818
def test_ModuleList(self):
792819
modules = [nn.ReLU(), nn.Linear(5, 5)]
793820
module_list = nn.ModuleList(modules)

torch/nn/modules/container.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import warnings
12
from collections import OrderedDict, Iterable
2-
import string
3+
from itertools import islice
4+
35
import torch
4-
import warnings
56
from .module import Module
67

78

@@ -49,18 +50,23 @@ def __init__(self, *args):
4950
for idx, module in enumerate(args):
5051
self.add_module(str(idx), module)
5152

53+
def _get_item_by_idx(self, iterator, idx):
54+
"""Get the idx-th item of the iterator"""
55+
size = len(self)
56+
if not -size <= idx < size:
57+
raise IndexError('index {} is out of range'.format(idx))
58+
idx %= size
59+
return next(islice(iterator, idx, None))
60+
5261
def __getitem__(self, idx):
5362
if isinstance(idx, slice):
5463
return Sequential(OrderedDict(list(self._modules.items())[idx]))
5564
else:
56-
if not (-len(self) <= idx < len(self)):
57-
raise IndexError('index {} is out of range'.format(idx))
58-
if idx < 0:
59-
idx += len(self)
60-
it = iter(self._modules.values())
61-
for i in range(idx):
62-
next(it)
63-
return next(it)
65+
return self._get_item_by_idx(self._modules.values(), idx)
66+
67+
def __setitem__(self, idx, module):
68+
key = self._get_item_by_idx(self._modules.keys(), idx)
69+
return setattr(self, key, module)
6470

6571
def __len__(self):
6672
return len(self._modules)

0 commit comments

Comments
 (0)