Skip to content

Commit 3e3ca34

Browse files
committed
[Quant][fx][bc-breaking] Remove backend_config_utils.py
Summary: Previously under torch/ao/quantization we have backend_config/utils.py and fx/backend_config_utils.py, which was confusing. This commit deletes the latter and moves everything there to more suitable util files. BC-breaking note: The following public APIs under the `torch.ao.quantization.fx.backend_config_utils` namespace are removed in this commit. ``` get_quantize_handler_cls get_fusion_pattern_to_fuse_handler_cls get_native_quant_patterns get_pattern_to_quantize_handlers ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 60fbdf6 Pull Request resolved: #89810
1 parent 22e7514 commit 3e3ca34

File tree

8 files changed

+93
-129
lines changed

8 files changed

+93
-129
lines changed

test/quantization/fx/test_numeric_suite_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
)
8686
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
8787
from torch.ao.quantization.backend_config import get_native_backend_config
88-
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
88+
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
8989

9090

9191
# Note: these models are not for use outside of this file. While it's good
@@ -299,7 +299,7 @@ def get_all_quant_patterns():
299299
all_quant_patterns = get_default_quant_patterns()
300300
# some of the patterns are moved to (native) backend_config_dict so we need to
301301
# add them back here
302-
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items():
302+
for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items():
303303
all_quant_patterns[pattern] = quantize_handler
304304
return all_quant_patterns
305305

torch/ao/ns/_numeric_suite_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@
121121
)
122122
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
123123
from torch.ao.quantization.backend_config import BackendConfig
124-
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
125124
from torch.ao.quantization.fx.match_utils import find_matches
126125
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
126+
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
127127
from torch.ao.quantization.qconfig import QConfigAny
128128
from torch.ao.ns.fx.n_shadows_utils import (
129129
OutputProp,
@@ -803,7 +803,7 @@ def prepare_n_shadows_model(
803803
# Find the set of subgraphs in the original graph which we need to
804804
# consider.
805805
modules = dict(mt.named_modules(remove_duplicate=False))
806-
patterns = get_pattern_to_quantize_handlers(backend_config)
806+
patterns = _get_pattern_to_quantize_handlers(backend_config)
807807
root_node_getter_mapping = \
808808
get_fusion_pattern_to_root_node_getter(backend_config)
809809
standalone_module_names: List[str] = []

torch/ao/ns/fx/pattern_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from torch.fx import GraphModule
77
from torch.fx.graph import Node
88

9+
from torch.ao.quantization.backend_config import get_native_backend_config
10+
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
911
from torch.ao.quantization.utils import getattr_from_fqn
1012
from .ns_types import NSNodeTargetType
11-
from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns
1213
from torch.ao.quantization import (
1314
ObserverBase,
1415
FakeQuantizeBase,
@@ -66,7 +67,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
6667
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
6768
# For fusions, we only care about patterns composed of multiple ops.
6869
# TODO(future PR): allow customizations from default patterns.
69-
all_quant_patterns = get_native_quant_patterns()
70+
all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
7071

7172
default_base_op_idx = 0
7273
for quant_pattern, _quant_handler in all_quant_patterns.items():

torch/ao/quantization/fx/backend_config_utils.py

Lines changed: 0 additions & 113 deletions
This file was deleted.

torch/ao/quantization/fx/fuse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
get_fusion_pattern_to_root_node_getter,
2525
get_fusion_pattern_to_extra_inputs_getter,
2626
)
27-
from .backend_config_utils import get_fusion_pattern_to_fuse_handler_cls
2827

2928
from .custom_config import FuseCustomConfig
3029

31-
from .fusion_patterns import * # noqa: F401,F403
30+
from .fusion_patterns import (
31+
_get_fusion_pattern_to_fuse_handler_cls,
32+
FuseHandler,
33+
)
3234

3335
from typing import Any, Callable, Dict, List, Tuple, Union
3436
import warnings
@@ -69,7 +71,7 @@ def fuse(
6971
if backend_config is None:
7072
backend_config = get_native_backend_config()
7173

72-
fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(get_fusion_pattern_to_fuse_handler_cls(backend_config))
74+
fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
7375
fuser_method_mapping = get_fuser_method_mapping(backend_config)
7476
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
7577
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)

torch/ao/quantization/fx/fusion_patterns.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch.ao.quantization.backend_config import BackendConfig
23
from torch.fx.graph import Node, Graph
34
from ..utils import _parent_name, NodePattern, Pattern
45
from ..fuser_method_mappings import get_fuser_method_new
@@ -108,3 +109,12 @@ def get_matched_types(m):
108109
args.extend(extra_args)
109110
node.args = tuple(args)
110111
return node
112+
113+
def _get_fusion_pattern_to_fuse_handler_cls(
114+
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
115+
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
116+
for pattern, config in backend_config.configs.items():
117+
if config.fuser_method is not None:
118+
# TODO: is this logic right?
119+
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
120+
return fusion_pattern_to_fuse_handlers

torch/ao/quantization/fx/prepare.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535

3636
from .quantization_patterns import (
37+
_get_pattern_to_quantize_handlers,
3738
QuantizeHandler,
3839
)
3940

@@ -98,10 +99,6 @@
9899
DTypeConfig,
99100
get_native_backend_config,
100101
)
101-
from .backend_config_utils import (
102-
get_pattern_to_quantize_handlers,
103-
)
104-
105102
from .custom_config import (
106103
PrepareCustomConfig,
107104
StandaloneModuleConfigEntry,
@@ -1520,7 +1517,7 @@ def prepare(
15201517
pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
15211518
if backend_config is None:
15221519
backend_config = get_native_backend_config()
1523-
pattern_to_quantize_handler = get_pattern_to_quantize_handlers(backend_config)
1520+
pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
15241521
pattern_to_quantize_handler = sorted_patterns_dict(pattern_to_quantize_handler)
15251522

15261523
root_node_getter_mapping = \

torch/ao/quantization/fx/quantization_patterns.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,19 @@
66
from .utils import (
77
all_node_args_have_no_tensors,
88
)
9-
from torch.ao.quantization.utils import NodePattern
9+
from torch.ao.quantization.backend_config import (
10+
BackendConfig,
11+
DTypeConfig,
12+
ObservationType,
13+
)
14+
from torch.ao.quantization.utils import (
15+
NodePattern,
16+
Pattern,
17+
QuantizerCls,
18+
)
1019

1120
from abc import ABC
12-
from typing import Callable, Dict
21+
from typing import Callable, Dict, List, Type
1322

1423
__all__ = [
1524
"QuantizeHandler",
@@ -101,6 +110,64 @@ def is_custom_module(self):
101110
def is_standalone_module(self):
102111
return self.is_standalone_module_
103112

113+
def _get_quantize_handler_cls(
114+
observation_type: ObservationType,
115+
dtype_configs: List[DTypeConfig],
116+
num_tensor_args_to_observation_type: Dict[int, ObservationType],
117+
input_output_observed: bool) -> Type[QuantizeHandler]:
118+
"""
119+
Return a configurable QuantizeHandler that matches the given specifications from the backend.
120+
"""
121+
122+
class ConfigurableQuantizeHandler(QuantizeHandler):
123+
def __init__(
124+
self,
125+
node_pattern: NodePattern,
126+
modules: Dict[str, torch.nn.Module],
127+
root_node_getter: Callable = None):
128+
super().__init__(node_pattern, modules, root_node_getter)
129+
if num_tensor_args_to_observation_type:
130+
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
131+
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
132+
f" in num_tensor_args_to_observation_type for {node_pattern}"
133+
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
134+
else:
135+
self.observation_type = observation_type
136+
self.dtype_configs = dtype_configs
137+
self.input_output_observed_ = input_output_observed
138+
139+
def is_general_tensor_value_op(self) -> bool:
140+
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
141+
142+
# This is temporary, and will be removed soon
143+
def input_output_observed(self):
144+
return self.input_output_observed_
145+
146+
return ConfigurableQuantizeHandler
147+
148+
def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
149+
"""
150+
Note: Quantize handler is just a holder for some check methods like
151+
(should_insert_observer_for_output), maybe this can be a enum as well,
152+
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
153+
new path, this is not exposed to backend developers
154+
"""
155+
pattern_to_quantize_handlers = {}
156+
for pattern, config in backend_config.configs.items():
157+
observation_type = config.observation_type
158+
dtype_configs = config.dtype_configs
159+
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
160+
input_output_observed = config._input_output_observed
161+
if input_output_observed is None:
162+
input_output_observed = True
163+
pattern_to_quantize_handlers[pattern] = \
164+
_get_quantize_handler_cls(
165+
observation_type,
166+
dtype_configs,
167+
num_tensor_args_to_observation_type,
168+
input_output_observed)
169+
return pattern_to_quantize_handlers
170+
104171
# TODO: remove this class, this is still exposed in torch.quantization
105172
# but we should be able to break bc
106173
class BinaryOpQuantizeHandler(QuantizeHandler):

0 commit comments

Comments
 (0)