Skip to content

Commit 40fa5b7

Browse files
Meghan Lelexuzhao9
authored andcommitted
[JIT] Add property support for ScriptModules (#42390)
Summary: Pull Request resolved: #42390 **Summary** This commit extends support for properties to include ScriptModules. **Test Plan** This commit adds a unit test that has a ScriptModule with a user-defined property. `python test/test_jit_py3.py TestScriptPy3.test_module_properties` Test Plan: Imported from OSS Reviewed By: eellison, mannatsingh Differential Revision: D22880298 Pulled By: SplitInfinity fbshipit-source-id: 74f6cb80f716084339e2151ca25092b6341a1560
1 parent 1489da5 commit 40fa5b7

File tree

9 files changed

+159
-61
lines changed

9 files changed

+159
-61
lines changed

test/jit/test_recursive_script.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -625,24 +625,6 @@ def forward(self, x):
625625
m = M()
626626
self.checkModule(m, (torch.randn(5, 5), ))
627627

628-
def test_property(self):
629-
class M(nn.Module):
630-
def __init__(self):
631-
super(M, self).__init__()
632-
self.x = 0
633-
634-
@property
635-
def x_and_1(self):
636-
return self.x + 1
637-
638-
def forward(self, new_x):
639-
# type: (int) -> int
640-
self.x = new_x
641-
return self.x_and_1
642-
643-
with self.assertRaisesRegex(RuntimeError, "property"):
644-
torch.jit.script(M())
645-
646628
def test_inner_traced_module(self):
647629
class Dummy(nn.Module):
648630
def forward(self, x):

test/test_jit_py3.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,56 @@ def if_function(inp: torch.Tensor) -> Any:
556556

557557
self.checkScript(if_function, (torch.randn(5),))
558558

559+
def test_module_properties(self):
560+
class ModuleWithProperties(torch.nn.Module):
561+
__ignored_properties__ = ["ignored_attr"]
562+
563+
def __init__(self, a: int):
564+
super().__init__()
565+
self.a = a
566+
567+
def forward(self, a: int, b: int):
568+
self.attr = a + b
569+
return self.attr
570+
571+
@property
572+
def attr(self):
573+
return self.a
574+
575+
@property
576+
def ignored_attr(self):
577+
return sum([self.a])
578+
579+
@attr.setter
580+
def attr(self, a: int):
581+
if a > 0:
582+
self.a = a
583+
else:
584+
self.a = 0
585+
586+
class ModuleWithNoSetter(torch.nn.Module):
587+
def __init__(self, a: int):
588+
super().__init__()
589+
self.a = a
590+
591+
def forward(self, a: int, b: int):
592+
self.attr + a + b
593+
594+
@property
595+
def attr(self):
596+
return self.a + 1
597+
598+
self.checkModule(ModuleWithProperties(5), (5, 6,))
599+
self.checkModule(ModuleWithProperties(5), (-5, -6,))
600+
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
601+
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
602+
603+
mod = ModuleWithProperties(3)
604+
scripted_mod = torch.jit.script(mod)
605+
606+
with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"):
607+
scripted_mod.ignored_attr
608+
559609
def test_export_opnames_interface(self):
560610
global OneTwoModule
561611

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
573573
return attr;
574574
}
575575

576+
// Check if it's a property.
577+
auto prop =
578+
concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
579+
if (prop) {
580+
return MethodValue(self_, prop->getter->name())
581+
.call(loc, m, {}, {}, /*n_binders=*/1);
582+
}
583+
576584
// We don't define this attr. Bailout with a hint to the user.
577585
std::string hint;
578586
if (auto failureReason = concreteType_->findFailedAttribute(field)) {

torch/csrc/jit/python/python_tree_views.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <pybind11/pybind11.h>
66
#include <pybind11/stl.h>
7+
#include <torch/csrc/utils/pybind.h>
78

89
#include <sstream>
910

@@ -168,7 +169,18 @@ void initTreeViewBindings(PyObject* module) {
168169
const Def& getter,
169170
Def* setter) {
170171
return Property::create(r, name, getter, wrap_maybe(r, setter));
171-
}));
172+
}))
173+
.def("name", [](const Property& property) { return property.name(); })
174+
.def(
175+
"getter_name",
176+
[](const Property& property) { return property.getter().name(); })
177+
.def("setter_name", [](const Property& property) {
178+
if (property.setter().present()) {
179+
return c10::optional<Ident>(property.setter().get().name());
180+
}
181+
182+
return c10::optional<Ident>(c10::nullopt);
183+
});
172184

