Skip to content
85 changes: 85 additions & 0 deletions test/cpp/jit/test_alias_analysis.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/irparser.h>
#include "test/cpp/jit/test_base.h"
#include "torch/csrc/jit/custom_operator.h"
Expand Down Expand Up @@ -751,6 +752,90 @@ graph():
auto tensor = vmap["11"];
ASSERT_TRUE(aliasDb.writesToAlias(conservativeOp, ValueSet{tensor}));
}
{
auto ops = torch::RegisterOperators().op(
"uses::list",
torch::RegisterOperators::options()
.catchAllKernel([](std::vector<at::Tensor> in) {
return torch::rand({2, 3});
})
.aliasAnalysis(AliasAnalysisKind::PURE));
// Write to the inside of a list. Check that we can't reorder a
// print across it.
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
script::parseIR(
R"IR(
graph():
%35 : int = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%23 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
%l : Tensor[] = prim::ListConstruct(%11, %21)
%24 : Tensor = aten::select(%l, %23)
%25 : int[] = prim::ListConstruct(%0, %1)
%34 : Tensor = aten::rand(%25, %4, %4, %8, %10)
%36 : Tensor = aten::add_(%24, %34, %35)
%37 : Tensor = uses::list(%l)
return (%37)
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto listUse = vmap["37"]->node();
auto internalWrite = vmap["36"]->node();
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
}
{
// The same as above, but with a nested list
auto ops = torch::RegisterOperators().op(
"uses::list",
torch::RegisterOperators::options()
.catchAllKernel([](std::vector<at::Tensor> in) {
return torch::rand({2, 3});
})
.aliasAnalysis(AliasAnalysisKind::PURE));
// Write to the inside of a list. Check that we can't reorder a
// print across it.
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
script::parseIR(
R"IR(
graph():
%38 : int = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%24 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
%l : Tensor[] = prim::ListConstruct(%11, %21)
%25 : Tensor = aten::select(%l, %24)
%27 : Tensor = aten::select(%25, %24, %24)
%28 : int[] = prim::ListConstruct(%0, %1)
%37 : Tensor = aten::rand(%28, %4, %4, %8, %10)
%39 : Tensor = aten::add_(%27, %37, %38)
%40 : Tensor = uses::list(%l)
return (%40)
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto listUse = vmap["40"]->node();
auto internalWrite = vmap["39"]->node();
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
}
}

void testWildcards() {
Expand Down
28 changes: 15 additions & 13 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,16 @@ void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
for (const auto input : n->inputs()) {
auto it = elementMap_.find(input);
if (it != elementMap_.end()) {
ret |= it->second->getMemoryLocations();
}
}
for (const auto output : n->outputs()) {
auto it = elementMap_.find(output);
if (it != elementMap_.end()) {
ret |= it->second->getMemoryLocations();
auto el = it->second;
// Add all memory locations this element may alias.
ret |= el->getMemoryLocations();

// We also consider memory locations of contained values to be "read".
for (const auto& type : input->type()->containedTypes()) {
if (auto wildcard = getWildcard(type)) {
ret |= wildcard->getMemoryLocations();
}
}
}
}

Expand Down Expand Up @@ -606,9 +609,10 @@ void AliasDb::analyzeContainerConstruct(Node* node) {
for (auto input : node->inputs()) {
setWildcard(input);
}
for (auto output : node->outputs()) {
giveFreshAlias(output);
}

TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
auto container = node->output();
giveFreshAlias(container);
}

// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
Expand Down Expand Up @@ -1172,10 +1176,8 @@ Element* AliasDb::getOrCreateWildcard(const TypePtr& type) {
// Search the wildcard index for an element that corresponds to the given type.
// Const version returns nullptr
Element* AliasDb::getWildcard(const TypePtr& type) const {
TORCH_INTERNAL_ASSERT(shouldAnnotate(type));
const auto kind = getMutableTypeKind(type);
TORCH_INTERNAL_ASSERT(kind);
if (!wildcardIndex_.count(*kind)) {
if (!kind || !wildcardIndex_.count(*kind)) {
return nullptr;
}
return wildcardIndex_.at(*kind);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/utils/memory_dag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ bool MemoryDAG::mayContainAlias(
return all_a_mlocs.intersects(all_b_mlocs);
}

// Make `v` point at `to`.
void MemoryDAG::makePointerTo(Element* from, Element* to) {
from->pointsTo.set(to->index);
to->pointedFrom.set(from->index);
}

void MemoryDAG::addToContainedElements(Element* elem, Element* container) {
TORCH_INTERNAL_ASSERT(
elem != container, "Elements cannot contain themselves");
container->containedElements.set(elem->index);
}

Expand Down