Skip to content

Commit 00d108d

Browse files
committed
Preserve python backtrace in autograd engine errors.
Pull Request resolved: #43684 This PR attempts to address #42560 by capturing the appropriate exception_ptr in the autograd engine and passing it over to the Future. As part of this change, there is a significant change the Future API where we now only accept an exception_ptr as part of setError. For the example in #42560, the exception trace would now look like: ``` > Traceback (most recent call last): > File "test_autograd.py", line 6914, in test_preserve_backtrace > Foo.apply(t).sum().backward() > File "torch/tensor.py", line 214, in backward > torch.autograd.backward(self, gradient, retain_graph, create_graph) > File "torch/autograd/__init__.py", line 127, in backward > allow_unreachable=True) # allow_unreachable flag > File "torch/autograd/function.py", line 87, in apply > return self._forward_cls.backward(self, *args) > File "test_autograd.py", line 6910, in backward > raise ValueError("something") > ValueError: something ``` ghstack-source-id: 111109637 Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
1 parent 0394c5a commit 00d108d

File tree

17 files changed

+119
-82
lines changed

17 files changed

+119
-82
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ CAFFE2_API intrusive_ptr<ivalue::Future> collectAny(
770770
ctx->srcFutures =
771771
List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
772772
if (src->hasError()) {
773-
dst->setError(*src->error());
773+
dst->setError(src->exception_ptr());
774774
} else {
775775
dst->markCompleted(src->constValue());
776776
}

aten/src/ATen/core/ivalue_inl.h

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -300,35 +300,32 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
300300
markCompleted(IValue {});
301301
}
302302

303-
virtual void setError(std::string err) {
304-
setError(FutureError(std::move(err)));
305-
}
306-
307-
void setError(FutureError&& error) {
303+
void setError(std::exception_ptr eptr) {
308304
std::unique_lock<std::mutex> lock(mutex_);
309-
setErrorInternal(std::move(error), lock);
305+
setErrorInternal(std::move(eptr), lock);
310306
}
311307

312-
void setErrorIfNeeded(std::string errorMsg) {
308+
void setErrorIfNeeded(std::exception_ptr eptr) {
313309
std::unique_lock<std::mutex> lock(mutex_);
314310
if (completed_) {
315311
// This should be rare and shouldn't cause log spew. Its important to
316312
// log errors and thats why we have this log here.
317-
LOG(INFO) << "Skipping setting following error on the Future since " <<
318-
"it is already marked completed (this is not neccessarily an error): "
319-
<< errorMsg;
313+
LOG(INFO)
314+
<< "Skipping setting following error on the Future since "
315+
<< "it is already marked completed (this is not neccessarily an error): "
316+
<< tryRetrieveErrorMessageInternal(eptr);
320317
return;
321318
} else {
322-
setErrorInternal(FutureError(std::move(errorMsg)), lock);
319+
setErrorInternal(std::move(eptr), lock);
323320
}
324321
}
325322

326323
// Get the result of the current future.
327324
virtual IValue value() {
328325
std::unique_lock<std::mutex> lock(mutex_);
329326
AT_ASSERT(completed());
330-
if (error_) {
331-
throw *error_;
327+
if (eptr_) {
328+
std::rethrow_exception(eptr_);
332329
}
333330
return value_;
334331
}
@@ -338,7 +335,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
338335
virtual const IValue& constValue() {
339336
std::unique_lock<std::mutex> lock(mutex_);
340337
AT_ASSERT(completed());
341-
AT_ASSERT(!error_);
338+
AT_ASSERT(!eptr_);
342339
return value_;
343340
}
344341

@@ -375,31 +372,38 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
375372
try {
376373
fut->markCompleted(cb());
377374
} catch (std::exception& e) {
378-
fut->setError(e.what());
375+
fut->setError(std::current_exception());
379376
}
380377
},
381378
std::move(callback)));
382379
return fut;
383380
}
384381

