Skip to content

Commit c8cac64

Browse files
dskhudiafacebook-github-bot
authored andcommitted
add missing instantiation for float bias for gconv (#127)
Summary: Pull Request resolved: #127 float bias was going through a slow path. Adding a missing specialization. Reviewed By: protonu, jianyuh Differential Revision: D17346881 fbshipit-source-id: dd6b40d80c3c429b438ea6b4e1520b935e582c4a
1 parent ea787e8 commit c8cac64

File tree

3 files changed

+95
-34
lines changed

3 files changed

+95
-34
lines changed

include/fbgemm/Fbgemm.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,8 @@ template <
13881388
typename outType,
13891389
bool FUSE_RELU,
13901390
QuantizationGranularity Q_GRAN,
1391-
int SPATIAL_DIM = 2>
1391+
int SPATIAL_DIM = 2,
1392+
typename BIAS_TYPE = std::int32_t>
13921393
FBGEMM_API void fbgemmGroupwiseConv(
13931394
const conv_param_t<SPATIAL_DIM>& conv_param,
13941395
const std::uint8_t* activations,
@@ -1397,7 +1398,7 @@ FBGEMM_API void fbgemmGroupwiseConv(
13971398
packed_W& packed_weights,
13981399
outType* out,
13991400
std::int32_t* outBuffer,
1400-
const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
1401+
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
14011402
int thread_id,
14021403
int num_threads);
14031404

src/GroupwiseConvAcc32Avx2.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,7 +1769,8 @@ template <
17691769
typename outType,
17701770
bool FUSE_RELU,
17711771
QuantizationGranularity Q_GRAN,
1772-
int SPATIAL_DIM>
1772+
int SPATIAL_DIM,
1773+
typename BIAS_TYPE>
17731774
void fbgemmGroupwiseConv(
17741775
const conv_param_t<SPATIAL_DIM>& conv_param,
17751776
const std::uint8_t* activations,
@@ -1778,10 +1779,10 @@ void fbgemmGroupwiseConv(
17781779
packed_W& packed_weights,
17791780
outType* out,
17801781
int32_t* outBuffer,
1781-
const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
1782+
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
17821783
int thread_id,
17831784
int num_threads) {
1784-
typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType;
1785+
typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType;
17851786

17861787
if (!cpuinfo_initialize()) {
17871788
throw std::runtime_error("Failed to initialize cpuinfo!");

src/QuantUtilsAvx2.cc

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,41 +1222,100 @@ void requantizeOutputProcessingGConvAvx2(
12221222
__m256 xf_v, yf_v, zf_v, wf_v;
12231223
if (HAS_BIAS) {
12241224
if (is_same<BIAS_TYPE, float>::value) {
1225-
__m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
1225+
__m256 x_bias_v = _mm256_loadu_ps(
1226+
reinterpret_cast<const float*>(r.bias + j + 0 * VLEN));
1227+
__m256 y_bias_v = _mm256_loadu_ps(
1228+
reinterpret_cast<const float*>(r.bias + j + 1 * VLEN));
1229+
__m256 z_bias_v = _mm256_loadu_ps(
1230+
reinterpret_cast<const float*>(r.bias + j + 2 * VLEN));
1231+
__m256 w_bias_v = _mm256_loadu_ps(
1232+
reinterpret_cast<const float*>(r.bias + j + 3 * VLEN));
12261233
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
12271234
x_bias_v = _mm256_div_ps(
1228-
_mm256_loadu_ps(
1229-
reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
1230-
_mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
1235+
x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
12311236
y_bias_v = _mm256_div_ps(
1232-
_mm256_loadu_ps(
1233-
reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
1234-
_mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
1237+
y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
12351238
z_bias_v = _mm256_div_ps(
1236-
_mm256_loadu_ps(
1237-
reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
1238-
_mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
1239+
z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
12391240
w_bias_v = _mm256_div_ps(
1240-
_mm256_loadu_ps(
1241-
reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
1242-
_mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
1241+
w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
1242+
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
1243+
__m256 diviser_v;
1244+
if (C_PER_G == 4) {
1245+
diviser_v = _mm256_insertf128_ps(
1246+
_mm256_castps128_ps256(
1247+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
1248+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
1249+
1);
1250+
x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
1251+
1252+
diviser_v = _mm256_insertf128_ps(
1253+
_mm256_castps128_ps256(
1254+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])),
1255+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]),
1256+
1);
1257+
y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
1258+
1259+
diviser_v = _mm256_insertf128_ps(
1260+
_mm256_castps128_ps256(
1261+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])),
1262+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]),
1263+
1);
1264+
z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
1265+
1266+
diviser_v = _mm256_insertf128_ps(
1267+
_mm256_castps128_ps256(
1268+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])),
1269+
_mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]),
1270+
1);
1271+
w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
1272+
1273+
} else if (C_PER_G == 8) {
1274+
diviser_v = _mm256_set1_ps(
1275+
r.act_times_w_scale
1276+
[quant_param_idx +
1277+
(j - block.col_start) / (VLEN * 4) * 4 + 0]);
1278+
x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
1279+
1280+
diviser_v = _mm256_set1_ps(
1281+
r.act_times_w_scale
1282+
[quant_param_idx +
1283+
(j - block.col_start) / (VLEN * 4) * 4 + 1]);
1284+
y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
1285+
1286+
diviser_v = _mm256_set1_ps(
1287+
r.act_times_w_scale
1288+
[quant_param_idx +
1289+
(j - block.col_start) / (VLEN * 4) * 4 + 2]);
1290+
z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
1291+
1292+
diviser_v = _mm256_set1_ps(
1293+
r.act_times_w_scale
1294+
[quant_param_idx +
1295+
(j - block.col_start) / (VLEN * 4) * 4 + 3]);
1296+
w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
1297+
1298+
} else {
1299+
assert(C_PER_G == 16);
1300+
diviser_v = _mm256_set1_ps(
1301+
r.act_times_w_scale
1302+
[quant_param_idx +
1303+
(j - block.col_start) / (VLEN * 4) * 2 + 0]);
1304+
x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
1305+
y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
1306+
1307+
diviser_v = _mm256_set1_ps(
1308+
r.act_times_w_scale
1309+
[quant_param_idx +
1310+
(j - block.col_start) / (VLEN * 4) * 2 + 1]);
1311+
z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
1312+
w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
1313+
}
12431314
} else {
1244-
x_bias_v = _mm256_mul_ps(
1245-
_mm256_loadu_ps(
1246-
reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
1247-
act_times_w_rcp_v);
1248-
y_bias_v = _mm256_mul_ps(
1249-
_mm256_loadu_ps(
1250-
reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
1251-
act_times_w_rcp_v);
1252-
z_bias_v = _mm256_mul_ps(
1253-
_mm256_loadu_ps(
1254-
reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
1255-
act_times_w_rcp_v);
1256-
w_bias_v = _mm256_mul_ps(
1257-
_mm256_loadu_ps(
1258-
reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
1259-
act_times_w_rcp_v);
1315+
x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
1316+
y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v);
1317+
z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v);
1318+
w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v);
12601319
}
12611320
xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
12621321
yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);

0 commit comments

Comments
 (0)