Skip to content

Commit 0140a75

Browse files
malfacebook-github-bot
authored andcommitted
Prioritize reentrant tasks and execute them recursively until close to limit
Summary: Pull Request resolved: #22397 Test Plan: Added test for reentrant backwards with checkpoint and a test for a recursive backwards function (which should fail if we run all the reentrant tasks recursively in the same thread) and for testing priority of reentrant tasks. ~~Will add a test for priority of reentrant tasks in future pr.~~ Imported from OSS Differential Revision: D16131955 fbshipit-source-id: 18301d45c1ec9fbeb566b1016dbaf7a84a09c7ac
1 parent e5d6403 commit 0140a75

File tree

3 files changed

+146
-16
lines changed

3 files changed

+146
-16
lines changed

test/test_autograd.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
from itertools import product
1313
from operator import mul
1414
from functools import reduce
15+
from torch import nn
1516
from torch._six import inf, nan, istuple
1617
from torch.autograd.gradcheck import gradgradcheck, gradcheck
1718
from torch.autograd.function import once_differentiable
1819
from torch.autograd.profiler import profile, format_time, EventList, FunctionEvent, emit_nvtx
1920
from torch.utils.checkpoint import checkpoint
2021
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
21-
suppress_warnings, skipIfRocm,
22+
suppress_warnings, skipIfRocm, slowTest,
2223
load_tests, random_symmetric_pd_matrix, IS_WINDOWS)
2324
from common_cuda import TEST_CUDA
2425
from torch.autograd import Variable, Function, detect_anomaly
@@ -3195,6 +3196,99 @@ def backward(ctx, grad):
31953196
s = TestCase.runWithPytorchAPIUsageStderr(code)
31963197
self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
31973198

3199+
def test_deep_reentrant(self):
3200+
3201+
class DeepReentrant(Function):
3202+
@staticmethod
3203+
def forward(ctx, x):
3204+
with torch.enable_grad():
3205+
ctx.x = Variable(x.data, requires_grad=True)
3206+
ctx.x = ctx.x - 1
3207+
return ctx.x.detach()
3208+
3209+
@staticmethod
3210+
def backward(ctx, x):
3211+
if ctx.x < 0:
3212+
return x
3213+
with torch.enable_grad():
3214+
DeepReentrant.apply(ctx.x).sum().backward()
3215+
return x
3216+
3217+
v = torch.tensor(2000.0, requires_grad=True)
3218+
# This will cause stack overflow if reentrant calls are handled
3219+
# in the same thread recursively
3220+
DeepReentrant.apply(v).sum().backward()
3221+
3222+
def test_reentrant_priority(self):
3223+
order = []
3224+
3225+
class MyFunction(Function):
3226+
@staticmethod
3227+
def forward(ctx, x):
3228+
return x
3229+
3230+
@staticmethod
3231+
def backward(ctx, x):
3232+
order.append("MyFunction")
3233+
return x
3234+
3235+
class Reentrant(Function):
3236+
@staticmethod
3237+
def forward(ctx, x):
3238+
with torch.enable_grad():
3239+
ctx.x = Variable(x.data, requires_grad=True)
3240+
ctx.x = ctx.x - 1
3241+
return ctx.x.detach()
3242+
3243+
@staticmethod
3244+
def backward(ctx, x):
3245+
order.append("Reentrant")
3246+
if ctx.x < 0:
3247+
return x
3248+
with torch.enable_grad():
3249+
Reentrant.apply(ctx.x).backward()
3250+
return x
3251+
3252+
a = MyFunction.apply(torch.tensor(6.0, requires_grad=True))
3253+
b = Reentrant.apply(torch.tensor(9.0, requires_grad=True))
3254+
v = a * b
3255+
v.backward()
3256+
# The tasks for the Reentrant and MyFunction backward() will be added
3257+
# to the queue in the autograd engine at the same time. The backward
3258+
# for Reentrant will be executed first, which will then add other
3259+
# backward tasks to the queue. We want to ensure all the reentrant tasks
3260+
# are prioritized over the MyFunction backward task regardless of their
3261+
# sequence numbers
3262+
self.assertEqual(len(order), 11)
3263+
self.assertEqual(order.count("Reentrant"), 10)
3264+
self.assertEqual(order[-1], "MyFunction")
3265+
3266+
@slowTest
3267+
def test_checkpointing(self):
3268+
num_inp = 2000
3269+
nz_inp = 10
3270+
nz_out = 10
3271+
nz_bottleneck = 1000
3272+
3273+
# small proxy network for some complex reasoning we want to do per input
3274+
module = nn.Sequential(
3275+
nn.Linear(nz_inp, nz_bottleneck),
3276+
nn.ReLU(),
3277+
nn.Linear(nz_bottleneck, nz_inp)
3278+
)
3279+
3280+
feat_combined = []
3281+
for r in range(num_inp):
3282+
data_r = torch.Tensor(1, nz_inp)
3283+
data_r.uniform_()
3284+
data_r.requires_grad = True
3285+
feat_r = checkpoint(module, data_r)
3286+
feat_combined.append(feat_r)
3287+
3288+
# compute mean as a proxy for some joint reasoning
3289+
mean_combined = torch.stack(feat_combined).mean()
3290+
mean_combined.backward()
3291+
31983292
def index_variable(shape, max_indices):
31993293
if not isinstance(shape, tuple):
32003294
shape = (shape,)

