Skip to content

Commit d03fceb

Browse files
authored
feat: combine graph by prefixing with unique name (#4334)
Previously, graph plugin combined multiple graphs traced by creating one monolith of a GraphDef. In doing so, we checked whether, for example, node names are unique to detect a case when our graph vis can result in faulty UI. To alleviate the poor UX, we decided, instead, to duplicate all nodes in one giant GraphDef container prefixing all names. While this creates some bloat, (1) users should see the confusing error less and (2) combined graphs make it very clear that we have traced multiple graphs.
1 parent 29ef9e6 commit d03fceb

File tree

4 files changed

+246
-587
lines changed

4 files changed

+246
-587
lines changed

tensorboard/plugins/graph/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ py_library(
123123
srcs = ["graph_util.py"],
124124
srcs_version = "PY2AND3",
125125
visibility = ["//visibility:private"],
126+
deps = [
127+
"//tensorboard/compat/proto:protos_all_py_pb2",
128+
],
126129
)
127130

128131
py_test(
@@ -136,7 +139,6 @@ py_test(
136139
"//tensorboard:expect_tensorflow_installed",
137140
"//tensorboard/compat/proto:protos_all_py_pb2",
138141
"@com_google_protobuf//:protobuf_python",
139-
"@org_pythonhosted_six",
140142
],
141143
)
142144

tensorboard/plugins/graph/graph_util.py

Lines changed: 86 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -14,152 +14,111 @@
1414
# ==============================================================================
1515
"""Utilities for graph plugin."""
1616

17+
from tensorboard.compat.proto import graph_pb2
1718

18-
class _ProtoListDuplicateKeyError(Exception):
19-
pass
2019

20+
def _prefixed_op_name(prefix, op_name):
21+
return "%s/%s" % (prefix, op_name)
2122

22-
class _SameKeyDiffContentError(Exception):
23-
pass
2423

24+
def _prefixed_func_name(prefix, func_name):
25+
"""Returns function name prefixed with `prefix`.
2526
26-
def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key):
27-
"""Safely merge values from `src_proto_list` into `dst_proto_list`.
27+
For function libraries, which are often created out of autographed Python
28+
function, are factored out in the graph vis. They are grouped under a
29+
function name which often has a shape of
30+
`__inference_[py_func_name]_[numeric_suffix]`.
2831
29-
Each element in `dst_proto_list` must be mapped by `get_key` to a key
30-
value that is unique within that list; likewise for `src_proto_list`.
31-
If an element of `src_proto_list` has the same key as an existing
32-
element in `dst_proto_list`, then the elements must also be equal.
32+
While it does not have some unique information about which graph it is from,
33+
creating another wrapping structure with graph prefix and "/" is less than
34+
ideal so we join the prefix and func_name using underscore.
3335
34-
Args:
35-
dst_proto_list: A `RepeatedCompositeContainer` or
36-
`RepeatedScalarContainer` into which values should be copied.
37-
src_proto_list: A container holding the same kind of values as in
38-
`dst_proto_list` from which values should be copied.
39-
get_key: A function that takes an element of `dst_proto_list` or
40-
`src_proto_list` and returns a key, such that if two elements have
41-
the same key then it is required that they be deep-equal. For
42-
instance, if `dst_proto_list` is a list of nodes, then `get_key`
43-
might be `lambda node: node.name` to indicate that if two nodes
44-
have the same name then they must be the same node. All keys must
45-
be hashable.
46-
47-
Raises:
48-
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
49-
keys.
50-
_SameKeyDiffContentError: An item with the same key has different contents.
36+
TODO(stephanwlee): add business logic to strip "__inference_" for more user
37+
friendlier name
5138
"""
39+
return "%s_%s" % (prefix, func_name)
40+
5241

53-
def _assert_proto_container_unique_keys(proto_list, get_key):
54-
"""Asserts proto_list to only contains unique keys.
55-
56-
Args:
57-
proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`.
58-
get_key: A function that takes an element of `proto_list` and returns a
59-
hashable key.
60-
61-
Raises:
62-
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
63-
keys.
64-
"""
65-
keys = set()
66-
for item in proto_list:
67-
key = get_key(item)
68-
if key in keys:
69-
raise _ProtoListDuplicateKeyError(key)
70-
keys.add(key)
71-
72-
_assert_proto_container_unique_keys(dst_proto_list, get_key)
73-
_assert_proto_container_unique_keys(src_proto_list, get_key)
74-
75-
key_to_proto = {}
76-
for proto in dst_proto_list:
77-
key = get_key(proto)
78-
key_to_proto[key] = proto
79-
80-
for proto in src_proto_list:
81-
key = get_key(proto)
82-
if key in key_to_proto:
83-
if proto != key_to_proto.get(key):
84-
raise _SameKeyDiffContentError(key)
85-
else:
86-
dst_proto_list.add().CopyFrom(proto)
87-
88-
89-
def combine_graph_defs(to_proto, from_proto):
90-
"""Combines two GraphDefs by adding nodes from from_proto into to_proto.
42+
def _add_with_prepended_names(prefix, graph_to_add, destination_graph):
43+
for node in graph_to_add.node:
44+
new_node = destination_graph.node.add()
45+
new_node.CopyFrom(node)
46+
new_node.name = _prefixed_op_name(prefix, node.name)
47+
new_node.input[:] = [
48+
_prefixed_op_name(prefix, input_name) for input_name in node.input
49+
]
50+
51+
# Remap tf.function method name in the PartitionedCall. 'f' is short for
52+
# function.
53+
if new_node.op == "PartitionedCall" and new_node.attr["f"]:
54+
55+
new_node.attr["f"].func.name = _prefixed_func_name(
56+
prefix, new_node.attr["f"].func.name,
57+
)
58+
59+
for func in graph_to_add.library.function:
60+
new_func = destination_graph.library.function.add()
61+
new_func.CopyFrom(func)
62+
new_func.signature.name = _prefixed_func_name(
63+
prefix, new_func.signature.name
64+
)
65+
66+
for gradient in graph_to_add.library.gradient:
67+
new_gradient = destination_graph.library.gradient.add()
68+
new_gradient.CopyFrom(gradient)
69+
new_gradient.function_name = _prefixed_func_name(
70+
prefix, new_gradient.function_name,
71+
)
72+
new_gradient.gradient_func = _prefixed_func_name(
73+
prefix, new_gradient.gradient_func,
74+
)
75+
76+
77+
def merge_graph_defs(graph_defs):
78+
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.
9179
9280
All GraphDefs are expected to be of TensorBoard's.
93-
It assumes node names are unique across GraphDefs if contents differ. The
94-
names can be the same if the NodeDef content are exactly the same.
81+
82+
When collecting graphs using the `tf.summary.trace` API, node names are not
83+
guranteed to be unique. When non-unique names are not considered, it can
84+
lead to graph visualization showing them as one which creates inaccurate
85+
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
86+
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
87+
for uniquenss while merging but it resulted in
88+
https://github.com/tensorflow/tensorboard/issues/1929.
89+
90+
To remedy these issues, we simply "apply name scope" on each graph by
91+
prefixing it with unique name (with a chance of collision) to create
92+
unconnected group of graphs.
93+
94+
In case there is only one graph def passed, it returns the original
95+
graph_def. In case no graph defs are passed, it returns an empty GraphDef.
9596
9697
Args:
97-
to_proto: A destination TensorBoard GraphDef.
98-
from_proto: A TensorBoard GraphDef to copy contents from.
98+
graph_defs: TensorBoard GraphDefs to merge.
9999
100100
Returns:
101-
to_proto
101+
TensorBoard GraphDef that merges all graph_defs with unique prefixes.
102102
103103
Raises:
104-
ValueError in case any assumption about GraphDef is violated: A
105-
GraphDef should have unique node, function, and gradient function
106-
names. Also, when merging GraphDefs, they should have not have nodes,
107-
functions, or gradient function mappings that share the name but details
108-
do not match.
104+
ValueError in case GraphDef versions mismatch.
109105
"""
110-
if from_proto.version != to_proto.version:
111-
raise ValueError("Cannot combine GraphDefs of different versions.")
106+
if len(graph_defs) == 1:
107+
return graph_defs[0]
108+
elif len(graph_defs) == 0:
109+
return graph_pb2.GraphDef()
112110

113-
try:
114-
_safe_copy_proto_list_values(
115-
to_proto.node, from_proto.node, lambda n: n.name
116-
)
117-
except _ProtoListDuplicateKeyError as exc:
118-
raise ValueError("A GraphDef contains non-unique node names: %s" % exc)
119-
except _SameKeyDiffContentError as exc:
120-
raise ValueError(
121-
(
122-
"Cannot combine GraphDefs because nodes share a name "
123-
"but contents are different: %s"
124-
)
125-
% exc
126-
)
127-
try:
128-
_safe_copy_proto_list_values(
129-
to_proto.library.function,
130-
from_proto.library.function,
131-
lambda n: n.signature.name,
132-
)
133-
except _ProtoListDuplicateKeyError as exc:
134-
raise ValueError(
135-
"A GraphDef contains non-unique function names: %s" % exc
136-
)
137-
except _SameKeyDiffContentError as exc:
138-
raise ValueError(
139-
(
140-
"Cannot combine GraphDefs because functions share a name "
141-
"but are different: %s"
142-
)
143-
% exc
144-
)
111+
dst_graph_def = graph_pb2.GraphDef()
145112

146-
try:
147-
_safe_copy_proto_list_values(
148-
to_proto.library.gradient,
149-
from_proto.library.gradient,
150-
lambda g: g.gradient_func,
151-
)
152-
except _ProtoListDuplicateKeyError as exc:
153-
raise ValueError(
154-
"A GraphDef contains non-unique gradient function names: %s" % exc
155-
)
156-
except _SameKeyDiffContentError as exc:
157-
raise ValueError(
158-
(
159-
"Cannot combine GraphDefs because gradients share a gradient_func name "
160-
"but map to different functions: %s"
161-
)
162-
% exc
113+
if graph_defs[0].versions.producer:
114+
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)
115+
116+
for index, graph_def in enumerate(graph_defs):
117+
if dst_graph_def.versions.producer != graph_def.versions.producer:
118+
raise ValueError("Cannot combine GraphDefs of different versions.")
119+
120+
_add_with_prepended_names(
121+
"graph_%d" % (index + 1), graph_def, dst_graph_def,
163122
)
164123

165-
return to_proto
124+
return dst_graph_def

0 commit comments

Comments
 (0)