@@ -83,7 +83,21 @@ namespace Eval::NNUE::Layers {
8383 return _mm512_reduce_add_epi32 (sum) + bias;
8484 };
8585
86- [[maybe_unused]] auto m512_haddx4 = [](__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
86+ // This function takes
87+ // sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
88+ // sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
89+ // sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
90+ // sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
91+ // and returns
92+ // ret = [
93+ // reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
94+ // reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
95+ // reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
96+ // reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
97+ // ]
98+ [[maybe_unused]] auto m512_hadd128x16_interleave = [](
99+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
100+
87101 __m512i sum01a = _mm512_unpacklo_epi32 (sum0, sum1);
88102 __m512i sum01b = _mm512_unpackhi_epi32 (sum0, sum1);
89103
@@ -96,7 +110,13 @@ namespace Eval::NNUE::Layers {
96110 __m512i sum0123a = _mm512_unpacklo_epi64 (sum01, sum23);
97111 __m512i sum0123b = _mm512_unpackhi_epi64 (sum01, sum23);
98112
99- __m512i sum = _mm512_add_epi32 (sum0123a, sum0123b);
113+ return _mm512_add_epi32 (sum0123a, sum0123b);
114+ };
115+
116+ [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
117+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
118+
119+ __m512i sum = m512_hadd128x16_interleave (sum0, sum1, sum2, sum3);
100120
101121 __m256i sum256lo = _mm512_castsi512_si256 (sum);
102122 __m256i sum256hi = _mm512_extracti64x4_epi64 (sum, 1 );
@@ -109,6 +129,58 @@ namespace Eval::NNUE::Layers {
109129 return _mm_add_epi32 (_mm_add_epi32 (sum128lo, sum128hi), bias);
110130 };
111131
132+ [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
133+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
134+ __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
135+
136+ __m512i suma = m512_hadd128x16_interleave (sum0, sum1, sum2, sum3);
137+ __m512i sumb = m512_hadd128x16_interleave (sum4, sum5, sum6, sum7);
138+
139+ __m512i indices0 = _mm512_setr_epi64 (0 , 1 , 8 , 9 , 4 , 5 , 12 , 13 );
140+ __m512i indices1 = _mm512_setr_epi64 (2 , 3 , 10 , 11 , 6 , 7 , 14 , 15 );
141+ __m512i x = _mm512_add_epi32 (
142+ _mm512_permutex2var_epi64 (suma, indices0, sumb),
143+ _mm512_permutex2var_epi64 (suma, indices1, sumb));
144+
145+ __m256i sum256lo = _mm512_castsi512_si256 (x);
146+ __m256i sum256hi = _mm512_extracti64x4_epi64 (x, 1 );
147+
148+ return _mm256_add_epi32 (_mm256_add_epi32 (sum256lo, sum256hi), bias);
149+ };
150+
151+ [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
152+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
153+
154+ __m512i sum = m512_hadd128x16_interleave (sum0, sum1, sum2, sum3);
155+
156+ __m512i indices = _mm512_setr_epi32 (
157+ 0 , 4 , 8 , 12 , 2 , 6 , 10 , 14 ,
158+ 1 , 5 , 9 , 13 , 3 , 7 , 11 , 15 );
159+ sum = _mm512_permutexvar_epi32 (indices, sum);
160+
161+ __m256i sum256lo = _mm512_castsi512_si256 (sum);
162+ __m256i sum256hi = _mm512_extracti64x4_epi64 (sum, 1 );
163+
164+ return _mm256_add_epi32 (_mm256_hadd_epi32 (sum256lo, sum256hi), bias);
165+ };
166+
167+ [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
168+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
169+ __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
170+
171+ __m512i suma = m512_hadd128x16_interleave (sum0, sum1, sum2, sum3);
172+ __m512i sumb = m512_hadd128x16_interleave (sum4, sum5, sum6, sum7);
173+
174+ __m512i indices0 = _mm512_setr_epi64 (0 , 1 , 8 , 9 , 4 , 5 , 12 , 13 );
175+ __m512i indices1 = _mm512_setr_epi64 (2 , 3 , 10 , 11 , 6 , 7 , 14 , 15 );
176+ __m512i x = _mm512_add_epi32 (
177+ _mm512_permutex2var_epi64 (suma, indices0, sumb),
178+ _mm512_permutex2var_epi64 (suma, indices1, sumb));
179+
180+ __m512i indices = _mm512_setr_epi32 (0 , 8 , 1 , 9 , 2 , 10 , 3 , 11 , 4 , 12 , 5 , 13 , 6 , 14 , 7 , 15 );
181+ return _mm512_add_epi32 (_mm512_permutexvar_epi32 (indices, x), bias);
182+ };
183+
112184 [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
113185#if defined (USE_VNNI)
114186 acc = _mm512_dpbusd_epi32 (acc, a, b);
@@ -205,7 +277,58 @@ namespace Eval::NNUE::Layers {
205277
206278 // kOutputDimensions is either 1 or a multiple of kSimdWidth
207279 // because then it is also an input dimension.
208- if constexpr (kOutputDimensions % 4 == 0 )
280+ if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1 )
281+ {
282+ for (IndexType i = 0 ; i < kOutputDimensions ; i += 16 )
283+ {
284+ const IndexType offset01a = (i + 0 ) * kPaddedInputDimensions ;
285+ const IndexType offset23a = (i + 2 ) * kPaddedInputDimensions ;
286+ const IndexType offset45a = (i + 4 ) * kPaddedInputDimensions ;
287+ const IndexType offset67a = (i + 6 ) * kPaddedInputDimensions ;
288+ const IndexType offset01b = (i + 8 ) * kPaddedInputDimensions ;
289+ const IndexType offset23b = (i + 10 ) * kPaddedInputDimensions ;
290+ const IndexType offset45b = (i + 12 ) * kPaddedInputDimensions ;
291+ const IndexType offset67b = (i + 14 ) * kPaddedInputDimensions ;
292+
293+ const __m512i bias = *reinterpret_cast <const __m512i*>(&biases_[i]);
294+ __m512i* outptr = reinterpret_cast <__m512i*>(&output[i]);
295+
296+ __m512i sum01a = _mm512_setzero_si512 ();
297+ __m512i sum23a = _mm512_setzero_si512 ();
298+ __m512i sum45a = _mm512_setzero_si512 ();
299+ __m512i sum67a = _mm512_setzero_si512 ();
300+ __m512i sum01b = _mm512_setzero_si512 ();
301+ __m512i sum23b = _mm512_setzero_si512 ();
302+ __m512i sum45b = _mm512_setzero_si512 ();
303+ __m512i sum67b = _mm512_setzero_si512 ();
304+
305+ const auto row01a = *reinterpret_cast <const __m512i*>(&weights_[offset01a]);
306+ const auto row23a = *reinterpret_cast <const __m512i*>(&weights_[offset23a]);
307+ const auto row45a = *reinterpret_cast <const __m512i*>(&weights_[offset45a]);
308+ const auto row67a = *reinterpret_cast <const __m512i*>(&weights_[offset67a]);
309+ const auto row01b = *reinterpret_cast <const __m512i*>(&weights_[offset01b]);
310+ const auto row23b = *reinterpret_cast <const __m512i*>(&weights_[offset23b]);
311+ const auto row45b = *reinterpret_cast <const __m512i*>(&weights_[offset45b]);
312+ const auto row67b = *reinterpret_cast <const __m512i*>(&weights_[offset67b]);
313+
314+ const __m256i in256 = input_vector256[0 ];
315+ const __m512i in = _mm512_inserti64x4 (_mm512_castsi256_si512 (in256), in256, 1 );
316+
317+ m512_add_dpbusd_epi32 (sum01a, in, row01a);
318+ m512_add_dpbusd_epi32 (sum23a, in, row23a);
319+ m512_add_dpbusd_epi32 (sum45a, in, row45a);
320+ m512_add_dpbusd_epi32 (sum67a, in, row67a);
321+ m512_add_dpbusd_epi32 (sum01b, in, row01b);
322+ m512_add_dpbusd_epi32 (sum23b, in, row23b);
323+ m512_add_dpbusd_epi32 (sum45b, in, row45b);
324+ m512_add_dpbusd_epi32 (sum67b, in, row67b);
325+
326+ *outptr = m512_hadd256x16 (
327+ sum01a, sum23a, sum45a, sum67a,
328+ sum01b, sum23b, sum45b, sum67b, bias);
329+ }
330+ }
331+ else if constexpr (kOutputDimensions % 4 == 0 )
209332 {
210333 for (IndexType i = 0 ; i < kOutputDimensions ; i += 4 )
211334 {
0 commit comments