@@ -793,7 +793,17 @@ void qtanh_kernel(const Tensor& qx, Tensor& qy) {
793793 });
794794}
795795
796- void qelu_kernel (const Tensor& qx, Scalar alpha, Tensor& qy) {
796+ void qelu_kernel (
797+ const Tensor& qx,
798+ Scalar alpha,
799+ Scalar scale,
800+ Scalar input_scale,
801+ Tensor& qy) {
802+ // scale and input_scale arguments refer to a generalized ELU formula
803+ // if x >= 0, ELU(x) = x * scale
804+ // if x <= 0, ELU(x) = (exp(x * input_scale) - 1) * scale
805+ // in the normal ELU formula, both are equal to 1
806+ // they are NOT related to the quantization scale term
797807
798808 int64_t i_zp = qx.q_zero_point ();
799809 float i_scale = qx.q_scale ();
@@ -805,6 +815,8 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
805815 float inv_o_scale = 1.0 / o_scale;
806816
807817 float alpha_float = alpha.to <float >();
818+ float scale_coef = scale.to <float >();
819+ float input_scale_coef = input_scale.to <float >();
808820
809821 AT_DISPATCH_QINT_TYPES (qx.scalar_type (), " qelu_kernel" , [&] {
810822
@@ -817,6 +829,8 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
817829 Vec zero_vec = Vec (0 .0f );
818830 Vec one_vec = Vec (1 .0f );
819831 Vec alpha_vec = Vec (alpha_float);
832+ Vec scale_coef_vec = Vec (scale_coef);
833+ Vec input_scale_coef_vec = Vec (input_scale_coef);
820834 Vec i_scale_vec = Vec (i_scale);
821835 Vec i_zero_point_vec = Vec ((float )i_zp);
822836 Vec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg ();
@@ -828,8 +842,9 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
828842 const auto x = at::native::dequantize_val (i_scale, i_zp, value_qx);
829843 // ELU
830844 const auto y = x >= 0
831- ? x
832- : (alpha_float * (std::exp (x) - 1 ));
845+ ? x * scale_coef
846+ : ((std::exp (x * input_scale_coef) - 1 ) * alpha_float * scale_coef);
847+
833848 // quantize
834849 return at::native::quantize_val<scalar_t >(o_scale, o_zp, y);
835850 },
@@ -846,13 +861,16 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
846861
847862 Vec dx_vec_copy_neg_elu = dx_vec_vec[idx] * one_vec;
848863 // calculate the negative part of ELU on the copy
864+ dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * input_scale_coef_vec;
849865 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu.exp ();
850866 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu - one_vec;
851867 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * alpha_vec;
852868 // blend
853869 dx_vec_vec[idx] = Vec::blendv (dx_vec_copy_neg_elu, dx_vec_vec[idx],
854870 dx_vec_vec[idx] > zero_vec);
855871 }
872+
873+ dx_vec_vec[idx] = dx_vec_vec[idx] * scale_coef_vec;
856874 }
857875 // quantize
858876 return qVec::quantize (dx_vec_vec, o_scale, o_zp, inv_o_scale);
0 commit comments