Skip to content

Commit f596d5b

Browse files
committed
[jit] Better match behavior of loaded ScriptModules vs. freshly created ones
IR emitter uses `ModuleValue` to represent ScriptModules and emit IR for attribute access, submodule access, etc. `ModuleValue` relies on two pieces of information, the JIT type of the module, and the `ConcreteModuleType`, which encapsulates Python-only information about the module. ScriptModules loaded from a package used to create a dummy ConcreteModuleType without any info in it. This led to divergences in behavior during compilation. This PR makes the two ways of constructing a ConcreteModuleType equivalent, modulo any py-only information (which, by definition, is never present in packaged files anyway). ghstack-source-id: 6c20331 Pull Request resolved: #43298
1 parent acbc573 commit f596d5b

File tree

5 files changed

+106
-36
lines changed

5 files changed

+106
-36
lines changed

test/test_jit.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15192,6 +15192,47 @@ def forward(self, x):
1519215192
with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
1519315193
torch.jit.script(wrapped)
1519415194

15195+
def test_rescripting_loaded_modules(self):
15196+
class InnerSubmod(nn.Module):
15197+
my_constant: torch.jit.Final[int]
15198+
15199+
def __init__(self):
15200+
super().__init__()
15201+
self.register_buffer("foo", torch.ones(1))
15202+
self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
15203+
self.baz = torch.ones(1)
15204+
self.my_constant = 1
15205+
15206+
def forward(self, x):
15207+
return x + x
15208+
15209+
class Inner(nn.Module):
15210+
def __init__(self):
15211+
super().__init__()
15212+
self.submod = InnerSubmod()
15213+
15214+
def forward(self, x):
15215+
return self.submod(x)
15216+
15217+
class Wrapper(nn.Module):
15218+
def __init__(self, inner):
15219+
super().__init__()
15220+
self.inner = inner
15221+
15222+
def forward(self, x):
15223+
# access inner elements
15224+
ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
15225+
ret = ret + self.inner.submod.my_constant
15226+
return ret
15227+
15228+
inner_module = torch.jit.script(Inner())
15229+
wrapped = Wrapper(inner_module)
15230+
self.checkModule(wrapped, torch.ones(1))
15231+
15232+
inner_module_loaded = self.getExportImportCopy(inner_module)
15233+
wrapped_loaded = Wrapper(inner_module_loaded)
15234+
self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
15235+
1519515236

1519615237
# known to be failing in tracer
1519715238
EXCLUDE_TRACED = {

torch/csrc/jit/frontend/concrete_module_type.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,7 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
2929
}
3030

3131
for (const auto& pr : constants_) {
32-
const auto& name = pr.first;
33-
const auto& val = pr.second.v_;
34-
auto match = tryToInferType(val);
35-
if (!match.success()) {
36-
TORCH_INTERNAL_ASSERT(
37-
false,
38-
"We need to infer the type of constant to convert the python value to IValue, but failed to infer type of ",
39-
py::str(val),
40-
"\n:",
41-
match.reason());
42-
}
43-
// Validation and conversion to make sure `val` is a valid constant
44-
// is done in python, see `torch/jit/_recursive.py`
45-
cls->addConstant(name, toIValue(val, match.type()));
32+
cls->addConstant(pr.first, pr.second);
4633
}
4734

4835
for (const auto& moduleInfo : modules_) {
@@ -57,15 +44,43 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
5744

5845
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
5946
TypePtr type) {
47+
ConcreteModuleTypeBuilder builder;
48+
builder.setPoisoned();
49+
6050
// `type` should either be a module interface or a class type
6151
if (auto interface = type->cast<InterfaceType>()) {
6252
TORCH_INTERNAL_ASSERT(interface->is_module());
6353
} else {
64-
TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
54+
const auto classType = type->expect<ClassType>();
55+
56+
// Populate the builder metadata from the JIT type. This is to ensure
57+
// ConcreteModuleTypes produced from Python and ones produced from a JIT
58+
// type directly behave the same to the rest of the system.
59+
for (size_t i = 0; i < classType->numAttributes(); i++) {
60+
const auto& attrName = classType->getAttributeName(i);
61+
const auto& attrType = classType->getAttribute(i);
62+
if (attrType->is_module()) {
63+
builder.addModule(attrName, ConcreteModuleType::fromJitType(attrType));
64+
} else {
65+
builder.addAttribute(
66+
attrName,
67+
attrType,
68+
classType->is_parameter(i),
69+
classType->is_buffer(i));
70+
}
71+
}
72+
73+
for (size_t i = 0; i < classType->numConstants(); i++) {
74+
builder.addConstant(
75+
classType->getConstantName(i), classType->getConstant(i));
76+
}
6577
}
78+
79+
// Not make_shared because the constructor is private.
6680
auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
6781
ret->jitType_ = std::move(type);
68-
ret->data_.setPoisoned();
82+
ret->data_ = builder;
83+
6984
return ret;
7085
}
7186

