Skip to content

Commit e0b9c1d

Browse files
committed
[Quant][fx][bc-breaking] Rename fx/*patterns.py
Summary: This commit renames fx/quantization_patterns.py to fx/quantize_handler.py, and fx/fusion_patterns.py to fx/fuse_handler.py. This is because these files contain only QuantizeHandler and FuseHandler respectively, so the new names are more descriptive. A future commit will further break BC by removing all the empty *QuantizeHandler classes. BC-breaking notes: The following classes under the `torch.ao.quantization.fx.quantization_patterns` namespace are migrated to the `torch.ao.quantization.fx.quantize_handler` namespace: ``` QuantizeHandler BinaryOpQuantizeHandler CatQuantizeHandler ConvReluQuantizeHandler LinearReLUQuantizeHandler BatchNormQuantizeHandler EmbeddingQuantizeHandler RNNDynamicQuantizeHandler DefaultNodeQuantizeHandler FixedQParamsOpQuantizeHandler CopyNodeQuantizeHandler GeneralTensorShapeOpQuantizeHandler CustomModuleQuantizeHandler StandaloneModuleQuantizeHandler ``` The following classes under the `torch.ao.quantization.fx.fusion_patterns` namespace are migrated to the `torch.ao.quantization.fx.fuse_handler` namespace: ``` DefaultFuseHandler FuseHandler ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 5e688fd Pull Request resolved: #89872
1 parent 76c6dfe commit e0b9c1d

File tree

17 files changed

+73
-58
lines changed

17 files changed

+73
-58
lines changed

.github/scripts/gql_mocks.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/quantization/ao_migration/common.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
class AOMigrationTestCase(TestCase):
77
def _test_package_import(self, package_name: str,
88
base: Optional[str] = None,
9-
skip: List[str] = None):
9+
skip: List[str] = None,
10+
new_package_name: Optional[str] = None):
1011
r"""Tests the module import by making sure that all the internals match
1112
(except the dunder methods).
1213
@@ -19,8 +20,10 @@ def _test_package_import(self, package_name: str,
1920
base = base or 'quantization'
2021
old_base = 'torch.' + base
2122
new_base = 'torch.ao.' + base
23+
if new_package_name is None:
24+
new_package_name = package_name
2225
old_module = importlib.import_module(f'{old_base}.{package_name}')
23-
new_module = importlib.import_module(f'{new_base}.{package_name}')
26+
new_module = importlib.import_module(f'{new_base}.{new_package_name}')
2427
old_module_dir = set(dir(old_module))
2528
new_module_dir = set(dir(new_module))
2629
# Remove magic modules from checking in subsets
@@ -36,15 +39,17 @@ def _test_package_import(self, package_name: str,
3639
f"{old_module_dir - new_module_dir}"
3740

3841
def _test_function_import(self, package_name: str, function_list: List[str],
39-
base: Optional[str] = None):
42+
base: Optional[str] = None, new_package_name: Optional[str] = None):
4043
r"""Tests individual function list import by comparing the functions
4144
and their hashes."""
4245
if base is None:
4346
base = 'quantization'
4447
old_base = 'torch.' + base
4548
new_base = 'torch.ao.' + base
49+
if new_package_name is None:
50+
new_package_name = package_name
4651
old_location = importlib.import_module(f'{old_base}.{package_name}')
47-
new_location = importlib.import_module(f'{new_base}.{package_name}')
52+
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
4853
for fn_name in function_list:
4954
old_function = getattr(old_location, fn_name)
5055
new_function = getattr(new_location, fn_name)

test/quantization/ao_migration/test_quantization_fx.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def test_function_import_quantize_fx(self):
2626
self._test_function_import('quantize_fx', function_list)
2727

2828
def test_package_import_fx(self):
29-
self._test_package_import('fx')
29+
self._test_package_import('fx', skip=[
30+
'fusion_patterns',
31+
'quantization_patterns',
32+
])
3033

3134
def test_function_import_fx(self):
3235
function_list = [
@@ -99,7 +102,10 @@ def test_function_import_fx_equalize(self):
99102
self._test_function_import('fx._equalize', function_list)
100103

101104
def test_package_import_fx_quantization_patterns(self):
102-
self._test_package_import('fx.quantization_patterns')
105+
self._test_package_import(
106+
'fx.quantization_patterns',
107+
new_package_name='fx.quantize_handler',
108+
)
103109

104110
def test_function_import_fx_quantization_patterns(self):
105111
function_list = [
@@ -118,7 +124,11 @@ def test_function_import_fx_quantization_patterns(self):
118124
'GeneralTensorShapeOpQuantizeHandler',
119125
'StandaloneModuleQuantizeHandler'
120126
]
121-
self._test_function_import('fx.quantization_patterns', function_list)
127+
self._test_function_import(
128+
'fx.quantization_patterns',
129+
function_list,
130+
new_package_name='fx.quantize_handler',
131+
)
122132

123133
def test_package_import_fx_match_utils(self):
124134
self._test_package_import('fx.match_utils')
@@ -158,14 +168,21 @@ def test_function_import_fx_fuse(self):
158168
self._test_function_import('fx.fuse', function_list)
159169

160170
def test_package_import_fx_fusion_patterns(self):
161-
self._test_package_import('fx.fusion_patterns')
171+
self._test_package_import(
172+
'fx.fusion_patterns',
173+
new_package_name='fx.fuse_handler',
174+
)
162175

163176
def test_function_import_fx_fusion_patterns(self):
164177
function_list = [
165178
'FuseHandler',
166179
'DefaultFuseHandler'
167180
]
168-
self._test_function_import('fx.fusion_patterns', function_list)
181+
self._test_function_import(
182+
'fx.fusion_patterns',
183+
function_list,
184+
new_package_name='fx.fuse_handler',
185+
)
169186

170187
# we removed matching test for torch.quantization.fx.quantization_types
171188
# old: torch.quantization.fx.quantization_types

test/quantization/core/test_backend_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
ObservationType,
1515
)
1616
from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
17-
from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
17+
from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter
1818

1919

2020
class TestBackendConfig(QuantizationTestCase):

test/quantization/fx/test_numeric_suite_fx.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from torch.testing._internal.common_cuda import TEST_CUDA
4242
from torch.testing._internal.common_quantization import NodeSpec as ns
4343
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns
44-
import torch.ao.quantization.fx.quantization_patterns as qp
44+
import torch.ao.quantization.fx.quantize_handler as qh
4545
from torch.ao.ns.fx.pattern_utils import (
4646
get_type_a_related_to_b,
4747
)
@@ -85,7 +85,7 @@
8585
)
8686
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
8787
from torch.ao.quantization.backend_config import get_native_backend_config
88-
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
88+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
8989

9090

9191
# Note: these models are not for use outside of this file. While it's good
@@ -669,21 +669,21 @@ def _op_is_unmatchable(op):
669669
base_op = pattern
670670

671671
qhandler_cls_all_ops_quantizeable = [
672-
qp.CatQuantizeHandler,
673-
qp.ConvReluQuantizeHandler,
674-
qp.LinearReLUQuantizeHandler,
675-
qp.BatchNormQuantizeHandler,
676-
qp.EmbeddingQuantizeHandler,
677-
qp.RNNDynamicQuantizeHandler,
672+
qh.CatQuantizeHandler,
673+
qh.ConvReluQuantizeHandler,
674+
qh.LinearReLUQuantizeHandler,
675+
qh.BatchNormQuantizeHandler,
676+
qh.EmbeddingQuantizeHandler,
677+
qh.RNNDynamicQuantizeHandler,
678678
]
679679

680680
qhandler_cls_quant_op_same_signature = [
681-
qp.FixedQParamsOpQuantizeHandler,
682-
qp.CopyNodeQuantizeHandler,
683-
qp.GeneralTensorShapeOpQuantizeHandler,
681+
qh.FixedQParamsOpQuantizeHandler,
682+
qh.CopyNodeQuantizeHandler,
683+
qh.GeneralTensorShapeOpQuantizeHandler,
684684
]
685685

686-
if qhandler_cls == qp.BinaryOpQuantizeHandler:
686+
if qhandler_cls == qh.BinaryOpQuantizeHandler:
687687
# these ops do not have quantized equivalents
688688
ops_to_skip = [
689689
torch.bmm,
@@ -697,11 +697,11 @@ def _op_is_unmatchable(op):
697697
self.assertTrue(
698698
_op_in_base_sets_of_related_ops(base_op),
699699
f"{base_op} not in sets of related ops")
700-
elif qhandler_cls == qp.RNNDynamicQuantizeHandler:
700+
elif qhandler_cls == qh.RNNDynamicQuantizeHandler:
701701
# TODO(future PR): add support for all classes in
702702
# RNNDynamicQuantizeHandler
703703
pass
704-
elif qhandler_cls == qp.DefaultNodeQuantizeHandler:
704+
elif qhandler_cls == qh.DefaultNodeQuantizeHandler:
705705
self.assertTrue(
706706
_op_in_base_sets_of_related_ops(base_op),
707707
f"{base_op} not in sets of related ops")
@@ -1606,33 +1606,33 @@ def test_op_io_dtype_coverage(self):
16061606

16071607
if (
16081608
qhandler_cls in (
1609-
qp.BinaryOpQuantizeHandler,
1610-
qp.RNNDynamicQuantizeHandler,
1609+
qh.BinaryOpQuantizeHandler,
1610+
qh.RNNDynamicQuantizeHandler,
16111611
)
16121612
):
16131613
# TODO(future PR): implement shadowing for binary ops
16141614
# TODO(future PR): implement shadowing for RNN ops
16151615
continue
1616-
elif qhandler_cls == qp.CatQuantizeHandler:
1616+
elif qhandler_cls == qh.CatQuantizeHandler:
16171617
self.assertTrue(
16181618
base_op in FUNS_IO_TYPE_FP32_OR_INT8,
16191619
f"missing IO type handling for {base_op}")
16201620
elif (
16211621
qhandler_cls in (
1622-
qp.ConvReluQuantizeHandler,
1623-
qp.LinearReLUQuantizeHandler,
1624-
qp.BatchNormQuantizeHandler,
1625-
qp.DefaultNodeQuantizeHandler,
1622+
qh.ConvReluQuantizeHandler,
1623+
qh.LinearReLUQuantizeHandler,
1624+
qh.BatchNormQuantizeHandler,
1625+
qh.DefaultNodeQuantizeHandler,
16261626
)
16271627
):
16281628
self.assertTrue(
16291629
(base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32),
16301630
f"missing IO type handling for {base_op}")
16311631
elif (
16321632
qhandler_cls in (
1633-
qp.FixedQParamsOpQuantizeHandler,
1634-
qp.CopyNodeQuantizeHandler,
1635-
qp.GeneralTensorShapeOpQuantizeHandler,
1633+
qh.FixedQParamsOpQuantizeHandler,
1634+
qh.CopyNodeQuantizeHandler,
1635+
qh.GeneralTensorShapeOpQuantizeHandler,
16361636
)
16371637
):
16381638
if (
@@ -1650,7 +1650,7 @@ def test_op_io_dtype_coverage(self):
16501650
# version, so it does not fit into the cases above.
16511651
(base_op is torch.nn.Softmax),
16521652
f"missing IO type handling for {base_op}")
1653-
elif qhandler_cls == qp.EmbeddingQuantizeHandler:
1653+
elif qhandler_cls == qh.EmbeddingQuantizeHandler:
16541654
# embedding shadowing is not implemented, for now
16551655
continue
16561656
else:

test/quantization/fx/test_quantize_fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525

2626

27-
from torch.ao.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
27+
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
2828

2929
from torch.ao.quantization.fx.match_utils import (
3030
is_match,

torch/ao/ns/_numeric_suite_fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
from torch.ao.quantization.backend_config import BackendConfig
124124
from torch.ao.quantization.fx.match_utils import find_matches
125125
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
126-
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
126+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
127127
from torch.ao.quantization.qconfig import QConfigAny
128128
from torch.ao.ns.fx.n_shadows_utils import (
129129
OutputProp,

torch/ao/ns/fx/pattern_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.fx.graph import Node
88

99
from torch.ao.quantization.backend_config import get_native_backend_config
10-
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
10+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
1111
from torch.ao.quantization.utils import getattr_from_fqn
1212
from .ns_types import NSNodeTargetType
1313
from torch.ao.quantization import (

torch/ao/quantization/fx/fuse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from .custom_config import FuseCustomConfig
2929

30-
from .fusion_patterns import (
30+
from .fuse_handler import (
3131
_get_fusion_pattern_to_fuse_handler_cls,
3232
FuseHandler,
3333
)
@@ -40,6 +40,9 @@
4040

4141
__all__ = [
4242
"fuse",
43+
# TODO: We should make this private in the future
44+
# This is currently needed for test_public_bindings for some reason
45+
"FuseHandler",
4346
]
4447

4548

torch/ao/quantization/fx/fusion_patterns.py renamed to torch/ao/quantization/fx/fuse_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def fuse(self,
3939
is_qat: bool) -> Node:
4040
pass
4141

42-
# TODO: move this to backend_config_utils
4342
class DefaultFuseHandler(FuseHandler):
4443
def __init__(
4544
self,

0 commit comments

Comments
 (0)