torch/csrc/autograd/engine.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ static thread_local bool checkpoint_valid = true;
5151
// engine thread affinity to the device can break this invariant, and we depend
5252
// on it in a few places (e.g. AccumulateGrad function).
5353

54+
// Number of nested reentrant backwards calls currently on this thread
55+
static thread_local int current_depth = 0;
56+
// Total nested reentrant backwards calls over all threads for workder_device
57+
static thread_local int total_depth = 0;
58+
5459
struct FunctionTask {
5560
GraphTask* base_;
5661
std::shared_ptr<Function> fn_;
@@ -62,6 +67,8 @@ struct FunctionTask {
6267
// exit. The engine sends a shutdown task to every queue upon its destruction.
6368
bool isShutdownTask_;
6469

70+
int getReentrantDepth() const;
71+
6572
FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs, bool isShutdownTask = false)
6673
: base_(base)
6774
, fn_(std::move(fn))
@@ -79,8 +86,10 @@ struct CompareFunctionTaskTime {
7986
return false;
8087
} else if (!t2.fn_) {
8188
return true;
82-
} else {
89+
} else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
8390
return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
91+
} else {
92+
return t1.getReentrantDepth() < t2.getReentrantDepth();
8493
}
8594
}
8695
};
@@ -174,21 +183,28 @@ struct GraphTask {
174183

175184
// The value of worker_device in the thread that created this task.
176185
// See Note [Reentrant backwards]
177-
// Safe to read owner_ without synchronizaton
186+
// Safe to read owner_ and reentrant_depth_ without synchronizaton
178187
int owner_;
188+
// The number of parent graph tasks for this graph task
189+
const int reentrant_depth_;
179190

180191
bool can_checkpoint() {
181192
return exec_info_.empty();
182193
}
183194

184-
GraphTask(bool keep_graph, bool grad_mode)
195+
GraphTask(bool keep_graph, bool grad_mode, int reentrant_depth)
185196
: has_error_(false)
186197
, outstanding_tasks_(0)
187198
, keep_graph_(keep_graph)
188199
, grad_mode_(grad_mode)
189-
, owner_(NO_DEVICE) {}
200+
, owner_(NO_DEVICE)
201+
, reentrant_depth_(reentrant_depth) {}
190202
};
191203

