Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ CAFFE2_API intrusive_ptr<ivalue::Future> collectAny(
ctx->srcFutures =
List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
if (src->hasError()) {
dst->setError(*src->error());
dst->setError(src->exception_ptr());
} else {
dst->markCompleted(src->constValue());
}
Expand Down
59 changes: 37 additions & 22 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,35 +300,32 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
markCompleted(IValue {});
}

virtual void setError(std::string err) {
setError(FutureError(std::move(err)));
}

void setError(FutureError&& error) {
void setError(std::exception_ptr eptr) {
std::unique_lock<std::mutex> lock(mutex_);
setErrorInternal(std::move(error), lock);
setErrorInternal(std::move(eptr), lock);
}

void setErrorIfNeeded(std::string errorMsg) {
void setErrorIfNeeded(std::exception_ptr eptr) {
std::unique_lock<std::mutex> lock(mutex_);
if (completed_) {
// This should be rare and shouldn't cause log spew. Its important to
// log errors and thats why we have this log here.
LOG(INFO) << "Skipping setting following error on the Future since " <<
"it is already marked completed (this is not neccessarily an error): "
<< errorMsg;
LOG(INFO)
<< "Skipping setting following error on the Future since "
<< "it is already marked completed (this is not neccessarily an error): "
<< tryRetrieveErrorMessageInternal(eptr);
return;
} else {
setErrorInternal(FutureError(std::move(errorMsg)), lock);
setErrorInternal(std::move(eptr), lock);
}
}

// Get the result of the current future.
virtual IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
if (error_) {
throw *error_;
if (eptr_) {
std::rethrow_exception(eptr_);
}
return value_;
}
Expand All @@ -338,7 +335,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
virtual const IValue& constValue() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
AT_ASSERT(!error_);
AT_ASSERT(!eptr_);
return value_;
}

Expand Down Expand Up @@ -375,31 +372,38 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
try {
fut->markCompleted(cb());
} catch (std::exception& e) {
fut->setError(e.what());
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}

// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessage() {
TORCH_CHECK(hasError(), "No error present on the future.");
std::unique_lock<std::mutex> lock(mutex_);
return tryRetrieveErrorMessageInternal(eptr_);
}

// Check if the current future has completed
virtual bool completed() const{
return completed_;
}

virtual bool hasValue() const {
std::unique_lock<std::mutex> lock(mutex_);
return completed_ && !error_;
return completed_ && !eptr_;
}

bool hasError() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_ ? true : false;
return eptr_ ? true : false;
}

c10::optional<FutureError> error() const {
std::exception_ptr exception_ptr() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_;
return eptr_;
}

CAFFE2_API friend std::ostream& operator<<(
Expand All @@ -412,11 +416,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {

private:
void setErrorInternal(
FutureError error,
std::exception_ptr eptr,
std::unique_lock<std::mutex>& lock) {
AT_ASSERT(!completed());
completed_ = true;
error_ = std::move(error);
eptr_ = std::move(eptr);

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
Expand All @@ -428,14 +432,25 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}
}

// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) {
try {
std::rethrow_exception(eptr);
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "Unknown Exception Type";
}
}

mutable std::mutex mutex_;
std::atomic_bool completed_ = {false}; // is this future complete
std::condition_variable finished_cv_;

IValue value_; // when finished the value
TypePtr type_;
std::vector<std::function<void(void)>> callbacks_;
c10::optional<FutureError> error_;
std::exception_ptr eptr_;
};

// Input is a list of Futures with the same target type.
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/test/ivalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ TEST(IValueTest, FutureExceptions) {
}
});
ivalue::Future::FutureError err("My Error");
f3->setError(std::move(err));
f3->setError(std::make_exception_ptr(err));
ASSERT_EQ(calledTimes, 1);
ASSERT_TRUE(f3->hasError());
ASSERT_EQ(std::string(f3->error()->what()), std::string("My Error"));
ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error"));
}

