Skip to content

Commit ba03087

Browse files
committed
Preserve python backtrace in autograd engine errors.
This PR attempts to address #42560 by capturing the appropriate exception_ptr in the autograd engine and passing it over to the Future. As part of this change, there is a significant change the Future API where we now only accept an exception_ptr as part of setError. For the example in #42560, the exception trace would now look like: ``` > Traceback (most recent call last): > File "test_autograd.py", line 6914, in test_preserve_backtrace > Foo.apply(t).sum().backward() > File "torch/tensor.py", line 214, in backward > torch.autograd.backward(self, gradient, retain_graph, create_graph) > File "torch/autograd/__init__.py", line 127, in backward > allow_unreachable=True) # allow_unreachable flag > File "torch/autograd/function.py", line 87, in apply > return self._forward_cls.backward(self, *args) > File "test_autograd.py", line 6910, in backward > raise ValueError("something") > ValueError: something ``` Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/) ghstack-source-id: 110820151 Pull Request resolved: #43684
1 parent 77a2ae6 commit ba03087

File tree

16 files changed

+556
-429
lines changed

16 files changed

+556
-429
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ namespace ivalue {
2222
// is declared in jit_type.h
2323
void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
2424
// NB: doing pointer comparison here
25-
// If in the future there ever arises a need to call operator== on custom class
26-
// Type's, this needs to be changed!
27-
TORCH_CHECK(actual_type == expected_type,
28-
"Tried to convert an IValue of type ",
29-
actual_type->repr_str(),
30-
" to custom class type ",
31-
expected_type->repr_str());
25+
// If in the future there ever arises a need to call operator== on custom
26+
// class Type's, this needs to be changed!
27+
TORCH_CHECK(
28+
actual_type == expected_type,
29+
"Tried to convert an IValue of type ",
30+
actual_type->repr_str(),
31+
" to custom class type ",
32+
expected_type->repr_str());
3233
}
3334

3435
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
@@ -117,7 +118,7 @@ TypePtr IValue::type() const {
117118
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
118119
}
119120

120-
void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
121+
void IValue::visit(const std::function<bool(const IValue&)>& visitor) const {
121122
if (visitor(*this)) {
122123
// Short cut.
123124
return;
@@ -146,15 +147,15 @@ void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
146147
auto obj_type = type()->expect<ClassType>();
147148
auto obj_value = toObject();
148149
auto attributes = obj_type->getAttributes();
149-
for (const auto& attr: attributes) {
150+
for (const auto& attr : attributes) {
150151
auto attribute = obj_value->getAttr(attr.getName());
151152
attribute.visit(visitor);
152153
}
153154
break;
154155
}
155156
default:
156157
break;
157-
}
158+
}
158159
}
159160