382+
// Tries to retrieve the error message from std::exception_ptr.
383+
std::string tryRetrieveErrorMessage() {
384+
TORCH_CHECK(hasError(), "No error present on the future.");
385+
std::unique_lock<std::mutex> lock(mutex_);
386+
return tryRetrieveErrorMessageInternal(eptr_);
387+
}
388+
385389
// Check if the current future has completed
386390
virtual bool completed() const{
387391
return completed_;
388392
}
389393

390394
virtual bool hasValue() const {
391395
std::unique_lock<std::mutex> lock(mutex_);
392-
return completed_ && !error_;
396+
return completed_ && !eptr_;
393397
}
394398

395399
bool hasError() const {
396400
std::unique_lock<std::mutex> lock(mutex_);
397-
return error_ ? true : false;
401+
return eptr_ ? true : false;
398402
}
399403

400-
c10::optional<FutureError> error() const {
404+
std::exception_ptr exception_ptr() const {
401405
std::unique_lock<std::mutex> lock(mutex_);
402-
return error_;
406+
return eptr_;
403407
}
404408

405409
CAFFE2_API friend std::ostream& operator<<(
@@ -412,11 +416,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
412416

413417
private:
414418
void setErrorInternal(
415-
FutureError error,
419+
std::exception_ptr eptr,
416420
std::unique_lock<std::mutex>& lock) {
417421
AT_ASSERT(!completed());
418422
completed_ = true;
419-
error_ = std::move(error);
423+
eptr_ = std::move(eptr);
420424

421425
std::vector<std::function<void(void)>> cbs;
422426
cbs.swap(callbacks_);
@@ -428,14 +432,25 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
428432
}
429433
}
430434

435+
// Tries to retrieve the error message from std::exception_ptr.
436+
std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) {
437+
try {
438+
std::rethrow_exception(eptr);
439+
} catch (const std::exception& e) {
440+
return e.what();
441+
} catch (...) {
442+
return "Unknown Exception Type";
443+
}
444+
}
445+
431446
mutable std::mutex mutex_;
432447
std::atomic_bool completed_ = {false}; // is this future complete
433448
std::condition_variable finished_cv_;
434449

435450
IValue value_; // when finished the value
436451
TypePtr type_;
437452
std::vector<std::function<void(void)>> callbacks_;
438-
c10::optional<FutureError> error_;
453+
std::exception_ptr eptr_;
439454
};
440455

441456
// Input is a list of Futures with the same target type.

aten/src/ATen/test/ivalue_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ TEST(IValueTest, FutureExceptions) {
139139
}
140140
});
141141
ivalue::Future::FutureError err("My Error");
142-
f3->setError(std::move(err));
142+
f3->setError(std::make_exception_ptr(err));
143143
ASSERT_EQ(calledTimes, 1);
144144
ASSERT_TRUE(f3->hasError());
145-
ASSERT_EQ(std::string(f3->error()->what()), std::string("My Error"));
145+
ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error"));
146146
}
147147

