Skip to content

Commit 0c77bd7

Browse files
zetyquicklyfacebook-github-bot
authored andcommitted
Quantization: preserving pre and post forward hooks (#37233)
Summary: 1. While do convert() preserve module's **pre and post forward** hooks 2. While do fusion preserve only module's **pre forward** hooks (because after fusion output no longer the same) Pull Request resolved: #37233 Differential Revision: D22425141 Pulled By: jerryzh168 fbshipit-source-id: e69b81821d507dcd110d2ff3594ba94b9593c8da
1 parent c451dda commit 0c77bd7

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed

test/quantization/test_quantize.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from torch.testing._internal.common_quantized import (
6262
override_quantized_engine,
6363
supported_qengines,
64+
override_qengines,
6465
)
6566
from torch.testing._internal.common_utils import TemporaryFileName
6667
from torch.testing._internal.common_utils import suppress_warnings
@@ -445,6 +446,55 @@ def checkQuantized(model):
445446
self.calib_data)
446447
checkQuantized(model_oneline)
447448

449+
@override_qengines
450+
def test_forward_hooks_preserved(self):
451+
r"""Test post-training static quantization on preserving
452+
pre forward and post forward hooks of original model
453+
"""
454+
qengine = torch.backends.quantized.engine
455+
model = QuantStubModel()
456+
counter = {
457+
'pre_forwards': 0,
458+
'forwards': 0,
459+
}
460+
461+
def fw_pre_hook(h_module, input):
462+
counter['pre_forwards'] += 1
463+
464+
def fw_hook(h_module, input, output):
465+
counter['forwards'] += 1
466+
467+
model.fc.register_forward_pre_hook(fw_pre_hook)
468+
model.fc.register_forward_hook(fw_hook)
469+
470+
model.qconfig = torch.quantization.get_default_qconfig(qengine)
471+
model = prepare(model)
472+
473+
def checkHooksIsPresent(model, before_convert=True):
474+
num_fwd_hooks = 1
475+
if before_convert:
476+
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
477+
"Quantization observer hook has disappeared")
478+
num_fwd_hooks = 2
479+
480+
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
481+
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
482+
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
483+
"Extra pre forward hooks have appeared on a layer")
484+
# During static quantization non stub layers are provided with quantization observer hook too
485+
self.assertEqual(len(model.fc._forward_hooks.values()), num_fwd_hooks,
486+
"Extra post forward hooks have appeared on a layer")
487+
# Implicitly check that fw_hook goes after _observer_forward_hook
488+
self.assertEqual(list(model.fc._forward_hooks.values())[-1], fw_hook,
489+
"_observer_forward_hook is not a first entry of the hooks list")
490+
491+
checkHooksIsPresent(model, True)
492+
test_only_eval_fn(model, self.calib_data)
493+
torch.quantization.convert(model, inplace=True)
494+
checkHooksIsPresent(model, False)
495+
496+
497+
448498
@skipIfNoFBGEMM
449499
class TestPostTrainingDynamic(QuantizationTestCase):
450500
def test_single_layer(self):
@@ -752,6 +802,46 @@ def checkQuantized(model, module_type):
752802
self.checkScriptable(model_quantized, [[x]], check_save_load=True)
753803

754804

805+
def test_forward_hooks_preserved(self):
806+
r"""Test post-training dynamic quantization on preserving
807+
pre forward and post forward hooks of original model
808+
"""
809+
for dtype in [torch.qint8, torch.float16]:
810+
model = SingleLayerLinearDynamicModel().eval()
811+
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
812+
qconfig_dict = {
813+
'fc1': qconfig
814+
}
815+
convert_dynamic(model)
816+
817+
counter = {
818+
'pre_forwards': 0,
819+
'forwards': 0,
820+
}
821+
822+
def fw_pre_hook(h_module, input):
823+
counter['pre_forwards'] += 1
824+
825+
def fw_hook(h_module, input, output):
826+
counter['forwards'] += 1
827+
828+
model.fc1.register_forward_pre_hook(fw_pre_hook)
829+
model.fc1.register_forward_hook(fw_hook)
830+
prepare_dynamic(model, qconfig_dict)
831+
832+
def checkHooksIsPresent(model):
833+
self.assertObjectIn(fw_pre_hook, model.fc1._forward_pre_hooks.values())
834+
self.assertObjectIn(fw_hook, model.fc1._forward_hooks.values())
835+
self.assertEqual(len(model.fc1._forward_pre_hooks.values()), 1,
836+
"Extra pre forward hooks have appeared on a layer")
837+
self.assertEqual(len(model.fc1._forward_hooks.values()), 1,
838+
"Extra post forward hooks have appeared on a layer")
839+
840+
checkHooksIsPresent(model)
841+
test_only_eval_fn(model, self.calib_data)
842+
convert_dynamic(model)
843+
checkHooksIsPresent(model)
844+
755845
class TestQuantizationAwareTraining(QuantizationTestCase):
756846
def test_manual(self):
757847
for qengine in supported_qengines:
@@ -864,6 +954,45 @@ def test_train_save_load_eval(self):
864954
out = model(x)
865955
self.assertEqual(ref, out)
866956

