Skip to content

Commit 62e450d

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA Graphs] Add option to dump a captured graph for debugging (#85519)
CC @xwang233 @ptrblck @ngimel Pull Request resolved: #85519 Approved by: https://github.com/ngimel
1 parent 1abe264 commit 62e450d

File tree

5 files changed

+80
-12
lines changed

5 files changed

+80
-12
lines changed

aten/src/ATen/cuda/CUDAGraph.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
namespace at {
99
namespace cuda {
1010

11+
static bool _cuda_graphs_debug = false;
12+
1113
MempoolId_t graph_pool_handle() {
1214
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
1315
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
@@ -16,7 +18,7 @@ MempoolId_t graph_pool_handle() {
1618
// cudaStreamGetCaptureInfo id_s in capture_begin.
1719
return {0, uuid++};
1820
#else
19-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
21+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
2022
return {0, 0};
2123
#endif
2224
}
@@ -46,7 +48,7 @@ CUDAGraph::CUDAGraph()
4648
// CUDAStreams may not be default-constructed.
4749
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
4850
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11000) || defined(USE_ROCM)
49-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
51+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
5052
#endif
5153
}
5254

@@ -122,7 +124,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
122124
// kernel will end up as part of the capture or not.
123125
c10::cuda::CUDACachingAllocator::notifyCaptureBegin(capture_dev_, id_, mempool_id_);
124126
#else
125-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
127+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
126128
#endif
127129
}
128130

@@ -186,12 +188,17 @@ void CUDAGraph::capture_end() {
186188
"attempted to be captured on wrong device or stream.");
187189
}
188190

189-
// Now that we've instantiated graph_ into graph_exec_,
190-
// we don't need graph_ anymore.
191-
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
192-
has_graph_ = false;
191+
// check if debug path is set
192+
if (!_cuda_graphs_debug) {
193+
// Now that we've instantiated graph_ into graph_exec_,
194+
// we don't need graph_ anymore.
195+
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
196+
has_graph_ = false;
197+
} else {
198+
TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called.");
199+
}
193200
#else
194-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
201+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
195202
#endif
196203
}
197204

@@ -226,7 +233,33 @@ void CUDAGraph::replay() {
226233
AT_CUDA_CHECK(cudaDeviceSynchronize());
227234
}
228235
#else
229-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
236+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
237+
#endif
238+
}
239+
240+
void CUDAGraph::enable_debug_mode() {
241+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
242+
_cuda_graphs_debug = true;
243+
#else
244+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
245+
#endif
246+
247+
}
248+
249+
void CUDAGraph::debug_dump(const std::string& debug_path) {
250+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
251+
if (_cuda_graphs_debug) {
252+
TORCH_WARN("DEBUG: calling debug_dump()");
253+
if (has_graph_) {
254+
TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
255+
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output
256+
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
257+
}
258+
} else {
259+
TORCH_WARN("CUDA Graphs debug not enabled, set with torch._C._cuda_enable_graphs_debug_mode");
260+
}
261+
#else
262+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
230263
#endif
231264
}
232265

@@ -262,7 +295,7 @@ void CUDAGraph::reset() {
262295
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
263296
}
264297
#else
265-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
298+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
266299
#endif
267300
}
268301

@@ -272,7 +305,7 @@ MempoolId_t CUDAGraph::pool() {
272305
TORCH_CHECK(has_graph_exec_,
273306
"Called CUDAGraph::pool() without a preceding successful capture.");
274307
#else
275-
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
308+
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
276309
#endif
277310
return mempool_id_;
278311
}

aten/src/ATen/cuda/CUDAGraph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
2424
void replay();
2525
void reset();
2626
MempoolId_t pool();
27+
void enable_debug_mode();
28+
void debug_dump(const std::string& debug_path);
2729

2830
protected:
2931
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000

torch/_C/__init__.pyi.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,9 @@ class _CUDAGraph:
13061306
def replay(self) -> None: ...
13071307
def reset(self) -> None: ...
13081308
def pool(self) -> Tuple[_int, _int]: ...
1309+
def enable_debug_mode(self) -> None: ...
1310+
def debug_dump(self,
1311+
debug_path: str) -> None: ...
13091312

13101313
def _cuda_isCurrentStreamCapturing() -> _bool: ...
13111314

torch/csrc/cuda/Graph.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,19 @@ void THCPGraph_init(PyObject* module) {
4848
.def(
4949
"pool",
5050
torch::wrap_pybind_function(&at::cuda::CUDAGraph::pool),
51-
py::call_guard<py::gil_scoped_release>());
51+
py::call_guard<py::gil_scoped_release>())
52+
.def(
53+
"debug_dump",
54+
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
55+
py::call_guard<py::gil_scoped_release>())
56+
.def(
57+
"enable_debug_mode",
58+
torch::wrap_pybind_function(
59+
&::at::cuda::CUDAGraph::enable_debug_mode),
60+
py::call_guard<py::gil_scoped_release>())
61+
.def(
62+
"debug_dump",
63+
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
64+
py::call_guard<py::gil_scoped_release>(),
65+
py::arg("debug_path"));
5266
}

torch/cuda/graphs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ def pool(self):
100100
"""
101101
return super(CUDAGraph, self).pool()
102102

103+
def enable_debug_mode(self):
104+
r"""
105+
Enables debugging mode for CUDAGraph.debug_dump.
106+
"""
107+
return super(CUDAGraph, self).enable_debug_mode()
108+
109+
def debug_dump(self, debug_path):
110+
r"""
111+
Arguments:
112+
debug_path (required): Path to dump the graph to.
113+
114+
Calls a debugging function to dump the graph if the debugging is
115+
enabled via CUDAGraph.enable_debug_mode()
116+
"""
117+
return super(CUDAGraph, self).debug_dump(debug_path)
118+
103119

104120
class graph(object):
105121
r"""

0 commit comments

Comments
 (0)