Skip to content

Commit f9c3c37

Browse files
committed
Add quantized CELU operator by adding additional parameters to quantized ELU
ghstack-source-id: f2649e0 Pull Request resolved: #39199 Updated ELU to accept additional parameters ghstack-source-id: f2649e0 Pull Request resolved: #39200 Added tests ghstack-source-id: f2649e0 Pull Request resolved: #39201 Improved tests to fail when formula is wrong ghstack-source-id: f2649e0 Pull Request resolved: #39202
1 parent 016cf7d commit f9c3c37

File tree

10 files changed

+89
-6
lines changed

10 files changed

+89
-6
lines changed

aten/src/ATen/native/Activation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,15 @@ Tensor & selu_(Tensor & self) {
181181
}
182182

183183
Tensor celu(const Tensor & self, Scalar alpha) {
184+
TORCH_CHECK(alpha.to<double>() != 0,
185+
"ZeroDivisionError: alpha cannot be 0 for CELU");
184186
double inv_alpha = 1. / alpha.to<double>();
185187
return at::elu(self, alpha, Scalar(1.0), Scalar(inv_alpha));
186188
}
187189

188190
Tensor & celu_(Tensor & self, Scalar alpha) {
191+
TORCH_CHECK(alpha.to<double>() != 0,
192+
"ZeroDivisionError: alpha cannot be 0 for CELU");
189193
double inv_alpha = 1. / alpha.to<double>();
190194
return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
191195
}

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,17 @@ void qtanh_kernel(const Tensor& qx, Tensor& qy) {
793793
});
794794
}
795795

796-
void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
796+
void qelu_kernel(
797+
const Tensor& qx,
798+
Scalar alpha,
799+
Scalar scale,
800+
Scalar input_scale,
801+
Tensor& qy) {
802+
// scale and input_scale arguments refer to a generalized ELU formula
803+
// if x >= 0, ELU(x) = x * scale
804+
// if x <= 0, ELU(x) = (exp(x * input_scale) - 1) * scale
805+
// in the normal ELU formula, both are equal to 1
806+
// they are NOT related to the quantization scale term
797807

798808
int64_t i_zp = qx.q_zero_point();
799809
float i_scale = qx.q_scale();
@@ -805,6 +815,8 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
805815
float inv_o_scale = 1.0 / o_scale;
806816

807817
float alpha_float = alpha.to<float>();
818+
float scale_coef = scale.to<float>();
819+
float input_scale_coef = input_scale.to<float>();
808820

809821
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qelu_kernel", [&] {
810822

@@ -817,6 +829,8 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
817829
Vec zero_vec = Vec(0.0f);
818830
Vec one_vec = Vec(1.0f);
819831
Vec alpha_vec = Vec(alpha_float);
832+
Vec scale_coef_vec = Vec(scale_coef);
833+
Vec input_scale_coef_vec = Vec(input_scale_coef);
820834
Vec i_scale_vec = Vec(i_scale);
821835
Vec i_zero_point_vec = Vec((float)i_zp);
822836
Vec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
@@ -828,8 +842,9 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
828842
const auto x = at::native::dequantize_val(i_scale, i_zp, value_qx);
829843
// ELU
830844
const auto y = x >= 0
831-
? x
832-
: (alpha_float * (std::exp(x) - 1));
845+
? x * scale_coef
846+
: ((std::exp(x * input_scale_coef) - 1) * alpha_float * scale_coef);
847+
833848
// quantize
834849
return at::native::quantize_val<scalar_t>(o_scale, o_zp, y);
835850
},
@@ -846,13 +861,16 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
846861

847862
Vec dx_vec_copy_neg_elu = dx_vec_vec[idx] * one_vec;
848863
// calculate the negative part of ELU on the copy
864+
dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * input_scale_coef_vec;
849865
dx_vec_copy_neg_elu = dx_vec_copy_neg_elu.exp();
850866
dx_vec_copy_neg_elu = dx_vec_copy_neg_elu - one_vec;
851867
dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * alpha_vec;
852868
// blend
853869
dx_vec_vec[idx] = Vec::blendv(dx_vec_copy_neg_elu, dx_vec_vec[idx],
854870
dx_vec_vec[idx] > zero_vec);
855871
}
872+
873+
dx_vec_vec[idx] = dx_vec_vec[idx] * scale_coef_vec;
856874
}
857875
// quantize
858876
return qVec::quantize(dx_vec_vec, o_scale, o_zp, inv_o_scale);

aten/src/ATen/native/quantized/cpu/qelu.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@ DEFINE_DISPATCH(qelu_stub);
1111

1212
Tensor quantized_elu(
1313
const Tensor& qx, double output_scale, int64_t output_zero_point, Scalar alpha, Scalar scale, Scalar input_scale) {
14-
Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(),
15-
output_scale, output_zero_point);
16-
qelu_stub(qx.device().type(), qx, alpha, qy);
14+
Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(), output_scale, output_zero_point);
15+
qelu_stub(qx.device().type(), qx, alpha, scale, input_scale, qy);
1716
return qy;
1817
}
1918

