88namespace at {
99namespace cuda {
1010
11+ static bool _cuda_graphs_debug = false ;
12+
1113MempoolId_t graph_pool_handle () {
1214#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
1315 // uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
@@ -16,7 +18,7 @@ MempoolId_t graph_pool_handle() {
1618 // cudaStreamGetCaptureInfo id_s in capture_begin.
1719 return {0 , uuid++};
1820#else
19- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
21+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
2022 return {0 , 0 };
2123#endif
2224}
@@ -46,7 +48,7 @@ CUDAGraph::CUDAGraph()
4648 // CUDAStreams may not be default-constructed.
4749 : capture_stream_(at::cuda::getCurrentCUDAStream()) {
4850#if (defined(CUDA_VERSION) && CUDA_VERSION < 11000) || defined(USE_ROCM)
49- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
51+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
5052#endif
5153}
5254
@@ -122,7 +124,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
122124 // kernel will end up as part of the capture or not.
123125 c10::cuda::CUDACachingAllocator::notifyCaptureBegin (capture_dev_, id_, mempool_id_);
124126#else
125- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
127+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
126128#endif
127129}
128130
@@ -186,12 +188,17 @@ void CUDAGraph::capture_end() {
186188 " attempted to be captured on wrong device or stream." );
187189 }
188190
189- // Now that we've instantiated graph_ into graph_exec_,
190- // we don't need graph_ anymore.
191- AT_CUDA_CHECK (cudaGraphDestroy (graph_));
192- has_graph_ = false ;
191+ // check if debug path is set
192+ if (!_cuda_graphs_debug) {
193+ // Now that we've instantiated graph_ into graph_exec_,
194+ // we don't need graph_ anymore.
195+ AT_CUDA_CHECK (cudaGraphDestroy (graph_));
196+ has_graph_ = false ;
197+ } else {
198+ TORCH_WARN (" DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called." );
199+ }
193200#else
194- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
201+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
195202#endif
196203}
197204
@@ -226,7 +233,33 @@ void CUDAGraph::replay() {
226233 AT_CUDA_CHECK (cudaDeviceSynchronize ());
227234 }
228235#else
229- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
236+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
237+ #endif
238+ }
239+
240+ void CUDAGraph::enable_debug_mode () {
241+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
242+ _cuda_graphs_debug = true ;
243+ #else
244+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
245+ #endif
246+
247+ }
248+
249+ void CUDAGraph::debug_dump (const std::string& debug_path) {
250+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
251+ if (_cuda_graphs_debug) {
252+ TORCH_WARN (" DEBUG: calling debug_dump()" );
253+ if (has_graph_) {
254+ TORCH_WARN (" DEBUG: calling cudaGraphDebugDotPrint() with " , debug_path);
255+ C10_CUDA_CHECK_WARN (cudaGraphDebugDotPrint (graph_, debug_path.c_str (), 1 <<10 )); // most verbose output
256+ AT_CUDA_CHECK (cudaGraphDestroy (graph_));
257+ }
258+ } else {
259+ TORCH_WARN (" CUDA Graphs debug not enabled, set with torch._C._cuda_enable_graphs_debug_mode" );
260+ }
261+ #else
262+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
230263#endif
231264}
232265
@@ -262,7 +295,7 @@ void CUDAGraph::reset() {
262295 C10_CUDA_CHECK_WARN (cudaGraphExecDestroy (graph_exec_));
263296 }
264297#else
265- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
298+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
266299#endif
267300}
268301
@@ -272,7 +305,7 @@ MempoolId_t CUDAGraph::pool() {
272305 TORCH_CHECK (has_graph_exec_,
273306 " Called CUDAGraph::pool() without a preceding successful capture." );
274307#else
275- TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM" );
308+ TORCH_CHECK (false , " CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM" );
276309#endif
277310 return mempool_id_;
278311}
0 commit comments