Skip to content

Commit a0a74f0

Browse files
committed
[jit] fix segfault in attribute lookup on loaded ScriptModules
The IR emitter looks for attributes on modules like: 1. Check the JIT type for the attribute 2. Check the originating Python class, in order to fulfill requests for, e.g. static methods or ignored methods. In the case where you do: ``` inner_module = torch.jit.load("inner.pt") wrapped = Wrapper(inner_module) # wrap the loaded ScriptModule in an nn.Module torch.jit.script(wrapped) ``` The IR emitter may check for attributes on `inner_module`. There is no originating Python class for `inner_module`, since it was directly compiled from the serialized format. Due to a bug in the code, we don't guard for this case an a segfault results if the wrapper asks for an undefined attribute. The lookup in this case looks like: 1. Check the JIT type for the attribute (not there!) 2. Check the originating Python class (this is a nullptr! segfault!) This PR guards this case and properly just raises an attribute missing compiler error instead of segfaulting. ghstack-source-id: a27e4a4 Pull Request resolved: #43284
1 parent d467ac8 commit a0a74f0

File tree

4 files changed

+44
-11
lines changed

4 files changed

+44
-11
lines changed

test/test_jit.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15158,6 +15158,30 @@ def parameter_script(x: torch.nn.Parameter):
1515815158
input = torch.ones(2, 2)
1515915159
self.assertEqual(input, parameter_script(input))
1516015160

15161+
def test_save_load_attr_error(self):
15162+
class Inner(nn.Module):
15163+
def __init__(self):
15164+
super().__init__()
15165+
15166+
def forward(self, x):
15167+
return x
15168+
15169+
class Wrapper(nn.Module):
15170+
def __init__(self, inner):
15171+
super().__init__()
15172+
self.inner = inner
15173+
15174+
def forward(self, x):
15175+
# this attribute doesn't exist on `Inner`
15176+
return self.inner.b(x)
15177+
15178+
inner_module = torch.jit.script(Inner())
15179+
inner_module = self.getExportImportCopy(inner_module)
15180+
wrapped = Wrapper(inner_module)
15181+
# This should properly complain that `self.inner` doesn't have the attribute `b`
15182+
with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
15183+
torch.jit.script(wrapped)
15184+
1516115185

1516215186
# known to be failing in tracer
1516315187
EXCLUDE_TRACED = {

torch/csrc/jit/frontend/concrete_module_type.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ TypePtr ConcreteModuleType::getJitType() const {
128128
return jitType_;
129129
}
130130

131-
py::object ConcreteModuleType::getPyClass() const {
131+
c10::optional<py::object> ConcreteModuleType::getPyClass() const {
132+
if (!data_.pyClass_) {
133+
return c10::nullopt;
134+
}
132135
return data_.pyClass_;
133136
}
134137

torch/csrc/jit/frontend/concrete_module_type.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ConcreteModuleType;
5858
class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
5959
public:
6060
explicit ConcreteModuleTypeBuilder(py::object pyClass) {
61+
TORCH_INTERNAL_ASSERT(pyClass);
6162
pyClass_ = std::move(pyClass);
6263
}
6364
void addConstant(std::string name, py::object value);
@@ -192,7 +193,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
192193
static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
193194

194195
TypePtr getJitType() const;
195-
py::object getPyClass() const;
196+
c10::optional<py::object> getPyClass() const;
196197
IterableModuleKind getIterableModuleKind() const;
197198
c10::optional<std::vector<std::string>> findOverloads(
198199
const std::string& name) const;

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -535,21 +535,26 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
535535
// 5. Check if it's an attribute of the original Python class that this
536536
// ScriptModule was derived from. The only class attributes we handle are
537537
// methods.
538+
const auto maybePyClass = concreteType_->getPyClass();
539+
if (!maybePyClass) {
540+
// ConcreteType doesn't always have an originating Python class, e.g. if it
541+
// was derived from a serialized ScriptModule. In this case, we've exhausted
542+
// our options for attr lookup.
543+
return nullptr;
544+
}
538545
py::object unboundMethod = py::getattr(
539-
concreteType_->getPyClass(),
540-
field.c_str(),
541-
pybind11::cast<pybind11::none>(Py_None));
546+
*maybePyClass, field.c_str(), pybind11::cast<pybind11::none>(Py_None));
542547

543548
if (py::isinstance<py::function>(unboundMethod)) {
544-
bool isStaticFn = py::cast<bool>(
545-
py::module::import("torch._jit_internal")
546-
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
549+
bool isStaticFn =
550+
py::cast<bool>(py::module::import("torch._jit_internal")
551+
.attr("is_static_fn")(*maybePyClass, field.c_str()));
547552
if (isStaticFn) {
548553
// Functions within the module annotated with @staticmethod do not need
549554
// binding.
550-
py::object staticFn = py::module::import("torch._jit_internal")
551-
.attr("get_static_fn")(
552-
concreteType_->getPyClass(), field.c_str());
555+
py::object staticFn =
556+
py::module::import("torch._jit_internal")
557+
.attr("get_static_fn")(*maybePyClass, field.c_str());
553558
return toSugaredValue(staticFn, m, loc);
554559
}
555560
// For Python methods that we're trying to call directly, we need to bind

0 commit comments

Comments
 (0)