@@ -1222,41 +1222,100 @@ void requantizeOutputProcessingGConvAvx2(
12221222 __m256 xf_v, yf_v, zf_v, wf_v;
12231223 if (HAS_BIAS) {
12241224 if (is_same<BIAS_TYPE, float >::value) {
1225- __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
1225+ __m256 x_bias_v = _mm256_loadu_ps (
1226+ reinterpret_cast <const float *>(r.bias + j + 0 * VLEN));
1227+ __m256 y_bias_v = _mm256_loadu_ps (
1228+ reinterpret_cast <const float *>(r.bias + j + 1 * VLEN));
1229+ __m256 z_bias_v = _mm256_loadu_ps (
1230+ reinterpret_cast <const float *>(r.bias + j + 2 * VLEN));
1231+ __m256 w_bias_v = _mm256_loadu_ps (
1232+ reinterpret_cast <const float *>(r.bias + j + 3 * VLEN));
12261233 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
12271234 x_bias_v = _mm256_div_ps (
1228- _mm256_loadu_ps (
1229- reinterpret_cast <const float *>(r.bias + j + 0 * VLEN)),
1230- _mm256_loadu_ps (r.act_times_w_scale + j + 0 * VLEN));
1235+ x_bias_v, _mm256_loadu_ps (r.act_times_w_scale + j + 0 * VLEN));
12311236 y_bias_v = _mm256_div_ps (
1232- _mm256_loadu_ps (
1233- reinterpret_cast <const float *>(r.bias + j + 1 * VLEN)),
1234- _mm256_loadu_ps (r.act_times_w_scale + j + 1 * VLEN));
1237+ y_bias_v, _mm256_loadu_ps (r.act_times_w_scale + j + 1 * VLEN));
12351238 z_bias_v = _mm256_div_ps (
1236- _mm256_loadu_ps (
1237- reinterpret_cast <const float *>(r.bias + j + 2 * VLEN)),
1238- _mm256_loadu_ps (r.act_times_w_scale + j + 2 * VLEN));
1239+ z_bias_v, _mm256_loadu_ps (r.act_times_w_scale + j + 2 * VLEN));
12391240 w_bias_v = _mm256_div_ps (
1240- _mm256_loadu_ps (
1241- reinterpret_cast <const float *>(r.bias + j + 3 * VLEN)),
1242- _mm256_loadu_ps (r.act_times_w_scale + j + 3 * VLEN));
1241+ w_bias_v, _mm256_loadu_ps (r.act_times_w_scale + j + 3 * VLEN));
1242+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
1243+ __m256 diviser_v;
1244+ if (C_PER_G == 4 ) {
1245+ diviser_v = _mm256_insertf128_ps (
1246+ _mm256_castps128_ps256 (
1247+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 0 ])),
1248+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 1 ]),
1249+ 1 );
1250+ x_bias_v = _mm256_div_ps (x_bias_v, diviser_v);
1251+
1252+ diviser_v = _mm256_insertf128_ps (
1253+ _mm256_castps128_ps256 (
1254+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 2 ])),
1255+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 3 ]),
1256+ 1 );
1257+ y_bias_v = _mm256_div_ps (y_bias_v, diviser_v);
1258+
1259+ diviser_v = _mm256_insertf128_ps (
1260+ _mm256_castps128_ps256 (
1261+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 4 ])),
1262+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 5 ]),
1263+ 1 );
1264+ z_bias_v = _mm256_div_ps (z_bias_v, diviser_v);
1265+
1266+ diviser_v = _mm256_insertf128_ps (
1267+ _mm256_castps128_ps256 (
1268+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 6 ])),
1269+ _mm_set1_ps (r.act_times_w_scale [quant_param_idx + 7 ]),
1270+ 1 );
1271+ w_bias_v = _mm256_div_ps (w_bias_v, diviser_v);
1272+
1273+ } else if (C_PER_G == 8 ) {
1274+ diviser_v = _mm256_set1_ps (
1275+ r.act_times_w_scale
1276+ [quant_param_idx +
1277+ (j - block.col_start ) / (VLEN * 4 ) * 4 + 0 ]);
1278+ x_bias_v = _mm256_div_ps (x_bias_v, diviser_v);
1279+
1280+ diviser_v = _mm256_set1_ps (
1281+ r.act_times_w_scale
1282+ [quant_param_idx +
1283+ (j - block.col_start ) / (VLEN * 4 ) * 4 + 1 ]);
1284+ y_bias_v = _mm256_div_ps (y_bias_v, diviser_v);
1285+
1286+ diviser_v = _mm256_set1_ps (
1287+ r.act_times_w_scale
1288+ [quant_param_idx +
1289+ (j - block.col_start ) / (VLEN * 4 ) * 4 + 2 ]);
1290+ z_bias_v = _mm256_div_ps (z_bias_v, diviser_v);
1291+
1292+ diviser_v = _mm256_set1_ps (
1293+ r.act_times_w_scale
1294+ [quant_param_idx +
1295+ (j - block.col_start ) / (VLEN * 4 ) * 4 + 3 ]);
1296+ w_bias_v = _mm256_div_ps (w_bias_v, diviser_v);
1297+
1298+ } else {
1299+ assert (C_PER_G == 16 );
1300+ diviser_v = _mm256_set1_ps (
1301+ r.act_times_w_scale
1302+ [quant_param_idx +
1303+ (j - block.col_start ) / (VLEN * 4 ) * 2 + 0 ]);
1304+ x_bias_v = _mm256_div_ps (x_bias_v, diviser_v);
1305+ y_bias_v = _mm256_div_ps (y_bias_v, diviser_v);
1306+
1307+ diviser_v = _mm256_set1_ps (
1308+ r.act_times_w_scale
1309+ [quant_param_idx +
1310+ (j - block.col_start ) / (VLEN * 4 ) * 2 + 1 ]);
1311+ z_bias_v = _mm256_div_ps (z_bias_v, diviser_v);
1312+ w_bias_v = _mm256_div_ps (w_bias_v, diviser_v);
1313+ }
12431314 } else {
1244- x_bias_v = _mm256_mul_ps (
1245- _mm256_loadu_ps (
1246- reinterpret_cast <const float *>(r.bias + j + 0 * VLEN)),
1247- act_times_w_rcp_v);
1248- y_bias_v = _mm256_mul_ps (
1249- _mm256_loadu_ps (
1250- reinterpret_cast <const float *>(r.bias + j + 1 * VLEN)),
1251- act_times_w_rcp_v);
1252- z_bias_v = _mm256_mul_ps (
1253- _mm256_loadu_ps (
1254- reinterpret_cast <const float *>(r.bias + j + 2 * VLEN)),
1255- act_times_w_rcp_v);
1256- w_bias_v = _mm256_mul_ps (
1257- _mm256_loadu_ps (
1258- reinterpret_cast <const float *>(r.bias + j + 3 * VLEN)),
1259- act_times_w_rcp_v);
1315+ x_bias_v = _mm256_mul_ps (x_bias_v, act_times_w_rcp_v);
1316+ y_bias_v = _mm256_mul_ps (y_bias_v, act_times_w_rcp_v);
1317+ z_bias_v = _mm256_mul_ps (z_bias_v, act_times_w_rcp_v);
1318+ w_bias_v = _mm256_mul_ps (w_bias_v, act_times_w_rcp_v);
12601319 }
12611320 xf_v = _mm256_add_ps (_mm256_cvtepi32_ps (x_v), x_bias_v);
12621321 yf_v = _mm256_add_ps (_mm256_cvtepi32_ps (y_v), y_bias_v);
0 commit comments