@@ -35,45 +35,82 @@ namespace Stockfish::Eval::NNUE {
3535 // vector registers.
3636 #define VECTOR
3737
38+ static_assert (PSQTBuckets == 8 , " Assumed by the current choice of constants." );
39+
3840 #ifdef USE_AVX512
3941 typedef __m512i vec_t ;
42+ typedef __m256i psqt_vec_t ;
4043 #define vec_load (a ) _mm512_load_si512(a)
4144 #define vec_store (a,b ) _mm512_store_si512(a,b)
4245 #define vec_add_16 (a,b ) _mm512_add_epi16(a,b)
4346 #define vec_sub_16 (a,b ) _mm512_sub_epi16(a,b)
47+ #define vec_load_psqt (a ) _mm256_load_si256(a)
48+ #define vec_store_psqt (a,b ) _mm256_store_si256(a,b)
49+ #define vec_add_psqt_32 (a,b ) _mm256_add_epi32(a,b)
50+ #define vec_sub_psqt_32 (a,b ) _mm256_sub_epi32(a,b)
51+ #define vec_zero_psqt () _mm256_setzero_si256()
4452 static constexpr IndexType NumRegs = 8 ; // only 8 are needed
53+ static constexpr IndexType NumPsqtRegs = 1 ;
4554
4655 #elif USE_AVX2
4756 typedef __m256i vec_t ;
57+ typedef __m256i psqt_vec_t ;
4858 #define vec_load (a ) _mm256_load_si256(a)
4959 #define vec_store (a,b ) _mm256_store_si256(a,b)
5060 #define vec_add_16 (a,b ) _mm256_add_epi16(a,b)
5161 #define vec_sub_16 (a,b ) _mm256_sub_epi16(a,b)
62+ #define vec_load_psqt (a ) _mm256_load_si256(a)
63+ #define vec_store_psqt (a,b ) _mm256_store_si256(a,b)
64+ #define vec_add_psqt_32 (a,b ) _mm256_add_epi32(a,b)
65+ #define vec_sub_psqt_32 (a,b ) _mm256_sub_epi32(a,b)
66+ #define vec_zero_psqt () _mm256_setzero_si256()
5267 static constexpr IndexType NumRegs = 16 ;
68+ static constexpr IndexType NumPsqtRegs = 1 ;
5369
5470 #elif USE_SSE2
5571 typedef __m128i vec_t ;
72+ typedef __m128i psqt_vec_t ;
5673 #define vec_load (a ) (*(a))
5774 #define vec_store (a,b ) *(a)=(b)
5875 #define vec_add_16 (a,b ) _mm_add_epi16(a,b)
5976 #define vec_sub_16 (a,b ) _mm_sub_epi16(a,b)
77+ #define vec_load_psqt (a ) (*(a))
78+ #define vec_store_psqt (a,b ) *(a)=(b)
79+ #define vec_add_psqt_32 (a,b ) _mm_add_epi32(a,b)
80+ #define vec_sub_psqt_32 (a,b ) _mm_sub_epi32(a,b)
81+ #define vec_zero_psqt () _mm_setzero_si128()
6082 static constexpr IndexType NumRegs = Is64Bit ? 16 : 8 ;
83+ static constexpr IndexType NumPsqtRegs = 2 ;
6184
6285 #elif USE_MMX
6386 typedef __m64 vec_t ;
87+ typedef std::int32_t psqt_vec_t ;
6488 #define vec_load (a ) (*(a))
6589 #define vec_store (a,b ) *(a)=(b)
6690 #define vec_add_16 (a,b ) _mm_add_pi16(a,b)
6791 #define vec_sub_16 (a,b ) _mm_sub_pi16(a,b)
92+ #define vec_load_psqt (a ) (*(a))
93+ #define vec_store_psqt (a,b ) *(a)=(b)
94+ #define vec_add_psqt_32 (a,b ) a+b
95+ #define vec_sub_psqt_32 (a,b ) a-b
96+ #define vec_zero_psqt () 0
6897 static constexpr IndexType NumRegs = 8 ;
98+ static constexpr IndexType NumPsqtRegs = 8 ;
6999
70100 #elif USE_NEON
71101 typedef int16x8_t vec_t ;
102+ typedef int32x4_t psqt_vec_t ;
72103 #define vec_load (a ) (*(a))
73104 #define vec_store (a,b ) *(a)=(b)
74105 #define vec_add_16 (a,b ) vaddq_s16(a,b)
75106 #define vec_sub_16 (a,b ) vsubq_s16(a,b)
107+ #define vec_load_psqt (a ) (*(a))
108+ #define vec_store_psqt (a,b ) *(a)=(b)
109+ #define vec_add_psqt_32 (a,b ) vaddq_s32(a,b)
110+ #define vec_sub_psqt_32 (a,b ) vsubq_s32(a,b)
111+ #define vec_zero_psqt () psqt_vec_t {0 }
76112 static constexpr IndexType NumRegs = 16 ;
113+ static constexpr IndexType NumPsqtRegs = 2 ;
77114
78115 #else
79116 #undef VECTOR
@@ -89,7 +126,9 @@ namespace Stockfish::Eval::NNUE {
89126
90127 #ifdef VECTOR
91128 static constexpr IndexType TileHeight = NumRegs * sizeof (vec_t ) / 2 ;
129+ static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof (psqt_vec_t ) / 4 ;
92130 static_assert (HalfDimensions % TileHeight == 0 , " TileHeight must divide HalfDimensions" );
131+ static_assert (PSQTBuckets % PsqtTileHeight == 0 , " PsqtTileHeight must divide PSQTBuckets" );
93132 #endif
94133
95134 public:
@@ -115,15 +154,18 @@ namespace Stockfish::Eval::NNUE {
115154 biases[i] = read_little_endian<BiasType>(stream);
116155 for (std::size_t i = 0 ; i < HalfDimensions * InputDimensions; ++i)
117156 weights[i] = read_little_endian<WeightType>(stream);
157+ for (std::size_t i = 0 ; i < PSQTBuckets * InputDimensions; ++i)
158+ psqtWeights[i] = read_little_endian<PSQTWeightType>(stream);
118159 return !stream.fail ();
119160 }
120161
121162 // Convert input features
122- void transform (const Position& pos, OutputType* output) const {
163+ std:: int32_t transform (const Position& pos, OutputType* output, int bucket ) const {
123164 update_accumulator (pos, WHITE);
124165 update_accumulator (pos, BLACK);
125166
126167 const auto & accumulation = pos.state ()->accumulator .accumulation ;
168+ const auto & psqtAccumulation = pos.state ()->accumulator .psqtAccumulation ;
127169
128170 #if defined(USE_AVX512)
129171 constexpr IndexType NumChunks = HalfDimensions / (SimdWidth * 2 );
@@ -231,6 +273,12 @@ namespace Stockfish::Eval::NNUE {
231273 #if defined(USE_MMX)
232274 _mm_empty ();
233275 #endif
276+
277+ const auto psqt = (
278+ psqtAccumulation[static_cast <int >(perspectives[0 ])][bucket]
279+ - psqtAccumulation[static_cast <int >(perspectives[1 ])][bucket]
280+ ) / 2 ;
281+ return psqt;
234282 }
235283
236284 private:
@@ -246,6 +294,7 @@ namespace Stockfish::Eval::NNUE {
246294 // Gcc-10.2 unnecessarily spills AVX2 registers if this array
247295 // is defined in the VECTOR code below, once in each branch
248296 vec_t acc[NumRegs];
297+ psqt_vec_t psqt[NumPsqtRegs];
249298 #endif
250299
251300 // Look for a usable accumulator of an earlier position. We keep track
@@ -324,12 +373,52 @@ namespace Stockfish::Eval::NNUE {
324373 }
325374 }
326375
376+ for (IndexType j = 0 ; j < PSQTBuckets / PsqtTileHeight; ++j)
377+ {
378+ // Load accumulator
379+ auto accTilePsqt = reinterpret_cast <psqt_vec_t *>(
380+ &st->accumulator .psqtAccumulation [perspective][j * PsqtTileHeight]);
381+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
382+ psqt[k] = vec_load_psqt (&accTilePsqt[k]);
383+
384+ for (IndexType i = 0 ; states_to_update[i]; ++i)
385+ {
386+ // Difference calculation for the deactivated features
387+ for (const auto index : removed[i])
388+ {
389+ const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
390+ auto columnPsqt = reinterpret_cast <const psqt_vec_t *>(&psqtWeights[offset]);
391+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
392+ psqt[k] = vec_sub_psqt_32 (psqt[k], columnPsqt[k]);
393+ }
394+
395+ // Difference calculation for the activated features
396+ for (const auto index : added[i])
397+ {
398+ const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
399+ auto columnPsqt = reinterpret_cast <const psqt_vec_t *>(&psqtWeights[offset]);
400+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
401+ psqt[k] = vec_add_psqt_32 (psqt[k], columnPsqt[k]);
402+ }
403+
404+ // Store accumulator
405+ accTilePsqt = reinterpret_cast <psqt_vec_t *>(
406+ &states_to_update[i]->accumulator .psqtAccumulation [perspective][j * PsqtTileHeight]);
407+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
408+ vec_store_psqt (&accTilePsqt[k], psqt[k]);
409+ }
410+ }
411+
327412 #else
328413 for (IndexType i = 0 ; states_to_update[i]; ++i)
329414 {
330415 std::memcpy (states_to_update[i]->accumulator .accumulation [perspective],
331416 st->accumulator .accumulation [perspective],
332417 HalfDimensions * sizeof (BiasType));
418+
419+ for (std::size_t k = 0 ; k < PSQTBuckets; ++k)
420+ states_to_update[i]->accumulator .psqtAccumulation [perspective][k] = st->accumulator .psqtAccumulation [perspective][k];
421+
333422 st = states_to_update[i];
334423
335424 // Difference calculation for the deactivated features
@@ -339,6 +428,9 @@ namespace Stockfish::Eval::NNUE {
339428
340429 for (IndexType j = 0 ; j < HalfDimensions; ++j)
341430 st->accumulator .accumulation [perspective][j] -= weights[offset + j];
431+
432+ for (std::size_t k = 0 ; k < PSQTBuckets; ++k)
433+ st->accumulator .psqtAccumulation [perspective][k] -= psqtWeights[index * PSQTBuckets + k];
342434 }
343435
344436 // Difference calculation for the activated features
@@ -348,6 +440,9 @@ namespace Stockfish::Eval::NNUE {
348440
349441 for (IndexType j = 0 ; j < HalfDimensions; ++j)
350442 st->accumulator .accumulation [perspective][j] += weights[offset + j];
443+
444+ for (std::size_t k = 0 ; k < PSQTBuckets; ++k)
445+ st->accumulator .psqtAccumulation [perspective][k] += psqtWeights[index * PSQTBuckets + k];
351446 }
352447 }
353448 #endif
@@ -383,16 +478,42 @@ namespace Stockfish::Eval::NNUE {
383478 vec_store (&accTile[k], acc[k]);
384479 }
385480
481+ for (IndexType j = 0 ; j < PSQTBuckets / PsqtTileHeight; ++j)
482+ {
483+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
484+ psqt[k] = vec_zero_psqt ();
485+
486+ for (const auto index : active)
487+ {
488+ const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
489+ auto columnPsqt = reinterpret_cast <const psqt_vec_t *>(&psqtWeights[offset]);
490+
491+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
492+ psqt[k] = vec_add_psqt_32 (psqt[k], columnPsqt[k]);
493+ }
494+
495+ auto accTilePsqt = reinterpret_cast <psqt_vec_t *>(
496+ &accumulator.psqtAccumulation [perspective][j * PsqtTileHeight]);
497+ for (std::size_t k = 0 ; k < NumPsqtRegs; ++k)
498+ vec_store_psqt (&accTilePsqt[k], psqt[k]);
499+ }
500+
386501 #else
387502 std::memcpy (accumulator.accumulation [perspective], biases,
388503 HalfDimensions * sizeof (BiasType));
389504
505+ for (std::size_t k = 0 ; k < PSQTBuckets; ++k)
506+ accumulator.psqtAccumulation [perspective][k] = 0 ;
507+
390508 for (const auto index : active)
391509 {
392510 const IndexType offset = HalfDimensions * index;
393511
394512 for (IndexType j = 0 ; j < HalfDimensions; ++j)
395513 accumulator.accumulation [perspective][j] += weights[offset + j];
514+
515+ for (std::size_t k = 0 ; k < PSQTBuckets; ++k)
516+ accumulator.psqtAccumulation [perspective][k] += psqtWeights[index * PSQTBuckets + k];
396517 }
397518 #endif
398519 }
@@ -404,9 +525,11 @@ namespace Stockfish::Eval::NNUE {
404525
405526 using BiasType = std::int16_t ;
406527 using WeightType = std::int16_t ;
528+ using PSQTWeightType = std::int32_t ;
407529
408530 alignas (CacheLineSize) BiasType biases[HalfDimensions];
409531 alignas (CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];
532+ alignas (CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets];
410533 };
411534
412535} // namespace Stockfish::Eval::NNUE
0 commit comments