Skip to content

Commit dd0faf4

Browse files
Zachary DeVitofacebook-github-bot
authored andcommitted
clean up the TracingState API (#21514)
Summary: Pull Request resolved: #21514 ghimport-source-id: 6a9b6fd Reviewed By: jamesr66a Differential Revision: D15719980 Pulled By: zdevito fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862
1 parent 8c5f3ac commit dd0faf4

File tree

2 files changed

+90
-76
lines changed

2 files changed

+90
-76
lines changed

torch/csrc/jit/tracer.cpp

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
5252
return [state]() { tracer::setTracingState(state); };
5353
}
5454

55-
void delValueTrace(const Variable& var) {
56-
AT_ASSERT(var.defined());
57-
auto& env_stack = getTracingState()->env_stack;
55+
void delValueTrace(const IValue& var) {
56+
getTracingState()->delValue(var);
57+
}
58+
void TracingState::delValue(const IValue& var) {
59+
at::Tensor t = var.toTensor();
60+
AT_ASSERT(t.defined());
5861
for (size_t i = 0; i < env_stack.size(); ++i) {
5962
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
6063

61-
auto it = value_map.find(var);
64+
auto it = value_map.find(t);
6265
if (it == value_map.end()) {
6366
continue;
6467
}
6568
value_map.erase(it);
6669
}
67-
getTracingState()->env_stack.back().value_map.erase(var);
6870
}
6971

7072
// Given a IValue 'var', return the 'node' which represents the instruction
@@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
8284
// zero. This is one of the cases where a Variable can be created inside of a
8385
// trace, and if we treat it as a constant, everything will work out.
8486
Value* getValueTrace(const IValue& var) {
85-
auto& state = getTracingState();
86-
auto& env_stack = getTracingState()->env_stack;
87-
87+
return getTracingState()->getValue(var);
88+
}
89+
Value* TracingState::getValue(const IValue& var) {
8890
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
8991
if (var.isTensorList()) {
90-
return state->graph
91-
->insertNode(state->graph->createList(
92+
return graph
93+
->insertNode(graph->createList(
9294
TensorType::get(),
9395
fmap(
9496
var.toTensorListRef(),
9597
[](const IValue& val) { return getValueTrace(val); })))
9698
->output();
9799
} else if (var.isTuple()) {
98-
return state->graph
99-
->insertNode(state->graph->createTuple(fmap(
100+
return graph
101+
->insertNode(graph->createTuple(fmap(
100102
var.toTuple()->elements(),
101103
[](const IValue& val) { return getValueTrace(val); })))
102104
->output();
103105
} if (var.isTensor()) {
104106
auto ten = var.toTensor();
105107
if (!ten.defined()) {
106-
Node* n = state->graph->createNone(TensorType::get());
107-
return state->graph->insertNode(n)->output();
108+
Node* n = graph->createNone(TensorType::get());
109+
return graph->insertNode(n)->output();
108110
}
109111
for (size_t i = 0; i < env_stack.size(); ++i) {
110112
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
@@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
132134
throw std::runtime_error(oss.str());
133135
}
134136

135-
Value* constant = state->graph->insertConstant(ten);
137+
Value* constant = graph->insertConstant(ten);
136138
recordSourceLocation(constant->node());
137139
constant->inferTypeFrom(ten);
138140
auto it = env_stack.back().value_map.find(ten);
@@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
155157
} else {
156158
// If the values are non-tensors, we try to create constants
157159
// and bake those constants into the traced graph
158-
auto constant = tryInsertConstant(*state->graph, var);
160+
auto constant = tryInsertConstant(*graph, var);
159161
if (constant) {
160162
recordSourceLocation(constant.value()->node());
161163
return *constant;
@@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
167169
throw std::runtime_error(os.str());
168170
}
169171
}
170-
171-
Value* getOutputTrace(
172-
const std::shared_ptr<TracingState>& state,
173-
const Variable& var) {
174-
if (!var.defined()) {
175-
Node* n = state->graph->createNone(TensorType::get());
176-
return state->graph->insertNode(n)->output();
177-
}
178-
179-
auto& value_map = getTracingState()->env_stack.back().value_map;
180-
auto it = value_map.find(var);
181-
if (it == value_map.end()) {
182-
std::ostringstream os;
183-
os << "output of traced region did not have observable "
184-
<< "data dependence with trace inputs; this probably indicates your program "
185-
<< "cannot be understood by the tracer.";
186-
throw std::runtime_error(os.str());
172+
bool TracingState::hasValue(const IValue& var) const {
173+
if (var.isTensor()) {
174+
at::Tensor t = var.toTensor();
175+
for(const auto & frame : env_stack) {
176+
if (frame.value_map.count(t)) {
177+
return true;
178+
}
179+
}
187180
}
188-
return it->second;
189-
}
190-
191-
Value* getNestedOutputTrace(
192-
const std::shared_ptr<TracingState>& state,
193-
const IValue& iv) {
194-
if (iv.isTensor()) {
195-
return getOutputTrace(state, iv.toTensor());
181+
return false;
182+
}
183+
184+
185+
Value* TracingState::getOutput(const IValue& iv) {
186+
if (iv.isTensor()) {
187+
at::Tensor var = iv.toTensor();
188+
if (!var.defined()) {
189+
Node *n = graph->createNone(TensorType::get());
190+
return graph->insertNode(n)->output();
191+
}
192+
193+
auto &value_map = getTracingState()->env_stack.back().value_map;
194+
auto it = value_map.find(var);
195+
if (it == value_map.end()) {
196+
std::ostringstream os;
197+
os << "output of traced region did not have observable "
198+
<< "data dependence with trace inputs; this probably indicates your "
199+
"program "
200+
<< "cannot be understood by the tracer.";
201+
throw std::runtime_error(os.str());
202+
}
203+
return it->second;
196204
} else if (iv.isTuple()) {
197205
const auto& elems = iv.toTuple()->elements();
198206
auto tuple_node =
199-
state->graph->createTuple(fmap(elems, [&state](const IValue& ival) {
200-
return getNestedOutputTrace(state, ival);
207+
graph->createTuple(fmap(elems, [&](const IValue& ival) {
208+
return getOutput(ival);
201209
}));
202-
state->graph->insertNode(tuple_node);
210+
graph->insertNode(tuple_node);
203211
return tuple_node->output();
204212
} else {
205213
AT_ERROR(
@@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
213221
if (type->isSubtypeOf(TensorType::get())) {
214222
auto input_tensor = input.toTensor();
215223
auto name = Variable(input_tensor).name();
216-
auto& value_map = state->env_stack.back().value_map;
217-
if (value_map.find(input_tensor) != value_map.end()) {
224+
if (state->hasValue(input)) {
218225
input_tensor = input_tensor.view(input_tensor.sizes());
219226
}
220227
value->setUniqueName(name);
221-
value_map[input_tensor] = value;
228+
state->setValue(input_tensor, value);
222229
return input_tensor;
223230
} else if (auto tuple_type = type->cast<TupleType>()) {
224231
auto unpack_node =
@@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
330337
auto& state = getTracingState();
331338
size_t i = 0;
332339
for (auto& output : outputs) {
333-
state->graph->registerOutput(getNestedOutputTrace(state, output));
340+
state->graph->registerOutput(state->getOutput(output));
334341
i++;
335342
}
336343
setTracingState(nullptr);
@@ -342,36 +349,36 @@ void abandon() {
342349
}
343350

344351
void setValueTrace(const IValue& v, Value* value) {
352+
return getTracingState()->setValue(v, value);
353+
}
354+
void TracingState::setValue(const IValue& v, Value* value) {
345355
if (v.isTensor()) {
346356
auto var = v.toTensor();
347357
AT_ASSERT(var.defined());
348-
getTracingState()->env_stack.back().value_map[var] = value;
358+
env_stack.back().value_map[var] = value;
349359
} else if (v.isTensorList()) {
350360
auto& outputs = v.toTensorList()->elements();
351-
auto graph = getTracingState()->graph;
352361
Node* unpack_node =
353362
graph->insertNode(graph->createListUnpack(value, outputs.size()));
354363
for (size_t i = 0; i < outputs.size(); ++i) {
355-
setValueTrace(outputs[i], unpack_node->outputs()[i]);
364+
setValue(outputs[i], unpack_node->outputs()[i]);
356365
}
357366
} else if (v.isTuple()) {
358367
auto& outputs = v.toTuple()->elements();
359-
auto graph = getTracingState()->graph;
360368
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
361369
for (size_t i = 0; i < outputs.size(); ++i) {
362-
setValueTrace(outputs[i], unpack_node->outputs()[i]);
370+
setValue(outputs[i], unpack_node->outputs()[i]);
363371
}
364372
} else if (v.isGenericList()) {
365373
auto elements = v.toGenericListRef();
366-
auto graph = getTracingState()->graph;
367374
Node* unpack_node =
368375
graph->insertNode(graph->createListUnpack(value, elements.size()));
369376
for (size_t i = 0; i < elements.size(); ++i) {
370-
setValueTrace(elements[i], unpack_node->outputs()[i]);
377+
setValue(elements[i], unpack_node->outputs()[i]);
371378
}
372379
} else if (v.isFuture()) {
373380
auto fut = v.toFuture();
374-
getTracingState()->env_stack.back().future_map[fut] = value;
381+
env_stack.back().future_map[fut] = value;
375382
} else {
376383
std::ostringstream os;
377384
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
@@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
560567
}
561568

562569
TracingState::TracingState()
563-
: env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {}
570+
: graph(new Graph()), env_stack{Frame()} {}
564571

565572
TracingState::~TracingState() = default;
566573

torch/csrc/jit/tracer.h

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@ struct TORCH_API TracingState
5252
TracingState();
5353
~TracingState();
5454

55+
std::shared_ptr<Graph> graph;
56+
bool warn = true;
57+
bool force_outplace = false;
58+
std::function<std::string(const Variable& var)> lookup_var_name_fn =
59+
[](const Variable& var) { return ""; };
60+
61+
void enterFrame() {
62+
env_stack.emplace_back();
63+
}
64+
65+
void leaveFrame() {
66+
env_stack.pop_back();
67+
}
68+
69+
void setValue(const IValue& v, Value* value);
70+
void delValue(const IValue& var);
71+
Value* getValue(const IValue& var);
72+
Value* getOutput(const IValue& var);
73+
bool hasValue(const IValue& var) const;
74+
75+
private:
5576
using WeakTensor = at::WeakTensor;
5677

5778
struct WeakTensorHasher {
@@ -66,22 +87,16 @@ struct TORCH_API TracingState
6687
}
6788
};
6889

69-
struct TracingEnvironmentFrame {
90+
struct Frame {
7091
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
7192
value_map;
7293
// TODO weak refcount
7394
std::unordered_map<c10::intrusive_ptr<c10::ivalue::Future>, Value*>
7495
future_map;
7596
};
7697

77-
using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
98+
std::vector<Frame> env_stack;
7899

79-
TracingEnvironmentStack env_stack;
80-
std::shared_ptr<Graph> graph;
81-
bool warn = true;
82-
bool force_outplace = false;
83-
std::function<std::string(const Variable& var)> lookup_var_name_fn =
84-
[](const Variable& var) { return ""; };
85100
};
86101

87102
// This is meant to be used as a thread local place, where we can store extra
@@ -182,11 +197,11 @@ struct TORCH_API NoWarn {
182197

183198
struct WithNestedTracingFrame {
184199
WithNestedTracingFrame() {
185-
getTracingState()->env_stack.emplace_back();
200+
getTracingState()->enterFrame();
186201
}
187202

188203
~WithNestedTracingFrame() {
189-
getTracingState()->env_stack.pop_back();
204+
getTracingState()->leaveFrame();
190205
}
191206
};
192207
TORCH_API void recordSourceLocation(Node* n);
@@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
197212
// involving this variable know which node in the IR to reference.
198213
TORCH_API void setValueTrace(const IValue& v, Value* value);
199214

200-
TORCH_API void delValueTrace(const Variable& var);
215+
TORCH_API void delValueTrace(const IValue& var);
201216

202217
TORCH_API std::function<void()> pauseTracing();
203218

204219
TORCH_API Value* getValueTrace(const IValue& var);
205220

206-
TORCH_API Value* getOutputTrace(
207-
const std::shared_ptr<TracingState>& state,
208-
const Variable& var);
209-
210-
TORCH_API Value* getNestedOutputTrace(
211-
const std::shared_ptr<TracingState>& state,
212-
const IValue& iv);
213-
214221
struct TypedStack : public std::pair<Stack, TupleTypePtr>
215222
{
216223
using pair::pair;

0 commit comments

Comments
 (0)