Skip to content

Commit a63397a

Browse files
author
Meghan Lele
committed
[JIT] Add property support for ScriptModules
**Summary** This commit extends support for properties to include ScriptModules. **Test Plan** This commit adds a unit test that has a ScriptModule with a user-defined property. `python test/test_jit_py3.py TestScriptPy3.test_module_properties` ghstack-source-id: 4f44e0e Pull Request resolved: #42390
1 parent 5d7c217 commit a63397a

File tree

6 files changed

+137
-52
lines changed

6 files changed

+137
-52
lines changed

test/jit/test_recursive_script.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -607,24 +607,6 @@ def forward(self, x):
607607
m = M()
608608
self.checkModule(m, (torch.randn(5, 5), ))
609609

610-
def test_property(self):
611-
class M(nn.Module):
612-
def __init__(self):
613-
super(M, self).__init__()
614-
self.x = 0
615-
616-
@property
617-
def x_and_1(self):
618-
return self.x + 1
619-
620-
def forward(self, new_x):
621-
# type: (int) -> int
622-
self.x = new_x
623-
return self.x_and_1
624-
625-
with self.assertRaisesRegex(RuntimeError, "property"):
626-
torch.jit.script(M())
627-
628610
def test_inner_traced_module(self):
629611
class Dummy(nn.Module):
630612
def forward(self, x):

test/test_jit_py3.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,30 @@ def forward(self, x):
533533
# Check that ignored method is still intact.
534534
self.assertEqual(inp, n.ignored_method(inp))
535535

536+
def test_module_properties(self):
537+
class ModuleWithProperties(torch.nn.Module):
538+
def __init__(self, a: int):
539+
super().__init__()
540+
self.a = a
541+
542+
def forward(self, a: int, b: int):
543+
self.attr = a + b
544+
return self.attr
545+
546+
@property
547+
def attr(self):
548+
return self.a
549+
550+
@attr.setter
551+
def attr(self, a: int):
552+
if a > 0:
553+
self.a = a
554+
else:
555+
self.a = 0
556+
557+
self.checkModule(ModuleWithProperties(5), (5, 6,))
558+
self.checkModule(ModuleWithProperties(5), (-5, -6,))
559+
536560
def test_export_opnames_interface(self):
537561
global OneTwoModule
538562

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
589589
return attr;
590590
}
591591

592+
// Check if it's a property.
593+
auto prop =
594+
concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
595+
if (prop) {
596+
return MethodValue(self_, prop->getter->name())
597+
.call(loc, m, {}, {}, /*n_binders=*/1);
598+
}
599+
592600
// We don't define this attr. Bailout with a hint to the user.
593601
std::string hint;
594602
if (auto failureReason = concreteType_->findFailedAttribute(field)) {

torch/csrc/jit/python/python_tree_views.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,18 @@ void initTreeViewBindings(PyObject* module) {
162162
const Def& getter,
163163
Def* setter) {
164164
return Property::create(r, name, getter, wrap_maybe(r, setter));
165-
}));
165+
}))
166+
.def("name", [](const Property& property) { return property.name(); })
167+
.def(
168+
"getter_name",
169+
[](const Property& property) { return property.getter().name(); })
170+
.def("setter_name", [](const Property& property) {
171+
if (property.setter().present()) {
172+
return property.setter().get().name();
173+
}
174+
175+
return Ident::create(property.range(), "");
176+
});
166177

