Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
37ddaa3
[JIT] Add property support for ScriptModules
Jul 31, 2020
d755d14
Update on "[JIT] Add property support for ScriptModules"
Aug 4, 2020
0f29e08
Update on "[JIT] Add property support for ScriptModules"
Aug 5, 2020
20e5794
Update on "[JIT] Add property support for ScriptModules"
Aug 6, 2020
0f51010
Update on "[JIT] Add property support for ScriptModules"
Aug 11, 2020
251e3fb
Update on "[JIT] Add property support for ScriptModules"
Aug 11, 2020
67f2e03
Update on "[JIT] Add property support for ScriptModules"
Aug 11, 2020
09907d2
Update on "[JIT] Add property support for ScriptModules"
Aug 12, 2020
43461f3
Update on "[JIT] Add property support for ScriptModules"
Aug 12, 2020
74cb4cb
Update on "[JIT] Add property support for ScriptModules"
Aug 12, 2020
5d4e583
Update on "[JIT] Add property support for ScriptModules"
Aug 13, 2020
e921537
Update on "[JIT] Add property support for ScriptModules"
Aug 13, 2020
12930f2
Update on "[JIT] Add property support for ScriptModules"
Aug 14, 2020
1a8fcf9
Update on "[JIT] Add property support for ScriptModules"
Aug 14, 2020
a3ca37d
Update on "[JIT] Add property support for ScriptModules"
Aug 14, 2020
80b678b
Update on "[JIT] Add property support for ScriptModules"
Aug 15, 2020
7faa808
Update on "[JIT] Add property support for ScriptModules"
Aug 15, 2020
fbbae46
Update on "[JIT] Add property support for ScriptModules"
Aug 24, 2020
667851c
Update on "[JIT] Add property support for ScriptModules"
Sep 10, 2020
31b98c1
Update on "[JIT] Add property support for ScriptModules"
Sep 10, 2020
36fe6bb
Update on "[JIT] Add property support for ScriptModules"
Sep 10, 2020
c2e2680
Update on "[JIT] Add property support for ScriptModules"
Sep 11, 2020
2bb610e
Update on "[JIT] Add property support for ScriptModules"
Sep 11, 2020
d22e82a
Update on "[JIT] Add property support for ScriptModules"
Sep 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions test/jit/test_recursive_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,24 +625,6 @@ def forward(self, x):
m = M()
self.checkModule(m, (torch.randn(5, 5), ))

def test_property(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.x = 0

@property
def x_and_1(self):
return self.x + 1

def forward(self, new_x):
# type: (int) -> int
self.x = new_x
return self.x_and_1

with self.assertRaisesRegex(RuntimeError, "property"):
torch.jit.script(M())

def test_inner_traced_module(self):
class Dummy(nn.Module):
def forward(self, x):
Expand Down
50 changes: 50 additions & 0 deletions test/test_jit_py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,56 @@ def if_function(inp: torch.Tensor) -> Any:

self.checkScript(if_function, (torch.randn(5),))

def test_module_properties(self):
class ModuleWithProperties(torch.nn.Module):
__ignored_properties__ = ["ignored_attr"]

def __init__(self, a: int):
super().__init__()
self.a = a

def forward(self, a: int, b: int):
self.attr = a + b
return self.attr

@property
def attr(self):
return self.a

@property
def ignored_attr(self):
return sum([self.a])

@attr.setter
def attr(self, a: int):
if a > 0:
self.a = a
else:
self.a = 0

class ModuleWithNoSetter(torch.nn.Module):
def __init__(self, a: int):
super().__init__()
self.a = a

def forward(self, a: int, b: int):
self.attr + a + b

@property
def attr(self):
return self.a + 1

self.checkModule(ModuleWithProperties(5), (5, 6,))
self.checkModule(ModuleWithProperties(5), (-5, -6,))
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))

mod = ModuleWithProperties(3)
scripted_mod = torch.jit.script(mod)

with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"):
scripted_mod.ignored_attr

def test_export_opnames_interface(self):
global OneTwoModule

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/python/python_sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
return attr;
}

// Check if it's a property.
auto prop =
concreteType_->getJitType()->expect<ClassType>()->getProperty(field);
if (prop) {
return MethodValue(self_, prop->getter->name())
.call(loc, m, {}, {}, /*n_binders=*/1);
}

// We don't define this attr. Bailout with a hint to the user.
std::string hint;
if (auto failureReason = concreteType_->findFailedAttribute(field)) {
Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/python/python_tree_views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>

#include <sstream>

Expand Down Expand Up @@ -168,7 +169,18 @@ void initTreeViewBindings(PyObject* module) {
const Def& getter,
Def* setter) {
return Property::create(r, name, getter, wrap_maybe(r, setter));
}));
}))
.def("name", [](const Property& property) { return property.name(); })
.def(
"getter_name",
[](const Property& property) { return property.getter().name(); })
.def("setter_name", [](const Property& property) {
if (property.setter().present()) {
return c10::optional<Ident>(property.setter().get().name());
}

return c10::optional<Ident>(c10::nullopt);
});