148148
TEST(IValueTest, ValueEquality) {

test/cpp/jit/test_misc.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,8 @@ void testFutures() {
19871987
int sat1 = 0;
19881988
int sat2 = 0;
19891989
f1->addCallback([&]() { ++sat1; });
1990-
f1->setError("Failed");
1990+
f1->setError(
1991+
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
19911992
ASSERT_EQ(sat1, 1);
19921993
ASSERT_TRUE(f1->completed());
19931994
ASSERT_TRUE(f1->hasError());
@@ -2001,8 +2002,9 @@ void testFutures() {
20012002
f1->addCallback([&]() { ++sat2; });
20022003
ASSERT_EQ(sat1, 1);
20032004
ASSERT_EQ(sat2, 1);
2004-
f1->setErrorIfNeeded("Dup");
2005-
ASSERT_TRUE(strcmp(f1->error()->what(), "Failed") == 0);
2005+
f1->setErrorIfNeeded(
2006+
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
2007+
ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
20062008
ASSERT_EQ(sat1, 1);
20072009
ASSERT_EQ(sat2, 1);
20082010
}
@@ -2082,7 +2084,8 @@ void testFutures() {
20822084
futures.push_back(s4);
20832085
auto c5 = collectAll(futures);
20842086
ASSERT_FALSE(c5->completed());
2085-
s4->setError("Failed");
2087+
s4->setError(
2088+
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
20862089
ASSERT_TRUE(c5->completed());
20872090
ASSERT_EQ(c5->value().toList().size(), 4);
20882091
try {

test/test_autograd.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_custom_function_exception(self):
253253

254254
tmp = (t1 + t2) * (t1 + t2)
255255
t3 = TestAutograd.SimulateBackwardError.apply(tmp)
256-
with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
256+
with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
257257
t3.sum().backward()
258258

259259
def test_invalid_gradients(self):
@@ -2313,7 +2313,7 @@ def backward(ctx, grad):
23132313
return grad
23142314

23152315
d = ReentrantFunc.apply(c)
2316-
with self.assertRaisesRegex(RuntimeError, 'Simulate error'):
2316+
with self.assertRaisesRegex(Exception, 'Simulate error'):
23172317
d.sum().backward()
23182318

23192319
def test_broadcast_tensors(self):
@@ -6168,7 +6168,7 @@ def backward(ctx, grad):
61686168
t7 = t6 * t6
61696169

61706170
# Parent graph will error out first, while child graph will continue executing.
6171-
with self.assertRaisesRegex(RuntimeError, "Simulate error"):
6171+
with self.assertRaisesRegex(Exception, "Simulate error"):
61726172
torch.autograd.backward([t5.sum(), t7.sum()])
61736173

61746174
# No grads should be accumulated since child graph will stop execution
@@ -6964,6 +6964,24 @@ def train_fn_fork_join_calls_retain(x):
69646964
self.assertEqual(grad, grad1)
69656965
self.assertEqual(grad, grad2)
69666966

6967+
def test_preserve_backtrace(self):
6968+
class Foo(torch.autograd.Function):
6969+
@staticmethod
6970+
def forward(ctx, input):
6971+
return input
6972+
6973+
@staticmethod
6974+
def backward(ctx, *grad):
6975+
raise ValueError("something")
6976+
6977+
t = torch.rand(10, requires_grad=True)
6978+
try:
6979+
Foo.apply(t).sum().backward()
6980+
except Exception:
6981+
import traceback
6982+
tb = sys.exc_info()[2]
6983+
tb_str = "\n".join(traceback.format_tb(tb))
6984+
self.assertTrue('raise ValueError("something")' in tb_str)
69676985

69686986
for test in method_tests():
69696987
add_test(*test)

torch/csrc/autograd/engine.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ void Engine::thread_on_exception(
442442
std::shared_ptr<GraphTask> graph_task,
443443
const std::shared_ptr<Node>& fn,
444444
std::exception& e) {
445-
graph_task->set_exception(e, fn);
445+
graph_task->set_exception(std::current_exception(), fn);
446446
}
447447

448448
bool GraphTask::completed() {
@@ -473,7 +473,7 @@ void GraphTask::mark_as_completed_and_run_post_processing() {
473473
lock.unlock();
474474
future_result_->markCompleted(std::move(vars));
475475
} catch (std::exception& e) {
476-
future_result_->setErrorIfNeeded(e.what());
476+
future_result_->setErrorIfNeeded(std::current_exception());
477477
}
478478
}
479479

@@ -523,11 +523,11 @@ void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
523523
}
524524

525525
void GraphTask::set_exception(
526-
std::exception& e,
526+
std::exception_ptr eptr,
527527
const std::shared_ptr<Node>& fn) {
528528
set_exception_without_signal(fn);
529529
if (!future_completed_.exchange(true)) {
530-
future_result_->setError(e.what());
530+
future_result_->setError(std::move(eptr));
531531
}
532532
}
533533

torch/csrc/autograd/engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {
129129

130130
// Set an appropriate exception on this graph_task which was encountered while
131131
// running the provided function.
132-
void set_exception(std::exception& e, const std::shared_ptr<Node>& fn);
132+
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
133133

134134
// Set an appropriate exception on this graph_task which was encountered while
135135
// running the provided function. But doesn't signal completion on

torch/csrc/distributed/autograd/context/context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,17 @@ void DistAutogradContext::addOutstandingRpc(
127127
futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) {
128128
if (futureMessage.hasError()) {
129129
// If we have an error, let the local autograd engine know about it.
130-
std::runtime_error err((*futureMessage.error()).what());
131130
std::unique_lock<std::mutex> lock(lock_);
132131
if (graphTask_) {
133132
graphTask_->set_exception_without_signal(nullptr);
134133
lock.unlock();
135134
if (!graphTask_->future_completed_.exchange(true)) {
136-
graphTask_->future_result_->setErrorIfNeeded(err.what());
135+
graphTask_->future_result_->setErrorIfNeeded(
136+
std::make_exception_ptr(*futureMessage.error()));
137137
}
138138
} else {
139139
LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
140-
<< err.what();
140+
<< (*futureMessage.error()).what();
141141
}
142142
}
143143
});

torch/csrc/distributed/autograd/engine/dist_engine.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -389,29 +389,31 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::runEngineAndAccumulateGradients(
389389
// future that waits for all gradient accumulation to finish.
390390
auto accumulateGradFuture = std::make_shared<rpc::FutureMessage>();
391391

392-
futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
393-
if (futureGrads->hasError()) {
394-
// Don't accumulate gradients if we receive an error.
395-
// We must add the node information here since DistEngine::execute
396-
// waits on accumulateGradFuture and will throw an exception once we
397-
// set the error below.
398-
std::string errorMsg = c10::str(
399-
"Error on Node ",
400-
DistAutogradContainer::getInstance().getWorkerId(),
401-
": ",
402-
futureGrads->error()->what());
403-
accumulateGradFuture->setError(errorMsg);
404-
return;
405-
}
392+
futureGrads->addCallback(
393+
[autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
394+
if (futureGrads->hasError()) {
395+
// Don't accumulate gradients if we receive an error.
396+
// We must add the node information here since DistEngine::execute
397+
// waits on accumulateGradFuture and will throw an exception once we
398+
// set the error below.
399+
std::string errorMsg = c10::str(
400+
"Error on Node ",
401+
DistAutogradContainer::getInstance().getWorkerId(),
402+
": ",
403+
futureGrads->tryRetrieveErrorMessage());
404+
accumulateGradFuture->setError(errorMsg);
405+
return;
406+
}
406407

407-
try {
408-
const variable_list& grads = futureGrads->constValue().toTensorVector();
409-
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
410-
accumulateGradFuture->markCompleted(rpc::Message());
411-
} catch (std::exception& e) {
412-
accumulateGradFuture->setErrorIfNeeded(e.what());
413-
}
414-
});
408+
try {
409+
const variable_list& grads =
410+
futureGrads->constValue().toTensorVector();
411+
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
412+
accumulateGradFuture->markCompleted(rpc::Message());
413+
} catch (std::exception& e) {
414+
accumulateGradFuture->setErrorIfNeeded(e.what());
415+
}
416+
});
415417

416418
return accumulateGradFuture;
417419
}

torch/csrc/distributed/rpc/python_functions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
138138
at::wrapPropagateTLSState<void>([jitFuture, wp]() {
139139
auto futureResponseMessage = wp.lock();
140140
if (futureResponseMessage->hasError()) {
141-
jitFuture->setError(futureResponseMessage->error()->what());
141+
jitFuture->setError(
142+
std::make_exception_ptr(*futureResponseMessage->error()));
142143
} else {
143144
jitFuture->markCompleted(
144145
toIValue(futureResponseMessage->constValue()));
@@ -154,7 +155,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
154155
at::wrapPropagateTLSState<void>([wp, jitFuture]() {
155156
auto futureResponseMessage = wp.lock();
156157
if (futureResponseMessage->hasError()) {
157-
jitFuture->setError(futureResponseMessage->error()->what());
158+
jitFuture->setError(
159+
std::make_exception_ptr(*futureResponseMessage->error()));
158160
} else {
159161
jitFuture->markCompleted(IValue());
160162
}

0 commit comments

Comments
 (0)