Skip to content

Commit 65f37b1

Browse files
author
mikeiovine
committed
[SR] Force split_and_squeeze usage via graph transformation
Pull Request resolved: #74274 Differential Revision: [D34913889](https://our.internmc.facebook.com/intern/diff/D34913889/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34913889/)! ghstack-source-id: 152119802
1 parent a1e284d commit 65f37b1

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

benchmarks/static_runtime/test_static_module.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,3 +1588,23 @@ TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
15881588

15891589
EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
15901590
}
1591+
1592+
TEST(UseSplitAndSqueeze, Fusion) {
1593+
const auto src = R"IR(
1594+
graph(%x: Tensor):
1595+
%dim: int = prim::Constant[value=1]()
1596+
%split_size: int = prim::Constant[value=1]()
1597+
%split: Tensor[] = aten::split(%x, %split_size, %dim)
1598+
%a: Tensor, %b: Tensor = prim::ListUnpack(%split)
1599+
%c: Tensor = aten::squeeze(%a, %dim)
1600+
%d: Tensor = aten::squeeze(%b, %dim)
1601+
return (%c, %d)
1602+
)IR";
1603+
auto graph = getGraphFromIR(src);
1604+
UseSplitAndSqueeze(graph);
1605+
EXPECT_TRUE(
1606+
hasNodeWithKind(graph, "static_runtime::fused_split_and_squeeze"));
1607+
EXPECT_FALSE(hasNodeWithKind(graph, "aten::split"));
1608+
EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze"));
1609+
EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack"));
1610+
}

torch/csrc/jit/runtime/static/impl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ void OptimizeGraph(
175175
EliminateNoOps(
176176
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
177177
AddIfThenElseOp(graph);
178+
UseSplitAndSqueeze(graph);
178179
GRAPH_DUMP("Final graph after optimizations: ", graph);
179180
}
180181

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,5 +1050,77 @@ void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph) {
10501050
fuse.runOnGraph(graph, dims_are_valid_constants);
10511051
}
10521052

1053+
namespace {
1054+
1055+
Node* maybeUserWithKind(Value* value, c10::Symbol kind) {
1056+
auto& uses = value->uses();
1057+
if (uses.size() != 1) {
1058+
return nullptr;
1059+
}
1060+
auto* user = uses[0].user;
1061+
if (user->kind() != kind) {
1062+
return nullptr;
1063+
}
1064+
return user;
1065+
}
1066+
1067+
} // namespace
1068+
1069+
void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph) {
1070+
std::vector<Node*> to_erase;
1071+
for (auto* node : graph->nodes()) {
1072+
if (node->kind() != aten::split) {
1073+
continue;
1074+
}
1075+
auto axis_opt = toIValue(node->input(2));
1076+
if (!axis_opt) {
1077+
continue;
1078+
}
1079+
auto axis = *axis_opt;
1080+
auto* split_node_output = node->output();
1081+
auto* list_unpack_node =
1082+
maybeUserWithKind(split_node_output, prim::ListUnpack);
1083+
if (list_unpack_node == nullptr) {
1084+
continue;
1085+
}
1086+
std::vector<Node*> squeeze_nodes;
1087+
squeeze_nodes.reserve(list_unpack_node->outputs().size());
1088+
for (auto* output : list_unpack_node->outputs()) {
1089+
auto* squeeze_node = maybeUserWithKind(output, aten::squeeze);
1090+
if (squeeze_node == nullptr) {
1091+
break;
1092+
}
1093+
auto dim_opt = toIValue(squeeze_node->input(1));
1094+
if (!dim_opt || *dim_opt != axis) {
1095+
break;
1096+
}
1097+
squeeze_nodes.push_back(squeeze_node);
1098+
}
1099+
auto num_outputs = list_unpack_node->outputs().size();
1100+
if (squeeze_nodes.size() != num_outputs) {
1101+
continue;
1102+
}
1103+
auto* split_and_squeeze_node = graph->create(
1104+
c10::Symbol::fromQualString("static_runtime::fused_split_and_squeeze"),
1105+
num_outputs);
1106+
split_and_squeeze_node->addInput(node->input(0));
1107+
split_and_squeeze_node->addInput(node->input(1));
1108+
split_and_squeeze_node->addInput(node->input(2));
1109+
split_and_squeeze_node->insertBefore(node);
1110+
for (const auto i : c10::irange(num_outputs)) {
1111+
auto* squeeze_node = squeeze_nodes[i];
1112+
split_and_squeeze_node->output(i)->copyMetadata(squeeze_node->output());
1113+
squeeze_node->output()->replaceAllUsesWith(
1114+
split_and_squeeze_node->output(i));
1115+
}
1116+
to_erase.insert(to_erase.end(), squeeze_nodes.begin(), squeeze_nodes.end());
1117+
to_erase.push_back(list_unpack_node);
1118+
to_erase.push_back(node);
1119+
}
1120+
for (auto* node : to_erase) {
1121+
node->destroy();
1122+
}
1123+
}
1124+
10531125
} // namespace jit
10541126
} // namespace torch

torch/csrc/jit/runtime/static/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,7 @@ TORCH_API void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph);
6363

6464
TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);
6565

66+
TORCH_API void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph);
67+
6668
} // namespace jit
6769
} // namespace torch

0 commit comments

Comments
 (0)