957+
@override_qengines
958+
def test_forward_hooks_preserved(self):
959+
r"""Test QAT on preserving pre forward and post forward hooks of original model
960+
"""
961+
qengine = torch.backends.quantized.engine
962+
model = QuantStubModel()
963+
counter = {
964+
'pre_forwards': 0,
965+
'forwards': 0,
966+
}
967+
968+
def fw_pre_hook(h_module, input):
969+
counter['pre_forwards'] += 1
970+
971+
def fw_hook(h_module, input, output):
972+
counter['forwards'] += 1
973+
974+
model.fc.register_forward_pre_hook(fw_pre_hook)
975+
model.fc.register_forward_hook(fw_hook)
976+
977+
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
978+
model = prepare_qat(model)
979+
980+
def checkHooksIsPresent(model, before_convert=True):
981+
if before_convert:
982+
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
983+
"Quantization observer hook has disappeared")
984+
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
985+
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
986+
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
987+
"Extra pre forward hooks have appeared on a layer")
988+
self.assertEqual(len(model.fc._forward_hooks.values()), 1,
989+
"Extra post forward hooks have appeared on a layer")
990+
991+
checkHooksIsPresent(model, True)
992+
x = torch.rand(2, 5, dtype=torch.float)
993+
model(x)
994+
torch.quantization.convert(model, inplace=True)
995+
checkHooksIsPresent(model, False)
867996

868997
class TestFunctionalModule(QuantizationTestCase):
869998
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@@ -1156,6 +1285,52 @@ def checkQAT(model):
11561285

11571286
checkQAT(model)
11581287

1288+
def test_forward_hooks_preserved(self):
1289+
r"""Test case that checks whether forward pre hooks of the first module and
1290+
post forward hooks of the last module in modules list passed to fusion function preserved.
1291+
(e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
1292+
after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
1293+
"""
1294+
model = ModelForFusion(default_qat_qconfig).train()
1295+
1296+
counter = {
1297+
'pre_forwards': 0,
1298+
'forwards': 0,
1299+
}
1300+
fused = False
1301+
1302+
def fw_pre_hook(fused_module_class, h_module, input):
1303+
if fused:
1304+
self.assertEqual(type(h_module), fused_module_class,
1305+
"After fusion owner of the first module's forward pre hook is not a fused module")
1306+
counter['pre_forwards'] += 1
1307+
1308+
def fw_hook(fused_module_class, h_module, input, output):
1309+
if fused:
1310+
self.assertEqual(type(h_module), fused_module_class,
1311+
"After fusion owner of the last module's forward hook is not a fused module")
1312+
counter['forwards'] += 1
1313+
1314+
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
1315+
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
1316+
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
1317+
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
1318+
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
1319+
1320+
test_only_eval_fn(model, self.img_data_1d)
1321+
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
1322+
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
1323+
1324+
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
1325+
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
1326+
1327+
fused = True
1328+
before_fusion_pre_count = counter['pre_forwards']
1329+
before_fusion_post_count = counter['forwards']
1330+
test_only_eval_fn(model, self.img_data_1d)
1331+
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
1332+
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
1333+
11591334
class TestModelNumerics(QuantizationTestCase):
11601335
def test_float_quant_compare_per_tensor(self):
11611336
for qengine in supported_qengines:

torch/quantization/fuse_modules.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def fuse_known_modules(mod_list):
124124
raise NotImplementedError("Cannot fuse modules: {}".format(types))
125125
new_mod = [None] * len(mod_list)
126126
new_mod[0] = fuser_method(*mod_list)
127+
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
128+
# Move pre forward hooks of the base module to resulting fused module
129+
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
130+
new_mod[0].register_forward_pre_hook(pre_hook_fn)
131+
del mod_list[0]._forward_pre_hooks[handle_id]
132+
# Move post forward hooks of the last module to resulting fused module
133+
for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
134+
new_mod[0].register_forward_hook(hook_fn)
135+
del mod_list[-1]._forward_hooks[handle_id]
127136

128137
for i in range(1, len(mod_list)):
129138
new_mod[i] = torch.nn.Identity()

torch/quantization/quantize.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def add_observer_(module, non_leaf_module_list=None, device=None):
119119
if device is not None:
120120
activation.to(device)
121121
module.add_module('activation_post_process', activation)
122-
module.register_forward_hook(_observer_forward_hook)
122+
# Register observer as the first entry in the hook list
123+
# All post forward hooks are preserved and will be executed after the observer before convert
124+
handle = module.register_forward_hook(_observer_forward_hook)
125+
module._forward_hooks.move_to_end(handle.id, last=False)
123126

124127
def get_unique_devices_(module):
125128
return {p.device for p in module.parameters()} | \
@@ -393,6 +396,14 @@ def swap_module(mod, mapping):
393396
)
394397
device = next(iter(devices)) if len(devices) > 0 else None
395398
new_mod = mapping[type(mod)].from_float(mod)
399+
# Preserve module's pre forward hooks. They'll be called on quantized input
400+
for pre_hook_fn in mod._forward_pre_hooks.values():
401+
new_mod.register_forward_pre_hook(pre_hook_fn)
402+
# Preserve module's post forward hooks except _observer_forward_hook
403+
# After convert they'll work with quantized output
404+
for hook_fn in mod._forward_hooks.values():
405+
if hook_fn is not _observer_forward_hook:
406+
new_mod.register_forward_hook(hook_fn)
396407
if device:
397408
new_mod.to(device)
398409
return new_mod

0 commit comments

Comments
 (0)