Skip to content

Commit ce71f97

Browse files
Mikhail Zolotukhinbertmaher
authored andcommitted
[TensorExpr] Fuser: try merging adjacent fusion groups.
ghstack-source-id: 6673ea6 Pull Request resolved: #43671
1 parent aedce77 commit ce71f97

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

test/cpp/tensorexpr/test_te_fuser_pass.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,29 @@ void testFuserPass_Multidevice() {
287287
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
288288
}
289289
}
290+
291+
void testFuserPass_MergeGroups() {
292+
WithCPUFuser cf;
293+
KernelScope kernel_scope;
294+
const auto graph_string = R"IR(
295+
graph(%a : Float(128:1, device=cpu),
296+
%b : Float(128:1, device=cpu)):
297+
%x : Float(128:1, device=cpu) = aten::mul(%a, %a)
298+
%y : Float(128:1, device=cpu) = aten::mul(%b, %b)
299+
return (%x, %y))IR";
300+
auto g = std::make_shared<Graph>();
301+
torch::jit::parseIR(graph_string, g.get());
302+
303+
g->lint();
304+
FuseTensorExprs(g, /* min_group_size= */ 1);
305+
306+
// The %x and %y computations are completely independent and yet we should put
307+
// them into a single fusion group rather than having two separate ones.
308+
testing::FileCheck()
309+
.check("= prim::TensorExprGroup_")
310+
->check_not("= prim::TensorExprGroup_")
311+
->run(*g);
312+
}
313+
290314
} // namespace jit
291315
} // namespace torch

test/cpp/tensorexpr/tests.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ namespace jit {
264264
_(FuserPass_UnfusibleDevice) \
265265
_(FuserPass_UnknownShapes) \
266266
_(FuserPass_Multidevice) \
267+
_(FuserPass_MergeGroups) \
267268
_(TrainBasic)
268269

269270
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \

test/test_jit_fuser_te.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,8 @@ def fn(x, y):
774774
ge(*inputs_cuda0)
775775
ge(*inputs_cuda1)
776776

777+
# TODO: we're currently not checking 'device' in the type info when pulling
778+
# nodes into a fusion group. We should fix that and re-enable this test.
777779
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
778780
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
779781
def test_kernel_cache_multi_gpu(self):

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ class TensorExprFuser {
428428
void createFusionGroups(Block* block) {
429429
std::vector<Node*> fusion_groups;
430430
auto reverse_iter = block->nodes().reverse();
431+
Node* prev_fusion_group = nullptr;
431432
for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
432433
Node* n = *it;
433434
GRAPH_DEBUG("Considering node:", *n)
@@ -450,11 +451,40 @@ class TensorExprFuser {
450451
}
451452

452453
Node* fusion_group = createFusionGroup(n);
453-
fusion_groups.push_back(fusion_group);
454-
it = fusion_group->reverseIterator();
454+
debugDumpFusionGroup("Fusion group constructed: ", fusion_group);
455+
456+
// Try merging the just created fusion group into the previous one.
457+
// If it did not work, then put the previous fusion group into
458+
// fusion_groups vector - we will not touch it anymore in this loop.
459+
// If merging suceeded, save the merged group as the "previous" fusion
460+
// group so that we can try to merge the next one into it.
461+
if (prev_fusion_group) {
462+
debugDumpFusionGroup(
463+
"Trying to merge into the previous fusion group: ",
464+
prev_fusion_group);
465+
if (canMerge(prev_fusion_group, fusion_group)) {
466+
prev_fusion_group = tryMerge(prev_fusion_group, fusion_group);
467+
debugDumpFusionGroup(
468+
"Successfully merged into the previous fusion group: ",
469+
prev_fusion_group);
470+
} else {
471+
GRAPH_DEBUG("Cannot merge into the previous fusion group");
472+
fusion_groups.push_back(prev_fusion_group);
473+
prev_fusion_group = fusion_group;
474+
}
475+
} else {
476+
prev_fusion_group = fusion_group;
477+
}
478+
it = prev_fusion_group->reverseIterator();
455479
it++;
456480
}
457481

482+
// We were adding groups into the vector lagging by one - catch up with
483+
// adding the last one
484+
if (prev_fusion_group) {
485+
fusion_groups.push_back(prev_fusion_group);
486+
}
487+
458488
for (Node* n : fusion_groups) {
459489
inlineIfTooSmall(n);
460490
}
@@ -617,7 +647,7 @@ class TensorExprFuser {
617647
REQ(consumer->owningBlock() == producer->owningBlock());
618648

619649
// Symbolic checks
620-
REQ(canHandle(producer));
650+
REQ(canHandle(producer) || producer->kind() == prim::TensorExprGroup);
621651
TORCH_INTERNAL_ASSERT(
622652
consumer->kind() == prim::TensorExprGroup || canHandle(consumer));
623653

0 commit comments

Comments
 (0)