TEST(IValueTest, ValueEquality) {
Expand Down
11 changes: 7 additions & 4 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,8 @@ void testFutures() {
int sat1 = 0;
int sat2 = 0;
f1->addCallback([&]() { ++sat1; });
f1->setError("Failed");
f1->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_EQ(sat1, 1);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasError());
Expand All @@ -2001,8 +2002,9 @@ void testFutures() {
f1->addCallback([&]() { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
f1->setErrorIfNeeded("Dup");
ASSERT_TRUE(strcmp(f1->error()->what(), "Failed") == 0);
f1->setErrorIfNeeded(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
}
Expand Down Expand Up @@ -2082,7 +2084,8 @@ void testFutures() {
futures.push_back(s4);
auto c5 = collectAll(futures);
ASSERT_FALSE(c5->completed());
s4->setError("Failed");
s4->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_TRUE(c5->completed());
ASSERT_EQ(c5->value().toList().size(), 4);
try {
Expand Down
24 changes: 21 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_custom_function_exception(self):

tmp = (t1 + t2) * (t1 + t2)
t3 = TestAutograd.SimulateBackwardError.apply(tmp)
with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
t3.sum().backward()

def test_invalid_gradients(self):
Expand Down Expand Up @@ -2313,7 +2313,7 @@ def backward(ctx, grad):
return grad

d = ReentrantFunc.apply(c)
with self.assertRaisesRegex(RuntimeError, 'Simulate error'):
with self.assertRaisesRegex(Exception, 'Simulate error'):
d.sum().backward()

def test_broadcast_tensors(self):
Expand Down Expand Up @@ -6168,7 +6168,7 @@ def backward(ctx, grad):
t7 = t6 * t6

# Parent graph will error out first, while child graph will continue executing.
with self.assertRaisesRegex(RuntimeError, "Simulate error"):
with self.assertRaisesRegex(Exception, "Simulate error"):
torch.autograd.backward([t5.sum(), t7.sum()])

# No grads should be accumulated since child graph will stop execution
Expand Down Expand Up @@ -6964,6 +6964,24 @@ def train_fn_fork_join_calls_retain(x):
self.assertEqual(grad, grad1)
self.assertEqual(grad, grad2)

def test_preserve_backtrace(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, *grad):
raise ValueError("something")

t = torch.rand(10, requires_grad=True)
try:
Foo.apply(t).sum().backward()
except Exception:
import traceback
tb = sys.exc_info()[2]
tb_str = "\n".join(traceback.format_tb(tb))
self.assertTrue('raise ValueError("something")' in tb_str)

for test in method_tests():
add_test(*test)
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ void Engine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
graph_task->set_exception(e, fn);
graph_task->set_exception(std::current_exception(), fn);
}

bool GraphTask::completed() {
Expand Down Expand Up @@ -473,7 +473,7 @@ void GraphTask::mark_as_completed_and_run_post_processing() {
lock.unlock();
future_result_->markCompleted(std::move(vars));
} catch (std::exception& e) {
future_result_->setErrorIfNeeded(e.what());
future_result_->setErrorIfNeeded(std::current_exception());
}
}

Expand Down Expand Up @@ -523,11 +523,11 @@ void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
}

void GraphTask::set_exception(
std::exception& e,
std::exception_ptr eptr,
const std::shared_ptr<Node>& fn) {
set_exception_without_signal(fn);
if (!future_completed_.exchange(true)) {
future_result_->setError(e.what());
future_result_->setError(std::move(eptr));
}
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {

// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception& e, const std::shared_ptr<Node>& fn);
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);

// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/distributed/autograd/context/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,17 @@ void DistAutogradContext::addOutstandingRpc(
futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) {
if (futureMessage.hasError()) {
// If we have an error, let the local autograd engine know about it.
std::runtime_error err((*futureMessage.error()).what());
std::unique_lock<std::mutex> lock(lock_);
if (graphTask_) {
graphTask_->set_exception_without_signal(nullptr);
lock.unlock();
if (!graphTask_->future_completed_.exchange(true)) {
graphTask_->future_result_->setErrorIfNeeded(err.what());
graphTask_->future_result_->setErrorIfNeeded(
std::make_exception_ptr(*futureMessage.error()));
}
} else {
LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
<< err.what();
<< (*futureMessage.error()).what();
}
}
});
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/distributed/autograd/engine/dist_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,29 +389,31 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::runEngineAndAccumulateGradients(
// future that waits for all gradient accumulation to finish.
auto accumulateGradFuture = std::make_shared<rpc::FutureMessage>();

futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
if (futureGrads->hasError()) {
// Don't accumulate gradients if we receive an error.
// We must add the node information here since DistEngine::execute
// waits on accumulateGradFuture and will throw an exception once we
// set the error below.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
futureGrads->error()->what());
accumulateGradFuture->setError(errorMsg);
return;
}
futureGrads->addCallback(
[autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
if (futureGrads->hasError()) {
// Don't accumulate gradients if we receive an error.
// We must add the node information here since DistEngine::execute
// waits on accumulateGradFuture and will throw an exception once we
// set the error below.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
futureGrads->tryRetrieveErrorMessage());
accumulateGradFuture->setError(errorMsg);
return;
}

try {
const variable_list& grads = futureGrads->constValue().toTensorVector();
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
accumulateGradFuture->markCompleted(rpc::Message());
} catch (std::exception& e) {
accumulateGradFuture->setErrorIfNeeded(e.what());
}
});
try {
const variable_list& grads =
futureGrads->constValue().toTensorVector();
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
accumulateGradFuture->markCompleted(rpc::Message());
} catch (std::exception& e) {
accumulateGradFuture->setErrorIfNeeded(e.what());
}
});

return accumulateGradFuture;
}
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/distributed/rpc/python_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
at::wrapPropagateTLSState<void>([jitFuture, wp]() {
auto futureResponseMessage = wp.lock();
if (futureResponseMessage->hasError()) {
jitFuture->setError(futureResponseMessage->error()->what());
jitFuture->setError(
std::make_exception_ptr(*futureResponseMessage->error()));
} else {
jitFuture->markCompleted(
toIValue(futureResponseMessage->constValue()));
Expand All @@ -154,7 +155,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
at::wrapPropagateTLSState<void>([wp, jitFuture]() {
auto futureResponseMessage = wp.lock();
if (futureResponseMessage->hasError()) {
jitFuture->setError(futureResponseMessage->error()->what());
jitFuture->setError(
std::make_exception_ptr(*futureResponseMessage->error()));
} else {
jitFuture->markCompleted(IValue());
}
Expand Down
Loading