@@ -168,6 +168,12 @@ struct BlockPool {
168168 PrivatePool* owner_PrivatePool;
169169};
170170
171+ struct HistoryChain {
172+ History h;
173+ std::unique_ptr<HistoryChain> next; // when blocks are merged we keep records
174+ // of what used to be in the block
175+ };
176+
171177struct Block {
172178 int device; // gpu
173179 cudaStream_t stream; // allocation stream
@@ -181,8 +187,8 @@ struct Block {
181187 int event_count; // number of outstanding CUDA events
182188 int gc_count; // counter for prioritizing older / less useful blocks for
183189 // garbage collection
184- std::unique_ptr<History > history;
185- History * history_last;
190+ std::unique_ptr<HistoryChain > history;
191+ HistoryChain * history_last;
186192
187193 Block (
188194 int device,
@@ -284,7 +290,7 @@ struct AllocParams {
284290
285291int trimHistoryBefore (Block* block, void * point) {
286292 int n = 0 ;
287- while (block->history && block->history ->addr < point) {
293+ while (block->history && block->history ->h . addr < point) {
288294 block->history = std::move (block->history ->next );
289295 ++n;
290296 }
@@ -549,6 +555,16 @@ class DeviceCachingAllocator {
549555
550556 bool set_fraction = false ;
551557
558+ bool record_history = false ;
559+ std::atomic<CreateContextFn> context_recorder_;
560+ size_t alloc_trace_next = 0 ;
561+ bool alloc_trace_record_context = false ;
562+ size_t alloc_trace_max_entries = 1 ;
563+ std::vector<TraceEntry>*
564+ alloc_trace; // pointer because we need to intentionally leak this on
565+ // deallocation it can hold references to Python state which
566+ // will already be destroyed when we are in exit handlers
567+
552568 // Members specific to CUDA graphs
553569
554570 // Private pools for CUDA graphs
@@ -564,18 +580,36 @@ class DeviceCachingAllocator {
564580 // Maps a capturing stream to its assigned private pool,
565581 // in case we want multiple captures to share the same pool
566582 ska::flat_hash_map<CaptureId_t, MempoolId_t> capture_to_pool_map;
567- std::atomic<CreateContextFn> context_recorder_;
583+
584+ // XXX - maybe we should generalize and have multiple events
585+ std::vector<OutOfMemoryObserver> oom_observers_;
568586
569587 public:
570588 DeviceCachingAllocator ()
571589 : large_blocks(BlockComparator, /* is_small=*/ false ),
572- small_blocks (BlockComparator, /* is_small=*/ true ) {
590+ small_blocks (BlockComparator, /* is_small=*/ true ),
591+ alloc_trace(new std::vector<TraceEntry>()) {
573592 stats.max_split_size = CachingAllocatorConfig::max_split_size ();
574593 context_recorder_.store (nullptr );
575594 }
576595
577- void setContextRecorder (CreateContextFn c) {
578- context_recorder_.store (c);
596+ void recordHistory (
597+ bool enabled,
598+ CreateContextFn context_recorder,
599+ size_t alloc_trace_max_entries,
600+ bool alloc_trace_record_context) {
601+ std::unique_lock<std::recursive_mutex> lock (mutex);
602+ this ->record_history = enabled;
603+ this ->context_recorder_ .store (context_recorder);
604+ this ->alloc_trace_max_entries =
605+ std::max (size_t (1 ), alloc_trace_max_entries);
606+ this ->alloc_trace_record_context = alloc_trace_record_context;
607+ alloc_trace_next = 0 ;
608+ alloc_trace->clear ();
609+ }
610+
611+ void attachOutOfMemoryObserver (OutOfMemoryObserver observer) {
612+ oom_observers_.emplace_back (std::move (observer));
579613 }
580614
581615 // All public methods (except the above) acquire the allocator mutex.
@@ -585,7 +619,7 @@ class DeviceCachingAllocator {
585619 // done outside the lock because we don't know what locks the recorder needs
586620 // to have...
587621 CreateContextFn context_recorder = context_recorder_.load ();
588- std::unique_ptr <Context> context =
622+ std::shared_ptr <Context> context =
589623 context_recorder ? context_recorder () : nullptr ;
590624
591625 std::unique_lock<std::recursive_mutex> lock (mutex);
@@ -603,7 +637,6 @@ class DeviceCachingAllocator {
603637 // effect on memory use during capture should be small.
604638 process_events ();
605639 }
606-
607640 size_t size = round_size (orig_size);
608641 auto & pool = get_pool (size, stream);
609642 const size_t alloc_size = get_allocation_size (size);
@@ -635,6 +668,14 @@ class DeviceCachingAllocator {
635668 // Free all non-split cached blocks and retry alloc.
636669 || (C10_LIKELY (captures_underway == 0 ) && release_cached_blocks () &&
637670 alloc_block (params, true ));
671+ if (record_history && block_found) {
672+ record_trace (
673+ TraceEntry::SEGMENT_ALLOC,
674+ int64_t (params.block ->ptr ),
675+ params.block ->size ,
676+ params.stream (),
677+ context);
678+ }
638679 }
639680
640681 if (!block_found) {
@@ -651,6 +692,14 @@ class DeviceCachingAllocator {
651692 allowed_info = format_size (allowed_memory_maximum) + " allowed; " ;
652693 }
653694
695+ if (record_history) {
696+ record_trace (
697+ TraceEntry::OOM,
698+ device_free,
699+ params.size (),
700+ params.stream (),
701+ context);
702+ }
654703 stats.num_ooms += 1 ;
655704
656705 c10::reportOutOfMemoryToProfiler (
@@ -660,6 +709,12 @@ class DeviceCachingAllocator {
660709 stats.reserved_bytes [static_cast <int64_t >(StatType::AGGREGATE)]
661710 .current ,
662711 c10::Device (c10::DeviceType::CUDA, static_cast <DeviceIndex>(device)));
712+ for (const auto & obs : oom_observers_) {
713+ obs (device,
714+ alloc_size,
715+ set_fraction ? allowed_memory_maximum : device_total,
716+ device_free);
717+ }
663718 // "total capacity": total global memory on GPU
664719 // "allowed": memory is allowed to use, which set by fraction.
665720 // "already allocated": memory allocated by the program using the
@@ -727,7 +782,7 @@ class DeviceCachingAllocator {
727782 bool inserted = pool.blocks .insert (remaining).second ;
728783 TORCH_INTERNAL_ASSERT_DEBUG_ONLY (inserted);
729784
730- if (context ) {
785+ if (record_history ) {
731786 trimHistoryBefore (remaining, (char *)block->ptr + size);
732787 }
733788
@@ -753,17 +808,22 @@ class DeviceCachingAllocator {
753808 }
754809
755810 block->allocated = true ;
756- if (context ) {
811+ if (record_history ) {
757812 trimHistoryBefore (block, (char *)block->ptr + size);
758- block->history = std::make_unique<History>(History{
759- block->ptr ,
760- orig_size,
761- std::move (context),
813+ block->history = std::make_unique<HistoryChain>(HistoryChain{
814+ History{block->ptr , orig_size, std::move (context)},
762815 std::move (block->history )});
763816 if (!block->history_last ) {
764817 block->history_last = block->history .get ();
765818 }
819+ record_trace (
820+ TraceEntry::ALLOC,
821+ int64_t (block->ptr ),
822+ orig_size,
823+ block->stream ,
824+ block->history ->h .context );
766825 }
826+
767827 bool inserted = active_blocks.insert (block).second ;
768828 TORCH_INTERNAL_ASSERT_DEBUG_ONLY (inserted);
769829
@@ -804,6 +864,14 @@ class DeviceCachingAllocator {
804864 update_stat (stats.allocation [stat_type], -1 );
805865 update_stat (stats.allocated_bytes [stat_type], -block->size );
806866 });
867+ if (block->history ) {
868+ record_trace (
869+ TraceEntry::FREE,
870+ int64_t (block->ptr ),
871+ block->history ->h .real_size ,
872+ block->stream ,
873+ block->history ->h .context );
874+ }
807875 if (block->size >= CachingAllocatorConfig::max_split_size ())
808876 update_stat (stats.oversize_allocations , -1 );
809877
@@ -938,12 +1006,12 @@ class DeviceCachingAllocator {
9381006
9391007 /* * Dump a complete snapshot of the memory held by the allocator. Potentially
9401008 * VERY expensive. **/
941- std::vector<SegmentInfo> snapshot () const {
1009+ std::vector<SegmentInfo> snapshot () {
9421010 std::lock_guard<std::recursive_mutex> lock (mutex);
9431011
1012+ size_t total_active = 0 ;
9441013 std::vector<SegmentInfo> result;
9451014 const auto all_blocks = get_all_blocks ();
946-
9471015 for (const Block* const head_block : all_blocks) {
9481016 if (head_block->prev != nullptr ) {
9491017 continue ;
@@ -972,9 +1040,14 @@ class DeviceCachingAllocator {
9721040 if (block_info.active ) {
9731041 segment_info.active_size += block_info.size ;
9741042 }
975- block_info.history = block->history .get ();
1043+ HistoryChain* h = block->history .get ();
1044+ while (h) {
1045+ block_info.history .push_back (h->h );
1046+ h = h->next .get ();
1047+ }
9761048 block = block->next ;
9771049 }
1050+ total_active += segment_info.active_size ;
9781051 }
9791052
9801053 std::sort (
@@ -984,6 +1057,24 @@ class DeviceCachingAllocator {
9841057 return a.address < b.address ;
9851058 });
9861059
1060+ if (record_history) {
1061+ record_trace (TraceEntry::SNAPSHOT, 0 , total_active, 0 , nullptr );
1062+ }
1063+ return result;
1064+ }
1065+
1066+ std::vector<TraceEntry> trace () {
1067+ std::lock_guard<std::recursive_mutex> lock (mutex);
1068+ std::vector<TraceEntry> result;
1069+ result.reserve (alloc_trace->size ());
1070+ result.insert (
1071+ result.end (),
1072+ alloc_trace->begin () + alloc_trace_next,
1073+ alloc_trace->end ());
1074+ result.insert (
1075+ result.end (),
1076+ alloc_trace->begin (),
1077+ alloc_trace->begin () + alloc_trace_next);
9871078 return result;
9881079 }
9891080
@@ -1510,7 +1601,14 @@ class DeviceCachingAllocator {
15101601 });
15111602 if (block->size >= CachingAllocatorConfig::max_split_size ())
15121603 update_stat (stats.oversize_segments , -1 );
1513-
1604+ if (block->history ) {
1605+ record_trace (
1606+ TraceEntry::SEGMENT_FREE,
1607+ int64_t (block->ptr ),
1608+ block->size ,
1609+ block->stream ,
1610+ block->history ->h .context );
1611+ }
15141612 pool->blocks .erase (block);
15151613 delete block;
15161614 }
@@ -1641,6 +1739,28 @@ class DeviceCachingAllocator {
16411739 }
16421740 }
16431741 }
1742+
1743+ void record_trace (
1744+ TraceEntry::Action action,
1745+ int64_t addr,
1746+ size_t size,
1747+ cudaStream_t stream,
1748+ std::shared_ptr<Context> context) {
1749+ auto te = TraceEntry (
1750+ action,
1751+ addr,
1752+ size,
1753+ stream,
1754+ alloc_trace_record_context ? std::move (context) : nullptr );
1755+ if (alloc_trace->size () < alloc_trace_max_entries) {
1756+ alloc_trace->emplace_back (te);
1757+ } else {
1758+ (*alloc_trace)[alloc_trace_next++] = te;
1759+ if (alloc_trace_next == alloc_trace_max_entries) {
1760+ alloc_trace_next = 0 ;
1761+ }
1762+ }
1763+ }
16441764};
16451765
16461766class THCCachingAllocator {
@@ -1740,10 +1860,24 @@ class THCCachingAllocator {
17401860 device_allocator[device]->setMemoryFraction (fraction);
17411861 }
17421862
1743- void setContextRecorder (CreateContextFn recorder) {
1863+ void recordHistory (
1864+ bool enabled,
1865+ CreateContextFn context_recorder,
1866+ size_t alloc_trace_max_entries,
1867+ bool alloc_trace_record_context) {
1868+ int device;
1869+ C10_CUDA_CHECK (cudaGetDevice (&device));
1870+ device_allocator[device]->recordHistory (
1871+ enabled,
1872+ std::move (context_recorder),
1873+ alloc_trace_max_entries,
1874+ alloc_trace_record_context);
1875+ }
1876+
1877+ void attachOutOfMemoryObserver (OutOfMemoryObserver observer) {
17441878 int device;
17451879 C10_CUDA_CHECK (cudaGetDevice (&device));
1746- device_allocator[device]->setContextRecorder (std::move (recorder ));
1880+ device_allocator[device]->attachOutOfMemoryObserver (std::move (observer ));
17471881 }
17481882
17491883 void emptyCache () {
@@ -1780,13 +1914,13 @@ class THCCachingAllocator {
17801914 device_allocator[block->device ]->recordStream (block, stream);
17811915 }
17821916
1783- std::vector<SegmentInfo> snapshot () {
1784- std::vector<SegmentInfo> result;
1917+ SnapshotInfo snapshot () {
1918+ SnapshotInfo result;
17851919 for (auto & da : device_allocator) {
1920+ result.device_traces .emplace_back (da->trace ());
17861921 auto snap = da->snapshot ();
1787- result.insert (result.end (), snap.begin (), snap.end ());
1922+ result.segments . insert (result. segments .end (), snap.begin (), snap.end ());
17881923 }
1789-
17901924 return result;
17911925 }
17921926};
@@ -1862,14 +1996,26 @@ void setMemoryFraction(double fraction, int device) {
18621996 caching_allocator.setMemoryFraction (fraction, device);
18631997}
18641998
1865- void setContextRecorder (CreateContextFn recorder) {
1866- caching_allocator.setContextRecorder (std::move (recorder));
1999+ void recordHistory (
2000+ bool enabled,
2001+ CreateContextFn context_recorder,
2002+ size_t alloc_trace_max_entries,
2003+ bool alloc_trace_record_context) {
2004+ caching_allocator.recordHistory (
2005+ enabled,
2006+ std::move (context_recorder),
2007+ alloc_trace_max_entries,
2008+ alloc_trace_record_context);
18672009}
18682010
18692011void setAllocatorSettings (const std::string& env) {
18702012 CachingAllocatorConfig::instance ().parseArgs (env.c_str ());
18712013}
18722014
2015+ void attachOutOfMemoryObserver (OutOfMemoryObserver observer) {
2016+ caching_allocator.attachOutOfMemoryObserver (std::move (observer));
2017+ }
2018+
18732019void emptyCache (void ) {
18742020 caching_allocator.emptyCache ();
18752021}
@@ -1915,7 +2061,7 @@ void resetPeakStats(int device) {
19152061 caching_allocator.device_allocator [device]->resetPeakStats ();
19162062}
19172063
1918- std::vector<SegmentInfo> snapshot () {
2064+ SnapshotInfo snapshot () {
19192065 return caching_allocator.snapshot ();
19202066}
19212067
0 commit comments