Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,47 @@ def f(x, y):
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
self.assertExpectedTrace(trace)

class Net(nn.Module):

This comment was marked as off-topic.

def forward(self, x):
return F.log_softmax(x, dim=0)

net = Net()
t = Variable(torch.ones(2), requires_grad=True)
trace, _ = torch.jit.trace(net, (t, ))
torch.onnx._optimize_trace(trace, False)
g = torch._C._jit_get_graph(trace)
for node in g.nodes():
self.assertTrue(node.scopeName() == 'Net')

class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)

def forward(self, x):
x = self.features(x)
return x

model = Net()

t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)

with torch.onnx.set_training(model, False):
trace, _ = torch.jit.trace(model, (t, ))

torch.onnx._optimize_trace(trace, False)
graph = torch._C._jit_get_graph(trace)
nodes = list(graph.nodes())

self.assertTrue(nodes[0].scopeName() == 'Net/Sequential[features]/Conv2d[0]')
self.assertTrue(nodes[1].scopeName() == 'Net/Sequential[features]/ReLU[1]')
self.assertTrue(nodes[2].scopeName() == 'Net/Sequential[features]/MaxPool2d[2]')

This comment was marked as off-topic.


@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_fusion(self):
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,17 @@ friend struct Value;
Scope * current_scope() {
return current_scope_;
}
void set_current_scope(Scope* scope) {
if (scope->getRoot() != scope_root_.get()) {
throw std::runtime_error("trying to set a scope as current that does not belong to the Graph's scope trie");
}
current_scope_ = scope;
}
ResourceGuard set_current_scope_temporary(Scope* scope) {
auto prev_scope = current_scope_;
this->set_current_scope(scope);
return ResourceGuard([prev_scope, this]() { this->current_scope_ = prev_scope; });
}
std::shared_ptr<Scope> scope_root() {
return scope_root_;
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state, bool aten) {
// Copy over source location information to all nodes created by
// the symbolic
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
outputs[i]->node()->setScope(node->scope());
env[old] = outputs[i];
} else {
// Null output means that the ONNX op doesn't have outputs corresponding
Expand Down Expand Up @@ -151,6 +150,8 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state, bool aten) {
py_inputs[input_nr++] = py::cast(envFn(input));
}

auto scope_guard = ctx.graph->set_current_scope_temporary(n->scope());

py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.graph, n, py_inputs, aten);

processSymbolicOutput(symbolToString(n->kind()), n, raw_output);
Expand Down