Skip to content

Commit ff3dd72

Browse files
bwastifacebook-github-bot
authored andcommitted
Add in-place check to AliasDb
Summary: Pull Request resolved: #23210 Test Plan: Imported from OSS Differential Revision: D16444529 Pulled By: bwasti fbshipit-source-id: 83af54d423989a2a726158b521093660584ee9c2
1 parent 336c9be commit ff3dd72

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

test/cpp/jit/test_alias_analysis.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,51 @@ void testWriteTracking() {
539539
ASSERT_TRUE(aliasDb.writesToAlias(
540540
writingNode, std::unordered_set<const Value*>{aAlias}));
541541
}
542+
{
543+
auto graph = std::make_shared<Graph>();
544+
script::parseIR(
545+
R"IR(
546+
graph(%x: Tensor):
547+
%b : (Tensor) = aten::relu_(%x)
548+
return (%b)
549+
)IR",
550+
&*graph);
551+
auto node_iter = graph->block()->nodes().begin();
552+
auto relu = *node_iter;
553+
AliasDb aliasDb(graph);
554+
AT_ASSERT(aliasDb.isMutable(relu));
555+
}
556+
{
557+
auto graph = std::make_shared<Graph>();
558+
script::parseIR(
559+
R"IR(
560+
graph(%x: Tensor, %y : Tensor):
561+
%b : (Tensor) = aten::mul(%x, %y)
562+
return (%b)
563+
)IR",
564+
&*graph);
565+
auto node_iter = graph->block()->nodes().begin();
566+
auto mul = *node_iter;
567+
AliasDb aliasDb(graph);
568+
AT_ASSERT(!aliasDb.isMutable(mul));
569+
}
570+
{
571+
auto graph = std::make_shared<Graph>();
572+
std::unordered_map<std::string, Value*> vmap;
573+
script::parseIR(
574+
R"IR(
575+
graph(%x: Tensor, %y : Tensor):
576+
%c1 : int = prim::Constant[value=1]()
577+
%b : (Tensor) = aten::add_(%x, %y, %c1)
578+
return (%b)
579+
)IR",
580+
&*graph,
581+
vmap);
582+
auto add = vmap["b"]->node();
583+
AliasDb aliasDb(graph);
584+
AT_ASSERT(aliasDb.hasWriters(add));
585+
AT_ASSERT(aliasDb.isMutable(add));
586+
}
542587
}
543588

544589
void testContainerAliasing() {

torch/csrc/jit/passes/alias_analysis.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,24 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
6060
GRAPH_DEBUG(toString());
6161
}
6262

63-
bool AliasDb::hasWriters(const Node* n) const {
63+
bool AliasDb::isMutable(Node* n) const {
64+
ValueSet vs;
65+
for (const auto input : n->inputs()) {
66+
vs.insert(input);
67+
}
68+
return writesToAlias(n, vs);
69+
}
70+
71+
bool AliasDb::hasInputWriters(const Node* n) const {
6472
for (const auto input : n->inputs()) {
6573
if (hasWriters(input)) {
6674
return true;
6775
}
6876
}
77+
return false;
78+
}
79+
80+
bool AliasDb::hasOutputWriters(const Node* n) const {
6981
for (const auto output : n->outputs()) {
7082
if (hasWriters(output)) {
7183
return true;
@@ -74,6 +86,10 @@ bool AliasDb::hasWriters(const Node* n) const {
7486
return false;
7587
}
7688

89+
bool AliasDb::hasWriters(const Node* n) const {
90+
return hasInputWriters(n) || hasOutputWriters(n);
91+
}
92+
7793
bool AliasDb::hasWriters(const Value* v) const {
7894
if (v->mustBeNone()) {
7995
return false;

torch/csrc/jit/passes/alias_analysis.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,19 @@ class AliasDb {
6161
// value in group `b`? i.e. may they overlap?
6262
TORCH_API bool mayAlias(const ValueSet& a, const ValueSet& b) const;
6363

64+
// Do any nodes write to an alias set input to `n`?
65+
TORCH_API bool hasInputWriters(const Node* n) const;
66+
67+
// Do any nodes write to an alias set output by `n`?
68+
TORCH_API bool hasOutputWriters(const Node* n) const;
69+
6470
// Do any nodes write to an alias set inputed/outputed by `n`?
6571
TORCH_API bool hasWriters(const Node* n) const;
6672

73+
// Is the operation in-place? i.e. doesn't write anywhere but locations it
74+
// reads from.
75+
TORCH_API bool isMutable(Node* n) const;
76+
6777
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
6878
//
6979
// Tries to preserve value dependencies, so other nodes might be moved. We

0 commit comments

Comments
 (0)