Skip to content

Commit 8d7242a

Browse files
salilsdesaipytorchmergebot
authored andcommitted
[PyTorch Edge] Add Quantized Softmax Op (Naive Implementation) (#75017)
Summary: Pull Request resolved: #75017 This version just does dequantize, fp32 softmax, quantize. Another version of actual quantized softmax using qnnpack will be added next Test Plan: From fbcode: ```buck test caffe2/test:quantization -- test_qsoftmax``` Benchmarking: See summary of D34996486 Reviewed By: kimishpatel Differential Revision: D34943147 fbshipit-source-id: 426a0780803597a21460139c67960891d6e9cc81 (cherry picked from commit 524eede)
1 parent 8b8f3e8 commit 8d7242a

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <ATen/ATen.h>
2+
#include <torch/library.h>
3+
4+
namespace at {
5+
namespace native {
6+
7+
namespace {
8+
9+
Tensor qsoftmax(
10+
const Tensor& qx,
11+
const int64_t dim,
12+
const double output_scale,
13+
const int64_t output_zero_point) {
14+
Tensor rx = at::dequantize(qx);
15+
Tensor ry = at::softmax(rx, dim);
16+
return at::quantize_per_tensor(
17+
ry, output_scale, output_zero_point, qx.scalar_type());
18+
}
19+
20+
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
21+
m.impl(TORCH_SELECTIVE_NAME("quantized::softmax"), TORCH_FN(qsoftmax));
22+
}
23+
24+
} // namespace
25+
26+
} // namespace native
27+
} // namespace at

aten/src/ATen/native/quantized/library.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ TORCH_LIBRARY(quantized, m) {
188188
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
189189
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
190190
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
191+
m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor"));
191192
}
192193

193194
// According to #33294: The "_" prefix registration will be

test/quantization/core/test_quantized_op.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,40 @@ def test_qmatmul(self, num_dims, outer_dims, m, k, n, dtypes):
11011101
scale_C,
11021102
zero_point_C)
11031103

1104+
"""Tests the correctness of the quantized softmax op."""
1105+
@given(num_dims=st.integers(2, 4),
1106+
dims=st.lists(st.integers(2, 5), min_size=5, max_size=5))
1107+
def test_qsoftmax(self, num_dims, dims):
1108+
size = dims[:num_dims]
1109+
torch_dtype = torch.quint8
1110+
np_dtype = np.uint8
1111+
dim = num_dims - 1
1112+
1113+
scale_X = 1.3
1114+
zero_point_X = 0
1115+
X = torch.rand(size=size, dtype=torch.float32) * 8 + zero_point_X
1116+
1117+
scale_Y = 1 / 256
1118+
zero_point_Y = 0
1119+
1120+
qX = torch.quantize_per_tensor(X,
1121+
scale=scale_X,
1122+
zero_point=zero_point_X,
1123+
dtype=torch_dtype)
1124+
1125+
1126+
# softmax ground truth
1127+
Y = torch.softmax(qX.dequantize(), dim=dim).numpy()
1128+
qY = _quantize(Y, scale_Y, zero_point_Y, dtype=np_dtype)
1129+
qY_hat = torch.ops.quantized.softmax(qX,
1130+
dim=dim,
1131+
output_scale=scale_Y,
1132+
output_zero_point=zero_point_Y)
1133+
1134+
np.testing.assert_equal(qY, qY_hat.int_repr(),
1135+
"Quantized softmax failed.")
1136+
1137+
11041138
"""Tests the correctness of the mul and mul_relu op."""
11051139
def test_qmul_broadcast(self):
11061140
mul_relu = torch.ops.quantized.mul_relu

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,7 @@ aten_native_source_non_codegen_list = [
12201220
"aten/src/ATen/native/quantized/cpu/qreduction.cpp",
12211221
"aten/src/ATen/native/quantized/cpu/qrelu.cpp",
12221222
"aten/src/ATen/native/quantized/cpu/qsigmoid.cpp",
1223+
"aten/src/ATen/native/quantized/cpu/qsoftmax.cpp",
12231224
"aten/src/ATen/native/quantized/cpu/qsort.cpp",
12241225
"aten/src/ATen/native/quantized/cpu/qtanh.cpp",
12251226
"aten/src/ATen/native/quantized/cpu/qthreshold.cpp",

0 commit comments

Comments
 (0)