Skip to content

Commit 037e036

Browse files
committed
[Quant][FX] Lower QLinearLeakyReLU for onednn backend
ghstack-source-id: 0700231 Pull Request resolved: #88668
1 parent a2ba401 commit 037e036

File tree

6 files changed

+68
-0
lines changed

6 files changed

+68
-0
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
from torch.testing._internal.common_quantization import (
158158
LinearReluLinearModel,
159159
LinearReluModel,
160+
LinearBnLeakyReluModel,
160161
QuantizationTestCase,
161162
skipIfNoFBGEMM,
162163
skip_if_no_torchvision,
@@ -166,6 +167,7 @@
166167
test_only_train_fn,
167168
ModelForConvTransposeBNFusion,
168169
get_supported_device_types,
170+
skipIfNoONEDNN,
169171
)
170172

171173
from torch.testing._internal.common_quantization import (
@@ -363,6 +365,48 @@ def forward(self, x):
363365
expected_node_list=expected_nodes,
364366
expected_node_occurrence=expected_occurrence)
365367

368+
@skipIfNoONEDNN
369+
def test_fuse_linear_bn_leaky_relu_eval(self):
370+
# linear - bn - leaky_relu is fused for onednn backend only
371+
from torch.ao.quantization.backend_config import get_onednn_backend_config
372+
expected_nodes = [
373+
ns.call_module(nni.LinearLeakyReLU),
374+
]
375+
expected_occurrence = {
376+
ns.call_module(nn.BatchNorm1d): 0,
377+
ns.call_module(nn.LeakyReLU): 0,
378+
}
379+
380+
for with_bn in [True, False]:
381+
# test eval mode
382+
m = LinearBnLeakyReluModel(with_bn).eval()
383+
# fuse_fx is a top level api and only supports eval mode
384+
m = fuse_fx(m,
385+
backend_config=get_onednn_backend_config())
386+
self.checkGraphModuleNodes(
387+
m,
388+
expected_node_list=expected_nodes,
389+
expected_node_occurrence=expected_occurrence)
390+
391+
def test_no_fuse_linear_bn_leaky_relu_eval(self):
392+
# Make sure linear - bn - leaky_relu is not fused by default
393+
for with_bn in [True, False]:
394+
# test eval mode
395+
m = LinearBnLeakyReluModel(with_bn).eval()
396+
# fuse_fx is a top level api and only supports eval mode
397+
m = fuse_fx(m)
398+
expected_nodes = [
399+
ns.call_module(nn.Linear),
400+
ns.call_module(nn.LeakyReLU),
401+
]
402+
expected_occurrence = {
403+
ns.call_module(nni.LinearLeakyReLU): 0,
404+
}
405+
self.checkGraphModuleNodes(
406+
m,
407+
expected_node_list=expected_nodes,
408+
expected_node_occurrence=expected_occurrence)
409+
366410
def test_fuse_convtranspose_bn_eval(self):
367411

368412
m = ModelForConvTransposeBNFusion().eval()

torch/ao/ns/fx/mappings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
607607
nniqat.LinearReLU,
608608
nniqat.LinearBn1d,
609609
nniqd.LinearReLU,
610+
nni.LinearLeakyReLU,
610611
])
611612

612613
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
@@ -637,6 +638,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
637638
nniq.ConvReLU2d,
638639
nniq.ConvReLU3d,
639640
nniq.LinearReLU,
641+
nniq.LinearLeakyReLU,
640642
])
641643

642644
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([

torch/ao/quantization/backend_config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .qnnpack import get_qnnpack_backend_config
55
from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
66
from .executorch import get_executorch_backend_config
7+
from .onednn import get_onednn_backend_config
78

89
__all__ = [
910
"get_fbgemm_backend_config",
@@ -17,4 +18,5 @@
1718
"BackendPatternConfig",
1819
"DTypeConfig",
1920
"ObservationType",
21+
"get_onednn_backend_config",
2022
]

torch/ao/quantization/fx/_lower_to_native_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA
253253
# 2) The replacement static quantized module class for lowering
254254
STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
255255
nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
256+
nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU),
256257
nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
257258
nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
258259
nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),

torch/ao/quantization/quantization_mappings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
nni.ConvReLU2d: nniq.ConvReLU2d,
111111
nni.ConvReLU3d: nniq.ConvReLU3d,
112112
nni.LinearReLU: nniq.LinearReLU,
113+
nni.LinearLeakyReLU: nniq.LinearLeakyReLU,
113114
nniqat.ConvBn1d: nnq.Conv1d,
114115
nniqat.ConvBn2d: nnq.Conv2d,
115116
nniqat.ConvBn3d: nnq.Conv3d,

torch/testing/_internal/common_quantization.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,24 @@ def forward(self, x):
13881388
def get_example_inputs(self) -> Tuple[Any, ...]:
13891389
return (torch.rand(1, 5),)
13901390

1391+
class LinearBnLeakyReluModel(torch.nn.Module):
1392+
def __init__(self, with_bn=True):
1393+
super().__init__()
1394+
self.linear = nn.Linear(5, 5)
1395+
self.bn1d = nn.BatchNorm1d(5)
1396+
self.leaky_relu = nn.LeakyReLU(0.01)
1397+
self.with_bn = with_bn
1398+
1399+
def forward(self, x):
1400+
x = self.linear(x)
1401+
if self.with_bn:
1402+
x = self.bn1d(x)
1403+
x = self.leaky_relu(x)
1404+
return x
1405+
1406+
def get_example_inputs(self) -> Tuple[Any, ...]:
1407+
return (torch.rand(1, 5),)
1408+
13911409
# TODO: self.fc should be self.conv
13921410
class ConvReluModel(torch.nn.Module):
13931411
def __init__(self):

0 commit comments

Comments
 (0)