11#include " memory_dag.h"
22
3+ #include < c10/util/flat_hash_map.h>
34#include < torch/csrc/utils/memory.h>
45#include < algorithm>
56#include < queue>
6- #include < iostream>
77
88namespace torch {
99namespace jit {
@@ -18,9 +18,10 @@ int getCompressed(const Element* x) {
1818 decomprMap[comprMap.size ()] = x;
1919 return comprMap[x];
2020}
21- const Element * getDecompressed (int x) {
22- assert (decomprMap.count (x));
23- return decomprMap[x];
21+ const Element* getDecompressed (int x) {
22+ auto res = decomprMap[x];
23+ AT_ASSERT (res);
24+ return res;
2425}
2526bool MemoryDAG::mayAlias (Element* a, Element* b) const {
2627 return mayAliasImpl (a, b);
@@ -31,8 +32,8 @@ bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
3132}
3233
3334bool MemoryDAG::memoryLocationOverlap (
34- const MemoryLocations & aMemLoc,
35- const MemoryLocations & bMemLoc) const {
35+ const MemoryLocations& aMemLoc,
36+ const MemoryLocations& bMemLoc) const {
3637 return aMemLoc.intersects (bMemLoc);
3738}
3839
@@ -53,15 +54,13 @@ bool MemoryDAG::mayContainAlias(Element* a, Element* b) const {
5354
5455void collectAllContainedMemoryLocations (
5556 const Element* elem,
56- MemoryLocations & cont) {
57+ MemoryLocations& cont) {
5758 // we have already recursed on this element
5859 int compIdx = getCompressed (elem);
59- if (cont.test (compIdx)) {
60+ if (cont.test_and_set (compIdx)) {
6061 return ;
6162 }
6263
63- cont.set (compIdx);
64-
6564 for (const auto & mem_loc : elem->getMemoryLocations ()) {
6665 collectAllContainedMemoryLocations (getDecompressed (mem_loc), cont);
6766 }
@@ -72,8 +71,8 @@ void collectAllContainedMemoryLocations(
7271}
7372
7473bool MemoryDAG::mayContainAliasImpl (const Element* a, const Element* b) const {
75- MemoryLocations all_a_mlocs;
76- MemoryLocations all_b_mlocs;
74+ MemoryLocations all_a_mlocs;
75+ MemoryLocations all_b_mlocs;
7776
7877 collectAllContainedMemoryLocations (a, all_a_mlocs);
7978 collectAllContainedMemoryLocations (b, all_b_mlocs);
@@ -88,12 +87,12 @@ bool MemoryDAG::mayContainAlias(
8887 return false ;
8988 }
9089
91- MemoryLocations all_a_mlocs;
90+ MemoryLocations all_a_mlocs;
9291 for (const auto & elem : a) {
9392 collectAllContainedMemoryLocations (elem, all_a_mlocs);
9493 }
9594
96- MemoryLocations all_b_mlocs;
95+ MemoryLocations all_b_mlocs;
9796 for (const auto & elem : b) {
9897 collectAllContainedMemoryLocations (elem, all_b_mlocs);
9998 }
@@ -123,7 +122,7 @@ Element* MemoryDAG::makeFreshValue(const Value* v) {
123122
124123TORCH_API std::unordered_set<const Element*> convert (MemoryLocations bits) {
125124 std::unordered_set<const Element*> res;
126- for (auto i: bits) {
125+ for (auto i : bits) {
127126 res.insert (getDecompressed (i));
128127 }
129128 return res;
@@ -134,51 +133,45 @@ const MemoryLocations& Element::getMemoryLocations() const {
134133 }
135134
136135 // Do a BFS in the `points-to` direction, collecting all memory locations
137- MemoryLocations ret;
138- this ->bfs (
139- [&](const Element* el) {
140- if (el->pointsTo .empty ()) {
141- ret.set (getCompressed (el));
142- }
143- },
144- BfsDirection::POINTS_TO);
136+ MemoryLocations ret;
137+ this ->bfs (BfsDirection::POINTS_TO, ret);
145138 cachedMemoryLocations_ = ret;
146- return ret ;
139+ return cachedMemoryLocations_ ;
147140}
148141
149142// Do a breadth-first search over the graph, starting at `this` and
150143// traversing in the direction `dir`.`fn` will be run on each element.
151- template <typename Fn>
152- bool Element::bfs (Fn fn, BfsDirection dir) const {
153- std::queue<const Element*> queue;
154- MemoryLocations seen;
155- queue.push (this );
144+ void Element::bfs (BfsDirection dir, MemoryLocations& res) const {
145+ std::queue<int > queue;
146+ MemoryLocations seen;
147+ queue.push (getCompressed (this ));
156148 while (!queue.empty ()) {
157149 const auto el = queue.front ();
158150 queue.pop ();
159- seen.set (getCompressed (el));
160-
161- fn (el);
151+ seen.set (el);
152+ auto decompEl = getDecompressed (el);
153+ if (decompEl->pointsTo .empty ()) {
154+ res.set (el);
155+ }
162156
163157 switch (dir) {
164158 case BfsDirection::POINTS_TO: {
165- for (auto ptr : el ->pointsTo ) {
159+ for (auto ptr : decompEl ->pointsTo ) {
166160 if (!seen.test (ptr)) {
167- queue.push (getDecompressed ( ptr) );
161+ queue.push (ptr);
168162 }
169163 }
170164 } break ;
171165
172166 case BfsDirection::POINTED_FROM: {
173- for (auto ptr : el ->pointedFrom ) {
167+ for (auto ptr : decompEl ->pointedFrom ) {
174168 if (!seen.test (ptr)) {
175- queue.push (getDecompressed ( ptr) );
169+ queue.push (ptr);
176170 }
177171 }
178172 } break ;
179173 }
180174 }
181- return false ;
182175}
183176} // namespace jit
184177} // namespace torch
0 commit comments