Skip to content

Commit 2b47480

Browse files
lantigasoumith
authored andcommitted
Scopes 0.3.1 backport (#5153)
* Introduce scopes during tracing (#3016) * Fix segfault during ONNX export * Further fix to tracing scope (#4558) * Set missing temporary scope in callPySymbolicMethod * Use expected traces in all scope tests * Fix tracking of tracing scopes during ONNX pass (#4524) * Fix tracking of tracing scopes during ONNX pass * Use ResourceGuard to manage setting a temporary current scope in Graph * Add tests for ONNX pass scopes * Remove unused num_classes argument * Expose node scopeName to python (#4200) * Inherit JIT scopes when cloning only when it's correct It's correct only when the new graph owns the same scope tree as the original one. We can end up with dangling pointers otherwise. * Fixes after cherry-picking, still one test to go * Fix for last failing test after scope cherry-pick * Fix linting issue
1 parent 902d57b commit 2b47480

15 files changed

+310
-19
lines changed

test/expect/TestJit.test_batchnorm.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ graph(%1 : Double(2, 2)
33
%3 : Double(2)
44
%4 : Double(2)
55
%5 : Double(2)) {
6-
%7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []];
6+
%7 : Double(2, 2), %8 : Handle = CppOp[N5torch8autograd16BatchNormForwardE](%1, %2, %3), uses = [[%0.i0], []], scope: BatchNorm2d;
77
return (%7);
88
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
graph(%1 : Double(20, 16, 50, 40)
22
%2 : Double(13, 16, 3, 3)) {
3-
%4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2];
4-
%5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []];
3+
%4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2], scope: Conv2d;
4+
%5 : Double(20, 13, 48, 38), %6 : Handle = CppOp[ConvForward](%1, %2, %4), uses = [[%0.i0], []], scope: Conv2d;
55
return (%5);
66
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
graph(%1 : Double(2, 2)) {
2-
%3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []];
2+
%3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []], scope: Dropout;
33
return (%3);
44
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
graph(%1 : Double(1)
2+
%2 : Double(1)) {
3+
%3 : Double(1) = add[alpha={1}](%1, %2), uses = [%4.i1];
4+
%4 : Double(1) = mul(%1, %3), uses = [%5.i0], scope: Foo;
5+
%5 : Double(1) = tanh(%4), uses = [%6.i0], scope: Foo/Bar;
6+
%6 : Double(1) = sigmoid(%5), uses = [%0.i0], scope: Foo;
7+
return (%6);
8+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
graph(%1 : Double(1, 3, 227, 227)
2+
%2 : Double(64, 3, 11, 11)
3+
%3 : Double(64)) {
4+
%5 : UNKNOWN_TYPE = Conv[kernel_shape=[11, 11], strides=[4, 4], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%1, %2), uses = [[%6.i0]], scope: Net/Sequential[features]/Conv2d[0];
5+
%6 : Double(1, 64, 56, 56) = Add[broadcast=1, axis=1](%5, %3), uses = [%7.i0], scope: Net/Sequential[features]/Conv2d[0];
6+
%7 : Double(1, 64, 56, 56) = Relu(%6), uses = [%8.i0], scope: Net/Sequential[features]/ReLU[1];
7+
%8 : Double(1, 64, 27, 27) = MaxPool[kernel_shape=[3, 3], pads=[0, 0], strides=[2, 2]](%7), uses = [%0.i0], scope: Net/Sequential[features]/MaxPool2d[2];
8+
return (%8);
9+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
graph(%1 : Double(2)) {
2+
%2 : Double(2) = Softmax[axis=0](%1), uses = [%3.i0], scope: Net;
3+
%3 : Double(2) = Log(%2), uses = [%0.i0], scope: Net;
4+
return (%3);
5+
}

test/test_jit.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def LSTMCellC(*args, **kwargs):
5050
class TestJit(TestCase):
5151
maxDiff = None
5252

53+
def assertExpectedTrace(self, trace, *args, **kwargs):
54+
torch._C._jit_pass_lint(trace)
55+
torch._C._jit_pass_dce(trace)
56+
torch._C._jit_pass_lint(trace)
57+
self.assertExpected(str(trace), *args, **kwargs)
58+
5359
def test_simple(self):
5460
x = Variable(torch.Tensor([0.4]), requires_grad=True)
5561
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@@ -61,6 +67,63 @@ def f(x, y):
6167
torch._C._jit_pass_lint(trace)
6268
self.assertExpected(str(trace))
6369

70+
def test_scopes(self):
71+
x = Variable(torch.Tensor([0.4]), requires_grad=True)
72+
y = Variable(torch.Tensor([0.7]), requires_grad=True)
73+
74+
def f(x, y):
75+
out = x + y
76+
with torch.jit.scope('Foo', out):
77+
out = x * out
78+
with torch.jit.scope('Bar', out):
79+
out = torch.tanh(out)
80+
out = torch.sigmoid(out)
81+
return out
82+
83+
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
84+
torch._C._jit_pass_lint(trace)
85+
self.assertExpected(str(trace))
86+
87+
def test_scopes_intermediate_node(self):
88+
89+
class Net(nn.Module):
90+
def forward(self, x):
91+
return F.log_softmax(x, dim=0)
92+
93+
net = Net()
94+
t = Variable(torch.ones(2), requires_grad=True)
95+
trace, _ = torch.jit.trace(net, (t, ))
96+
torch.onnx._optimize_trace(trace)
97+
98+
self.assertExpectedTrace(trace)
99+
100+
def test_scopes_identity_node(self):
101+
102+
class Net(nn.Module):
103+
104+
def __init__(self):
105+
super(Net, self).__init__()
106+
self.features = nn.Sequential(
107+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
108+
nn.ReLU(inplace=True),
109+
nn.MaxPool2d(kernel_size=3, stride=2),
110+
)
111+
112+
def forward(self, x):
113+
x = self.features(x)
114+
return x
115+
116+
model = Net()
117+
118+
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
119+
120+
with torch.onnx.set_training(model, False):
121+
trace, _ = torch.jit.trace(model, (t, ))
122+
123+
torch.onnx._optimize_trace(trace)
124+
125+
self.assertExpectedTrace(trace)
126+
64127
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
65128
def test_lstm_fusion(self):
66129
input = Variable(torch.randn(3, 10).cuda())

torch/csrc/jit/ir.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
263263
} else {
264264
emitUses(out,n);
265265
}
266-
out << "];\n";
266+
out << "]";
267+
std::string scopeName = n->scopeName();
268+
if (scopeName.empty()) {
269+
out << ";\n";
270+
}
271+
else {
272+
out << ", ";
273+
out << "scope: " << scopeName << ";\n";
274+
}
267275
return out;
268276
}
269277

