Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
T = TypeVar('T', bound=Module)


# Copied from torch.nn.modules.module, required for a cusom __repr__ for ModuleList
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s


class Container(Module):

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -312,6 +325,38 @@ def __add__(self, other: Iterable[Module]) -> 'ModuleList':
combined.add_module(str(i), module)
return combined

def __repr__(self):
"""A custom repr for ModuleList that compresses repeated module representations"""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + '()'

start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue

start_end_indices.append([i, i])
repeated_blocks.append(r)

lines = []
main_str = self._get_name() + '('
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr

if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"

local_repr = _addindent(local_repr, 2)
lines.append(local_repr)

main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str

@_copy_to_script_wrapper
def __dir__(self):
keys = super(ModuleList, self).__dir__()
Expand Down