@@ -70,6 +70,64 @@ struct SourceLocation {
7070 std::string python_traceback;
7171};
7272
73+ // Scope is a node of a trie that represents the tree of nested scopes.
74+ // Individual scopes are pushed and popped from Graph, which holds a
75+ // pointer to the current scope. Each Node in Graph holds a pointer
76+ // to the scope that was current when the node was created.
77+ // The trie never needs to shrink, it only grows until it is disposed
78+ // of when Graph is deallocated. Hence, pointers to scopes held by nodes
79+ // will always be valid as long as Graph is alive.
80+ struct Scope {
81+ private:
82+ Scope* parent_;
83+ Symbol name_;
84+ std::vector<std::unique_ptr<Scope> > children_;
85+ public:
86+ Scope () {
87+ name_ = stringToSymbol (" " );
88+ parent_ = NULL ;
89+ }
90+ Scope (Scope* parent, Symbol name) {
91+ name_ = name;
92+ parent_ = parent;
93+ }
94+ Scope* push (Symbol name) {
95+ children_.push_back (std::unique_ptr<Scope>(new Scope (this , name)));
96+ return children_.back ().get ();
97+ }
98+ Scope* parent () {
99+ if (parent_ == NULL ) {
100+ throw std::runtime_error (" Cannot get parent from Scope with no parent" );
101+ }
102+ return parent_;
103+ }
104+ bool isRoot () {
105+ return parent_ == NULL ;
106+ }
107+ Scope* getRoot () {
108+ Scope* current = this ;
109+ while (current->parent_ ) {
110+ current = current->parent_ ;
111+ }
112+ return current;
113+ }
114+ Symbol name () {
115+ return name_;
116+ }
117+ std::string namesFromRoot (const std::string& separator=" /" ) {
118+ std::string out = std::string (symbolToString (this ->name_ ));
119+ if (this ->isRoot ()) {
120+ return out;
121+ }
122+ Scope* parent = this ->parent_ ;
123+ while (!parent->isRoot ()) {
124+ out = std::string (symbolToString (parent->name_ )) + separator + out;
125+ parent = parent->parent_ ;
126+ }
127+ return out;
128+ }
129+ };
130+
73131// the list types are intentionally simple, but we type-def
74132// them here so if we need to change them, refactoring will be easier
75133using node_list = std::vector<Node*>;
@@ -139,6 +197,9 @@ struct Value {
139197 const Node * node () const {
140198 return node_;
141199 }
200+ Scope* scope ();
201+ void setScope (Scope* scope);
202+ std::string scopeName () const ;
142203 Graph * owningGraph ();
143204 const Graph * owningGraph () const ;
144205 // TODO: make this more const correct
@@ -197,6 +258,7 @@ struct Node : public Attributes<Node> {
197258 Graph* graph_;
198259 std::shared_ptr<SourceLocation> source_location_;
199260 size_t stage_;
261+ Scope* scope_;
200262protected:
201263 Node (Graph * graph_, NodeKind kind_); // defined after graph
202264public:
@@ -223,6 +285,18 @@ struct Node : public Attributes<Node> {
223285 stage_ = s;
224286 return this ;
225287 }
288+ Scope* scope () {
289+ return scope_;
290+ }
291+ void setScope (Scope* scope) {
292+ scope_ = scope;
293+ }
294+ std::string scopeName () const {
295+ if (scope_ == NULL ) {
296+ return " " ;
297+ }
298+ return scope_->namesFromRoot ();
299+ }
226300 // NB: This returns an ArrayRef; that means that it will
227301 // get invalidated if you resize inputs (e.g., using addInput)
228302 // We can't return a std::vector<Node*>& because there's no
@@ -534,6 +608,7 @@ struct Node : public Attributes<Node> {
534608 // if you are going to preserve it.
535609 virtual void cloneFrom (Node * s) {
536610 setSourceLocation (s->getSourceLocation ());
611+ scope_ = s->scope_ ;
537612 copyAttributes (*s);
538613 }
539614};
@@ -556,6 +631,9 @@ friend struct Value;
556631
557632 size_t new_node_stage_;
558633
634+ std::shared_ptr<Scope> scope_root_;
635+ Scope * current_scope_;
636+
559637 // holds outputs in a way that can be reflected
560638 // as a Use object
561639 // also used as the beginning/end of the circular node list to avoid
@@ -564,11 +642,17 @@ friend struct Value;
564642 Node * const input_;
565643
566644public:
567- Graph ()
645+
646+ Graph (std::shared_ptr<Scope> scope_root)
568647 : next_unique_(0 )
569648 , new_node_stage_(0 )
649+ , scope_root_(scope_root)
650+ , current_scope_(scope_root_.get())
570651 , output_(initOutput(create(kReturn , 0 ))), input_(create(kParam , 0 )) {}
571652
653+ Graph ()
654+ : Graph( std::make_shared<Scope>()) {}
655+
572656 at::ArrayRef<Value*> inputs () {
573657 return input_->outputs ();
574658 }
@@ -621,6 +705,18 @@ friend struct Value;
621705 const Node * return_node () const {
622706 return output_;
623707 }
708+ void push_scope (const std::string& scope_name) {
709+ current_scope_ = current_scope_->push (stringToSymbol (scope_name));
710+ }
711+ void pop_scope () {
712+ current_scope_ = current_scope_->parent ();
713+ }
714+ Scope * current_scope () {
715+ return current_scope_;
716+ }
717+ std::shared_ptr<Scope> scope_root () {
718+ return scope_root_;
719+ }
624720 Value * addInput (std::string name=" " ) {
625721 Value * v = input_->addOutput ();
626722 if (name != " " ) v->setUniqueName (name);
@@ -676,7 +772,8 @@ friend struct Value;
676772 }
677773 Node * createFusionGroup () {
678774 auto n = create (kFusionGroup , 0 );
679- n->g_ (kSubgraph ,std::make_shared<Graph>());
775+ auto subgraph = std::make_shared<Graph>(scope_root_);
776+ n->g_ (kSubgraph , subgraph);
680777 return n;
681778 }
682779 Node * createPythonOp (THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args);
@@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_)
759856 node_->graph_ ->all_values .emplace (this );
760857}
761858
859+ inline Scope* Value::scope () {
860+ return node ()->scope ();
861+ }
862+
863+ inline void Value::setScope (Scope* scope) {
864+ node ()->setScope (scope);
865+ }
866+
867+ inline std::string Value::scopeName () const {
868+ return node ()->scopeName ();
869+ }
870+
762871inline Graph * Value::owningGraph () {
763872 return node ()->owningGraph ();
764873}
@@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
779888inline Node::Node (Graph * graph_, NodeKind kind_) :
780889 kind_(kind_),
781890 graph_(graph_),
782- stage_(graph_->new_node_stage_) {
891+ stage_(graph_->new_node_stage_),
892+ scope_(graph_->current_scope_) {
783893 graph_->all_nodes .emplace (this );
784894}
785895
0 commit comments