Skip to content

Commit 0b255b3

Browse files
Guitaricetpytorchmergebot
authored andcommitted
Better __repr__ for ModuleList (#90452)
## Problem When models have a lot of complex repeated layers, `print(module)` output becomes unfeasible to work with. For example, current output of `__repr__` for `t5-small` is `715 ` lines long. ## Solution Using better `__repr__` it becomes `135`. For `t5-large`, current `__repr__` prints `1411` lines. Better `__repr__` — `135`. Same numer as for t5-small, because most of the layers are just repeated. For `EleutherAI/gpt-j-6B` number of lines reduces form `483` to just `24`. Here's how it works: when ModuleList items have exactly the same `__repr__` instead of printing both of them, it prints f`N x {repr(item)}`. Current code supports cases when the same ModuleList has multiple repeating items, which is especially useful when first/last layer of a block is different from the reset of them. Better `__repr__` should make model prints smaller, more beautiful and significantly more useful by highlighting the difference between repeated blocks instead of losing it in a wall of text. ## Motivating real-life example. You can try it out in this [colab notebook](https://colab.research.google.com/drive/1PscpX_K1UemIDotl2raC4QMy_pTqDq7p?usp=sharing). Current `__repr__` of gpt-j-6b output it too big to add it to this PR description: ``` GPTJModel( (wte): Embedding(50400, 4096) (drop): Dropout(p=0.0, inplace=False) (h): ModuleList( (0): GPTJBlock( (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) (attn): GPTJAttention( (attn_dropout): Dropout(p=0.0, inplace=False) (resid_dropout): Dropout(p=0.0, inplace=False) (k_proj): Linear(in_features=4096, out_features=4096, bias=False) (v_proj): Linear(in_features=4096, out_features=4096, bias=False) (q_proj): Linear(in_features=4096, out_features=4096, bias=False) (out_proj): Linear(in_features=4096, out_features=4096, bias=False) ) (mlp): GPTJMLP( (fc_in): Linear(in_features=4096, out_features=16384, bias=True) (fc_out): Linear(in_features=16384, out_features=4096, bias=True) (act): NewGELUActivation() (dropout): Dropout(p=0.0, inplace=False) ) ) (1): GPTJBlock( (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) (attn): GPTJAttention( (attn_dropout): Dropout(p=0.0, inplace=False) (resid_dropout): Dropout(p=0.0, inplace=False) (k_proj): Linear(in_features=4096, out_features=4096, bias=False) (v_proj): Linear(in_features=4096, out_features=4096, bias=False) (q_proj): Linear(in_features=4096, out_features=4096, bias=False) (out_proj): Linear(in_features=4096, out_features=4096, bias=False) ) (mlp): GPTJMLP( (fc_in): Linear(in_features=4096, out_features=16384, bias=True) (fc_out): Linear(in_features=16384, out_features=4096, bias=True) (act): NewGELUActivation() (dropout): Dropout(p=0.0, inplace=False) ) ) (2): GPTJBlock( ... ``` Better `__repr__` output looks like this: ``` GPTJModel( (wte): Embedding(50400, 4096) (drop): Dropout(p=0.0, inplace=False) (h): ModuleList( (0-27): 28 x GPTJBlock( (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) (attn): GPTJAttention( (attn_dropout): Dropout(p=0.0, inplace=False) (resid_dropout): Dropout(p=0.0, inplace=False) (k_proj): Linear(in_features=4096, out_features=4096, bias=False) (v_proj): Linear(in_features=4096, out_features=4096, bias=False) (q_proj): Linear(in_features=4096, out_features=4096, bias=False) (out_proj): Linear(in_features=4096, out_features=4096, bias=False) ) (mlp): GPTJMLP( (fc_in): Linear(in_features=4096, out_features=16384, bias=True) (fc_out): Linear(in_features=16384, out_features=4096, bias=True) (act): NewGELUActivation() (dropout): Dropout(p=0.0, inplace=False) ) ) ) (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) ) ``` Pull Request resolved: #90452 Approved by: https://github.com/albanD
1 parent 57dcd93 commit 0b255b3

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

torch/nn/modules/container.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
T = TypeVar('T', bound=Module)
1616

1717

18+
# Copied from torch.nn.modules.module, required for a cusom __repr__ for ModuleList
19+
def _addindent(s_, numSpaces):
20+
s = s_.split('\n')
21+
# don't do anything for single-line stuff
22+
if len(s) == 1:
23+
return s_
24+
first = s.pop(0)
25+
s = [(numSpaces * ' ') + line for line in s]
26+
s = '\n'.join(s)
27+
s = first + '\n' + s
28+
return s
29+
30+
1831
class Container(Module):
1932

2033
def __init__(self, **kwargs: Any) -> None:
@@ -312,6 +325,38 @@ def __add__(self, other: Iterable[Module]) -> 'ModuleList':
312325
combined.add_module(str(i), module)
313326
return combined
314327

328+
def __repr__(self):
329+
"""A custom repr for ModuleList that compresses repeated module representations"""
330+
list_of_reprs = [repr(item) for item in self]
331+
if len(list_of_reprs) == 0:
332+
return self._get_name() + '()'
333+
334+
start_end_indices = [[0, 0]]
335+
repeated_blocks = [list_of_reprs[0]]
336+
for i, r in enumerate(list_of_reprs[1:], 1):
337+
if r == repeated_blocks[-1]:
338+
start_end_indices[-1][1] += 1
339+
continue
340+
341+
start_end_indices.append([i, i])
342+
repeated_blocks.append(r)
343+
344+
lines = []
345+
main_str = self._get_name() + '('
346+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
347+
local_repr = f"({start_id}): {b}" # default repr
348+
349+
if start_id != end_id:
350+
n = end_id - start_id + 1
351+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
352+
353+
local_repr = _addindent(local_repr, 2)
354+
lines.append(local_repr)
355+
356+
main_str += '\n ' + '\n '.join(lines) + '\n'
357+
main_str += ')'
358+
return main_str
359+
315360
@_copy_to_script_wrapper
316361
def __dir__(self):
317362
keys = super(ModuleList, self).__dir__()

0 commit comments

Comments
 (0)