160161
void IValue::getSubValues(HashAliasedIValues& subValues) const {
@@ -189,7 +190,7 @@ void IValue::getSubValues(HashAliasedIValues& subValues) const {
189190
auto obj_type = type()->expect<ClassType>();
190191
auto obj_value = toObject();
191192
auto attributes = obj_type->getAttributes();
192-
for (const auto& attr: attributes) {
193+
for (const auto& attr : attributes) {
193194
auto attribute = obj_value->getAttr(attr.getName());
194195
attribute.getSubValues(subValues);
195196
}
@@ -381,7 +382,7 @@ std::ostream& printDict(
381382
out << "}";
382383
return out;
383384
}
384-
}
385+
} // namespace
385386

386387
// Properly disambiguate the type of an empty dict
387388
std::ostream& printMaybeAnnotatedDict(
@@ -401,9 +402,9 @@ std::ostream& printMaybeAnnotatedDict(
401402

402403
std::ostream& IValue::repr(
403404
std::ostream& out,
404-
std::function<bool(std::ostream&, const IValue& v)>
405-
customFormatter) const {
406-
// First check if the caller has provided a custom formatter. Use that if possible.
405+
std::function<bool(std::ostream&, const IValue& v)> customFormatter) const {
406+
// First check if the caller has provided a custom formatter. Use that if
407+
// possible.
407408
if (customFormatter(out, *this)) {
408409
return out;
409410
}
@@ -419,7 +420,7 @@ std::ostream& IValue::repr(
419420
case IValue::Tag::Double: {
420421
double d = v.toDouble();
421422
int c = std::fpclassify(d);
422-
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
423+
if ((c == FP_NORMAL || c == FP_ZERO) && std::abs(d) < 1e10) {
423424
int64_t i = int64_t(d);
424425
if (double(i) == d) {
425426
return out << i << ".";
@@ -455,8 +456,8 @@ std::ostream& IValue::repr(
455456
return printMaybeAnnotatedDict(out, v, formatter);
456457
case IValue::Tag::Enum: {
457458
auto enum_holder = v.toEnumHolder();
458-
return out << enum_holder->qualifiedClassName() << "." <<
459-
enum_holder->name();
459+
return out << enum_holder->qualifiedClassName() << "."
460+
<< enum_holder->name();
460461
}
461462
default:
462463
TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
@@ -468,11 +469,9 @@ std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
468469
return out;
469470
}
470471

471-
std::ostream& operator<<(std::ostream & out, const IValue & v) {
472-
auto formatter = [&](std::ostream& out, const IValue& v) {
473-
out << v;
474-
};
475-
switch(v.tag) {
472+
std::ostream& operator<<(std::ostream& out, const IValue& v) {
473+
auto formatter = [&](std::ostream& out, const IValue& v) { out << v; };
474+
switch (v.tag) {
476475
case IValue::Tag::None:
477476
return out << v.toNone();
478477
case IValue::Tag::Tensor:
@@ -487,11 +486,10 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
487486
}
488487
}
489488
auto orig_prec = out.precision();
490-
return out
491-
<< std::setprecision(std::numeric_limits<double>::max_digits10)
492-
<< v.toDouble()
493-
<< std::setprecision(orig_prec);
494-
} case IValue::Tag::Int:
489+
return out << std::setprecision(std::numeric_limits<double>::max_digits10)
490+
<< v.toDouble() << std::setprecision(orig_prec);
491+
}
492+
case IValue::Tag::Int:
495493
return out << v.toInt();
496494
case IValue::Tag::Bool:
497495
return out << (v.toBool() ? "True" : "False");
@@ -534,10 +532,9 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
534532
}
535533
case IValue::Tag::Enum: {
536534
auto enum_holder = v.toEnumHolder();
537-
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
538-
enum_holder->name() << ">";
535+
return out << "Enum<" << enum_holder->unqualifiedClassName() << "."
536+
<< enum_holder->name() << ">";
539537
}
540-
541538
}
542539
AT_ERROR("Tag not found: ", v.tagKind());
543540
}
@@ -557,13 +554,12 @@ IValue IValue::deepcopy() const {
557554
return deepcopy(memo);
558555
}
559556

560-
IValue IValue::deepcopy(
561-
IValue::HashAliasedIValueMap& memo) const {
557+
IValue IValue::deepcopy(IValue::HashAliasedIValueMap& memo) const {
562558
if (memo.count(*this)) {
563559
return memo.at(*this);
564560
}
565561
IValue copy;
566-
switch(tag) {
562+
switch (tag) {
567563
case IValue::Tag::Tensor:
568564
copy = IValue(toTensor().clone());
569565
break;
@@ -573,26 +569,25 @@ IValue IValue::deepcopy(
573569
copied_tuple.push_back(e.deepcopy(memo));
574570
}
575571
copy = IValue(ivalue::Tuple::create(copied_tuple));
576-
}
577-
break;
572+
} break;
578573
case IValue::Tag::GenericList: {
579574
auto list = toList();
580575
auto copied_list = c10::impl::GenericList(list.elementType());
581576
for (IValue v : list) {
582577
copied_list.push_back(v.deepcopy(memo));
583578
}
584579
copy = IValue(copied_list);
585-
}
586-
break;
580+
} break;
587581
case IValue::Tag::GenericDict: {
588582
auto dict = toGenericDict();
589-
auto copied_dict = c10::impl::GenericDict(dict.keyType(), dict.valueType());
583+
auto copied_dict =
584+
c10::impl::GenericDict(dict.keyType(), dict.valueType());
590585
for (const auto& entry : dict) {
591-
copied_dict.insert(entry.key().deepcopy(memo), entry.value().deepcopy(memo));
586+
copied_dict.insert(
587+
entry.key().deepcopy(memo), entry.value().deepcopy(memo));
592588
}
593589
copy = IValue(copied_dict);
594-
}
595-
break;
590+
} break;
596591
case IValue::Tag::Object: {
597592
auto class_type = type()->expect<ClassType>();
598593
if (class_type->hasMethod("__getstate__") &&
@@ -653,7 +648,8 @@ void ivalue::Object::resizeObject(size_t slot) {
653648
}
654649

655650
c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy() const {
656-
auto object = ivalue::Object::create(c10::StrongTypePtr(type_.cu_, type()), type()->numAttributes());
651+
auto object = ivalue::Object::create(
652+
c10::StrongTypePtr(type_.cu_, type()), type()->numAttributes());
657653
for (auto i = 0; i < slots_.size(); ++i) {
658654
object->setSlot(i, slots_[i]);
659655
}
@@ -665,8 +661,10 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy() const {
665661
return deepcopy(memo);
666662
}
667663

668-
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedIValueMap& memo) const {
669-
auto object = ivalue::Object::create(c10::StrongTypePtr(type_.cu_, type()), type()->numAttributes());
664+
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
665+
IValue::HashAliasedIValueMap& memo) const {
666+
auto object = ivalue::Object::create(
667+
c10::StrongTypePtr(type_.cu_, type()), type()->numAttributes());
670668
for (size_t i = 0; i < slots_.size(); ++i) {
671669
if (slots_[i].type() == c10::CapsuleType::get()) {
672670
// If we've gotten here, it means that we have *not* copied this
@@ -679,7 +677,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedI
679677
err << " " << qualname->qualifiedName();
680678
}
681679
err << ". Please define serialization methods via def_pickle() for "
682-
"this class.";
680+
"this class.";
683681
AT_ERROR(err.str());
684682
}
685683
object->setSlot(i, slots_[i].deepcopy(memo));
@@ -696,8 +694,8 @@ StrongTypePtr::StrongTypePtr(
696694
}
697695

698696
ska::flat_hash_map<std::type_index, c10::ClassTypePtr>& getCustomClassTypeMap() {
699-
static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
700-
return tmap;
697+
static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
698+
return tmap;
701699
}
702700

703701
std::unordered_map<std::string, std::function<PyObject*(void*)>>&
@@ -770,7 +768,7 @@ CAFFE2_API intrusive_ptr<ivalue::Future> collectAny(
770768
ctx->srcFutures =
771769
List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
772770
if (src->hasError()) {
773-
dst->setError(*src->error());
771+
dst->setError(src->exception_ptr());
774772
} else {
775773
dst->markCompleted(src->constValue());
776774
}

0 commit comments

Comments
 (0)