66import warnings
77
88import torch ._jit_internal as _jit_internal
9- from torch .jit .frontend import get_default_args , get_jit_def
9+ from torch .jit .frontend import get_default_args , get_jit_def , get_class_properties
1010from torch .jit ._builtins import _find_builtin
1111from torch .nn import Module
1212from torch ._six import get_function_from_type , bind_method
1313
1414
1515ScriptMethodStub = collections .namedtuple ('ScriptMethodStub' , ('resolution_callback' , 'def_' , 'original_method' ))
16+ PropertyStub = collections .namedtuple ('Property' , ('resolution_callback' , 'def_' ))
17+
1618
1719# TODO: there should be a more principled way of doing this.
1820ignored_attributes = [
@@ -48,6 +50,7 @@ def make_stub_from_method(nn_module, method_name):
4850 # even though we requested a stub for `forward`.
4951 return make_stub (func , method_name )
5052
53+
5154# base types that can be constants
5255# in addition, tuples and lists of these base types are also considered constants
5356# If you edit this list, then you also need to edit the handlers in
@@ -239,14 +242,6 @@ def infer_type(name, item):
239242 "to a TorchScript type.)" ).format (torch .typename (type (value )))
240243 concrete_type_builder .add_failed_attribute (name , hint )
241244
242- # Add @property methods as failed attributes, to give a better error message.
243- for name , value in type (nn_module ).__dict__ .items ():
244- if isinstance (value , property ):
245- hint = ("\n (This attribute exists on the Python module, but it's an @property "
246- "method. @property methods are not yet supported in TorchScript. "
247- "Please file a feature request on Github)" )
248- concrete_type_builder .add_failed_attribute (name , hint )
249-
250245 return concrete_type_builder
251246
252247class ConcreteTypeStore (object ):
@@ -285,11 +280,17 @@ def get_or_create_concrete_type(self, nn_module):
285280
286281concrete_type_store = ConcreteTypeStore ()
287282
288- def create_methods_from_stubs (concrete_type , stubs ):
289- defs = [m .def_ for m in stubs ]
290- rcbs = [m .resolution_callback for m in stubs ]
291- defaults = [get_default_args (m .original_method ) for m in stubs ]
292- concrete_type ._create_methods (defs , rcbs , defaults )
283+
284+ def create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs ):
285+ method_defs = [m .def_ for m in method_stubs ]
286+ method_rcbs = [m .resolution_callback for m in method_stubs ]
287+ method_defaults = [get_default_args (m .original_method ) for m in method_stubs ]
288+
289+ property_defs = [p .def_ for p in property_stubs ]
290+ property_rcbs = [p .resolution_callback for p in property_stubs ]
291+
292+ concrete_type ._create_methods_and_properties (property_defs , property_rcbs , method_defs , method_rcbs , method_defaults )
293+
293294
294295def create_script_module (nn_module , stubs_fn , share_types = True ):
295296 """
@@ -326,7 +327,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
326327 stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
327328 """
328329 cpp_module = torch ._C ._create_module_with_type (concrete_type .jit_type )
329- stubs = stubs_fn (nn_module )
330+ method_stubs = stubs_fn (nn_module )
331+ property_stubs = get_property_stubs (nn_module )
330332
331333 def init_fn (script_module ):
332334 # Initialize the ScriptModule:
@@ -354,13 +356,11 @@ def init_fn(script_module):
354356 cpp_module .setattr (name , scripted )
355357 script_module ._modules [name ] = scripted
356358
357- # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
359+ # 3. Copy @ignored/@unused methods and properties from the original `nn_module` to the new ScriptModule.
358360 # This ensures we can access these Python methods on the ScriptModule.
359361 for name in dir (nn_module ):
360362 item = getattr (nn_module , name , None )
361- if not inspect .ismethod (item ):
362- continue
363- if _jit_internal .is_ignored_fn (item ):
363+ if inspect .ismethod (item ) and _jit_internal .is_ignored_fn (item ):
364364 unbound_function = getattr (type (nn_module ), name )
365365 bound_method = unbound_function .__get__ (script_module )
366366 setattr (script_module , name , bound_method )
@@ -373,7 +373,7 @@ def init_fn(script_module):
373373
374374 # Compile methods if necessary
375375 if concrete_type not in concrete_type_store .methods_compiled :
376- create_methods_from_stubs (concrete_type , stubs )
376+ create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs )
377377 torch ._C ._run_emit_module_hook (cpp_module )
378378 concrete_type_store .methods_compiled .add (concrete_type )
379379
@@ -391,14 +391,14 @@ def init_fn(script_module):
391391
392392
393393 # Make the compiled methods available to the Python ScriptModule class.
394- for stub in stubs :
395- if stub .original_method is None :
394+ for method_stub in method_stubs :
395+ if method_stub .original_method is None :
396396 # define()'d methods don't have an Python original_method, so we
397397 # don't need to do any Python re-wrapping stuff
398398 continue
399399
400- name = stub .original_method .__name__
401- if name != stub .def_ .name ().name :
400+ name = method_stub .original_method .__name__
401+ if name != method_stub .def_ .name ().name :
402402 # TODO: Why skip this? Because @torch.jit._overload_method will
403403 # mangle the name of the function.
404404 continue
@@ -407,14 +407,23 @@ def init_fn(script_module):
407407 # Wrap the original to propagate docstrings and such.
408408 # TODO: we don't currently do this functions that are recursively
409409 # compiled, we should.
410- script_method = functools .wraps (stub .original_method )(script_method )
410+ script_method = functools .wraps (method_stub .original_method )(script_method )
411411
412412 # Add the methods to the script_module directly. This ensures they will
413413 # be found first when `name` is looked up (as opposed to the stubs or
414414 # nn.Module.forward)
415415 script_module .__dict__ [name ] = script_method
416416
417417
418+ # Make module properties available on the Python ScriptModule class.
419+ for property_stub in property_stubs :
420+ property_name = property_stub .def_ .name ().name
421+ fget = cpp_module ._get_method (property_stub .def_ .getter_name ().name )
422+ # Setter is optional, so it may not exist.
423+ setter_name = property_stub .def_ .setter_name ()
424+ fset = cpp_module ._get_method (setter_name .name ) if setter_name else None
425+ script_module .__dict__ [property_name ] = property (property_name , fget , fset )
426+
418427 # copy over python methods to script module if they aren't defined on the script module
419428 # this is currently an internal api used only on module containers
420429 for name in dir (nn_module ):
@@ -548,6 +557,28 @@ def ignore_overloaded(method_name):
548557 stubs .append (make_stub_from_method (nn_module , method ))
549558 return overload_stubs + stubs
550559
560+
561+ def get_property_stubs (nn_module ):
562+ """
563+ Create property stubs for the properties of the module by creating method
564+ stubs for the getter and setter.
565+ """
566+ module_ty = type (nn_module )
567+ properties_asts = get_class_properties (module_ty , self_name = "RecursiveScriptModule" )
568+ rcbs = {}
569+
570+ for name in dir (module_ty ):
571+ item = getattr (module_ty , name , None )
572+ if isinstance (item , property ):
573+ if not item .fget :
574+ raise RuntimeError (f'Property { name } of { nn_module .__name__ } must have a getter' )
575+
576+ rcbs [name ] = _jit_internal .createResolutionCallbackFromClosure (item .fget )
577+
578+ stubs = [PropertyStub (rcbs [ast .name ().name ], ast ) for ast in properties_asts ]
579+ return stubs
580+
581+
551582def interface_script (mod_interface , nn_module ):
552583 """
553584 Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -612,7 +643,7 @@ def compile_unbound_method(concrete_type, fn):
612643 with torch ._jit_internal ._disable_emit_hooks ():
613644 # We don't want to call the hooks here since the graph that is calling
614645 # this function is not yet complete
615- create_methods_from_stubs (concrete_type , (stub ,))
646+ create_methods_and_properties_from_stubs (concrete_type , (stub ,), ( ))
616647 return stub
617648
618649def lazy_bind (concrete_type , unbound_method ):
0 commit comments