Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cuda_runtime_api.h>
#include <algorithm>
#include <bitset>
#include <cstdint>
#include <deque>
#include <iterator>
#include <map>
Expand Down Expand Up @@ -183,6 +184,7 @@ struct Block {
cudaStream_t stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
size_t requested_size; // memory originally requested
BlockPool* pool{nullptr}; // owning memory pool
void* ptr{nullptr}; // memory address
bool allocated{false}; // in-use flag
Expand All @@ -204,12 +206,17 @@ struct Block {
stream(stream),
stream_uses(),
size(size),
requested_size(0),
pool(pool),
ptr(ptr) {}

// constructor for search key
Block(int device, cudaStream_t stream, size_t size)
: device(device), stream(stream), stream_uses(), size(size) {}
: device(device),
stream(stream),
stream_uses(),
size(size),
requested_size(0) {}

bool is_split() const {
return (prev != nullptr) || (next != nullptr);
Expand Down Expand Up @@ -963,25 +970,32 @@ class DeviceCachingAllocator {
if (already_split) {
// An already-split inactive block is being shrunk by size bytes.
update_stat_array(
stats.inactive_split_bytes, -block->size, params.stat_types);
stats.inactive_split_bytes,
-static_cast<std::int64_t>(block->size),
params.stat_types);
} else {
// A new split inactive block is being created from a previously unsplit
// block, size remaining->size bytes.
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], remaining->size);
update_stat(
stats.inactive_split_bytes[stat_type],
static_cast<std::int64_t>(remaining->size));
update_stat(stats.inactive_split[stat_type], 1);
});
}

} else if (already_split) {
// An already-split block is becoming active
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], -block->size);
update_stat(
stats.inactive_split_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
update_stat(stats.inactive_split[stat_type], -1);
});
}

block->allocated = true;
block->requested_size = orig_size;
if (record_history) {
trimHistoryBefore(block, (char*)block->ptr + size);
block->history = std::make_unique<HistoryChain>(HistoryChain{
Expand All @@ -1003,9 +1017,16 @@ class DeviceCachingAllocator {

for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], 1);
update_stat(stats.allocated_bytes[stat_type], block->size);
update_stat(
stats.allocated_bytes[stat_type],
static_cast<std::int64_t>(block->size));
update_stat(stats.active[stat_type], 1);
update_stat(stats.active_bytes[stat_type], block->size);
update_stat(
stats.active_bytes[stat_type],
static_cast<std::int64_t>(block->size));
update_stat(
stats.requested_bytes[stat_type],
static_cast<std::int64_t>(block->requested_size));
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, 1);
Expand Down Expand Up @@ -1036,7 +1057,9 @@ class DeviceCachingAllocator {
true;
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], -1);
update_stat(stats.allocated_bytes[stat_type], -block->size);
update_stat(
stats.allocated_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
});
if (block->history) {
record_trace(
Expand Down Expand Up @@ -1151,6 +1174,7 @@ class DeviceCachingAllocator {
reset_accumulated_stat(stats.reserved_bytes[statType]);
reset_accumulated_stat(stats.active_bytes[statType]);
reset_accumulated_stat(stats.inactive_split_bytes[statType]);
reset_accumulated_stat(stats.requested_bytes[statType]);
}

stats.num_alloc_retries = 0;
Expand All @@ -1173,6 +1197,7 @@ class DeviceCachingAllocator {
reset_peak_stat(stats.reserved_bytes[statType]);
reset_peak_stat(stats.active_bytes[statType]);
reset_peak_stat(stats.inactive_split_bytes[statType]);
reset_peak_stat(stats.requested_bytes[statType]);
}
reset_peak_stat(stats.oversize_allocations);
reset_peak_stat(stats.oversize_segments);
Expand Down Expand Up @@ -1203,6 +1228,7 @@ class DeviceCachingAllocator {
BlockInfo& block_info = segment_info.blocks.back();

block_info.size = block->size;
block_info.requested_size = block->requested_size;
block_info.allocated = block->allocated;
block_info.active = block->allocated || (block->event_count > 0) ||
!block->stream_uses.empty();
Expand All @@ -1213,6 +1239,7 @@ class DeviceCachingAllocator {
}
if (block_info.active) {
segment_info.active_size += block_info.size;
segment_info.requested_size += block_info.requested_size;
}
HistoryChain* h = block->history.get();
while (h) {
Expand Down Expand Up @@ -1388,6 +1415,7 @@ class DeviceCachingAllocator {
block->history->h.context);
}
size_t original_block_size = block->size;
size_t requested_size = block->requested_size;

auto& pool = *block->pool;
int64_t net_change_inactive_split_blocks = 0;
Expand Down Expand Up @@ -1424,7 +1452,12 @@ class DeviceCachingAllocator {
stats.inactive_split_bytes[stat_type],
net_change_inactive_split_size);
update_stat(stats.active[stat_type], -1);
update_stat(stats.active_bytes[stat_type], -original_block_size);
update_stat(
stats.active_bytes[stat_type],
-static_cast<std::int64_t>(original_block_size));
update_stat(
stats.requested_bytes[stat_type],
-static_cast<std::int64_t>(requested_size));
});
}