torch/csrc/jit/ir.h

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,64 @@ struct SourceLocation {
6969
std::string python_traceback;
7070
};
7171

72+
// Scope is a node of a trie that represents the tree of nested scopes.
73+
// Individual scopes are pushed and popped from Graph, which holds a
74+
// pointer to the current scope. Each Node in Graph holds a pointer
75+
// to the scope that was current when the node was created.
76+
// The trie never needs to shrink, it only grows until it is disposed
77+
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
78+
// will always be valid as long as Graph is alive.
79+
struct Scope {
80+
private:
81+
Scope* parent_;
82+
Symbol name_;
83+
std::vector<std::unique_ptr<Scope> > children_;
84+
public:
85+
Scope() {
86+
name_ = stringToSymbol("");
87+
parent_ = NULL;
88+
}
89+
Scope(Scope* parent, Symbol name) {
90+
name_ = name;
91+
parent_ = parent;
92+
}
93+
Scope* push(Symbol name) {
94+
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
95+
return children_.back().get();
96+
}
97+
Scope* parent() {
98+
if (parent_ == NULL) {
99+
throw std::runtime_error("Cannot get parent from Scope with no parent");
100+
}
101+
return parent_;
102+
}
103+
bool isRoot() {
104+
return parent_ == NULL;
105+
}
106+
Scope* getRoot() {
107+
Scope* current = this;
108+
while (current->parent_) {
109+
current = current->parent_;
110+
}
111+
return current;
112+
}
113+
Symbol name() {
114+
return name_;
115+
}
116+
std::string namesFromRoot(const std::string& separator="/") {
117+
std::string out = std::string(symbolToString(this->name_));
118+
if (this->isRoot()) {
119+
return out;
120+
}
121+
Scope* parent = this->parent_;
122+
while (!parent->isRoot()) {
123+
out = std::string(symbolToString(parent->name_)) + separator + out;
124+
parent = parent->parent_;
125+
}
126+
return out;
127+
}
128+
};
129+
72130
// the list types are intentionally simple, but we type-def
73131
// them here so if we need to change them, refactoring will be easier
74132
using node_list = std::vector<Node*>;
@@ -123,6 +181,7 @@ struct Node : public Attributes<Node> {
123181
size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
124182
std::string debug_name_;
125183
std::shared_ptr<SourceLocation> source_location_;
184+
Scope* scope_;
126185
protected:
127186
TypePtr type_;
128187
Node(Graph * graph_, NodeKind kind_); //defined after graph
@@ -188,6 +247,18 @@ struct Node : public Attributes<Node> {
188247
size_t stage() const {
189248
return stage_;
190249
}
250+
Scope* scope() {
251+
return scope_;
252+
}
253+
void setScope(Scope* scope) {
254+
scope_ = scope;
255+
}
256+
std::string scopeName() const {
257+
if (scope_ == NULL) {
258+
return "";
259+
}
260+
return scope_->namesFromRoot();
261+
}
191262
// NB: This returns an ArrayRef; that means that it will
192263
// get invalidated if you resize inputs (e.g., using addInput)
193264
// We can't return a std::vector<Node*>& because there's no
@@ -528,12 +599,7 @@ struct Node : public Attributes<Node> {
528599
//
529600
// NB: This does NOT clone stages. You're expected to set the stage correctly
530601
// if you are going to preserve it.
531-
virtual void cloneFrom(Node * s) {
532-
if (s->hasType()) setType(s->type());
533-
setDebugName(s->debugName());
534-
setSourceLocation(s->getSourceLocation());
535-
copyAttributes(*s);
536-
}
602+
virtual void cloneFrom(Node * s);
537603
};
538604

539605
struct Graph {
@@ -551,18 +617,27 @@ friend struct Node;
551617

552618
size_t new_node_stage_;
553619

620+
std::shared_ptr<Scope> scope_root_;
621+
Scope * current_scope_;
622+
554623
// holds outputs in a way that can be reflected
555624
// as a Use object
556625
// also used as the beginning/end of the circular node list to avoid
557626
// having corner cases where the list is empty.
558627
Node * const output_;
559628

560629
public:
561-
Graph()
630+
631+
Graph(std::shared_ptr<Scope> scope_root)
562632
: next_unique_(0)
563633
, new_node_stage_(0)
634+
, scope_root_(scope_root)
635+
, current_scope_(scope_root_.get())
564636
, output_(initOutput(create(kReturn))) {}
565637

638+
Graph()
639+
: Graph( std::make_shared<Scope>()) {}
640+
566641
at::ArrayRef<Node*> inputs() {
567642
return inputs_;
568643
}
@@ -618,6 +693,29 @@ friend struct Node;
618693
Node * addInput() {
619694
return addInput(create(kParam));
620695
}
696+
void push_scope(const std::string& scope_name) {
697+
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
698+
}
699+
void pop_scope() {
700+
current_scope_ = current_scope_->parent();
701+
}
702+
Scope * current_scope() {
703+
return current_scope_;
704+
}
705+
void set_current_scope(Scope* scope) {
706+
if (scope->getRoot() != scope_root_.get()) {
707+
throw std::runtime_error("trying to set a scope as current that does not belong to the Graph's scope trie");
708+
}
709+
current_scope_ = scope;
710+
}
711+
ResourceGuard set_current_scope_temporary(Scope* scope) {
712+
auto prev_scope = current_scope_;
713+
this->set_current_scope(scope);
714+
return ResourceGuard([prev_scope, this]() { this->current_scope_ = prev_scope; });
715+
}
716+
std::shared_ptr<Scope> scope_root() {
717+
return scope_root_;
718+
}
621719

622720
Node * addInput(Node* n) {
623721
JIT_ASSERT(n->kind() == kParam);
@@ -694,7 +792,8 @@ friend struct Node;
694792
}
695793
Node * createFusionGroup() {
696794
auto n = create(kFusionGroup);
697-
n->g_(kSubgraph,std::make_shared<Graph>());
795+
auto subgraph = std::make_shared<Graph>(scope_root_);
796+
n->g_(kSubgraph, subgraph);
698797
return n;
699798
}
700799
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, pyobj_list&& scalar_args);
@@ -764,9 +863,10 @@ inline Node::Node(Graph * graph_, NodeKind kind_) :
764863
graph_(graph_),
765864
unique_(graph_->next_unique_++),
766865
stage_(graph_->new_node_stage_),
866+
scope_(graph_->current_scope_) ,
767867
type_(getInitialType(kind_)) {
768-
graph_->all_nodes.emplace(this);
769-
}
868+
graph_->all_nodes.emplace(this);
869+
}
770870

771871
inline void Node::destroy() {
772872
JIT_ASSERT(inGraphList());
@@ -788,6 +888,16 @@ inline Node* Node::makeMultireturn() {
788888
return select;
789889
}
790890

891+
inline void Node::cloneFrom(Node * s) {
892+
if (s->hasType()) setType(s->type());
893+
setDebugName(s->debugName());
894+
setSourceLocation(s->getSourceLocation());
895+
if (s->owningGraph()->scope_root_ == owningGraph()->scope_root_) {
896+
scope_ = s->scope_;
897+
}
898+
copyAttributes(*s);
899+
}
900+
791901
// Helper macros for constructing switch statements over Node types
792902
// instead of heavy-weight visitors
793903
// read 'between' these defines to see how they turn into a big switch

torch/csrc/jit/passes/onnx.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
3333
throw std::logic_error("ToONNX: tracing state is expired");
3434
}
3535

36-
auto new_graph = std::make_shared<Graph>();
36+
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
3737
std::unordered_map<void*, Node*> new_buffer_map;
3838

3939
torch::autograd::SymbolicContext ctx;
@@ -159,6 +159,8 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
159159
py_inputs[input_nr++] = py::cast(envFn(input));
160160
}
161161

162+
auto scope_guard = ctx.graph->set_current_scope_temporary(n->scope());
163+
162164
py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.graph, n, py_inputs);
163165

164166
processSymbolicOutput(symbolToString(n->kind()), n, raw_output);
@@ -195,6 +197,8 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
195197
py_symbolic_args[input_nr++] = obj;
196198
}
197199

200+
auto scope_guard = ctx.graph->set_current_scope_temporary(op->scope());
201+
198202
// Call the symbolic function
199203
// Use a little trampoline function so we can give good error messages
200204
// upon argument mismatch
@@ -218,6 +222,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
218222
// Selects are translated by multi-return nodes.
219223
JIT_ASSERT(env.count(value) > 0);
220224
IR_ELSEIFM(CppOp)
225+
auto scope_guard = new_graph->set_current_scope_temporary(node->scope());
221226
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
222227
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn));
223228
setOutputs(value->name(), node, outputs);

0 commit comments

Comments
 (0)