Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,27 @@ bool AliasDb::hasWriters(const Value* v) const {
return writeCache_.intersects(el->getMemoryLocations());
}

void AliasDb::getWritesImpl(Block* b, MemoryLocations& ret, bool recurseBlocks)
const {
for (auto node : b->nodes()) {
getWritesImpl(node, ret, recurseBlocks);
}
}

void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret, bool recurseBlocks)
const {
void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret) const {
if (writeIndex_.count(n)) {
const auto& writes = writeIndex_.at(n);
ret |= writes;
}

if (recurseBlocks) {
for (auto block : n->blocks()) {
getWritesImpl(block, ret, recurseBlocks);
for (auto node : block->nodes()) {
getWritesImpl(node, ret);
}
}
}
}

// Does `n` write to an alias of one of the values in `vs`?
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
const {
const auto writtenTo = getWrites(n, recurseBlocks);
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}

MemoryLocations locs;
for (const auto v : vs) {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
Expand All @@ -134,14 +126,13 @@ bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
return false;
}

MemoryLocations AliasDb::getWrites(Node* n, bool recurseBlocks) const {
MemoryLocations AliasDb::getWrites(Node* n) const {
MemoryLocations writes;
getWritesImpl(n, writes, recurseBlocks);
getWritesImpl(n, writes);
return writes;
}

void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret, bool recurseBlocks)
const {
void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
for (const auto input : n->inputs()) {
auto it = elementMap_.find(input);
if (it != elementMap_.end()) {
Expand All @@ -155,18 +146,16 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret, bool recurseBlocks)
}
}

if (recurseBlocks) {
for (auto block : n->blocks()) {
for (auto node : block->nodes()) {
getReadsImpl(node, ret, recurseBlocks);
}
for (auto block : n->blocks()) {
for (auto node : block->nodes()) {
getReadsImpl(node, ret);
}
}
}

MemoryLocations AliasDb::getReads(Node* n, bool recurseBlocks) const {
MemoryLocations AliasDb::getReads(Node* n) const {
MemoryLocations reads;
getReadsImpl(n, reads, recurseBlocks);
getReadsImpl(n, reads);
return reads;
}

Expand Down Expand Up @@ -828,8 +817,8 @@ class AliasDb::WorkingSet {
for (const auto user : getUsersSameBlock(mover_)) {
moverUsers_.insert(user);
}
moverWrites_ |= aliasDb_.getWrites(mover_, /*recurseBlocks=*/true);
moverReads_ |= aliasDb_.getReads(mover_, /*recurseBlocks=*/true);
moverWrites_ |= aliasDb_.getWrites(mover_);
moverReads_ |= aliasDb_.getReads(mover_);
}

// Add `n` to the working set
Expand All @@ -839,8 +828,8 @@ class AliasDb::WorkingSet {
users_.insert(user);
}

writes_ |= aliasDb_.getWrites(n, /*recurseBlocks=*/true);
reads_ |= aliasDb_.getReads(n, /*recurseBlocks=*/true);
writes_ |= aliasDb_.getWrites(n);
reads_ |= aliasDb_.getReads(n);
}

void eraseMover() {
Expand Down Expand Up @@ -878,7 +867,7 @@ class AliasDb::WorkingSet {

bool hasMutabilityDependency(Node* n) const {
// Check that `n` does not write to anything used by the working set
const auto& nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
const auto& nWrites = aliasDb_.getWrites(n);
if (reads_.intersects(nWrites)) {
return true;
}
Expand All @@ -887,7 +876,7 @@ class AliasDb::WorkingSet {
}

// Check that the working set doesn't write to anything that `n` uses.
const auto& nReads = aliasDb_.getReads(n, /*recurseBlocks=*/true);
const auto& nReads = aliasDb_.getReads(n);
if (writes_.intersects(nReads)) {
return true;
}
Expand Down
17 changes: 5 additions & 12 deletions torch/csrc/jit/passes/alias_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ class AliasDb {

// Does `n` write to an alias of one of the values in `vs`?
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
TORCH_API bool writesToAlias(
Node* n,
const ValueSet& vs,
bool recurseBlocks = false) const;
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const;

// Does `a` and `b` potentially share a memory location or do either
// hold in memory any element that exists in the other
Expand Down Expand Up @@ -100,21 +97,17 @@ class AliasDb {
// NOTE: this only returns values directly written to, not aliases thereof
//
// if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
MemoryLocations getWrites(Node* n, bool recurseBlocks = false) const;
void getWritesImpl(Block* b, MemoryLocations& ret, bool recurseBlocks = false)
const;
void getWritesImpl(Node* n, MemoryLocations& ret, bool recurseBlocks = false)
const;
MemoryLocations getWrites(Node* n) const;
void getWritesImpl(Node* n, MemoryLocations& ret) const;
// Do any nodes write to `v`s memory location?
TORCH_API bool hasWriters(const Value* v) const;
// Register the fact that `n` writes to `v`.
void registerWrite(const Value* v, Node* n);
void registerWrite(const Element* e, Node* n);
// Get all the values that `n` reads from.
// if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
MemoryLocations getReads(Node* n, bool recurseBlocks = false) const;
void getReadsImpl(Node* n, MemoryLocations& ret, bool recurseBlocks = false)
const;
MemoryLocations getReads(Node* n) const;
void getReadsImpl(Node* n, MemoryLocations& ret) const;

/**
* Wildcard methods
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/dead_code_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class DeadCodeEliminator {
}

if (aliasDb_) {
if (aliasDb_->writesToAlias(node, liveValues_, /*recurseBlocks=*/false)) {
if (aliasDb_->writesToAlias(node, liveValues_)) {
return mark(node);
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ShapePropagator {
}
if (resizesInput(n)) {
for (const auto input : n->inputs()) {
if (aliasDb_.writesToAlias(n, {input}, /*recurseBlocks*/ false)) {
if (aliasDb_.writesToAlias(n, {input})) {
resized_alias_set.insert(input);
}
}
Expand Down