Skip to content

Commit aa91a65

Browse files
nickggfacebook-github-bot
authored andcommitted
[TensorExpr] Fix propagation of loop options when splitting loops (#40035)
Summary: Fix a bug in SplitWithTail and SplitWithMask where loop_options such as Cuda block/thread bindings are overwritten by the split. This PR fixes this bug by propagating the loop options to the outer loop, which for axis bindings should be equivalent. Pull Request resolved: #40035 Reviewed By: ZolotukhinM Differential Revision: D22080263 Pulled By: nickgg fbshipit-source-id: b8a9583fd90f69319fc4bb4db644e91f6ffa8e67
1 parent 9c7ca89 commit aa91a65

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

test/cpp/tensorexpr/test_loopnest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,59 @@ void testExprSplitWithMask01() {
265265
ExpectAllNear(c_v, c_ref, 1e-5);
266266
}
267267

268+
void testSplitWithTailWithLoopOptions() {
269+
KernelScope kernel_scope;
270+
const int M = 21;
271+
Buffer a_buf("a", kFloat, {M});
272+
Buffer b_buf("b", kFloat, {M});
273+
Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) {
274+
return a_buf(m) + b_buf(m) + 1.0f;
275+
});
276+
For *outer, *inner, *tail;
277+
278+
LoopNest l({tensor});
279+
auto loops = NodeFinder<For>::find(l.root_stmt());
280+
ASSERT_GT(loops.size(), 0);
281+
l.setGPUBlockIndex(loops[0], LoopOptions::IDX_Y);
282+
l.splitWithTail(loops[0], 4, &outer, &inner, &tail);
283+
ASSERT_NE(outer, nullptr);
284+
ASSERT_NE(inner, nullptr);
285+
ASSERT_NE(tail, nullptr);
286+
287+
// Outer loop carries loop axis bindings.
288+
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
289+
ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
290+
291+
// Inner loop has none.
292+
ASSERT_TRUE(inner->loop_options().isDefault());
293+
294+
// Tail loop has none.
295+
ASSERT_TRUE(tail->loop_options().isDefault());
296+
}
297+
298+
void testSplitWithMaskWithLoopOptions() {
299+
KernelScope kernel_scope;
300+
const int M = 21;
301+
Buffer a_buf("a", kFloat, {M});
302+
Buffer b_buf("b", kFloat, {M});
303+
Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) {
304+
return a_buf(m) + b_buf(m) + 1.0f;
305+
});
306+
For *outer, *inner;
307+
308+
LoopNest l({tensor});
309+
auto loops = NodeFinder<For>::find(l.root_stmt());
310+
l.setGPUBlockIndex(loops[0], LoopOptions::IDX_Y);
311+
l.splitWithMask(loops[0], 4, &outer, &inner);
312+
313+
// Outer loop carries loop axis bindings.
314+
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
315+
ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
316+
317+
// Inner loop has none.
318+
ASSERT_TRUE(inner->loop_options().isDefault());
319+
}
320+
268321
void testScheduleBroadcastAddBuffer() {
269322
KernelScope kernel_scope;
270323
const int M = 4;

test/cpp/tensorexpr/tests.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ namespace jit {
4141
_(ExprSplitWithTail) \
4242
_(ExprSplitWithTailNone) \
4343
_(ExprSplitWithMask01) \
44+
_(SplitWithTailWithLoopOptions) \
45+
_(SplitWithMaskWithLoopOptions) \
4446
_(ScheduleBroadcastAddBuffer) \
4547
_(ScheduleFunctionCall01) \
4648
_(ScheduleInlineFunc01) \

torch/csrc/jit/tensorexpr/loopnest.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ void LoopNest::splitWithTail(
946946
Substitute(Stmt::clone(f->body()), {{f->var(), combined_index1}});
947947

948948
*inner = new For(i_inner, new IntImm(0), factor_expr, body_inner);
949-
*outer = new For(i_outer, new IntImm(0), split_count, *inner);
949+
*outer =
950+
new For(i_outer, new IntImm(0), split_count, *inner, f->loop_options());
950951

951952
// TODO: cleanup API for adding/removing statements
952953
p->replace_stmt(f, *outer);
@@ -1020,7 +1021,8 @@ void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) {
10201021
body_inner = Substitute(body_inner, {{f->var(), combined_index}});
10211022

10221023
*inner = new For(i_inner, new IntImm(0), factor_expr, body_inner);
1023-
*outer = new For(i_outer, new IntImm(0), split_count, *inner);
1024+
*outer =
1025+
new For(i_outer, new IntImm(0), split_count, *inner, f->loop_options());
10241026

10251027
// TODO: cleanup API for adding/removing statements
10261028
p->replace_stmt(f, *outer);

0 commit comments

Comments
 (0)