|
14 | 14 | # ============================================================================== |
15 | 15 | """Utilities for graph plugin.""" |
16 | 16 |
|
| 17 | +from tensorboard.compat.proto import graph_pb2 |
17 | 18 |
|
18 | | -class _ProtoListDuplicateKeyError(Exception): |
19 | | - pass |
20 | 19 |
|
| 20 | +def _prefixed_op_name(prefix, op_name): |
| 21 | + return "%s/%s" % (prefix, op_name) |
21 | 22 |
|
22 | | -class _SameKeyDiffContentError(Exception): |
23 | | - pass |
24 | 23 |
|
| 24 | +def _prefixed_func_name(prefix, func_name): |
| 25 | + # TODO(stephanwlee): add business logic to strip "__inference_". |
| 26 | + return "%s_%s" % (prefix, func_name) |
25 | 27 |
|
26 | | -def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key): |
27 | | - """Safely merge values from `src_proto_list` into `dst_proto_list`. |
28 | 28 |
|
29 | | - Each element in `dst_proto_list` must be mapped by `get_key` to a key |
30 | | - value that is unique within that list; likewise for `src_proto_list`. |
31 | | - If an element of `src_proto_list` has the same key as an existing |
32 | | - element in `dst_proto_list`, then the elements must also be equal. |
| 29 | +def _prepend_names(prefix, orig_graph_def): |
| 30 | + mut_graph_def = graph_pb2.GraphDef() |
| 31 | + for node in orig_graph_def.node: |
| 32 | + new_node = mut_graph_def.node.add() |
| 33 | + new_node.CopyFrom(node) |
| 34 | + new_node.name = _prefixed_op_name(prefix, node.name) |
| 35 | + new_node.input[:] = [ |
| 36 | + _prefixed_op_name(prefix, input_name) for input_name in node.input |
| 37 | + ] |
33 | 38 |
|
34 | | - Args: |
35 | | - dst_proto_list: A `RepeatedCompositeContainer` or |
36 | | - `RepeatedScalarContainer` into which values should be copied. |
37 | | - src_proto_list: A container holding the same kind of values as in |
38 | | - `dst_proto_list` from which values should be copied. |
39 | | - get_key: A function that takes an element of `dst_proto_list` or |
40 | | - `src_proto_list` and returns a key, such that if two elements have |
41 | | - the same key then it is required that they be deep-equal. For |
42 | | - instance, if `dst_proto_list` is a list of nodes, then `get_key` |
43 | | - might be `lambda node: node.name` to indicate that if two nodes |
44 | | - have the same name then they must be the same node. All keys must |
45 | | - be hashable. |
| 39 | + # Remap tf.function method name in the PartitionedCall. 'f' is short for |
| 40 | + # function. |
| 41 | + if new_node.op == "PartitionedCall" and new_node.attr["f"]: |
| 42 | + |
| 43 | + new_node.attr["f"].func.name = _prefixed_func_name( |
| 44 | + prefix, new_node.attr["f"].func.name, |
| 45 | + ) |
| 46 | + |
| 47 | + for func in orig_graph_def.library.function: |
| 48 | + new_func = mut_graph_def.library.function.add() |
| 49 | + new_func.CopyFrom(func) |
| 50 | + # Not creating a structure out of factored out function. They already |
| 51 | + # create an awkward hierarchy and one for each graph. |
| 52 | + new_func.signature.name = _prefixed_func_name( |
| 53 | + prefix, new_func.signature.name |
| 54 | + ) |
| 55 | + |
| 56 | + for gradient in orig_graph_def.library.gradient: |
| 57 | + new_gradient = mut_graph_def.library.gradient.add() |
| 58 | + new_gradient.CopyFrom(gradient) |
| 59 | + new_gradient.function_name = _prefixed_func_name( |
| 60 | + prefix, new_gradient.function_name, |
| 61 | + ) |
| 62 | + new_gradient.gradient_func = _prefixed_func_name( |
| 63 | + prefix, new_gradient.gradient_func, |
| 64 | + ) |
| 65 | + |
| 66 | + return mut_graph_def |
46 | 67 |
|
47 | | - Raises: |
48 | | - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
49 | | - keys. |
50 | | - _SameKeyDiffContentError: An item with the same key has different contents. |
51 | | - """ |
52 | 68 |
|
53 | | - def _assert_proto_container_unique_keys(proto_list, get_key): |
54 | | - """Asserts proto_list to only contains unique keys. |
55 | | -
|
56 | | - Args: |
57 | | - proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`. |
58 | | - get_key: A function that takes an element of `proto_list` and returns a |
59 | | - hashable key. |
60 | | -
|
61 | | - Raises: |
62 | | - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
63 | | - keys. |
64 | | - """ |
65 | | - keys = set() |
66 | | - for item in proto_list: |
67 | | - key = get_key(item) |
68 | | - if key in keys: |
69 | | - raise _ProtoListDuplicateKeyError(key) |
70 | | - keys.add(key) |
71 | | - |
72 | | - _assert_proto_container_unique_keys(dst_proto_list, get_key) |
73 | | - _assert_proto_container_unique_keys(src_proto_list, get_key) |
74 | | - |
75 | | - key_to_proto = {} |
76 | | - for proto in dst_proto_list: |
77 | | - key = get_key(proto) |
78 | | - key_to_proto[key] = proto |
79 | | - |
80 | | - for proto in src_proto_list: |
81 | | - key = get_key(proto) |
82 | | - if key in key_to_proto: |
83 | | - if proto != key_to_proto.get(key): |
84 | | - raise _SameKeyDiffContentError(key) |
85 | | - else: |
86 | | - dst_proto_list.add().CopyFrom(proto) |
87 | | - |
88 | | - |
89 | | -def combine_graph_defs(to_proto, from_proto): |
90 | | - """Combines two GraphDefs by adding nodes from from_proto into to_proto. |
| 69 | +def merge_graph_defs(graph_defs): |
| 70 | + """Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names. |
91 | 71 |
|
92 | 72 | All GraphDefs are expected to be of TensorBoard's. |
93 | | - It assumes node names are unique across GraphDefs if contents differ. The |
94 | | - names can be the same if the NodeDef content are exactly the same. |
| 73 | +
|
| 74 | + When collecting graphs using the `tf.summary.trace` API, node names are not |
| 75 | + guranteed to be unique. When non-unique names are not considered, it can |
| 76 | + lead to graph visualization showing them as one which creates inaccurate |
| 77 | + depiction of the flow of the graph (e.g., if there are A -> B -> C and D -> |
| 78 | + B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked |
| 79 | + for uniquenss while merging but it resulted in |
| 80 | + https://github.com/tensorflow/tensorboard/issues/1929. |
| 81 | +
|
| 82 | + To remedy these issues, we simply "apply name scope" on each graph by |
| 83 | + prefixing it with unique name (with a chance of collision) to create |
| 84 | + unconnected group of graphs. |
| 85 | +
|
| 86 | + In case there is only one graph def passed, it returns the original |
| 87 | + graph_def. In case no graph defs are passed, it returns an empty GraphDef. |
95 | 88 |
|
96 | 89 | Args: |
97 | | - to_proto: A destination TensorBoard GraphDef. |
98 | | - from_proto: A TensorBoard GraphDef to copy contents from. |
| 90 | + graph_defs: TensorBoard GraphDefs to merge. |
99 | 91 |
|
100 | 92 | Returns: |
101 | | - to_proto |
| 93 | + TensorBoard GraphDef that merges all graph_defs with unique prefixes. |
102 | 94 |
|
103 | 95 | Raises: |
104 | | - ValueError in case any assumption about GraphDef is violated: A |
105 | | - GraphDef should have unique node, function, and gradient function |
106 | | - names. Also, when merging GraphDefs, they should have not have nodes, |
107 | | - functions, or gradient function mappings that share the name but details |
108 | | - do not match. |
| 96 | + ValueError in case GraphDef versions mismatch. |
109 | 97 | """ |
110 | | - if from_proto.version != to_proto.version: |
111 | | - raise ValueError("Cannot combine GraphDefs of different versions.") |
| 98 | + if len(graph_defs) == 1: |
| 99 | + return graph_defs[0] |
| 100 | + elif len(graph_defs) == 0: |
| 101 | + return graph_pb2.GraphDef() |
112 | 102 |
|
113 | | - try: |
114 | | - _safe_copy_proto_list_values( |
115 | | - to_proto.node, from_proto.node, lambda n: n.name |
116 | | - ) |
117 | | - except _ProtoListDuplicateKeyError as exc: |
118 | | - raise ValueError("A GraphDef contains non-unique node names: %s" % exc) |
119 | | - except _SameKeyDiffContentError as exc: |
120 | | - raise ValueError( |
121 | | - ( |
122 | | - "Cannot combine GraphDefs because nodes share a name " |
123 | | - "but contents are different: %s" |
124 | | - ) |
125 | | - % exc |
126 | | - ) |
127 | | - try: |
128 | | - _safe_copy_proto_list_values( |
129 | | - to_proto.library.function, |
130 | | - from_proto.library.function, |
131 | | - lambda n: n.signature.name, |
132 | | - ) |
133 | | - except _ProtoListDuplicateKeyError as exc: |
134 | | - raise ValueError( |
135 | | - "A GraphDef contains non-unique function names: %s" % exc |
136 | | - ) |
137 | | - except _SameKeyDiffContentError as exc: |
138 | | - raise ValueError( |
139 | | - ( |
140 | | - "Cannot combine GraphDefs because functions share a name " |
141 | | - "but are different: %s" |
142 | | - ) |
143 | | - % exc |
144 | | - ) |
| 103 | + dst_graph_def = graph_pb2.GraphDef() |
145 | 104 |
|
146 | | - try: |
147 | | - _safe_copy_proto_list_values( |
148 | | - to_proto.library.gradient, |
149 | | - from_proto.library.gradient, |
150 | | - lambda g: g.gradient_func, |
151 | | - ) |
152 | | - except _ProtoListDuplicateKeyError as exc: |
153 | | - raise ValueError( |
154 | | - "A GraphDef contains non-unique gradient function names: %s" % exc |
155 | | - ) |
156 | | - except _SameKeyDiffContentError as exc: |
157 | | - raise ValueError( |
158 | | - ( |
159 | | - "Cannot combine GraphDefs because gradients share a gradient_func name " |
160 | | - "but map to different functions: %s" |
| 105 | + if graph_defs[0].versions.producer: |
| 106 | + dst_graph_def.versions.CopyFrom(graph_defs[0].versions) |
| 107 | + |
| 108 | + for index, graph_def in enumerate(graph_defs): |
| 109 | + if dst_graph_def.versions.producer != graph_def.versions.producer: |
| 110 | + raise ValueError("Cannot combine GraphDefs of different versions.") |
| 111 | + |
| 112 | + mapped_graph_def = _prepend_names("graph_%d" % (index + 1), graph_def) |
| 113 | + dst_graph_def.node.extend(mapped_graph_def.node) |
| 114 | + if mapped_graph_def.library.function: |
| 115 | + dst_graph_def.library.function.extend( |
| 116 | + mapped_graph_def.library.function |
| 117 | + ) |
| 118 | + if mapped_graph_def.library.gradient: |
| 119 | + dst_graph_def.library.gradient.extend( |
| 120 | + mapped_graph_def.library.gradient |
161 | 121 | ) |
162 | | - % exc |
163 | | - ) |
164 | 122 |
|
165 | | - return to_proto |
| 123 | + return dst_graph_def |
0 commit comments