Skip to content

Commit b66513b

Browse files
committed
Add avx2 integer horizontal sum and sum of squares to vec256 qint types
Summary: Adds utility functions to quantized int types of vec256 to calculate horizontal sums and sums of squares using avx2 intrinsics. This is useful for quantized implementations of various normalization layers (LayerNorm, GroupNorm, InstanceNorm), where we need to calculate the mean and variance of a layer of quantized ints. Test Plan: Adhoc c++ tester for the correctness of the avx2 functions: https://gist.github.com/vkuzo/0380f450793cd5c05abbeacb6d3883ae Run with: ``` -lstdc++ -mavx2 -lm -ldl -o main main.cpp && ./main ``` The integration bits and performance will be tested in the next PR in the stack where we will hook quantized Layernorm to use this. Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0f3eaf8 Pull Request resolved: #35693
1 parent 83abd7f commit b66513b

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,226 @@ Tensor qcat_nhwc_kernel(
152152
return output;
153153
}
154154

155+
// horizontal sum over a range of uint8_t
156+
int64_t hsum(const uint8_t* A, int len) {
157+
int64_t row_sum = 0;
158+
int i = 0;
159+
160+
#ifdef __AVX2__
161+
__m256i sum_v = _mm256_setzero_si256();
162+
__m256i one_epi16_v = _mm256_set1_epi16(1);
163+
__m256i one_epi8_v = _mm256_set1_epi8(1);
164+
// vectorized
165+
for (; i < len / 32 * 32; i += 32) {
166+
__m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
167+
sum_v = _mm256_add_epi32(
168+
sum_v,
169+
_mm256_madd_epi16(
170+
// first argument is unsigned, second is signed
171+
_mm256_maddubs_epi16(src_v, one_epi8_v),
172+
one_epi16_v)
173+
);
174+
}
175+
176+
alignas(64) int32_t temp[8];
177+
_mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
178+
for (int k = 0; k < 8; ++k) {
179+
row_sum += temp[k];
180+
}
181+
#endif // __AVX2__
182+
183+
// scalar
184+
for (; i < len; ++i) {
185+
row_sum += A[i];
186+
}
187+
188+
return row_sum;
189+
}
190+
191+
// horizontal sum over a range of int8_t
192+
int64_t hsum(const int8_t* A, int len) {
193+
int64_t row_sum = 0;
194+
int i = 0;
195+
196+
#ifdef __AVX2__
197+
__m256i sum_v = _mm256_setzero_si256();
198+
__m256i one_epi16_v = _mm256_set1_epi16(1);
199+
__m256i one_epi8_v = _mm256_set1_epi8(1);
200+
// vectorized
201+
for (; i < len / 32 * 32; i += 32) {
202+
__m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
203+
sum_v = _mm256_add_epi32(
204+
sum_v,
205+
_mm256_madd_epi16(
206+
// first argument is unsigned, second is signed
207+
_mm256_maddubs_epi16(one_epi8_v, src_v),
208+
one_epi16_v)
209+
);
210+
}
211+
212+
alignas(64) int32_t temp[8];
213+
_mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
214+
for (int k = 0; k < 8; ++k) {
215+
row_sum += temp[k];
216+
}
217+
#endif // __AVX2__
218+
219+
// scalar
220+
for (; i < len; ++i) {
221+
row_sum += A[i];
222+
}
223+
224+
return row_sum;
225+
}
226+
227+
// horizontal sum over a range of int32_t
228+
int64_t hsum(const int32_t* A, int len) {
229+
int64_t row_sum = 0;
230+
int i = 0;
231+
232+
#ifdef __AVX2__
233+
__m256i sum_epi64 = _mm256_setzero_si256();
234+
// vectorized
235+
for (; i < len / 8 * 8; i += 8) {
236+
__m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
237+
// widen
238+
__m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
239+
__m128i src_hi_epi32 = _mm256_extractf128_si256(src_epi32, 1);
240+
__m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
241+
__m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
242+
// add
243+
sum_epi64 = _mm256_add_epi64(sum_epi64, src_lo_epi64);
244+
sum_epi64 = _mm256_add_epi64(sum_epi64, src_hi_epi64);
245+
}
246+
247+
alignas(64) int64_t temp[4];
248+
_mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_epi64);
249+
for (int k = 0; k < 4; ++k) {
250+
row_sum += temp[k];
251+
}
252+
#endif // __AVX2__
253+
254+
// scalar
255+
for (; i < len; ++i) {
256+
row_sum += A[i];
257+
}
258+
259+
return row_sum;
260+
}
261+
262+
// horizontal sum of squares over a range of uint8_t
263+
int64_t hsum_sq(const uint8_t* A, int len) {
264+
int64_t row_sum = 0;
265+
int i = 0;
266+
267+
#ifdef __AVX2__
268+
__m256i sum_v_epu32 = _mm256_setzero_si256();
269+
// vectorized
270+
for (; i < len / 16 * 16; i += 16) {
271+
// (i15, ..., i0)
272+
__m128i src_epu8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
273+
__m256i src_epu16 = _mm256_cvtepu8_epi16(src_epu8);
274+
// (i15 ^ 2, ..., i0 ^ 2)
275+
__m256i sq_epu16 = _mm256_mullo_epi16(src_epu16, src_epu16);
276+
// (i7 ^ 2, ..., i0 ^ 2)
277+
__m128i sq_lo_epu16 = _mm256_castsi256_si128(sq_epu16);
278+
// (i15 ^ 2, ..., i8 ^ 2)
279+
__m128i sq_hi_epu16 = _mm256_extractf128_si256(sq_epu16, 1);
280+
// widen to epu32
281+
__m256i sq_lo_epu32 = _mm256_cvtepu16_epi32(sq_lo_epu16);
282+
__m256i sq_hi_epu32 = _mm256_cvtepu16_epi32(sq_hi_epu16);
283+
// add to running sum
284+
sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_lo_epu32);
285+
sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_hi_epu32);
286+
}
287+
288+
alignas(64) int32_t temp[8];
289+
_mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epu32);
290+
for (int k = 0; k < 8; ++k) {
291+
row_sum += temp[k];
292+
}
293+
#endif // __AVX2__
294+
295+
// scalar
296+
for (; i < len; ++i) {
297+
row_sum += A[i] * A[i];
298+
}
299+
300+
return row_sum;
301+
}
302+
303+
// horizontal sum of squares over a range of int8_t
304+
int64_t hsum_sq(const int8_t* A, int len) {
305+
int64_t row_sum = 0;
306+
int i = 0;
307+
308+
#ifdef __AVX2__
309+
__m256i sum_v_epi32 = _mm256_setzero_si256();
310+
// vectorized
311+
for (; i < len / 16 * 16; i += 16) {
312+
// (i15, ..., i0)
313+
__m128i src_epi8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
314+
__m256i src_epi16 = _mm256_cvtepi8_epi16(src_epi8);
315+
// (i15 ^ 2, ..., i0 ^ 2)
316+
__m256i sq_epi16 = _mm256_mullo_epi16(src_epi16, src_epi16);
317+
// (i7 ^ 2, ..., i0 ^ 2)
318+
__m128i sq_lo_epi16 = _mm256_castsi256_si128(sq_epi16);
319+
// (i15 ^ 2, ..., i8 ^ 2)
320+
__m128i sq_hi_epi16 = _mm256_extractf128_si256(sq_epi16, 1);
321+
// widen to epi32
322+
__m256i sq_lo_epi32 = _mm256_cvtepi16_epi32(sq_lo_epi16);
323+
__m256i sq_hi_epi32 = _mm256_cvtepi16_epi32(sq_hi_epi16);
324+
// add to running sum
325+
sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_lo_epi32);
326+
sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_hi_epi32);
327+
}
328+
329+
alignas(64) int32_t temp[8];
330+
_mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epi32);
331+
for (int k = 0; k < 8; ++k) {
332+
row_sum += temp[k];
333+
}
334+
#endif // __AVX2__
335+
336+
// scalar
337+
for (; i < len; ++i) {
338+
row_sum += A[i] * A[i];
339+
}
340+
341+
return row_sum;
342+
}
343+
344+
// horizontal sum os squares over a range of int32_t
345+
// floats throughout are necessary to prevent overflow
346+
float hsum_sq(const int32_t* A, int len) {
347+
float row_sum = 0;
348+
int i = 0;
349+
350+
#ifdef __AVX2__
351+
__m256 sum_ps = _mm256_setzero_ps();
352+
// vectorized
353+
for (; i < len / 8 * 8; i += 8) {
354+
__m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
355+
__m256 src_ps = _mm256_cvtepi32_ps(src_epi32);
356+
sum_ps = _mm256_add_ps(sum_ps, _mm256_mul_ps(src_ps, src_ps));
357+
}
358+
359+
alignas(64) float temp[8];
360+
_mm256_store_ps(temp, sum_ps);
361+
for (int k = 0; k < 8; ++k) {
362+
row_sum += static_cast<float>(temp[k]);
363+
}
364+
#endif // __AVX2__
365+
366+
// scalar
367+
for (; i < len; ++i) {
368+
int64_t cur = static_cast<int64_t>(A[i]);
369+
row_sum += (float)cur * (float)cur;
370+
}
371+
372+
return row_sum;
373+
}
374+
155375
void qrelu_kernel(const Tensor& qx, Tensor& qy) {
156376
const auto zero_point = qx.q_zero_point();
157377
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {

0 commit comments

Comments
 (0)