Skip to content

Commit 0e3b5ea

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant][fx] Add _convert_to_reference_decomposed (#87094)
Summary: _convert_to_reference_decomposed is a private convert function in fx graph mode quantization flow to convert a calibrated/trained model to a reference quantized model with decomposed quantized tensor representations. Test Plan: python test/test_quantization.py TestQuantizeFx.test__convert_to_reference_decomposed_fx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: #87094 Approved by: https://github.com/andrewor14
1 parent a12d3d6 commit 0e3b5ea

File tree

5 files changed

+177
-28
lines changed

5 files changed

+177
-28
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
prepare_fx,
1919
convert_fx,
2020
convert_to_reference_fx,
21+
_convert_to_reference_decomposed_fx,
2122
prepare_qat_fx,
2223
fuse_fx,
2324
)
2425

26+
2527
from torch.ao.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
2628

2729
from torch.ao.quantization.fx.match_utils import (
@@ -5237,6 +5239,30 @@ def test_get_default_qconfig_valid_backend(self):
52375239
with self.assertRaisesRegex(AssertionError, "not supported"):
52385240
qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)
52395241

5242+
def test__convert_to_reference_decomposed_fx(self):
5243+
class M(torch.nn.Module):
5244+
def __init__(self):
5245+
super().__init__()
5246+
self.linear = torch.nn.Linear(5, 10)
5247+
5248+
def forward(self, x):
5249+
return self.linear(x)
5250+
5251+
m = M().eval()
5252+
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
5253+
example_inputs = (torch.randn(1, 5),)
5254+
m = prepare_fx(m, qconfig_mapping, example_inputs)
5255+
m = _convert_to_reference_decomposed_fx(m)
5256+
expected_occurrence = {
5257+
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
5258+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
5259+
}
5260+
self.checkGraphModuleNodes(
5261+
m,
5262+
expected_node_occurrence=expected_occurrence)
5263+
# make sure it runs
5264+
m(*example_inputs)
5265+
52405266
@skipIfNoFBGEMM
52415267
class TestQuantizeFxOps(QuantizationTestCase):
52425268
def setUp(self):

torch/ao/quantization/fx/convert.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
PrepareCustomConfig,
7070
)
7171
from .lower_to_fbgemm import lower_to_fbgemm
72+
# importing the lib so that the quantized_decomposed ops are registered
73+
from ._decomposed import quantized_decomposed_lib # noqa: F401
7274

7375

