77from typing import Dict , List , Set , Type
88
99import torch ._jit_internal as _jit_internal
10- from torch .jit .frontend import get_default_args , get_jit_def
10+ from torch .jit .frontend import get_default_args , get_jit_def , get_class_properties
1111from torch .jit ._builtins import _find_builtin
1212from torch .nn import Module
1313from torch ._six import get_function_from_type , bind_method
1414
1515
1616ScriptMethodStub = collections .namedtuple ('ScriptMethodStub' , ('resolution_callback' , 'def_' , 'original_method' ))
17+ PropertyStub = collections .namedtuple ('Property' , ('resolution_callback' , 'def_' ))
18+
1719
1820# TODO: there should be a more principled way of doing this.
1921ignored_attributes = [
@@ -49,6 +51,7 @@ def make_stub_from_method(nn_module, method_name):
4951 # even though we requested a stub for `forward`.
5052 return make_stub (func , method_name )
5153
54+
5255# base types that can be constants
5356# in addition, tuples and lists of these base types are also considered constants
5457# If you edit this list, then you also need to edit the handlers in
@@ -240,14 +243,6 @@ def infer_type(name, item):
240243 "to a TorchScript type.)" ).format (torch .typename (type (value )))
241244 concrete_type_builder .add_failed_attribute (name , hint )
242245
243- # Add @property methods as failed attributes, to give a better error message.
244- for name , value in type (nn_module ).__dict__ .items ():
245- if isinstance (value , property ):
246- hint = ("\n (This attribute exists on the Python module, but it's an @property "
247- "method. @property methods are not yet supported in TorchScript. "
248- "Please file a feature request on Github)" )
249- concrete_type_builder .add_failed_attribute (name , hint )
250-
251246 return concrete_type_builder
252247
253248class ConcreteTypeStore (object ):
@@ -284,11 +279,17 @@ def get_or_create_concrete_type(self, nn_module):
284279
285280concrete_type_store = ConcreteTypeStore ()
286281
287- def create_methods_from_stubs (concrete_type , stubs ):
288- defs = [m .def_ for m in stubs ]
289- rcbs = [m .resolution_callback for m in stubs ]
290- defaults = [get_default_args (m .original_method ) for m in stubs ]
291- concrete_type ._create_methods (defs , rcbs , defaults )
282+
283+ def create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs ):
284+ method_defs = [m .def_ for m in method_stubs ]
285+ method_rcbs = [m .resolution_callback for m in method_stubs ]
286+ method_defaults = [get_default_args (m .original_method ) for m in method_stubs ]
287+
288+ property_defs = [p .def_ for p in property_stubs ]
289+ property_rcbs = [p .resolution_callback for p in property_stubs ]
290+
291+ concrete_type ._create_methods_and_properties (property_defs , property_rcbs , method_defs , method_rcbs , method_defaults )
292+
292293
293294def get_module_concrete_type (nn_module , share_types = True ):
294295 """
@@ -347,7 +348,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
347348 stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
348349 """
349350 cpp_module = torch ._C ._create_module_with_type (concrete_type .jit_type )
350- stubs = stubs_fn (nn_module )
351+ method_stubs = stubs_fn (nn_module )
352+ property_stubs = get_property_stubs (nn_module )
351353
352354 def init_fn (script_module ):
353355 # Initialize the ScriptModule:
@@ -379,9 +381,7 @@ def init_fn(script_module):
379381 # This ensures we can access these Python methods on the ScriptModule.
380382 for name in dir (nn_module ):
381383 item = getattr (nn_module , name , None )
382- if not inspect .ismethod (item ):
383- continue
384- if _jit_internal .is_ignored_fn (item ):
384+ if inspect .ismethod (item ) and _jit_internal .is_ignored_fn (item ):
385385 unbound_function = getattr (type (nn_module ), name )
386386 bound_method = unbound_function .__get__ (script_module )
387387 setattr (script_module , name , bound_method )
@@ -394,7 +394,7 @@ def init_fn(script_module):
394394
395395 # Compile methods if necessary
396396 if concrete_type not in concrete_type_store .methods_compiled :
397- create_methods_from_stubs (concrete_type , stubs )
397+ create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs )
398398 torch ._C ._run_emit_module_hook (cpp_module )
399399 concrete_type_store .methods_compiled .add (concrete_type )
400400
@@ -412,14 +412,14 @@ def init_fn(script_module):
412412
413413
414414 # Make the compiled methods available to the Python ScriptModule class.
415- for stub in stubs :
416- if stub .original_method is None :
415+ for method_stub in method_stubs :
416+ if method_stub .original_method is None :
417417 # define()'d methods don't have an Python original_method, so we
418418 # don't need to do any Python re-wrapping stuff
419419 continue
420420
421- name = stub .original_method .__name__
422- if name != stub .def_ .name ().name :
421+ name = method_stub .original_method .__name__
422+ if name != method_stub .def_ .name ().name :
423423 # TODO: Why skip this? Because @torch.jit._overload_method will
424424 # mangle the name of the function.
425425 continue
@@ -428,14 +428,23 @@ def init_fn(script_module):
428428 # Wrap the original to propagate docstrings and such.
429429 # TODO: we don't currently do this functions that are recursively
430430 # compiled, we should.
431- wrapped_script_method = functools .wraps (stub .original_method )(script_method ) # type: ignore
431+ wrapped_script_method = functools .wraps (method_stub .original_method )(script_method ) # type: ignore
432432
433433 # Add the methods to the script_module directly. This ensures they will
434434 # be found first when `name` is looked up (as opposed to the stubs or
435435 # nn.Module.forward)
436436 script_module .__dict__ [name ] = wrapped_script_method
437437
438438
439+ # Make module properties available on the Python ScriptModule class.
440+ for property_stub in property_stubs :
441+ property_name = property_stub .def_ .name ().name
442+ fget = cpp_module ._get_method (property_stub .def_ .getter_name ().name )
443+ # Setter is optional, so it may not exist.
444+ setter_name = property_stub .def_ .setter_name ()
445+ fset = cpp_module ._get_method (setter_name .name ) if setter_name else None
446+ script_module .__dict__ [property_name ] = property (property_name , fget , fset ) # type: ignore
447+
439448 # copy over python methods to script module if they aren't defined on the script module
440449 # this is currently an internal api used only on module containers
441450 for name in dir (nn_module ):
@@ -569,6 +578,28 @@ def ignore_overloaded(method_name):
569578 stubs .append (make_stub_from_method (nn_module , method ))
570579 return overload_stubs + stubs
571580
581+
582+ def get_property_stubs (nn_module ):
583+ """
584+ Create property stubs for the properties of the module by creating method
585+ stubs for the getter and setter.
586+ """
587+ module_ty = type (nn_module )
588+ properties_asts = get_class_properties (module_ty , self_name = "RecursiveScriptModule" )
589+ rcbs = {}
590+
591+ for name in dir (module_ty ):
592+ item = getattr (module_ty , name , None )
593+ if isinstance (item , property ):
594+ if not item .fget :
595+ raise RuntimeError (f'Property { name } of { nn_module .__name__ } must have a getter' )
596+
597+ rcbs [name ] = _jit_internal .createResolutionCallbackFromClosure (item .fget )
598+
599+ stubs = [PropertyStub (rcbs [ast .name ().name ], ast ) for ast in properties_asts ]
600+ return stubs
601+
602+
572603def interface_script (mod_interface , nn_module ):
573604 """
574605 Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -633,7 +664,7 @@ def compile_unbound_method(concrete_type, fn):
633664 with torch ._jit_internal ._disable_emit_hooks ():
634665 # We don't want to call the hooks here since the graph that is calling
635666 # this function is not yet complete
636- create_methods_from_stubs (concrete_type , (stub ,))
667+ create_methods_and_properties_from_stubs (concrete_type , (stub ,), ( ))
637668 return stub
638669
639670def lazy_bind (concrete_type , unbound_method ):
0 commit comments