167178
py::class_<ClassDef, TreeView>(m, "ClassDef")
168179
.def(py::init([](const Ident& name,

torch/csrc/jit/python/script_init.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,27 +1559,43 @@ void initJitScriptBindings(PyObject* module) {
15591559
return self.equals(other);
15601560
})
15611561
.def(
1562-
"_create_methods",
1562+
"_create_methods_and_properties",
15631563
[](std::shared_ptr<ConcreteModuleType> concreteType,
1564-
const std::vector<Def>& defs,
1565-
const std::vector<ResolutionCallback>& rcbs,
1564+
const std::vector<Property>& properties,
1565+
const std::vector<ResolutionCallback>& propertyRcbs,
1566+
const std::vector<Def>& methodDefs,
1567+
const std::vector<ResolutionCallback>& methodRcbs,
15661568
const std::vector<FunctionDefaults>& defaults) {
1567-
TORCH_INTERNAL_ASSERT(defs.size() == rcbs.size());
1568-
std::vector<ResolverPtr> resolvers;
1569-
resolvers.reserve(rcbs.size());
1570-
for (auto& callback : rcbs) {
1571-
resolvers.push_back(pythonResolver(callback));
1569+
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
1570+
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
1571+
1572+
std::vector<ResolverPtr> methodResolvers, propertyResolvers;
1573+
methodResolvers.reserve(methodRcbs.size());
1574+
for (auto& callback : methodRcbs) {
1575+
methodResolvers.push_back(pythonResolver(callback));
1576+
}
1577+
1578+
propertyResolvers.reserve(propertyRcbs.size());
1579+
for (auto& callback : propertyRcbs) {
1580+
propertyResolvers.push_back(pythonResolver(callback));
15721581
}
1582+
15731583
const auto& selfType =
15741584
concreteType->getJitType()->expect<ClassType>();
15751585
const auto& prefix = selfType->name().value();
15761586
const auto self = ModuleSelf(std::move(concreteType));
15771587
auto cu = selfType->compilation_unit();
1578-
cu->define(prefix, defs, resolvers, &self);
1588+
cu->define(
1589+
prefix,
1590+
properties,
1591+
propertyResolvers,
1592+
methodDefs,
1593+
methodResolvers,
1594+
&self);
15791595
// Stitch in default arguments for each Def if provided
15801596
auto defaults_it = defaults.begin();
1581-
auto defs_it = defs.begin();
1582-
while (defs_it != defs.end()) {
1597+
auto defs_it = methodDefs.begin();
1598+
while (defs_it != methodDefs.end()) {
15831599
const auto method_name =
15841600
QualifiedName(prefix, (*defs_it).name().name());
15851601
auto& method = cu->get_function(method_name);

torch/jit/_recursive.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import warnings
77

88
import torch._jit_internal as _jit_internal
9-
from torch.jit.frontend import get_default_args, get_jit_def
9+
from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties
1010
from torch.jit._builtins import _find_builtin
1111
from torch.nn import Module
1212
from torch._six import get_function_from_type, bind_method
1313

1414

1515
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
16+
PropertyStub = collections.namedtuple('Property', ('resolution_callback', 'def_'))
17+
1618

1719
# TODO: there should be a more principled way of doing this.
1820
ignored_attributes = [
@@ -29,6 +31,16 @@
2931
"dump_patches",
3032
]
3133

34+
ignored_properties = [
35+
# Temporary fix for RNN module property named 'all_weights' being scripted
36+
"all_weights",
37+
"original_name",
38+
"graph",
39+
"inlined_graph",
40+
"code",
41+
"code_with_constants",
42+
]
43+
3244
def make_stub(func, name):
3345
rcb = _jit_internal.createResolutionCallbackFromClosure(func)
3446
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
@@ -48,6 +60,7 @@ def make_stub_from_method(nn_module, method_name):
4860
# even though we requested a stub for `forward`.
4961
return make_stub(func, method_name)
5062

63+
5164
# base types that can be constants
5265
# in addition, tuples and lists of these base types are also considered constants
5366
# If you edit this list, then you also need to edit the handlers in
@@ -239,14 +252,6 @@ def infer_type(name, item):
239252
"to a TorchScript type.)").format(torch.typename(type(value)))
240253
concrete_type_builder.add_failed_attribute(name, hint)
241254

242-
# Add @property methods as failed attributes, to give a better error message.
243-
for name, value in type(nn_module).__dict__.items():
244-
if isinstance(value, property):
245-
hint = ("\n(This attribute exists on the Python module, but it's an @property "
246-
"method. @property methods are not yet supported in TorchScript. "
247-
"Please file a feature request on Github)")
248-
concrete_type_builder.add_failed_attribute(name, hint)
249-
250255
return concrete_type_builder
251256

252257
class ConcreteTypeStore(object):
@@ -285,11 +290,17 @@ def get_or_create_concrete_type(self, nn_module):
285290

286291
concrete_type_store = ConcreteTypeStore()
287292

288-
def create_methods_from_stubs(concrete_type, stubs):
289-
defs = [m.def_ for m in stubs]
290-
rcbs = [m.resolution_callback for m in stubs]
291-
defaults = [get_default_args(m.original_method) for m in stubs]
292-
concrete_type._create_methods(defs, rcbs, defaults)
293+
294+
def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
295+
method_defs = [m.def_ for m in method_stubs]
296+
method_rcbs = [m.resolution_callback for m in method_stubs]
297+
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
298+
299+
property_defs = [p.def_ for p in property_stubs]
300+
property_rcbs = [p.resolution_callback for p in property_stubs]
301+
302+
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
303+
293304

294305
def create_script_module(nn_module, stubs_fn, share_types=True):
295306
"""
@@ -326,7 +337,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
326337
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
327338
"""
328339
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
329-
stubs = stubs_fn(nn_module)
340+
method_stubs = stubs_fn(nn_module)
341+
property_stubs = get_property_stubs(nn_module)
330342

331343
def init_fn(script_module):
332344
# Initialize the ScriptModule:
@@ -373,7 +385,7 @@ def init_fn(script_module):
373385

374386
# Compile methods if necessary
375387
if concrete_type not in concrete_type_store.methods_compiled:
376-
create_methods_from_stubs(concrete_type, stubs)
388+
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
377389
torch._C._run_emit_module_hook(cpp_module)
378390
concrete_type_store.methods_compiled.add(concrete_type)
379391

@@ -391,14 +403,14 @@ def init_fn(script_module):
391403

392404

393405
# Make the compiled methods available to the Python ScriptModule class.
394-
for stub in stubs:
395-
if stub.original_method is None:
406+
for method_stub in method_stubs:
407+
if method_stub.original_method is None:
396408
# define()'d methods don't have an Python original_method, so we
397409
# don't need to do any Python re-wrapping stuff
398410
continue
399411

400-
name = stub.original_method.__name__
401-
if name != stub.def_.name().name:
412+
name = method_stub.original_method.__name__
413+
if name != method_stub.def_.name().name:
402414
# TODO: Why skip this? Because @torch.jit._overload_method will
403415
# mangle the name of the function.
404416
continue
@@ -407,14 +419,20 @@ def init_fn(script_module):
407419
# Wrap the original to propagate docstrings and such.
408420
# TODO: we don't currently do this functions that are recursively
409421
# compiled, we should.
410-
script_method = functools.wraps(stub.original_method)(script_method)
422+
script_method = functools.wraps(method_stub.original_method)(script_method)
411423

412424
# Add the methods to the script_module directly. This ensures they will
413425
# be found first when `name` is looked up (as opposed to the stubs or
414426
# nn.Module.forward)
415427
script_module.__dict__[name] = script_method
416428

417429

430+
for property_stub in property_stubs:
431+
property_name = property_stub.def_.name().name
432+
fget = cpp_module._get_method(property_stub.def_.getter_name().name)
433+
fset = cpp_module._get_method(property_stub.def_.setter_name().name)
434+
script_module.__dict__[property_name] = property(property_name, fget, fset)
435+
418436
# copy over python methods to script module if they aren't defined on the script module
419437
# this is currently an internal api used only on module containers
420438
for name in dir(nn_module):
@@ -548,6 +566,32 @@ def ignore_overloaded(method_name):
548566
stubs.append(make_stub_from_method(nn_module, method))
549567
return overload_stubs + stubs
550568

569+
570+
def get_property_stubs(nn_module):
571+
"""
572+
Create property stubs for the properties of the module by creating method
573+
stubs for the getter and setter.
574+
"""
575+
module_ty = type(nn_module)
576+
properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
577+
rcbs = {}
578+
579+
for name in dir(module_ty):
580+
item = getattr(module_ty, name, None)
581+
if isinstance(item, property) and name not in ignored_properties:
582+
if not item.fget:
583+
raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter')
584+
585+
rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
586+
587+
stubs = []
588+
for ast in properties_asts:
589+
if ast.name().name not in ignored_properties:
590+
stubs.append(PropertyStub(rcbs[ast.name().name], ast))
591+
592+
return stubs
593+
594+
551595
def interface_script(mod_interface, nn_module):
552596
"""
553597
Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -612,7 +656,7 @@ def compile_unbound_method(concrete_type, fn):
612656
with torch._jit_internal._disable_emit_hooks():
613657
# We don't want to call the hooks here since the graph that is calling
614658
# this function is not yet complete
615-
create_methods_from_stubs(concrete_type, (stub,))
659+
create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
616660
return stub
617661

618662
def lazy_bind(concrete_type, unbound_method):

0 commit comments

Comments
 (0)