Skip to content

Commit a1382b8

Browse files
committed
[quant][fx] Support override observers and fake quantize module in backend_config_dict
Summary: Some operators have fixed quantization parameters, this PR adds the support to override the qconfig in the backend_config_dict Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1dbd088 Pull Request resolved: #75135
1 parent 4f78ca2 commit a1382b8

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4034,7 +4034,7 @@ def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2):
40344034
self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr)
40354035

40364036
def test_fixed_qparams_patterns(self):
4037-
hard_sigmoid_keys = [torch.nn.Hardsigmoid, torch.nn.functional.hardsigmoid, "hardsigmoid", "hardsigmoid_"]
4037+
hard_sigmoid_keys = [torch.nn.functional.hardsigmoid, "hardsigmoid", "hardsigmoid_"]
40384038
sigmoid_keys = [torch.nn.Sigmoid, torch.sigmoid, "sigmoid", "sigmoid_"]
40394039
tanh_keys = [torch.nn.Tanh, torch.tanh, "tanh", "tanh_"]
40404040
for k in hard_sigmoid_keys + sigmoid_keys:

torch/ao/quantization/fx/backend_config/native.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch.nn.qat as nnqat
33
import operator
44
from .observation_type import ObservationType
5+
from ...observer import default_affine_fixed_qparams_observer
6+
from ...fake_quantize import FixedQParamsFakeQuantize
57

68
def _get_default_op_backend_config(op, dtype_configs):
79
return {
@@ -91,6 +93,15 @@ def _get_default_op_backend_config(op, dtype_configs):
9193
],
9294
}
9395

96+
_HARDSIGMOID_MODULE_CONFIG = {
97+
"pattern": torch.nn.Hardsigmoid,
98+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
99+
"overwrite_output_fake_quantizer": FixedQParamsFakeQuantize.with_args(observer=default_affine_fixed_qparams_observer),
100+
"overwrite_output_observer": default_affine_fixed_qparams_observer,
101+
"dtype_configs": [
102+
weighted_op_int8_dtype_config,
103+
],
104+
}
94105

95106
def get_native_backend_config_dict():
96107
""" Get backend for PyTorch Native backend_config_dict (fbgemm/qnnpack)
@@ -102,5 +113,6 @@ def get_native_backend_config_dict():
102113
_LINEAR_MODULE_CONFIG,
103114
*_DEFAULT_OP_INT8_CONFIGS,
104115
_ADD_CONFIG,
116+
_HARDSIGMOID_MODULE_CONFIG,
105117
],
106118
}

torch/ao/quantization/fx/backend_config/quantize_handler.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import torch
2-
from typing import Dict, Callable
2+
from typing import Dict, Callable, Any, Optional
33
from .observation_type import ObservationType
44
from ..quantization_patterns import QuantizeHandler
5-
from ..quantization_types import NodePattern
5+
from ..quantization_types import Pattern, NodePattern
6+
from ...utils import (
7+
activation_dtype,
8+
)
69

710
def get_quantize_handler_cls(
8-
observation_type, dtype_configs, num_tensor_args_to_observation_type):
11+
observation_type,
12+
dtype_configs,
13+
num_tensor_args_to_observation_type,
14+
overwrite_output_fake_quantizer,
15+
overwrite_output_observer):
916

1017
class ConfigurableQuantizeHandler(QuantizeHandler):
1118
def __init__(
@@ -22,8 +29,33 @@ def __init__(
2229
else:
2330
self.observation_type = observation_type
2431
self.dtype_configs = dtype_configs
32+
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
33+
self.overwrite_output_observer = overwrite_output_observer
2534

2635
def is_general_tensor_value_op(self) -> bool:
2736
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
2837

38+
# TODO: change this to output activation
39+
def get_activation_ctr(
40+
self,
41+
qconfig: Any,
42+
pattern: Pattern,
43+
is_training: bool,
44+
) -> Optional[Callable]:
45+
"""
46+
Returns the constructor for the activation observer which should be
47+
used for the pattern matched to this handler. Some handlers override
48+
this to a different value than what is specified in the qconfig.
49+
"""
50+
act_dtype = activation_dtype(qconfig)
51+
# TODO: change to is_qat
52+
if is_training:
53+
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
54+
return self.overwrite_output_fake_quantizer
55+
else:
56+
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
57+
return self.overwrite_output_observer
58+
return qconfig.activation
59+
60+
2961
return ConfigurableQuantizeHandler

torch/ao/quantization/fx/backend_config/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,15 @@ def get_pattern_to_quantize_handlers(
1919
observation_type = config.get("observation_type", None)
2020
dtype_configs = config["dtype_configs"]
2121
num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
22+
overwrite_fake_quantizer = config.get("overwrite_output_fake_quantizer", None)
23+
overwrite_observer = config.get("overwrite_output_observer", None)
2224
pattern_to_quantize_handlers[pattern] = \
23-
get_quantize_handler_cls(observation_type, dtype_configs, num_tensor_args_to_observation_type)
25+
get_quantize_handler_cls(
26+
observation_type,
27+
dtype_configs,
28+
num_tensor_args_to_observation_type,
29+
overwrite_fake_quantizer,
30+
overwrite_observer)
2431

2532
return pattern_to_quantize_handlers
2633

torch/ao/quantization/fx/quantization_patterns.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
239239
"""
240240
pass
241241

242-
@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_observer)
243242
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer)
244243
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_observer)
245244
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_observer)

0 commit comments

Comments
 (0)