Skip to content

Commit 0e428f4

Browse files
author
Meghan Lele
committed
[JIT] Add property support for ScriptModules
**Summary** This commit extends support for properties to include ScriptModules. **Test Plan** This commit adds a unit test that has a ScriptModule with a user-defined property. `python test/test_jit_py3.py TestScriptPy3.test_module_properties` ghstack-source-id: 4d97cd7 Pull Request resolved: #42390
1 parent 06aaf8c commit 0e428f4

File tree

10 files changed

+173
-63
lines changed

10 files changed

+173
-63
lines changed

test/jit/test_recursive_script.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -625,24 +625,6 @@ def forward(self, x):
625625
m = M()
626626
self.checkModule(m, (torch.randn(5, 5), ))
627627

628-
def test_property(self):
629-
class M(nn.Module):
630-
def __init__(self):
631-
super(M, self).__init__()
632-
self.x = 0
633-
634-
@property
635-
def x_and_1(self):
636-
return self.x + 1
637-
638-
def forward(self, new_x):
639-
# type: (int) -> int
640-
self.x = new_x
641-
return self.x_and_1
642-
643-
with self.assertRaisesRegex(RuntimeError, "property"):
644-
torch.jit.script(M())
645-
646628
def test_inner_traced_module(self):
647629
class Dummy(nn.Module):
648630
def forward(self, x):

test/test_jit_py3.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,55 @@ def forward(self, x):
533533
# Check that ignored method is still intact.
534534
self.assertEqual(inp, n.ignored_method(inp))
535535

536+
def test_module_properties(self):
537+
class ModuleWithProperties(torch.nn.Module):
538+
def __init__(self, a: int):
539+
super().__init__()
540+
self.a = a
541+
542+
def forward(self, a: int, b: int):
543+
self.attr = a + b
544+
return self.attr
545+
546+
@property
547+
def attr(self):
548+
return self.a
549+
550+
@torch.jit.ignore
551+
@property
552+
def ignored_attr(self):
553+
return sum([self.a])
554+
555+
@attr.setter
556+
def attr(self, a: int):
557+
if a > 0:
558+
self.a = a
559+
else:
560+
self.a = 0
561+
562+
class ModuleWithNoSetter(torch.nn.Module):
563+
def __init__(self, a: int):
564+
super().__init__()
565+
self.a = a
566+
567+
def forward(self, a: int, b: int):
568+
self.attr + a + b
569+
570+
@property
571+
def attr(self):
572+
return self.a + 1
573+
574+
self.checkModule(ModuleWithProperties(5), (5, 6,))
575+
self.checkModule(ModuleWithProperties(5), (-5, -6,))
576+
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
577+
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
578+
579+
mod = ModuleWithProperties(3)
580+
scripted_mod = torch.jit.script(mod)
581+
582+
with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"):
583+
scripted_mod.ignored_attr
584+
536585
def test_export_opnames_interface(self):
537586
global OneTwoModule
538587

torch/_jit_internal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,15 @@ def forward(self, x):
463463
fn._torchscript_modifier = FunctionModifiers.IGNORE
464464
return fn
465465

466+
if isinstance(drop, property):
467+
prop = drop
468+
setattr(prop.fget, "_torchscript_modifier", FunctionModifiers.IGNORE)
469+
470+
if prop.fset:
471+
setattr(prop.fset, "_torchscript_modifier", FunctionModifiers.IGNORE)
472+
473+
return prop
474+
466475
if not isinstance(drop, bool):
467476
raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
468477
"a function but got {}".format(drop))

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
599599
return attr;
600600
}
601601

602+
// Check if it's a property.
603+
auto prop =
604+
concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
605+
if (prop) {
606+
return MethodValue(self_, prop->getter->name())
607+
.call(loc, m, {}, {}, /*n_binders=*/1);
608+
}
609+
602610
// We don't define this attr. Bailout with a hint to the user.
603611
std::string hint;
604612
if (auto failureReason = concreteType_->findFailedAttribute(field)) {

torch/csrc/jit/python/python_tree_views.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <pybind11/pybind11.h>
66
#include <pybind11/stl.h>
7+
#include <torch/csrc/utils/pybind.h>
78

89
#include <sstream>
910

@@ -162,7 +163,18 @@ void initTreeViewBindings(PyObject* module) {
162163
const Def& getter,
163164
Def* setter) {
164165
return Property::create(r, name, getter, wrap_maybe(r, setter));
165-
}));
166+
}))
167+
.def("name", [](const Property& property) { return property.name(); })
168+
.def(
169+
"getter_name",
170+
[](const Property& property) { return property.getter().name(); })
171+
.def("setter_name", [](const Property& property) {
172+
if (property.setter().present()) {
173+
return c10::optional<Ident>(property.setter().get().name());
174+
}
175+
176+
return c10::optional<Ident>(c10::nullopt);
177+
});
166178

