Skip to content

Commit 123f49b

Browse files
vishwakftwapaszke
authored andcommitted
Add Slicing capabilities for Sequential, ModuleList and ParameterList (#4491)
1 parent 9c2561e commit 123f49b

File tree

2 files changed

+42
-18
lines changed

2 files changed

+42
-18
lines changed

test/test_nn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,11 @@ def test_Sequential_getitem(self):
714714
self.assertEqual(n[1], l2)
715715
self.assertEqual(n[2], l3)
716716
self.assertEqual(n[3], l4)
717+
self.assertEqual(n[1:], nn.Sequential(l2, l3, l4))
718+
self.assertEqual(n[3:], nn.Sequential(l4))
719+
self.assertEqual(n[:-1], nn.Sequential(l1, l2, l3))
720+
self.assertEqual(n[:-3], nn.Sequential(l1))
721+
self.assertEqual(n[::-1], nn.Sequential(l4, l3, l2, l1))
717722

718723
def test_ModuleList(self):
719724
modules = [nn.ReLU(), nn.Linear(5, 5)]
@@ -742,6 +747,11 @@ def check():
742747
modules[2] = nn.Conv2d(5, 3, 2)
743748
module_list[2] = modules[2]
744749
check()
750+
self.assertEqual(module_list[1:], nn.ModuleList(modules[1:]))
751+
self.assertEqual(module_list[3:], nn.ModuleList(modules[3:]))
752+
self.assertEqual(module_list[:-1], nn.ModuleList(modules[:-1]))
753+
self.assertEqual(module_list[:-3], nn.ModuleList(modules[:-3]))
754+
self.assertEqual(module_list[::-1], nn.ModuleList(modules[::-1]))
745755

746756
with self.assertRaises(TypeError):
747757
module_list += nn.ReLU()
@@ -796,6 +806,11 @@ def check():
796806
parameters[2] = make_param()
797807
param_list[2] = parameters[2]
798808
check()
809+
self.assertEqual(param_list[1:], nn.ParameterList(parameters[1:]))
810+
self.assertEqual(param_list[3:], nn.ParameterList(parameters[3:]))
811+
self.assertEqual(param_list[:-1], nn.ParameterList(parameters[:-1]))
812+
self.assertEqual(param_list[:-3], nn.ParameterList(parameters[:-3]))
813+
self.assertEqual(param_list[::-1], nn.ParameterList(parameters[::-1]))
799814

800815
with self.assertRaises(TypeError):
801816
param_list += make_param()

torch/nn/modules/container.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,17 @@ def __init__(self, *args):
5050
self.add_module(str(idx), module)
5151

5252
def __getitem__(self, idx):
53-
if not (-len(self) <= idx < len(self)):
54-
raise IndexError('index {} is out of range'.format(idx))
55-
if idx < 0:
56-
idx += len(self)
57-
it = iter(self._modules.values())
58-
for i in range(idx):
59-
next(it)
60-
return next(it)
53+
if isinstance(idx, slice):
54+
return Sequential(OrderedDict(list(self._modules.items())[idx]))
55+
else:
56+
if not (-len(self) <= idx < len(self)):
57+
raise IndexError('index {} is out of range'.format(idx))
58+
if idx < 0:
59+
idx += len(self)
60+
it = iter(self._modules.values())
61+
for i in range(idx):
62+
next(it)
63+
return next(it)
6164

6265
def __len__(self):
6366
return len(self._modules)
@@ -102,11 +105,14 @@ def __init__(self, modules=None):
102105
self += modules
103106

104107
def __getitem__(self, idx):
105-
if not (-len(self) <= idx < len(self)):
106-
raise IndexError('index {} is out of range'.format(idx))
107-
if idx < 0:
108-
idx += len(self)
109-
return self._modules[str(idx)]
108+
if isinstance(idx, slice):
109+
return ModuleList(list(self._modules.values())[idx])
110+
else:
111+
if not (-len(self) <= idx < len(self)):
112+
raise IndexError('index {} is out of range'.format(idx))
113+
if idx < 0:
114+
idx += len(self)
115+
return self._modules[str(idx)]
110116

111117
def __setitem__(self, idx, module):
112118
return setattr(self, str(idx), module)
@@ -178,11 +184,14 @@ def __init__(self, parameters=None):
178184
self += parameters
179185

180186
def __getitem__(self, idx):
181-
if not (-len(self) <= idx < len(self)):
182-
raise IndexError('index {} is out of range'.format(idx))
183-
if idx < 0:
184-
idx += len(self)
185-
return self._parameters[str(idx)]
187+
if isinstance(idx, slice):
188+
return ParameterList(list(self._parameters.values())[idx])
189+
else:
190+
if not (-len(self) <= idx < len(self)):
191+
raise IndexError('index {} is out of range'.format(idx))
192+
if idx < 0:
193+
idx += len(self)
194+
return self._parameters[str(idx)]
186195

187196
def __setitem__(self, idx, param):
188197
return self.register_parameter(str(idx), param)

0 commit comments

Comments
 (0)