Skip to content

Commit fa6b34b

Browse files
Radhakrishnan Venkataramanifacebook-github-bot
authored andcommitted
2 Bit Embedding Conversion Operator support. (#43077)
Summary: Pull Request resolved: #43077 2 Bit Embedding weight conversion operation is quite similar to 4 bit embedding weight conversion. The diff contains both the 1. 2bit packing op `embedding_bag_2bit_prepack`. 2. 2bit unpacking op `embedding_bag_2bit_unpack`. Comments about the op are inline with the op definition. Test Plan: buck test caffe2/test:quantization -- test_embedding_bag_2bit_unpack Reviewed By: supriyar Differential Revision: D23143262 fbshipit-source-id: fd8877f049ac1f7eb4bc580e588dc95f8b1edef0
1 parent ab366d0 commit fa6b34b

File tree

4 files changed

+80
-10
lines changed

4 files changed

+80
-10
lines changed

aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,24 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
137137
return output;
138138
}
139139

140-
Tensor qembeddingbag_4bit_prepack(const Tensor& weight) {
140+
Tensor _qembeddingbag_nbit_prepack_helper(const Tensor& weight, int BIT_RATE) {
141141
int64_t embedding_rows = weight.size(0);
142142
int64_t embedding_cols = weight.size(1);
143143

144144
Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
145145

146146
const auto weight_data = weight.data_ptr<float>();
147-
constexpr int BIT_RATE = 4;
148-
constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
147+
TORCH_CHECK(
148+
BIT_RATE == 4 || BIT_RATE == 2,
149+
"BIT_RATE must be either 2 or 4 to use 'qembeddingbag_nbit_prepack'."
150+
"For 8bit, consider using 'embedding_bag_byte_prepack'.");
151+
152+
int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
149153
TORCH_CHECK(
150154
weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0,
151-
"FloatToFused4BitRowwiseQuantizedOp only works for the number of "
152-
"columns a multiple of 2");
155+
"qembeddingbag_" + c10::to_string(BIT_RATE) +
156+
"bit_prepack only works for the number of columns a multiple of "
157+
+ c10::to_string(NUM_ELEM_PER_BYTE));
153158

154159
// The "fused" representation stores the scale and bias with the
155160
// row-wise quantized data in one tensor.
@@ -219,6 +224,29 @@ Tensor qembeddingbag_4bit_prepack(const Tensor& weight) {
219224
return output;
220225
}
221226

227+
// Applies 4-bit row-wise quantization by determining the range
228+
// (maximum - minimum) and bias (minimum value) of each row in the input
229+
// matrix, and then scaling each element to an 2-bit number between 0 and
230+
// 15.
231+
// To later de-quantize values, the scale (range / 15) and zero_point
232+
// are stored alongside the data. More precisely, each row first has quantized
233+
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
234+
Tensor qembeddingbag_4bit_prepack(const Tensor& weight) {
235+
return _qembeddingbag_nbit_prepack_helper(weight, 4 /*BIT_RATE*/);
236+
}
237+
238+
// Applies 2-bit row-wise quantization by determining the range
239+
// (maximum - minimum) and bias (minimum value) of each row in the input
240+
// matrix, and then scaling each element to an 2-bit number between 0 and
241+
// 3.
242+
// To later de-quantize values, the scale (range / 3) and zero_point
243+
// are stored alongside the data. More precisely, each row first has quantized
244+
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
245+
// TODO() - Add 2Bit Embedding Lookup operator.
246+
Tensor qembeddingbag_2bit_prepack(const Tensor& weight) {
247+
return _qembeddingbag_nbit_prepack_helper(weight, 2 /*BIT_RATE*/);
248+
}
249+
222250
class QEmbeddingPackWeights final {
223251
public:
224252
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(at::Tensor weight) {
@@ -229,7 +257,9 @@ class QEmbeddingPackWeights final {
229257
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
230258
m.impl("embedding_bag_byte_prepack", qembeddingbag_byte_prepack);
231259
m.impl("embedding_bag_4bit_prepack", qembeddingbag_4bit_prepack);
260+
m.impl("embedding_bag_2bit_prepack", qembeddingbag_2bit_prepack);
232261
}
262+
233263
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
234264
m.impl("embedding_bag_prepack", TORCH_FN(QEmbeddingPackWeights::run));
235265
}

aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
8787
return output;
8888
}
8989

90-
Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
90+
Tensor _qembeddingbag_nbit_unpack_helper(const Tensor& packed_weight, int BIT_RATE) {
9191
const auto input_rows = packed_weight.size(0);
9292
const auto input_columns = packed_weight.size(1);
9393
const auto* input_data = packed_weight.data_ptr<uint8_t>();
94-
constexpr int NUM_ELEM_PER_BYTE = 2;
95-
constexpr int BIT_RATE = 4;
94+
int NUM_ELEM_PER_BYTE = 8/BIT_RATE;
9695

9796
// The last 4 bytes per row are two fp16 scale and zero_point.
9897
// The rest of input_columns is the number of values in the original row.
@@ -126,6 +125,30 @@ Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
126125
return output;
127126
}
128127

128+
// De-quantizes the result of the qembeddingbag_4bit_prepack operator.
129+
// The input is expected to first have quantized values,
130+
// then 2-byte fp16 scale and 2-byte zero_offset.
131+
// The output is a matrix containing only the values, but de-quantized.
132+
// De-quantization is performed by multiplying each value by its
133+
// row's scale and zero_point parameters. The de-quantized values
134+
// will thus not be exactly equal to the original, un-quantized
135+
// floating point values.
136+
Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
137+
return _qembeddingbag_nbit_unpack_helper(packed_weight, 4 /*BIT_RATE*/);
138+
}
139+
140+
// De-quantizes the result of the qembeddingbag_2bit_prepack operator.
141+
// The input is expected to first have quantized values,
142+
// then 2-byte fp16 scale and 2-byte zero_offset.
143+
// The output is a matrix containing only the values, but de-quantized.
144+
// De-quantization is performed by multiplying each value by its
145+
// row's scale and zero_point parameters. The de-quantized values
146+
// will thus not be exactly equal to the original, un-quantized
147+
// floating point values.
148+
Tensor qembeddingbag_2bit_unpack(const Tensor& packed_weight) {
149+
return _qembeddingbag_nbit_unpack_helper(packed_weight, 2 /*BIT_RATE*/);
150+
}
151+
129152
class QEmbeddingUnpackWeights final {
130153
public:
131154
static at::Tensor run(
@@ -137,6 +160,7 @@ class QEmbeddingUnpackWeights final {
137160
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
138161
m.impl("embedding_bag_byte_unpack", qembeddingbag_byte_unpack);
139162
m.impl("embedding_bag_4bit_unpack", qembeddingbag_4bit_unpack);
163+
m.impl("embedding_bag_2bit_unpack", qembeddingbag_2bit_unpack);
140164
}
141165

142166
TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {

aten/src/ATen/native/quantized/library.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ TORCH_LIBRARY(quantized, m) {
9595
m.def("embedding_bag_byte_unpack(Tensor weight) -> Tensor");
9696
m.def("embedding_bag_4bit_prepack(Tensor weight) -> Tensor");
9797
m.def("embedding_bag_4bit_unpack(Tensor weight) -> Tensor");
98+
m.def("embedding_bag_2bit_prepack(Tensor weight) -> Tensor");
99+
m.def("embedding_bag_2bit_unpack(Tensor weight) -> Tensor");
98100
m.def("embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor");
99101
m.def("embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");
100102
m.def("embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");

test/quantization/test_quantized_op.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,8 +2743,13 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe
27432743
# compare against C2 to ensure numerical equivalency.
27442744
from caffe2.python import core, workspace
27452745
conversion_op = "FloatToFused8BitRowwiseQuantized"
2746+
reverse_conversion_op = None
27462747
if bit_rate == 4:
27472748
conversion_op = "FloatToFused4BitRowwiseQuantized"
2749+
reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat"
2750+
elif bit_rate == 2:
2751+
conversion_op = "FloatToFused2BitRowwiseQuantized"
2752+
reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat"
27482753

27492754
def get_c2_weights(weights):
27502755
workspace.ResetWorkspace()
@@ -2756,10 +2761,10 @@ def get_c2_weights(weights):
27562761
)
27572762
)
27582763
emb_q = workspace.FetchBlob("quantized_weights")
2759-
if bit_rate == 4:
2764+
if bit_rate == 4 or bit_rate == 2:
27602765
workspace.RunOperatorOnce(
27612766
core.CreateOperator(
2762-
"Fused4BitRowwiseQuantizedToFloat", ["quantized_weights"], ["dequantized_weights"]
2767+
reverse_conversion_op, ["quantized_weights"], ["dequantized_weights"]
27632768
)
27642769
)
27652770
dequantized_data = torch.from_numpy(workspace.FetchBlob("dequantized_weights"))
@@ -2794,6 +2799,15 @@ def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim):
27942799

27952800
self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate=4)
27962801

2802+
""" Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
2803+
@given(num_embeddings=st.integers(10, 100),
2804+
embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),)
2805+
def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim):
2806+
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
2807+
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
2808+
2809+
self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate=2)
2810+
27972811
def embedding_bag_rowwise_offsets_run(
27982812
self, bit_rate, num_embeddings,
27992813
embedding_dim, num_offsets, enable_per_sample_weights,

0 commit comments

Comments
 (0)