66import warnings
77
88import torch ._jit_internal as _jit_internal
9- from torch .jit .frontend import get_default_args , get_jit_def
9+ from torch .jit .frontend import get_default_args , get_jit_def , get_class_properties
1010from torch .jit ._builtins import _find_builtin
1111from torch .nn import Module
1212from torch ._six import get_function_from_type , bind_method
1313
1414
1515ScriptMethodStub = collections .namedtuple ('ScriptMethodStub' , ('resolution_callback' , 'def_' , 'original_method' ))
16+ PropertyStub = collections .namedtuple ('Property' , ('resolution_callback' , 'def_' ))
17+
1618
1719# TODO: there should be a more principled way of doing this.
1820ignored_attributes = [
2931 "dump_patches" ,
3032]
3133
34+ ignored_properties = [
35+ # Temporary fix for RNN module property named 'all_weights' being scripted
36+ "all_weights" ,
37+ "original_name" ,
38+ "graph" ,
39+ "inlined_graph" ,
40+ "code" ,
41+ "code_with_constants" ,
42+ ]
43+
3244def make_stub (func , name ):
3345 rcb = _jit_internal .createResolutionCallbackFromClosure (func )
3446 ast = get_jit_def (func , name , self_name = "RecursiveScriptModule" )
@@ -48,6 +60,7 @@ def make_stub_from_method(nn_module, method_name):
4860 # even though we requested a stub for `forward`.
4961 return make_stub (func , method_name )
5062
63+
5164# base types that can be constants
5265# in addition, tuples and lists of these base types are also considered constants
5366# If you edit this list, then you also need to edit the handlers in
@@ -239,14 +252,6 @@ def infer_type(name, item):
239252 "to a TorchScript type.)" ).format (torch .typename (type (value )))
240253 concrete_type_builder .add_failed_attribute (name , hint )
241254
242- # Add @property methods as failed attributes, to give a better error message.
243- for name , value in type (nn_module ).__dict__ .items ():
244- if isinstance (value , property ):
245- hint = ("\n (This attribute exists on the Python module, but it's an @property "
246- "method. @property methods are not yet supported in TorchScript. "
247- "Please file a feature request on Github)" )
248- concrete_type_builder .add_failed_attribute (name , hint )
249-
250255 return concrete_type_builder
251256
252257class ConcreteTypeStore (object ):
@@ -285,11 +290,17 @@ def get_or_create_concrete_type(self, nn_module):
285290
286291concrete_type_store = ConcreteTypeStore ()
287292
288- def create_methods_from_stubs (concrete_type , stubs ):
289- defs = [m .def_ for m in stubs ]
290- rcbs = [m .resolution_callback for m in stubs ]
291- defaults = [get_default_args (m .original_method ) for m in stubs ]
292- concrete_type ._create_methods (defs , rcbs , defaults )
293+
294+ def create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs ):
295+ method_defs = [m .def_ for m in method_stubs ]
296+ method_rcbs = [m .resolution_callback for m in method_stubs ]
297+ method_defaults = [get_default_args (m .original_method ) for m in method_stubs ]
298+
299+ property_defs = [p .def_ for p in property_stubs ]
300+ property_rcbs = [p .resolution_callback for p in property_stubs ]
301+
302+ concrete_type ._create_methods_and_properties (property_defs , property_rcbs , method_defs , method_rcbs , method_defaults )
303+
293304
294305def create_script_module (nn_module , stubs_fn , share_types = True ):
295306 """
@@ -326,7 +337,8 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
326337 stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
327338 """
328339 cpp_module = torch ._C ._create_module_with_type (concrete_type .jit_type )
329- stubs = stubs_fn (nn_module )
340+ method_stubs = stubs_fn (nn_module )
341+ property_stubs = get_property_stubs (nn_module )
330342
331343 def init_fn (script_module ):
332344 # Initialize the ScriptModule:
@@ -373,7 +385,7 @@ def init_fn(script_module):
373385
374386 # Compile methods if necessary
375387 if concrete_type not in concrete_type_store .methods_compiled :
376- create_methods_from_stubs (concrete_type , stubs )
388+ create_methods_and_properties_from_stubs (concrete_type , method_stubs , property_stubs )
377389 torch ._C ._run_emit_module_hook (cpp_module )
378390 concrete_type_store .methods_compiled .add (concrete_type )
379391
@@ -391,14 +403,14 @@ def init_fn(script_module):
391403
392404
393405 # Make the compiled methods available to the Python ScriptModule class.
394- for stub in stubs :
395- if stub .original_method is None :
406+ for method_stub in method_stubs :
407+ if method_stub .original_method is None :
396408 # define()'d methods don't have an Python original_method, so we
397409 # don't need to do any Python re-wrapping stuff
398410 continue
399411
400- name = stub .original_method .__name__
401- if name != stub .def_ .name ().name :
412+ name = method_stub .original_method .__name__
413+ if name != method_stub .def_ .name ().name :
402414 # TODO: Why skip this? Because @torch.jit._overload_method will
403415 # mangle the name of the function.
404416 continue
@@ -407,14 +419,20 @@ def init_fn(script_module):
407419 # Wrap the original to propagate docstrings and such.
408420 # TODO: we don't currently do this functions that are recursively
409421 # compiled, we should.
410- script_method = functools .wraps (stub .original_method )(script_method )
422+ script_method = functools .wraps (method_stub .original_method )(script_method )
411423
412424 # Add the methods to the script_module directly. This ensures they will
413425 # be found first when `name` is looked up (as opposed to the stubs or
414426 # nn.Module.forward)
415427 script_module .__dict__ [name ] = script_method
416428
417429
430+ for property_stub in property_stubs :
431+ property_name = property_stub .def_ .name ().name
432+ fget = cpp_module ._get_method (property_stub .def_ .getter_name ().name )
433+ fset = cpp_module ._get_method (property_stub .def_ .setter_name ().name )
434+ script_module .__dict__ [property_name ] = property (property_name , fget , fset )
435+
418436 # copy over python methods to script module if they aren't defined on the script module
419437 # this is currently an internal api used only on module containers
420438 for name in dir (nn_module ):
@@ -548,6 +566,32 @@ def ignore_overloaded(method_name):
548566 stubs .append (make_stub_from_method (nn_module , method ))
549567 return overload_stubs + stubs
550568
569+
570+ def get_property_stubs (nn_module ):
571+ """
572+ Create property stubs for the properties of the module by creating method
573+ stubs for the getter and setter.
574+ """
575+ module_ty = type (nn_module )
576+ properties_asts = get_class_properties (module_ty , self_name = "RecursiveScriptModule" )
577+ rcbs = {}
578+
579+ for name in dir (module_ty ):
580+ item = getattr (module_ty , name , None )
581+ if isinstance (item , property ) and name not in ignored_properties :
582+ if not item .fget :
583+ raise RuntimeError (f'Property { name } of { nn_module .__name__ } must have a getter' )
584+
585+ rcbs [name ] = _jit_internal .createResolutionCallbackFromClosure (item .fget )
586+
587+ stubs = []
588+ for ast in properties_asts :
589+ if ast .name ().name not in ignored_properties :
590+ stubs .append (PropertyStub (rcbs [ast .name ().name ], ast ))
591+
592+ return stubs
593+
594+
551595def interface_script (mod_interface , nn_module ):
552596 """
553597 Makes a ScriptModule from an nn.Module, using the interface methods rule for
@@ -612,7 +656,7 @@ def compile_unbound_method(concrete_type, fn):
612656 with torch ._jit_internal ._disable_emit_hooks ():
613657 # We don't want to call the hooks here since the graph that is calling
614658 # this function is not yet complete
615- create_methods_from_stubs (concrete_type , (stub ,))
659+ create_methods_and_properties_from_stubs (concrete_type , (stub ,), ( ))
616660 return stub
617661
618662def lazy_bind (concrete_type , unbound_method ):
0 commit comments