Skip to content

Commit 4eb8e12

Browse files
lantigaapaszke
authored andcommitted
Introduce scopes during tracing (#3016)
1 parent 7ddcb91 commit 4eb8e12

File tree

11 files changed

+242
-11
lines changed

11 files changed

+242
-11
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
graph(%0 : Double(20, 16, 50, 40)
22
%1 : Double(13, 16, 3, 3)) {
3-
%2 : UNKNOWN_TYPE = Undefined()
4-
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2)
3+
%2 : UNKNOWN_TYPE = Undefined(), scope: Conv2d
4+
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2), scope: Conv2d
55
return (%3);
66
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
graph(%0 : Double(2, 2)) {
2-
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0)
2+
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
33
return (%1);
44
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
graph(%0 : Double(1)
2+
%1 : Double(1)) {
3+
%2 : Double(1) = add[alpha={1}](%0, %1)
4+
%3 : Double(1) = mul(%0, %2), scope: Foo
5+
%4 : Double(1) = tanh(%3), scope: Foo/Bar
6+
%5 : Double(1) = sigmoid(%4), scope: Foo
7+
return (%5);
8+
}

test/test_jit.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ def f(x, y):
6969
torch._C._jit_pass_lint(trace)
7070
self.assertExpected(str(trace))
7171

72+
def test_scopes(self):
73+
x = Variable(torch.Tensor([0.4]), requires_grad=True)
74+
y = Variable(torch.Tensor([0.7]), requires_grad=True)
75+
76+
def f(x, y):
77+
out = x + y
78+
with torch.jit.scope('Foo', out):
79+
out = x * out
80+
with torch.jit.scope('Bar', out):
81+
out = torch.tanh(out)
82+
out = torch.sigmoid(out)
83+
return out
84+
85+
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
86+
torch._C._jit_pass_lint(trace)
87+
self.assertExpected(str(trace))
88+
7289
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
7390
def test_lstm_fusion(self):
7491
input = Variable(torch.randn(3, 10).float().cuda())

torch/csrc/jit/ir.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
239239
printAttributes(out,n);
240240
}
241241
IR_END()
242-
out << "(" << n->inputs() << ")\n";
242+
out << "(" << n->inputs() << ")";
243+
std::string scopeName = n->scopeName();
244+
if (scopeName.empty()) {
245+
out << "\n";
246+
}
247+
else {
248+
out << ", ";
249+
out << "scope: " << scopeName << "\n";
250+
}
243251
return out;
244252
}
245253

torch/csrc/jit/ir.h

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,64 @@ struct SourceLocation {
7070
std::string python_traceback;
7171
};
7272

