Skip to content

Commit 99674d1

Browse files
committed
Responded to more suggestions
1 parent 9bb9a7e commit 99674d1

File tree

3 files changed

+16
-21
lines changed

3 files changed

+16
-21
lines changed

c10/util/sparse_bitset.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@
154154
Copy &= ~0UL << BitPos;
155155

156156
if (Copy != 0)
157-
return WordPos * BITWORD_SIZE + __builtin_ctzl(Copy);
157+
return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
158158

159159
// Check subsequent words.
160160
for (unsigned i = WordPos+1; i < BITWORDS_PER_ELEMENT; ++i)
161161
if (Bits[i] != 0)
162-
return i * BITWORD_SIZE + __builtin_ctzl(Bits[i]);
162+
return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
163163
return -1;
164164
}
165165

torch/csrc/jit/passes/utils/memory_dag.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,11 @@
88
namespace torch {
99
namespace jit {
1010
namespace {
11-
ska::flat_hash_map<const Element*, unsigned> comprMap;
1211
std::vector<const Element*> decomprMap;
1312
} // namespace
14-
15-
unsigned Element::toIndex(const Element* x) {
16-
if (comprMap.count(x)) {
17-
return comprMap[x];
18-
}
19-
comprMap[x] = comprMap.size();
20-
decomprMap.push_back(x);
21-
return comprMap[x];
13+
unsigned Element::indexCount = 0;
14+
Element::Element(const Value* value_) : value(value_), index(indexCount++) {
15+
decomprMap.push_back(this);
2216
}
2317

2418
const Element* Element::toElement(unsigned x) {
@@ -54,7 +48,7 @@ void collectAllContainedMemoryLocations(
5448
const Element* elem,
5549
MemoryLocations& cont) {
5650
// we have already recursed on this element
57-
unsigned compIdx = Element::toIndex(elem);
51+
unsigned compIdx = elem->index;
5852
if (cont.test(compIdx)) {
5953
return;
6054
}
@@ -101,18 +95,17 @@ bool MemoryDAG::mayContainAlias(
10195

10296
// Make `v` point at `to`.
10397
void MemoryDAG::makePointerTo(Element* from, Element* to) {
104-
from->pointsTo.set(Element::toIndex(to));
105-
to->pointedFrom.set(Element::toIndex(from));
98+
from->pointsTo.set(to->index);
99+
to->pointedFrom.set(from->index);
106100
}
107101

108102
void MemoryDAG::addToContainedElements(Element* elem, Element* container) {
109-
container->contained_elements.set(Element::toIndex(elem));
103+
container->contained_elements.set(elem->index);
110104
}
111105

112106
// Give `v` a fresh alias (i.e. it does not point to any value)
113107
Element* MemoryDAG::makeFreshValue(const Value* v) {
114-
auto el = torch::make_unique<Element>();
115-
el->value = v;
108+
auto el = torch::make_unique<Element>(v);
116109

117110
auto rawPtr = el.get();
118111
elements_.emplace(rawPtr, std::move(el));
@@ -135,8 +128,8 @@ const MemoryLocations& Element::getMemoryLocations() const {
135128
// traversing in the direction `dir`.`fn` will be run on each element.
136129
void Element::bfs(BfsDirection dir, MemoryLocations& res) const {
137130
std::queue<unsigned> queue;
138-
std::unordered_set<int> seen;
139-
queue.push(Element::toIndex(this));
131+
ska::flat_hash_set<int> seen;
132+
queue.push(this->index);
140133
while (!queue.empty()) {
141134
const auto el = queue.front();
142135
queue.pop();

torch/csrc/jit/passes/utils/memory_dag.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ struct Element {
126126
MemoryLocations pointedFrom;
127127

128128
MemoryLocations contained_elements;
129+
static unsigned indexCount;
130+
signed index;
131+
Element(const Value* value_);
129132

130133
// Return the unique memory locations that `Element` might represent.
131134
TORCH_API const MemoryLocations& getMemoryLocations() const;
@@ -138,8 +141,7 @@ struct Element {
138141
// traversing in the direction `dir`.`fn` will be run on each element.
139142
void bfs(BfsDirection dir, MemoryLocations& res) const;
140143

141-
// Converts to and from the compressed index representation
142-
static unsigned toIndex(const Element* x);
144+
// Converts from the compressed index representation
143145
static const Element* toElement(unsigned x);
144146
};
145147
} // namespace jit

0 commit comments

Comments
 (0)