Skip to content

Commit 691a44f

Browse files
andrewor14pytorchmergebot
authored andcommitted
[Quant][fx][bc-breaking] Add simpler BackendConfig pattern format (#90698)
Summary: The existing BackendConfig fusion pattern uses a "reversed nested tuple" format that is highly unintuitive. For example, ``` linear-relu -> (nn.ReLU, nn.Linear) conv-bn-relu -> (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) ``` This pattern format also complicates the signatures of the user specified "fuser methods", which needed to accept arguments in reverse nested order to match the patterns: ``` def fuse_linear_relu(is_qat, relu, linear): ... def fuse_conv_bn_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv ... ``` Instead, this commit introduces a new pattern format that simply specifies the ops in forward order with no nesting: ``` linear-relu -> (nn.Linear, nn.ReLU) conv-bn-relu -> (nn.Conv2d, nn.BatchNorm2d, nn.ReLU) def fuse_linear_relu(is_qat, linear, relu): ... def fuse_conv_bn_relu(is_qat, conv, bn, relu): ... ``` Note that the legacy "reversed nested tuple" is still used internally since it is more general. In the future, we should replace it with the format used in the subgraph rewriter in `torch.fx`, and simplify the existing pattern matching code to handle the new format added in this commit. BC-breaking Notes: Before: ``` import torch as nn import torch.ao.nn.intrinsic as nni from torch.ao.quantization.backend_config import BackendPatternConfig def fuse_linear_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) ``` After: ``` def fuse_linear_relu(is_qat, conv, bn, relu): return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig((nn.Conv2d, nn.BatchNorm2d, nn.ReLU)) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) ``` OR (for backward-compatibility) ``` def fuse_linear_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) \ ._set_use_legacy_pattern_format(True) ``` Before: ``` backend_config.configs # returns Dict[Pattern, BackendPatternConfig] ``` After: ``` backend_config.configs # returns List[BackendPatternConfig] ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestBackendConfig Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo Differential Revision: [D41954553](https://our.internmc.facebook.com/intern/diff/D41954553) Pull Request resolved: #90698 Approved by: https://github.com/vkuzo, https://github.com/jerryzh168
1 parent 1e347b7 commit 691a44f

File tree

14 files changed

+285
-161
lines changed

14 files changed

+285
-161
lines changed

test/quantization/core/test_backend_config.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
DTypeWithConstraints,
1414
ObservationType,
1515
)
16-
from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
16+
from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
1717
from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter
1818

1919

@@ -104,7 +104,7 @@ def test_dtype_config_to_dict(self):
104104
# BackendPatternConfig
105105
# ======================
106106

107-
_fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU)
107+
_fuser_method = _sequential_wrapper2(nni.LinearReLU)
108108

109109
_num_tensor_args_to_observation_type = {
110110
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
@@ -121,7 +121,7 @@ def _extra_inputs_getter(self, p):
121121
return (torch.rand(3, 3),)
122122

123123
def _get_backend_op_config1(self):
124-
return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
124+
return BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) \
125125
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
126126
.add_dtype_config(self.dtype_config1) \
127127
.add_dtype_config(self.dtype_config2) \
@@ -142,7 +142,7 @@ def _get_backend_op_config2(self):
142142

143143
def _get_backend_pattern_config_dict1(self):
144144
return {
145-
"pattern": (torch.nn.ReLU, torch.nn.Linear),
145+
"pattern": (torch.nn.Linear, torch.nn.ReLU),
146146
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
147147
"dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
148148
"root_module": torch.nn.Linear,
@@ -198,19 +198,19 @@ def test_backend_op_config_set_reference_quantized_module(self):
198198
self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
199199

200200
def test_backend_op_config_set_fused_module(self):
201-
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
201+
conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
202202
self.assertTrue(conf.fused_module is None)
203203
conf.set_fused_module(nni.LinearReLU)
204204
self.assertEqual(conf.fused_module, nni.LinearReLU)
205205

206206
def test_backend_op_config_set_fuser_method(self):
207-
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
207+
conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
208208
self.assertTrue(conf.fuser_method is None)
209209
conf.set_fuser_method(self._fuser_method)
210210
self.assertEqual(conf.fuser_method, self._fuser_method)
211211

212212
def test_backend_op_config_set_root_node_getter(self):
213-
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
213+
conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
214214
self.assertTrue(conf._root_node_getter is None)
215215
conf._set_root_node_getter(_default_root_node_getter)
216216
self.assertEqual(conf._root_node_getter, _default_root_node_getter)
@@ -242,7 +242,7 @@ def test_backend_op_config_set_input_output_observed(self):
242242
def test_backend_op_config_from_dict(self):
243243
conf_dict1 = self._get_backend_pattern_config_dict1()
244244
conf1 = BackendPatternConfig.from_dict(conf_dict1)
245-
self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
245+
self.assertEqual(conf1.pattern, (torch.nn.Linear, torch.nn.ReLU))
246246
self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
247247
self.assertEqual(conf1.root_module, torch.nn.Linear)
248248
self.assertEqual(conf1.qat_module, nnqat.Linear)
@@ -294,11 +294,11 @@ def test_backend_config_set_backend_pattern_config(self):
294294
backend_op_config1 = self._get_backend_op_config1()
295295
backend_op_config2 = self._get_backend_op_config2()
296296
conf.set_backend_pattern_config(backend_op_config1)
297-
self.assertEqual(conf.configs, {
297+
self.assertEqual(conf._pattern_complex_format_to_config, {
298298
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
299299
})
300300
conf.set_backend_pattern_config(backend_op_config2)
301-
self.assertEqual(conf.configs, {
301+
self.assertEqual(conf._pattern_complex_format_to_config, {
302302
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
303303
torch.add: backend_op_config2
304304
})
@@ -317,10 +317,10 @@ def test_backend_config_from_dict(self):
317317
self.assertEqual(len(conf.configs), 2)
318318
key1 = (torch.nn.ReLU, torch.nn.Linear)
319319
key2 = torch.add
320-
self.assertTrue(key1 in conf.configs)
321-
self.assertTrue(key2 in conf.configs)
322-
self.assertEqual(conf.configs[key1].to_dict(), op_dict1)
323-
self.assertEqual(conf.configs[key2].to_dict(), op_dict2)
320+
self.assertTrue(key1 in conf._pattern_complex_format_to_config)
321+
self.assertTrue(key2 in conf._pattern_complex_format_to_config)
322+
self.assertEqual(conf._pattern_complex_format_to_config[key1].to_dict(), op_dict1)
323+
self.assertEqual(conf._pattern_complex_format_to_config[key2].to_dict(), op_dict2)
324324

325325
def test_backend_config_to_dict(self):
326326
op1 = self._get_backend_op_config1()

test/quantization/fx/test_quantize_fx.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,11 @@ def fuse_conv_bn_relu(is_qat, relu, add_pattern):
546546
bn, conv = bn_pattern
547547
return conv
548548

549-
conv_bn_res_relu_config1 = BackendPatternConfig((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
549+
conv_bn_res_relu_config1 = BackendPatternConfig() \
550+
._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
550551
.set_fuser_method(fuse_conv_bn_relu)
551-
conv_bn_res_relu_config2 = BackendPatternConfig((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
552+
conv_bn_res_relu_config2 = BackendPatternConfig() \
553+
._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
552554
.set_fuser_method(fuse_conv_bn_relu)
553555
backend_config = BackendConfig() \
554556
.set_backend_pattern_config(conv_bn_res_relu_config1) \
@@ -606,7 +608,8 @@ def conv_bn_res_relu_extra_inputs_getter(pattern):
606608
bn, conv = bn_pattern
607609
return [extra_input]
608610

609-
conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
611+
conv_bn_res_relu_config = BackendPatternConfig() \
612+
._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
610613
.set_fuser_method(fuse_conv_bn_relu) \
611614
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
612615
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
@@ -654,7 +657,7 @@ def forward(self, x):
654657

655658
m = M().eval()
656659

657-
def fuse_conv_relu(is_qat, relu, conv):
660+
def fuse_conv_relu(is_qat, conv, relu):
658661
return conv
659662

660663
def fuse_conv_res_relu(is_qat, relu, add_pattern):
@@ -669,9 +672,10 @@ def conv_res_relu_extra_inputs_getter(pattern):
669672
relu, (_, _, extra_input) = pattern
670673
return [extra_input]
671674

672-
conv_relu_config = BackendPatternConfig((nn.ReLU, nn.Conv2d)) \
675+
conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \
673676
.set_fuser_method(fuse_conv_relu)
674-
conv_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
677+
conv_res_relu_config = BackendPatternConfig() \
678+
._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
675679
.set_fuser_method(fuse_conv_res_relu) \
676680
._set_root_node_getter(conv_res_relu_root_node_getter) \
677681
._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter)
@@ -5545,10 +5549,12 @@ def root_node_getter(node_pattern):
55455549
return transpose
55465550

55475551
backend_pattern_configs.append(
5548-
BackendPatternConfig((torch.reshape, torch.transpose, MatchAllNode))
5549-
.set_observation_type(observation_type) # noqa: E131
5552+
BackendPatternConfig()
5553+
._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131
5554+
.set_observation_type(observation_type)
55505555
.set_dtype_configs(dtype_configs)
5551-
._set_root_node_getter(root_node_getter))
5556+
._set_root_node_getter(root_node_getter)
5557+
)
55525558
return backend_pattern_configs
55535559

55545560
backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs())

torch/ao/ns/fx/mappings.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
import torch.nn.intrinsic as nni
1414
import torch.ao.nn.qat as nnqat
1515
import torch.ao.nn.qat.dynamic as nnqatd
16-
from torch.ao.quantization.backend_config import get_native_backend_config_dict
16+
from torch.ao.quantization.backend_config import get_native_backend_config
1717
import torch.ao.quantization.fx._lower_to_native_backend as \
1818
_lower_to_native_backend
1919
import torch.ao.quantization.quantization_mappings as quantization_mappings
2020

2121
from .ns_types import NSNodeTargetType
2222

23-
from typing import Set, Dict, List, Optional
23+
from typing import Callable, Dict, List, Optional, Set, Tuple
2424

2525

2626
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
27-
# note: this set is modified below by items from backend_config_dict
27+
# note: this set is modified below by items from backend_config
2828
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
2929
# conv modules
3030
set([
@@ -327,42 +327,36 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
327327
]
328328

329329
# for each floating point op, add versions of the op added by
330-
# backend_config_dict
331-
backend_config_dict = get_native_backend_config_dict()
330+
# backend_config
331+
backend_config = get_native_backend_config()
332332

333-
new_connections = [
333+
new_connections: List[Tuple[Callable, Callable]] = [
334334
# technical debt edge case
335335
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
336336
]
337337

338-
for config in backend_config_dict['configs']:
338+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
339339

340-
if 'pattern' not in config:
341-
continue
342-
343-
# format: (c, (b, a))
344-
pattern = config['pattern']
340+
# pattern format: (c, (b, a))
345341
first_element = pattern
346342
# look from the end, because pattern is in reverse order
347343
while isinstance(first_element, (list, tuple)):
348344
first_element = first_element[-1]
349345

350-
if 'fused_module' in config:
346+
if config.fused_module is not None:
351347
# case 1: pattern fuses a pattern of ops into an op
352348
# example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
353-
new_connections.append((first_element, config['fused_module']))
349+
new_connections.append((first_element, config.fused_module))
354350

355-
if 'qat_module' in config:
351+
if config.qat_module is not None:
356352
# case 2: pattern swaps a module into a QAT module
357353
# example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
358-
new_connections.append((first_element, config['qat_module']))
354+
new_connections.append((first_element, config.qat_module))
359355

360-
if 'reference_quantized_module_for_root' in config:
356+
if config.reference_quantized_module is not None:
361357
# case 3: reference version of floating point module, such as
362358
# nn.Conv2d and nnqr.Conv2d
363-
new_connections.append(
364-
(first_element, config['reference_quantized_module_for_root'])
365-
)
359+
new_connections.append((first_element, config.reference_quantized_module))
366360

367361
#
368362
# Add reference module swaps from default lowering path
@@ -413,7 +407,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
413407
new_connections.append((source, target))
414408

415409

416-
# add the new connections from backend_config_dict
410+
# add the new connections from backend_config
417411
for item1, item2 in new_connections:
418412
for set_of_related_ops in sets_of_related_ops:
419413
if item1 in set_of_related_ops or item2 in set_of_related_ops:

torch/ao/quantization/backend_config/README.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,19 @@ Instead of hardcoding the fusion mappings, float to reference quantized module m
2222

2323
## Pattern Specification
2424

25-
The operator patterns used in BackendConfig are float modules, functional operators and pytorch operators specified in reverse order:
25+
The operator patterns used in BackendConfig are float modules, functional operators, pytorch operators, or a tuple combination of the above. For example:
26+
* torch.nn.Linear
27+
* torch.nn.functional.linear
28+
* torch.add
29+
* operator.add
30+
* (torch.nn.functional.linear, torch.nn.functional.relu)
31+
* (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU)
32+
33+
Tuple patterns are treated as sequential patterns, and currently only tuples of 2 or 3 elements are supported.
34+
35+
### Advanced Pattern Specification
36+
37+
The above format should satisfy the vast majority of use cases. However, it does not handle more complex scenarios such as graph patterns. For these use cases, the BackendConfig API offers an alternative "reverse nested tuple" pattern format, enabled through `BackendPatternConfig()._set_pattern_complex_format(...)`. Note that this format is deprecated and will be replaced in a future version of PyTorch.
2638
```
2739
operator = module_type | functional | torch op | native op | MatchAllNode
2840
Pattern = (operator, Pattern, Pattern, ...) | operator
@@ -62,7 +74,7 @@ weighted_int8_dtype_config = DTypeConfig(
6274
weight_dtype=torch.qint8,
6375
bias_dtype=torch.float)
6476
65-
def fuse_conv2d_relu(is_qat, relu, conv):
77+
def fuse_conv2d_relu(is_qat, conv, relu):
6678
"""Return a fused ConvReLU2d from individual conv and relu modules."""
6779
return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
6880
@@ -75,7 +87,7 @@ linear_config = BackendPatternConfig(torch.nn.Linear) \
7587
.set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
7688
7789
# For fusing Conv2d + ReLU into ConvReLU2d
78-
conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
90+
conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
7991
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
8092
.add_dtype_config(weighted_int8_dtype_config) \
8193
.set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
@@ -118,7 +130,7 @@ Relevant APIs:
118130
* `_set_root_node_getter`
119131
* `_set_extra_inputs_getter`
120132

121-
As an optimization, operator patterns such as (`torch.nn.ReLU`, `torch.nn.Linear`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules.
133+
As an optimization, operator patterns such as (`torch.nn.Linear`, `torch.nn.ReLU`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules.
122134

123135
In FX graph mode quantization, we replace the corresponding nodes in the graph using two helper functions set by the user: `root_node_getter`, which returns the root node (typically the weighted module in the pattern like `torch.nn.Linear`) to replace the matched pattern in the graph, and `extra_inputs_getter`, which returns a list of extra input arguments that will be appended to the existing arguments of the fused module (copied over from the root node). See [this snippet](https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6) for an example usage.
124136

0 commit comments

Comments
 (0)