|
| 1 | +import warnings |
1 | 2 | from collections import OrderedDict, Iterable |
2 | | -import string |
| 3 | +from itertools import islice |
| 4 | + |
3 | 5 | import torch |
4 | | -import warnings |
5 | 6 | from .module import Module |
6 | 7 |
|
7 | 8 |
|
@@ -49,18 +50,23 @@ def __init__(self, *args): |
49 | 50 | for idx, module in enumerate(args): |
50 | 51 | self.add_module(str(idx), module) |
51 | 52 |
|
| 53 | + def _get_item_by_idx(self, iterator, idx): |
| 54 | + """Get the idx-th item of the iterator""" |
| 55 | + size = len(self) |
| 56 | + if not -size <= idx < size: |
| 57 | + raise IndexError('index {} is out of range'.format(idx)) |
| 58 | + idx %= size |
| 59 | + return next(islice(iterator, idx, None)) |
| 60 | + |
52 | 61 | def __getitem__(self, idx): |
53 | 62 | if isinstance(idx, slice): |
54 | 63 | return Sequential(OrderedDict(list(self._modules.items())[idx])) |
55 | 64 | else: |
56 | | - if not (-len(self) <= idx < len(self)): |
57 | | - raise IndexError('index {} is out of range'.format(idx)) |
58 | | - if idx < 0: |
59 | | - idx += len(self) |
60 | | - it = iter(self._modules.values()) |
61 | | - for i in range(idx): |
62 | | - next(it) |
63 | | - return next(it) |
| 65 | + return self._get_item_by_idx(self._modules.values(), idx) |
| 66 | + |
| 67 | + def __setitem__(self, idx, module): |
| 68 | + key = self._get_item_by_idx(self._modules.keys(), idx) |
| 69 | + return setattr(self, key, module) |
64 | 70 |
|
65 | 71 | def __len__(self): |
66 | 72 | return len(self._modules) |
|
0 commit comments