173185
py::class_<ClassDef, TreeView>(m, "ClassDef")
174186
.def(py::init([](const Ident& name,

torch/csrc/jit/python/script_init.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,33 +1590,43 @@ void initJitScriptBindings(PyObject* module) {
15901590
return self.equals(other);
15911591
})
15921592
.def(
1593-
"_create_methods",
1593+
"_create_methods_and_properties",
15941594
[](std::shared_ptr<ConcreteModuleType> concreteType,
1595-
const std::vector<Def>& defs,
1596-
const std::vector<ResolutionCallback>& rcbs,
1595+
const std::vector<Property>& properties,
1596+
const std::vector<ResolutionCallback>& propertyRcbs,
1597+
const std::vector<Def>& methodDefs,
1598+
const std::vector<ResolutionCallback>& methodRcbs,
15971599
const std::vector<FunctionDefaults>& defaults) {
1598-
TORCH_INTERNAL_ASSERT(defs.size() == rcbs.size());
1599-
std::vector<ResolverPtr> resolvers;
1600-
resolvers.reserve(rcbs.size());
1601-
for (auto& callback : rcbs) {
1602-
resolvers.push_back(pythonResolver(callback));
1600+
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
1601+
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
1602+
1603+
std::vector<ResolverPtr> methodResolvers, propertyResolvers;
1604+
methodResolvers.reserve(methodRcbs.size());
1605+
for (auto& callback : methodRcbs) {
1606+
methodResolvers.push_back(pythonResolver(callback));
1607+
}
1608+
1609+
propertyResolvers.reserve(propertyRcbs.size());
1610+
for (auto& callback : propertyRcbs) {
1611+
propertyResolvers.push_back(pythonResolver(callback));
16031612
}
1613+
16041614
const auto& selfType =
16051615
concreteType->getJitType()->expect<ClassType>();
16061616
const auto& prefix = selfType->name().value();
16071617
const auto self = ModuleSelf(std::move(concreteType));
16081618
auto cu = selfType->compilation_unit();
16091619
cu->define(
16101620
prefix,
1611-
/*properties=*/{},
1612-
/*propResolvers=*/{},
1613-
defs,
1614-
resolvers,
1621+
properties,
1622+
propertyResolvers,
1623+
methodDefs,
1624+
methodResolvers,
16151625
&self);
16161626
// Stitch in default arguments for each Def if provided
16171627
auto defaults_it = defaults.begin();
1618-
auto defs_it = defs.begin();
1619-
while (defs_it != defs.end()) {
1628+
auto defs_it = methodDefs.begin();
1629+
while (defs_it != methodDefs.end()) {
16201630
const auto method_name =
16211631
QualifiedName(prefix, (*defs_it).name().name());
16221632
auto& method = cu->get_function(method_name);

torch/jit/_recursive.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from typing import Dict, List, Set, Type
88

99
import torch._jit_internal as _jit_internal
10-
from torch.jit.frontend import get_default_args, get_jit_def
10+
from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties
1111
from torch.jit._builtins import _find_builtin
1212
from torch.nn import Module
1313
from torch._six import get_function_from_type, bind_method
1414

1515

1616
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
17+
PropertyStub = collections.namedtuple('Property', ('resolution_callback', 'def_'))
18+
1719

1820
# TODO: there should be a more principled way of doing this.
1921
ignored_attributes = [
@@ -49,6 +51,7 @@ def make_stub_from_method(nn_module, method_name):
4951
# even though we requested a stub for `forward`.
5052
return make_stub(func, method_name)
5153

54+
5255
# base types that can be constants
5356
# in addition, tuples and lists of these base types are also considered constants
5457
# If you edit this list, then you also need to edit the handlers in
@@ -240,14 +243,6 @@ def infer_type(name, item):
240243
"to a TorchScript type.)").format(torch.typename(type(value)))
241244
concrete_type_builder.add_failed_attribute(name, hint)
242245

243-
# Add @property methods as failed attributes, to give a better error message.
244-
for name, value in type(nn_module).__dict__.items():
245-
if isinstance(value, property):
246-
hint = ("\n(This attribute exists on the Python module, but it's an @property "
247-
"method. @property methods are not yet supported in TorchScript. "
248-
"Please file a feature request on Github)")
249-
concrete_type_builder.add_failed_attribute(name, hint)
250-
251246
return concrete_type_builder
252247

253248
class ConcreteTypeStore(object):
@@ -284,11 +279,17 @@ def get_or_create_concrete_type(self, nn_module):
284279

285280
concrete_type_store = ConcreteTypeStore()
286281

287-
def create_methods_from_stubs(concrete_type, stubs):
288-
defs = [m.def_ for m in stubs]
289-
rcbs = [m.resolution_callback for m in stubs]
290-
defaults = [get_default_args(m.original_method) for m in stubs]
291-
concrete_type._create_methods(defs, rcbs, defaults)
282+
283+
def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
284+
method_defs = [m.def_ for m in method_stubs]
285+
method_rcbs = [m.resolution_callback for m in method_stubs]
286+
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
287+
288+
property_defs = [p.def_ for p in property_stubs]
289+
property_rcbs = [p.resolution_callback for p in property_stubs]
290+
291+
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
292+
292293

293294
def get_module_concrete_type(nn_module, share_types=True):
294295
"""
@@ -347,7 +348,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
347348
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
348349
"""
349350
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
350-
stubs = stubs_fn(nn_module)
351+
method_stubs = stubs_fn(nn_module)
352+
property_stubs = get_property_stubs(nn_module)
351353

352354
def init_fn(script_module):
353355
# Initialize the ScriptModule:
@@ -379,9 +381,7 @@ def init_fn(script_module):
379381
# This ensures we can access these Python methods on the ScriptModule.
380382
for name in dir(nn_module):
381383
item = getattr(nn_module, name, None)
382-
if not inspect.ismethod(item):
383-
continue
384-
if _jit_internal.is_ignored_fn(item):
384+
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
385385
unbound_function = getattr(type(nn_module), name)
386386
bound_method = unbound_function.__get__(script_module)
387387
setattr(script_module, name, bound_method)
@@ -394,7 +394,7 @@ def init_fn(script_module):
394394

395395
# Compile methods if necessary
396396
if concrete_type not in concrete_type_store.methods_compiled:
397-
create_methods_from_stubs(concrete_type, stubs)
397+
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
398398
torch._C._run_emit_module_hook(cpp_module)
399399
concrete_type_store.methods_compiled.add(concrete_type)
400400

@@ -412,14 +412,14 @@ def init_fn(script_module):
412412

413413

414414
# Make the compiled methods available to the Python ScriptModule class.
415-
for stub in stubs:
416-
if stub.original_method is None:
415+
for method_stub in method_stubs:
416+
if method_stub.original_method is None:
417417
# define()'d methods don't have an Python original_method, so we
418418
# don't need to do any Python re-wrapping stuff
419419
continue
420420

421-
name = stub.original_method.__name__
422-
if name != stub.def_.name().name:
421+
name = method_stub.original_method.__name__
422+
if name != method_stub.def_.name().name:
423423
# TODO: Why skip this? Because @torch.jit._overload_method will
424424
# mangle the name of the function.
425425
continue
@@ -428,14 +428,23 @@ def init_fn(script_module):
428428
# Wrap the original to propagate docstrings and such.
429429
# TODO: we don't currently do this functions that are recursively
430430
# compiled, we should.
431-
wrapped_script_method = functools.wraps(stub.original_method)(script_method) # type: ignore
431+
wrapped_script_method = functools.wraps(method_stub.original_method)(script_method) # type: ignore
432432

433433
# Add the methods to the script_module directly. This ensures they will
434434
# be found first when `name` is looked up (as opposed to the stubs or
435435
# nn.Module.forward)
436436
script_module.__dict__[name] = wrapped_script_method
437437

438438

439+
# Make module properties available on the Python ScriptModule class.
440+
for property_stub in property_stubs:
441+
property_name = property_stub.def_.name().name
442+
fget = cpp_module._get_method(property_stub.def_.getter_name().name)
443+
# Setter is optional, so it may not exist.
444+
setter_name = property_stub.def_.setter_name()
445+
fset = cpp_module._get_method(setter_name.name) if setter_name else None
446+
script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore
447+
439448
# copy over python methods to script module if they aren't defined on the script module
440449
# this is currently an internal api used only on module containers
441450
for name in dir(nn_module):
@@ -569,6 +578,28 @@ def ignore_overloaded(method_name):
569578
stubs.append(make_stub_from_method(nn_module, method))
570579
return overload_stubs + stubs
571580

581+
582+
def get_property_stubs(nn_module):
583+
"""
584+
Create property stubs for the properties of the module by creating method
585+
stubs for the getter and setter.
586+
"""
587+
module_ty = type(nn_module)
588+
properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
589+
rcbs = {}
590+
591+
for name in dir(module_ty):
592+
item = getattr(module_ty, name, None)
593+
if isinstance(item, property):
594+
if not item.fget:
595+
raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter')
596+
597+
rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
598+
599+
stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
600+
return stubs
601+
602+
572603
def interface_script(mod_interface, nn_module):
573604
"""
574605
Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -633,7 +664,7 @@ def compile_unbound_method(concrete_type, fn):
633664
with torch._jit_internal._disable_emit_hooks():
634665
# We don't want to call the hooks here since the graph that is calling
635666
# this function is not yet complete
636-
create_methods_from_stubs(concrete_type, (stub,))
667+
create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
637668
return stub
638669

639670
def lazy_bind(concrete_type, unbound_method):

torch/jit/_script.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
272272
contain methods, attributes, parameters, and
273273
constants. These can be accessed the same as on a normal ``nn.Module``.
274274
"""
275+
__ignored_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
275276

276277
def __init__(self):
277278
super(ScriptModule, self).__init__()

torch/jit/frontend.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,16 @@ def get_class_properties(cls, self_name):
141141
"""
142142
props = inspect.getmembers(
143143
cls, predicate=lambda m: isinstance(m, property))
144+
# Any property that should not compiled must be in this list on the Module.
145+
ignored_properties = getattr(cls, "__ignored_properties__", [])
144146

145147
# Create Property TreeView objects from inspected property objects.
146148
properties = []
147149
for prop in props:
148-
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
149-
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
150-
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))
150+
if prop[0] not in ignored_properties:
151+
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
152+
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
153+
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))
151154

152155
return properties
153156

torch/nn/modules/rnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens
2424
class RNNBase(Module):
2525
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
2626
'batch_first', 'dropout', 'bidirectional']
27+
__ignored_properties__ = ['all_weights']
2728

2829
mode: str
2930
input_size: int

0 commit comments

Comments
 (0)