Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from itertools import product
from operator import mul
from functools import reduce
from torch import nn
from torch._six import inf, nan, istuple
from torch.autograd.gradcheck import gradgradcheck, gradcheck
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import profile, format_time, EventList, FunctionEvent, emit_nvtx
from torch.utils.checkpoint import checkpoint
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
suppress_warnings, skipIfRocm,
suppress_warnings, skipIfRocm, slowTest,
load_tests, random_symmetric_pd_matrix, IS_WINDOWS)
from common_cuda import TEST_CUDA
from torch.autograd import Variable, Function, detect_anomaly
Expand Down Expand Up @@ -3195,6 +3196,99 @@ def backward(ctx, grad):
s = TestCase.runWithPytorchAPIUsageStderr(code)
self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")

def test_deep_reentrant(self):

class DeepReentrant(Function):
@staticmethod
def forward(ctx, x):
with torch.enable_grad():
ctx.x = Variable(x.data, requires_grad=True)
ctx.x = ctx.x - 1
return ctx.x.detach()

@staticmethod
def backward(ctx, x):
if ctx.x < 0:
return x
with torch.enable_grad():
DeepReentrant.apply(ctx.x).sum().backward()
return x

v = torch.tensor(2000.0, requires_grad=True)
# This will cause stack overflow if reentrant calls are handled
# in the same thread recursively
DeepReentrant.apply(v).sum().backward()

def test_reentrant_priority(self):
order = []

class MyFunction(Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, x):
order.append("MyFunction")
return x

class Reentrant(Function):
@staticmethod
def forward(ctx, x):
with torch.enable_grad():
ctx.x = Variable(x.data, requires_grad=True)
ctx.x = ctx.x - 1
return ctx.x.detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty funky forward function. (It's funky because we don't normally do autograd operations inside of a forward function.) It seems like it's doing two things: you want to return x (forward is just identity), but you also want to create a leaf variable on context with some non-trivial autograd history. Is there a reason x has to be used in both cases? I'll keep reading and see if I can figure out why you create leaf variables in forward ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, you're also using ctx.x to keep track about how many times you recurse.


@staticmethod
def backward(ctx, x):
order.append("Reentrant")
if ctx.x < 0:
return x
with torch.enable_grad():
Reentrant.apply(ctx.x).backward()
return x

a = MyFunction.apply(torch.tensor(6.0, requires_grad=True))
b = Reentrant.apply(torch.tensor(9.0, requires_grad=True))
v = a * b
v.backward()
# The tasks for the Reentrant and MyFunction backward() will be added
# to the queue in the autograd engine at the same time. The backward
# for Reentrant will be executed first, which will then add other
# backward tasks to the queue. We want to ensure all the reentrant tasks
# are prioritized over the MyFunction backward task regardless of their
# sequence numbers
self.assertEqual(len(order), 11)
self.assertEqual(order.count("Reentrant"), 10)
self.assertEqual(order[-1], "MyFunction")

@slowTest
def test_checkpointing(self):
num_inp = 2000
nz_inp = 10
nz_out = 10
nz_bottleneck = 1000

# small proxy network for some complex reasoning we want to do per input
module = nn.Sequential(
nn.Linear(nz_inp, nz_bottleneck),
nn.ReLU(),
nn.Linear(nz_bottleneck, nz_inp)
)

feat_combined = []
for r in range(num_inp):
data_r = torch.Tensor(1, nz_inp)
data_r.uniform_()
data_r.requires_grad = True
feat_r = checkpoint(module, data_r)
feat_combined.append(feat_r)

# compute mean as a proxy for some joint reasoning
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()

def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
Expand Down
63 changes: 48 additions & 15 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ static thread_local bool checkpoint_valid = true;
// engine thread affinity to the device can break this invariant, and we depend
// on it in a few places (e.g. AccumulateGrad function).

// Number of nested reentrant backwards calls currently on this thread
static thread_local int current_depth = 0;
// Total nested reentrant backwards calls over all threads for workder_device
static thread_local int total_depth = 0;

