Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
bool sparse,
const c10::optional<at::Tensor>& per_sample_weights_,
bool include_last_offset) {

TORCH_CHECK(offsets_in.has_value(), "embedding_bag_byte_rowwise_offsets expects offsets to be set");
TORCH_CHECK(
offsets_in.has_value(),
"embedding_bag_byte_rowwise_offsets expects offsets to be set");
auto offsets = offsets_in.value();
auto offsets_data = offsets.data_ptr<int64_t>();
const auto indices_data = indices.data_ptr<int64_t>();
Expand Down Expand Up @@ -123,7 +124,9 @@ Tensor embedding_bag_byte_rowwise_offsets(
bool include_last_offset) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
TORCH_CHECK(weight.ndimension() == 2);
TORCH_CHECK(offsets_in.has_value(), "embedding_bag_byte_rowwise_offsets expects offsets to be set");
TORCH_CHECK(
offsets_in.has_value(),
"embedding_bag_byte_rowwise_offsets expects offsets to be set");

auto offsets = offsets_in.value();
auto offsets_data = offsets.data_ptr<int64_t>();
Expand Down Expand Up @@ -221,7 +224,9 @@ Tensor embedding_bag_4bit_rowwise_offsets(
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
TORCH_CHECK(offsets_in.has_value(), "embedding_bag_4bit_rowwise_offsets expects offsets to be set");
TORCH_CHECK(
offsets_in.has_value(),
"embedding_bag_4bit_rowwise_offsets expects offsets to be set");

TORCH_CHECK(weight.ndimension() == 2);
TORCH_CHECK(indices.ndimension() == 1);
Expand Down Expand Up @@ -423,9 +428,31 @@ class QEmbeddingBag final {
}
};

template <int bit_rate>
class QEmbedding final {
public:
static at::Tensor run(
const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
const Tensor& indices,
bool sparse) {
const auto offsets_size = indices.numel();
at::Tensor offsets = at::arange(0, offsets_size, at::kLong);
at::Tensor output;
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
indices, offsets, sparse, c10::nullopt, false);
} else {
TORCH_INTERNAL_ASSERT(
"Currently only support 8-bit embedding quantization");
}
return output;
}
};

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
// Function that works on TorchBind packed weights.
m.impl("embedding_bag_byte", TORCH_FN(QEmbeddingBag<8>::run));
m.impl("embedding_byte", TORCH_FN(QEmbedding<8>::run));

// Functions that work on at::Tensor packed weight.
m.impl(
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ TORCH_LIBRARY(quantized, m) {
m.def("embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor");
m.def("embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");
m.def("embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");
m.def("embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool sparse=False) -> Tensor");
m.def("celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor");
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
Expand Down
36 changes: 33 additions & 3 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines
from torch.quantization import PerChannelMinMaxObserver

np_dtype = {
torch.quint8 : np.uint8,
Expand Down Expand Up @@ -2716,7 +2717,7 @@ def test_qlinear_unpack(self, W, use_channelwise):


@unittest.skipIf(sys.platform == "darwin", "Known test failure on Mac.")
class TestQuantizedEmbeddingBag(TestCase):
class TestQuantizedEmbeddingOps(TestCase):
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate):
weights = torch.from_numpy((np.random.random_sample((
num_embeddings, embedding_dim)) + 1).astype(np.float32))
Expand All @@ -2727,7 +2728,6 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe
if bit_rate == 8:
# Check numerics of prepack function that accepts qtensor as input.
# We use min-max observer to mimic the quantization performed in the original function.
from torch.quantization import PerChannelMinMaxObserver
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
# Get the scale and zero point for the weight tensor
Expand Down Expand Up @@ -2884,7 +2884,6 @@ def get_reference_result(

if bit_rate == 8:
# Test operator that accepts TorchBind packed weights.
from torch.quantization import PerChannelMinMaxObserver
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
# Get the scale and zero point for the weight tensor
Expand Down Expand Up @@ -2931,6 +2930,37 @@ def test_embedding_bag_4bit_rowwise_offsets(self, num_embeddings,
include_last_offset, atol=0.1,
rtol=1e-2)

""" Tests the correctness of the quantized embedding lookup operator """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))
def test_embedding_byte(self, num_embeddings, embedding_dim):
quant_op = torch.ops.quantized.embedding_byte
prepack_op = torch.ops.quantized.embedding_bag_prepack

weights = torch.from_numpy((np.random.random_sample((
num_embeddings, embedding_dim)) + 1).astype(np.float32))

obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()

# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
max_segments = 5
max_segment_length = 20
num_lengths = np.random.randint(1, max_segments + 1)
lengths = np.random.randint(1, max_segment_length + 1,
size=num_lengths).astype(np.int32)
num_indices = np.sum(lengths)
indices = torch.from_numpy(np.random.randint(
low=0, high=num_embeddings, size=num_indices, dtype=np.int64))

packed_weight = prepack_op(qweight)
qresult = quant_op(packed_weight, indices, sparse=False)

ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use qweight.dequantize(), this should be an exact match

torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the modified atol and rtol values only because of quantization error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes


class TestQuantizedConv(unittest.TestCase):
def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs,
Expand Down
2 changes: 1 addition & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from quantization.test_quantized_op import TestDynamicQuantizedLinear # noqa: F401
from quantization.test_quantized_op import TestComparatorOps # noqa: F401
from quantization.test_quantized_op import TestPadding # noqa: F401
from quantization.test_quantized_op import TestQuantizedEmbeddingBag # noqa: F401
from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401

# Quantized Functional
from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401
Expand Down