Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_jit_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def test_abs_cpu(self):
def test_abs_cuda(self):
self._test_fused_abs(device="cuda")

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@skipIfRocm
def test_zero_element_tensors(self):
def decode(sin_t, cos_t):
theta = torch.atan2(sin_t.float(), cos_t.float())
return theta

sin = torch.zeros(0, device="cuda")
cos = torch.zeros(0, device="cuda")
inputs = [sin, cos]
ge = self.checkScript(decode, inputs)

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_arg_configurations_smoke_cuda(self):
# A smoke test to make sure we won't use the same kernel for contiguous
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/fuser/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,11 @@ void launchFusion(
}
}
}

fusion.launch_raw(numel, arguments);
// Skip launching the kernel for zero-element tensor inputs
// launches are skipped, empty zero-sized output is returned
if (numel > 0) {
fusion.launch_raw(numel, arguments);
}
}

bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
Expand Down