@@ -198,6 +213,20 @@ void ConcreteModuleTypeBuilder::setPoisoned() {
198213
void ConcreteModuleTypeBuilder::addConstant(
199214
std::string name,
200215
py::object value) {
216+
auto match = tryToInferType(value);
217+
if (!match.success()) {
218+
TORCH_INTERNAL_ASSERT(
219+
false,
220+
"We need to infer the type of constant to convert the python value to IValue,"
221+
" but failed to infer type of ",
222+
py::str(value),
223+
"\n:",
224+
match.reason());
225+
}
226+
constants_.emplace(std::move(name), toIValue(value, match.type()));
227+
}
228+
229+
void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) {
201230
constants_.emplace(std::move(name), std::move(value));
202231
}
203232

@@ -257,7 +286,7 @@ void ConcreteModuleType::dump() const {
257286
<< py::getattr(data_.pyClass_, "__name__") << "\n";
258287
std::cout << "Constants: \n";
259288
for (const auto& pr : data_.constants_) {
260-
std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
289+
std::cout << "\t" << pr.first << ": " << pr.second << "\n";
261290
}
262291
std::cout << "\nAttributes: \n";
263292
for (const auto& pr : data_.attributes_) {
@@ -286,7 +315,7 @@ std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
286315
// need to bind ConcreteModuleType::Constant as well.
287316
std::unordered_map<std::string, py::object> ret;
288317
for (const auto& pr : data_.constants_) {
289-
ret.emplace(pr.first, pr.second.v_);
318+
ret.emplace(pr.first, toPyObject(pr.second));
290319
}
291320
return ret;
292321
}

torch/csrc/jit/frontend/concrete_module_type.h

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
6161
TORCH_INTERNAL_ASSERT(pyClass);
6262
pyClass_ = std::move(pyClass);
6363
}
64+
6465
void addConstant(std::string name, py::object value);
66+
void addConstant(std::string name, IValue value);
6567
void addAttribute(
6668
std::string name,
6769
TypePtr type,
@@ -94,19 +96,6 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
9496
// implements a meaningful comparison in that context.
9597
bool equals(const ConcreteModuleTypeBuilder& other) const;
9698

97-
struct Constant {
98-
/* implicit */ Constant(py::object v) : v_(std::move(v)) {}
99-
friend bool operator==(const Constant& lhs, const Constant& rhs) {
100-
// Perform the equivalent of `lhs == rhs` in Python.
101-
int rv = PyObject_RichCompareBool(lhs.v_.ptr(), rhs.v_.ptr(), Py_EQ);
102-
if (rv == -1) {
103-
throw py::error_already_set();
104-
}
105-
return rv == 1;
106-
}
107-
py::object v_;
108-
};
109-
11099
struct FunctionAttribute {
111100
FunctionTypePtr function_;
112101
py::object pyFunction_;
@@ -153,7 +142,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
153142
bool isPoisoned_ = false;
154143

155144
// The value of any constants defined by the module.
156-
std::unordered_map<std::string, Constant> constants_;
145+
std::unordered_map<std::string, IValue> constants_;
157146
// The types of any attributes
158147
OrderedDict<std::string, Attribute> attributes_;
159148
// Overloads, in the same format as `__overloads__` in Python

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,15 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
451451
if (selfType->hasAttribute(field) &&
452452
selfType->getAttribute(field)->is_module()) {
453453
// ...if it's a submodule, return it as a new ModuleValue.
454-
const auto submoduleConcreteType =
455-
concreteType_->findSubmoduleConcreteType(field);
454+
if (const auto submoduleConcreteType =
455+
concreteType_->findSubmoduleConcreteType(field)) {
456+
return std::make_shared<ModuleValue>(
457+
m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
458+
}
459+
456460
return std::make_shared<ModuleValue>(
457-
m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
461+
m.graph()->insertGetAttr(self_, field),
462+
ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
458463
} else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
459464
// ...otherwise, methods, parameters, attributes, and buffers are all
460465
// first class so they get returned as SimpleValues

torch/csrc/jit/python/script_init.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,13 @@ void initJitScriptBindings(PyObject* module) {
15331533
std::shared_ptr<ConcreteModuleTypeBuilder>>(
15341534
m, "ConcreteModuleTypeBuilder")
15351535
.def(py::init<py::object>())
1536-
.def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
1536+
.def(
1537+
"add_constant",
1538+
[](ConcreteModuleTypeBuilder& self,
1539+
std::string name,
1540+
py::object value) {
1541+
self.addConstant(std::move(name), std::move(value));
1542+
})
15371543
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
15381544
.def(
15391545
"add_function_attribute",

0 commit comments

Comments
 (0)