Skip to content

Commit 0057be3

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[CUDA graphs] Add warning if captured graph is empty (#88754)
Fixes #87894 This PR adds a warning if captured graph is empty (consists of zero nodes). The example snippet where would it be useful: ```python import torch x = torch.randn(10) z = torch.zeros(10) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): z = x * x # Warn user ``` and in #87894 Pull Request resolved: #88754 Approved by: https://github.com/ezyang
1 parent c18da59 commit 0057be3

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

aten/src/ATen/cuda/CUDAGraph.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ void CUDAGraph::capture_end() {
179179
"when capture began");
180180
wholegraph_increment_ = gen->capture_epilogue();
181181

182+
size_t numCUDAGraphNodes = 0;
183+
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, NULL, &numCUDAGraphNodes));
184+
if (numCUDAGraphNodes == 0) {
185+
TORCH_WARN("The CUDA Graph is empty. This ususally means that the graph was ",
186+
"attempted to be captured on wrong device or stream.");
187+
}
188+
182189
// Now that we've instantiated graph_ into graph_exec_,
183190
// we don't need graph_ anymore.
184191
AT_CUDA_CHECK(cudaGraphDestroy(graph_));

test/test_cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import tempfile
1616
import threading
1717
import unittest
18+
import warnings
1819
from random import randint
1920

2021
import torch
@@ -3291,6 +3292,18 @@ def test_graph_capture_simple(self):
32913292

32923293
self.assertTrue(b.sum().item() == 11000.)
32933294

3295+
@unittest.skipIf((not TEST_CUDA) or
3296+
TEST_WITH_ROCM or
3297+
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
3298+
def test_graph_warn_if_has_zero_nodes(self):
3299+
with warnings.catch_warnings(record=True) as caught:
3300+
g = torch.cuda.CUDAGraph()
3301+
s = torch.cuda.Stream()
3302+
with torch.cuda.stream(s):
3303+
g.capture_begin()
3304+
g.capture_end()
3305+
self.assertTrue(any("The CUDA Graph is empty" in str(w.message) for w in caught))
3306+
32943307
@unittest.skipIf((not TEST_CUDA) or
32953308
TEST_WITH_ROCM or
32963309
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")

0 commit comments

Comments
 (0)