|
14 | 14 | # ============================================================================== |
15 | 15 | """Utilities for graph plugin.""" |
16 | 16 |
|
| 17 | +from tensorboard.compat.proto import graph_pb2 |
17 | 18 |
|
18 | | -class _ProtoListDuplicateKeyError(Exception): |
19 | | - pass |
20 | 19 |
|
| 20 | +def _prefixed_op_name(prefix, op_name): |
| 21 | + return "%s/%s" % (prefix, op_name) |
21 | 22 |
|
22 | | -class _SameKeyDiffContentError(Exception): |
23 | | - pass |
24 | 23 |
|
| 24 | +def _prefixed_func_name(prefix, func_name): |
| 25 | + """Returns function name prefixed with `prefix`. |
25 | 26 |
|
26 | | -def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key): |
27 | | - """Safely merge values from `src_proto_list` into `dst_proto_list`. |
| 27 | + For function libraries, which are often created out of autographed Python |
| 28 | + function, are factored out in the graph vis. They are grouped under a |
| 29 | + function name which often has a shape of |
| 30 | + `__inference_[py_func_name]_[numeric_suffix]`. |
28 | 31 |
|
29 | | - Each element in `dst_proto_list` must be mapped by `get_key` to a key |
30 | | - value that is unique within that list; likewise for `src_proto_list`. |
31 | | - If an element of `src_proto_list` has the same key as an existing |
32 | | - element in `dst_proto_list`, then the elements must also be equal. |
| 32 | + While it does not have some unique information about which graph it is from, |
| 33 | + creating another wrapping structure with graph prefix and "/" is less than |
| 34 | + ideal so we join the prefix and func_name using underscore. |
33 | 35 |
|
34 | | - Args: |
35 | | - dst_proto_list: A `RepeatedCompositeContainer` or |
36 | | - `RepeatedScalarContainer` into which values should be copied. |
37 | | - src_proto_list: A container holding the same kind of values as in |
38 | | - `dst_proto_list` from which values should be copied. |
39 | | - get_key: A function that takes an element of `dst_proto_list` or |
40 | | - `src_proto_list` and returns a key, such that if two elements have |
41 | | - the same key then it is required that they be deep-equal. For |
42 | | - instance, if `dst_proto_list` is a list of nodes, then `get_key` |
43 | | - might be `lambda node: node.name` to indicate that if two nodes |
44 | | - have the same name then they must be the same node. All keys must |
45 | | - be hashable. |
46 | | -
|
47 | | - Raises: |
48 | | - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
49 | | - keys. |
50 | | - _SameKeyDiffContentError: An item with the same key has different contents. |
| 36 | + TODO(stephanwlee): add business logic to strip "__inference_" for more user |
| 37 | + friendlier name |
51 | 38 | """ |
| 39 | + return "%s_%s" % (prefix, func_name) |
| 40 | + |
52 | 41 |
|
53 | | - def _assert_proto_container_unique_keys(proto_list, get_key): |
54 | | - """Asserts proto_list to only contains unique keys. |
55 | | -
|
56 | | - Args: |
57 | | - proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`. |
58 | | - get_key: A function that takes an element of `proto_list` and returns a |
59 | | - hashable key. |
60 | | -
|
61 | | - Raises: |
62 | | - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
63 | | - keys. |
64 | | - """ |
65 | | - keys = set() |
66 | | - for item in proto_list: |
67 | | - key = get_key(item) |
68 | | - if key in keys: |
69 | | - raise _ProtoListDuplicateKeyError(key) |
70 | | - keys.add(key) |
71 | | - |
72 | | - _assert_proto_container_unique_keys(dst_proto_list, get_key) |
73 | | - _assert_proto_container_unique_keys(src_proto_list, get_key) |
74 | | - |
75 | | - key_to_proto = {} |
76 | | - for proto in dst_proto_list: |
77 | | - key = get_key(proto) |
78 | | - key_to_proto[key] = proto |
79 | | - |
80 | | - for proto in src_proto_list: |
81 | | - key = get_key(proto) |
82 | | - if key in key_to_proto: |
83 | | - if proto != key_to_proto.get(key): |
84 | | - raise _SameKeyDiffContentError(key) |
85 | | - else: |
86 | | - dst_proto_list.add().CopyFrom(proto) |
87 | | - |
88 | | - |
89 | | -def combine_graph_defs(to_proto, from_proto): |
90 | | - """Combines two GraphDefs by adding nodes from from_proto into to_proto. |
| 42 | +def _add_with_prepended_names(prefix, graph_to_add, destination_graph): |
| 43 | + for node in graph_to_add.node: |
| 44 | + new_node = destination_graph.node.add() |
| 45 | + new_node.CopyFrom(node) |
| 46 | + new_node.name = _prefixed_op_name(prefix, node.name) |
| 47 | + new_node.input[:] = [ |
| 48 | + _prefixed_op_name(prefix, input_name) for input_name in node.input |
| 49 | + ] |
| 50 | + |
| 51 | + # Remap tf.function method name in the PartitionedCall. 'f' is short for |
| 52 | + # function. |
| 53 | + if new_node.op == "PartitionedCall" and new_node.attr["f"]: |
| 54 | + |
| 55 | + new_node.attr["f"].func.name = _prefixed_func_name( |
| 56 | + prefix, new_node.attr["f"].func.name, |
| 57 | + ) |
| 58 | + |
| 59 | + for func in graph_to_add.library.function: |
| 60 | + new_func = destination_graph.library.function.add() |
| 61 | + new_func.CopyFrom(func) |
| 62 | + new_func.signature.name = _prefixed_func_name( |
| 63 | + prefix, new_func.signature.name |
| 64 | + ) |
| 65 | + |
| 66 | + for gradient in graph_to_add.library.gradient: |
| 67 | + new_gradient = destination_graph.library.gradient.add() |
| 68 | + new_gradient.CopyFrom(gradient) |
| 69 | + new_gradient.function_name = _prefixed_func_name( |
| 70 | + prefix, new_gradient.function_name, |
| 71 | + ) |
| 72 | + new_gradient.gradient_func = _prefixed_func_name( |
| 73 | + prefix, new_gradient.gradient_func, |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +def merge_graph_defs(graph_defs): |
| 78 | + """Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names. |
91 | 79 |
|
92 | 80 | All GraphDefs are expected to be of TensorBoard's. |
93 | | - It assumes node names are unique across GraphDefs if contents differ. The |
94 | | - names can be the same if the NodeDef content are exactly the same. |
| 81 | +
|
| 82 | + When collecting graphs using the `tf.summary.trace` API, node names are not |
| 83 | + guranteed to be unique. When non-unique names are not considered, it can |
| 84 | + lead to graph visualization showing them as one which creates inaccurate |
| 85 | + depiction of the flow of the graph (e.g., if there are A -> B -> C and D -> |
| 86 | + B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked |
| 87 | + for uniquenss while merging but it resulted in |
| 88 | + https://github.com/tensorflow/tensorboard/issues/1929. |
| 89 | +
|
| 90 | + To remedy these issues, we simply "apply name scope" on each graph by |
| 91 | + prefixing it with unique name (with a chance of collision) to create |
| 92 | + unconnected group of graphs. |
| 93 | +
|
| 94 | + In case there is only one graph def passed, it returns the original |
| 95 | + graph_def. In case no graph defs are passed, it returns an empty GraphDef. |
95 | 96 |
|
96 | 97 | Args: |
97 | | - to_proto: A destination TensorBoard GraphDef. |
98 | | - from_proto: A TensorBoard GraphDef to copy contents from. |
| 98 | + graph_defs: TensorBoard GraphDefs to merge. |
99 | 99 |
|
100 | 100 | Returns: |
101 | | - to_proto |
| 101 | + TensorBoard GraphDef that merges all graph_defs with unique prefixes. |
102 | 102 |
|
103 | 103 | Raises: |
104 | | - ValueError in case any assumption about GraphDef is violated: A |
105 | | - GraphDef should have unique node, function, and gradient function |
106 | | - names. Also, when merging GraphDefs, they should have not have nodes, |
107 | | - functions, or gradient function mappings that share the name but details |
108 | | - do not match. |
| 104 | + ValueError in case GraphDef versions mismatch. |
109 | 105 | """ |
110 | | - if from_proto.version != to_proto.version: |
111 | | - raise ValueError("Cannot combine GraphDefs of different versions.") |
| 106 | + if len(graph_defs) == 1: |
| 107 | + return graph_defs[0] |
| 108 | + elif len(graph_defs) == 0: |
| 109 | + return graph_pb2.GraphDef() |
112 | 110 |
|
113 | | - try: |
114 | | - _safe_copy_proto_list_values( |
115 | | - to_proto.node, from_proto.node, lambda n: n.name |
116 | | - ) |
117 | | - except _ProtoListDuplicateKeyError as exc: |
118 | | - raise ValueError("A GraphDef contains non-unique node names: %s" % exc) |
119 | | - except _SameKeyDiffContentError as exc: |
120 | | - raise ValueError( |
121 | | - ( |
122 | | - "Cannot combine GraphDefs because nodes share a name " |
123 | | - "but contents are different: %s" |
124 | | - ) |
125 | | - % exc |
126 | | - ) |
127 | | - try: |
128 | | - _safe_copy_proto_list_values( |
129 | | - to_proto.library.function, |
130 | | - from_proto.library.function, |
131 | | - lambda n: n.signature.name, |
132 | | - ) |
133 | | - except _ProtoListDuplicateKeyError as exc: |
134 | | - raise ValueError( |
135 | | - "A GraphDef contains non-unique function names: %s" % exc |
136 | | - ) |
137 | | - except _SameKeyDiffContentError as exc: |
138 | | - raise ValueError( |
139 | | - ( |
140 | | - "Cannot combine GraphDefs because functions share a name " |
141 | | - "but are different: %s" |
142 | | - ) |
143 | | - % exc |
144 | | - ) |
| 111 | + dst_graph_def = graph_pb2.GraphDef() |
145 | 112 |
|
146 | | - try: |
147 | | - _safe_copy_proto_list_values( |
148 | | - to_proto.library.gradient, |
149 | | - from_proto.library.gradient, |
150 | | - lambda g: g.gradient_func, |
151 | | - ) |
152 | | - except _ProtoListDuplicateKeyError as exc: |
153 | | - raise ValueError( |
154 | | - "A GraphDef contains non-unique gradient function names: %s" % exc |
155 | | - ) |
156 | | - except _SameKeyDiffContentError as exc: |
157 | | - raise ValueError( |
158 | | - ( |
159 | | - "Cannot combine GraphDefs because gradients share a gradient_func name " |
160 | | - "but map to different functions: %s" |
161 | | - ) |
162 | | - % exc |
| 113 | + if graph_defs[0].versions.producer: |
| 114 | + dst_graph_def.versions.CopyFrom(graph_defs[0].versions) |
| 115 | + |
| 116 | + for index, graph_def in enumerate(graph_defs): |
| 117 | + if dst_graph_def.versions.producer != graph_def.versions.producer: |
| 118 | + raise ValueError("Cannot combine GraphDefs of different versions.") |
| 119 | + |
| 120 | + _add_with_prepended_names( |
| 121 | + "graph_%d" % (index + 1), graph_def, dst_graph_def, |
163 | 122 | ) |
164 | 123 |
|
165 | | - return to_proto |
| 124 | + return dst_graph_def |
0 commit comments