Skip to content

Commit 02bc06a

Browse files
wanchaolfacebook-github-bot
authored andcommitted
avoid kernel launches for zero-sized tensor inputs
Summary: Pull Request resolved: #22790 Test Plan: Imported from OSS Differential Revision: D16226168 Pulled By: wanchaol fbshipit-source-id: 081607c9acc1540c753b080c5f727dc4e8c22acc
1 parent b1b65f3 commit 02bc06a

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

test/test_jit_fuser.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ def test_abs_cpu(self):
4646
def test_abs_cuda(self):
4747
self._test_fused_abs(device="cuda")
4848

49+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
50+
@skipIfRocm
51+
def test_zero_element_tensors(self):
52+
def decode(sin_t, cos_t):
53+
theta = torch.atan2(sin_t.float(), cos_t.float())
54+
return theta
55+
56+
sin = torch.zeros(0, device="cuda")
57+
cos = torch.zeros(0, device="cuda")
58+
inputs = [sin, cos]
59+
ge = self.checkScript(decode, inputs)
60+
4961
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
5062
def test_arg_configurations_smoke_cuda(self):
5163
# A smoke test to make sure we won't use the same kernel for contiguous

torch/csrc/jit/fuser/executor.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,11 @@ void launchFusion(
313313
}
314314
}
315315
}
316-
317-
fusion.launch_raw(numel, arguments);
316+
// Skip launching the kernel for zero-element tensor inputs
317+
// launches are skipped, empty zero-sized output is returned
318+
if (numel > 0) {
319+
fusion.launch_raw(numel, arguments);
320+
}
318321
}
319322

320323
bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {

0 commit comments

Comments
 (0)