|
157 | 157 | from torch.testing._internal.common_quantization import ( |
158 | 158 | LinearReluLinearModel, |
159 | 159 | LinearReluModel, |
| 160 | + LinearBnLeakyReluModel, |
160 | 161 | QuantizationTestCase, |
161 | 162 | skipIfNoFBGEMM, |
162 | 163 | skip_if_no_torchvision, |
|
166 | 167 | test_only_train_fn, |
167 | 168 | ModelForConvTransposeBNFusion, |
168 | 169 | get_supported_device_types, |
| 170 | + skipIfNoONEDNN, |
169 | 171 | ) |
170 | 172 |
|
171 | 173 | from torch.testing._internal.common_quantization import ( |
@@ -363,6 +365,48 @@ def forward(self, x): |
363 | 365 | expected_node_list=expected_nodes, |
364 | 366 | expected_node_occurrence=expected_occurrence) |
365 | 367 |
|
| 368 | + @skipIfNoONEDNN |
| 369 | + def test_fuse_linear_bn_leaky_relu_eval(self): |
| 370 | + # linear - bn - leaky_relu is fused for onednn backend only |
| 371 | + from torch.ao.quantization.backend_config import get_onednn_backend_config |
| 372 | + expected_nodes = [ |
| 373 | + ns.call_module(nni.LinearLeakyReLU), |
| 374 | + ] |
| 375 | + expected_occurrence = { |
| 376 | + ns.call_module(nn.BatchNorm1d): 0, |
| 377 | + ns.call_module(nn.LeakyReLU): 0, |
| 378 | + } |
| 379 | + |
| 380 | + for with_bn in [True, False]: |
| 381 | + # test eval mode |
| 382 | + m = LinearBnLeakyReluModel(with_bn).eval() |
| 383 | + # fuse_fx is a top level api and only supports eval mode |
| 384 | + m = fuse_fx(m, |
| 385 | + backend_config=get_onednn_backend_config()) |
| 386 | + self.checkGraphModuleNodes( |
| 387 | + m, |
| 388 | + expected_node_list=expected_nodes, |
| 389 | + expected_node_occurrence=expected_occurrence) |
| 390 | + |
| 391 | + def test_no_fuse_linear_bn_leaky_relu_eval(self): |
| 392 | + # Make sure linear - bn - leaky_relu is not fused by default |
| 393 | + for with_bn in [True, False]: |
| 394 | + # test eval mode |
| 395 | + m = LinearBnLeakyReluModel(with_bn).eval() |
| 396 | + # fuse_fx is a top level api and only supports eval mode |
| 397 | + m = fuse_fx(m) |
| 398 | + expected_nodes = [ |
| 399 | + ns.call_module(nn.Linear), |
| 400 | + ns.call_module(nn.LeakyReLU), |
| 401 | + ] |
| 402 | + expected_occurrence = { |
| 403 | + ns.call_module(nni.LinearLeakyReLU): 0, |
| 404 | + } |
| 405 | + self.checkGraphModuleNodes( |
| 406 | + m, |
| 407 | + expected_node_list=expected_nodes, |
| 408 | + expected_node_occurrence=expected_occurrence) |
| 409 | + |
366 | 410 | def test_fuse_convtranspose_bn_eval(self): |
367 | 411 |
|
368 | 412 | m = ModelForConvTransposeBNFusion().eval() |
|
0 commit comments