Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
namespace at {
namespace cuda {

static bool _cuda_graphs_debug = false;

MempoolId_t graph_pool_handle() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
Expand All @@ -16,7 +18,7 @@ MempoolId_t graph_pool_handle() {
// cudaStreamGetCaptureInfo id_s in capture_begin.
return {0, uuid++};
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
return {0, 0};
#endif
}
Expand Down Expand Up @@ -46,7 +48,7 @@ CUDAGraph::CUDAGraph()
// CUDAStreams may not be default-constructed.
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11000) || defined(USE_ROCM)
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

Expand Down Expand Up @@ -122,7 +124,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
// kernel will end up as part of the capture or not.
c10::cuda::CUDACachingAllocator::notifyCaptureBegin(capture_dev_, id_, mempool_id_);
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

Expand Down Expand Up @@ -186,12 +188,17 @@ void CUDAGraph::capture_end() {
"attempted to be captured on wrong device or stream.");
}

// Now that we've instantiated graph_ into graph_exec_,
// we don't need graph_ anymore.
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
has_graph_ = false;
// check if debug path is set
if (!_cuda_graphs_debug) {
// Now that we've instantiated graph_ into graph_exec_,
// we don't need graph_ anymore.
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
has_graph_ = false;
} else {
TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called.");
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

Expand Down Expand Up @@ -226,7 +233,33 @@ void CUDAGraph::replay() {
AT_CUDA_CHECK(cudaDeviceSynchronize());
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

void CUDAGraph::enable_debug_mode() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
_cuda_graphs_debug = true;
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif

}

void CUDAGraph::debug_dump(const std::string& debug_path) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
if (_cuda_graphs_debug) {
TORCH_WARN("DEBUG: calling debug_dump()");
if (has_graph_) {
TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
}
} else {
TORCH_WARN("CUDA Graphs debug not enabled, set with torch._C._cuda_enable_graphs_debug_mode");
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

Expand Down Expand Up @@ -262,7 +295,7 @@ void CUDAGraph::reset() {
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
}

Expand All @@ -272,7 +305,7 @@ MempoolId_t CUDAGraph::pool() {
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::pool() without a preceding successful capture.");
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM");
#endif
return mempool_id_;
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/CUDAGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
void replay();
void reset();
MempoolId_t pool();
void enable_debug_mode();
void debug_dump(const std::string& debug_path);

protected:
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,9 @@ class _CUDAGraph:
def replay(self) -> None: ...
def reset(self) -> None: ...
def pool(self) -> Tuple[_int, _int]: ...
def enable_debug_mode(self) -> None: ...
def debug_dump(self,
debug_path: str) -> None: ...

def _cuda_isCurrentStreamCapturing() -> _bool: ...

Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/cuda/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,19 @@ void THCPGraph_init(PyObject* module) {
.def(
"pool",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::pool),
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"debug_dump",
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
py::call_guard<py::gil_scoped_release>())
.def(
"enable_debug_mode",
torch::wrap_pybind_function(
&::at::cuda::CUDAGraph::enable_debug_mode),
py::call_guard<py::gil_scoped_release>())
.def(
"debug_dump",
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
py::call_guard<py::gil_scoped_release>(),
py::arg("debug_path"));
}
16 changes: 16 additions & 0 deletions torch/cuda/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ def pool(self):
"""
return super(CUDAGraph, self).pool()

def enable_debug_mode(self):
r"""
Enables debugging mode for CUDAGraph.debug_dump.
"""
return super(CUDAGraph, self).enable_debug_mode()

def debug_dump(self, debug_path):
r"""
Arguments:
debug_path (required): Path to dump the graph to.

Calls a debugging function to dump the graph if the debugging is
enabled via CUDAGraph.enable_debug_mode()
"""
return super(CUDAGraph, self).debug_dump(debug_path)


class graph(object):
r"""
Expand Down