@@ -1050,5 +1050,77 @@ void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph) {
10501050 fuse.runOnGraph (graph, dims_are_valid_constants);
10511051}
10521052
1053+ namespace {
1054+
1055+ Node* maybeUserWithKind (Value* value, c10::Symbol kind) {
1056+ auto & uses = value->uses ();
1057+ if (uses.size () != 1 ) {
1058+ return nullptr ;
1059+ }
1060+ auto * user = uses[0 ].user ;
1061+ if (user->kind () != kind) {
1062+ return nullptr ;
1063+ }
1064+ return user;
1065+ }
1066+
1067+ } // namespace
1068+
1069+ void UseSplitAndSqueeze (std::shared_ptr<Graph>& graph) {
1070+ std::vector<Node*> to_erase;
1071+ for (auto * node : graph->nodes ()) {
1072+ if (node->kind () != aten::split) {
1073+ continue ;
1074+ }
1075+ auto axis_opt = toIValue (node->input (2 ));
1076+ if (!axis_opt) {
1077+ continue ;
1078+ }
1079+ auto axis = *axis_opt;
1080+ auto * split_node_output = node->output ();
1081+ auto * list_unpack_node =
1082+ maybeUserWithKind (split_node_output, prim::ListUnpack);
1083+ if (list_unpack_node == nullptr ) {
1084+ continue ;
1085+ }
1086+ std::vector<Node*> squeeze_nodes;
1087+ squeeze_nodes.reserve (list_unpack_node->outputs ().size ());
1088+ for (auto * output : list_unpack_node->outputs ()) {
1089+ auto * squeeze_node = maybeUserWithKind (output, aten::squeeze);
1090+ if (squeeze_node == nullptr ) {
1091+ break ;
1092+ }
1093+ auto dim_opt = toIValue (squeeze_node->input (1 ));
1094+ if (!dim_opt || *dim_opt != axis) {
1095+ break ;
1096+ }
1097+ squeeze_nodes.push_back (squeeze_node);
1098+ }
1099+ auto num_outputs = list_unpack_node->outputs ().size ();
1100+ if (squeeze_nodes.size () != num_outputs) {
1101+ continue ;
1102+ }
1103+ auto * split_and_squeeze_node = graph->create (
1104+ c10::Symbol::fromQualString (" static_runtime::fused_split_and_squeeze" ),
1105+ num_outputs);
1106+ split_and_squeeze_node->addInput (node->input (0 ));
1107+ split_and_squeeze_node->addInput (node->input (1 ));
1108+ split_and_squeeze_node->addInput (node->input (2 ));
1109+ split_and_squeeze_node->insertBefore (node);
1110+ for (const auto i : c10::irange (num_outputs)) {
1111+ auto * squeeze_node = squeeze_nodes[i];
1112+ split_and_squeeze_node->output (i)->copyMetadata (squeeze_node->output ());
1113+ squeeze_node->output ()->replaceAllUsesWith (
1114+ split_and_squeeze_node->output (i));
1115+ }
1116+ to_erase.insert (to_erase.end (), squeeze_nodes.begin (), squeeze_nodes.end ());
1117+ to_erase.push_back (list_unpack_node);
1118+ to_erase.push_back (node);
1119+ }
1120+ for (auto * node : to_erase) {
1121+ node->destroy ();
1122+ }
1123+ }
1124+
10531125} // namespace jit
10541126} // namespace torch
0 commit comments