Expand Down Expand Up @@ -1775,7 +1808,9 @@ class DeviceCachingAllocator {
stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true;
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], -1);
update_stat(stats.reserved_bytes[stat_type], -block->size);
update_stat(
stats.reserved_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, -1);
Expand Down
6 changes: 5 additions & 1 deletion c10/cuda/CUDACachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ struct DeviceStats {
// released via cudaFree)
StatArray inactive_split;

// SUM: bytes requested by client code
// SUM: bytes allocated by this memory alocator
StatArray allocated_bytes;
// SUM: bytes reserved by this memory allocator (both free and used)
StatArray reserved_bytes;
// SUM: bytes within active memory blocks
StatArray active_bytes;
// SUM: bytes within inactive, split memory blocks
StatArray inactive_split_bytes;
// SUM: bytes requested by client code
StatArray requested_bytes;

// COUNT: total number of failed calls to CUDA malloc necessitating cache
// flushes.
Expand Down Expand Up @@ -110,6 +112,7 @@ struct History {
// cudaMalloc)..
struct BlockInfo {
int64_t size = 0;
int64_t requested_size = 0;
int32_t gc_counter = 0;
bool allocated = false;
bool active = false;
Expand All @@ -121,6 +124,7 @@ struct SegmentInfo {
int64_t device = 0;
int64_t address = 0;
int64_t total_size = 0;
int64_t requested_size = 0;
int64_t allocated_size = 0;
int64_t active_size = 0;
cudaStream_t stream = 0;
Expand Down
31 changes: 21 additions & 10 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,18 @@ def _check_memory_stat_consistency(self):
expected["active_bytes.all.current"] += segment["active_size"]
expected["active_bytes." + pool_str + ".current"] += segment["active_size"]

expected["requested_bytes.all.current"] += segment["requested_size"]
expected["requested_bytes." + pool_str + ".current"] += segment["requested_size"]

sum_requested = 0
is_split = len(segment["blocks"]) > 1
for block in segment["blocks"]:
if block["state"] == "active_allocated":
expected["allocation.all.current"] += 1
expected["allocation." + pool_str + ".current"] += 1

if block["state"].startswith("active_"):
sum_requested += block["requested_size"]
expected["active.all.current"] += 1
expected["active." + pool_str + ".current"] += 1

Expand All @@ -123,6 +128,8 @@ def _check_memory_stat_consistency(self):
expected["inactive_split_bytes.all.current"] += block["size"]
expected["inactive_split_bytes." + pool_str + ".current"] += block["size"]

self.assertEqual(sum_requested, segment["requested_size"])

for device, expected in expected_each_device.items():
stats = torch.cuda.memory_stats(device)
for k, v in expected.items():
Expand Down Expand Up @@ -5028,57 +5035,61 @@ def power2_div(size, div_factor):
return ret

torch.cuda.memory.empty_cache()
key = 'active_bytes.all.allocated' if not TEST_CUDAMALLOCASYNC else 'allocated_bytes.all.current'
key_allocated = 'active_bytes.all.allocated' if not TEST_CUDAMALLOCASYNC else 'allocated_bytes.all.current'
key_requested = 'requested_bytes.all.allocated'

nelems = 21 * 1024 * 1024
nbytes = 4 * nelems # floats are 4 bytes

nelems_big = 100 * 1024 * 1024
nbytes_big = 4 * nelems_big # floats are 4 bytes

start_mem = torch.cuda.memory_stats()[key]
start_mem = torch.cuda.memory_stats()[key_allocated]
torch.cuda.memory._set_allocator_settings("")
x = torch.rand(nelems, device='cuda')

# test roundup_power2_divisions single value syntax
reg_mem = torch.cuda.memory_stats()[key]
reg_mem = torch.cuda.memory_stats()[key_allocated]
start_requested = torch.cuda.memory_stats()[key_requested]
torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4")
y = torch.rand(nelems, device='cuda')

pow2_div4_mem = torch.cuda.memory_stats()[key]
pow2_div4_mem = torch.cuda.memory_stats()[key_allocated]
current_requested = torch.cuda.memory_stats()[key_requested]

self.assertTrue(reg_mem - start_mem == nbytes)
if not TEST_CUDAMALLOCASYNC:
# not supported with the cudaMallocAsync backend
self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
self.assertTrue(current_requested - start_requested == nbytes)

torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5,max_split_size_mb:40")

# should have reset the power2 divisions now
torch.cuda.memory.empty_cache()
start_mem = torch.cuda.memory_stats()[key]
start_mem = torch.cuda.memory_stats()[key_allocated]
z = torch.rand(nelems, device='cuda')
reg_mem = torch.cuda.memory_stats()[key]
reg_mem = torch.cuda.memory_stats()[key_allocated]
self.assertTrue(reg_mem - start_mem == nbytes)

# roundup_power2_divisions knob array syntax
torch.cuda.memory.empty_cache()
torch.cuda.memory._set_allocator_settings(
"garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]")
start_mem = torch.cuda.memory_stats()[key]
start_mem = torch.cuda.memory_stats()[key_allocated]
w = torch.rand(nelems, device='cuda')

pow2_div8_mem = torch.cuda.memory_stats()[key]
pow2_div8_mem = torch.cuda.memory_stats()[key_allocated]
if not TEST_CUDAMALLOCASYNC:
# not supported with the cudaMallocAsync backend
self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8))

torch.cuda.memory.empty_cache()
start_mem = torch.cuda.memory_stats()[key]
start_mem = torch.cuda.memory_stats()[key_allocated]
v = torch.rand(nelems_big, device='cuda')

pow2_div2_mem = torch.cuda.memory_stats()[key]
pow2_div2_mem = torch.cuda.memory_stats()[key_allocated]
if not TEST_CUDAMALLOCASYNC:
# not supported with the cudaMallocAsync backend
self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2))
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/cuda/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) {
result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes);
result["active_bytes"] = statArrayToDict(stats.active_bytes);
result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes);
result["requested_bytes"] = statArrayToDict(stats.requested_bytes);
result["oversize_allocations"] = statToDict(stats.oversize_allocations);
result["oversize_segments"] = statToDict(stats.oversize_segments);

