-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Prioritize reentrant tasks and execute them recursively until close to limit #22397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3e2bbb5
6d8a960
a7079cc
1bcd8b3
59d281d
4ea502a
6d3ab30
00e9b83
b6e118d
affc526
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,11 @@ static thread_local bool checkpoint_valid = true; | |
| // engine thread affinity to the device can break this invariant, and we depend | ||
| // on it in a few places (e.g. AccumulateGrad function). | ||
|
|
||
| // Number of nested reentrant backwards calls currently on this thread | ||
| static thread_local int current_depth = 0; | ||
| // Total nested reentrant backwards calls over all threads for workder_device | ||
| static thread_local int total_depth = 0; | ||
|
|
||
| struct FunctionTask { | ||
| GraphTask* base_; | ||
| std::shared_ptr<Function> fn_; | ||
|
|
@@ -62,6 +67,8 @@ struct FunctionTask { | |
| // exit. The engine sends a shutdown task to every queue upon its destruction. | ||
| bool isShutdownTask_; | ||
|
|
||
| int getReentrantDepth() const; | ||
|
|
||
| FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs, bool isShutdownTask = false) | ||
| : base_(base) | ||
| , fn_(std::move(fn)) | ||
|
|
@@ -79,8 +86,10 @@ struct CompareFunctionTaskTime { | |
| return false; | ||
| } else if (!t2.fn_) { | ||
| return true; | ||
| } else { | ||
| } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) { | ||
| return t1.fn_->sequence_nr() < t2.fn_->sequence_nr(); | ||
| } else { | ||
| return t1.getReentrantDepth() < t2.getReentrantDepth(); | ||
| } | ||
| } | ||
| }; | ||
|
|
@@ -174,21 +183,28 @@ struct GraphTask { | |
|
|
||
| // The value of worker_device in the thread that created this task. | ||
| // See Note [Reentrant backwards] | ||
| // Safe to read owner_ without synchronizaton | ||
| // Safe to read owner_ and reentrant_depth_ without synchronizaton | ||
| int owner_; | ||
| // The number of parent graph tasks for this graph task | ||
| const int reentrant_depth_; | ||
|
|
||
| bool can_checkpoint() { | ||
| return exec_info_.empty(); | ||
| } | ||
|
|
||
| GraphTask(bool keep_graph, bool grad_mode) | ||
| GraphTask(bool keep_graph, bool grad_mode, int reentrant_depth) | ||
| : has_error_(false) | ||
| , outstanding_tasks_(0) | ||
| , keep_graph_(keep_graph) | ||
| , grad_mode_(grad_mode) | ||
| , owner_(NO_DEVICE) {} | ||
| , owner_(NO_DEVICE) | ||
| , reentrant_depth_(reentrant_depth) {} | ||
| }; | ||
|
|
||
| int FunctionTask::getReentrantDepth() const { | ||
| return base_->reentrant_depth_; | ||
| } | ||
|
|
||
| auto ReadyQueue::push(FunctionTask item) -> void { | ||
| { | ||
| // Lock mutex for writing to heap_ | ||
|
|
@@ -216,7 +232,8 @@ auto ReadyQueue::pop() -> FunctionTask { | |
| return task; | ||
| } | ||
|
|
||
| Engine::Engine() = default; | ||
| // This limit is based on the default python recursion limit which is 1000 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comments do not match the code 1000 vs 100
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to set it lower than the actual python limit to take into account the function calls within python |
||
| Engine::Engine() : max_recursion_depth_(100) {} | ||
|
|
||
| // Send shutdown tasks to all ReadyQueues if no backward tasks are running | ||
| // Even though readyQueue should be empty, shutdown tasks have the highest | ||
|
|
@@ -342,7 +359,13 @@ auto Engine::thread_main(GraphTask *graph_task) -> void { | |
| } | ||
| } | ||
|
|
||
| if (graph_task) { | ||
| // When current_depth is 0 this worker thread is done and we need to notify | ||
| // the parent thread waiting on the graph_task | ||
| // NOTE: An edge case for this is when reentrant calls are repeatedly made in | ||
| // a thread which is at its maximum stack depth and they keep exiting right | ||
| // after. We will always switch to a new thread for each call, so, we'll keep | ||
| // oscillating between the two threads. | ||
| if (graph_task && current_depth == 0) { | ||
malvika2147 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| graph_task->not_done_.notify_all(); | ||
| } | ||
| } | ||
|
|
@@ -359,6 +382,7 @@ void Engine::reentrant_thread_init() { | |
| tp_shared->graphtasks_queue_.pop(); | ||
| lk.unlock(); | ||
| set_device(graph_task->owner_); | ||
| total_depth = graph_task->reentrant_depth_; | ||
| thread_main(graph_task); | ||
| } | ||
| } | ||
|
|
@@ -632,7 +656,7 @@ auto Engine::execute(const edge_list& roots, | |
| // Lock post_callbacks_lock_ before clearing final_callbacks_ | ||
| ClearCallbacks _cb_guard(final_callbacks_, post_callbacks_lock_); | ||
|
|
||
| GraphTask graph_task(keep_graph, create_graph); | ||
| GraphTask graph_task(keep_graph, create_graph, worker_device == NO_DEVICE ? 0 : total_depth+1); | ||
| // Lock mutex while GraphTask is being set up | ||
| std::unique_lock<std::mutex> lock(graph_task.mutex_); | ||
|
|
||
|
|
@@ -651,15 +675,24 @@ auto Engine::execute(const edge_list& roots, | |
| return graph_task.outstanding_tasks_.load() == 0; | ||
| }); | ||
| } else { | ||
| // Get back to work while we wait for our new graph_task to | ||
| // complete! | ||
| // See Note [Reentrant backwards] | ||
| // If no extra threads remaining, create a new thread for reentrant call | ||
| graph_task.owner_ = worker_device; | ||
| add_thread_pool_task(&graph_task); | ||
| graph_task.not_done_.wait(lock, [&graph_task]{ | ||
| return graph_task.outstanding_tasks_.load() == 0; | ||
| }); | ||
| ++total_depth; | ||
| if(current_depth >= max_recursion_depth_){ | ||
| // See Note [Reentrant backwards] | ||
| // If reached the max depth, switch to a different thread | ||
| add_thread_pool_task(&graph_task); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will happen if we exceed thread_pool_shared_->graphtasks_queue_.size() here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::queue should resize dynamically, or am I misunderstanding your question? |
||
| graph_task.not_done_.wait(lock, [&graph_task]{ | ||
| return graph_task.outstanding_tasks_.load() == 0; | ||
| }); | ||
| } else { | ||
| // Get back to work while we wait for our new graph_task to | ||
| // complete! | ||
| ++current_depth; | ||
| lock.unlock(); | ||
| thread_main(&graph_task); | ||
| --current_depth; | ||
| } | ||
| --total_depth; | ||
| } | ||
|
|
||
| // Check for an exception while running backwards | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty funky forward function. (It's funky because we don't normally do autograd operations inside of a forward function.) It seems like it's doing two things: you want to
return x(forward is just identity), but you also want to create a leaf variable on context with some non-trivial autograd history. Is there a reasonxhas to be used in both cases? I'll keep reading and see if I can figure out why you create leaf variables in forward ;)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, you're also using
ctx.xto keep track about how many times you recurse.