19+
Tensor quantized_celu(const Tensor& qx, double output_scale, int64_t output_zero_point, Scalar alpha) {
20+
TORCH_CHECK(alpha.to<double>() != 0,
21+
"ZeroDivisionError: alpha cannot be 0 for CELU");
22+
double inv_alpha = 1. / alpha.to<double>();
23+
return quantized_elu(qx, output_scale, output_zero_point, alpha, Scalar(1.0), Scalar(inv_alpha));
24+
}
25+
2026
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
2127
m.impl("elu", quantized_elu);
28+
m.impl("celu", quantized_celu);
2229
}
2330

2431
}} // namespace at::native

aten/src/ATen/native/quantized/cpu/quantized_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
2424
using qelu_fn = void(*)(
2525
const at::Tensor& /*qx*/,
2626
Scalar /*alpha*/,
27+
Scalar /*scale*/,
28+
Scalar /*input_scale*/,
2729
at::Tensor& /*qy*/);
2830
using qbinary_fn =
2931
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);

aten/src/ATen/native/quantized/library.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ TORCH_LIBRARY(quantized, m) {
7777
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
7878
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
7979
m.def("elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor");
80+
m.def("celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor");
8081
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
8182
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
8283
m.def("instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");

benchmarks/operator_benchmark/pt/qactivation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
('functional.hardtanh', nnq.functional.hardtanh),
5353
('functional.hardswish', nnq.functional.hardswish),
5454
('functional.elu', nnq.functional.elu),
55+
('functional.celu', nnq.functional.celu),
5556
('functional.hardsigmoid', nnq.functional.hardsigmoid),
5657
('functional.leaky_relu', nnq.functional.leaky_relu),
5758
('functional.sigmoid', torch.nn.functional.sigmoid),

test/quantization/test_quantized_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,33 @@ def test_qelu(self, X, alpha):
324324
self.assertEqual(qY, qY_hat,
325325
msg="F.elu failed ({} vs {})".format(qY, qY_hat))
326326

327+
328+
"""Tests the correctness of the quantized::celu op."""
329+
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
330+
elements=hu.floats(-1e2, 1e2, allow_nan=False, allow_infinity=False),
331+
qparams=hu.qparams(scale_max=9.999999747378752e-06)),
332+
alpha=st.floats(0.01, 100.0, allow_nan=False, allow_infinity=False))
333+
def test_qcelu(self, X, alpha):
334+
X, (scale, zero_point, torch_type) = X
335+
336+
X = torch.from_numpy(X)
337+
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
338+
dtype=torch_type)
339+
op = torch.nn.quantized.functional.celu
340+
341+
# calculate ELU(dqX) and quantize
342+
dqX = qX.dequantize()
343+
dqY_hat = dqX.clone()
344+
dqY_hat[dqX < 0] = alpha * (torch.exp(dqY_hat[dqX < 0] / alpha) - 1.)
345+
qY_hat = torch.quantize_per_tensor(dqY_hat, scale=scale, zero_point=zero_point,
346+
dtype=torch_type)
347+
348+
# test regular
349+
qY = op(qX, alpha=alpha)
350+
self.assertEqual(qY, qY_hat,
351+
msg="F.celu failed ({} vs {})".format(qY, qY_hat))
352+
353+
327354
"""Tests the correctness of the quantized::qlayer_norm op."""
328355
@skipIfNoFBGEMM
329356
def test_qlayer_norm(self):

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,9 @@
11781178
- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
11791179
self: elu_backward(grad, alpha, scale, input_scale, result)
11801180

1181+
- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor
1182+
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), result)
1183+
11811184
- name: gelu(Tensor self) -> Tensor
11821185
self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)"
11831186

torch/csrc/jit/passes/quantization/helper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ std::vector<std::string> _static_quantizable_call_funcs = {
2121
"batch_norm",
2222
"hardswish",
2323
"elu",
24+
"celu",
2425
"layer_norm",
2526
"group_norm",
2627
"instance_norm",
@@ -37,6 +38,8 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
3738
"hardswish_",
3839
"elu",
3940
"elu_",
41+
"celu",
42+
"celu_",
4043
"batch_norm",
4144
"layer_norm",
4245
"group_norm",

torch/nn/quantized/functional.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,23 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
362362
return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding,
363363
dilation, ceil_mode, return_indices)
364364

365+
def celu(input, alpha=1.):
366+
# type: (Tensor, Optional[float], Optional[bool]) -> Tensor
367+
r"""celu(input, alpha=1.) -> Tensor
368+
369+
Applies the quantized CELU function element-wise.
370+
.. math::
371+
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
372+
373+
Args:
374+
input: quantized input
375+
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
376+
"""
377+
if not input.is_quantized:
378+
raise ValueError("Input to 'quantized.celu' must be quantized!")
379+
return torch.celu(input, alpha)
380+
381+
365382
def relu(input, inplace=False):
366383
# type: (Tensor, bool) -> Tensor
367384
r"""relu(input, inplace=False) -> Tensor

0 commit comments

Comments
 (0)