Skip to content

Commit 91b1bae

Browse files
zdevitopytorchmergebot
authored andcommitted
Caching allocator tracing (#86241)
We currently can take snapshots of the state of the allocated cuda memory, but we do not have a way to correlate these snapshots with the actions the allocator that were taken between snapshots. This PR adds a simple fixed-sized buffer that records the major actions that the allocator takes (ALLOC, FREE, SEGMENT_ALLOC, SEGMENT_FREE, OOM, SNAPSHOT) and includes these with the snapshot information. Capturing period snapshots with a big enough trace buffer makes it possible to see how the allocator state changes over time. We plan to use this functionality to guide how settings in the allocator can be adjusted and eventually have a more robust overall algorithm. As a component of this functionality, we also add the ability to get a callback when the allocator will throw an OOM, primarily so that snapshots can be taken immediately to see why the program ran out of memory (most programs have some C++ state that would free tensors before the OutOfMemory exception can be caught). This PR also updates the _memory_viz.py script to pretty-print the trace information and provide a better textual summary of snapshots distinguishing between internal and external fragmentation. Pull Request resolved: #86241 Approved by: https://github.com/ngimel
1 parent 8a3a54e commit 91b1bae

File tree

9 files changed

+694
-135
lines changed

9 files changed

+694
-135
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 182 additions & 29 deletions
Large diffs are not rendered by default.

c10/cuda/CUDACachingAllocator.h

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,12 @@ struct Context {
9898
virtual ~Context() {}
9999
};
100100

101-
typedef std::unique_ptr<Context> (*CreateContextFn)(void);
101+
typedef std::shared_ptr<Context> (*CreateContextFn)(void);
102102

103103
struct History {
104104
void* addr;
105105
size_t real_size; // unrounded, actually requested size
106-
std::unique_ptr<Context> context; // per-watcher context
107-
std::unique_ptr<History> next; // when blocks are merged we keep records of
108-
// what used to be in the block
106+
std::shared_ptr<Context> context; // per-watcher context
109107
};
110108

111109
// Struct containing info of an allocation block (i.e. a fractional part of a
@@ -115,8 +113,7 @@ struct BlockInfo {
115113
int32_t gc_counter = 0;
116114
bool allocated = false;
117115
bool active = false;
118-
History* history =
119-
nullptr; // borrowed reference because it is owned by the allocator
116+
std::vector<History> history;
120117
};
121118

122119
// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc).
@@ -131,6 +128,44 @@ struct SegmentInfo {
131128
std::vector<BlockInfo> blocks;
132129
};
133130

131+
struct TraceEntry {
132+
enum Action {
133+
ALLOC, // API made to the caching allocator for new memory
134+
FREE_REQUESTED, // API call made to the caching allocator to free memory
135+
FREE_COMPLETED, // The allocator might have to delay a free because
136+
// it is still in use on another stream via record_stream
137+
// This event is generated when a free actually completes.
138+
SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS
139+
SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to
140+
// defragement or empty_caches)
141+
SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
142+
// events
143+
OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
144+
// bytes reported by cuda)
145+
};
146+
TraceEntry(
147+
Action action,
148+
int64_t addr,
149+
size_t size,
150+
cudaStream_t stream,
151+
std::shared_ptr<Context> context = nullptr)
152+
: action_(action),
153+
addr_(addr),
154+
context_(context),
155+
stream_(stream),
156+
size_(size) {}
157+
Action action_;
158+
int64_t addr_; // for OOM, this is the amount of free bytes reported by cuda
159+
std::shared_ptr<Context> context_;
160+
cudaStream_t stream_;
161+
int64_t size_;
162+
};
163+
164+
struct SnapshotInfo {
165+
std::vector<SegmentInfo> segments;
166+
std::vector<std::vector<TraceEntry>> device_traces;
167+
};
168+
134169
C10_CUDA_API void* raw_alloc(size_t nbytes);
135170
C10_CUDA_API void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream);
136171
C10_CUDA_API void raw_delete(void* ptr);
@@ -149,7 +184,7 @@ C10_CUDA_API void recordStream(const DataPtr&, CUDAStream stream);
149184
C10_CUDA_API DeviceStats getDeviceStats(int device);
150185
C10_CUDA_API void resetAccumulatedStats(int device);
151186
C10_CUDA_API void resetPeakStats(int device);
152-
C10_CUDA_API std::vector<SegmentInfo> snapshot();
187+
C10_CUDA_API SnapshotInfo snapshot();
153188

