@@ -231,7 +231,12 @@ std::pair<std::int32_t, bool> Transformer::transform(Position & pos, std::uint8_
231231 if (abs (pt) > PSQT_THRESHOLD * WEIGHTS_SCALE)
232232 return { pt, true };
233233
234- #if defined(USE_AVX2)
234+ #if defined(USE_AVX512)
235+ std::uint32_t chunks = HalfDimensions / (SIMD_WIDTH * 2 );
236+ static_assert (HalfDimensions % (SIMD_WIDTH * 2 ) == 0 );
237+ const __m512i control = _mm512_setr_epi64 (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 );
238+ const __m512i zero = _mm512_setzero_si512 ();
239+ #elif defined(USE_AVX2)
235240 std::uint32_t chunks = HalfDimensions / SIMD_WIDTH;
236241 constexpr int control = 0b11011000 ;
237242 const __m256i zero = _mm256_setzero_si256 ();
@@ -240,7 +245,15 @@ std::pair<std::int32_t, bool> Transformer::transform(Position & pos, std::uint8_
240245 for (auto side : sides) {
241246 std::uint32_t offset = HalfDimensions * side;
242247
243- #if defined(USE_AVX2)
248+ #if defined(USE_AVX512)
249+ auto out = reinterpret_cast <__m512i*>(&outBuffer[offset]);
250+ for (std::uint32_t j = 0 ; j < chunks; ++j) {
251+ __m512i sum0 = _mm512_load_si512 (&reinterpret_cast <const __m512i*>(acc[sides[side]])[j * 2 + 0 ]);
252+ __m512i sum1 = _mm512_load_si512 (&reinterpret_cast <const __m512i*>(acc[sides[side]])[j * 2 + 1 ]);
253+ _mm512_store_si512 (&out[j], _mm512_permutexvar_epi64 (control,_mm512_max_epi8 (_mm512_packs_epi16 (sum0, sum1), zero)));
254+ }
255+
256+ #elif defined(USE_AVX2)
244257 auto out = reinterpret_cast <__m256i*>(&outBuffer[offset]);
245258 for (std::uint32_t j = 0 ; j < chunks; ++j) {
246259 __m256i sum0 = _mm256_loadA_si256 (&reinterpret_cast <const __m256i*>(acc[sides[side]][0 ])[j * 2 + 0 ]);
@@ -339,7 +352,13 @@ inline void Transformer::refresh(Position & pos) {
339352 for (std::size_t k = 0 ; k < PSQT_BUCKETS; ++k)
340353 accumulator.psqtAccumulation [c][k] += psqts[indexes[index] * PSQT_BUCKETS + k];
341354
342- #if defined(USE_AVX2)
355+ #if defined(USE_AVX512)
356+ auto accumulation = reinterpret_cast <__m512i*>(&accumulator.accumulation [c][0 ][0 ]);
357+ auto column = reinterpret_cast <const __m512i*>(&weights[offset]);
358+ constexpr std::uint32_t chunks = HalfDimensions / SIMD_WIDTH;
359+ for (std::uint32_t j = 0 ; j < chunks; ++j)
360+ _mm512_storeA_si512 (&accumulation[j], _mm512_add_epi16 (_mm512_loadA_si512 (&accumulation[j]), column[j]));
361+ #elif defined(USE_AVX2)
343362 auto accumulation = reinterpret_cast <__m256i*>(&accumulator.accumulation [c][0 ][0 ]);
344363 auto column = reinterpret_cast <const __m256i*>(&weights[offset]);
345364
@@ -362,7 +381,13 @@ inline std::int32_t * Layer<OutputDimensions, InputDimensions>::propagate(std::u
362381
363382 auto output = reinterpret_cast <std::int32_t *>(outBuffer);
364383
365- #if defined(USE_AVX2)
384+ #if defined(USE_AVX512)
385+ std::uint32_t chunks = InputDimensions / (SIMD_WIDTH * 2 );
386+ const auto input_vector = reinterpret_cast <const __m512i*>(features);
387+ #if !defined(USE_VNNI)
388+ const __m512i ones = _mm512_set1_epi16 (1 );
389+ #endif
390+ #elif defined(USE_AVX2)
366391 std::uint32_t chunks = InputDimensions / SIMD_WIDTH;
367392 const __m256i ones = _mm256_set1_epi16 (1 );
368393 const auto input_vector = reinterpret_cast <const __m256i*>(features);
@@ -371,7 +396,34 @@ inline std::int32_t * Layer<OutputDimensions, InputDimensions>::propagate(std::u
371396 for (std::uint32_t i = 0 ; i < OutputDimensions; ++i) {
372397 const std::uint32_t offset = i * InputDimensions;
373398
374- #if defined(USE_AVX2)
399+ #if defined(USE_AVX512)
400+ __m512i sum = _mm512_setzero_si512 ();
401+ const auto row = reinterpret_cast <const __m512i*>(&weights[offset]);
402+ for (std::uint32_t j = 0 ; j < chunks; ++j) {
403+ #if defined(USE_VNNI)
404+ sum = _mm512_dpbusd_epi32 (sum, _mm512_loadA_si512 (&input_vector[j]), _mm512_load_si512 (&row[j]));
405+ #else
406+ __m512i product = _mm512_maddubs_epi16 (_mm512_loadA_si512 (&input_vector[j]), _mm512_load_si512 (&row[j]));
407+ product = _mm512_madd_epi16 (product, ones);
408+ sum = _mm512_add_epi32 (sum, product);
409+ #endif
410+ }
411+
412+ if (InputDimensions != chunks * SIMD_WIDTH * 2 ) {
413+ const auto iv256 = reinterpret_cast <const __m256i*>(&input_vector[chunks]);
414+ const auto row256 = reinterpret_cast <const __m256i*>(&row[chunks]);
415+ #if defined(USE_VNNI)
416+ __m256i product256 = _mm256_dpbusd_epi32 (_mm512_castsi512_si256 (sum), _mm256_loadA_si256 (&iv256[0 ]), _mm256_load_si256 (&row256[0 ]));
417+ sum = _mm512_inserti32x8 (sum, product256, 0 );
418+ #else
419+ __m256i product256 = _mm256_maddubs_epi16 (_mm256_loadA_si256 (&iv256[0 ]), _mm256_load_si256 (&row256[0 ]));
420+ sum = _mm512_add_epi32 (sum, _mm512_cvtepi16_epi32 (product256));
421+ #endif
422+ }
423+
424+ output[i] = _mm512_reduce_add_epi32 (sum) + biases[i];
425+
426+ #elif defined(USE_AVX2)
375427 __m256i sum = _mm256_setzero_si256 ();
376428 const auto row = reinterpret_cast <const __m256i*>(&weights[offset]);
377429 for (std::uint32_t j = 0 ; j < chunks; ++j) {
0 commit comments