Skip to content

Commit 4be1efb

Browse files
committed
Simplified bfs
1 parent b27345c commit 4be1efb

File tree

3 files changed

+49
-61
lines changed

3 files changed

+49
-61
lines changed

c10/util/sparse_bitset.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@
101101
size_type count() const {
102102
unsigned NumBits = 0;
103103
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
104-
NumBits += __builtin_popcountll(Bits[i]);
104+
NumBits += __builtin_popcountl(Bits[i]);
105105
return NumBits;
106106
}
107107

108108
/// find_first - Returns the index of the first set bit.
109109
int find_first() const {
110110
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
111111
if (Bits[i] != 0)
112-
return i * BITWORD_SIZE + __builtin_ctzll(Bits[i]);
112+
return i * BITWORD_SIZE + __builtin_ctzl(Bits[i]);
113113
throw std::runtime_error("Illegal empty element");
114114
}
115115

@@ -119,7 +119,7 @@
119119
unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
120120
if (Bits[Idx] != 0)
121121
return Idx * BITWORD_SIZE + BITWORD_SIZE -
122-
__builtin_clzll(Bits[Idx]) - 1;
122+
__builtin_clzl(Bits[Idx]) - 1;
123123
}
124124
throw std::runtime_error("Illegal empty element");
125125
}
@@ -140,12 +140,12 @@
140140
Copy &= ~0UL << BitPos;
141141

142142
if (Copy != 0)
143-
return WordPos * BITWORD_SIZE + __builtin_ctzll(Copy);
143+
return WordPos * BITWORD_SIZE + __builtin_ctzl(Copy);
144144

145145
// Check subsequent words.
146146
for (unsigned i = WordPos+1; i < BITWORDS_PER_ELEMENT; ++i)
147147
if (Bits[i] != 0)
148-
return i * BITWORD_SIZE + __builtin_ctzll(Bits[i]);
148+
return i * BITWORD_SIZE + __builtin_ctzl(Bits[i]);
149149
return -1;
150150
}
151151

torch/csrc/jit/passes/utils/memory_dag.cpp

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "memory_dag.h"
22

3+
#include <c10/util/flat_hash_map.h>
34
#include <torch/csrc/utils/memory.h>
45
#include <algorithm>
56
#include <queue>
6-
#include <iostream>
77

