Skip to content

Commit 78bce48

Browse files
louisfengfacebook-github-bot
authored andcommitted
DPP Async Tracing
Summary: Add tracing to DPP client. Because DPP requests are async, we need to be able to start a trace event in one thread and potentially end in a different thread. RecordFunction and LibgpumonObserver previously assume each trace event starts and finishes in the same thread. So they use a thread local context to track enter and exit call backs. Async events breaks this assumption. This change attaches the event context to the RecordFunction object so we do not need to use thread local context. Test Plan: Tested with dpp perf test and able to collect trace. {F307824044} Differential Revision: D23323486 fbshipit-source-id: 7700fddee73c0e6e89d2e8b3a863cf56316999d6
1 parent 15a7368 commit 78bce48

File tree

4 files changed

+70
-18
lines changed

4 files changed

+70
-18
lines changed

aten/src/ATen/record_function.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@ class CallbackManager {
9292
bool found_needs_ids = false;
9393
auto init_handles = [
9494
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
95-
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
95+
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
9696
handles.clear();
97+
98+
size_t num_callbacks = 0;
9799
for (const auto& cb : cbs) {
98100
if (cb.first.shouldRun(scope)) {
99101
handles.push_back(cb.second);
102+
++num_callbacks;
100103
found_active_cb = true;
101104
if (cb.first.needsInputs()) {
102105
found_needs_inputs = true;
@@ -106,10 +109,12 @@ class CallbackManager {
106109
}
107110
}
108111
}
112+
// Pre-allocate observer context list with nullptr.
113+
ctx_list.resize(num_callbacks);
109114
};
110115

111-
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
112-
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
116+
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
117+
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
113118
rec_fn.active = found_active_cb;
114119
rec_fn.needs_inputs = found_needs_inputs;
115120
if (found_needs_ids && found_active_cb) {
@@ -121,11 +126,13 @@ class CallbackManager {
121126
mergeRunCallbacks(
122127
sorted_global_callbacks_,
123128
rf.sorted_active_global_handles_,
129+
rf.global_ctx_,
124130
/* is_start */ true,
125131
rf);
126132
mergeRunCallbacks(
127133
sorted_tls_callbacks_,
128134
rf.sorted_active_tls_handles_,
135+
rf.tls_ctx_,
129136
/* is_start */ true,
130137
rf);
131138
rf.called_start_callbacks_ = true;
@@ -135,21 +142,30 @@ class CallbackManager {
135142
mergeRunCallbacks(
136143
sorted_global_callbacks_,
137144
rf.sorted_active_global_handles_,
145+
rf.global_ctx_,
138146
/* is_start */ false,
139147
rf);
140148
mergeRunCallbacks(
141149
sorted_tls_callbacks_,
142150
rf.sorted_active_tls_handles_,
151+
rf.tls_ctx_,
143152
/* is_start */ false,
144153
rf);
145154
}
146155

147156
private:
148157
bool tryRunCallback(
149-
const std::function<void(const RecordFunction&)>& fn,
150-
RecordFunction& rf) {
158+
const RecordFunctionCallback& rfcb,
159+
RecordFunction& rf,
160+
std::unique_ptr<ObserverContext>& ctx,
161+
bool is_start) {
151162
try {
152-
fn(rf);
163+
if (is_start) {
164+
ctx = rfcb.start()(rf);
165+
}
166+
else {
167+
rfcb.end()(rf, ctx.get());
168+
}
153169
return true;
154170
} catch (const std::exception &e) {
155171
LOG(WARNING) << "Exception in RecordFunction callback: "
@@ -165,11 +181,12 @@ class CallbackManager {
165181
void mergeRunCallbacks(
166182
const RecordFunctionCallbacks& sorted_callbacks,
167183
const CallbackHandles& sorted_handles,
184+
ObserverContextList& ctx_list,
168185
bool is_start,
169186
RecordFunction& rf) {
170187
size_t num_executed = 0;
171188
size_t idx_c = 0;
172-
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
189+
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
173190
while (idx_c < sorted_callbacks.size() &&
174191
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
175192
++idx_c;
@@ -178,11 +195,7 @@ class CallbackManager {
178195
break;
179196
}
180197
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
181-
if (is_start) {
182-
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
183-
} else {
184-
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
185-
}
198+
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
186199
++num_executed;
187200
}
188201
}

aten/src/ATen/record_function.h

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,16 @@ struct TORCH_API StringView {
6767
// Soft limit on the number of callbacks to use;
6868
constexpr std::size_t kSoftLimitCallbacks = 4;
6969

70+
// An abstract base class for various observer contexts that can be attached to
71+
// the RecordFunction.
72+
struct ObserverContext {
73+
virtual ~ObserverContext() {}
74+
protected:
75+
ObserverContext() {}
76+
};
77+
7078
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
79+
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
7180
typedef uint64_t RecordFunctionHandle;
7281

7382
struct TORCH_API RecordFunction {
@@ -164,6 +173,15 @@ struct TORCH_API RecordFunction {
164173
// public because of anonymous "friend" class
165174
CallbackHandles sorted_active_tls_handles_;
166175
CallbackHandles sorted_active_global_handles_;
176+
177+
// Stores various ObserverContext objects with event metadata for thread local
178+
// callbacks.
179+
ObserverContextList tls_ctx_;
180+
181+
// Stores various ObserverContext objects with event metadata for global
182+
// callbacks.
183+
ObserverContextList global_ctx_;
184+
167185
// Whether this RecordFunction runs any callbacks
168186
bool active = false;
169187
/// Whether any of the picked callbacks require inputs
@@ -198,6 +216,8 @@ struct TORCH_API RecordFunction {
198216
* RecordFunctionCallback represents a pair of callbacks to be used with
199217
* RecordFunction, members:
200218
* start, end - the callbacks to run when entering and exiting the scope;
219+
* optionally, the start callback may return an ObserverContext which will
220+
* be passed to the end callback, use appropriate constructor accordingly.
201221
* needs_inputs - whether the callbacks need the inputs passed from the observed
202222
* function/range; NOTE: passing the inputs incurs an additional overhead;
203223
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
@@ -211,12 +231,25 @@ struct TORCH_API RecordFunction {
211231
*/
212232
class TORCH_API RecordFunctionCallback {
213233
public:
234+
// This interface supports observers that require passing an ObserverContext
235+
// between start and end callbacks.
236+
explicit RecordFunctionCallback(
237+
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
238+
std::function<void(const RecordFunction&, ObserverContext*)> end =
239+
[](const RecordFunction&, ObserverContext*) {}):
240+
start_(std::move(start)),
241+
end_(std::move(end)) {
242+
scopes_.fill(true);
243+
}
244+
245+
// This interface is for observers that do not pass an ObserverContext object
246+
// between start and end callbacks.
214247
explicit RecordFunctionCallback(
215248
std::function<void(const RecordFunction&)> start,
216249
std::function<void(const RecordFunction&)> end =
217250
[](const RecordFunction&) {}):
218-
start_(std::move(start)),
219-
end_(std::move(end)) {
251+
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
252+
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
220253
scopes_.fill(true);
221254
}
222255

@@ -272,20 +305,20 @@ class TORCH_API RecordFunctionCallback {
272305
return scopes_[(size_t)sc];
273306
}
274307

275-
inline const std::function<void(const RecordFunction&)>& start() const {
308+
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
276309
return start_;
277310
}
278311

279-
inline const std::function<void(const RecordFunction&)>& end() const {
312+
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
280313
return end_;
281314
}
282315

283316
// whether the callbacks should run in the given scope
284317
bool shouldRun(RecordScope scope) const;
285318

286319
private:
287-
std::function<void(const RecordFunction&)> start_;
288-
std::function<void(const RecordFunction&)> end_;
320+
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
321+
std::function<void(const RecordFunction&, ObserverContext*)> end_;
289322
std::function<bool(const RecordFunctionCallback&)> should_run_;
290323
bool needs_inputs_ = false;
291324
bool needs_ids_ = false;

c10/macros/Macros.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ __host__ __device__
294294
#endif // ANDROID / IOS
295295

296296
// Portably determine if a type T is trivially copyable or not.
297+
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
298+
// correctly. For example, T = std::unique_ptr may evaluate to true and be
299+
// treated as POD. This can cause unexpected behavior.
297300
#if defined(__GNUG__) && __GNUC__ < 5
298301
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
299302
#else

c10/util/SmallVector.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
378378

379379
/// This class consists of common code factored out of the SmallVector class to
380380
/// reduce code duplication based on the SmallVector 'N' template parameter.
381+
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
382+
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
383+
/// memory leaks.
381384
template <typename T>
382385
class SmallVectorImpl
383386
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {

0 commit comments

Comments
 (0)