|
61 | 61 | from torch.testing._internal.common_quantized import ( |
62 | 62 | override_quantized_engine, |
63 | 63 | supported_qengines, |
| 64 | + override_qengines, |
64 | 65 | ) |
65 | 66 | from torch.testing._internal.common_utils import TemporaryFileName |
66 | 67 | from torch.testing._internal.common_utils import suppress_warnings |
@@ -445,6 +446,55 @@ def checkQuantized(model): |
445 | 446 | self.calib_data) |
446 | 447 | checkQuantized(model_oneline) |
447 | 448 |
|
| 449 | + @override_qengines |
| 450 | + def test_forward_hooks_preserved(self): |
| 451 | + r"""Test post-training static quantization on preserving |
| 452 | + pre forward and post forward hooks of original model |
| 453 | + """ |
| 454 | + qengine = torch.backends.quantized.engine |
| 455 | + model = QuantStubModel() |
| 456 | + counter = { |
| 457 | + 'pre_forwards': 0, |
| 458 | + 'forwards': 0, |
| 459 | + } |
| 460 | + |
| 461 | + def fw_pre_hook(h_module, input): |
| 462 | + counter['pre_forwards'] += 1 |
| 463 | + |
| 464 | + def fw_hook(h_module, input, output): |
| 465 | + counter['forwards'] += 1 |
| 466 | + |
| 467 | + model.fc.register_forward_pre_hook(fw_pre_hook) |
| 468 | + model.fc.register_forward_hook(fw_hook) |
| 469 | + |
| 470 | + model.qconfig = torch.quantization.get_default_qconfig(qengine) |
| 471 | + model = prepare(model) |
| 472 | + |
| 473 | + def checkHooksIsPresent(model, before_convert=True): |
| 474 | + num_fwd_hooks = 1 |
| 475 | + if before_convert: |
| 476 | + self.assertEqual(len(model.quant._forward_hooks.values()), 1, |
| 477 | + "Quantization observer hook has disappeared") |
| 478 | + num_fwd_hooks = 2 |
| 479 | + |
| 480 | + self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) |
| 481 | + self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) |
| 482 | + self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, |
| 483 | + "Extra pre forward hooks have appeared on a layer") |
| 484 | + # During static quantization non stub layers are provided with quantization observer hook too |
| 485 | + self.assertEqual(len(model.fc._forward_hooks.values()), num_fwd_hooks, |
| 486 | + "Extra post forward hooks have appeared on a layer") |
| 487 | + # Implicitly check that fw_hook goes after _observer_forward_hook |
| 488 | + self.assertEqual(list(model.fc._forward_hooks.values())[-1], fw_hook, |
| 489 | + "_observer_forward_hook is not a first entry of the hooks list") |
| 490 | + |
| 491 | + checkHooksIsPresent(model, True) |
| 492 | + test_only_eval_fn(model, self.calib_data) |
| 493 | + torch.quantization.convert(model, inplace=True) |
| 494 | + checkHooksIsPresent(model, False) |
| 495 | + |
| 496 | + |
| 497 | + |
448 | 498 | @skipIfNoFBGEMM |
449 | 499 | class TestPostTrainingDynamic(QuantizationTestCase): |
450 | 500 | def test_single_layer(self): |
@@ -752,6 +802,46 @@ def checkQuantized(model, module_type): |
752 | 802 | self.checkScriptable(model_quantized, [[x]], check_save_load=True) |
753 | 803 |
|
754 | 804 |
|
| 805 | + def test_forward_hooks_preserved(self): |
| 806 | + r"""Test post-training dynamic quantization on preserving |
| 807 | + pre forward and post forward hooks of original model |
| 808 | + """ |
| 809 | + for dtype in [torch.qint8, torch.float16]: |
| 810 | + model = SingleLayerLinearDynamicModel().eval() |
| 811 | + qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| 812 | + qconfig_dict = { |
| 813 | + 'fc1': qconfig |
| 814 | + } |
| 815 | + convert_dynamic(model) |
| 816 | + |
| 817 | + counter = { |
| 818 | + 'pre_forwards': 0, |
| 819 | + 'forwards': 0, |
| 820 | + } |
| 821 | + |
| 822 | + def fw_pre_hook(h_module, input): |
| 823 | + counter['pre_forwards'] += 1 |
| 824 | + |
| 825 | + def fw_hook(h_module, input, output): |
| 826 | + counter['forwards'] += 1 |
| 827 | + |
| 828 | + model.fc1.register_forward_pre_hook(fw_pre_hook) |
| 829 | + model.fc1.register_forward_hook(fw_hook) |
| 830 | + prepare_dynamic(model, qconfig_dict) |
| 831 | + |
| 832 | + def checkHooksIsPresent(model): |
| 833 | + self.assertObjectIn(fw_pre_hook, model.fc1._forward_pre_hooks.values()) |
| 834 | + self.assertObjectIn(fw_hook, model.fc1._forward_hooks.values()) |
| 835 | + self.assertEqual(len(model.fc1._forward_pre_hooks.values()), 1, |
| 836 | + "Extra pre forward hooks have appeared on a layer") |
| 837 | + self.assertEqual(len(model.fc1._forward_hooks.values()), 1, |
| 838 | + "Extra post forward hooks have appeared on a layer") |
| 839 | + |
| 840 | + checkHooksIsPresent(model) |
| 841 | + test_only_eval_fn(model, self.calib_data) |
| 842 | + convert_dynamic(model) |
| 843 | + checkHooksIsPresent(model) |
| 844 | + |
755 | 845 | class TestQuantizationAwareTraining(QuantizationTestCase): |
756 | 846 | def test_manual(self): |
757 | 847 | for qengine in supported_qengines: |
@@ -864,6 +954,45 @@ def test_train_save_load_eval(self): |
864 | 954 | out = model(x) |
865 | 955 | self.assertEqual(ref, out) |
866 | 956 |
|
| 957 | + @override_qengines |
| 958 | + def test_forward_hooks_preserved(self): |
| 959 | + r"""Test QAT on preserving pre forward and post forward hooks of original model |
| 960 | + """ |
| 961 | + qengine = torch.backends.quantized.engine |
| 962 | + model = QuantStubModel() |
| 963 | + counter = { |
| 964 | + 'pre_forwards': 0, |
| 965 | + 'forwards': 0, |
| 966 | + } |
| 967 | + |
| 968 | + def fw_pre_hook(h_module, input): |
| 969 | + counter['pre_forwards'] += 1 |
| 970 | + |
| 971 | + def fw_hook(h_module, input, output): |
| 972 | + counter['forwards'] += 1 |
| 973 | + |
| 974 | + model.fc.register_forward_pre_hook(fw_pre_hook) |
| 975 | + model.fc.register_forward_hook(fw_hook) |
| 976 | + |
| 977 | + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) |
| 978 | + model = prepare_qat(model) |
| 979 | + |
| 980 | + def checkHooksIsPresent(model, before_convert=True): |
| 981 | + if before_convert: |
| 982 | + self.assertEqual(len(model.quant._forward_hooks.values()), 1, |
| 983 | + "Quantization observer hook has disappeared") |
| 984 | + self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) |
| 985 | + self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) |
| 986 | + self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, |
| 987 | + "Extra pre forward hooks have appeared on a layer") |
| 988 | + self.assertEqual(len(model.fc._forward_hooks.values()), 1, |
| 989 | + "Extra post forward hooks have appeared on a layer") |
| 990 | + |
| 991 | + checkHooksIsPresent(model, True) |
| 992 | + x = torch.rand(2, 5, dtype=torch.float) |
| 993 | + model(x) |
| 994 | + torch.quantization.convert(model, inplace=True) |
| 995 | + checkHooksIsPresent(model, False) |
867 | 996 |
|
868 | 997 | class TestFunctionalModule(QuantizationTestCase): |
869 | 998 | # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out |
@@ -1156,6 +1285,52 @@ def checkQAT(model): |
1156 | 1285 |
|
1157 | 1286 | checkQAT(model) |
1158 | 1287 |
|
| 1288 | + def test_forward_hooks_preserved(self): |
| 1289 | + r"""Test case that checks whether forward pre hooks of the first module and |
| 1290 | + post forward hooks of the last module in modules list passed to fusion function preserved. |
| 1291 | + (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)] |
| 1292 | + after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity]) |
| 1293 | + """ |
| 1294 | + model = ModelForFusion(default_qat_qconfig).train() |
| 1295 | + |
| 1296 | + counter = { |
| 1297 | + 'pre_forwards': 0, |
| 1298 | + 'forwards': 0, |
| 1299 | + } |
| 1300 | + fused = False |
| 1301 | + |
| 1302 | + def fw_pre_hook(fused_module_class, h_module, input): |
| 1303 | + if fused: |
| 1304 | + self.assertEqual(type(h_module), fused_module_class, |
| 1305 | + "After fusion owner of the first module's forward pre hook is not a fused module") |
| 1306 | + counter['pre_forwards'] += 1 |
| 1307 | + |
| 1308 | + def fw_hook(fused_module_class, h_module, input, output): |
| 1309 | + if fused: |
| 1310 | + self.assertEqual(type(h_module), fused_module_class, |
| 1311 | + "After fusion owner of the last module's forward hook is not a fused module") |
| 1312 | + counter['forwards'] += 1 |
| 1313 | + |
| 1314 | + # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference |
| 1315 | + model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) |
| 1316 | + model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) |
| 1317 | + model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) |
| 1318 | + model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) |
| 1319 | + |
| 1320 | + test_only_eval_fn(model, self.img_data_1d) |
| 1321 | + self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) |
| 1322 | + self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) |
| 1323 | + |
| 1324 | + model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) |
| 1325 | + model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) |
| 1326 | + |
| 1327 | + fused = True |
| 1328 | + before_fusion_pre_count = counter['pre_forwards'] |
| 1329 | + before_fusion_post_count = counter['forwards'] |
| 1330 | + test_only_eval_fn(model, self.img_data_1d) |
| 1331 | + self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) |
| 1332 | + self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) |
| 1333 | + |
1159 | 1334 | class TestModelNumerics(QuantizationTestCase): |
1160 | 1335 | def test_float_quant_compare_per_tensor(self): |
1161 | 1336 | for qengine in supported_qengines: |
|
0 commit comments