73+
// Scope is a node of a trie that represents the tree of nested scopes.
74+
// Individual scopes are pushed and popped from Graph, which holds a
75+
// pointer to the current scope. Each Node in Graph holds a pointer
76+
// to the scope that was current when the node was created.
77+
// The trie never needs to shrink, it only grows until it is disposed
78+
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
79+
// will always be valid as long as Graph is alive.
80+
struct Scope {
81+
private:
82+
Scope* parent_;
83+
Symbol name_;
84+
std::vector<std::unique_ptr<Scope> > children_;
85+
public:
86+
Scope() {
87+
name_ = stringToSymbol("");
88+
parent_ = NULL;
89+
}
90+
Scope(Scope* parent, Symbol name) {
91+
name_ = name;
92+
parent_ = parent;
93+
}
94+
Scope* push(Symbol name) {
95+
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
96+
return children_.back().get();
97+
}
98+
Scope* parent() {
99+
if (parent_ == NULL) {
100+
throw std::runtime_error("Cannot get parent from Scope with no parent");
101+
}
102+
return parent_;
103+
}
104+
bool isRoot() {
105+
return parent_ == NULL;
106+
}
107+
Scope* getRoot() {
108+
Scope* current = this;
109+
while (current->parent_) {
110+
current = current->parent_;
111+
}
112+
return current;
113+
}
114+
Symbol name() {
115+
return name_;
116+
}
117+
std::string namesFromRoot(const std::string& separator="/") {
118+
std::string out = std::string(symbolToString(this->name_));
119+
if (this->isRoot()) {
120+
return out;
121+
}
122+
Scope* parent = this->parent_;
123+
while (!parent->isRoot()) {
124+
out = std::string(symbolToString(parent->name_)) + separator + out;
125+
parent = parent->parent_;
126+
}
127+
return out;
128+
}
129+
};
130+
73131
// the list types are intentionally simple, but we type-def
74132
// them here so if we need to change them, refactoring will be easier
75133
using node_list = std::vector<Node*>;
@@ -139,6 +197,9 @@ struct Value {
139197
const Node * node() const {
140198
return node_;
141199
}
200+
Scope* scope();
201+
void setScope(Scope* scope);
202+
std::string scopeName() const;
142203
Graph * owningGraph();
143204
const Graph * owningGraph() const;
144205
// TODO: make this more const correct
@@ -197,6 +258,7 @@ struct Node : public Attributes<Node> {
197258
Graph* graph_;
198259
std::shared_ptr<SourceLocation> source_location_;
199260
size_t stage_;
261+
Scope* scope_;
200262
protected:
201263
Node(Graph * graph_, NodeKind kind_); //defined after graph
202264
public:
@@ -223,6 +285,18 @@ struct Node : public Attributes<Node> {
223285
stage_ = s;
224286
return this;
225287
}
288+
Scope* scope() {
289+
return scope_;
290+
}
291+
void setScope(Scope* scope) {
292+
scope_ = scope;
293+
}
294+
std::string scopeName() const {
295+
if (scope_ == NULL) {
296+
return "";
297+
}
298+
return scope_->namesFromRoot();
299+
}
226300
// NB: This returns an ArrayRef; that means that it will
227301
// get invalidated if you resize inputs (e.g., using addInput)
228302
// We can't return a std::vector<Node*>& because there's no
@@ -534,6 +608,7 @@ struct Node : public Attributes<Node> {
534608
// if you are going to preserve it.
535609
virtual void cloneFrom(Node * s) {
536610
setSourceLocation(s->getSourceLocation());
611+
scope_ = s->scope_;
537612
copyAttributes(*s);
538613
}
539614
};
@@ -556,6 +631,9 @@ friend struct Value;
556631

557632
size_t new_node_stage_;
558633

634+
std::shared_ptr<Scope> scope_root_;
635+
Scope * current_scope_;
636+
559637
// holds outputs in a way that can be reflected
560638
// as a Use object
561639
// also used as the beginning/end of the circular node list to avoid
@@ -564,11 +642,17 @@ friend struct Value;
564642
Node * const input_;
565643

566644
public:
567-
Graph()
645+
646+
Graph(std::shared_ptr<Scope> scope_root)
568647
: next_unique_(0)
569648
, new_node_stage_(0)
649+
, scope_root_(scope_root)
650+
, current_scope_(scope_root_.get())
570651
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {}
571652

653+
Graph()
654+
: Graph( std::make_shared<Scope>()) {}
655+
572656
at::ArrayRef<Value*> inputs() {
573657
return input_->outputs();
574658
}
@@ -621,6 +705,18 @@ friend struct Value;
621705
const Node * return_node() const {
622706
return output_;
623707
}
708+
void push_scope(const std::string& scope_name) {
709+
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
710+
}
711+
void pop_scope() {
712+
current_scope_ = current_scope_->parent();
713+
}
714+
Scope * current_scope() {
715+
return current_scope_;
716+
}
717+
std::shared_ptr<Scope> scope_root() {
718+
return scope_root_;
719+
}
624720
Value * addInput(std::string name="") {
625721
Value * v = input_->addOutput();
626722
if (name != "") v->setUniqueName(name);
@@ -676,7 +772,8 @@ friend struct Value;
676772
}
677773
Node * createFusionGroup() {
678774
auto n = create(kFusionGroup, 0);
679-
n->g_(kSubgraph,std::make_shared<Graph>());
775+
auto subgraph = std::make_shared<Graph>(scope_root_);
776+
n->g_(kSubgraph, subgraph);
680777
return n;
681778
}
682779
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args);
@@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_)
759856
node_->graph_->all_values.emplace(this);
760857
}
761858

859+
inline Scope* Value::scope() {
860+
return node()->scope();
861+
}
862+
863+
inline void Value::setScope(Scope* scope) {
864+
node()->setScope(scope);
865+
}
866+
867+
inline std::string Value::scopeName() const {
868+
return node()->scopeName();
869+
}
870+
762871
inline Graph * Value::owningGraph() {
763872
return node()->owningGraph();
764873
}
@@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
779888
inline Node::Node(Graph * graph_, NodeKind kind_) :
780889
kind_(kind_),
781890
graph_(graph_),
782-
stage_(graph_->new_node_stage_) {
891+
stage_(graph_->new_node_stage_),
892+
scope_(graph_->current_scope_) {
783893
graph_->all_nodes.emplace(this);
784894
}
785895

torch/csrc/jit/passes/onnx.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
3131
throw std::logic_error("ToONNX: tracing state is expired");
3232
}
3333

