Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions aten/src/ATen/record_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ class CallbackManager {
bool found_needs_ids = false;
auto init_handles = [
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
handles.clear();

size_t num_callbacks = 0;
for (const auto& cb : cbs) {
if (cb.first.shouldRun(scope)) {
handles.push_back(cb.second);
++num_callbacks;
found_active_cb = true;
if (cb.first.needsInputs()) {
found_needs_inputs = true;
Expand All @@ -106,10 +109,12 @@ class CallbackManager {
}
}
}
// Pre-allocate observer context list with nullptr.
ctx_list.resize(num_callbacks);
};

init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
rec_fn.active = found_active_cb;
rec_fn.needs_inputs = found_needs_inputs;
if (found_needs_ids && found_active_cb) {
Expand All @@ -121,11 +126,13 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ true,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ true,
rf);
rf.called_start_callbacks_ = true;
Expand All @@ -135,21 +142,30 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ false,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ false,
rf);
}

private:
bool tryRunCallback(
const std::function<void(const RecordFunction&)>& fn,
RecordFunction& rf) {
const RecordFunctionCallback& rfcb,
RecordFunction& rf,
std::unique_ptr<ObserverContext>& ctx,
bool is_start) {
try {
fn(rf);
if (is_start) {
ctx = rfcb.start()(rf);
}
else {
rfcb.end()(rf, ctx.get());
}
return true;
} catch (const std::exception &e) {
LOG(WARNING) << "Exception in RecordFunction callback: "
Expand All @@ -165,11 +181,12 @@ class CallbackManager {
void mergeRunCallbacks(
const RecordFunctionCallbacks& sorted_callbacks,
const CallbackHandles& sorted_handles,
ObserverContextList& ctx_list,
bool is_start,
RecordFunction& rf) {
size_t num_executed = 0;
size_t idx_c = 0;
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
while (idx_c < sorted_callbacks.size() &&
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
++idx_c;
Expand All @@ -178,11 +195,7 @@ class CallbackManager {
break;
}
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
if (is_start) {
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
} else {
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
}
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
++num_executed;
}
}
Expand Down
45 changes: 39 additions & 6 deletions aten/src/ATen/record_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@ struct TORCH_API StringView {
// Soft limit on the number of callbacks to use;
constexpr std::size_t kSoftLimitCallbacks = 4;

// An abstract base class for various observer contexts that can be attached to
// the RecordFunction.
struct ObserverContext {
virtual ~ObserverContext() {}
protected:
ObserverContext() {}
};

typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
typedef uint64_t RecordFunctionHandle;

struct TORCH_API RecordFunction {
Expand Down Expand Up @@ -164,6 +173,15 @@ struct TORCH_API RecordFunction {
// public because of anonymous "friend" class
CallbackHandles sorted_active_tls_handles_;
CallbackHandles sorted_active_global_handles_;

// Stores various ObserverContext objects with event metadata for thread local
// callbacks.
ObserverContextList tls_ctx_;

// Stores various ObserverContext objects with event metadata for global
// callbacks.
ObserverContextList global_ctx_;

// Whether this RecordFunction runs any callbacks
bool active = false;
/// Whether any of the picked callbacks require inputs
Expand Down Expand Up @@ -198,6 +216,8 @@ struct TORCH_API RecordFunction {
* RecordFunctionCallback represents a pair of callbacks to be used with
* RecordFunction, members:
* start, end - the callbacks to run when entering and exiting the scope;
* optionally, the start callback may return an ObserverContext which will
* be passed to the end callback, use appropriate constructor accordingly.
* needs_inputs - whether the callbacks need the inputs passed from the observed
* function/range; NOTE: passing the inputs incurs an additional overhead;
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
Expand All @@ -211,12 +231,25 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
scopes_.fill(true);
}

// This interface is for observers that do not pass an ObserverContext object
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<void(const RecordFunction&)> start,
std::function<void(const RecordFunction&)> end =
[](const RecordFunction&) {}):
start_(std::move(start)),
end_(std::move(end)) {
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
scopes_.fill(true);
}

Expand Down Expand Up @@ -272,20 +305,20 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}

inline const std::function<void(const RecordFunction&)>& start() const {
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
return start_;
}

inline const std::function<void(const RecordFunction&)>& end() const {
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
return end_;
}

// whether the callbacks should run in the given scope
bool shouldRun(RecordScope scope) const;

private:
std::function<void(const RecordFunction&)> start_;
std::function<void(const RecordFunction&)> end_;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
std::function<bool(const RecordFunctionCallback&)> should_run_;
bool needs_inputs_ = false;
bool needs_ids_ = false;
Expand Down
3 changes: 3 additions & 0 deletions c10/macros/Macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ __host__ __device__
#endif // ANDROID / IOS

// Portably determine if a type T is trivially copyable or not.
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
// correctly. For example, T = std::unique_ptr may evaluate to true and be
// treated as POD. This can cause unexpected behavior.
#if defined(__GNUG__) && __GNUC__ < 5
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
#else
Expand Down
3 changes: 3 additions & 0 deletions c10/util/SmallVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {

/// This class consists of common code factored out of the SmallVector class to
/// reduce code duplication based on the SmallVector 'N' template parameter.
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
/// memory leaks.
template <typename T>
class SmallVectorImpl
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {
Expand Down