Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/expect/TestJit.test_conv.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Double(20, 16, 50, 40)
%1 : Double(13, 16, 3, 3)) {
%2 : UNKNOWN_TYPE = Undefined()
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2)
%2 : UNKNOWN_TYPE = Undefined(), scope: Conv2d
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2), scope: Conv2d
return (%3);
}
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_dropout.expect
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
graph(%0 : Double(2, 2)) {
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0)
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
return (%1);
}
8 changes: 8 additions & 0 deletions test/expect/TestJit.test_scopes.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
graph(%0 : Double(1)
%1 : Double(1)) {
%2 : Double(1) = add[alpha={1}](%0, %1)
%3 : Double(1) = mul(%0, %2), scope: Foo
%4 : Double(1) = tanh(%3), scope: Foo/Bar
%5 : Double(1) = sigmoid(%4), scope: Foo
return (%5);
}
17 changes: 17 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,23 @@ def f(x, y):
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))

def test_scopes(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

def f(x, y):
out = x + y
with torch.jit.scope('Foo', out):
out = x * out
with torch.jit.scope('Bar', out):
out = torch.tanh(out)
out = torch.sigmoid(out)
return out

trace, z = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).float().cuda())
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
printAttributes(out,n);
}
IR_END()
out << "(" << n->inputs() << ")\n";
out << "(" << n->inputs() << ")";
std::string scopeName = n->scopeName();
if (scopeName.empty()) {
out << "\n";
}
else {
out << ", ";
out << "scope: " << scopeName << "\n";
}
return out;
}

Expand Down
116 changes: 113 additions & 3 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,64 @@ struct SourceLocation {
std::string python_traceback;
};

// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope {
private:
Scope* parent_;
Symbol name_;
std::vector<std::unique_ptr<Scope> > children_;
public:
Scope() {
name_ = stringToSymbol("");
parent_ = NULL;
}
Scope(Scope* parent, Symbol name) {
name_ = name;
parent_ = parent;
}
Scope* push(Symbol name) {
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
return children_.back().get();
}
Scope* parent() {
if (parent_ == NULL) {
throw std::runtime_error("Cannot get parent from Scope with no parent");
}
return parent_;
}
bool isRoot() {
return parent_ == NULL;
}
Scope* getRoot() {
Scope* current = this;
while (current->parent_) {
current = current->parent_;
}
return current;
}
Symbol name() {
return name_;
}
std::string namesFromRoot(const std::string& separator="/") {
std::string out = std::string(symbolToString(this->name_));
if (this->isRoot()) {
return out;
}
Scope* parent = this->parent_;
while (!parent->isRoot()) {
out = std::string(symbolToString(parent->name_)) + separator + out;
parent = parent->parent_;
}
return out;
}
};

// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
Expand Down Expand Up @@ -139,6 +197,9 @@ struct Value {
const Node * node() const {
return node_;
}
Scope* scope();
void setScope(Scope* scope);
std::string scopeName() const;
Graph * owningGraph();
const Graph * owningGraph() const;
// TODO: make this more const correct
Expand Down Expand Up @@ -197,6 +258,7 @@ struct Node : public Attributes<Node> {
Graph* graph_;
std::shared_ptr<SourceLocation> source_location_;
size_t stage_;
Scope* scope_;
protected:
Node(Graph * graph_, NodeKind kind_); //defined after graph
public:
Expand All @@ -223,6 +285,18 @@ struct Node : public Attributes<Node> {
stage_ = s;
return this;
}
Scope* scope() {
return scope_;
}
void setScope(Scope* scope) {
scope_ = scope;
}
std::string scopeName() const {
if (scope_ == NULL) {
return "";
}
return scope_->namesFromRoot();
}
// NB: This returns an ArrayRef; that means that it will
// get invalidated if you resize inputs (e.g., using addInput)
// We can't return a std::vector<Node*>& because there's no
Expand Down Expand Up @@ -534,6 +608,7 @@ struct Node : public Attributes<Node> {
// if you are going to preserve it.
virtual void cloneFrom(Node * s) {
setSourceLocation(s->getSourceLocation());
scope_ = s->scope_;
copyAttributes(*s);
}
};
Expand All @@ -556,6 +631,9 @@ friend struct Value;

size_t new_node_stage_;

std::shared_ptr<Scope> scope_root_;

This comment was marked as off-topic.

This comment was marked as off-topic.

Scope * current_scope_;

// holds outputs in a way that can be reflected
// as a Use object
// also used as the beginning/end of the circular node list to avoid
Expand All @@ -564,11 +642,17 @@ friend struct Value;
Node * const input_;

public:
Graph()

Graph(std::shared_ptr<Scope> scope_root)
: next_unique_(0)
, new_node_stage_(0)
, scope_root_(scope_root)
, current_scope_(scope_root_.get())
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {}

Graph()
: Graph( std::make_shared<Scope>()) {}

at::ArrayRef<Value*> inputs() {
return input_->outputs();
}
Expand Down Expand Up @@ -621,6 +705,18 @@ friend struct Value;
const Node * return_node() const {
return output_;
}
void push_scope(const std::string& scope_name) {
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
}
void pop_scope() {
current_scope_ = current_scope_->parent();
}
Scope * current_scope() {
return current_scope_;
}
std::shared_ptr<Scope> scope_root() {
return scope_root_;
}
Value * addInput(std::string name="") {
Value * v = input_->addOutput();
if (name != "") v->setUniqueName(name);
Expand Down Expand Up @@ -676,7 +772,8 @@ friend struct Value;
}
Node * createFusionGroup() {
auto n = create(kFusionGroup, 0);
n->g_(kSubgraph,std::make_shared<Graph>());
auto subgraph = std::make_shared<Graph>(scope_root_);
n->g_(kSubgraph, subgraph);
return n;
}
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args);
Expand Down Expand Up @@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_)
node_->graph_->all_values.emplace(this);
}

