forked from rapidsai/rapidsmpf
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscoped_memory_record.cpp
More file actions
106 lines (92 loc) · 3.74 KB
/
scoped_memory_record.cpp
File metadata and controls
106 lines (92 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/
#include <numeric>
#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/scoped_memory_record.hpp>
namespace rapidsmpf {
namespace {
/**
* @brief Retrieves a value from a statistics array or accumulates the total.
*
* @param arr The array containing statistics for each allocator type.
* @param alloc_type The type of allocator to retrieve data for. If `AllocType::ALL`,
* the function returns the sum across all entries in the array.
* @return The requested statistic value or the accumulated total.
*/
std::int64_t get_or_accumulate(
ScopedMemoryRecord::AllocTypeArray const& arr,
ScopedMemoryRecord::AllocType alloc_type
) noexcept {
if (alloc_type == ScopedMemoryRecord::AllocType::ALL) {
return std::accumulate(arr.begin(), arr.end(), std::int64_t{0});
}
return arr[static_cast<std::size_t>(alloc_type)];
}
} // namespace
std::int64_t ScopedMemoryRecord::num_total_allocs(AllocType alloc_type) const noexcept {
return get_or_accumulate(num_total_allocs_, alloc_type);
}
std::int64_t ScopedMemoryRecord::num_current_allocs(AllocType alloc_type) const noexcept {
return get_or_accumulate(num_current_allocs_, alloc_type);
}
std::int64_t ScopedMemoryRecord::current(AllocType alloc_type) const noexcept {
return get_or_accumulate(current_, alloc_type);
}
std::int64_t ScopedMemoryRecord::total(AllocType alloc_type) const noexcept {
return get_or_accumulate(total_, alloc_type);
}
std::int64_t ScopedMemoryRecord::peak(AllocType alloc_type) const noexcept {
if (alloc_type == AllocType::ALL) {
return highest_peak_;
}
return peak_[static_cast<std::size_t>(alloc_type)];
}
void ScopedMemoryRecord::record_allocation(AllocType alloc_type, std::int64_t nbytes) {
RAPIDSMPF_EXPECTS(
alloc_type != AllocType::ALL,
"AllocType::ALL may not be used to record allocation"
);
auto at = static_cast<std::size_t>(alloc_type);
++num_total_allocs_[at];
++num_current_allocs_[at];
current_[at] += nbytes;
total_[at] += nbytes;
peak_[at] = std::max(peak_[at], current_[at]);
highest_peak_ = std::max(highest_peak_, current());
}
void ScopedMemoryRecord::record_deallocation(AllocType alloc_type, std::int64_t nbytes) {
RAPIDSMPF_EXPECTS(
alloc_type != AllocType::ALL,
"AllocType::ALL may not be used to record deallocation"
);
auto at = static_cast<std::size_t>(alloc_type);
current_[at] -= nbytes;
--num_current_allocs_[at];
}
ScopedMemoryRecord& ScopedMemoryRecord::add_subscope(ScopedMemoryRecord const& subscope) {
highest_peak_ = std::max(highest_peak_, current() + subscope.highest_peak_);
for (AllocType type : {AllocType::PRIMARY, AllocType::FALLBACK}) {
auto i = static_cast<std::size_t>(type);
peak_[i] = std::max(peak_[i], current_[i] + subscope.peak_[i]);
num_total_allocs_[i] += subscope.num_total_allocs_[i];
num_current_allocs_[i] += subscope.num_current_allocs_[i];
current_[i] += subscope.current_[i];
total_[i] += subscope.total_[i];
}
return *this;
}
ScopedMemoryRecord& ScopedMemoryRecord::add_scope(ScopedMemoryRecord const& scope) {
highest_peak_ = std::max(highest_peak_, scope.highest_peak_);
for (AllocType type : {AllocType::PRIMARY, AllocType::FALLBACK}) {
auto i = static_cast<std::size_t>(type);
peak_[i] = std::max(peak_[i], scope.peak_[i]);
current_[i] += scope.current_[i];
total_[i] += scope.total_[i];
num_total_allocs_[i] += scope.num_total_allocs_[i];
num_current_allocs_[i] += scope.num_current_allocs_[i];
}
return *this;
}
} // namespace rapidsmpf