|
99 | 99 | from contextlib import contextmanager |
100 | 100 | from functools import partial, reduce |
101 | 101 | from copy import deepcopy |
| 102 | +import itertools |
102 | 103 | from math import floor, log10 |
103 | 104 | import os |
104 | 105 | import re |
@@ -223,10 +224,13 @@ def _shift_labels(obj, label_type, args): |
223 | 224 | return obj.add_labels(label_type, args) |
224 | 225 |
|
225 | 226 | if all(_is_label(arg) for arg in args): |
226 | | - return obj.add_labels(label_type, [arg for arg in args]) |
| 227 | + return obj.add_labels(label_type, args) |
227 | 228 |
|
228 | | - assert len(args) == 2 and _is_label(args[0]) |
229 | | - return obj.add_labels(label_type, {args[1]: args[0]}) |
| 229 | + is_dst = label_type == "dst" |
| 230 | + assert len(args) == 2 and _is_label(args[0 if is_dst else 1]) |
| 231 | + return obj.add_labels( |
| 232 | + label_type, {obj._resolve_index(is_dst, args[is_dst]): args[not is_dst]} |
| 233 | + ) |
230 | 234 |
|
231 | 235 |
|
232 | 236 | ################################################################################################### |
@@ -2543,9 +2547,16 @@ def add_labels(self, pad_type, labels): |
2543 | 2547 | pad = fg._resolve_index(is_input, None) |
2544 | 2548 | fg.add_label(label, **{pad_type: pad}) |
2545 | 2549 | else: |
2546 | | - for label in labels: |
2547 | | - pad = fg._resolve_index(is_input, None) |
2548 | | - fg.add_label(label, **{pad_type: pad}) |
| 2550 | + pads = list( |
| 2551 | + itertools.islice( |
| 2552 | + fg.iter_input_pads(exclude_named=True) |
| 2553 | + if pad_type == "dst" |
| 2554 | + else fg.iter_output_pads(exclude_named=True), |
| 2555 | + len(labels), |
| 2556 | + ) |
| 2557 | + ) |
| 2558 | + for label, pad in zip(labels, pads): |
| 2559 | + fg.add_label(label, **{pad_type: pad[0]}) |
2549 | 2560 | return fg |
2550 | 2561 |
|
2551 | 2562 | @contextmanager |
|
0 commit comments