167179
py::class_<ClassDef, TreeView>(m, "ClassDef")
168180
.def(py::init([](const Ident& name,

torch/csrc/jit/python/script_init.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,33 +1569,43 @@ void initJitScriptBindings(PyObject* module) {
15691569
return self.equals(other);
15701570
})
15711571
.def(
1572-
"_create_methods",
1572+
"_create_methods_and_properties",
15731573
[](std::shared_ptr<ConcreteModuleType> concreteType,
1574-
const std::vector<Def>& defs,
1575-
const std::vector<ResolutionCallback>& rcbs,
1574+
const std::vector<Property>& properties,
1575+
const std::vector<ResolutionCallback>& propertyRcbs,
1576+
const std::vector<Def>& methodDefs,
1577+
const std::vector<ResolutionCallback>& methodRcbs,
15761578
const std::vector<FunctionDefaults>& defaults) {
1577-
TORCH_INTERNAL_ASSERT(defs.size() == rcbs.size());
1578-
std::vector<ResolverPtr> resolvers;
1579-
resolvers.reserve(rcbs.size());
1580-
for (auto& callback : rcbs) {
1581-
resolvers.push_back(pythonResolver(callback));
1579+
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
1580+
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
1581+
1582+
std::vector<ResolverPtr> methodResolvers, propertyResolvers;
1583+
methodResolvers.reserve(methodRcbs.size());
1584+
for (auto& callback : methodRcbs) {
1585+
methodResolvers.push_back(pythonResolver(callback));
1586+
}
1587+
1588+
propertyResolvers.reserve(propertyRcbs.size());
1589+
for (auto& callback : propertyRcbs) {
1590+
propertyResolvers.push_back(pythonResolver(callback));
15821591
}
1592+
15831593
const auto& selfType =
15841594
concreteType->getJitType()->expect<ClassType>();
15851595
const auto& prefix = selfType->name().value();
15861596
const auto self = ModuleSelf(std::move(concreteType));
15871597
auto cu = selfType->compilation_unit();
15881598
cu->define(
15891599
prefix,
1590-
/*properties=*/{},
1591-
/*propResolvers=*/{},
1592-
defs,
1593-
resolvers,
1600+
properties,
1601+
propertyResolvers,
1602+
methodDefs,
1603+
methodResolvers,
15941604
&self);
15951605
// Stitch in default arguments for each Def if provided
15961606
auto defaults_it = defaults.begin();
1597-
auto defs_it = defs.begin();
1598-
while (defs_it != defs.end()) {
1607+
auto defs_it = methodDefs.begin();
1608+
while (defs_it != methodDefs.end()) {
15991609
const auto method_name =
16001610
QualifiedName(prefix, (*defs_it).name().name());
16011611
auto& method = cu->get_function(method_name);

torch/jit/_recursive.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import warnings
77

88
import torch._jit_internal as _jit_internal
9-
from torch.jit.frontend import get_default_args, get_jit_def
9+
from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties
1010
from torch.jit._builtins import _find_builtin
1111
from torch.nn import Module
1212
from torch._six import get_function_from_type, bind_method
1313

1414

1515
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
16+
PropertyStub = collections.namedtuple('Property', ('resolution_callback', 'def_'))
17+
1618

1719
# TODO: there should be a more principled way of doing this.
1820
ignored_attributes = [
@@ -48,6 +50,7 @@ def make_stub_from_method(nn_module, method_name):
4850
# even though we requested a stub for `forward`.
4951
return make_stub(func, method_name)
5052

53+
5154
# base types that can be constants
5255
# in addition, tuples and lists of these base types are also considered constants
5356
# If you edit this list, then you also need to edit the handlers in
@@ -239,14 +242,6 @@ def infer_type(name, item):
239242
"to a TorchScript type.)").format(torch.typename(type(value)))
240243
concrete_type_builder.add_failed_attribute(name, hint)
241244

242-
# Add @property methods as failed attributes, to give a better error message.
243-
for name, value in type(nn_module).__dict__.items():
244-
if isinstance(value, property):
245-
hint = ("\n(This attribute exists on the Python module, but it's an @property "
246-
"method. @property methods are not yet supported in TorchScript. "
247-
"Please file a feature request on Github)")
248-
concrete_type_builder.add_failed_attribute(name, hint)
249-
250245
return concrete_type_builder
251246

252247
class ConcreteTypeStore(object):
@@ -285,11 +280,17 @@ def get_or_create_concrete_type(self, nn_module):
285280

286281
concrete_type_store = ConcreteTypeStore()
287282

288-
def create_methods_from_stubs(concrete_type, stubs):
289-
defs = [m.def_ for m in stubs]
290-
rcbs = [m.resolution_callback for m in stubs]
291-
defaults = [get_default_args(m.original_method) for m in stubs]
292-
concrete_type._create_methods(defs, rcbs, defaults)
283+
284+
def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
285+
method_defs = [m.def_ for m in method_stubs]
286+
method_rcbs = [m.resolution_callback for m in method_stubs]
287+
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
288+
289+
property_defs = [p.def_ for p in property_stubs]
290+
property_rcbs = [p.resolution_callback for p in property_stubs]
291+
292+
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
293+
293294

294295
def create_script_module(nn_module, stubs_fn, share_types=True):
295296
"""
@@ -326,7 +327,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
326327
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
327328
"""
328329
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
329-
stubs = stubs_fn(nn_module)
330+
method_stubs = stubs_fn(nn_module)
331+
property_stubs = get_property_stubs(nn_module)
330332

331333
def init_fn(script_module):
332334
# Initialize the ScriptModule:
@@ -354,13 +356,11 @@ def init_fn(script_module):
354356
cpp_module.setattr(name, scripted)
355357
script_module._modules[name] = scripted
356358

357-
# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
359+
# 3. Copy @ignored/@unused methods and properties from the original `nn_module` to the new ScriptModule.
358360
# This ensures we can access these Python methods on the ScriptModule.
359361
for name in dir(nn_module):
360362
item = getattr(nn_module, name, None)
361-
if not inspect.ismethod(item):
362-
continue
363-
if _jit_internal.is_ignored_fn(item):
363+
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
364364
unbound_function = getattr(type(nn_module), name)
365365
bound_method = unbound_function.__get__(script_module)
366366
setattr(script_module, name, bound_method)
@@ -373,7 +373,7 @@ def init_fn(script_module):
373373

374374
# Compile methods if necessary
375375
if concrete_type not in concrete_type_store.methods_compiled:
376-
create_methods_from_stubs(concrete_type, stubs)
376+
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
377377
torch._C._run_emit_module_hook(cpp_module)
378378
concrete_type_store.methods_compiled.add(concrete_type)
379379

@@ -391,14 +391,14 @@ def init_fn(script_module):
391391

392392

393393
# Make the compiled methods available to the Python ScriptModule class.
394-
for stub in stubs:
395-
if stub.original_method is None:
394+
for method_stub in method_stubs:
395+
if method_stub.original_method is None:
396396
# define()'d methods don't have an Python original_method, so we
397397
# don't need to do any Python re-wrapping stuff
398398
continue
399399

400-
name = stub.original_method.__name__
401-
if name != stub.def_.name().name:
400+
name = method_stub.original_method.__name__
401+
if name != method_stub.def_.name().name:
402402
# TODO: Why skip this? Because @torch.jit._overload_method will
403403
# mangle the name of the function.
404404
continue
@@ -407,14 +407,23 @@ def init_fn(script_module):
407407
# Wrap the original to propagate docstrings and such.
408408
# TODO: we don't currently do this functions that are recursively
409409
# compiled, we should.
410-
script_method = functools.wraps(stub.original_method)(script_method)
410+
script_method = functools.wraps(method_stub.original_method)(script_method)
411411

412412
# Add the methods to the script_module directly. This ensures they will
413413
# be found first when `name` is looked up (as opposed to the stubs or
414414
# nn.Module.forward)
415415
script_module.__dict__[name] = script_method
416416

417417

418+
# Make module properties available on the Python ScriptModule class.
419+
for property_stub in property_stubs:
420+
property_name = property_stub.def_.name().name
421+
fget = cpp_module._get_method(property_stub.def_.getter_name().name)
422+
# Setter is optional, so it may not exist.
423+
setter_name = property_stub.def_.setter_name()
424+
fset = cpp_module._get_method(setter_name.name) if setter_name else None
425+
script_module.__dict__[property_name] = property(property_name, fget, fset)
426+
418427
# copy over python methods to script module if they aren't defined on the script module
419428
# this is currently an internal api used only on module containers
420429
for name in dir(nn_module):
@@ -548,6 +557,28 @@ def ignore_overloaded(method_name):
548557
stubs.append(make_stub_from_method(nn_module, method))
549558
return overload_stubs + stubs
550559

560+
561+
def get_property_stubs(nn_module):
562+
"""
563+
Create property stubs for the properties of the module by creating method
564+
stubs for the getter and setter.
565+
"""
566+
module_ty = type(nn_module)
567+
properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
568+
rcbs = {}
569+
570+
for name in dir(module_ty):
571+
item = getattr(module_ty, name, None)
572+
if isinstance(item, property):
573+
if not item.fget:
574+
raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter')
575+
576+
rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
577+
578+
stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
579+
return stubs
580+
581+
551582
def interface_script(mod_interface, nn_module):
552583
"""
553584
Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -612,7 +643,7 @@ def compile_unbound_method(concrete_type, fn):
612643
with torch._jit_internal._disable_emit_hooks():
613644
# We don't want to call the hooks here since the graph that is calling
614645
# this function is not yet complete
615-
create_methods_from_stubs(concrete_type, (stub,))
646+
create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
616647
return stub
617648

618649
def lazy_bind(concrete_type, unbound_method):

0 commit comments

Comments
 (0)