@@ -69,6 +69,64 @@ struct SourceLocation {
6969 std::string python_traceback;
7070};
7171
72+ // Scope is a node of a trie that represents the tree of nested scopes.
73+ // Individual scopes are pushed and popped from Graph, which holds a
74+ // pointer to the current scope. Each Node in Graph holds a pointer
75+ // to the scope that was current when the node was created.
76+ // The trie never needs to shrink, it only grows until it is disposed
77+ // of when Graph is deallocated. Hence, pointers to scopes held by nodes
78+ // will always be valid as long as Graph is alive.
79+ struct Scope {
80+ private:
81+ Scope* parent_;
82+ Symbol name_;
83+ std::vector<std::unique_ptr<Scope> > children_;
84+ public:
85+ Scope () {
86+ name_ = stringToSymbol (" " );
87+ parent_ = NULL ;
88+ }
89+ Scope (Scope* parent, Symbol name) {
90+ name_ = name;
91+ parent_ = parent;
92+ }
93+ Scope* push (Symbol name) {
94+ children_.push_back (std::unique_ptr<Scope>(new Scope (this , name)));
95+ return children_.back ().get ();
96+ }
97+ Scope* parent () {
98+ if (parent_ == NULL ) {
99+ throw std::runtime_error (" Cannot get parent from Scope with no parent" );
100+ }
101+ return parent_;
102+ }
103+ bool isRoot () {
104+ return parent_ == NULL ;
105+ }
106+ Scope* getRoot () {
107+ Scope* current = this ;
108+ while (current->parent_ ) {
109+ current = current->parent_ ;
110+ }
111+ return current;
112+ }
113+ Symbol name () {
114+ return name_;
115+ }
116+ std::string namesFromRoot (const std::string& separator=" /" ) {
117+ std::string out = std::string (symbolToString (this ->name_ ));
118+ if (this ->isRoot ()) {
119+ return out;
120+ }
121+ Scope* parent = this ->parent_ ;
122+ while (!parent->isRoot ()) {
123+ out = std::string (symbolToString (parent->name_ )) + separator + out;
124+ parent = parent->parent_ ;
125+ }
126+ return out;
127+ }
128+ };
129+
72130// the list types are intentionally simple, but we type-def
73131// them here so if we need to change them, refactoring will be easier
74132using node_list = std::vector<Node*>;
@@ -123,6 +181,7 @@ struct Node : public Attributes<Node> {
123181 size_t stage_ = 0 ; // 0-forward, 1-backward, 2-double-backward,...
124182 std::string debug_name_;
125183 std::shared_ptr<SourceLocation> source_location_;
184+ Scope* scope_;
126185protected:
127186 TypePtr type_;
128187 Node (Graph * graph_, NodeKind kind_); // defined after graph
@@ -188,6 +247,18 @@ struct Node : public Attributes<Node> {
188247 size_t stage () const {
189248 return stage_;
190249 }
250+ Scope* scope () {
251+ return scope_;
252+ }
253+ void setScope (Scope* scope) {
254+ scope_ = scope;
255+ }
256+ std::string scopeName () const {
257+ if (scope_ == NULL ) {
258+ return " " ;
259+ }
260+ return scope_->namesFromRoot ();
261+ }
191262 // NB: This returns an ArrayRef; that means that it will
192263 // get invalidated if you resize inputs (e.g., using addInput)
193264 // We can't return a std::vector<Node*>& because there's no
@@ -528,12 +599,7 @@ struct Node : public Attributes<Node> {
528599 //
529600 // NB: This does NOT clone stages. You're expected to set the stage correctly
530601 // if you are going to preserve it.
531- virtual void cloneFrom (Node * s) {
532- if (s->hasType ()) setType (s->type ());
533- setDebugName (s->debugName ());
534- setSourceLocation (s->getSourceLocation ());
535- copyAttributes (*s);
536- }
602+ virtual void cloneFrom (Node * s);
537603};
538604
539605struct Graph {
@@ -551,18 +617,27 @@ friend struct Node;
551617
552618 size_t new_node_stage_;
553619
620+ std::shared_ptr<Scope> scope_root_;
621+ Scope * current_scope_;
622+
554623 // holds outputs in a way that can be reflected
555624 // as a Use object
556625 // also used as the beginning/end of the circular node list to avoid
557626 // having corner cases where the list is empty.
558627 Node * const output_;
559628
560629public:
561- Graph ()
630+
631+ Graph (std::shared_ptr<Scope> scope_root)
562632 : next_unique_(0 )
563633 , new_node_stage_(0 )
634+ , scope_root_(scope_root)
635+ , current_scope_(scope_root_.get())
564636 , output_(initOutput(create(kReturn ))) {}
565637
638+ Graph ()
639+ : Graph( std::make_shared<Scope>()) {}
640+
566641 at::ArrayRef<Node*> inputs () {
567642 return inputs_;
568643 }
@@ -618,6 +693,29 @@ friend struct Node;
618693 Node * addInput () {
619694 return addInput (create (kParam ));
620695 }
696+ void push_scope (const std::string& scope_name) {
697+ current_scope_ = current_scope_->push (stringToSymbol (scope_name));
698+ }
699+ void pop_scope () {
700+ current_scope_ = current_scope_->parent ();
701+ }
702+ Scope * current_scope () {
703+ return current_scope_;
704+ }
705+ void set_current_scope (Scope* scope) {
706+ if (scope->getRoot () != scope_root_.get ()) {
707+ throw std::runtime_error (" trying to set a scope as current that does not belong to the Graph's scope trie" );
708+ }
709+ current_scope_ = scope;
710+ }
711+ ResourceGuard set_current_scope_temporary (Scope* scope) {
712+ auto prev_scope = current_scope_;
713+ this ->set_current_scope (scope);
714+ return ResourceGuard ([prev_scope, this ]() { this ->current_scope_ = prev_scope; });
715+ }
716+ std::shared_ptr<Scope> scope_root () {
717+ return scope_root_;
718+ }
621719
622720 Node * addInput (Node* n) {
623721 JIT_ASSERT (n->kind () == kParam );
@@ -694,7 +792,8 @@ friend struct Node;
694792 }
695793 Node * createFusionGroup () {
696794 auto n = create (kFusionGroup );
697- n->g_ (kSubgraph ,std::make_shared<Graph>());
795+ auto subgraph = std::make_shared<Graph>(scope_root_);
796+ n->g_ (kSubgraph , subgraph);
698797 return n;
699798 }
700799 Node * createPythonOp (THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, pyobj_list&& scalar_args);
@@ -764,9 +863,10 @@ inline Node::Node(Graph * graph_, NodeKind kind_) :
764863 graph_(graph_),
765864 unique_(graph_->next_unique_++),
766865 stage_(graph_->new_node_stage_),
866+ scope_(graph_->current_scope_) ,
767867 type_(getInitialType(kind_)) {
768- graph_->all_nodes .emplace (this );
769- }
868+ graph_->all_nodes .emplace (this );
869+ }
770870
771871inline void Node::destroy () {
772872 JIT_ASSERT (inGraphList ());
@@ -788,6 +888,16 @@ inline Node* Node::makeMultireturn() {
788888 return select;
789889}
790890
891+ inline void Node::cloneFrom (Node * s) {
892+ if (s->hasType ()) setType (s->type ());
893+ setDebugName (s->debugName ());
894+ setSourceLocation (s->getSourceLocation ());
895+ if (s->owningGraph ()->scope_root_ == owningGraph ()->scope_root_ ) {
896+ scope_ = s->scope_ ;
897+ }
898+ copyAttributes (*s);
899+ }
900+
791901// Helper macros for constructing switch statements over Node types
792902// instead of heavy-weight visitors
793903// read 'between' these defines to see how they turn into a big switch
0 commit comments