Skip to content

Commit 9903d85

Browse files
committed
Update on "Caching allocator tracing"
We currently can take snapshots of the state of the allocated cuda memory, but we do not have a way to correlate these snapshots with the actions the allocator that were taken between snapshots. This PR adds a simple fixed-sized buffer that records the major actions that the allocator takes (ALLOC, FREE, SEGMENT_ALLOC, SEGMENT_FREE, OOM, SNAPSHOT) and includes these with the snapshot information. Capturing period snapshots with a big enough trace buffer makes it possible to see how the allocator state changes over time. We plan to use this functionality to guide how settings in the allocator can be adjusted and eventually have a more robust overall algorithm. As a component of this functionality, we also add the ability to get a callback when the allocator will throw an OOM, primarily so that snapshots can be taken immediately to see why the program ran out of memory (most programs have some C++ state that would free tensors before the OutOfMemory exception can be caught). This PR also updates the _memory_viz.py script to pretty-print the trace information and provide a better textual summary of snapshots distinguishing between internal and external fragmentation. [ghstack-poisoned]
2 parents 06a84d7 + d255818 commit 9903d85

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

torch/csrc/cuda/Module.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
646646
history_entry[addr_s] = (int64_t)h.addr;
647647
history_entry[real_size_s] = h.real_size;
648648
if (h.context) {
649-
auto sc = (StackContext*) h.context.get();
649+
auto sc = (StackContext*)h.context.get();
650650
history_entry[frames_s] = get_frames(sc);
651651
if (!sc->cpp_frames.empty()) {
652652
history_entry[cpp_frames_s] = py::cast(sc->cpp_frames);
@@ -829,13 +829,21 @@ static void registerCudaDeviceProperties(PyObject* module) {
829829
return stream.str();
830830
});
831831

832-
m.def("_cuda_recordMemoryHistory", [](bool enabled, bool record_context, bool record_context_cpp, Py_ssize_t alloc_trace_max_entries, bool alloc_trace_record_context) {
833-
c10::cuda::CUDACachingAllocator::recordHistory(
834-
enabled,
835-
record_context ? (record_context_cpp ? StackContext::gather_with_cpp : StackContext::gather) : nullptr,
836-
alloc_trace_max_entries,
837-
alloc_trace_record_context);
838-
});
832+
m.def(
833+
"_cuda_recordMemoryHistory",
834+
[](bool enabled,
835+
bool record_context,
836+
bool record_context_cpp,
837+
Py_ssize_t alloc_trace_max_entries,
838+
bool alloc_trace_record_context) {
839+
c10::cuda::CUDACachingAllocator::recordHistory(
840+
enabled,
841+
record_context ? (record_context_cpp ? StackContext::gather_with_cpp
842+
: StackContext::gather)
843+
: nullptr,
844+
alloc_trace_max_entries,
845+
alloc_trace_record_context);
846+
});
839847
}
840848

841849
static void bindGetDeviceProperties(PyObject* module) {

torch/cuda/memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,8 @@ def _record_memory_history(enabled: bool, record_context=True,
615615
stack trace collection; file an issue with us if you need it.
616616
"""
617617
with torch.cuda.device(device):
618-
_C._cuda_recordMemoryHistory(enabled, record_context, _enable_expensive_cpp, trace_alloc_max_entries, trace_alloc_record_context)
618+
_C._cuda_recordMemoryHistory(enabled, record_context, _enable_expensive_cpp,
619+
trace_alloc_max_entries, trace_alloc_record_context)
619620

620621
def _snapshot(device: Union[Device, int] = None):
621622
with torch.cuda.device(device):

0 commit comments

Comments
 (0)