204+
int FunctionTask::getReentrantDepth() const {
205+
return base_->reentrant_depth_;
206+
}
207+
192208
auto ReadyQueue::push(FunctionTask item) -> void {
193209
{
194210
// Lock mutex for writing to heap_
@@ -216,7 +232,8 @@ auto ReadyQueue::pop() -> FunctionTask {
216232
return task;
217233
}
218234

219-
Engine::Engine() = default;
235+
// This limit is based on the default python recursion limit which is 1000
236+
Engine::Engine() : max_recursion_depth_(100) {}
220237

221238
// Send shutdown tasks to all ReadyQueues if no backward tasks are running
222239
// Even though readyQueue should be empty, shutdown tasks have the highest
@@ -342,7 +359,13 @@ auto Engine::thread_main(GraphTask *graph_task) -> void {
342359
}
343360
}
344361

345-
if (graph_task) {
362+
// When current_depth is 0 this worker thread is done and we need to notify
363+
// the parent thread waiting on the graph_task
364+
// NOTE: An edge case for this is when reentrant calls are repeatedly made in
365+
// a thread which is at its maximum stack depth and they keep exiting right
366+
// after. We will always switch to a new thread for each call, so, we'll keep
367+
// oscillating between the two threads.
368+
if (graph_task && current_depth == 0) {
346369
graph_task->not_done_.notify_all();
347370
}
348371
}
@@ -359,6 +382,7 @@ void Engine::reentrant_thread_init() {
359382
tp_shared->graphtasks_queue_.pop();
360383
lk.unlock();
361384
set_device(graph_task->owner_);
385+
total_depth = graph_task->reentrant_depth_;
362386
thread_main(graph_task);
363387
}
364388
}
@@ -632,7 +656,7 @@ auto Engine::execute(const edge_list& roots,
632656
// Lock post_callbacks_lock_ before clearing final_callbacks_
633657
ClearCallbacks _cb_guard(final_callbacks_, post_callbacks_lock_);
634658

635-
GraphTask graph_task(keep_graph, create_graph);
659+
GraphTask graph_task(keep_graph, create_graph, worker_device == NO_DEVICE ? 0 : total_depth+1);
636660
// Lock mutex while GraphTask is being set up
637661
std::unique_lock<std::mutex> lock(graph_task.mutex_);
638662

@@ -651,15 +675,24 @@ auto Engine::execute(const edge_list& roots,
651675
return graph_task.outstanding_tasks_.load() == 0;
652676
});
653677
} else {
654-
// Get back to work while we wait for our new graph_task to
655-
// complete!
656-
// See Note [Reentrant backwards]
657-
// If no extra threads remaining, create a new thread for reentrant call
658678
graph_task.owner_ = worker_device;
659-
add_thread_pool_task(&graph_task);
660-
graph_task.not_done_.wait(lock, [&graph_task]{
661-
return graph_task.outstanding_tasks_.load() == 0;
662-
});
679+
++total_depth;
680+
if(current_depth >= max_recursion_depth_){
681+
// See Note [Reentrant backwards]
682+
// If reached the max depth, switch to a different thread
683+
add_thread_pool_task(&graph_task);
684+
graph_task.not_done_.wait(lock, [&graph_task]{
685+
return graph_task.outstanding_tasks_.load() == 0;
686+
});
687+
} else {
688+
// Get back to work while we wait for our new graph_task to
689+
// complete!
690+
++current_depth;
691+
lock.unlock();
692+
thread_main(&graph_task);
693+
--current_depth;
694+
}
695+
--total_depth;
663696
}
664697

665698
// Check for an exception while running backwards

torch/csrc/autograd/engine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <unordered_map>
1717
#include <utility>
1818
#include <vector>
19+
#include <thread>
1920

2021
namespace torch { namespace autograd {
2122
struct ReadyQueue;
@@ -72,6 +73,8 @@ struct TORCH_API Engine {
7273
std::vector<std::function<void()>> final_callbacks_;
7374
// To protect reads and writes to final_callbacks_
7475
std::mutex post_callbacks_lock_;
76+
// How many nested reentrant calls are allowed until a new thread is used
77+
int max_recursion_depth_;
7578

7679
struct ThreadPoolShared {
7780
// Data structures used by the threads for executing reentrant backwards

0 commit comments

Comments
 (0)