py::class_<ClassDef, TreeView>(m, "ClassDef")
.def(py::init([](const Ident& name,
Expand Down
38 changes: 24 additions & 14 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1590,33 +1590,43 @@ void initJitScriptBindings(PyObject* module) {
return self.equals(other);
})
.def(
"_create_methods",
"_create_methods_and_properties",
[](std::shared_ptr<ConcreteModuleType> concreteType,
const std::vector<Def>& defs,
const std::vector<ResolutionCallback>& rcbs,
const std::vector<Property>& properties,
const std::vector<ResolutionCallback>& propertyRcbs,
const std::vector<Def>& methodDefs,
const std::vector<ResolutionCallback>& methodRcbs,
const std::vector<FunctionDefaults>& defaults) {
TORCH_INTERNAL_ASSERT(defs.size() == rcbs.size());
std::vector<ResolverPtr> resolvers;
resolvers.reserve(rcbs.size());
for (auto& callback : rcbs) {
resolvers.push_back(pythonResolver(callback));
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());

std::vector<ResolverPtr> methodResolvers, propertyResolvers;
methodResolvers.reserve(methodRcbs.size());
for (auto& callback : methodRcbs) {
methodResolvers.push_back(pythonResolver(callback));
}

propertyResolvers.reserve(propertyRcbs.size());
for (auto& callback : propertyRcbs) {
propertyResolvers.push_back(pythonResolver(callback));
}

const auto& selfType =
concreteType->getJitType()->expect<ClassType>();
const auto& prefix = selfType->name().value();
const auto self = ModuleSelf(std::move(concreteType));
auto cu = selfType->compilation_unit();
cu->define(
prefix,
/*properties=*/{},
/*propResolvers=*/{},
defs,
resolvers,
properties,
propertyResolvers,
methodDefs,
methodResolvers,
&self);
// Stitch in default arguments for each Def if provided
auto defaults_it = defaults.begin();
auto defs_it = defs.begin();
while (defs_it != defs.end()) {
auto defs_it = methodDefs.begin();
while (defs_it != methodDefs.end()) {
const auto method_name =
QualifiedName(prefix, (*defs_it).name().name());
auto& method = cu->get_function(method_name);
Expand Down
81 changes: 56 additions & 25 deletions torch/jit/_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from typing import Dict, List, Set, Type

import torch._jit_internal as _jit_internal
from torch.jit.frontend import get_default_args, get_jit_def
from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties
from torch.jit._builtins import _find_builtin
from torch.nn import Module
from torch._six import get_function_from_type, bind_method


ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
PropertyStub = collections.namedtuple('Property', ('resolution_callback', 'def_'))


# TODO: there should be a more principled way of doing this.
ignored_attributes = [
Expand Down Expand Up @@ -49,6 +51,7 @@ def make_stub_from_method(nn_module, method_name):
# even though we requested a stub for `forward`.
return make_stub(func, method_name)


# base types that can be constants
# in addition, tuples and lists of these base types are also considered constants
# If you edit this list, then you also need to edit the handlers in
Expand Down Expand Up @@ -240,14 +243,6 @@ def infer_type(name, item):
"to a TorchScript type.)").format(torch.typename(type(value)))
concrete_type_builder.add_failed_attribute(name, hint)

# Add @property methods as failed attributes, to give a better error message.
for name, value in type(nn_module).__dict__.items():
if isinstance(value, property):
hint = ("\n(This attribute exists on the Python module, but it's an @property "
"method. @property methods are not yet supported in TorchScript. "
"Please file a feature request on Github)")
concrete_type_builder.add_failed_attribute(name, hint)
Comment on lines -243 to -249
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't quite sure what to do here. Properties are really just syntactic sugar for setter and getter methods. Do they belong on the concrete type? If yes, how do we represent them? If not, what happens if two modules have the same underlying data attributes but different property getters/setters for them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, concrete type already returns false if two classes differ: pyClass_.is(other.pyClass_). As long as we only take properties from methods, and not from attributes, we don't need to do any additional work here. If we do take properties as attributes than we'll need to add stuff to the ConcreteType.

(i think it's fine to not take properties that are passed in as attributes, no one really does that)


return concrete_type_builder

class ConcreteTypeStore(object):
Expand Down Expand Up @@ -284,11 +279,17 @@ def get_or_create_concrete_type(self, nn_module):

concrete_type_store = ConcreteTypeStore()

def create_methods_from_stubs(concrete_type, stubs):
defs = [m.def_ for m in stubs]
rcbs = [m.resolution_callback for m in stubs]
defaults = [get_default_args(m.original_method) for m in stubs]
concrete_type._create_methods(defs, rcbs, defaults)

def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
method_defs = [m.def_ for m in method_stubs]
method_rcbs = [m.resolution_callback for m in method_stubs]
method_defaults = [get_default_args(m.original_method) for m in method_stubs]

property_defs = [p.def_ for p in property_stubs]
property_rcbs = [p.resolution_callback for p in property_stubs]

concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)


def get_module_concrete_type(nn_module, share_types=True):
"""
Expand Down Expand Up @@ -347,7 +348,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
"""
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
stubs = stubs_fn(nn_module)
method_stubs = stubs_fn(nn_module)
property_stubs = get_property_stubs(nn_module)

def init_fn(script_module):
# Initialize the ScriptModule:
Expand Down Expand Up @@ -379,9 +381,7 @@ def init_fn(script_module):
# This ensures we can access these Python methods on the ScriptModule.
for name in dir(nn_module):
item = getattr(nn_module, name, None)
if not inspect.ismethod(item):
continue
if _jit_internal.is_ignored_fn(item):
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
unbound_function = getattr(type(nn_module), name)
bound_method = unbound_function.__get__(script_module)
setattr(script_module, name, bound_method)
Expand All @@ -394,7 +394,7 @@ def init_fn(script_module):

# Compile methods if necessary
if concrete_type not in concrete_type_store.methods_compiled:
create_methods_from_stubs(concrete_type, stubs)
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
torch._C._run_emit_module_hook(cpp_module)
concrete_type_store.methods_compiled.add(concrete_type)

Expand All @@ -412,14 +412,14 @@ def init_fn(script_module):


# Make the compiled methods available to the Python ScriptModule class.
for stub in stubs:
if stub.original_method is None:
for method_stub in method_stubs:
if method_stub.original_method is None:
# define()'d methods don't have an Python original_method, so we
# don't need to do any Python re-wrapping stuff
continue

name = stub.original_method.__name__
if name != stub.def_.name().name:
name = method_stub.original_method.__name__
if name != method_stub.def_.name().name:
# TODO: Why skip this? Because @torch.jit._overload_method will
# mangle the name of the function.
continue
Expand All @@ -428,14 +428,23 @@ def init_fn(script_module):
# Wrap the original to propagate docstrings and such.
# TODO: we don't currently do this functions that are recursively
# compiled, we should.
wrapped_script_method = functools.wraps(stub.original_method)(script_method) # type: ignore
wrapped_script_method = functools.wraps(method_stub.original_method)(script_method) # type: ignore

# Add the methods to the script_module directly. This ensures they will
# be found first when `name` is looked up (as opposed to the stubs or
# nn.Module.forward)
script_module.__dict__[name] = wrapped_script_method


# Make module properties available on the Python ScriptModule class.
for property_stub in property_stubs:
property_name = property_stub.def_.name().name
fget = cpp_module._get_method(property_stub.def_.getter_name().name)
# Setter is optional, so it may not exist.
setter_name = property_stub.def_.setter_name()
fset = cpp_module._get_method(setter_name.name) if setter_name else None
script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore

# copy over python methods to script module if they aren't defined on the script module
# this is currently an internal api used only on module containers
for name in dir(nn_module):
Expand Down Expand Up @@ -569,6 +578,28 @@ def ignore_overloaded(method_name):
stubs.append(make_stub_from_method(nn_module, method))
return overload_stubs + stubs


def get_property_stubs(nn_module):
"""
Create property stubs for the properties of the module by creating method
stubs for the getter and setter.
"""
module_ty = type(nn_module)
properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
rcbs = {}

for name in dir(module_ty):
item = getattr(module_ty, name, None)
if isinstance(item, property):
if not item.fget:
raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter')

rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)

stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
return stubs


def interface_script(mod_interface, nn_module):
"""
Makes a ScriptModule from an nn.Module, using the interface methods rule for
Expand Down Expand Up @@ -633,7 +664,7 @@ def compile_unbound_method(concrete_type, fn):
with torch._jit_internal._disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling
# this function is not yet complete
create_methods_from_stubs(concrete_type, (stub,))
create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
return stub

def lazy_bind(concrete_type, unbound_method):
Expand Down
1 change: 1 addition & 0 deletions torch/jit/_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
contain methods, attributes, parameters, and
constants. These can be accessed the same as on a normal ``nn.Module``.
"""
__ignored_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']

def __init__(self):
super(ScriptModule, self).__init__()
Expand Down
9 changes: 6 additions & 3 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,16 @@ def get_class_properties(cls, self_name):
"""
props = inspect.getmembers(
cls, predicate=lambda m: isinstance(m, property))
# Any property that should not compiled must be in this list on the Module.
ignored_properties = getattr(cls, "__ignored_properties__", [])

# Create Property TreeView objects from inspected property objects.
properties = []
for prop in props:
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))
if prop[0] not in ignored_properties:
getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name)
setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None
properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter))

return properties

Expand Down
1 change: 1 addition & 0 deletions torch/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens
class RNNBase(Module):
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional']
__ignored_properties__ = ['all_weights']

mode: str
input_size: int
Expand Down