7476
# TODO: revisit this list. Many helper methods shouldn't be public
@@ -485,7 +487,8 @@ def convert(
485487
is_standalone_module: bool = False,
486488
_remove_qconfig_flag: bool = True,
487489
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
488-
backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module:
490+
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
491+
is_decomposed: bool = False) -> torch.nn.Module:
489492
"""
490493
We will convert an observed model (a module with observer calls) to a reference
491494
quantized model, the rule is simple:
@@ -497,13 +500,21 @@ def convert(
497500
is stored in observed_node_names, we can decide whether we need to swap the
498501
module based on this set
499502
500-
standalone_module means it a submodule that is not inlined in
501-
parent module, and will be quantized separately as one unit.
502-
503-
Returns a quantized standalone module, whether input/output is quantized is
504-
specified by prepare_custom_config, with
505-
input_quantized_idxs, output_quantized_idxs, please
506-
see docs for prepare_fx for details
503+
Args:
504+
* `is_standalone_module`: when this flag is True, it means we are quantizing
505+
a submodule that is not inlined in parent module, and will be quantized
506+
separately as one unit.
507+
508+
* `is_decomposed`: a boolean flag to indicate whether we want to use the
509+
quantize operator for decomposed quantized tensor
510+
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
511+
quantized tensor (torch.quantize_per_tensor)
512+
513+
Returns:
514+
a quantized standalone module, whether input/output is quantized is
515+
specified by prepare_custom_config, with
516+
input_quantized_idxs, output_quantized_idxs, please
517+
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
507518
"""
508519
if convert_custom_config is None:
509520
convert_custom_config = ConvertCustomConfig()
@@ -595,7 +606,8 @@ def replace_observer_with_quantize_dequantize_node(
595606
node: Node,
596607
modules: Dict[str, torch.nn.Module],
597608
node_name_to_scope: Dict[str, Tuple[str, type]],
598-
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
609+
node_name_to_qconfig: Dict[str, QConfigAny],
610+
is_decomposed: bool) -> None:
599611
""" Replace activation_post_process module call node with quantize and
600612
dequantize node
601613
@@ -608,7 +620,7 @@ def replace_observer_with_quantize_dequantize_node(
608620
assert isinstance(node.target, str)
609621
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
610622
observer_module = modules[node.target]
611-
maybe_quantize_node_info = get_quantize_node_info(observer_module)
623+
maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
612624
# Skip replacing observers to quant/dequant nodes if the qconfigs of all
613625
# consumers and producers of this observer are None
614626
skip_replacement = all([
@@ -626,21 +638,30 @@ def replace_observer_with_quantize_dequantize_node(
626638
# replace observer node with quant - dequant node
627639
with graph.inserting_before(node):
628640
input_node = node.args[0]
629-
inputs = [input_node]
641+
quantize_op_inputs = [input_node]
630642
for key, value in qparams.items():
631643
# TODO: we can add the information of whether a value needs to
632644
# be registered as an attribute in qparams dict itself
633645
if key in ['_scale_', '_zero_point_']:
634646
# For scale and zero_point values we register them as buffers in the root module.
635647
# TODO: maybe need more complex attr name here
636648
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
637-
inputs.append(qparam_node)
649+
quantize_op_inputs.append(qparam_node)
638650
else:
639651
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
640-
inputs.append(value)
641-
642-
quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
643-
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
652+
quantize_op_inputs.append(value)
653+
654+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
655+
if is_decomposed:
656+
# use the same qparams from quantize op
657+
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
658+
dequantized_node = graph.call_function(
659+
torch.ops.quantized_decomposed.dequantize_per_tensor,
660+
tuple(dq_inputs),
661+
{}
662+
)
663+
else:
664+
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
644665
node.replace_all_uses_with(dequantized_node)
645666
graph.erase_node(node)
646667

@@ -711,7 +732,7 @@ def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Gra
711732
else:
712733
replace_observer_with_quantize_dequantize_node(
713734
model, model.graph, node, modules, node_name_to_scope,
714-
node_name_to_qconfig)
735+
node_name_to_qconfig, is_decomposed)
715736
elif isinstance(mod, DeQuantStub):
716737
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
717738
elif is_observed_standalone_module(mod):

torch/ao/quantization/fx/utils.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
activation_is_statically_quantized,
1818
is_per_tensor,
1919
is_per_channel,
20+
to_underlying_dtype,
2021
)
2122
from torch.ao.quantization.quantize import is_activation_post_process
2223

@@ -27,6 +28,8 @@
2728
Node,
2829
)
2930
from .custom_config import PrepareCustomConfig
31+
# importing the lib so that the quantized_decomposed ops are registered
32+
from ._decomposed import quantized_decomposed_lib # noqa: F401
3033

3134
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
3235
from collections import namedtuple
@@ -160,11 +163,22 @@ def get_per_tensor_qparams(activation_post_process):
160163
dtype = activation_post_process.dtype
161164
return scale, zero_point, dtype
162165

163-
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
164-
''' Given an activation_post_process module,
165-
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
166-
of extracted qparams from the module
167-
'''
166+
def get_quantize_node_info(
167+
activation_post_process: Callable,
168+
is_decomposed: bool
169+
) -> Optional[Tuple[str, Union[Callable[..., Any], str], Dict[str, Any]]]:
170+
""" Extract information about quantize op from activation_post_process module
171+
Args:
172+
* `activation_post_process`: observer module instance or fake quant module instance
173+
after calibration/QAT
174+
* `is_decomposed`: a boolean flag to indicate whether we want to use the
175+
quantize operator for decomposed quantized tensor (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
176+
quantized tensor (torch.quantize_per_tensor)
177+
178+
Returns
179+
node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
180+
of extracted qparams from the module
181+
"""
168182
dtype = activation_post_process.dtype # type: ignore[attr-defined]
169183
compute_dtype = None
170184
if hasattr(activation_post_process, "compute_dtype"):
@@ -177,17 +191,36 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
177191
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
178192
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined]
179193
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
180-
quantize_op = torch.quantize_per_channel
194+
if is_decomposed:
195+
raise NotImplementedError("decomposed quantize_per_channel op not implemented yet")
196+
else:
197+
quantize_op = torch.quantize_per_channel
181198
else:
182199
scale = float(scale)
183200
zero_point = int(zero_point)
184-
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
185-
quantize_op = torch.quantize_per_tensor
201+
if is_decomposed:
202+
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
203+
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
204+
dtype = to_underlying_dtype(dtype)
205+
qparams = {
206+
"_scale_": scale,
207+
"_zero_point_": zero_point,
208+
"_quant_min": quant_max,
209+
"_quant_max": quant_max,
210+
"_dtype_": dtype
211+
}
212+
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
213+
else:
214+
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
215+
quantize_op = torch.quantize_per_tensor
186216
elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
187217
# TODO(future PR): switch compute_dtype to is_dynamic
188218
# dynamic quantization
189219
node_type = "call_function"
190-
quantize_op = torch.quantize_per_tensor_dynamic
220+
if is_decomposed:
221+
raise NotImplementedError("decomposed quantize_per_tensor_dynamic op not implemented yet")
222+
else:
223+
quantize_op = torch.quantize_per_tensor_dynamic
191224
# TODO: get reduce range from observer
192225
# reduce_range = activation_post_process.reduce_range
193226
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
@@ -199,8 +232,9 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
199232
else:
200233
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
201234
return None
202-
return node_type, quantize_op, qparams
235+
return node_type, quantize_op, qparams # type: ignore[return-value]
203236

237+
# TODO: looks like this is not used, remove
204238
def quantize_node(
205239
in_node: Node,
206240
obs_module: torch.nn.Module,
@@ -247,7 +281,8 @@ def quantize_node(
247281
module_path = ""
248282
root_module = modules['']
249283
graph = quantized_graph
250-
maybe_quantize_node_info = get_quantize_node_info(obs_module)
284+
is_decomposed_qtensor = False
285+
maybe_quantize_node_info = get_quantize_node_info(obs_module, is_decomposed_qtensor)
251286
assert maybe_quantize_node_info is not None, \
252287
f"Expecting quantize node info not to be None, observer: {obs_module}"
253288
node_type, quantize_op, qparams = maybe_quantize_node_info

torch/ao/quantization/quantize_fx.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def _convert_fx(
530530
_remove_qconfig: bool = True,
531531
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
532532
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
533+
is_decomposed: bool = False,
533534
) -> torch.nn.Module:
534535
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
535536
"""
@@ -552,6 +553,7 @@ def _convert_fx(
552553
_remove_qconfig_flag=_remove_qconfig,
553554
qconfig_mapping=qconfig_mapping,
554555
backend_config=backend_config,
556+
is_decomposed=is_decomposed,
555557
)
556558

557559
preserved_attributes = convert_custom_config.preserved_attributes
@@ -676,6 +678,59 @@ def convert_to_reference_fx(
676678
backend_config=backend_config,
677679
)
678680

681+
def _convert_to_reference_decomposed_fx(
682+
graph_module: GraphModule,
683+
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
684+
_remove_qconfig: bool = True,
685+
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
686+
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
687+
) -> torch.nn.Module:
688+
r""" Convert a calibrated or trained model to a reference quantized model, with
689+
decomposed representation for quantized Tensor
690+
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
691+
reference quantzied model is a standard representation of a quantized model provided
692+
by FX Graph Mode Quantization, it can be further lowered to run on the target
693+
hardware, like accelerators
694+
695+
Note: this is not public API
696+
697+
Args:
698+
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
699+
700+
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
701+
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
702+
703+
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
704+
705+
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
706+
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
707+
708+
* `backend_config` (BackendConfig): A configuration for the backend which describes how
709+
operators should be quantized in the backend. See
710+
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
711+
712+
Return:
713+
A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
714+
715+
Example::
716+
717+
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
718+
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
719+
# e.g. backend_config = get_default_backend_config("fbgemm")
720+
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
721+
722+
"""
723+
torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
724+
return _convert_fx(
725+
graph_module,
726+
is_reference=True,
727+
convert_custom_config=convert_custom_config,
728+
_remove_qconfig=_remove_qconfig,
729+
qconfig_mapping=qconfig_mapping,
730+
backend_config=backend_config,
731+
is_decomposed=True,
732+
)
733+
679734

680735
def _convert_standalone_module_fx(
681736
graph_module: GraphModule,

torch/ao/quantization/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ def getattr_from_fqn(obj: Any, fqn: str) -> Any:
140140
"""
141141
return functools.reduce(getattr, fqn.split("."), obj)
142142

143+
def to_underlying_dtype(qdtype):
144+
DTYPE_MAPPING = {
145+
torch.quint8: torch.uint8,
146+
torch.qint8: torch.int8,
147+
torch.qint32: torch.int32,
148+
torch.quint4x2: torch.uint8,
149+
torch.quint2x4: torch.uint8,
150+
}
151+
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
152+
return DTYPE_MAPPING[qdtype]
153+
143154
def get_qparam_dict(observer_or_fake_quant):
144155
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
145156
dtype = observer_or_fake_quant.dtype
@@ -562,4 +573,5 @@ def _patched_module_call(self, *args, **kwargs):
562573
"calculate_qmin_qmax",
563574
"has_no_children_ignoring_parametrizations",
564575
"get_fqn_to_example_inputs",
576+
"to_underlying_dtype",
565577
]

0 commit comments

Comments
 (0)