@@ -152,6 +152,226 @@ Tensor qcat_nhwc_kernel(
152152 return output;
153153}
154154
155+ // horizontal sum over a range of uint8_t
156+ int64_t hsum (const uint8_t * A, int len) {
157+ int64_t row_sum = 0 ;
158+ int i = 0 ;
159+
160+ #ifdef __AVX2__
161+ __m256i sum_v = _mm256_setzero_si256 ();
162+ __m256i one_epi16_v = _mm256_set1_epi16 (1 );
163+ __m256i one_epi8_v = _mm256_set1_epi8 (1 );
164+ // vectorized
165+ for (; i < len / 32 * 32 ; i += 32 ) {
166+ __m256i src_v = _mm256_loadu_si256 (reinterpret_cast <__m256i const *>(A + i));
167+ sum_v = _mm256_add_epi32 (
168+ sum_v,
169+ _mm256_madd_epi16 (
170+ // first argument is unsigned, second is signed
171+ _mm256_maddubs_epi16 (src_v, one_epi8_v),
172+ one_epi16_v)
173+ );
174+ }
175+
176+ alignas (64 ) int32_t temp[8 ];
177+ _mm256_store_si256 (reinterpret_cast <__m256i*>(temp), sum_v);
178+ for (int k = 0 ; k < 8 ; ++k) {
179+ row_sum += temp[k];
180+ }
181+ #endif // __AVX2__
182+
183+ // scalar
184+ for (; i < len; ++i) {
185+ row_sum += A[i];
186+ }
187+
188+ return row_sum;
189+ }
190+
191+ // horizontal sum over a range of int8_t
192+ int64_t hsum (const int8_t * A, int len) {
193+ int64_t row_sum = 0 ;
194+ int i = 0 ;
195+
196+ #ifdef __AVX2__
197+ __m256i sum_v = _mm256_setzero_si256 ();
198+ __m256i one_epi16_v = _mm256_set1_epi16 (1 );
199+ __m256i one_epi8_v = _mm256_set1_epi8 (1 );
200+ // vectorized
201+ for (; i < len / 32 * 32 ; i += 32 ) {
202+ __m256i src_v = _mm256_loadu_si256 (reinterpret_cast <__m256i const *>(A + i));
203+ sum_v = _mm256_add_epi32 (
204+ sum_v,
205+ _mm256_madd_epi16 (
206+ // first argument is unsigned, second is signed
207+ _mm256_maddubs_epi16 (one_epi8_v, src_v),
208+ one_epi16_v)
209+ );
210+ }
211+
212+ alignas (64 ) int32_t temp[8 ];
213+ _mm256_store_si256 (reinterpret_cast <__m256i*>(temp), sum_v);
214+ for (int k = 0 ; k < 8 ; ++k) {
215+ row_sum += temp[k];
216+ }
217+ #endif // __AVX2__
218+
219+ // scalar
220+ for (; i < len; ++i) {
221+ row_sum += A[i];
222+ }
223+
224+ return row_sum;
225+ }
226+
227+ // horizontal sum over a range of int32_t
228+ int64_t hsum (const int32_t * A, int len) {
229+ int64_t row_sum = 0 ;
230+ int i = 0 ;
231+
232+ #ifdef __AVX2__
233+ __m256i sum_epi64 = _mm256_setzero_si256 ();
234+ // vectorized
235+ for (; i < len / 8 * 8 ; i += 8 ) {
236+ __m256i src_epi32 = _mm256_loadu_si256 (reinterpret_cast <__m256i const *>(A + i));
237+ // widen
238+ __m128i src_lo_epi32 = _mm256_castsi256_si128 (src_epi32);
239+ __m128i src_hi_epi32 = _mm256_extractf128_si256 (src_epi32, 1 );
240+ __m256i src_lo_epi64 = _mm256_cvtepi32_epi64 (src_lo_epi32);
241+ __m256i src_hi_epi64 = _mm256_cvtepi32_epi64 (src_hi_epi32);
242+ // add
243+ sum_epi64 = _mm256_add_epi64 (sum_epi64, src_lo_epi64);
244+ sum_epi64 = _mm256_add_epi64 (sum_epi64, src_hi_epi64);
245+ }
246+
247+ alignas (64 ) int64_t temp[4 ];
248+ _mm256_store_si256 (reinterpret_cast <__m256i*>(temp), sum_epi64);
249+ for (int k = 0 ; k < 4 ; ++k) {
250+ row_sum += temp[k];
251+ }
252+ #endif // __AVX2__
253+
254+ // scalar
255+ for (; i < len; ++i) {
256+ row_sum += A[i];
257+ }
258+
259+ return row_sum;
260+ }
261+
262+ // horizontal sum of squares over a range of uint8_t
263+ int64_t hsum_sq (const uint8_t * A, int len) {
264+ int64_t row_sum = 0 ;
265+ int i = 0 ;
266+
267+ #ifdef __AVX2__
268+ __m256i sum_v_epu32 = _mm256_setzero_si256 ();
269+ // vectorized
270+ for (; i < len / 16 * 16 ; i += 16 ) {
271+ // (i15, ..., i0)
272+ __m128i src_epu8 = _mm_loadu_si128 (reinterpret_cast <__m128i const *>(A + i));
273+ __m256i src_epu16 = _mm256_cvtepu8_epi16 (src_epu8);
274+ // (i15 ^ 2, ..., i0 ^ 2)
275+ __m256i sq_epu16 = _mm256_mullo_epi16 (src_epu16, src_epu16);
276+ // (i7 ^ 2, ..., i0 ^ 2)
277+ __m128i sq_lo_epu16 = _mm256_castsi256_si128 (sq_epu16);
278+ // (i15 ^ 2, ..., i8 ^ 2)
279+ __m128i sq_hi_epu16 = _mm256_extractf128_si256 (sq_epu16, 1 );
280+ // widen to epu32
281+ __m256i sq_lo_epu32 = _mm256_cvtepu16_epi32 (sq_lo_epu16);
282+ __m256i sq_hi_epu32 = _mm256_cvtepu16_epi32 (sq_hi_epu16);
283+ // add to running sum
284+ sum_v_epu32 = _mm256_add_epi32 (sum_v_epu32, sq_lo_epu32);
285+ sum_v_epu32 = _mm256_add_epi32 (sum_v_epu32, sq_hi_epu32);
286+ }
287+
288+ alignas (64 ) int32_t temp[8 ];
289+ _mm256_store_si256 (reinterpret_cast <__m256i*>(temp), sum_v_epu32);
290+ for (int k = 0 ; k < 8 ; ++k) {
291+ row_sum += temp[k];
292+ }
293+ #endif // __AVX2__
294+
295+ // scalar
296+ for (; i < len; ++i) {
297+ row_sum += A[i] * A[i];
298+ }
299+
300+ return row_sum;
301+ }
302+
303+ // horizontal sum of squares over a range of int8_t
304+ int64_t hsum_sq (const int8_t * A, int len) {
305+ int64_t row_sum = 0 ;
306+ int i = 0 ;
307+
308+ #ifdef __AVX2__
309+ __m256i sum_v_epi32 = _mm256_setzero_si256 ();
310+ // vectorized
311+ for (; i < len / 16 * 16 ; i += 16 ) {
312+ // (i15, ..., i0)
313+ __m128i src_epi8 = _mm_loadu_si128 (reinterpret_cast <__m128i const *>(A + i));
314+ __m256i src_epi16 = _mm256_cvtepi8_epi16 (src_epi8);
315+ // (i15 ^ 2, ..., i0 ^ 2)
316+ __m256i sq_epi16 = _mm256_mullo_epi16 (src_epi16, src_epi16);
317+ // (i7 ^ 2, ..., i0 ^ 2)
318+ __m128i sq_lo_epi16 = _mm256_castsi256_si128 (sq_epi16);
319+ // (i15 ^ 2, ..., i8 ^ 2)
320+ __m128i sq_hi_epi16 = _mm256_extractf128_si256 (sq_epi16, 1 );
321+ // widen to epi32
322+ __m256i sq_lo_epi32 = _mm256_cvtepi16_epi32 (sq_lo_epi16);
323+ __m256i sq_hi_epi32 = _mm256_cvtepi16_epi32 (sq_hi_epi16);
324+ // add to running sum
325+ sum_v_epi32 = _mm256_add_epi32 (sum_v_epi32, sq_lo_epi32);
326+ sum_v_epi32 = _mm256_add_epi32 (sum_v_epi32, sq_hi_epi32);
327+ }
328+
329+ alignas (64 ) int32_t temp[8 ];
330+ _mm256_store_si256 (reinterpret_cast <__m256i*>(temp), sum_v_epi32);
331+ for (int k = 0 ; k < 8 ; ++k) {
332+ row_sum += temp[k];
333+ }
334+ #endif // __AVX2__
335+
336+ // scalar
337+ for (; i < len; ++i) {
338+ row_sum += A[i] * A[i];
339+ }
340+
341+ return row_sum;
342+ }
343+
344+ // horizontal sum os squares over a range of int32_t
345+ // floats throughout are necessary to prevent overflow
346+ float hsum_sq (const int32_t * A, int len) {
347+ float row_sum = 0 ;
348+ int i = 0 ;
349+
350+ #ifdef __AVX2__
351+ __m256 sum_ps = _mm256_setzero_ps ();
352+ // vectorized
353+ for (; i < len / 8 * 8 ; i += 8 ) {
354+ __m256i src_epi32 = _mm256_loadu_si256 (reinterpret_cast <__m256i const *>(A + i));
355+ __m256 src_ps = _mm256_cvtepi32_ps (src_epi32);
356+ sum_ps = _mm256_add_ps (sum_ps, _mm256_mul_ps (src_ps, src_ps));
357+ }
358+
359+ alignas (64 ) float temp[8 ];
360+ _mm256_store_ps (temp, sum_ps);
361+ for (int k = 0 ; k < 8 ; ++k) {
362+ row_sum += static_cast <float >(temp[k]);
363+ }
364+ #endif // __AVX2__
365+
366+ // scalar
367+ for (; i < len; ++i) {
368+ int64_t cur = static_cast <int64_t >(A[i]);
369+ row_sum += (float )cur * (float )cur;
370+ }
371+
372+ return row_sum;
373+ }
374+
155375void qrelu_kernel (const Tensor& qx, Tensor& qy) {
156376 const auto zero_point = qx.q_zero_point ();
157377 AT_DISPATCH_QINT_TYPES (qx.scalar_type (), " qrelu" , [&]() {
0 commit comments