154189
// CUDAGraph interactions
155190
C10_CUDA_API void notifyCaptureBegin(
@@ -161,7 +196,17 @@ C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id);
161196

162197
C10_CUDA_API std::mutex* getFreeMutex();
163198

164-
C10_CUDA_API void setContextRecorder(CreateContextFn recorder);
199+
C10_CUDA_API void recordHistory(
200+
bool enabled,
201+
CreateContextFn context_recorder,
202+
size_t alloc_trace_max_entries,
203+
bool alloc_trace_record_context);
204+
using OutOfMemoryObserver = std::function<void(
205+
int64_t device,
206+
int64_t allocated,
207+
int64_t device_total,
208+
int64_t device_free)>;
209+
C10_CUDA_API void attachOutOfMemoryObserver(OutOfMemoryObserver observer);
165210

166211
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
167212
} // namespace CUDACachingAllocator

test/test_cuda.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,20 +4619,26 @@ def test_memory_snapshot(self):
46194619

46204620
ss = torch.cuda.memory._snapshot()
46214621
found_it = False
4622-
for seg in ss:
4622+
for seg in ss['segments']:
46234623
for b in seg['blocks']:
46244624
if 'history' in b:
46254625
for h in b['history']:
46264626
if h['real_size'] == 311 * 411 * 4:
46274627
self.assertTrue('test_cuda' in h['frames'][0]['filename'])
46284628
found_it = True
46294629
self.assertTrue(found_it)
4630+
46304631
if not IS_WINDOWS:
46314632
with tempfile.NamedTemporaryFile() as f:
46324633
torch.cuda.memory._save_segment_usage(f.name)
46334634
with open(f.name, 'r') as f2:
46344635
self.assertTrue('test_cuda.py' in f2.read())
46354636

4637+
del x
4638+
torch.cuda.empty_cache()
4639+
ss = torch.cuda.memory._snapshot()
4640+
self.assertTrue(ss['device_traces'][0][-1]['action'] == 'segment_free')
4641+
46364642
finally:
46374643
torch.cuda.memory._record_memory_history(False)
46384644

@@ -4643,7 +4649,7 @@ def test_memory_snapshot_with_cpp(self):
46434649
torch.cuda.memory._record_memory_history(True, _enable_expensive_cpp=True)
46444650
x = torch.rand(311, 411, device='cuda')
46454651

4646-
ss = torch.cuda.memory._snapshot()
4652+
ss = torch.cuda.memory._snapshot()['segments']
46474653
found_it = False
46484654
for seg in ss:
46494655
for b in seg['blocks']:
@@ -4734,16 +4740,31 @@ def test_cpp_memory_snapshot_pickle(self):
47344740
t = torch.rand(311, 411, device='cuda')
47354741
mem = pickle.loads(m.do_snapshot())
47364742
found = False
4737-
for s in mem:
4743+
for s in mem['segments']:
47384744
for b in s['blocks']:
47394745
if b['state'] == 'active_allocated' and 'history' in b:
47404746
history = b['history']
47414747
if history and history[0]['real_size'] == 311 * 411 * 4:
47424748
found = True
4749+
last_action = mem['device_traces'][0][-1]
4750+
self.assertTrue(last_action['action'] == 'alloc')
4751+
self.assertTrue(last_action['size'] == 311 * 411 * 4)
47434752
self.assertTrue(found)
47444753
finally:
47454754
m.record(False)
47464755

4756+
def test_notifies_oom(self):
4757+
x = False
4758+
4759+
def cb(device, alloc, device_alloc, device_free):
4760+
nonlocal x
4761+
x = True
4762+
torch._C._cuda_attach_out_of_memory_observer(cb)
4763+
with self.assertRaises(torch.cuda.OutOfMemoryError):
4764+
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
4765+
self.assertTrue(x)
4766+
4767+
47474768
instantiate_parametrized_tests(TestCuda)
47484769

47494770
if __name__ == '__main__':

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,8 +1178,8 @@ def _cuda_emptyCache() -> None: ...
11781178
def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
11791179
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
11801180
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
1181-
def _cuda_memorySnapshot() -> List[Dict[str, Any]]: ...
1182-
def _cuda_recordMemoryHistory(enabled: _bool, cpp: _bool) -> None: ...
1181+
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
1182+
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
11831183
def _cuda_lock_mutex() -> None: ...
11841184
def _cuda_unlock_mutex() -> None: ...
11851185
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...

0 commit comments

Comments
 (0)