Skip to content

Commit 164b96c

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant][pyper] make embedding_bag quantization static (#44008)
Summary: Pull Request resolved: #44008 embedding_bag requires only quantization of weights (no dynamic quantization of inputs) So the type of quantization is essentially static (without calibration) This will enable pyper to do fc and embedding_bag quantization using the same API call Test Plan: python test/test_quantization.py test_embedding_bag Imported from OSS Reviewed By: vkuzo Differential Revision: D23467019 fbshipit-source-id: 41a61a17ee34bcb737ba5b4e19fb7a576d4aeaf9
1 parent a0ae416 commit 164b96c

File tree

5 files changed

+30
-12
lines changed

5 files changed

+30
-12
lines changed

test/quantization/test_quantize_jit.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
fuse_modules,
2424
quantize_jit,
2525
quantize_dynamic_jit,
26+
PlaceholderObserver,
2627
)
2728

2829
# torch.quantization.quantize_jit
@@ -2947,14 +2948,14 @@ def forward(self, indices1, offsets1, indices2, offsets2):
29472948
m = torch.jit.trace(module, dummy_inputs)
29482949
else:
29492950
m = torch.jit.script(module)
2950-
from torch.quantization import QConfigDynamic, PlaceholderObserver
2951-
int4_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
2952-
custom_op_name="embedding_bag_4bit"),
2953-
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"))
2954-
int8_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
2955-
custom_op_name="embedding_bag_byte"),
2956-
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"))
2957-
m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig})
2951+
int4_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float,
2952+
custom_op_name="embedding_bag_4bit"),
2953+
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"))
2954+
int8_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float,
2955+
custom_op_name="embedding_bag_byte"),
2956+
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"))
2957+
m = prepare_jit(m, {'embedding1' : int4_qconfig, 'embedding2' : int8_qconfig})
2958+
m = convert_jit(m)
29582959
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
29592960
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
29602961
.run(m.graph)

torch/csrc/jit/passes/quantization/helper.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ std::vector<std::string> _static_quantizable_call_funcs = {
2525
"layer_norm",
2626
"group_norm",
2727
"instance_norm",
28+
"embedding_bag",
2829
};
2930

3031
std::vector<std::string> _static_quantizable_aten_funcs = {
@@ -42,15 +43,21 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
4243
"layer_norm",
4344
"group_norm",
4445
"instance_norm",
46+
"embedding_bag",
4547
};
4648

4749
std::vector<std::string> _dynamic_quantizable_call_funcs = {
4850
"linear",
49-
"embedding_bag",
5051
};
5152

5253
std::vector<std::string> _dynamic_quantizable_aten_funcs = {
5354
"linear",
55+
};
56+
57+
std::vector<std::string> _static_weight_only_quant_aten_funcs = {
58+
"embedding_bag",
59+
};
60+
std::vector<std::string> _static_weight_only_quant_call_funcs = {
5461
"embedding_bag",
5562
};
5663

@@ -469,6 +476,13 @@ bool userDefinedCallFunction(Node* n) {
469476
!isFunctionNode(n, _static_quantizable_call_funcs, {});
470477
}
471478

479+
bool isWeightOnlyStaticQuantOp(Node* n) {
480+
return isFunctionNode(
481+
n,
482+
_static_weight_only_quant_call_funcs,
483+
_static_weight_only_quant_aten_funcs);
484+
}
485+
472486
bool nodeQuantizable(Node* n, QuantType quant_type) {
473487
bool is_dynamic = quant_type == QuantType::DYNAMIC;
474488
return isFunctionNode(

torch/csrc/jit/passes/quantization/helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ TORCH_API bool nodeQuantizable(
100100
Node* n,
101101
QuantType quant_type = QuantType::STATIC);
102102

103+
// Nodes which only require quantization of weight value, eg. embedding_bag
104+
bool isWeightOnlyStaticQuantOp(Node* n);
105+
103106
// Check if a use of the value is quantizable, this depends on
104107
// both the use node and the offset
105108
TORCH_API bool useQuantizable(const Use& use, QuantType quant_type);

torch/csrc/jit/passes/quantization/insert_observers.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,8 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(
11701170
// of the quantizable function.
11711171
if (quant_type_ == QuantType::STATIC) {
11721172
// Check whether producer is quantizable
1173-
if (nodeQuantizable(v->node()) || isPropagateQuantOp(v->node())) {
1173+
if (!isWeightOnlyStaticQuantOp(v->node()) &&
1174+
(nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) {
11741175
return true;
11751176
}
11761177
}

torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,7 @@ void insertQuantizationOps(
403403
// Temporary solution to quantize embedding_bag operators. Will be re-written
404404
// once we support quantization of embedding_bag weights.
405405
auto embedding_bag_name = getEmbeddingBagObsName(module, observer);
406-
if (quant_type == QuantType::DYNAMIC &&
407-
isEmbeddingBagOp(observer, embedding_bag_name)) {
406+
if (isEmbeddingBagOp(observer, embedding_bag_name)) {
408407
if (isWeight(module, observer_out)) {
409408
auto op_name = embedding_bag_name.value();
410409
Node* dequant = insertEmbeddingBagOps(observer, op_name);

0 commit comments

Comments
 (0)