Skip to content

Commit abda069

Browse files
b-koopmanfacebook-github-bot
authored andcommitted
[quant][embedding qat] Support Embedding QAT via FX API (#68296)
Summary: Pull Request resolved: #68296 Support QAT workflow by using torch.fx QAT API. e.g. `prepare_qat_fx` and `convert_fx`. Test Plan: `pytest test/quantization/fx/test_quantize_fx.py -v -k "test_qat_embedding_linear"` Imported from OSS Reviewed By: jingsh, supriyar Differential Revision: D32404517 fbshipit-source-id: 0484df8c826b823b60dfecd9def77bf8cffe0527
1 parent 3157371 commit abda069

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
MovingAverageMinMaxObserver,
6363
HistogramObserver,
6464
QConfig,
65+
default_embedding_qat_qconfig,
6566
)
6667

6768
from torch.ao.quantization.fx.pattern_utils import (
@@ -5745,6 +5746,51 @@ def test_resnet18_ddp(self):
57455746
self._test_model_impl(
57465747
'ddp', 'resnet18', model, eager_quantizable_model)
57475748

5749+
def test_qat_embedding_linear(self):
5750+
for device in get_supported_device_types():
5751+
class EmbeddingLinear(torch.nn.Module):
5752+
def __init__(self):
5753+
super(EmbeddingLinear, self).__init__()
5754+
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
5755+
self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)
5756+
5757+
def forward(self, input: torch.Tensor):
5758+
x = torch.sum(self.emb(input), dim=1)
5759+
x = self.linear(x)
5760+
return x
5761+
5762+
qconfig_dict = {"": get_default_qat_qconfig("qnnpack"),
5763+
"object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]}
5764+
5765+
5766+
train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
5767+
eval_output = [[torch.randint(0, 10, (12, 1))]]
5768+
5769+
model = EmbeddingLinear().train()
5770+
prepared_fx_model = prepare_qat_fx(model, qconfig_dict)
5771+
test_only_train_fn(prepared_fx_model, train_indices)
5772+
convert_custom_config_dict = {
5773+
"additional_object_mapping": {
5774+
"static": {
5775+
torch.nn.qat.Embedding: nn.quantized.Embedding,
5776+
}
5777+
}
5778+
}
5779+
quant_model = convert_fx(prepared_fx_model,
5780+
convert_custom_config_dict=convert_custom_config_dict,
5781+
qconfig_dict=qconfig_dict)
5782+
5783+
def checkQuantized(model):
5784+
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
5785+
self.assertTrue(type(model.emb), nn.quantized.Embedding)
5786+
# Also test that Linear has been quantized.
5787+
self.assertTrue(type(model.linear), nnq.Linear)
5788+
5789+
test_only_eval_fn(model, eval_output)
5790+
self.checkScriptable(model, eval_output)
5791+
self.checkNoQconfig(model)
5792+
checkQuantized(quant_model)
5793+
57485794
@given(
57495795
device=st.sampled_from(
57505796
["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]

torch/ao/quantization/fx/quantization_patterns.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,7 @@ def convert(self,
11611161
load_arg(quantized=[0])(self.bn_node.args),
11621162
load_arg(quantized=torch.float)(self.bn_node.kwargs))
11631163

1164+
@register_quant_pattern(torch.nn.qat.Embedding)
11641165
@register_quant_pattern(torch.nn.Embedding)
11651166
@register_quant_pattern(torch.nn.EmbeddingBag)
11661167
class EmbeddingQuantizeHandler(QuantizeHandler):

torch/ao/quantization/qconfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ def get_default_qconfig(backend='fbgemm'):
192192
qconfig = default_qconfig
193193
return qconfig
194194

195-
default_embedding_qat_qconfig = QConfig(activation=NoopObserver,
195+
default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
196196
weight=default_embedding_fake_quant)
197197

198-
default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver,
198+
default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
199199
weight=default_embedding_fake_quant_4bit)
200200

201201
def get_default_qat_qconfig(backend='fbgemm', version=1):

0 commit comments

Comments
 (0)