11import torch
22from torch .fx import GraphModule # type: ignore
3- from torch .fx import symbolic_trace # type: ignore
43from torch .fx .symbolic_trace import Tracer # type: ignore
54from .fx import Fuser # noqa: F401
65from .fx import Quantizer # noqa: F401
@@ -24,16 +23,16 @@ def _fuse_fx(graph_module, inplace=False):
2423 return fuser .fuse (graph_module , inplace )
2524
2625class CustomTracer (Tracer ):
27- def __init__ (self , standalone_modules , custom_module_classes ):
26+ def __init__ (self , skipped_module_names , skipped_module_classes ):
2827 super ().__init__ ()
29- self .standalone_modules = standalone_modules
30- self .custom_module_classes = custom_module_classes
28+ self .skipped_module_names = skipped_module_names
29+ self .skipped_module_classes = skipped_module_classes
3130
3231 def is_leaf_module (self , m , module_qualified_name ):
3332 return (m .__module__ .startswith ('torch.nn' ) and
3433 not isinstance (m , torch .nn .Sequential )) or \
35- module_qualified_name in self .standalone_modules or \
36- type (m ) in self .custom_module_classes
34+ module_qualified_name in self .skipped_module_names or \
35+ type (m ) in self .skipped_module_classes
3736
3837
3938def _prepare_fx (model , qconfig_dict , inplace , prepare_custom_config_dict = None , is_standalone_module = False ):
@@ -50,17 +49,19 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
5049 if prepare_custom_config_dict is None :
5150 prepare_custom_config_dict = {}
5251
52+ skipped_module_names = prepare_custom_config_dict .get ("non_traceable_module_name" , [])
53+ skipped_module_classes = prepare_custom_config_dict .get ("non_traceable_module_class" , [])
54+
5355 # symbolically trace the model
54- if is_standalone_module :
55- # standlone module is traced before quantizing standalone modules
56- graph_module = symbolic_trace (model )
57- else :
58- standalone_modules = prepare_custom_config_dict .get ('standalone_module_name' , [])
56+ if not is_standalone_module :
57+ # standalone module and custom module config are applied in top level module
58+ standalone_module_names = prepare_custom_config_dict .get ('standalone_module_name' , [])
59+ skipped_module_names += standalone_module_names
5960 custom_module_config = prepare_custom_config_dict .get ('float_to_observed_custom_module_class' , {})
6061 custom_module_classes = list (custom_module_config .keys ())
61- # skipping tracing standalone modules when tracing top level module
62- tracer = CustomTracer (standalone_modules , custom_module_classes )
63- graph_module = GraphModule (model , tracer .trace (model ))
62+ skipped_module_classes += custom_module_classes
63+ tracer = CustomTracer (skipped_module_names , skipped_module_classes )
64+ graph_module = GraphModule (model , tracer .trace (model ))
6465 graph_module = _fuse_fx (graph_module , inplace )
6566 quantizer = Quantizer ()
6667 return quantizer .prepare (
@@ -156,7 +157,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
156157 # float custom module to observed custom module
157158 "float_to_observed_custom_module_class": {
158159 CustomModule: ObservedCustomModule
159- }
160+ },
161+
162+ # the qualified names for the submodule that are not symbolically traceable
163+ "non_traceable_module_name": [
164+ "non_traceable_module"
165+ ],
166+
167+ # the module classes that are not symbolically traceable
168+ "non_traceable_module_class": [
169+ NonTraceableModule
170+ ]
160171 }
161172
162173
0 commit comments