34-
auto new_graph = std::make_shared<Graph>();
34+
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
3535
std::unordered_map<void*, Value*> new_buffer_map;
3636

3737
torch::autograd::SymbolicContext ctx;
@@ -137,6 +137,10 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
137137
throw std::runtime_error(ss.str());
138138
}
139139

140+
for (auto& el: outputs) {
141+
el->setScope(n->scope());
142+
}
143+
140144
setOutputs(op_name, n, outputs);
141145
};
142146

@@ -208,6 +212,9 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
208212
IR_IFM(node, CppOp)
209213
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
210214
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn), node->getSourceLocation());
215+
for (auto& el: outputs) {
216+
el->setScope(node->scope());
217+
}
211218
setOutputs(value->name(), node, outputs);
212219
} else {
213220
cloneNode(node);

torch/csrc/jit/python_tracer.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace torch { namespace jit {
1919

2020
void initPythonTracerBindings(PyObject* module_) {
2121
auto m = py::handle(module_).cast<py::module>();
22-
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
22+
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
2323
// NB: no constructor; you have to get it from C++ code
2424
.def("__repr__", [](const TracingState& s) {
2525
std::ostringstream ss;
@@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
3232
ss << *s.graph;
3333
return ss.str();
3434
})
35+
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
36+
ASSERT_UNEXPIRED("push_scope");
37+
s.push_scope(scope_name);
38+
})
39+
.def("pop_scope", [](TracingState& s) {
40+
ASSERT_UNEXPIRED("pop_scope");
41+
s.pop_scope();
42+
})
3543
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
3644
ASSERT_UNEXPIRED("export");
3745
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
@@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
5260
m.def("_tracer_exit", [](variable_list var_outputs) {
5361
tracer::exit(var_outputs);
5462
});
63+
m.def("_get_tracing_state", [](const variable_list& vars) {
64+
return getTracingState(vars);
65+
});
66+
m.def("_is_tracing", [](const variable_list& vars) {
67+
return isTracing(vars);
68+
});
5569
}
5670

5771
}} // namespace torch::jit

torch/csrc/jit/tracer_state.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
7474
bool is_complete() const {
7575
return !is_expired() && graph->stage() == num_stages - 1;
7676
}
77+
78+
void push_scope(const std::string& scope_name) {
79+
graph->push_scope(scope_name);
80+
}
81+
82+
void pop_scope() {
83+
graph->pop_scope();
84+
}
7785
};
7886

7987
struct ValueTracingStateElem {

torch/jit/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,30 @@
1919
_flatten = torch._C._jit_flatten
2020

2121

22+
# This global variable is set when we are tracing a *forwards* computation.
23+
# It is intended to be a cheap way to test if tracing has occurred, before
24+
# doing the slower path using `get_tracing_state` (below.)
25+
_tracing = False
26+
27+
28+
def get_tracing_state(args):
29+
if not torch._C._is_tracing(args):
30+
return None
31+
return torch._C._get_tracing_state(args)
32+
33+
34+
@contextlib.contextmanager
35+
def scope(scope_name, *vars):
36+
tracing_state = get_tracing_state(vars)
37+
if tracing_state:
38+
tracing_state.push_scope(scope_name)
39+
try:
40+
yield
41+
finally:
42+
if tracing_state:
43+
tracing_state.pop_scope()
44+
45+
2246
def compile(arg=None, nderivs=1, optimize=True, enabled=True):
2347
"""
2448
Decorator which marks a function or module class as eligible for
@@ -237,13 +261,16 @@ def __init__(self, inner, nderivs=0):
237261
self.nderivs = nderivs
238262

239263
def forward(self, *args):
264+
global _tracing
240265
in_vars = _flatten(args)
241266
# NOTE: use full state, because we need it for BatchNorm export
242267
# This differs from the compiler path, which doesn't support it at the moment.
243268
module_state = list(self.state_dict(keep_vars=True).values())
244269
trace = torch._C._tracer_enter(in_vars + module_state, self.nderivs)
270+
_tracing = True
245271
out = self.inner(*args)
246272
out_vars = _flatten(out)
273+
_tracing = False
247274
torch._C._tracer_exit(out_vars)
248275
return trace, out
249276

0 commit comments

Comments
 (0)