88
namespace torch {
99
namespace jit {
@@ -18,9 +18,10 @@ int getCompressed(const Element* x) {
1818
decomprMap[comprMap.size()] = x;
1919
return comprMap[x];
2020
}
21-
const Element * getDecompressed(int x) {
22-
assert(decomprMap.count(x));
23-
return decomprMap[x];
21+
const Element* getDecompressed(int x) {
22+
auto res = decomprMap[x];
23+
AT_ASSERT(res);
24+
return res;
2425
}
2526
bool MemoryDAG::mayAlias(Element* a, Element* b) const {
2627
return mayAliasImpl(a, b);
@@ -31,8 +32,8 @@ bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
3132
}
3233

3334
bool MemoryDAG::memoryLocationOverlap(
34-
const MemoryLocations & aMemLoc,
35-
const MemoryLocations & bMemLoc) const {
35+
const MemoryLocations& aMemLoc,
36+
const MemoryLocations& bMemLoc) const {
3637
return aMemLoc.intersects(bMemLoc);
3738
}
3839

@@ -53,15 +54,13 @@ bool MemoryDAG::mayContainAlias(Element* a, Element* b) const {
5354

5455
void collectAllContainedMemoryLocations(
5556
const Element* elem,
56-
MemoryLocations & cont) {
57+
MemoryLocations& cont) {
5758
// we have already recursed on this element
5859
int compIdx = getCompressed(elem);
59-
if (cont.test(compIdx)) {
60+
if (cont.test_and_set(compIdx)) {
6061
return;
6162
}
6263

63-
cont.set(compIdx);
64-
6564
for (const auto& mem_loc : elem->getMemoryLocations()) {
6665
collectAllContainedMemoryLocations(getDecompressed(mem_loc), cont);
6766
}
@@ -72,8 +71,8 @@ void collectAllContainedMemoryLocations(
7271
}
7372

7473
bool MemoryDAG::mayContainAliasImpl(const Element* a, const Element* b) const {
75-
MemoryLocations all_a_mlocs;
76-
MemoryLocations all_b_mlocs;
74+
MemoryLocations all_a_mlocs;
75+
MemoryLocations all_b_mlocs;
7776

7877
collectAllContainedMemoryLocations(a, all_a_mlocs);
7978
collectAllContainedMemoryLocations(b, all_b_mlocs);
@@ -88,12 +87,12 @@ bool MemoryDAG::mayContainAlias(
8887
return false;
8988
}
9089

91-
MemoryLocations all_a_mlocs;
90+
MemoryLocations all_a_mlocs;
9291
for (const auto& elem : a) {
9392
collectAllContainedMemoryLocations(elem, all_a_mlocs);
9493
}
9594

96-
MemoryLocations all_b_mlocs;
95+
MemoryLocations all_b_mlocs;
9796
for (const auto& elem : b) {
9897
collectAllContainedMemoryLocations(elem, all_b_mlocs);
9998
}
@@ -123,7 +122,7 @@ Element* MemoryDAG::makeFreshValue(const Value* v) {
123122

124123
TORCH_API std::unordered_set<const Element*> convert(MemoryLocations bits) {
125124
std::unordered_set<const Element*> res;
126-
for (auto i: bits) {
125+
for (auto i : bits) {
127126
res.insert(getDecompressed(i));
128127
}
129128
return res;
@@ -134,51 +133,45 @@ const MemoryLocations& Element::getMemoryLocations() const {
134133
}
135134

136135
// Do a BFS in the `points-to` direction, collecting all memory locations
137-
MemoryLocations ret;
138-
this->bfs(
139-
[&](const Element* el) {
140-
if (el->pointsTo.empty()) {
141-
ret.set(getCompressed(el));
142-
}
143-
},
144-
BfsDirection::POINTS_TO);
136+
MemoryLocations ret;
137+
this->bfs(BfsDirection::POINTS_TO, ret);
145138
cachedMemoryLocations_ = ret;
146-
return ret;
139+
return cachedMemoryLocations_;
147140
}
148141

149142
// Do a breadth-first search over the graph, starting at `this` and
150143
// traversing in the direction `dir`.`fn` will be run on each element.
151-
template <typename Fn>
152-
bool Element::bfs(Fn fn, BfsDirection dir) const {
153-
std::queue<const Element*> queue;
154-
MemoryLocations seen;
155-
queue.push(this);
144+
void Element::bfs(BfsDirection dir, MemoryLocations& res) const {
145+
std::queue<int> queue;
146+
MemoryLocations seen;
147+
queue.push(getCompressed(this));
156148
while (!queue.empty()) {
157149
const auto el = queue.front();
158150
queue.pop();
159-
seen.set(getCompressed(el));
160-
161-
fn(el);
151+
seen.set(el);
152+
auto decompEl = getDecompressed(el);
153+
if (decompEl->pointsTo.empty()) {
154+
res.set(el);
155+
}
162156

163157
switch (dir) {
164158
case BfsDirection::POINTS_TO: {
165-
for (auto ptr : el->pointsTo) {
159+
for (auto ptr : decompEl->pointsTo) {
166160
if (!seen.test(ptr)) {
167-
queue.push(getDecompressed(ptr));
161+
queue.push(ptr);
168162
}
169163
}
170164
} break;
171165

172166
case BfsDirection::POINTED_FROM: {
173-
for (auto ptr : el->pointedFrom) {
167+
for (auto ptr : decompEl->pointedFrom) {
174168
if (!seen.test(ptr)) {
175-
queue.push(getDecompressed(ptr));
169+
queue.push(ptr);
176170
}
177171
}
178172
} break;
179173
}
180174
}
181-
return false;
182175
}
183176
} // namespace jit
184177
} // namespace torch

torch/csrc/jit/passes/utils/memory_dag.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
#pragma once
22

33
#include <c10/util/ArrayRef.h>
4+
#include <c10/util/sparse_bitset.h>
45
#include <memory>
56
#include <unordered_map>
67
#include <unordered_set>
78
#include <vector>
8-
#include <c10/util/flat_hash_map.h>
9-
#include <c10/util/sparse_bitset.h>
10-
// #include <c10/util/dense_bitset.h>
119

1210
#include <torch/csrc/WindowsTorchApiMacro.h>
1311

1412
typedef llvm::SparseBitVector<128> MemoryLocations;
15-
// typedef llvm::BitVector MemoryLocations;
1613
namespace torch {
1714
namespace jit {
1815

@@ -36,12 +33,12 @@ struct Value;
3633
// which memory locations an element may point to.
3734
class TORCH_API MemoryDAG {
3835
public:
39-
40-
// explicitly delete copy constructor because otherwise windows build is confused for an exported class
41-
// see https://stackoverflow.com/a/51033485/105137
36+
// explicitly delete copy constructor because otherwise windows build is
37+
// confused for an exported class see
38+
// https://stackoverflow.com/a/51033485/105137
4239
MemoryDAG() {}
43-
MemoryDAG(const MemoryDAG&)=delete;
44-
MemoryDAG& operator=(const MemoryDAG&)=delete;
40+
MemoryDAG(const MemoryDAG&) = delete;
41+
MemoryDAG& operator=(const MemoryDAG&) = delete;
4542

4643
// Make `from` point at `to`.
4744
void makePointerTo(Element* from, Element* to);
@@ -75,7 +72,7 @@ class TORCH_API MemoryDAG {
7572
}
7673

7774
// Record all memory locations from group `a`
78-
MemoryLocations memoryLocations;
75+
MemoryLocations memoryLocations;
7976
for (auto it = a.cbegin(); it != a.cend();) {
8077
const auto element = *it;
8178

@@ -99,9 +96,8 @@ class TORCH_API MemoryDAG {
9996
}
10097

10198
private:
102-
bool memoryLocationOverlap(
103-
const MemoryLocations & a,
104-
const MemoryLocations & b) const;
99+
bool memoryLocationOverlap(const MemoryLocations& a, const MemoryLocations& b)
100+
const;
105101
bool mayAliasImpl(const Element* a, const Element* b) const;
106102
bool mayContainAliasImpl(const Element* contained, const Element* container)
107103
const;
@@ -126,23 +122,22 @@ struct Element {
126122

127123
// All elements that this element *may* point to. It's possible to have
128124
// multiple elements that you might point to due to control flow/complex ops
129-
MemoryLocations pointsTo;
125+
MemoryLocations pointsTo;
130126
// Backreference for points-to.
131-
MemoryLocations pointedFrom;
127+
MemoryLocations pointedFrom;
132128

133-
MemoryLocations contained_elements;
129+
MemoryLocations contained_elements;
134130

135131
// Return the unique memory locations that `Element` might represent.
136132
TORCH_API const MemoryLocations& getMemoryLocations() const;
137133
// We do path compression to make repeated memory location queries faster.
138134
// An empty cache means it is invalidated (it can never be empty otherwise,
139135
// since every element must point to at least one memory location).
140-
mutable MemoryLocations cachedMemoryLocations_;
136+
mutable MemoryLocations cachedMemoryLocations_;
141137

142138
// Do a breadth-first search over the graph, starting at `this` and
143139
// traversing in the direction `dir`.`fn` will be run on each element.
144-
template <typename Fn>
145-
bool bfs(Fn fn, BfsDirection dir) const;
140+
void bfs(BfsDirection dir, MemoryLocations& res) const;
146141
};
147142
} // namespace jit
148143
} // namespace torch

0 commit comments

Comments
 (0)