@@ -428,6 +428,7 @@ class TensorExprFuser {
428428 void createFusionGroups (Block* block) {
429429 std::vector<Node*> fusion_groups;
430430 auto reverse_iter = block->nodes ().reverse ();
431+ Node* prev_fusion_group = nullptr ;
431432 for (auto it = reverse_iter.begin (); it != reverse_iter.end ();) {
432433 Node* n = *it;
433434 GRAPH_DEBUG (" Considering node:" , *n)
@@ -450,11 +451,40 @@ class TensorExprFuser {
450451 }
451452
452453 Node* fusion_group = createFusionGroup (n);
453- fusion_groups.push_back (fusion_group);
454- it = fusion_group->reverseIterator ();
454+ debugDumpFusionGroup (" Fusion group constructed: " , fusion_group);
455+
456+ // Try merging the just created fusion group into the previous one.
457+ // If it did not work, then put the previous fusion group into
458+ // fusion_groups vector - we will not touch it anymore in this loop.
459+ // If merging suceeded, save the merged group as the "previous" fusion
460+ // group so that we can try to merge the next one into it.
461+ if (prev_fusion_group) {
462+ debugDumpFusionGroup (
463+ " Trying to merge into the previous fusion group: " ,
464+ prev_fusion_group);
465+ if (canMerge (prev_fusion_group, fusion_group)) {
466+ prev_fusion_group = tryMerge (prev_fusion_group, fusion_group);
467+ debugDumpFusionGroup (
468+ " Successfully merged into the previous fusion group: " ,
469+ prev_fusion_group);
470+ } else {
471+ GRAPH_DEBUG (" Cannot merge into the previous fusion group" );
472+ fusion_groups.push_back (prev_fusion_group);
473+ prev_fusion_group = fusion_group;
474+ }
475+ } else {
476+ prev_fusion_group = fusion_group;
477+ }
478+ it = prev_fusion_group->reverseIterator ();
455479 it++;
456480 }
457481
482+ // We were adding groups into the vector lagging by one - catch up with
483+ // adding the last one
484+ if (prev_fusion_group) {
485+ fusion_groups.push_back (prev_fusion_group);
486+ }
487+
458488 for (Node* n : fusion_groups) {
459489 inlineIfTooSmall (n);
460490 }
@@ -617,7 +647,7 @@ class TensorExprFuser {
617647 REQ (consumer->owningBlock () == producer->owningBlock ());
618648
619649 // Symbolic checks
620- REQ (canHandle (producer));
650+ REQ (canHandle (producer) || producer-> kind () == prim::TensorExprGroup );
621651 TORCH_INTERNAL_ASSERT (
622652 consumer->kind () == prim::TensorExprGroup || canHandle (consumer));
623653
0 commit comments