inline Scope* Value::scope() {
return node()->scope();
}

inline void Value::setScope(Scope* scope) {
node()->setScope(scope);
}

inline std::string Value::scopeName() const {
return node()->scopeName();
}

inline Graph * Value::owningGraph() {
return node()->owningGraph();
}
Expand All @@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
inline Node::Node(Graph * graph_, NodeKind kind_) :
kind_(kind_),
graph_(graph_),
stage_(graph_->new_node_stage_) {
stage_(graph_->new_node_stage_),
scope_(graph_->current_scope_) {
graph_->all_nodes.emplace(this);
}

Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
throw std::logic_error("ToONNX: tracing state is expired");
}

auto new_graph = std::make_shared<Graph>();
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
std::unordered_map<void*, Value*> new_buffer_map;

torch::autograd::SymbolicContext ctx;
Expand Down Expand Up @@ -137,6 +137,10 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
throw std::runtime_error(ss.str());
}

for (auto& el: outputs) {
el->setScope(n->scope());
}

setOutputs(op_name, n, outputs);
};

Expand Down Expand Up @@ -208,6 +212,9 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
IR_IFM(node, CppOp)
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn), node->getSourceLocation());
for (auto& el: outputs) {
el->setScope(node->scope());
}
setOutputs(value->name(), node, outputs);
} else {
cloneNode(node);
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/jit/python_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace torch { namespace jit {

void initPythonTracerBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {
std::ostringstream ss;
Expand All @@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
ss << *s.graph;
return ss.str();
})
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
ASSERT_UNEXPIRED("push_scope");
s.push_scope(scope_name);
})
.def("pop_scope", [](TracingState& s) {
ASSERT_UNEXPIRED("pop_scope");
s.pop_scope();
})
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
ASSERT_UNEXPIRED("export");
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
Expand All @@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
m.def("_get_tracing_state", [](const variable_list& vars) {
return getTracingState(vars);
});
m.def("_is_tracing", [](const variable_list& vars) {
return isTracing(vars);
});
}

}} // namespace torch::jit
8 changes: 8 additions & 0 deletions torch/csrc/jit/tracer_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
bool is_complete() const {
return !is_expired() && graph->stage() == num_stages - 1;
}

void push_scope(const std::string& scope_name) {
graph->push_scope(scope_name);
}

void pop_scope() {
graph->pop_scope();
}
};

struct ValueTracingStateElem {
Expand Down
27 changes: 27 additions & 0 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
_flatten = torch._C._jit_flatten


# This global variable is set when we are tracing a *forwards* computation.
# It is intended to be a cheap way to test if tracing has occurred, before
# doing the slower path using `get_tracing_state` (below.)
_tracing = False


def get_tracing_state(args):
if not torch._C._is_tracing(args):
return None
return torch._C._get_tracing_state(args)


@contextlib.contextmanager
def scope(scope_name, *vars):
tracing_state = get_tracing_state(vars)
if tracing_state:
tracing_state.push_scope(scope_name)
try:
yield
finally:
if tracing_state:
tracing_state.pop_scope()


def compile(arg=None, nderivs=1, optimize=True, enabled=True):
"""
Decorator which marks a function or module class as eligible for
Expand Down Expand Up @@ -237,13 +261,16 @@ def __init__(self, inner, nderivs=0):
self.nderivs = nderivs

def forward(self, *args):
global _tracing
in_vars = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(self.state_dict(keep_vars=True).values())
trace = torch._C._tracer_enter(in_vars + module_state, self.nderivs)
_tracing = True
out = self.inner(*args)
out_vars = _flatten(out)
_tracing = False
torch._C._tracer_exit(out_vars)
return trace, out

Expand Down
Loading