@@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
5252 return [state]() { tracer::setTracingState (state); };
5353}
5454
55- void delValueTrace (const Variable& var) {
56- AT_ASSERT (var.defined ());
57- auto & env_stack = getTracingState ()->env_stack ;
55+ void delValueTrace (const IValue& var) {
56+ getTracingState ()->delValue (var);
57+ }
58+ void TracingState::delValue (const IValue& var) {
59+ at::Tensor t = var.toTensor ();
60+ AT_ASSERT (t.defined ());
5861 for (size_t i = 0 ; i < env_stack.size (); ++i) {
5962 auto & value_map = env_stack.at (env_stack.size () - 1 - i).value_map ;
6063
61- auto it = value_map.find (var );
64+ auto it = value_map.find (t );
6265 if (it == value_map.end ()) {
6366 continue ;
6467 }
6568 value_map.erase (it);
6669 }
67- getTracingState ()->env_stack .back ().value_map .erase (var);
6870}
6971
7072// Given a IValue 'var', return the 'node' which represents the instruction
@@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
8284// zero. This is one of the cases where a Variable can be created inside of a
8385// trace, and if we treat it as a constant, everything will work out.
8486Value* getValueTrace (const IValue& var) {
85- auto & state = getTracingState ();
86- auto & env_stack = getTracingState ()-> env_stack ;
87-
87+ return getTracingState ()-> getValue (var );
88+ }
89+ Value* TracingState::getValue ( const IValue& var) {
8890 // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
8991 if (var.isTensorList ()) {
90- return state-> graph
91- ->insertNode (state-> graph ->createList (
92+ return graph
93+ ->insertNode (graph->createList (
9294 TensorType::get (),
9395 fmap (
9496 var.toTensorListRef (),
9597 [](const IValue& val) { return getValueTrace (val); })))
9698 ->output ();
9799 } else if (var.isTuple ()) {
98- return state-> graph
99- ->insertNode (state-> graph ->createTuple (fmap (
100+ return graph
101+ ->insertNode (graph->createTuple (fmap (
100102 var.toTuple ()->elements (),
101103 [](const IValue& val) { return getValueTrace (val); })))
102104 ->output ();
103105 } if (var.isTensor ()) {
104106 auto ten = var.toTensor ();
105107 if (!ten.defined ()) {
106- Node* n = state-> graph ->createNone (TensorType::get ());
107- return state-> graph ->insertNode (n)->output ();
108+ Node* n = graph->createNone (TensorType::get ());
109+ return graph->insertNode (n)->output ();
108110 }
109111 for (size_t i = 0 ; i < env_stack.size (); ++i) {
110112 auto & value_map = env_stack.at (env_stack.size () - 1 - i).value_map ;
@@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
132134 throw std::runtime_error (oss.str ());
133135 }
134136
135- Value* constant = state-> graph ->insertConstant (ten);
137+ Value* constant = graph->insertConstant (ten);
136138 recordSourceLocation (constant->node ());
137139 constant->inferTypeFrom (ten);
138140 auto it = env_stack.back ().value_map .find (ten);
@@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
155157 } else {
156158 // If the values are non-tensors, we try to create constants
157159 // and bake those constants into the traced graph
158- auto constant = tryInsertConstant (*state-> graph , var);
160+ auto constant = tryInsertConstant (*graph, var);
159161 if (constant) {
160162 recordSourceLocation (constant.value ()->node ());
161163 return *constant;
@@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
167169 throw std::runtime_error (os.str ());
168170 }
169171}
170-
171- Value* getOutputTrace (
172- const std::shared_ptr<TracingState>& state,
173- const Variable& var) {
174- if (!var.defined ()) {
175- Node* n = state->graph ->createNone (TensorType::get ());
176- return state->graph ->insertNode (n)->output ();
177- }
178-
179- auto & value_map = getTracingState ()->env_stack .back ().value_map ;
180- auto it = value_map.find (var);
181- if (it == value_map.end ()) {
182- std::ostringstream os;
183- os << " output of traced region did not have observable "
184- << " data dependence with trace inputs; this probably indicates your program "
185- << " cannot be understood by the tracer." ;
186- throw std::runtime_error (os.str ());
172+ bool TracingState::hasValue (const IValue& var) const {
173+ if (var.isTensor ()) {
174+ at::Tensor t = var.toTensor ();
175+ for (const auto & frame : env_stack) {
176+ if (frame.value_map .count (t)) {
177+ return true ;
178+ }
179+ }
187180 }
188- return it->second ;
189- }
190-
191- Value* getNestedOutputTrace (
192- const std::shared_ptr<TracingState>& state,
193- const IValue& iv) {
194- if (iv.isTensor ()) {
195- return getOutputTrace (state, iv.toTensor ());
181+ return false ;
182+ }
183+
184+
185+ Value* TracingState::getOutput (const IValue& iv) {
186+ if (iv.isTensor ()) {
187+ at::Tensor var = iv.toTensor ();
188+ if (!var.defined ()) {
189+ Node *n = graph->createNone (TensorType::get ());
190+ return graph->insertNode (n)->output ();
191+ }
192+
193+ auto &value_map = getTracingState ()->env_stack .back ().value_map ;
194+ auto it = value_map.find (var);
195+ if (it == value_map.end ()) {
196+ std::ostringstream os;
197+ os << " output of traced region did not have observable "
198+ << " data dependence with trace inputs; this probably indicates your "
199+ " program "
200+ << " cannot be understood by the tracer." ;
201+ throw std::runtime_error (os.str ());
202+ }
203+ return it->second ;
196204 } else if (iv.isTuple ()) {
197205 const auto & elems = iv.toTuple ()->elements ();
198206 auto tuple_node =
199- state-> graph ->createTuple (fmap (elems, [&state ](const IValue& ival) {
200- return getNestedOutputTrace (state, ival);
207+ graph->createTuple (fmap (elems, [&](const IValue& ival) {
208+ return getOutput ( ival);
201209 }));
202- state-> graph ->insertNode (tuple_node);
210+ graph->insertNode (tuple_node);
203211 return tuple_node->output ();
204212 } else {
205213 AT_ERROR (
@@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
213221 if (type->isSubtypeOf (TensorType::get ())) {
214222 auto input_tensor = input.toTensor ();
215223 auto name = Variable (input_tensor).name ();
216- auto & value_map = state->env_stack .back ().value_map ;
217- if (value_map.find (input_tensor) != value_map.end ()) {
224+ if (state->hasValue (input)) {
218225 input_tensor = input_tensor.view (input_tensor.sizes ());
219226 }
220227 value->setUniqueName (name);
221- value_map[ input_tensor] = value;
228+ state-> setValue ( input_tensor, value) ;
222229 return input_tensor;
223230 } else if (auto tuple_type = type->cast <TupleType>()) {
224231 auto unpack_node =
@@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
330337 auto & state = getTracingState ();
331338 size_t i = 0 ;
332339 for (auto & output : outputs) {
333- state->graph ->registerOutput (getNestedOutputTrace ( state, output));
340+ state->graph ->registerOutput (state-> getOutput ( output));
334341 i++;
335342 }
336343 setTracingState (nullptr );
@@ -342,36 +349,36 @@ void abandon() {
342349}
343350
344351void setValueTrace (const IValue& v, Value* value) {
352+ return getTracingState ()->setValue (v, value);
353+ }
354+ void TracingState::setValue (const IValue& v, Value* value) {
345355 if (v.isTensor ()) {
346356 auto var = v.toTensor ();
347357 AT_ASSERT (var.defined ());
348- getTracingState ()-> env_stack .back ().value_map [var] = value;
358+ env_stack.back ().value_map [var] = value;
349359 } else if (v.isTensorList ()) {
350360 auto & outputs = v.toTensorList ()->elements ();
351- auto graph = getTracingState ()->graph ;
352361 Node* unpack_node =
353362 graph->insertNode (graph->createListUnpack (value, outputs.size ()));
354363 for (size_t i = 0 ; i < outputs.size (); ++i) {
355- setValueTrace (outputs[i], unpack_node->outputs ()[i]);
364+ setValue (outputs[i], unpack_node->outputs ()[i]);
356365 }
357366 } else if (v.isTuple ()) {
358367 auto & outputs = v.toTuple ()->elements ();
359- auto graph = getTracingState ()->graph ;
360368 Node* unpack_node = graph->insertNode (graph->createTupleUnpack (value));
361369 for (size_t i = 0 ; i < outputs.size (); ++i) {
362- setValueTrace (outputs[i], unpack_node->outputs ()[i]);
370+ setValue (outputs[i], unpack_node->outputs ()[i]);
363371 }
364372 } else if (v.isGenericList ()) {
365373 auto elements = v.toGenericListRef ();
366- auto graph = getTracingState ()->graph ;
367374 Node* unpack_node =
368375 graph->insertNode (graph->createListUnpack (value, elements.size ()));
369376 for (size_t i = 0 ; i < elements.size (); ++i) {
370- setValueTrace (elements[i], unpack_node->outputs ()[i]);
377+ setValue (elements[i], unpack_node->outputs ()[i]);
371378 }
372379 } else if (v.isFuture ()) {
373380 auto fut = v.toFuture ();
374- getTracingState ()-> env_stack .back ().future_map [fut] = value;
381+ env_stack.back ().future_map [fut] = value;
375382 } else {
376383 std::ostringstream os;
377384 os << " Tracer cannot set value trace for type " << v.tagKind () << " . "
@@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
560567}
561568
562569TracingState::TracingState ()
563- : env_stack{ TracingEnvironmentFrame ()}, graph(new Graph()) {}
570+ : graph(new Graph()), env_stack{ Frame ()} {}
564571
565572TracingState::~TracingState () = default ;
566573
0 commit comments