Expand Down Expand Up @@ -646,6 +647,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
py::str total_size_s = "total_size";
py::str allocated_size_s = "allocated_size";
py::str active_size_s = "active_size";
py::str requested_size_s = "requested_size";
py::str stream_s = "stream";
py::str segment_type_s = "segment_type";
py::str large_s = "large";
Expand Down Expand Up @@ -691,6 +693,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
segmentDict[total_size_s] = segmentInfo.total_size;
segmentDict[allocated_size_s] = segmentInfo.allocated_size;
segmentDict[active_size_s] = segmentInfo.active_size;
segmentDict[requested_size_s] = segmentInfo.requested_size;
// we want the python objects to pickle easily so use an int to
// represent the stream rather than a torch.cuda.stream object
segmentDict[stream_s] = int64_t(segmentInfo.stream);
Expand All @@ -700,6 +703,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
for (const auto& blockInfo : segmentInfo.blocks) {
py::dict blockDict;
blockDict[size_s] = blockInfo.size;
blockDict[requested_size_s] = blockInfo.requested_size;
blockDict[state_s] =
(blockInfo.allocated
? active_allocated_s
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/cuda/memory_snapshot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ std::string _memory_snapshot_pickled() {
IValue total_size_s = "total_size";
IValue allocated_size_s = "allocated_size";
IValue active_size_s = "active_size";
IValue requested_size_s = "requested_size";
IValue stream_s = "stream";
IValue segment_type_s = "segment_type";
IValue large_s = "large";
Expand Down Expand Up @@ -71,6 +72,7 @@ std::string _memory_snapshot_pickled() {
segmentDict.insert(total_size_s, segmentInfo.total_size);
segmentDict.insert(allocated_size_s, segmentInfo.allocated_size);
segmentDict.insert(active_size_s, segmentInfo.active_size);
segmentDict.insert(requested_size_s, segmentInfo.requested_size);
segmentDict.insert(stream_s, int64_t(segmentInfo.stream));
segmentDict.insert(
segment_type_s, (segmentInfo.is_large ? large_s : small_s));
Expand All @@ -79,6 +81,7 @@ std::string _memory_snapshot_pickled() {
for (const auto& blockInfo : segmentInfo.blocks) {
auto blockDict = new_dict();
blockDict.insert(size_s, blockInfo.size);
blockDict.insert(requested_size_s, blockInfo.requested_size);
blockDict.insert(
state_s,
(blockInfo.allocated
Expand Down
10 changes: 10 additions & 0 deletions torch/cuda/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
- ``"oversize_segments.{current,peak,allocated,freed}"``:
number of over-size reserved segments from ``cudaMalloc()``.

The caching allocator can be configured via ENV to round memory allocations in order
to reduce fragmentation. Sometimes the overhead from rounding can be higher than
the fragmentation it helps reduce. The following stat can be used to check if
rounding adds too much overhed:

- ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
memory requested by client code, compare this with allocated_bytes to check if
allocation rounding adds too much overhead.

Args:
device (torch.device or int, optional): selected device. Returns
statistics for the current device, given by :func:`~torch.cuda.current_device`,
Expand Down Expand Up @@ -477,6 +486,7 @@ def _format_count(cnt, pref_cnt):
metrics_to_display = [
("allocated_bytes", "Allocated memory", _format_size),
("active_bytes", "Active memory", _format_size),
("requested_bytes", "Requested memory", _format_size),
("reserved_bytes", "GPU reserved memory", _format_size),
("inactive_split_bytes", "Non-releasable memory", _format_size),
("allocation", "Allocations", _format_count),
Expand Down