@@ -73,87 +73,69 @@ Status DebugNodeInserter::InsertNodes(
7373 }
7474
7575 DeviceType device_type = DeviceType{device->device_type ()};
76- // 1. Record existing edges in the graph.
77- std::vector<const Edge*> existing_edges;
78- for (const Edge* edge : graph->edges ()) {
79- existing_edges.push_back (edge);
80- }
81-
82- // A map from tensor names to edges to be removed
83- std::unordered_map<string, std::vector<const Edge*>> edges_to_remove;
84- // A map from tensor names to newly added debug nodes (maybe more than one
85- // for a given tensor).
86- std::unordered_map<string, std::vector<Node*>> added_debug_nodes;
87- std::unordered_map<string, Node*> added_copy_nodes;
8876
89- // 2. Iterate through the edges, look for edges that match the tensor watch
90- // list.
91- for (const Edge* edge : existing_edges) {
92- Node* src_node = edge->src ();
93- Node* dst_node = edge->dst ();
94-
95- if (edge->IsControlEdge ()) {
96- continue ;
77+ // Keep track of all edges to be removed.
78+ std::vector<const Edge*> edges_to_remove;
79+
80+ for (Node* src_node : graph->nodes ()) {
81+ // Make a map from output slot to outgoing edges from the slot.
82+ std::unordered_map<int , std::vector<const Edge*>> output_slot_to_edges;
83+ for (const Edge* edge : src_node->out_edges ()) {
84+ const int src_output = edge->src_output ();
85+ if (output_slot_to_edges.find (src_output) == output_slot_to_edges.end ()) {
86+ output_slot_to_edges[src_output] = {edge};
87+ } else {
88+ output_slot_to_edges[src_output].push_back (edge);
89+ }
9790 }
9891
99- const bool is_ref = IsRefType (dst_node->input_type (edge->dst_input ()));
100- MemoryType memory_type;
101- MemoryTypeForOutput (device_type, graph, src_node, edge->src_output (),
102- &memory_type);
103-
104- const string tensor_name =
105- strings::StrCat (src_node->name (), " :" , edge->src_output ());
106- if (tensor_watches.find (tensor_name) == tensor_watches.end ()) {
107- // Add debug nodes only for edges with matching source node and source
108- // output slot.
109- continue ;
110- }
92+ // Iterate through all output slots of the node.
93+ for (int src_output_slot = 0 ; src_output_slot < src_node->num_outputs ();
94+ ++src_output_slot) {
95+ const string tensor_name =
96+ strings::StrCat (src_node->name (), " :" , src_output_slot);
97+ if (tensor_watches.find (tensor_name) == tensor_watches.end ()) {
98+ // Add debug nodes only for edges with matching source node and source
99+ // output slot.
100+ continue ;
101+ }
111102
112- if (added_copy_nodes.find (tensor_name) == added_copy_nodes.end ()) {
113- // It is the first time an edge with this source tensor is encountered:
114- // we will:
115- // 1) Mark this edge as to be removed, iff the destination node has
116- // non-Ref input
117- // 2) Create a Copy node
103+ // Now we have encountered a watched tensor. We will:
104+ // 1) Mark this edge as to be removed, iff this is a non-Reference
105+ // tensor
106+ // 2) Create a Copy node for the tensor
118107 // 3) Add a new edge, from the source tensor to the Copy node
119108 // 4) Add a new edge, from the Copy node to the destination node, iff
120- // the destination node has non-Ref input
109+ // this is a non-Reference tensor.
121110 // 5) Create all the requested debug nodes and their edges to the Copy
122111 // node.
123- if (!is_ref) {
124- std::vector<const Edge*> node_edges_to_remove;
125- node_edges_to_remove.push_back (edge);
126- edges_to_remove[tensor_name] = node_edges_to_remove;
127- }
112+ // 6) Add control edges from the debug nodes to the destination nodes
113+ // to ensure that the tensors values exported by the debug nodes
114+ // to the debug URLs reflect the values before the execution of
115+ // the destination nodes.
128116
129- const DataType src_dt = src_node->output_type (edge->src_output ());
117+ const DataType src_dt = src_node->output_type (src_output_slot);
118+ MemoryType memory_type;
119+ MemoryTypeForOutput (device_type, graph, src_node, src_output_slot,
120+ &memory_type);
130121
131- // Create the copy node.
122+ // Create the copy node for the watched tensor .
132123 Node* copy_node;
133124 Status copy_s = CreateCopyNode (
134125 graph, device_type, memory_type == HOST_MEMORY , src_node->name (),
135- edge-> src_output () , src_dt, tensor_name, ©_node);
126+ src_output_slot , src_dt, tensor_name, ©_node);
136127 if (!copy_s.ok ()) {
137128 return Status (
138129 error::FAILED_PRECONDITION ,
139130 strings::StrCat (" Failed to create Copy/CopyHost node for tensor " ,
140131 tensor_name, " , due to: " , copy_s.error_message ()));
141132 }
142133
143- // Record the added copy node for later use.
144- added_copy_nodes[tensor_name] = copy_node;
145-
146134 // Add edge from watched tensor to the copy node.
147- graph->AddEdge (src_node, edge->src_output (), copy_node, 0 );
148-
149- // Add edge from the copy node to the destination node, iff the
150- // destination node has non-Ref input.
151- if (!is_ref) {
152- graph->AddEdge (copy_node, 0 , dst_node, edge->dst_input ());
153- }
135+ graph->AddEdge (src_node, src_output_slot, copy_node, 0 );
154136
155137 // Create all requested debug nodes and their edges to the Copy node.
156- std::vector<Node*> node_added_debug_nodes ;
138+ std::vector<Node*> debug_nodes ;
157139 for (size_t i = 0 ; i < tensor_watches[tensor_name].size (); ++i) {
158140 const string& debug_op_name = tensor_watches[tensor_name][i];
159141
@@ -169,47 +151,37 @@ Status DebugNodeInserter::InsertNodes(
169151 debug_s.error_message ()));
170152 }
171153
172- node_added_debug_nodes.push_back (debug_node);
173-
174154 // Create edges from the Copy node to the debug node.
175155 graph->AddEdge (copy_node, 0 , debug_node, 0 );
176156
157+ debug_nodes.push_back (debug_node);
158+ }
159+
160+ // Is the output a reference?
161+ const bool is_ref = IsRefType (src_node->output_type (src_output_slot));
162+
163+ // Iterate through all outgoing edges attached to the slot.
164+ for (const Edge* edge : output_slot_to_edges[src_output_slot]) {
165+ // Mark the edge for removal.
166+ if (!is_ref) {
167+ edges_to_remove.push_back (edge);
168+ graph->AddEdge (copy_node, 0 , edge->dst (), edge->dst_input ());
169+ }
170+
177171 // Add control edges from the debug nodes to the destination node
178172 // to ensure that the debug nodes are executed before the destination
179173 // node.
180- graph->AddEdge (debug_node, Graph::kControlSlot , dst_node,
181- Graph::kControlSlot );
182- }
183- added_debug_nodes[tensor_name] = node_added_debug_nodes;
184- } else {
185- // It is not the first time an edge with this source is encountered.
186- // We will do the following iff the destination node has non-Ref input
187- // 1) Mark the edge for removal
188- // 2) Create an edge from the copy node to the destination node
189- // Iff the destination has Ref-input, the edge will not change.
190- // Regardless of whether the destination has Ref-inpt, we will
191- // 3) Add control edges from the already-created debug node(s) for the
192- // watched tensor to the destination node.
193- if (!is_ref) {
194- edges_to_remove[tensor_name].push_back (edge);
195- graph->AddEdge (added_copy_nodes[tensor_name], 0 , dst_node,
196- edge->dst_input ());
197- }
198-
199- for (Node* debug_node : added_debug_nodes[tensor_name]) {
200- graph->AddEdge (debug_node, Graph::kControlSlot , dst_node,
201- Graph::kControlSlot );
174+ for (Node* debug_node : debug_nodes) {
175+ graph->AddEdge (debug_node, Graph::kControlSlot , edge->dst (),
176+ Graph::kControlSlot );
177+ }
202178 }
203179 }
204180 }
205181
206182 // Remove all edges marked for removal.
207- for (auto it : edges_to_remove) {
208- std::vector<const Edge*> edges = it.second ;
209-
210- for (const Edge* edge : edges) {
211- graph->RemoveEdge (edge);
212- }
183+ for (const Edge* edge : edges_to_remove) {
184+ graph->RemoveEdge (edge);
213185 }
214186
215187 return Status::OK ();
0 commit comments