Skip to content

Commit 09f2c6a

Browse files
louisfengfacebook-github-bot
authored andcommitted
Back out "Revert D23494065: Refactor CallbackManager as a friend class of RecordFunction." (#44699)
Summary: Pull Request resolved: #44699 Original commit changeset: 3b1ec928e3db Previous revert (D23698861) was on the wrong diff stack. Backing out the revert. Test Plan: Passed unit tests and previously landed. Reviewed By: mruberry Differential Revision: D23702258 fbshipit-source-id: 5c3e197bca412f454db5a7e86251ec85faf621c1
1 parent 174cbff commit 09f2c6a

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

aten/src/ATen/record_function.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,34 @@ thread_local RecordFunctionCallbacks sorted_tls_callbacks_;
2424

2525
std::atomic<int64_t> defaultNodeId(-1);
2626

27+
// Enumerates thread ids logically;
28+
// note: std::this_thread::get_id may return potentially
29+
// reused thread id
30+
std::atomic<uint64_t> next_thread_id_ {0};
31+
thread_local uint64_t current_thread_id_ = 0;
32+
33+
thread_local bool tls_record_function_enabled_ = true;
34+
35+
// Low probability constant
36+
const double kLowProb = 0.001;
37+
thread_local int tries_left_ = 0;
38+
39+
int sample_geometric() {
40+
static thread_local auto gen =
41+
std::make_unique<std::mt19937>(std::random_device()());
42+
std::geometric_distribution<int> dist(kLowProb);
43+
return dist(*gen);
44+
}
45+
46+
double sample_zero_one() {
47+
static thread_local auto gen =
48+
std::make_unique<std::mt19937>(std::random_device()());
49+
std::uniform_real_distribution<double> dist(0.0, 1.0);
50+
return dist(*gen);
51+
}
52+
53+
} // namespace
54+
2755
class CallbackManager {
2856
public:
2957
CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) {
@@ -212,37 +240,12 @@ class CallbackManager {
212240
RecordFunctionCallbacks sorted_global_callbacks_;
213241
};
214242

215-
// Enumerates thread ids logically;
216-
// note: std::this_thread::get_id may return potentially
217-
// reused thread id
218-
std::atomic<uint64_t> next_thread_id_ {0};
219-
thread_local uint64_t current_thread_id_ = 0;
220-
221-
inline CallbackManager& manager() {
222-
static CallbackManager _manager;
223-
return _manager;
224-
}
225-
226-
thread_local bool tls_record_function_enabled_ = true;
227-
228-
// Low probability constant
229-
const double kLowProb = 0.001;
230-
thread_local int tries_left_ = 0;
231-
232-
int sample_geometric() {
233-
static thread_local auto gen =
234-
std::make_unique<std::mt19937>(std::random_device()());
235-
std::geometric_distribution<int> dist(kLowProb);
236-
return dist(*gen);
237-
}
238-
239-
double sample_zero_one() {
240-
static thread_local auto gen =
241-
std::make_unique<std::mt19937>(std::random_device()());
242-
std::uniform_real_distribution<double> dist(0.0, 1.0);
243-
return dist(*gen);
244-
}
245-
243+
namespace {
244+
// Keeping this static manager local.
245+
CallbackManager& manager() {
246+
static CallbackManager _manager;
247+
return _manager;
248+
}
246249
} // namespace
247250

248251
bool RecordFunctionCallback::shouldRun(RecordScope scope) const {

aten/src/ATen/record_function.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,17 @@ struct TORCH_API RecordFunction {
168168
handle_ = handle;
169169
}
170170

171+
// Whether this RecordFunction runs any callbacks
172+
bool active = false;
173+
// Whether any of the picked callbacks require inputs
174+
bool needs_inputs = false;
175+
176+
private:
177+
// Allows the modification of some internal states for callbacks.
178+
friend class CallbackManager;
179+
171180
// Used internally to keep track of thread local and global callbacks
172181
// that were picked to run; must be sorted;
173-
// public because of anonymous "friend" class
174182
CallbackHandles sorted_active_tls_handles_;
175183
CallbackHandles sorted_active_global_handles_;
176184

@@ -182,17 +190,11 @@ struct TORCH_API RecordFunction {
182190
// callbacks.
183191
ObserverContextList global_ctx_;
184192

185-
// Whether this RecordFunction runs any callbacks
186-
bool active = false;
187-
/// Whether any of the picked callbacks require inputs
188-
bool needs_inputs = false;
189-
190193
// In cases when RecordFunction might be active but we chose not to
191194
// use the observers (e.g. operator is not observed), this boolean
192195
// flag is used to check whether the start callbacks were called
193196
bool called_start_callbacks_ = false;
194197

195-
private:
196198
StringView name_;
197199
int64_t sequence_nr_ = -1;
198200
std::vector<c10::IValue> inputs_;

0 commit comments

Comments
 (0)