Skip to content

Commit 4e6aefa

Browse files
terrychenismpytorchmergebot
authored andcommitted
[Qunat] Refactor reference module mapping (#72755)
Summary: Pull Request resolved: #72755 Add is_refernece flag in convert function Test Plan: python3 test/test_quantization.py TestQuantizeEagerOps.test_conv_transpose_2d Imported from OSS Reviewed By: mruberry Differential Revision: D34188856 fbshipit-source-id: 291014a7b3b4d4b40ca0ca76a80711097dcc4b58 (cherry picked from commit cfba3b8)
1 parent 5993f48 commit 4e6aefa

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

test/quantization/eager/test_quantize_eager_ptq.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.quantized as nnq
6-
import torch.nn.quantized._reference as nnqr
76
from torch.nn.utils.rnn import PackedSequence
87
from torch.ao.quantization import (
98
quantize,
@@ -140,17 +139,7 @@ def forward(self, x):
140139

141140
ref_m = prepare(original_ref_m)
142141
ref_m(data)
143-
reference_module_mapping = {
144-
QuantStub: nnq.Quantize,
145-
DeQuantStub: nnq.DeQuantize,
146-
nn.Conv1d: nnqr.Conv1d,
147-
nn.Conv2d: nnqr.Conv2d,
148-
nn.Conv3d: nnqr.Conv3d,
149-
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
150-
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
151-
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
152-
}
153-
ref_m = convert(ref_m, mapping=reference_module_mapping)
142+
ref_m = convert(ref_m, is_reference=True)
154143
ref_res = ref_m(data)
155144
self.assertEqual(res, ref_res)
156145

@@ -202,6 +191,14 @@ def test_conv_transpose_3d(self):
202191
(16, 1, 10, 10, 10)
203192
)
204193

194+
def test_linear(self):
195+
self._test_reference_module_impl(
196+
nn.Linear,
197+
nnq.Linear,
198+
{'in_features': 5, 'out_features': 10},
199+
(16, 5)
200+
)
201+
205202
def _test_activation_op_impl(
206203
self, float_module_class, quantized_module_class, extra_module_kwargs):
207204
""" Implementation for testing common activation ops like leaky relu

torch/ao/quantization/quantization_mappings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
# Default map for swapping float module to reference quantized modules
2828
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
29+
QuantStub: nnq.Quantize,
30+
DeQuantStub: nnq.DeQuantize,
2931
nn.Linear: nnqr.Linear,
3032
nn.Conv1d: nnqr.Conv1d,
3133
nn.Conv2d: nnqr.Conv2d,
@@ -175,6 +177,11 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
175177
'''
176178
return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
177179

180+
def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
181+
''' Get reference module mapping for post training static quantization
182+
'''
183+
return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
184+
178185
def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
179186
''' Get module mapping, including mapping for embedding QAT
180187
'''

torch/ao/quantization/quantize.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.ao.quantization.quantization_mappings import (
1111
get_default_dynamic_quant_module_mappings,
1212
get_default_static_quant_module_mappings,
13+
get_default_static_quant_reference_module_mappings,
1314
get_default_qat_module_mappings,
1415
get_default_qconfig_propagation_list,
1516
no_observer_set,
@@ -472,7 +473,7 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
472473

473474
def convert(
474475
module, mapping=None, inplace=False, remove_qconfig=True,
475-
convert_custom_config_dict=None):
476+
is_reference=False, convert_custom_config_dict=None):
476477
r"""Converts submodules in input module to a different module according to `mapping`
477478
by calling `from_float` method on the target module class. And remove qconfig at the
478479
end if remove_qconfig is set to True.
@@ -503,15 +504,15 @@ def convert(
503504
if not inplace:
504505
module = copy.deepcopy(module)
505506
_convert(
506-
module, mapping, inplace=True,
507+
module, mapping, inplace=True, is_reference=is_reference,
507508
convert_custom_config_dict=convert_custom_config_dict)
508509
if remove_qconfig:
509510
_remove_qconfig(module)
510511
return module
511512

512513
def _convert(
513514
module, mapping=None, inplace=False,
514-
convert_custom_config_dict=None):
515+
is_reference=False, convert_custom_config_dict=None):
515516
r"""Converts submodules in input module to a different module according to `mapping`
516517
by calling `from_float` method on the target module class
517518
@@ -522,10 +523,12 @@ def _convert(
522523
Modules
523524
inplace: carry out model transformations in-place, the original module
524525
is mutated
526+
is_reference: a flag to enable quantized reference module
525527
526528
"""
527529
if mapping is None:
528-
mapping = get_default_static_quant_module_mappings()
530+
mapping = get_default_static_quant_reference_module_mappings() if is_reference \
531+
else get_default_static_quant_module_mappings()
529532
if convert_custom_config_dict is None:
530533
convert_custom_config_dict = {}
531534
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
@@ -539,7 +542,7 @@ def _convert(
539542
if not isinstance(mod, _FusedModule) and \
540543
type(mod) not in custom_module_class_mapping:
541544
_convert(mod, mapping, True, # inplace
542-
convert_custom_config_dict)
545+
is_reference, convert_custom_config_dict)
543546
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
544547

545548
for key, value in reassign.items():

torch/nn/quantized/_reference/modules/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class Linear(nn.Linear, ReferenceQuantizedModule):
1212
and dequantize the weight before running the floating point functional
1313
linear operator.
1414
"""
15+
_IS_REFERENCE = True
16+
1517
def __init__(
1618
self,
1719
in_features: int,

0 commit comments

Comments
 (0)