Skip to content

Commit a06b95b

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode][fx] Support non_traceable_module/module_class (#46298)
Summary: Pull Request resolved: #46298 Allow user to specify a list of qualified names for non traceable submodule or type of the non traceable submodule See quantize_fx.py for api Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D24294210 fbshipit-source-id: eb1e309065e3dfbf31e63507aaed73587f0dae29
1 parent 5b0f400 commit a06b95b

File tree

2 files changed

+79
-15
lines changed

2 files changed

+79
-15
lines changed

test/quantization/test_quantize_fx.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,59 @@ def forward(self, x):
761761
ref_res = ref_m(data)
762762
self.assertEqual(res, ref_res)
763763

764+
@skipIfNoFBGEMM
765+
def test_non_traceable_module(self):
766+
class NonTraceable(torch.nn.Module):
767+
def __init__(self):
768+
super().__init__()
769+
770+
def forward(self, x):
771+
for k in x.keys():
772+
print(x[k])
773+
return x
774+
775+
class NonTraceable2(torch.nn.Module):
776+
def __init__(self):
777+
super().__init__()
778+
779+
def forward(self, x):
780+
# data dependent control flow is not traceable
781+
for i in x:
782+
print(i)
783+
return x
784+
785+
class M(torch.nn.Module):
786+
def __init__(self):
787+
super().__init__()
788+
self.m1 = NonTraceable()
789+
self.m2 = NonTraceable2()
790+
791+
def forward(self, x):
792+
x = self.m1(x)
793+
x = self.m2(x)
794+
return x
795+
796+
m = M().eval()
797+
qconfig_dict = {"": default_qconfig}
798+
prepare_custom_config_dict = {
799+
"non_traceable_module_name": [
800+
"m1"
801+
],
802+
"non_traceable_module_class": [
803+
NonTraceable2
804+
]
805+
}
806+
m = prepare_fx(
807+
m, qconfig_dict,
808+
prepare_custom_config_dict=prepare_custom_config_dict)
809+
810+
node_occurrence = {
811+
ns.call_module(NonTraceable) : 1,
812+
ns.call_module(NonTraceable2) : 1,
813+
}
814+
# make sure these modules are not traced
815+
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
816+
764817
class TestQuantizeFxOps(QuantizationTestCase):
765818
"""Unit tests for individual ops
766819
"""

torch/quantization/quantize_fx.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from torch.fx import GraphModule # type: ignore
3-
from torch.fx import symbolic_trace # type: ignore
43
from torch.fx.symbolic_trace import Tracer # type: ignore
54
from .fx import Fuser # noqa: F401
65
from .fx import Quantizer # noqa: F401
@@ -24,16 +23,16 @@ def _fuse_fx(graph_module, inplace=False):
2423
return fuser.fuse(graph_module, inplace)
2524

2625
class CustomTracer(Tracer):
27-
def __init__(self, standalone_modules, custom_module_classes):
26+
def __init__(self, skipped_module_names, skipped_module_classes):
2827
super().__init__()
29-
self.standalone_modules = standalone_modules
30-
self.custom_module_classes = custom_module_classes
28+
self.skipped_module_names = skipped_module_names
29+
self.skipped_module_classes = skipped_module_classes
3130

3231
def is_leaf_module(self, m, module_qualified_name):
3332
return (m.__module__.startswith('torch.nn') and
3433
not isinstance(m, torch.nn.Sequential)) or \
35-
module_qualified_name in self.standalone_modules or \
36-
type(m) in self.custom_module_classes
34+
module_qualified_name in self.skipped_module_names or \
35+
type(m) in self.skipped_module_classes
3736

3837

3938
def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, is_standalone_module=False):
@@ -50,17 +49,19 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
5049
if prepare_custom_config_dict is None:
5150
prepare_custom_config_dict = {}
5251

52+
skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", [])
53+
skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", [])
54+
5355
# symbolically trace the model
54-
if is_standalone_module:
55-
# standlone module is traced before quantizing standalone modules
56-
graph_module = symbolic_trace(model)
57-
else:
58-
standalone_modules = prepare_custom_config_dict.get('standalone_module_name', [])
56+
if not is_standalone_module:
57+
# standalone module and custom module config are applied in top level module
58+
standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', [])
59+
skipped_module_names += standalone_module_names
5960
custom_module_config = prepare_custom_config_dict.get('float_to_observed_custom_module_class', {})
6061
custom_module_classes = list(custom_module_config.keys())
61-
# skipping tracing standalone modules when tracing top level module
62-
tracer = CustomTracer(standalone_modules, custom_module_classes)
63-
graph_module = GraphModule(model, tracer.trace(model))
62+
skipped_module_classes += custom_module_classes
63+
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
64+
graph_module = GraphModule(model, tracer.trace(model))
6465
graph_module = _fuse_fx(graph_module, inplace)
6566
quantizer = Quantizer()
6667
return quantizer.prepare(
@@ -156,7 +157,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
156157
# float custom module to observed custom module
157158
"float_to_observed_custom_module_class": {
158159
CustomModule: ObservedCustomModule
159-
}
160+
},
161+
162+
# the qualified names for the submodule that are not symbolically traceable
163+
"non_traceable_module_name": [
164+
"non_traceable_module"
165+
],
166+
167+
# the module classes that are not symbolically traceable
168+
"non_traceable_module_class": [
169+
NonTraceableModule
170+
]
160171
}
161172
162173

0 commit comments

Comments
 (0)