struct FunctionTask {
GraphTask* base_;
std::shared_ptr<Function> fn_;
Expand All @@ -62,6 +67,8 @@ struct FunctionTask {
// exit. The engine sends a shutdown task to every queue upon its destruction.
bool isShutdownTask_;

int getReentrantDepth() const;

FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs, bool isShutdownTask = false)
: base_(base)
, fn_(std::move(fn))
Expand All @@ -79,8 +86,10 @@ struct CompareFunctionTaskTime {
return false;
} else if (!t2.fn_) {
return true;
} else {
} else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
} else {
return t1.getReentrantDepth() < t2.getReentrantDepth();
}
}
};
Expand Down Expand Up @@ -174,21 +183,28 @@ struct GraphTask {

// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ without synchronizaton
// Safe to read owner_ and reentrant_depth_ without synchronizaton
int owner_;
// The number of parent graph tasks for this graph task
const int reentrant_depth_;

bool can_checkpoint() {
return exec_info_.empty();
}

GraphTask(bool keep_graph, bool grad_mode)
GraphTask(bool keep_graph, bool grad_mode, int reentrant_depth)
: has_error_(false)
, outstanding_tasks_(0)
, keep_graph_(keep_graph)
, grad_mode_(grad_mode)
, owner_(NO_DEVICE) {}
, owner_(NO_DEVICE)
, reentrant_depth_(reentrant_depth) {}
};

int FunctionTask::getReentrantDepth() const {
return base_->reentrant_depth_;
}

auto ReadyQueue::push(FunctionTask item) -> void {
{
// Lock mutex for writing to heap_
Expand Down Expand Up @@ -216,7 +232,8 @@ auto ReadyQueue::pop() -> FunctionTask {
return task;
}

Engine::Engine() = default;
// This limit is based on the default python recursion limit which is 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments do not match the code 1000 vs 100

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set it lower than the actual python limit to take into account the function calls within python

Engine::Engine() : max_recursion_depth_(100) {}

// Send shutdown tasks to all ReadyQueues if no backward tasks are running
// Even though readyQueue should be empty, shutdown tasks have the highest
Expand Down Expand Up @@ -342,7 +359,13 @@ auto Engine::thread_main(GraphTask *graph_task) -> void {
}
}

if (graph_task) {
// When current_depth is 0 this worker thread is done and we need to notify
// the parent thread waiting on the graph_task
// NOTE: An edge case for this is when reentrant calls are repeatedly made in
// a thread which is at its maximum stack depth and they keep exiting right
// after. We will always switch to a new thread for each call, so, we'll keep
// oscillating between the two threads.
if (graph_task && current_depth == 0) {
graph_task->not_done_.notify_all();
}
}
Expand All @@ -359,6 +382,7 @@ void Engine::reentrant_thread_init() {
tp_shared->graphtasks_queue_.pop();
lk.unlock();
set_device(graph_task->owner_);
total_depth = graph_task->reentrant_depth_;
thread_main(graph_task);
}
}
Expand Down Expand Up @@ -632,7 +656,7 @@ auto Engine::execute(const edge_list& roots,
// Lock post_callbacks_lock_ before clearing final_callbacks_
ClearCallbacks _cb_guard(final_callbacks_, post_callbacks_lock_);

GraphTask graph_task(keep_graph, create_graph);
GraphTask graph_task(keep_graph, create_graph, worker_device == NO_DEVICE ? 0 : total_depth+1);
// Lock mutex while GraphTask is being set up
std::unique_lock<std::mutex> lock(graph_task.mutex_);

Expand All @@ -651,15 +675,24 @@ auto Engine::execute(const edge_list& roots,
return graph_task.outstanding_tasks_.load() == 0;
});
} else {
// Get back to work while we wait for our new graph_task to
// complete!
// See Note [Reentrant backwards]
// If no extra threads remaining, create a new thread for reentrant call
graph_task.owner_ = worker_device;
add_thread_pool_task(&graph_task);
graph_task.not_done_.wait(lock, [&graph_task]{
return graph_task.outstanding_tasks_.load() == 0;
});
++total_depth;
if(current_depth >= max_recursion_depth_){
// See Note [Reentrant backwards]
// If reached the max depth, switch to a different thread
add_thread_pool_task(&graph_task);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if we exceed thread_pool_shared_->graphtasks_queue_.size() here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::queue should resize dynamically, or am I misunderstanding your question?

graph_task.not_done_.wait(lock, [&graph_task]{
return graph_task.outstanding_tasks_.load() == 0;
});
} else {
// Get back to work while we wait for our new graph_task to
// complete!
++current_depth;
lock.unlock();
thread_main(&graph_task);
--current_depth;
}
--total_depth;
}

// Check for an exception while running backwards
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <thread>

namespace torch { namespace autograd {
struct ReadyQueue;
Expand Down Expand Up @@ -72,6 +73,8 @@ struct TORCH_API Engine {
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
int max_recursion_depth_;

struct ThreadPoolShared {
// Data structures used by the threads for executing reentrant backwards
Expand Down